首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow CIFAR10多GPU -为什么合并损失?

Tensorflow CIFAR10多GPU -为什么合并损失?
EN

Stack Overflow用户
提问于 2017-02-13 02:48:31
回答 2查看 845关注 0票数 4

在经过多个GPU训练的TensorFlow CIFAR10实例中,每个“塔”的损耗似乎是合并的,而梯度是从这个合并损耗中计算出来的。

代码语言:javascript
复制
    # Build the portion of the Graph calculating the losses. Note that we will
    # assemble the total_loss using a custom function below.
    _ = cifar10.loss(logits, labels)

    # Assemble all of the losses for the current tower only.
    losses = tf.get_collection('losses', scope)

    # Calculate the total loss for the current tower.
    total_loss = tf.add_n(losses, name='total_loss')

    # Attach a scalar summary to all individual losses and the total loss; do the
    # same for the averaged version of the losses.
    for l in losses + [total_loss]:
        # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
        # session. This helps the clarity of presentation on tensorboard.
        loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name)
        tf.contrib.deprecated.scalar_summary(loss_name, l)

    return total_loss

我是TensorFlow新手,但据我所知,每次调用cifar10.loss时,都会运行tf.add_to_collection('losses', cross_entropy_mean),并将当前批处理的损失存储在集合中。

然后调用losses = tf.get_collection('losses', scope),并从集合中检索所有损失。然后,tf.add_n op将从这个“塔”中检索到的所有损失张量加到一起。

我希望损失只是目前的培训步骤/批,而不是所有的批次。

我是不是误会了什么?还是有理由把损失合并在一起?

EN

回答 2

Stack Overflow用户

发布于 2017-05-03 19:00:07

如果启用了重量衰减,它还会将其添加到损失收集中。因此,对于每个塔(范围),它将add_n所有的损失: cross_entropy_mean和weight_decay。

然后计算每个塔(范围)的梯度。最后,不同塔(范围)的所有梯度将在average_gradients中平均。

票数 1
EN

Stack Overflow用户

发布于 2017-08-16 17:03:35

为什么合并损失

您所提到的示例是多个gpus上的数据并行性示例。数据并行有助于使用更大的batch_size来训练更深层次的模型。在此设置中,您需要组合来自gpus的损失,因为每个gpus都持有输入批的一部分(丢失和与输入部分对应的渐变)。在tensorflow数据并行示例的以下示例中提供了一个示例。

注意:在模型并行的情况下,主处理器收集运行在不同gpus和中间输出上的模型的不同子图。

示例

如果您想使用256的批处理大小来训练模型,那么对于可能不适合于单个gpu (例如8GB内存)的更深层次的模型(例如,resnet/inception),可以将批处理分成两批大小为128的批,并在不同的gpu上使用这两个批处理对模型进行前向传递,并计算损失和梯度。计算的(损失))从每个gpus中收集并进行平均处理。利用平均梯度对模型参数进行更新。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/42195922

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档