首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从测试集的最后一个数据点进行预测

如何从测试集的最后一个数据点进行预测
EN

Stack Overflow用户
提问于 2019-06-14 21:11:35
回答 1查看 103关注 0票数 0

我正在进行一个时间序列预测项目。我的任务是在有1月到11月的数据的情况下预测12月份的销售额。我将数据拆分成训练集和测试集。我已经应用随机森林回归在测试集上进行预测。然而,我不知道如何使用该模型来预测12月份的销售量。你能告诉我怎么做吗?提前谢谢你。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-06-14 21:50:05

如果您已经完成了数据清理,并且已经将它们拆分为trainingtesting数据集。您可以简单地将它们放入我创建的这个pipline函数中。该generic function将任何算法和数据作为输入并建立模型,对testing数据集执行交叉验证并生成预测。

代码语言:javascript
复制
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
import pandas as pd
import plotly.plotly as ply
import cufflinks as cf

cf.go_offline()


#Define target and ID columns:
target = 'sales'
IDcol = ['months']
predictors = [x for x in training.columns if x not in [target]+IDcol]

alg = RandomForestRegressor(n_estimators=200,max_depth=5, min_samples_leaf=100,n_jobs=4)
test = modelfitting(alg, training, testing, predictors, target)
coef5 = pd.Series(alg.feature_importances_, predictors).sort_values(ascending=False)
coef5.iplot(kind='bar', title='Feature Importances')

for_plot = test
for_plot = for_plot[['sales prediction']]
for_plot.iplot()


def modelfitting(alg, training, testing, predictors, target):
    # Fit the algorithm on the data
    alg.fit(training[predictors], training[target])

    # Predict training set:
    dtrain_predictions = alg.predict(training[predictors])

    # Perform cross-validation:
    cv_score = cross_val_score(alg, training[predictors], training[target], cv=20, scoring='neg_mean_squared_error')
    cv_score = np.sqrt(np.abs(cv_score))

    # Print model report:
    print "\nModel Report"
    print "RMSE : %.4g" % np.sqrt(metrics.mean_squared_error(training[target].values, dtrain_predictions))
    print "CV Score : Mean - %.4g | Std - %.4g | Min - %.4g | Max - %.4g" % (
    np.mean(cv_score), np.std(cv_score), np.min(cv_score), np.max(cv_score))

    # Predict on testing data:
    testing["sales prediction"] = alg.predict(testing[predictors])

    return testing

我已经添加了不言自明的注释。如果您在理解代码时遇到困难,请在评论中自由讨论。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56598926

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档