首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow前馈网络会话未停止

Tensorflow前馈网络会话未停止
EN

Stack Overflow用户
提问于 2017-02-06 03:38:35
回答 1查看 433关注 0票数 0

我正在尝试使用TensorFlow及其tfr格式来构建一个简单的前馈神经网络。我一直在使用TensorFlow的教程和示例作为参考:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/how_tos/reading_data

给定“食品”浮点值,我想预测它产生的“幸福”浮点值。

food_test.json是一个JSON文件,其中包含“食品”值及其关联的“幸福”值。这是数据存储的格式。

food_to_record.py基于tensorflow的convert_to_records.py。它读取food_test.json并将其转换为food_record.tfr文件。

food_reader.py基于tensorflow的fully_connected_reader.py。它读取food_record.tfr文件并通过神经网络运行数据。

我按如下顺序运行程序: 1. food_to_record.py 2. food_reader.py

当food_reader.py运行时,它会启动一个TensorFlow会话,但会话永远不会终止,有人知道这是什么原因吗?

food_test.json:

代码语言:javascript
复制
[
  {
    "food": 1.0,
    "happiness": 2.0
  },
  {
    "food": 1.4,
    "happiness": 5.4
  }
]

food_to_record.py:

代码语言:javascript
复制
#based off of tensorflow's convert_to_records.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys
import json

import tensorflow as tf


FLAGS = None


#feature for integers
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#feature for floats
def _float_feature(value):
    return tf.train.Feature(float_list = tf.train.FloatList(value= [value]))
#feature for strings and others
def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def main(unused_argv):
    print("food_to_record:main")
    script_dir = os.path.dirname(__file__)
    file_path = os.path.join(script_dir, 'food_test.json')
    with open(file_path) as data_file:    
        data = json.load(data_file)

    print(data)
    num_examples = 2

    name = 'food_record'
    filename = os.path.join(FLAGS.directory, name + '.tfrecords')
    print('Writing', filename)
    writer = tf.python_io.TFRecordWriter(filename)
    for index in range(num_examples):
        example = tf.train.Example(features=tf.train.Features(feature={
            'food': _float_feature(data[index]['food']),
            'happiness': _float_feature(data[index]['happiness'])
            }))
        writer.write(example.SerializeToString())
    writer.close()

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--directory',
      type=str,
      default='.',
      help='Directory to download data files and write the converted result'
  )
  parser.add_argument(
      '--validation_size',
      type=int,
      default=5000,
      help="""\
      Number of examples to separate from the training data for the validation
      set.\
      """
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

food_reader.py:

代码语言:javascript
复制
#based off of tensorflow's fully_connected_reader

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import sys
import time

import tensorflow as tf


# Basic model parameters as external flags.
FLAGS = None

# Constants used for dealing with the files
TRAIN_FILE = 'food_record.tfrecords'
# For simple testing purposes, use training file for validation 
VALIDATION_FILE = 'food_record.tfrecords'


def read_and_decode(filename_queue):
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
      serialized_example,
      # Defaults are not specified since both keys are required.
      features={
          'food': tf.FixedLenFeature([], tf.float32),
          'happiness': tf.FixedLenFeature([], tf.float32)
      })


  food = tf.cast(features['food'], tf.float32)
  happiness = tf.cast(features['happiness'], tf.float32)


  food = tf.expand_dims(food, -1)

  print("food shape: ", tf.shape(food))
  print("happiness shape: ", tf.shape(happiness))

  return food, happiness


def inputs(train, batch_size, num_epochs):
  """Reads input data num_epochs times.

  Args:
    train: Selects between the training (True) and validation (False) data.
    batch_size: Number of examples per returned batch.
    num_epochs: Number of times to read the input data, or 0/None to
       train forever.

  Returns:
    A tuple (images, labels), where:
    * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
      in the range [-0.5, 0.5].
    * labels is an int32 tensor with shape [batch_size] with the true label,
      a number in the range [0, mnist.NUM_CLASSES).
    Note that an tf.train.QueueRunner is added to the graph, which
    must be run using e.g. tf.train.start_queue_runners().
  """
  if not num_epochs: num_epochs = None
  filename = os.path.join(FLAGS.train_dir,
                          TRAIN_FILE if train else VALIDATION_FILE)

  with tf.name_scope('input'):
    filename_queue = tf.train.string_input_producer(
        [filename], num_epochs=num_epochs)

    # Even when reading in multiple threads, share the filename
    # queue.
    food, happiness = read_and_decode(filename_queue)

    # Shuffle the examples and collect them into batch_size batches.
    # (Internally uses a RandomShuffleQueue.)
    # We run this in two threads to avoid being a bottleneck.
    foods, happinesses= tf.train.shuffle_batch(
        [food, happiness], batch_size=batch_size, num_threads=2,
        capacity=1000 + 3 * batch_size,
        # Ensures a minimum amount of shuffling of examples.
        min_after_dequeue=1000)

    return foods, happinesses





def main(_):
  with tf.Graph().as_default():
    # Input images and labels.
    foods, happinesses = inputs(train=True, batch_size=FLAGS.batch_size,
                            num_epochs=FLAGS.num_epochs)

    HIDDEN_UNITS = 4 

    INPUTS = 1
    OUTPUTS = 1


    weights_1 = tf.Variable(tf.truncated_normal([INPUTS, HIDDEN_UNITS]))
    biases_1 = tf.Variable(tf.zeros([HIDDEN_UNITS]))

    layer_1_outputs = tf.nn.sigmoid(tf.matmul(foods, weights_1) + biases_1)

    weights_2 = tf.Variable(tf.truncated_normal([HIDDEN_UNITS, OUTPUTS]))
    biases_2 = tf.Variable(tf.zeros([OUTPUTS]))

    logits = tf.nn.sigmoid(tf.matmul(layer_1_outputs, weights_2) + biases_2)

    #loss = tf.reduce_mean(logits)

    labels = tf.to_int64(happinesses)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits, name='xentropy')
    #loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
    loss = tf.reduce_sum(tf.sub(logits, happinesses))

    learning_rate = 0.01
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optimizer.minimize(loss)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    sess = tf.Session()
    sess.run(init_op)

    print('staring iteration', 0)
    _, loss = sess.run([train_op, loss])
    print(loss)

    sess.close()





if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='Initial learning rate.'
  )
  parser.add_argument(
      '--num_epochs',
      type=int,
      default=2,
      help='Number of epochs to run trainer.'
  )
  parser.add_argument(
      '--hidden1',
      type=int,
      default=128,
      help='Number of units in hidden layer 1.'
  )
  parser.add_argument(
      '--hidden2',
      type=int,
      default=32,
      help='Number of units in hidden layer 2.'
  )
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Batch size.'
  )
  parser.add_argument(
      '--train_dir',
      type=str,
      default='.',
      help='Directory with the training data.'
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
EN

回答 1

Stack Overflow用户

发布于 2017-02-06 05:25:43

在调用runeval执行读取文件之前,必须先调用tf.train.start_queue_runners来填充队列。否则,在等待队列中的文件名时,读取将会阻塞。请查看原始示例中的run_training方法,或tensorflow的有关how_tos/reading_data的文档。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/42056396

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档