tensorflow.contrib.seq2seq的dynamic_decode返回三个值,第一个值是带有指定字段( 'rnn_output'和sample_id )的二元组。我试图理解sample_id是什么,但我找不到任何示例或文档,TensorFlow开发人员峰会上的示例没有添加太多信息。有人能解释一下吗?
发布于 2018-02-10 20:15:58
sample_id是rnn输出的argmax。
发布于 2018-12-27 08:12:44
rnn_output=[batch_size, max length of a sentence, probability of each word in a vocabulary]
sample_id = [batch_size, max length of a sentence]例如:
batch_size is 99
max length of a sentence is 15
Vocabulary size is 233rnn_output = [99,15,233]
sample_id = [99,15]如前所述,sample_id第二维度包含rnn_output第三维度的argmax值。
在一种更简单的语言中,sample_id第二维度将具有rnn_output第三层dimension->max value->index。
https://stackoverflow.com/questions/44686878
复制相似问题