首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在pickle.load()之后,XGBRegressor.predict()返回错误

在pickle.load()之后,XGBRegressor.predict()返回错误
EN

Stack Overflow用户
提问于 2017-07-31 16:52:52
回答 1查看 943关注 0票数 2

我已经使用sklearn界面训练了XGBRegressor模型。相关代码如下:

代码语言:javascript
复制
def xgb_regressor_wrapper(X_train, y_train):
    xgb_regressor = XGBRegressor(objective='reg:linear', n_estimators=1000, learning_rate=0.01, base_score=0.005)
    xgb_regressor.fit(X=X_train, y=y_train) #, eval_set=[(X_test, y_test)], verbose=True)
    return xgb_regressor

def save_regressor(station, feature, regressor):
    fname = generate_regressor_fname(station, feature)
    pickle.dump(regressor, open(fname, "wb" ))

# regressor_list dict contains wrapper functions
# I currently have XGBRegressor and CatBoostRegressor in the list.
regressor_wrapper = regressor_list.get(name) 

# Create and fit XGBRegressor
regressor = regressor_wrapper(X_train, y_train)

# Save regressor
save_regressor(station_id, feature, best_regressor)

一段时间后,我使用以下代码重新加载回归器,并进行预测:

代码语言:javascript
复制
def load_regressor(station, feature):
    fname = generate_regressor_fname(station, feature)
    return pickle.load(open(fname, "rb" ))

# Load the regressor
regressor = load_regressor(station_id, feature)

# Do the prediction
y_predict = regressor.predict(X_test)

我得到以下错误:

代码语言:javascript
复制
  File "regressor_stuff.py", line 169, in regressor_check_for_station_feature
    y_predict = regressor.predict(X_test)
  File "D:\Anaconda\envs\Deep\lib\site-packages\xgboost\sklearn.py", line 268, in predict
    return self.booster().predict(test_dmatrix,
TypeError: 'str' object is not callable

经过一些调试,我发现self.booster实际上存储了字符串'gbtree‘。在训练了回归器的特性后(顺便说一句,这花了几天时间),这并不酷。

对于为什么会发生这种情况,有什么建议吗?

我目前的解决方法是按如下方式重建XGBBooster:

代码语言:javascript
复制
# Load the regressor
if isinstance(regressor, XGBRegressor):
    regressor = XGBRegressor()
    r = pickle.load(open(fname, "rb" ))
    print r.get_xgb_params()
    regressor._Booster = r._Booster
    regressor.set_params(**r.get_xgb_params())

# Do the prediction
y_predict = regressor.predict(X_test)

谢谢

库尔萨特

EN

回答 1

Stack Overflow用户

发布于 2017-08-22 07:58:37

我认为在您的训练和评分环境中,您可能会遇到xgboost版本不匹配的问题。我遇到了同样的问题,并发现我使用xgboost==0.6进行训练,而不是使用xgboost==0.6a2进行评分。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45411357

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档