首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何将张量数据类型float32_ref转换为数据类型float32?

如何将张量数据类型float32_ref转换为数据类型float32?
EN

Stack Overflow用户
提问于 2019-01-10 09:26:06
回答 1查看 1.5K关注 0票数 0
代码语言:javascript
复制
word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings,ids=self.word_ids)
word_embeddings_modify = tf.scatter_nd_update(word_embeddings, self.error_word, sum_all)

Error:
Tensor conversion requested dtype float32_ref for Tensor with dtype float32

从错误中看,函数scatter_nd_update中的word_embeddings实际的dtypetf.float_32,但scatter_nd_update应该接受word_embeddings dtype tf.float_32_ref

如何在使用tf.float_32之前将word_embeddings's dtypetf.float_32更改为tf.float_32_ref

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-01-10 15:37:21

您可以使用tf.Variable()直接转换数据类型。举个例子:

代码语言:javascript
复制
import tensorflow as tf

_word_embeddings = tf.get_variable(name='embedding',shape=[30, 5])

word_ids = [3,6,23]
word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings,ids=word_ids)

error_word = [[1]]
sum_all = [[0,0,0,0,0]]
word_embeddings_modify = tf.scatter_nd_update(tf.Variable(word_embeddings), error_word, sum_all)

with tf.Session()as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(word_embeddings))
    print(sess.run(word_embeddings_modify))
[[ 0.08698401 -0.15590087  0.00285593 -0.13804913 -0.12418613]
 [-0.25748074  0.32121882 -0.390212    0.24590132  0.3976703 ]
 [-0.3023583   0.00366881 -0.05178839 -0.20865369  0.2887713 ]]
[[ 0.08698401 -0.15590087  0.00285593 -0.13804913 -0.12418613]
 [ 0.          0.          0.          0.          0.        ]
 [-0.3023583   0.00366881 -0.05178839 -0.20865369  0.2887713 ]]

奇怪的是,你为什么要更新单词嵌入结果。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54120792

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档