首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow Estimators :单独训练图像网格的正确方法

Tensorflow Estimators :单独训练图像网格的正确方法
EN

Stack Overflow用户
提问于 2018-08-02 09:21:06
回答 1查看 48关注 0票数 2

我正在尝试训练对象检测模型,如此paper中所述

有3个完全连接的层,有512,512,25个神经元。来自最后一个卷积层的16x55x55特征映射被馈送到完全连接的层中,以检索适当的类。在这个阶段,(16x1x1)描述的每个网格都被馈送到完全连接的层中,以将网格分类为属于25个类中的一个。结构可以在下面的图中看到。

fully connected layers

我正在尝试改编TF MNIST分类教程中的代码,我想知道是否可以像下面的代码片段那样对每个网格的损失进行求和,并使用它来训练模型权重。

代码语言:javascript
复制
flat_fmap = tf.reshape(last_conv_layer, [-1, 16*55*55])

total_loss = 0

for grid of flat_fmap:
  dense1 = tf.layers.dense(inputs=grid, units=512, activation=tf.nn.relu)

  dense2 = tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu)

  logits = tf.layers.dense(inputs=dense2, units=25)

  total_loss += tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train_op = optimizer.minimize(
  loss=total_loss,
  global_step=tf.train.get_global_step())


return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.TRAIN, loss=total_loss, train_op=train_op)

在上面的代码中,我认为在每次迭代中都会创建3个新的层。但是,我希望在对一个网格和另一个网格进行分类时保留权重。

EN

回答 1

Stack Overflow用户

发布于 2018-08-04 03:10:49

添加到total_loss应该是可以的。

tf.losses.sparse_softmax_cross_entropy也在一起增加亏损。

它使用logit计算sparse_softmax,然后使用math_ops.reduce_sum通过求和来减少结果数组。因此,您可以通过某种方式将它们添加到一起。

As you can see in its source

网络声明上的for循环似乎不太寻常,在运行时这样做并通过feed_dict传递每个网格可能更有意义。

代码语言:javascript
复制
dense1 = tf.layers.dense(inputs=X, units=512, activation=tf.nn.relu)
dense2 = tf.layers.dense(inputs=dense1, units=512, activation=tf.nn.relu)
logits = tf.layers.dense(inputs=dense2, units=25)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)
total_loss = 0


with tf.session as sess:
   sess.run(init) 
   for grid in flat_fmap:
       _, l = sess.run([optimizer,loss], feed_dict{X: grid, labels=labels})

       total_loss += l
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51644257

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档