我正在尝试复制ConvNet + LSTM方法,使用pytorch实现在本文中。但我很难在我的模型中找到将CNN和LSTM结合起来的正确方法。以下是我的尝试:
class VideoRNN(nn.Module):
def __init__(self, hidden_size, n_classes):
super(VideoRNN, self).__init__()
self.hidden_size = hidden_size
vgg = models.vgg16(pretrained=True)
embed = nn.Sequential(*list(vgg.classifier.children())[:-1])
vgg.classifier = embed
for param in vgg.parameters():
param.requires_grad = False
self.embedding = vgg
self.GRU = nn.GRU(4096, hidden_size)
def forward(self, input, hidden=None):
embedded = self.embedding(input)
output, hidden = self.gru(output, hidden)
output = self.classifier(output.view(-1, 4096))
return output, hidden因为我的视频有可变的长度,所以我提供了一个PackedSequence作为输入。它是由形状为(M,B,C,H,W)的张量产生的,其中M是最大序列长度,B是批大小。C,H,W是每一帧的通道、高度和宽度。
我希望预先训练的CNN成为模型的一部分,因为我以后可能会解冻一些层以完善CNN以完成我的任务。这就是为什么我没有分别计算图像的嵌入。
我的问题如下:
发布于 2020-09-14 08:28:39
我终于找到了解决办法,让它发挥作用。下面是一个简单但完整的示例,说明我是如何创建一个能够使用VideoRNN作为输入的packedSequence的:
class VideoRNN(nn.Module):
def __init__(self, n_classes, batch_size, device):
super(VideoRNN, self).__init__()
self.batch = batch_size
self.device = device
# Loading a VGG16
vgg = models.vgg16(pretrained=True)
# Removing last layer of vgg 16
embed = nn.Sequential(*list(vgg.classifier.children())[:-1])
vgg.classifier = embed
# Freezing the model 3 last layers
for param in vgg.parameters():
param.requires_grad = False
self.embedding = vgg
self.gru = nn.LSTM(4096, 2048, bidirectional=True)
# Classification layer (*2 because bidirectionnal)
self.classifier = nn.Sequential(
nn.Linear(2048 * 2, 256),
nn.ReLU(),
nn.Linear(256, n_classes),
)
def forward(self, input):
hidden = torch.zeros(2, self.batch , 2048).to(
self.device
)
c_0 = torch.zeros(self.num_layer * 2, self.batch, 2048).to(
self.device
)
embedded = self.simple_elementwise_apply(self.embedding, input)
output, hidden = self.gru(embedded, (hidden, c_0))
hidden = hidden[0].view(-1, 2048 * 2)
output = self.classifier(hidden)
return output
def simple_elementwise_apply(self, fn, packed_sequence):
return torch.nn.utils.rnn.PackedSequence(
fn(packed_sequence.data), packed_sequence.batch_sizes
)关键是simple_elementwise_apply方法,允许在CNN网络中为PackedSequence提供数据,并检索由嵌入作为输出的新PackedSequence。
我希望你会发现它有用。
https://stackoverflow.com/questions/63567352
复制相似问题