首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >具有3个输出的Keras回归模型只对一个输出给出了准确的结果

具有3个输出的Keras回归模型只对一个输出给出了准确的结果
EN

Stack Overflow用户
提问于 2019-04-09 23:28:54
回答 1查看 287关注 0票数 2

我正在尝试运行一个神经网络,使用python中的keras,它有2个值作为输入,3个值作为输出。输入表示固有频率,而输出表示等效冰负载。问题是在模型完成训练后,它似乎只被训练为预测一个输入,而不是所有三个输入。这个模型是回归的,而不是classification.Here,我给出了我的代码

代码语言:javascript
复制
seed = 9
np.random.seed(seed)
# import dataset
dataset=np.loadtxt("Final_test_matrix_new_3_digits.csv", delimiter=",")
# Define dataset
Y=dataset[:, 0:3]
X=dataset[:, 3:5]
#Categorize data
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.3, 
random_state = seed)
# create model
model = Sequential()
model.add(Dense(12, input_dim=2,activation='relu'))
model.add(Dense(8, init='uniform', activation='relu'))
model.add(Dense(8, init='uniform', activation='relu'))
model.add(Dense(3, init='uniform', activation='relu'))
# compile the model
model.compile(loss='mean_squared_logarithmic_error', optimizer='adam', 
metrics=['accuracy'])
# checkpoint
filepath="weights.best_12_8_8_neurons.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, 
save_best_only=True, mode='max')
callbacks_list = [checkpoint]
# fit the model
history=model.fit(X_train, Y_train, validation_split=0.1, epochs=100000, 
batch_size=10,callbacks=callbacks_list)
# evaluate the model
scores = model.evaluate(X_test, Y_test)
print ("Accuracy: %.2f%%" %(scores[1]*100))

根据python的说法,该模型的准确性为65%,但这一事实不会影响输出的准确性,因为第二个输出的精度低于第一个输出的精度,而第三个输出的精度几乎为0。代码的主要目标是创建一个回归模型,其中所有输出将具有相同的accuracy.In如下所示分别显示模型精度、模型损失和每个输出的预测:

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-04-10 02:53:32

看起来随机森林更适合你的情况。你应该试一试,特别是如果你有不平衡的类。

作为一种解决办法,您可以增加Dense(8)层中的节点数,这取决于数据的变化。

然后,您必须检查少数类并采用以下代码(合成少数过采样技术):

代码语言:javascript
复制
from imblearn.over_sampling import SMOTE

sm = SMOTE()
x_train2, y_train2 = sm.fit_sample(X_train, Y_train)

请注意,此代码仅适用于二进制输出,因此您应该对3个类进行一次热编码,然后应用类0和类1,然后应用类0和类2,从类0中删除双倍过采样。然后运行神经网络模型,将validation_split增加到0.2。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55596342

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档