首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >测试直线的CatBoostRegression预测

测试直线的CatBoostRegression预测
EN

Stack Overflow用户
提问于 2017-09-19 21:48:11
回答 1查看 1.2K关注 0票数 2

测试数据集中的CatBoostRegressor符合直线

第一个图是训练数据集( CatBoostRegressor基于噪声的sin训练),第二个图是测试数据集。

为什么它是一条直线?其他函数的相同(如f(x)=x等)

代码语言:javascript
复制
x = np.linspace(0, 2*np.pi, 100)
y = func(x) + np.random.normal(0, 3, len(x))

x_test = np.linspace(0*np.pi, 4*np.pi, 200)
y_test = func(x_test)

train_pool = Pool(x.reshape((-1,1)), y)
test_pool = Pool(x_test.reshape((-1,1))) 

model = CatBoostRegressor(iterations=100, depth=2, loss_function="RMSE",
                          verbose=True
                          )
model.fit(train_pool)

y_pred = model.predict(x.reshape((-1,1)))
y_test_pred = model.predict(test_pool)

poly = Polynomial(4)
p = poly.fit(x,y);


plt.plot(x, y, 'ko')
plt.plot(x, func(x), 'k')
plt.plot(x, y_pred, 'r')
plt.plot(x, poly.evaluate(p, x), 'b')

plt.show()

plt.plot(x_test, y_test, 'k')
plt.plot(x_test, y_test_pred, 'r')
plt.show()
plt.plot(x_test, y_test, 'k')
plt.plot(x_test, poly.evaluate(p, x_test), 'b')
plt.show()
EN

回答 1

Stack Overflow用户

发布于 2018-01-18 14:30:14

这是因为决策树是分段常数函数,而Catboost完全是基于决策树的。所以,catboost总是用一个常数来推断。

因此,Catboost (以及其他基于树的算法,如XGBoost或随机森林的所有实现)在外推方面都很差(除非您做了一个聪明的特性工程,实际上它是自己推断出来的)。

在您的例子中,Catboost用常数推断正弦,这是不酷的。但是多项式拟合更糟糕:它很快就会无限大!

这是生成图片的完整代码:

代码语言:javascript
复制
import numpy as np
func = np.sin
from catboost import Pool, CatBoostRegressor
from numpy.polynomial.polynomial import Polynomial
import matplotlib.pyplot as plt

np.random.seed(1)

x = np.linspace(0, 2*np.pi, 100)
y = func(x) + np.random.normal(0, 3, len(x))

x_test = np.linspace(0*np.pi, 4*np.pi, 200)
y_test = func(x_test)

train_pool = Pool(x.reshape((-1,1)), y)
test_pool = Pool(x_test.reshape((-1,1))) 

model = CatBoostRegressor(iterations=100, depth=2, loss_function="RMSE",verbose=False)
model.fit(train_pool, verbose=False)

y_pred = model.predict(x.reshape((-1,1)))
y_test_pred = model.predict(test_pool)

p = np.polyfit(x, y, deg=4)

plt.scatter(x, y, s=3, c='k')
plt.plot(x_test, y_test, 'k')
plt.plot(x_test, y_test_pred, 'r')
plt.plot(x_test, np.polyval(p, x_test), 'b')
plt.title('Out-of-sample performance of trees and polynomials')
plt.legend(['training data', 'true', 'catboost', 'polynomial'])
plt.ylim([-4, 4])
plt.show()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46310237

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档