
本文不会介绍LSTM的原理,具体可看如下两篇文章
在介绍LSTM各种参数含义之前我们还是需要先用一个例子(参考LSTM神经网络输入输出究竟是怎样的?Scofield的回答)来理解LSTM。
Recurrent NNs,一般看的最多的图是这个:

rnn但是这个图对初学者相当不太友好。个人认为,目前所有的关于描述RecurrentNNs的图都画得不好,不够明确,里面的细节丢失了。(事实上里面一个"A"仅仅表示了一层的变换,具体如下图所示。)

非常清楚,这是很多初学者不能理解RecurrentNNs的根本原因,即在于Recurrent NNs是在time_step上的拓展的这一特性。MLP好理解,CNN也好理解,但Recurrent NNs,就是无法搞清楚里面的拓扑结构,跟MLP联系不上。
先看看MLP,很好理解,就是一张网络清楚地显示了张量流向。
general MLP是这样的拓扑:

mlp然后CNN也好理解,跟MLP无差若干,只是权重运算由
变为
。CNN是这样的拓扑:

但RecurrentNNs的拓扑发生了一个很大的改动,即一个MLP会在time_step这个维度上进行延伸,每个时序都会有input。
所以RecurrentNNs的结构图应该这样画,在理解上才会更清晰些,对比MLP,也一目了然。(为了简约,只画了4个time-steps )……

如上图所示,
的输入
,也就是说一次time_step输入一个input tensor。
也就代表了一张MLP的hidden layer的一个cell,可以看到中间黄色圈圈就表示隐藏层.
理解无异,可以看到每个时序的输出节点数是等于隐藏节点数的。注意,红色的箭头指向仅仅表示数据流动方向,并不是表示隐藏层之间相连。
再结合一个操作实例说明。如果我们有一条长文本,我们给句子事先分割好句子,并且进行tokenize, dictionarize,接着再由look up table 查找到embedding,将token由embedding表示,再对应到上图的输入。流程如下:
,每一列代表一个词向量,词向量维度自行确定(假设一个单词由长度为100的向量表示);矩阵列数固定为time_step length。
- sentence2: ...
- ……,则padded sentence length(step5中矩阵列数)固定为
。一次RNNs的run只处理一条sentence。每个sentence的每个token的embedding对应了每个时序 的输入 。一次RNNs的run,连续地将整个sentence处理完。简单理解就是每次传入RNN的句子长度为
,换句话就是RNN横向长度为
的隐状态
;但整体RNN的输出
是在最后一个time_step
时获取,才是完整的最终结果。
,做seq2seq 网络……或者搞创新……
通过源代码中可以看到nn.LSTM继承自nn.RNNBase,其初始化函数定义如下
class RNNBase(Module):
...
def __init__(self, mode, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
dropout=0., bidirectional=False):我们需要关注的参数以及其含义解释如下:
下面介绍一下输入数据的维度要求(batch_first=False):
输入数据需要按如下形式传入 input, (h_0,c_0)
torch.nn.utils.rnn.pack_padded_sequence()或者torch.nn.utils.rnn.pack_sequence()来对句子进行填充或者截断。bidirectional决定,如果为False,则等于1;反之等于2。当然,如果你没有传入(h_0, c_0),那么这两个参数会默认设置为0。
,表示第二层LSTM每个time step对应的输出。
- 另外如果前面你对输入数据使用了`torch.nn.utils.rnn.PackedSequence`,那么输出也会做同样的操作编程packed sequence。
- 对于unpacked情况,我们可以对输出做如下处理来对方向作分离`output.view(seq_len, batch, num_directions, hidden_size)`, 其中前向和后向分别用0和1表示Similarly, the directions can be separated in the packed case.
rnn = nn.LSTM(10, 20, 2) # 一个单词向量长度为10,隐藏层节点数为20,LSTM有2层
input = torch.randn(5, 3, 10) # 输入数据由3个句子组成,每个句子由5个单词组成,单词向量长度为10
h0 = torch.randn(2, 3, 20) # 2:LSTM层数*方向 3:batch 20: 隐藏层节点数
c0 = torch.randn(2, 3, 20) # 同上
output, (hn, cn) = rnn(input, (h0, c0))
print(output.shape, hn.shape, cn.shape)
>>> torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])参考:
MARSGGBO♥原创
如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com
2019-12-31 10:41:09