我是keras的新手
我尝试使用我的数据集,按照多层感知器( Multilayer Perceptron,MLP)的Keras教程进行多类softmax分类。我的数据有3个类别,只有一个特征,但我不明白为什么结果总是显示出0,3的准确率,而模型预测所有训练数据都是第一类。那么混淆矩阵是这样的。
下面是代码:
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import SGD
import pandas as pd
import numpy as np
# Importing the dataset
dataset = pd.read_csv('StatusAll.csv')
X = dataset.iloc[:, 1:].values
y = dataset.iloc[:, 0:1].values
# Splitting the dataset into the Training set and Test set
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)
from keras.utils import to_categorical
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
model = Sequential()
# Dense(64) is a fully-connected layer with 64 hidden units.
# in the first layer, you must specify the expected input data shape:
# here, 20-dimensional vectors.
model.add(Dense(64, activation='tanh', input_dim=1))
model.add(Dropout(0.5))
model.add(Dense(64, activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(4, activation='softmax'))
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
optimizer=sgd,
metrics=['accuracy'])
history = model.fit(x_train, y_train,
epochs=100,
batch_size=128)
score = model.evaluate(x_test, y_test, batch_size=128)
print('Test score:', score[0])
print('Test accuracy:', score[1])
from sklearn import metrics
prediction = model.predict(x_test)
prediction = np.around(prediction)
y_test_non_category = [ np.argmax(t) for t in y_test ]
y_predict_non_category = [ np.argmax(t) for t in prediction ]
from sklearn.metrics import confusion_matrix
conf_mat = confusion_matrix(y_test_non_category, y_predict_non_category)
print (conf_mat) 我希望我能得到一些建议,谢谢。
x_train示例x_train
发布于 2018-05-10 17:21:42
你的最终密集层有4个输出,看起来你的分类是4而不是3。
model.add(Dense(3, activation='softmax')) # Number of classes 3查看来自x_train和y_train的样本数据将有助于确保预处理是正确的。因为您只有1个功能,所以MLP可能会被夸大。除非您想尝试MLP,否则decision tree会更简单。
https://stackoverflow.com/questions/50268935
复制相似问题