因为SVR只支持一个输出,所以我尝试在我的模型上使用SVR,它有6个输入和19个输出,使用MultiOutputRegressor。
我从超参数调优开始。然而,我得到了下面的错误。如何修改代码以支持MultiOutputRegressor
from sklearn.svm import SVR
from sklearn.model_selection import RandomizedSearchCV
svr = SVR()
svr_regr = MultiOutputRegressor(svr)
from sklearn.model_selection import KFold
kfold_splitter = KFold(n_splits=6, random_state = 0,shuffle=True)
#On each iteration, the algorithm will choose a difference combination of the features.
svr_random = RandomizedSearchCV(svr_regr,
param_distributions = {'kernel': ('linear','poly','rbf','sigmoid'),
'C': [1,1.5,2,2.5,3,3.5,4,4.5,5,5.5,6,6.5,7,7.5,8,8.5,9,9.5,10],
'degree': [3,8],
'coef0': [0.01,0.1,0.5],
'gamma': ('auto','scale')
'tol': [1e-3, 1e-4, 1e-5, 1e-6]},
n_iter=100,
cv=kfold_splitter,
n_jobs=-1,
random_state=42,
scoring='r2')
svr_random.fit(X_train, y_train)
print(svr_random.best_params_)错误:
ValueError: Invalid parameter kernel for estimator MultiOutputRegressor(estimator=SVR()). Check the list of available parameters with `estimator.get_params().keys()`.得到最优参数后的:
SVR_model = svr_regr (kernel='rbf',C=10,
coef0=0.01,degree=3,
gamma='auto',tol=1e-6,random_state=42)
SVR_model.fit(X_train, y_train)
SVR_model_y_predict = SVR_model.predict((X_test))
SVR_model_y_predict获得最佳参数后的误差:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/var/folders/mm/r4gnnwl948zclfyx12w803040000gn/T/ipykernel_96269/769104914.py in <module>
----> 1 SVR_model = svr_regr (estimator__kernel='rbf',estimator__C=10,
2 estimator__coef0=0.01,estimator__degree=3,
3 estimator__gamma='auto',estimator__tol=1e-6,random_state=42)
4
5
TypeError: 'MultiOutputRegressor' object is not callable发布于 2022-08-18 01:56:42
我试图复制一个简单的MultiOutputRegressor示例,而不使用GridSearchCV (即仅仅是拟合和预测方法),这似乎很好。错误信息:
使用estimator.get_params().keys()检查可用参数列表
建议您在GridSearchCV中优化的参数(即通过param_distributions )与MultiOutputRegressor接受的参数不匹配。查看API参考,MultiOutputRegressor只接受几个参数,您试图传递给SVR的参数,例如C和tol属于支持向量机估计器。
您可以通过类似于SVR的嵌套参数将参数传递给它是如何在管道中完成的。
https://stackoverflow.com/questions/73396411
复制相似问题