首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将TensorFlow sparse_tensor_dense_matmul转换为embedding_lookup_sparse格式

将TensorFlow sparse_tensor_dense_matmul转换为embedding_lookup_sparse格式
EN

Stack Overflow用户
提问于 2017-09-08 19:18:01
回答 1查看 176关注 0票数 1

TensorFlow 文档提到sparse_tensor_dense_matmul所期望的SparseTensor格式是:sp_a (indices, values)

代码语言:javascript
复制
[0, 1]: a
[1, 0]: b
[1, 4]: c
[2, 2]: d

SparseTensor格式embedding_lookup_sparsesp_ids sp_weights

代码语言:javascript
复制
[0, 0]: 1                [0, 0]: a
[1, 0]: 0                [1, 0]: b
[1, 1]: 4                [1, 1]: c
[2, 0]: 2                [2, 0]: d

如何将sp_a转换为sp_ids,将sp_weights转换为TensorFlow中的第二个?如果不可能的话,我该怎么做呢?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-09-14 09:06:44

我忽略了Tensorflow API中是否存在用于此转换的函数,但这就是我如何将sp_a转换为sp_idssp_weights的方法。

代码语言:javascript
复制
import tensorflow as tf


indices = tf.constant([[0, 1],
                       [1, 0],
                       [1, 4],
                       [2, 2]], dtype=tf.int64)
values = tf.constant([1, 2, 3, 4])  # a, b, c, d
dense_shape = tf.constant([3, 5], dtype=tf.int64)
sp_a = tf.SparseTensor(indices=indices,
                       values=values,
                       dense_shape=dense_shape)

# transform sp_a into sp_ids and sp_weights
# Get sp_ids values
sp_ids_values = tf.slice(sp_a.indices,
                         begin=[0, 1],
                         size=[-1, 1])
sp_ids_values = tf.squeeze(sp_ids_values)

# Get the indices for sp_ids and sp_weights
d1 = tf.slice(sp_a.indices,
              begin=[0, 0],
              size=[-1, 1])
d2 = tf.expand_dims(scan_accum(tf.squeeze(d1)),
                    axis=1)
indices_ = tf.concat([d1, d2],
                     axis=1)

# Build sp_ids and sp_weights
sp_ids = tf.SparseTensor(indices=indices_,
                         values=sp_ids_values,
                         dense_shape=sp_a.dense_shape)
sp_weights = tf.SparseTensor(indices=indices_,
                             values=sp_a.values,
                             dense_shape=sp_a.dense_shape)

with tf.Session() as sess:
    print(sess.run(sp_ids))
    print(sess.run(sp_weights))

我定义了scan_accum 这里

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46123241

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档