首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tfrecord VarLenFeature读取错误

tfrecord VarLenFeature读取错误
EN

Stack Overflow用户
提问于 2018-01-21 01:59:22
回答 1查看 896关注 0票数 0

我测试了如何将动态数量的变量写入tfrecord。但是VarLenFeature无法正确读取它们。

我写的代码是

代码语言:javascript
复制
def test_write():
  writer = tf.python_io.TFRecordWriter('test.tfrecord')

  for i in range(3):
    val_list = []
    for j in range(i+1):
      val_list.append(i+j)
    feature_dict = {
      'val': tf.train.Feature(int64_list=tf.train.Int64List(value=val_list)),
    }

    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    writer.write(example.SerializeToString())

  writer.close()

读取代码为

代码语言:javascript
复制
def parse_test(example):
  features = {
    'val': tf.VarLenFeature(dtype=tf.int64)
  }
  parsed_features = tf.parse_single_example(example, features)

  return parsed_features

def test_read():
  dataset = tf.data.TFRecordDataset(['test.tfrecord'])
  dataset = dataset.map(parse_test)
  dataset = dataset.batch(1)

  iterator = dataset.make_one_shot_iterator()
  feature_dict =  iterator.get_next()

  with tf.Session() as sess:
    for _ in range(3):
      curr_dict = sess.run(feature_dict)
      print([curr_dict['val']])

错误消息为:

代码语言:javascript
复制
TypeError: Failed to convert object of type <class 'tensorflow.python.framework.sparse_tensor.SparseTensor'> to Tensor. Contents: SparseTensor(indices=Tensor("ParseSingleExample/Slice_Indices_val:0", shape=(?, 1), dtype=int64), values=Tensor("ParseSingleExample/ParseExample/ParseExample:1", shape=(?,), dtype=int64), dense_shape=Tensor("ParseSingleExample/Squeeze_Shape_val:0", shape=(1,), dtype=int64)). Consider casting elements to a supported type.

但是,如果我不使用dataset,而只使用tf.python_io.tf_record_iterator。程序运行正常,没有任何问题。此代码如下所示

代码语言:javascript
复制
def test_read2():
  with tf.Session() as sess:
    for serialized_example in tf.python_io.tf_record_iterator('test.tfrecord'):
      features = tf.parse_single_example(serialized_example,
        features={
          'val': tf.VarLenFeature(dtype=tf.int64),
        }
      )

      temp = features['val']

      values = sess.run(temp)
      print(values)

此代码已成功打印出来

代码语言:javascript
复制
SparseTensorValue(indices=array([[0]], dtype=int64), values=array([0], dtype=int64), dense_shape=array([1], dtype=int64))
SparseTensorValue(indices=array([[0],
       [1]], dtype=int64), values=array([1, 2], dtype=int64), dense_shape=array([2], dtype=int64))
SparseTensorValue(indices=array([[0],
       [1],
       [2]], dtype=int64), values=array([2, 3, 4], dtype=int64), dense_shape=array([3], dtype=int64))

但是,我仍然希望使用dataset结构来处理VarLenFeature。我的阅读代码有什么问题吗?谢谢。

EN

回答 1

Stack Overflow用户

发布于 2018-02-09 15:39:57

也许您需要在parse_test()函数中执行此操作

代码语言:javascript
复制
def parse_test(example):
  features = {
    'val': tf.VarLenFeature(dtype=tf.int64)
  }
  parsed_dict = tf.parse_example(example, features)
  parsed_features = {"val": tf.sparse_tensor_to_dense(parsed_dict ["val"], default_value=0)}

  return parsed_features
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48359385

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档