首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用Keras预测函数/表?

如何使用Keras预测函数/表?
EN

Stack Overflow用户
提问于 2018-01-28 12:57:06
回答 2查看 821关注 0票数 3

我目前正在学习角蛋白。我的目标是创建一个简单的模型,用来预测函数的值。首先,我创建两个数组,一个用于X值,另一个用于对应的Y值.

代码语言:javascript
复制
# declare and init arrays for training-data
X = np.arange(0.0, 10.0, 0.05)
Y = np.empty(shape=0, dtype=float)

# Calculate Y-Values
for x in X:
    Y = np.append(Y, float(0.05*(15.72807*x - 7.273893*x**2 + 1.4912*x**3 - 0.1384615*x**4 + 0.00474359*x**5)))

然后我创建并训练模型

代码语言:javascript
复制
# model architecture
model = Sequential()
model.add(Dense(1, input_shape=(1,)))
model.add(Dense(5))
model.add(Dense(1, activation='linear'))

# compile model
model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['accuracy'])

# train model
model.fit(X, Y, epochs=150, batch_size=10)

并使用该模型预测值。

代码语言:javascript
复制
# declare and init arrays for prediction
YPredict = np.empty(shape=0, dtype=float)

# Predict Y
YPredict = model.predict(X)

# plot training-data and prediction
plt.plot(X, Y, 'C0')
plt.plot(X, YPredict, 'C1')

# show graph
plt.show()

我得到了这个输出(蓝色是训练-数据,橙色是预测):

我做错什么了?我想这是网络体系结构的一个基本问题,对吧?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-01-28 13:18:07

问题确实在于您的网络架构。具体来说,您在所有层中都使用线性激活:这意味着网络只能适应线性函数。您应该在输出层中保持线性激活,但是应该在隐藏层中使用ReLU激活:

代码语言:javascript
复制
model.add(Dense(1, input_shape=(1,)))
model.add(Dense(5, activation='relu'))
model.add(Dense(1, activation='linear'))

然后,使用隐藏层的数量/大小;我建议您多使用几层。

票数 3
EN

Stack Overflow用户

发布于 2018-01-28 17:14:31

在BlackBear提供的答案的基础上:

  • 在将输入X和输出Y输入到您的神经网络之前,您应该规范化输入X和输出Y:从sklearn.preprocessing导入StandardScaler sc_X = StandardScaler() X_train = sc_X.fit_transform(X) sc_Y = StandardScaler() Y_train = sc_Y.fit_transform(Y) #.model.fit(X_train,Y_train,.)在与您的回归设置非常相似的情况下,查看this answer以查看如果不这样做会发生什么。请记住,您应该使用sc_X对任何测试数据进行类似的缩放;此外,如果以后需要将模型生成的任何predictions缩放回Y的原始比例,则应该使用 Sc_Y.inverse_transform(预测)
  • 在像您这样的回归设置中,准确性没有任何意义;您应该从模型编译中删除metrics=['accuracy'] (损失本身就足够作为一个度量)。
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48486598

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档