首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >通过加载一个intermediate_output_graphs(.pb)来继续训练(image_retraining/retraining.py)

通过加载一个intermediate_output_graphs(.pb)来继续训练(image_retraining/retraining.py)
EN

Stack Overflow用户
提问于 2017-10-29 00:52:26
回答 1查看 1.3K关注 0票数 0

我使用的是tensorflow存储库的image_retraining文件夹中提供的retrain脚本。

解析器参数/标志之一允许您每X步存储一次中间图

代码语言:javascript
复制
parser.add_argument(
      '--intermediate_output_graphs_dir',
      type=str,
      default='tf_files2/tmp/intermediate_graph/',
      help='Where to save the intermediate graphs.'

但是,这似乎将图形存储为具有.pb扩展名的冻结图形。关于如何正确加载.pb文件以继续训练的信息很少。我找到的大多数信息都使用.meta图形和.ckpts。.pb会被弃用吗?

如果是这样,我是否应该从开始重新训练模型,并使用tf.Saver来获得.meta和ckpt图作为中间检查点?

昨天,我在训练一个模型,由于某种原因,训练冻结了,所以我想加载中间图,然后继续训练。

我正在使用inception模型进行再培训:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py

如果有人能告诉我或向我展示如何正确地加载.pb中间图(一步一步地)并从我停止的地方继续--我将不胜感激。

谢谢。

编辑:

@明兴

所以我假设我应该让retrain.py首先基于默认的初始模型(下面的函数)创建默认的图,然后用加载的图覆盖它?

代码语言:javascript
复制
def create_model_graph(model_info):
  """"Creates a graph from saved GraphDef file and returns a Graph object.

  Args:
    model_info: Dictionary containing information about the model architecture.

  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Graph().as_default() as graph:
    model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
    with gfile.FastGFile(model_path, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(
          graph_def,
          name='',
          return_elements=[
              model_info['bottleneck_tensor_name'],
              model_info['resized_input_tensor_name'],
          ]))
  return graph, bottleneck_tensor, resized_input_tensor

EDIT_2:

我得到的一个错误是:

代码语言:javascript
复制
ValueError: Tensor("second_to_final_fC_layer_ops/weights/final_weights_1:0", shape=(2048, 102
4), dtype=float32_ref) must be from the same graph as Tensor("BottleneckInputPlaceholder:0",
shape=(?, 2048), dtype=float32).

我在第一个FC层之后添加了一个额外的FC层。So 2048 -> 1024 ->训练前的类数。

当训练模型时,我没有问题,但现在加载图形时,我似乎遇到了上面的错误。

下面是添加的层的外观:

代码语言:javascript
复制
layer_name = 'second_to_final_fC_layer_ops'
  with tf.name_scope(layer_name):
    with tf.name_scope('weights'):
      initial_value = tf.truncated_normal(
          [bottleneck_tensor_size, 1024], stddev=0.001)

      layer_weights = tf.Variable(initial_value, name='weights')

      variable_summaries(layer_weights)
   with tf.name_scope('biases'):
      layer_biases = tf.Variable(tf.zeros([1024]), name='biases')
      variable_summaries(layer_biases)
   with tf.name_scope('Wx_plus_b'):
      logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
      tf.summary.histogram('pre_activations', logits)
    with tf.name_scope('Relu_activation'):
      relu_activated =tf.nn.relu(logits, name= 'Relu')
      tf.summary.histogram('final_relu_activation', relu_activated)

然后是最后一层(这是最初的最后一层,但现在的输入是来自最后一层的输出,而不是瓶颈张量):

代码语言:javascript
复制
layer_name = 'final_training_ops'
  with tf.name_scope(layer_name):
    with tf.name_scope('weights'):
      initial_value = tf.truncated_normal(
          [1024, class_count], stddev=0.001)

      layer_weights = tf.Variable(initial_value, name='final_weights')

      variable_summaries(layer_weights)
    with tf.name_scope('biases'):
      layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
      variable_summaries(layer_biases)
    with tf.name_scope('Wx_plus_b'):
      logits = tf.matmul(relu_activated, layer_weights) + layer_biases
      tf.summary.histogram('pre_activations', logits)

  final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
  tf.summary.histogram('activations', final_tensor)

编辑:仍然不知道如何加载权重--加载图形结构似乎很容易,但我不知道如何加载再次使用迁移学习训练的Inception的权重和输入。

一个清晰的例子使用来自image_retraining/retraining.py的权重和变量会非常有帮助。谢谢。

EN

回答 1

Stack Overflow用户

发布于 2017-10-29 01:24:31

您可以使用tf.import_graph_def导入冻结的.pb文件:

代码语言:javascript
复制
# Read the .pb file into graph_def.
with tf.gfile.GFile(FLAGS.graph, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# Restore the graph. 
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name="")

# After this, graph is the what you need.

虽然直接使用冻结的.pb文件没有什么问题,但我仍然想指出,推荐的方法是遵循标准的保存/恢复(official doc)。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46992208

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档