如何保存函数返回的关于已定义方法类的模型?我想为许多类制作相同的包装器,类似于(在我的例子中) Rocket类。
下面的代码产生一个错误: Can't pickle local object 'sktime_wrapper..SKtimeWrapper‘
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
def sktime_wrapper(method_class):
class SKtimeWrapper(method_class):
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
return SKtimeWrapper
model = sktime_wrapper(Rocket)
with open('model.pkl','wb') as f:
pickle.dump(model, f)在类被定义为顶级对象的情况下,pickle可以很好地工作。下面的代码就像一个护身符,保存模型时没有任何问题:
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
class SKtimeWrapper(Rocket):
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
model = SKtimeWrapper
with open('model.pkl','wb') as f:
pickle.dump(model, f)发布于 2021-10-01 07:10:42
按照答案部分,我设法让它工作!我希望有人会觉得这很有用。诀窍是使用__reduce__()函数。
Bellow就是一个很好的例子。请注意,对象必须在保存前初始化。
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
def sktime_wrapper(method_class):
class SKtimeWrapper(method_class):
PARAM = method_class
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
def __reduce__(self):
return (_InitializeParameterized(), (self.PARAM,), self.__dict__)
return SKtimeWrapper
class _InitializeParameterized(object):
"""
When called with the param value as the only argument, returns an
un-initialized instance of the parameterized class. Subsequent __setstate__
will be called by pickle.
"""
def __call__(self, method_class):
# make a simple object which has no complex __init__ (this one will do)
obj = _InitializeParameterized()
obj.__class__ = sktime_wrapper(method_class)
return obj
model = sktime_wrapper(Rocket)()
with open('model.pkl','wb') as f:
pickle.dump(model, f)https://stackoverflow.com/questions/69392199
复制相似问题