首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用函数返回的pickle保存对象

使用函数返回的pickle保存对象
EN

Stack Overflow用户
提问于 2021-09-30 12:24:15
回答 1查看 65关注 0票数 1

如何保存函数返回的关于已定义方法类的模型?我想为许多类制作相同的包装器,类似于(在我的例子中) Rocket类。

下面的代码产生一个错误: Can't pickle local object 'sktime_wrapper..SKtimeWrapper‘

代码语言:javascript
复制
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested

def sktime_wrapper(method_class):
    class SKtimeWrapper(method_class):
        def transform(self, X):
            X = from_2d_array_to_nested(X)
            return super().transform(X)

        def fit(self, X, Y):
            X = from_2d_array_to_nested(X)
            return super().fit(X, Y)

    return SKtimeWrapper


model = sktime_wrapper(Rocket)

with open('model.pkl','wb') as f:
    pickle.dump(model, f)

在类被定义为顶级对象的情况下,pickle可以很好地工作。下面的代码就像一个护身符,保存模型时没有任何问题:

代码语言:javascript
复制
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested

class SKtimeWrapper(Rocket):
    def transform(self, X):
        X = from_2d_array_to_nested(X)
        return super().transform(X)

    def fit(self, X, Y):
        X = from_2d_array_to_nested(X)
        return super().fit(X, Y)

model = SKtimeWrapper


with open('model.pkl','wb') as f:
    pickle.dump(model, f)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-10-01 07:10:42

按照答案部分,我设法让它工作!我希望有人会觉得这很有用。诀窍是使用__reduce__()函数。

Bellow就是一个很好的例子。请注意,对象必须在保存前初始化。

代码语言:javascript
复制
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested

def sktime_wrapper(method_class):
    class SKtimeWrapper(method_class):
        PARAM = method_class
        def transform(self, X):
            X = from_2d_array_to_nested(X)
            return super().transform(X)

        def fit(self, X, Y):
            X = from_2d_array_to_nested(X)
            return super().fit(X, Y)

        def __reduce__(self):
            return (_InitializeParameterized(), (self.PARAM,), self.__dict__)

    return SKtimeWrapper

class _InitializeParameterized(object):
    """
    When called with the param value as the only argument, returns an
    un-initialized instance of the parameterized class. Subsequent __setstate__
    will be called by pickle.
    """
    def __call__(self, method_class):
        # make a simple object which has no complex __init__ (this one will do)
        obj = _InitializeParameterized()
        obj.__class__ = sktime_wrapper(method_class)
        return obj


model = sktime_wrapper(Rocket)()

with open('model.pkl','wb') as f:
    pickle.dump(model, f)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69392199

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档