首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow Python读取2个文件

Tensorflow Python读取2个文件
EN

Stack Overflow用户
提问于 2018-03-28 02:18:13
回答 1查看 2K关注 0票数 2

我正在运行以下(缩短)代码:

代码语言:javascript
复制
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
   while not coord.should_stop():      

      # Run some code.... (Reading some data from file 1)

      coord_dev = tf.train.Coordinator()
      threads_dev = tf.train.start_queue_runners(sess=sess, coord=coord_dev)

      try:
        while not coord_dev.should_stop():

           # Run some other code.... (Reading data from file 2)

      except tf.errors.OutOfRangeError:
        print('Reached end of file 2')
      finally:
        coord_dev.request_stop()
        coord_dev.join(threads_dev) 

except tf.errors.OutOfRangeError:
   print('Reached end of file 1')
finally:
   coord.request_stop()
   coord.join(threads)

上面应该发生的是:

  • 文件1是一个csv文件,包括我的神经网络的训练数据。
  • 文件2包括开发集数据。

在训练期间迭代文件1时,我偶尔也想要计算开发集数据(来自文件2)的准确性。但是当内部循环完成文件2的读取时,它显然会触发异常

"tf.errors.OutOfRangeError“

这将导致我的代码也离开外部循环。内环的异常也简单地作为外循环的例外来处理。但是在阅读完文件2之后,我希望我的代码在外部循环中继续在File 1上进行培训。

(为了简化代码的可实现性,我删除了一些细节,如num_epochs、训练等)。

有谁对如何解决这个问题有什么建议吗?我在这方面有点新。

提前谢谢你!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-03-29 04:46:47

解决了。

显然,使用queue_runners不是正确的方法。Tensorflow文档表明应该使用dataset api,这需要花费时间来理解。下面的代码完成了我之前试图做的工作。分享这里,以防其他人可能也需要它。

我在www.github.com/loheden/tf_样例/dataset api下添加了一些额外的培训代码。我挣扎了一下才找到完整的例子。

代码语言:javascript
复制
# READING DATA FROM train and validation (dev set) CSV FILES by using INITIALIZABLE ITERATORS

# All csv files have same # columns. First column is assumed to be train example ID, the next 5 columns are feature
# columns, and the last column is the label column

# ASSUMPTIONS: (Otherwise, decode_csv function needs update)
# 1) The first column is NOT a feature. (It is most probably a training example ID or similar)
# 2) The last column is always the label. And there is ONLY 1 column that represents the label.
#    If more than 1 column represents the label, see the next example down below

feature_names = ['f1','f2','f3','f4','f5']
record_defaults = [[""], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]


def decode_csv(line):
   parsed_line = tf.decode_csv(line, record_defaults)
   label =  parsed_line[-1]      # label is the last element of the list
   del parsed_line[-1]           # delete the last element from the list
   del parsed_line[0]            # even delete the first element bcz it is assumed NOT to be a feature
   features = tf.stack(parsed_line)  # Stack features so that you can later vectorize forward prop., etc.
   #label = tf.stack(label)          #NOT needed. Only if more than 1 column makes the label...
   batch_to_return = features, label
   return batch_to_return

filenames = tf.placeholder(tf.string, shape=[None])
dataset5 = tf.data.Dataset.from_tensor_slices(filenames)
dataset5 = dataset5.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv))
dataset5 = dataset5.shuffle(buffer_size=1000)
dataset5 = dataset5.batch(7)
iterator5 = dataset5.make_initializable_iterator()
next_element5 = iterator5.get_next()

# Initialize `iterator` with training data.
training_filenames = ["train_data1.csv", 
                      "train_data2.csv"]

# Initialize `iterator` with validation data.
validation_filenames = ["dev_data1.csv"]

with tf.Session() as sess:
    # Train 2 epochs. Then validate train set. Then validate dev set.
    for _ in range(2):     
        sess.run(iterator5.initializer, feed_dict={filenames: training_filenames})
        while True:
            try:
              features, labels = sess.run(next_element5)
              # Train...
              print("(train) features: ")
              print(features)
              print("(train) labels: ")
              print(labels)  
            except tf.errors.OutOfRangeError:
              print("Out of range error triggered (looped through training set 1 time)")
              break

    # Validate (cost, accuracy) on train set
    print("\nDone with the first iterator\n")

    sess.run(iterator5.initializer, feed_dict={filenames: validation_filenames})
    while True:
        try:
          features, labels = sess.run(next_element5)
          # Validate (cost, accuracy) on dev set
          print("(dev) features: ")
          print(features)
          print("(dev) labels: ")
          print(labels)
        except tf.errors.OutOfRangeError:
          print("Out of range error triggered (looped through dev set 1 time only)")
          break  
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49525056

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档