首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在翻译任务的GPT2训练中增加批量?

如何在翻译任务的GPT2训练中增加批量?
EN

Stack Overflow用户
提问于 2021-05-08 14:12:06
回答 1查看 267关注 0票数 2

我正在开发一个代码,以便使用预先训练好的GPT2模型来完成机器翻译任务。我的数据的word-to-id长度是91,我为我的模型开发了以下代码:

代码语言:javascript
复制
import torch
from torch.utils.data import DataLoader
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

# data preparation code

def batch_sequences(x, y, env):
    """
    Take as input a list of n sequences (torch.LongTensor vectors) and return
    a tensor of size (slen, n) where slen is the length of the longest
    sentence, and a vector lengths containing the length of each sentence.
    """
    lengths_x = torch.LongTensor([len(s) + 2 for s in x])
    lengths_y = torch.LongTensor([len(s) + 2 for s in y])
    max_length = max(lengths_x.max().item(), lengths_y.max().item())
    sent_x = torch.LongTensor(
        max_length, lengths_x.size(0)).fill_(env.pad_index)
    sent_y = torch.LongTensor(
        max_length, lengths_y.size(0)).fill_(env.pad_index)
    assert lengths_x.min().item() > 2
    assert lengths_y.min().item() > 2

    sent_x[0] = env.eos_index
    for i, s in enumerate(x):
        sent_x[1:lengths_x[i] - 1, i].copy_(s)
        sent_x[lengths_x[i] - 1, i] = env.eos_index

    sent_y[0] = env.eos_index
    for i, s in enumerate(y):
        sent_y[1:lengths_y[i] - 1, i].copy_(s)
        sent_y[lengths_y[i] - 1, i] = env.eos_index

    return sent_x, sent_y, max_length

def collate_fn(elements):
    """
    Collate samples into a batch.
    """
    x, y = zip(*elements)
    x = [torch.LongTensor([env.word2id[w]
                          for w in seq if w in env.word2id]) for seq in x]
    y = [torch.LongTensor([env.word2id[w]
                          for w in seq if w in env.word2id]) for seq in y]
    x, y, length = batch_sequences(x, y, env)
    return (x, length), (y, length), torch.LongTensor(nb_ops)

loader = DataLoader(data, batch_size=1, shuffle=False, collate_fn=collate_fn)
gpt2 = GPT2Model.from_pretrained('gpt2')
in_layer = nn.Embedding(len(env.word2id), 768)
out_layer = nn.Linear(768, len(env.word2id))

parameters = list(gpt2.parameters()) + list(in_layer.parameters()) + list(out_layer.parameters())
optimizer = torch.optim.Adam(parameters)
loss_fn = nn.CrossEntropyLoss()
for layer in (gpt2, in_layer, out_layer):
    layer.train()

accuracies = list()
n_epochs = 5
for i in range(n_epochs):
    for (x, x_len), (y, y_len) in loader:

        x = x.to(device=device)
        y = y.to(device=device)

        embeddings = in_layer(x.reshape(1, -1))
        hidden_state = gpt2(inputs_embeds=embeddings).last_hidden_state[:, :]
        logits = out_layer(hidden_state)[0]
        loss = loss_fn(logits, y.reshape(-1))
        accuracies.append(
            (logits.argmax(dim=-1) == y.reshape(-1)).float().mean().item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if len(accuracies) % 500 == 0:
            accuracy = sum(accuracies[-50:]) / len(accuracies[-50:])
            print(f'Samples: {len(accuracies)}, Accuracy: {accuracy}')

当批处理大小为1时,这段代码工作得很好,但它太慢了。我想将批处理大小从1增加到32,但我遇到了一些尺寸兼容性问题。如何才能不出错地增加批处理大小?

我的数据由两个句子组成,第一个是第一种语言的句子,第二种是它的第二种语言的翻译。

例如,假设x.shape是(batch_size,12) (意味着我们有长度为12的'batch_size‘句子作为输入,y.shape也是(batch_size,12) (翻译)。我们还有一个长度为90的word-to-id字典,它将句子中的每个单词与其索引进行匹配)

EN

回答 1

Stack Overflow用户

发布于 2021-05-14 01:09:35

这个问题可以通过填充来解决。我们需要两个特殊的符号:

translated.

  • code

  • in outputs (y)表示不应参与损失计算的“空白”令牌,
  • 0 in inputs (x)将表示不应参与损失计算的“空白”令牌。忽略此值(通过参数ignore_index).

)的nn.CrossEntropyLoss()programmed

大小为3的批次可能如下所示:

代码语言:javascript
复制
x:
[[1, 2, 3, 0, 0],
[ 4, 5, 6, 7, 8],
[ 9, 8, 0, 0, 0]]
y:
[[1, 2,    3, -100, -100],
[ 4, 5,    6,    7,    8],
[ 9, 8, -100, -100, -100]]

您可以使用如下代码生成它:

代码语言:javascript
复制
def pad_sequences(batch, pad_value=0):
    n = max(len(v) for v in batch)
    return torch.tensor([v + [pad_value] * (n - len(v)) for v in batch])

然而,我觉得你的问题陈述有一个问题。如果您执行机器翻译,那么您的输入和输出可以具有不同的长度,但是您的体系结构只允许xy具有相同的长度。如果你想支持不同长度的xy,我建议使用seq2seq架构,比如T5。

另一个问题是GPT是自回归的,所以如果yx完全对齐,那么我们在生成y的左侧部分时就不能使用x的后缀。因此,如果您希望xy完全一致,但仍然希望在生成y时使用有关x的完整信息,我建议您使用双向编码器,如BERT。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67444616

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档