word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings,ids=self.word_ids)
word_embeddings_modify = tf.scatter_nd_update(word_embeddings, self.error_word, sum_all)
Error:
Tensor conversion requested dtype float32_ref for Tensor with dtype float32从错误中看,函数scatter_nd_update中的word_embeddings实际的dtype是tf.float_32,但scatter_nd_update应该接受word_embeddings dtype tf.float_32_ref。
如何在使用tf.float_32之前将word_embeddings's dtype从tf.float_32更改为tf.float_32_ref
发布于 2019-01-10 15:37:21
您可以使用tf.Variable()直接转换数据类型。举个例子:
import tensorflow as tf
_word_embeddings = tf.get_variable(name='embedding',shape=[30, 5])
word_ids = [3,6,23]
word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings,ids=word_ids)
error_word = [[1]]
sum_all = [[0,0,0,0,0]]
word_embeddings_modify = tf.scatter_nd_update(tf.Variable(word_embeddings), error_word, sum_all)
with tf.Session()as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(word_embeddings))
print(sess.run(word_embeddings_modify))
[[ 0.08698401 -0.15590087 0.00285593 -0.13804913 -0.12418613]
[-0.25748074 0.32121882 -0.390212 0.24590132 0.3976703 ]
[-0.3023583 0.00366881 -0.05178839 -0.20865369 0.2887713 ]]
[[ 0.08698401 -0.15590087 0.00285593 -0.13804913 -0.12418613]
[ 0. 0. 0. 0. 0. ]
[-0.3023583 0.00366881 -0.05178839 -0.20865369 0.2887713 ]]奇怪的是,你为什么要更新单词嵌入结果。
https://stackoverflow.com/questions/54120792
复制相似问题