首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >skflow回归预测多值

skflow回归预测多值
EN

Stack Overflow用户
提问于 2015-12-11 13:37:01
回答 1查看 2.5K关注 0票数 7

我试图预测一个时间序列:给定50个以前的值,我想预测下一个值。

为此,我使用了skflow包(基于TensorFlow),这个问题相对接近于波士顿的例子,提供在Github回购。

我的代码如下:

代码语言:javascript
复制
%matplotlib inline
import pandas as pd

import skflow
from sklearn import cross_validation, metrics
from sklearn import preprocessing

filepath = 'CSV/FILE.csv'
ts = pd.Series.from_csv(filepath)

nprev = 50
deltasuiv = 5

def load_data(data, n_prev = nprev, delta_suiv=deltasuiv):  

    docX, docY = [], []
    for i in range(len(data)-n_prev-delta_suiv):
        docX.append(np.array(data[i:i+n_prev]))
        docY.append(np.array(data[i+n_prev:i+n_prev+delta_suiv]))
    alsX = np.array(docX)
    alsY = np.array(docY)

    return alsX, alsY

X, y = load_data(ts.values) 
# Scale data to 0 mean and unit std dev.
scaler = preprocessing.StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y,
    test_size=0.2, random_state=42)
regressor = skflow.TensorFlowDNNRegressor(hidden_units=[30, 50],
    steps=5000, learning_rate=0.1, batch_size=1)
regressor.fit(X_train, y_train)
score = metrics.mean_squared_error(regressor.predict(X_test), y_test)
print('MSE: {0:f}'.format(score))

这导致:

ValueError: y_true和y_pred有不同的输出数(1!=5)

在训练结束时。

当我试图预测时,我也遇到了同样的问题

代码语言:javascript
复制
ypred = regressor.predict(X_test)
print ypred.shape, y_test.shape

(200,1) (200,5)

因此,我们可以看到,该模型在某种程度上预测的只是一个值,而不是5个希望/希望的值。

如何使用相同的模型来预测多个值的值?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2015-12-17 18:38:15

我刚刚在skflow中添加了对多输出回归的支持,因为这个#e443c734,所以请重新安装这个软件包,再试一次。如果不起作用,请继续追踪吉特布。

我还向示例文件夹添加了一个多输出回归示例

代码语言:javascript
复制
# Create random dataset.
rng = np.random.RandomState(1)
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T

# Fit regression DNN model.
regressor = skflow.TensorFlowDNNRegressor(hidden_units=[5, 5])
regressor.fit(X, y)
score = mean_squared_error(regressor.predict(X), y)
print("Mean Squared Error: {0:f}".format(score))
票数 6
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/34224826

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档