首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow输入管道:读取示例不止一次

tensorflow输入管道:读取示例不止一次
EN

Stack Overflow用户
提问于 2016-12-16 16:40:15
回答 1查看 219关注 0票数 0

我正在尝试实现一个从TFRecords二进制文件读取的模型的输入管道;每个二进制文件包含一个示例(图像、标签、其他我需要的东西)。

我有一个带有文件路径列表的文本文件,然后:

  1. 我将文本文件读取为一个列表,并将其提供给string_input_producer()以生成队列;
  2. 我将队列提供给读取序列化示例的TFRecordReader,并解码二进制数据。
  3. 我使用shuffle_batch()将示例分成几个批次
  4. 我用批次来评估我的模型

问题在于,同一个示例可以多次读取,有些示例可能根本无法访问;我将步骤数设置为图像总数除以批处理大小;因此,我希望在最后一步结束时访问所有输入示例,但事实并非如此;相反,有些示例被多次访问,有些从未(随机)访问过;这使得我的测试评估完全无法实现。

如果有人知道我做错了什么,请告诉我。

下面是我的模型测试代码的简化版本,谢谢!

代码语言:javascript
复制
def my_input(file_list, batch_size)

    filename = []
    f = open(file_list, 'r')
    for line in f:
        filename.append(params.TEST_RECORDS_DATA_DIR + line[:-1])

    filename_queue = tf.train.string_input_producer(filename)

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label_raw': tf.FixedLenFeature([], tf.string),
            'name': tf.FixedLenFeature([], tf.string)
            })

    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image.set_shape(params.IMAGE_HEIGHT*params.IMAGE_WIDTH*3)
    image = tf.reshape(image, (params.IMAGE_HEIGHT,params.IMAGE_WIDTH,3))
    image = tf.cast(image, tf.float32)/255.0
    image = preprocess(image)

    label = tf.decode_raw(features['label_raw'], tf.uint8)
    label.set_shape(params.NUM_CLASSES)

    name = features['name']

    images, labels, image_names = tf.train.batch([image, label, name],
            batch_size=batch_size, num_threads=2,
            capacity=1000 + 3 * batch_size, min_after_dequeue=1000)

    return images, labels, image_names


def main()

    with tf.Graph().as_default():

        # call input operations
        images, labels, image_names = my_input(file_list=params.TEST_FILE_LIST, batch_size=params.BATCH_SIZE)

        # load a trained model and make predictions     
        prediction = infer(images, labels, image_names)

        with tf.Session() as sess:

            for step in range(params.N_STEPS):
                prediction_values = sess.run([prediction])
                # process output

    return
EN

回答 1

Stack Overflow用户

发布于 2016-12-16 16:47:44

我的猜测是,tf.train.string_input_producer(filename)被设置为无限期地生成文件名,如果在多个(2)线程中对示例进行批处理,则可能是一个线程第二次开始处理该文件,而另一个线程尚未完成第一轮。要正确读取每个示例,请使用:

代码语言:javascript
复制
tf.train.string_input_producer(filename, num_epochs=1)

并在会话开始时初始化局部变量:

代码语言:javascript
复制
sess.run(tf.initialize_local_variables())
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41188816

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档