首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将SparseTensor写入tfrecord文件和从tfrecord文件读取tfrecord

将SparseTensor写入tfrecord文件和从tfrecord文件读取tfrecord
EN

Stack Overflow用户
提问于 2017-10-19 13:44:07
回答 2查看 3.1K关注 0票数 4

有没有可能做到这一点呢?

现在我唯一能想到的就是将SparseTensor的索引(tf.int64),值(tf.float32)和形状(tf.int64)保存在3个独立的特征中(前两个是VarLenFeature,最后一个是FixedLenFeature)。这看起来真的很麻烦。

任何建议都是非常感谢的!

更新1

我下面的答案不适合构建计算图(b/c稀疏张量的内容必须通过sess.run()提取,如果重复调用,这将耗费大量时间。)

mrry's answer的启发,我认为也许我们可以获得tf.serialize_sparse生成的字节,这样以后我们就可以使用tf.deserialize_many_sparse恢复SparseTensor。但是tf.serialize_sparse不是用纯python实现的(它调用外部函数SerializeSparse),这意味着我们仍然需要使用sess.run()来获取字节数。如何获得纯python版本的SerializeSparse?谢谢。

EN

回答 2

Stack Overflow用户

发布于 2019-04-01 14:05:45

我遇到了在TFRecord文件中写入和读取稀疏张量的问题,我在网上找到的有关这方面的信息很少。

正如您所建议的,一种解决方案是将SparseTensor的索引、值和形状存储在3个单独的特征中,这将在here中讨论。这看起来既不高效也不优雅。

我有一个有效的例子(使用TensorFlow 2.0.0.alpha0)。也许不是最优雅的,但看起来很管用。

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

# Example data
st_1 = tf.SparseTensor(indices=[[0,0],[1,2]], values=[1,2], dense_shape=[3,4])
st_2 = tf.SparseTensor(indices=[[0,1],[2,0],[3,3]], values=[3,9,5], dense_shape=[4, 4])
sparse_tensors = [st_1, st_2]

# Serialize sparse tensors to an array of byte strings
serialized_sparse_tensors = [tf.io.serialize_sparse(st).numpy() for st in sparse_tensors]

# Write to TFRecord
with tf.io.TFRecordWriter('sparse_example.tfrecord') as tfwriter:
    for sst in serialized_sparse_tensors:
        sparse_example = tf.train.Example(features = 
                     tf.train.Features(feature=
                         {'sparse_tensor': 
                               tf.train.Feature(bytes_list=tf.train.BytesList(value=sst))
                         }))
        # Append each example into tfrecord
        tfwriter.write(sparse_example.SerializeToString())

def parse_fn(data_element):
    features = {'sparse_tensor': tf.io.FixedLenFeature([3], tf.string)}
    parsed = tf.io.parse_single_example(data_element, features=features)

    # tf.io.deserialize_many_sparse() requires the dimensions to be [N,3] so we add one dimension with expand_dims
    parsed['sparse_tensor'] = tf.expand_dims(parsed['sparse_tensor'], axis=0)
    # deserialize sparse tensor
    parsed['sparse_tensor'] = tf.io.deserialize_many_sparse(parsed['sparse_tensor'], dtype=tf.int32)
    # convert from sparse to dense
    parsed['sparse_tensor'] = tf.sparse.to_dense(parsed['sparse_tensor'])
    # remove extra dimenson [1, 3] -> [3]
    parsed['sparse_tensor'] = tf.squeeze(parsed['sparse_tensor'])
    return parsed

# Read from TFRecord
dataset = tf.data.TFRecordDataset(['sparse_example.tfrecord'])
dataset = dataset.map(parse_fn)
# Pad and batch dataset
dataset = dataset.padded_batch(2, padded_shapes={'sparse_tensor':[None,None]})

dataset.__iter__().get_next()

这将输出以下内容:

代码语言:javascript
复制
{'sparse_tensor': <tf.Tensor: id=295, shape=(2, 4, 4), dtype=int32, numpy=
     array([[[1, 0, 0, 0],
             [0, 0, 2, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]],

            [[0, 3, 0, 0],
             [0, 0, 0, 0],
             [9, 0, 0, 0],
             [0, 0, 0, 5]]], dtype=int32)>}
票数 4
EN

Stack Overflow用户

发布于 2017-10-22 08:31:26

由于Tensorflow目前在SparseTensor中只支持3种类型: Float、Int64和Bytes,而Tensorflow通常有不止1种类型,因此我的解决方案是用Pickle将tfrecord转换为Bytes。

下面是一个示例代码:

代码语言:javascript
复制
import tensorflow as tf
import pickle
import numpy as np
from scipy.sparse import csr_matrix

#---------------------------------#
# Write to a tfrecord file

# create two sparse matrices (simulate the values from .eval() of SparseTensor)
a = csr_matrix(np.arange(12).reshape((4,3)))
b = csr_matrix(np.random.rand(20).reshape((5,4)))

# convert them to pickle bytes
p_a = pickle.dumps(a)
p_b = pickle.dumps(b)

# put the bytes in context_list and feature_list
## save p_a in context_lists 
context_lists = tf.train.Features(feature={
    'context_a': tf.train.Feature(bytes_list=tf.train.BytesList(value=[p_a]))
    })
## save p_b as a one element sequence in feature_lists
p_b_features = [tf.train.Feature(bytes_list=tf.train.BytesList(value=[p_b]))]
feature_lists = tf.train.FeatureLists(feature_list={
    'features_b': tf.train.FeatureList(feature=p_b_features)
    })

# create the SequenceExample
SeqEx = tf.train.SequenceExample(
    context = context_lists,
    feature_lists = feature_lists
    )
SeqEx_serialized = SeqEx.SerializeToString()

# write to a tfrecord file
tf_FWN = 'test_pickle1.tfrecord'
tf_writer1 = tf.python_io.TFRecordWriter(tf_FWN)
tf_writer1.write(SeqEx_serialized)
tf_writer1.close()

#---------------------------------#
# Read from the tfrecord file

# first, define the parse function
def _parse_SE_test_pickle1(in_example_proto):
    context_features = {
        'context_a': tf.FixedLenFeature([], dtype=tf.string)
        }
    sequence_features = {
        'features_b': tf.FixedLenSequenceFeature([1], dtype=tf.string)
        }
    context, sequence = tf.parse_single_sequence_example(
      in_example_proto, 
      context_features=context_features,
      sequence_features=sequence_features
      )
    p_a_tf = context['context_a']
    p_b_tf = sequence['features_b']

    return tf.tuple([p_a_tf, p_b_tf])

# use the Dataset API to read
dataset = tf.data.TFRecordDataset(tf_FWN)
dataset = dataset.map(_parse_SE_test_pickle1)
dataset = dataset.batch(1)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)

[p_a_bat, p_b_bat] = sess.run(next_element)

# 1st index refers to batch, 2nd and 3rd indices refers to the sequence position (only for b)
rec_a = pickle.loads(p_a_bat[0])
rec_b = pickle.loads(p_b_bat[0][0][0])

# check whether the recovered the same as the original ones.
assert((rec_a - a).nnz == 0)
assert((rec_b - b).nnz == 0)

# print the contents
print("\n------ a -------")
print(a.todense())
print("\n------ rec_a -------")
print(rec_a.todense())
print("\n------ b -------")
print(b.todense())
print("\n------ rec_b -------")
print(rec_b.todense())

下面是我得到的信息:

代码语言:javascript
复制
------ a -------
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]

------ rec_a -------
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]

------ b -------
[[ 0.88612402  0.51438017  0.20077887  0.20969243]
 [ 0.41762425  0.47394715  0.35596051  0.96074408]
 [ 0.35491739  0.0761953   0.86217511  0.45796474]
 [ 0.81253723  0.57032448  0.94959189  0.10139615]
 [ 0.92177499  0.83519464  0.96679833  0.41397829]]

------ rec_b -------
[[ 0.88612402  0.51438017  0.20077887  0.20969243]
 [ 0.41762425  0.47394715  0.35596051  0.96074408]
 [ 0.35491739  0.0761953   0.86217511  0.45796474]
 [ 0.81253723  0.57032448  0.94959189  0.10139615]
 [ 0.92177499  0.83519464  0.96679833  0.41397829]]
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46823440

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档