首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TypeError: forward()得到了一个意外的关键字参数“基线值”。如何在Skorch中正确加载保存的模型?

TypeError: forward()得到了一个意外的关键字参数“基线值”。如何在Skorch中正确加载保存的模型?
EN

Stack Overflow用户
提问于 2022-05-05 13:22:04
回答 1查看 151关注 0票数 0

我使用以下代码保存了Skorch神经网络模型:

代码语言:javascript
复制
net_b = NeuralNetClassifier(
    Classifier_b,
    max_epochs=50,
    optimizer__momentum= 0.9,
    lr=0.1,
    device=device,
)

#Fit the model on the full data
net_b.fit(merged_X_train, merged_Y_train);

#Test saving
import pickle
with open('MLP.pkl', 'wb') as f:
    pickle.dump(net_b, f)

当我再次加载这个模型并针对测试数据运行它时,我会收到以下错误:

代码语言:javascript
复制
TypeError: forward() got an unexpected keyword argument 'baseline value'

这是我的密码:

代码语言:javascript
复制
#Split the data
X_train, y_train, X_valid, y_valid,X_test, y_test = train_valid_test_split(rescaled_data, target = 'fetal_health',
                                        train_size=0.8, valid_size=0.1, test_size=0.1)

input_dim = f_df_toscale.shape[1]
output_dim = len(np.unique(f_target))
hidden_dim_a = 20
hidden_dim_b = 12
device = 'cpu'

class Classifier_b(nn.Module):
    def __init__(self,
                 input_dim = input_dim,
                 hidden_dim_a = hidden_dim_b,
                 output_dim = output_dim):
        
        super(Classifier_b, self).__init__()

        #Take the inputs and pass these to a hidden layer
        self.hidden = nn.Linear(input_dim,hidden_dim_b)
        
        #Take the hidden layer and pass it through an additional hidden layer
        self.hidden_b = nn.Linear(hidden_dim_a,hidden_dim_b)
        
        #Take the hidden layer and pass to a multi nerouon output
        self.output = nn.Linear(hidden_dim_b,output_dim)

    def forward(self, x):
        hidden = F.relu(self.hidden(x))
        hidden = F.relu(self.hidden_b(hidden))
        output = F.softmax(self.output(hidden))     
        return output

#load the model
with open('MLP.pkl', 'rb') as f:
    model_MLP = pickle.load(f)

#Test the model
y_pred = model_MLP.predict(X_test)
ML = accuracy_score(y_test, y_pred)
print('The accuracy score for the MLP is ', ML)

当我在原始笔记本上正常运行这个型号时,一切都会被罚款。但是,当我试图从保存的状态加载我的模型时,我会得到错误。知道为什么吗?我没有什么叫“基线值”。

谢谢

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-05-05 13:35:57

如果代码更改,则保存和加载模型可能会出现问题。所以用它更好

save_params()load_params()

在你的情况下

net_b.save_params(f_params='some-file.pkl')

加载模型首先初始化(初始化非常重要),然后加载参数

new_net.initialize()

new_net.load_params(f_params='some-file.pkl')

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72127915

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档