首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >对于可变长度字符串,我应该使用什么来代替tf.decode_raw?

对于可变长度字符串,我应该使用什么来代替tf.decode_raw?
EN

Stack Overflow用户
提问于 2018-01-23 15:38:24
回答 1查看 4.1K关注 0票数 3

我有一个功能列,它只是一个字符串:

代码语言:javascript
复制
tf.FixedLenFeature((), tf.string)

我的图使用tf.decode_raw将张量转换为二进制

代码语言:javascript
复制
tf.decode_raw(features['text'], tf.uint8)

这在batch_size = 1时有效,但当字符串长度不同时,batch_size >1则无效。decode_raw抛出DecodeRaw requires input strings to all be the same size

除了tf.decode_raw之外,是否还有返回填充张量和字符串长度的方法?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-01-24 02:06:49

我会用tf.data.Dataset。启用急切执行后:

代码语言:javascript
复制
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()

def _decode_and_length_map(encoded_string):
  decoded = tf.decode_raw(encoded_string, out_type=tf.uint8)
  return decoded, tf.shape(decoded)[0]

inputs = tf.constant(["aaa", "bbbbbbbb", "abcde"], dtype=tf.string)
dataset = (tf.data.Dataset.from_tensor_slices(inputs)
           .map(_decode_and_length_map)
           .padded_batch(batch_size=2, padded_shapes=([None], [])))
iterator = tfe.Iterator(dataset)
print(iterator.next())
print(iterator.next())

印刷品(免责声明:手动重新格式化)

代码语言:javascript
复制
(<tf.Tensor: id=24, shape=(2, 8), dtype=uint8,
     numpy=array([[97, 97, 97,  0,  0,  0,  0,  0],
                  [98, 98, 98, 98, 98, 98, 98, 98]], dtype=uint8)>,
 <tf.Tensor: id=25, shape=(2,), dtype=int32, numpy=array([3, 8], dtype=int32)>)

(<tf.Tensor: id=28, shape=(1, 5), dtype=uint8, 
     numpy=array([[ 97,  98,  99, 100, 101]], dtype=uint8)>,
 <tf.Tensor: id=29, shape=(1,), dtype=int32, numpy=array([5], dtype=int32)>)

当然,您可以混合和匹配数据源,添加随机化,更改填充字符等。

也适用于图形构建:

代码语言:javascript
复制
import tensorflow as tf

def _decode_and_length_map(encoded_string):
  decoded = tf.decode_raw(encoded_string, out_type=tf.uint8)
  return decoded, tf.shape(decoded)[0]

inputs = tf.constant(["aaa", "bbbbbbbb", "abcde"], dtype=tf.string)
dataset = (tf.data.Dataset.from_tensor_slices(inputs)
           .map(_decode_and_length_map)
           .padded_batch(batch_size=2, padded_shapes=([None], [])))
batch_op = dataset.make_one_shot_iterator().get_next()
with tf.Session() as session:
  print(session.run(batch_op))
  print(session.run(batch_op))
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48396442

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档