首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow:批量培训在sess.run中永远卡住

Tensorflow:批量培训在sess.run中永远卡住
EN

Stack Overflow用户
提问于 2017-06-07 08:50:18
回答 1查看 2K关注 0票数 2

我正试着一批一批地训练我的模型,因为我找不到任何例子来说明如何正确地完成它。这是我所能做到的,我的任务是找到如何在Tensorflow中一批一批地训练一个模型。

代码语言:javascript
复制
queue=tf.FIFOQueue(capacity=50,dtypes=[tf.float32,tf.float32],shapes=[[10],[2]])
enqueue_op=queue.enqueue_many([X,Y])
dequeue_op=queue.dequeue()

qr=tf.train.QueueRunner(queue,[enqueue_op]*2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    X_train_batch,y_train_batch=tf.train.batch(dequeue_op,batch_size=2)
    coord=tf.train.Coordinator()
    enqueue_threads=qr.create_threads(sess,coord,start=True)
    sess.run(tf.local_variables_initializer())
    for epoch in range(100):
        print("inside loop1")
        for iter in range(5):
            print("inside loop2")
            if coord.should_stop():
                break
            batch_x,batch_y=sess.run([X_train_batch,y_train_batch])
            print("after sess.run")
            print(batch_x.shape)
            _=sess.run(optimizer,feed_dict={x_place:batch_x,y_place:batch_y})
        coord.request_stop()
        coord.join(enqueue_threads)

输出,

代码语言:javascript
复制
inside loop1
inside loop2

正如您所看到的,当它运行batch_x,batch_y=sess.run([X_train_batch,y_train_batch])行时,它将永远停留在这个位置。我不知道如何解决这个问题,或者这是一批一批训练模型的正确方法吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-06-07 17:11:36

经过几个小时的搜索,我自己找到了解决方案。所以,我现在回答我自己的问题。队列由后台线程填充,如果不调用此方法,则在调用tf.train.start_queue_runners()时创建这些线程,后台线程将不会启动,队列将保持空,培训操作将无限期地阻塞等待输入。

FIX:在训练循环之前调用tf.train.start_queue_runners(sess)。就像我在下面做的那样:

代码语言:javascript
复制
queue=tf.FIFOQueue(capacity=50,dtypes=[tf.float32,tf.float32],shapes=[[10],[2]])
enqueue_op=queue.enqueue_many([X,Y])
dequeue_op=queue.dequeue()

qr=tf.train.QueueRunner(queue,[enqueue_op]*2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    X_train_batch,y_train_batch=tf.train.batch(dequeue_op,batch_size=2)
    coord=tf.train.Coordinator()
    enqueue_threads=qr.create_threads(sess,coord,start=True)
    tf.train.start_queue_runners(sess)
    for epoch in range(100):
        print("inside loop1")
        for iter in range(5):
            print("inside loop2")
            if coord.should_stop():
                break
            batch_x,batch_y=sess.run([X_train_batch,y_train_batch])
            print("after sess.run")
            print(batch_x.shape)
            _=sess.run(optimizer,feed_dict={x_place:batch_x,y_place:batch_y})
        coord.request_stop()
        coord.join(enqueue_threads)
票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44407873

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档