首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在火炬中进行批量学习?

如何在火炬中进行批量学习?
EN

Stack Overflow用户
提问于 2019-06-19 01:30:33
回答 1查看 3.4K关注 0票数 7

当您查看pytorch代码中是如何构建网络体系结构时,我们需要扩展torch.nn.Module__init__,我们定义了网络模块,pytorch将跟踪这些模块的参数梯度。然后,在forward函数中,我们定义了如何为我们的网络进行前向传递。

我不明白的是,批量学习是如何发生的。在包括forward函数在内的上述任何定义中,我们都不关心网络输入的批处理的维度。要执行批学习,唯一需要设置的是在输入中添加一个额外的维度,该维度对应于批处理大小,但是如果我们使用批处理学习,则网络定义中的任何内容都不会改变。至少,这是我在代码这里中看到的东西。

那么,如果到目前为止我解释的所有内容都是正确的(如果您让我知道我是否误解了什么),那么如果我们网络类(继承torch.nn.Module的类)的定义中没有任何关于批处理大小的声明,那么如何执行批处理学习?具体来说,我很想知道如何在pytorch中实现批处理梯度下降算法,当我们设置批处理维数时。

EN

回答 1

Stack Overflow用户

发布于 2019-06-23 20:16:24

检查一下这个:

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()         

    def forward(self, x):
        print("Hi ma")        
        print(x)
        x = F.relu(x)
        return x

n = Net()
r = n(torch.tensor(-1))
print(r)
r = n.forward(torch.tensor(1)) #not planned to call directly
print(r)

退出:

代码语言:javascript
复制
Hi ma
tensor(-1)
tensor(0)
Hi ma
tensor(1)
tensor(1)

需要记住的是,不应该直接调用forward。PyTorch使这个模块对象n可调用。它们实现了可调用性,例如:

代码语言:javascript
复制
 def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        hook(self, input)
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            raise RuntimeError(
                "forward hooks should never return any values, but '{}'"
                "didn't return None".format(hook))
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result

只有n()会自动调用forward

通常,__init__定义模块结构,forward()定义单个批处理上的操作。

如果需要的话,这个操作可能会对某些结构元素重复,或者像我们对x = F.relu(x)那样直接调用张量上的函数。

您已经做到了这一点,PyTorch中的所有功能都可以成批完成(迷你批处理),因为PyTorch是通过这种方式进行优化的。

这意味着当您读取图像时,您将不会读取单个图像,而是读取一个bs批次的图像。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56658935

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档