我正在学习深入的学习,并试图理解下面给出的pytorch代码。我很难理解概率计算是如何工作的。可以用门外汉的术语来分解它。谢谢一吨。
ps = model.forward(images0,:)
# Hyperparameters for our network
input_size = 784
hidden_sizes = [128, 64]
output_size = 10
# Build a feed-forward network
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], output_size),
nn.Softmax(dim=1))
print(model)
# Forward pass through the network and display output
images, labels = next(iter(trainloader))
images.resize_(images.shape[0], 1, 784)
print(images.shape)
ps = model.forward(images[0,:])发布于 2019-01-17 20:05:25
我是个外行,所以我会帮你处理外行的条件:)
input_size = 784
hidden_sizes = [128, 64]
output_size = 10这些是网络中各层的参数。每个神经网络由layers组成,每个layer都有一个输入和一个输出形状。
具体来说,input_size处理第一层的输入形状。这是整个网络的input_size。每个输入到网络中的样本都是长度为784的一维向量(数组长为784 )。
hidden_size处理网络中的形状。我们稍后再讨论这个问题。
output_size处理最后一层的输出形状。这意味着我们的网络将输出一个长度为10的一维向量。
现在,要逐行分解模型定义:
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),nn.Sequential部分只定义一个网络,输入的每个参数按照该顺序定义该网络中的一个新层。
nn.Linear(input_size, hidden_sizes[0])就是这样一个层的例子。它是我们网络的第一层,接收大小为input_size的输入,并输出大小hidden_sizes[0]的向量。输出的大小被认为是“隐藏的”,因为它不是整个网络的输入或输出。它是“隐藏的”,因为它位于网络内部,远离您实际使用时与之交互的网络的输入和输出端。
这被称为Linear,因为它通过将输入乘以其权值矩阵并将其偏差矩阵添加到结果中来应用线性变换。(Y = Ax + b,Y=输出,x=输入,A=权重,b=偏差)。
nn.ReLU(),ReLU是激活函数的一个例子。此函数所做的是将某种转换应用于最后一层(上面讨论的层)的输出,并输出该转换的结果。在这种情况下,所使用的函数是ReLU函数,它被定义为ReLU(x) = max(x, 0)。激活函数被用于神经网络,因为它们会产生非线性.这允许您的模型建模非线性关系。
nn.Linear(hidden_sizes[0], hidden_sizes[1]),根据我们前面讨论的内容,这是layer的另一个示例。它采用hidden_sizes[0] (与最后一层输出的形状相同)的输入,并输出长度为hidden_sizes[1]的一维矢量。
nn.ReLU(),再次苹果ReLU功能。
nn.Linear(hidden_sizes[1], output_size)与上述两层相同,但这次我们的输出形状是output_size。
nn.Softmax(dim=1))另一个激活函数。这个激活函数将nn.Linear输出的逻辑转换为实际的概率分布。这样,模型就可以输出每个类的概率。在这一点上,我们的模型建立。
# Forward pass through the network and display output
images, labels = next(iter(trainloader))
images.resize_(images.shape[0], 1, 784)
print(images.shape)这些只是对培训数据的预处理,并将其转换成正确的格式。
ps = model.forward(images[0,:])它通过模型(前向传递)传递图像,并应用先前在层中讨论的操作。你得到了结果输出。
https://stackoverflow.com/questions/54239125
复制相似问题