首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用categorical_crossentropy时出错

使用categorical_crossentropy时出错
EN

Stack Overflow用户
提问于 2020-08-02 00:08:32
回答 2查看 2.9K关注 0票数 2

我正在用tensorflow学习深度学习。我做了一个简单的NLP代码,预测给定句子的下一个单词。

代码语言:javascript
复制
model = tf.keras.Sequential()
model.add(Embedding(num,64,input_length = max_len-1))   # we subtract 1 coz we cropped the laste word from X in out data
model.add(Bidirectional(LSTM(32)))
model.add(Dense(num,activation = 'softmax'))


model.compile(optimizer = 'adam',loss = 'categorical_crossentropy',metrics = ['accuracy'])

history = model.fit(X,Y,epochs = 500)

但是,使用categorical_crossentropy会给出以下错误

代码语言:javascript
复制
ValueError: You are passing a target array of shape (453, 1) while using as loss `categorical_crossentropy`. `categorical_crossentropy` expects targets to be binary matrices (1s and 0s) of shape (samples, classes). If your targets are integer classes, you can convert them to the expected format via:

从keras.utils导入to_categorical

y_binary = to_categorical(y_int)

代码语言:javascript
复制
Alternatively, you can use the loss function `sparse_categorical_crossentropy` instead, which does expect integer targets.

有人能解释一下这是什么意思吗?为什么我不能使用绝对的交叉熵损失函数?非常感谢!任何帮助都将不胜感激!

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-08-02 01:11:51

分类交叉熵用于多类分类问题.当您使用"softmax“作为激活时,将在输出层中为每个类提供一个节点。对于每个示例,对应于示例类的节点应该接近一个,其余节点应该接近。因此,真正的类标签Y需要是一个单热点编码向量.

假设Y中的类标签是0,1,2,……这样的整数。请试试下面的代码。

代码语言:javascript
复制
from keras.utils import to_categorical

model = tf.keras.Sequential()
model.add(Embedding(num,64,input_length = max_len-1))   # we subtract 1 coz we cropped the laste word from X in out data
model.add(Bidirectional(LSTM(32)))
model.add(Dense(num,activation = 'softmax'))


model.compile(optimizer = 'adam',loss = 'categorical_crossentropy',metrics = ['accuracy'])

Y_one_hot=to_categorical(Y) # convert Y into an one-hot vector
history = model.fit(X,Y_one_hot,epochs = 500)  # use Y_one_hot encoding instead of Y
票数 2
EN

Stack Overflow用户

发布于 2022-03-23 07:05:46

对于提供的答案(由Roohollah提供),您必须导入to_categorical,如下所示:

from keras.utils.np_utils import to_categorical

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63211181

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档