首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无整形层的再训练初始v3模型

无整形层的再训练初始v3模型
EN

Stack Overflow用户
提问于 2017-08-21 17:12:33
回答 2查看 477关注 0票数 1

我已经为自定义数据集重新训练了inception v3模型。但是在重新训练之后,当我查看TenosorGraph时,我发现添加了一个名为reshape的层,后面跟着一个完全连接的层。我不得不在使用snapdragonneural神经处理引擎(SNPE)的嵌入式设备上运行该模型,但它目前还不支持在DSP上运行重形层。

有没有可能在不增加重塑层的情况下重新训练初始v3。下面是重新训练代码,其中添加了重塑图层。

代码语言:javascript
复制
enter code here
              def create_model_info(architecture):
  """Given the name of a model architecture, returns information about it.

  There are different base image recognition pretrained models that can be
  retrained using transfer learning, and this function translates from the name
  of a model to the attributes that are needed to download and train with it.

  Args:
    architecture: Name of a model architecture.

  Returns:
    Dictionary of information about the model, or None if the name isn't
    recognized

  Raises:
    ValueError: If architecture name is unknown.
  """
  architecture = architecture.lower()
  if architecture == 'inception_v3':
    # pylint: disable=line-too-long
    data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
    # pylint: enable=line-too-long
    bottleneck_tensor_name = 'pool_3/_reshape:0'
    bottleneck_tensor_size = 2048
    input_width = 299
    input_height = 299
    input_depth = 3
    resized_input_tensor_name = 'Mul:0'
    model_file_name = 'classify_image_graph_def.pb'
    input_mean = 128
    input_std = 128
      elif architecture.startswith('mobilenet_'):
        parts = architecture.split('_')
        if len(parts) != 3 and len(parts) != 4:
          tf.logging.error("Couldn't understand architecture name '%s'",
                           architecture)
          return None
        version_string = parts[1]
        if (version_string != '1.0' and version_string != '0.75' and
            version_string != '0.50' and version_string != '0.25'):
          tf.logging.error(
              """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25',
      but found '%s' for architecture '%s'""",
              version_string, architecture)
          return None
        size_string = parts[2]
        if (size_string != '224' and size_string != '192' and
            size_string != '160' and size_string != '128'):
          tf.logging.error(
              """The Mobilenet input size should be '224', '192', '160', or '128',
     but found '%s' for architecture '%s'""",
              size_string, architecture)
          return None
        if len(parts) == 3:
          is_quantized = False
        else:
          if parts[3] != 'quantized':
            tf.logging.error(
                "Couldn't understand architecture suffix '%s' for '%s'", parts[3],
                architecture)
            return None
          is_quantized = True
        data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
        data_url += version_string + '_' + size_string + '_frozen.tgz'
        bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
        bottleneck_tensor_size = 1001
        input_width = int(size_string)
        input_height = int(size_string)
        input_depth = 3
        resized_input_tensor_name = 'input:0'
        if is_quantized:
          model_base_name = 'quantized_graph.pb'
        else:
          model_base_name = 'frozen_graph.pb'
        model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
        model_file_name = os.path.join(model_dir_name, model_base_name)
        input_mean = 127.5
        input_std = 127.5
      else:
        tf.logging.error("Couldn't understand architecture name '%s'", architecture)
        raise ValueError('Unknown architecture', architecture)

      return {
          'data_url': data_url,
          'bottleneck_tensor_name': bottleneck_tensor_name,
          'bottleneck_tensor_size': bottleneck_tensor_size,
          'input_width': input_width,
          'input_height': input_height,
          'input_depth': input_depth,
          'resized_input_tensor_name': resized_input_tensor_name,
          'model_file_name': model_file_name,
          'input_mean': input_mean,
          'input_std': input_std,
      }

此处提供了compelete代码:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py

EN

回答 2

Stack Overflow用户

发布于 2017-12-06 16:53:24

从SNPE SDK v1.8.0开始支持TensorFlow的reshape层。

票数 1
EN

Stack Overflow用户

发布于 2017-09-20 16:54:18

他们不是在添加重塑图层,而是从训练好的模型中选择重塑图层。然后,他们将在重塑图层的输出上添加自己的图层。

如果您想选择一个更高的层,用您想要的层的名称替换'pool_3/_reshape:0‘。您应该能够从模型代码推导出名称:https://github.com/tensorflow/models/blob/master/slim/nets/inception_v3.py

或者更简单,打印graph_def中所有节点的名称并选择所需的节点:

代码语言:javascript
复制
    for node in graph_def.node:
        print(node.name)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45793329

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档