首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >值错误-检查目标时出错- LSTM

值错误-检查目标时出错- LSTM
EN

Stack Overflow用户
提问于 2020-05-03 19:16:02
回答 1查看 41关注 0票数 0

关于数据集

以下路透社数据集包含11228个文本,对应于46个类别的新闻。从每个单词对应一个整数的意义上来说,文本是加密的。我指定我们希望使用2000个单词。

代码语言:javascript
复制
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

num_words = 2000
(reuters_train_x, reuters_train_y), (reuters_test_x, reuters_test_y) = tf.keras.datasets.reuters.load_data(num_words=num_words)

n_labels = np.unique(reuters_train_y).shape[0]
print("labels: {}".format(n_labels))

# This is the first new
print(reuters_train_x[0])

实现LSTM

我需要实现一个具有10个单元的单个LSTM的网络。在进入LSTM单元格之前,输入需要嵌入10个维度。最后,需要添加密集层,以根据类别的数量调整输出的数量。

代码语言:javascript
复制
from keras.models import Sequential
from keras.layers import LSTM, Dense, Embedding
from from tensorflow.keras.utils import to_categorical

reuters_train_y = to_categorical(reuters_train_y, 46)
reuters_test_y = to_categorical(reuters_test_y, 46)

model = Sequential()
model.add(Embedding(input_dim = num_words, 10))
model.add(LSTM(10))
model.add(Dense(46,activation='softmax'))

培训

代码语言:javascript
复制
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
history = model.fit(reuters_train_x,reuters_train_y,epochs=20,validation_data=(reuters_test_x,reuters_test_y))

我得到的错误消息是:

代码语言:javascript
复制
ValueError: Error when checking target: expected dense_2 to have shape (46,) but got array with shape (1,)
EN

回答 1

Stack Overflow用户

发布于 2020-05-03 19:29:59

你需要对你的y标签进行一次热编码。

代码语言:javascript
复制
from tensorflow.keras.utils import to_categorical

reuters_train_y = to_categorical(reuters_train_y, 46)

reuters_test_y = to_categorical(reuters_test_y, 46)

我在fit函数中看到的另一个错误是,您传递的是validation_data=(reuters_test_x,reuters_train_y),但它应该是validation_data=(reuters_test_x,reuters_test_y)

你的x是一个不同长度列表的数值数组。您需要填充序列以获得固定形状的数值数组。

代码语言:javascript
复制
reuters_train_x = tf.keras.preprocessing.sequence.pad_sequences(
    reuters_train_x, maxlen=50
)

reuters_test_x = tf.keras.preprocessing.sequence.pad_sequences(
    reuters_test_x, maxlen=50
)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61573517

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档