首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >理解Keras预测

理解Keras预测
EN

Stack Overflow用户
提问于 2021-05-05 20:57:01
回答 1查看 47关注 0票数 0

我得到了以下代码:

代码语言:javascript
复制
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping
import numpy as np
from numpy.random import seed
from tensorflow import random

seed(42)
random.set_seed(43)

X = [
    'may it all be fine in the world',
    'this is not for me',
    'pffff ugly bike',
    'dropping by to say leave me alone',
    'getting sarcastic by now'
    'how would one satisfy his or her needs when the earth is boiling'
]

y = [1,2,4,5,3]

tokenizer = Tokenizer(num_words = 13)
tokenizer.fit_on_texts(X)
X_train_seq = tokenizer.texts_to_sequences(X)


X_train_seq_padded = pad_sequences(X_train_seq, maxlen = 15)

model = Sequential()
model.add(Dense(16, input_dim = 15, activation = 'relu', name = 'hidden-1'))
model.add(Dense(16, activation = 'relu', name = 'hidden-2'))
model.add(Dense(16, activation = 'relu', name = 'hidden-3'))
model.add(Dense(5, activation='softmax', name = 'output_layer'))

model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics=['accuracy'])

class CustomCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print('finished an epoch')
        zin = 'dropping by to say leave her alone'
        zin = tokenizer.texts_to_sequences(zin)
        zin = pad_sequences(zin, maxlen = 15)
        print(model.predict(zin))
        print(np.argmax(model.predict(zin), axis=-1))
callbacks = [EarlyStopping(monitor = 'accuracy', patience = 5, mode = 'max'), CustomCallback()]

from sklearn.preprocessing import LabelBinarizer
encoder = LabelBinarizer()
y = encoder.fit_transform(y)

history = model.fit(X_train_seq_padded, y, epochs = 100, batch_size = 100, callbacks = callbacks)

我预计在回调model.predict()内部会产生类似这样的结果(因为有5个可能的类):

代码语言:javascript
复制
[0.4534534, 0.5634246, 0.0045623, 0.0004536, 0.0000056]

和单个数字1、2、3、4或5中的np.argmax(model.predict(zin), axis=-1)

然而,我收到的输出(显示一个时期)是:

我必须如何解释这一点,以及如何过滤出模型将预测句子所属的实际类?

EN

回答 1

Stack Overflow用户

发布于 2021-05-05 21:15:19

代码语言:javascript
复制
print(model.predict(zin)[0])
print(np.argmax(model.predict(zin)[0], axis=-1))

这将为您提供正确的值。

tf模型被设计为以批处理方式工作,而不是一次性使用,因此它为您提供了一个输出列表,但因为您的输入是单个项目,所以它只会将该项目通过NN推送n次,从而产生相同的输出列表

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67401813

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档