首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >当批处理大小大于1时,tensorflow dataset API不能稳定工作

当批处理大小大于1时,tensorflow dataset API不能稳定工作
EN

Stack Overflow用户
提问于 2017-09-30 20:14:21
回答 1查看 1.3K关注 0票数 1

我将一组固定长度和可变长度的特性放入一个tf.train.SequenceExample中。

代码语言:javascript
复制
context_features
    length,            scalar,                    tf.int64
    site_code_raw,     scalar,                    tf.string
    Date_Local_raw,    scalar,                    tf.string
    Time_Local_raw,    scalar,                    tf.string
Sequence_features
    Orig_RefPts,       [#batch, #RefPoints, 4]    tf.float32
    tgt_location,      [#batch, 3]                tf.float32
    tgt_val            [#batch, 1]                tf.float32

对于不同的序列示例,#RefPoints的值是可变的。我将其值存储在length特性中的context_features中。其他功能的大小是固定的。

下面是我用来读取和解析数据的代码:

代码语言:javascript
复制
def read_batch_DatasetAPI(
    filenames, 
    batch_size = 20, 
    num_epochs = None, 
    buffer_size = 5000):

    dataset = tf.contrib.data.TFRecordDataset(filenames)
    dataset = dataset.map(_parse_SeqExample1)
    if (buffer_size is not None):
        dataset = dataset.shuffle(buffer_size=buffer_size)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    # next_element contains a tuple of following tensors
    # length,            scalar,                    tf.int64
    # site_code_raw,     scalar,                    tf.string
    # Date_Local_raw,    scalar,                    tf.string
    # Time_Local_raw,    scalar,                    tf.string
    # Orig_RefPts,       [#batch, #RefPoints, 4]    tf.float32
    # tgt_location,      [#batch, 3]                tf.float32
    # tgt_val            [#batch, 1]                tf.float32

    return iterator, next_element

def _parse_SeqExample1(in_SeqEx_proto):

    # Define how to parse the example
    context_features = {
        'length': tf.FixedLenFeature([], dtype=tf.int64),
        'site_code': tf.FixedLenFeature([], dtype=tf.string),
        'Date_Local': tf.FixedLenFeature([], dtype=tf.string),
        'Time_Local': tf.FixedLenFeature([], dtype=tf.string) #,
    }

    sequence_features = {
        "input_features": tf.VarLenFeature(dtype=tf.float32),
        'tgt_location_features': tf.FixedLenSequenceFeature([3], dtype=tf.float32),
        'tgt_val_feature': tf.FixedLenSequenceFeature([1], dtype=tf.float32)   
    }                                                        

    context, sequence = tf.parse_single_sequence_example(
      in_SeqEx_proto, 
      context_features=context_features,
      sequence_features=sequence_features)

    # distribute the fetched context and sequence features into tensors
    length = context['length']
    site_code_raw = context['site_code']
    Date_Local_raw = context['Date_Local']
    Time_Local_raw = context['Time_Local']

    # reshape the tensors according to the dimension definition above
    Orig_RefPts = sequence['input_features'].values
    Orig_RefPts = tf.reshape(Orig_RefPts, [-1, 4])
    tgt_location = sequence['tgt_location_features']
    tgt_location = tf.reshape(tgt_location, [-1])
    tgt_val = sequence['tgt_val_feature']
    tgt_val = tf.reshape(tgt_val, [-1])

    return length, site_code_raw, Date_Local_raw, Time_Local_raw, \
        Orig_RefPts, tgt_location, tgt_val

当我用read_batch_DatasetAPI调用batch_size = 1时(请参阅下面的代码),它可以一个接一个地处理所有(约20万)序列示例,而不存在任何问题。但是,如果我将batch_size更改为大于1的任何数字,它只是在获取320到700个序列示例之后停止,没有任何错误消息。我不知道如何解决这个问题。任何帮助都是非常感谢的!

代码语言:javascript
复制
# the iterator to get the next_element for one sample (in sequence)
iterator, next_element = read_batch_DatasetAPI(
    in_tf_FWN,  # the file name of the tfrecords containing ~200,000 Sequence Examples
    batch_size = 1,  # works when it is 1, doesn't work if > 1
    num_epochs = 1,
    buffer_size = None)

# tf session initialization
sess = tf.Session()
sess.run(tf.global_variables_initializer())

## reset the iterator to the beginning
sess.run(iterator.initializer)

try:
    step = 0

    while (True):

        # get the next batch data
        length, site_code_raw, Date_Local_raw, Time_Local_raw, \
        Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element)

        step = step + 1

except tf.errors.OutOfRangeError:
    # Task Done (all SeqExs have been visited)
    print("closing ", in_tf_FWN)

except ValueError as err:
    print("Error: {}".format(err.args))

except Exception as err:
    print("Error: {}".format(err.args))
EN

回答 1

Stack Overflow用户

发布于 2017-10-02 18:04:14

我看到一些帖子(示例1示例2)提到了新的dataset函数from_generator (生成器)。我还不知道如何用它来解决我的问题。任何人知道如何做,请张贴作为一个新的答案。谢谢!

以下是我目前对我的问题的诊断和解决方案:

序列长度(#RefPoints)的变化引起了这一问题。只有当批处理中的dataset.map(_parse_SeqExample1)恰好相同时,#RefPoints才能工作。这就是为什么如果batch_size是1,它总是有效的,但是如果它大于1,它会在某个时候失败。

我发现dataset有一个padded_batch函数,它可以将可变长度放置到批处理中的最大长度。为了暂时解决我的问题,做了一些修改(我想from_generator将是解决我的问题的真正解决方案):

  1. _parse_SeqExample1函数中,返回语句更改为 return tf.tuple([length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts, tgt_location, tgt_val])
  2. read_batch_DatasetAPI函数中,语句 dataset = dataset.batch(batch_size) 被更改为 dataset = dataset.padded_batch(batch_size, padded_shapes=( tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([None, 4]), tf.TensorShape([3]), tf.TensorShape([1]) ) )
  3. 最后,从 length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element)[length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts_val, tgt_location, tgt_vale] = sess.run(next_element)

注意:我不知道为什么,这只适用于当前的tf-每晚-gpu版本,而不是tensorflow-GPUv1.3。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46506658

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档