首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pytorch中的批量束流搜索

pytorch中的批量束流搜索
EN

Stack Overflow用户
提问于 2020-10-14 23:42:58
回答 2查看 2.9K关注 0票数 2

我正在尝试在一个文本生成模型中实现一个波束搜索解码策略。这是我用来解码输出概率的函数。

代码语言:javascript
复制
def beam_search_decoder(data, k):
    sequences = [[list(), 0.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score - torch.log(row[j])]
                all_candidates.append(candidate)
        # sort candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        sequences = ordered[:k]
    return sequences

现在你可以看到这个函数是在batch_size 1的基础上实现的。为批处理大小添加另一个循环将使算法成为O(n^4)。它就像现在一样慢。有没有什么方法可以提高这个函数的速度。我的模型输出通常是遵循(batch_size, max_len, vocab_size)格式的(32, 150, 9907)大小

EN

回答 2

Stack Overflow用户

发布于 2021-02-18 18:20:28

下面是我的实现,它可能比for循环实现快一点。

代码语言:javascript
复制
import torch


def beam_search_decoder(post, k):
    """Beam Search Decoder

    Parameters:

        post(Tensor) – the posterior of network.
        k(int) – beam size of decoder.

    Outputs:

        indices(Tensor) – a beam of index sequence.
        log_prob(Tensor) – a beam of log likelihood of sequence.

    Shape:

        post: (batch_size, seq_length, vocab_size).
        indices: (batch_size, beam_size, seq_length).
        log_prob: (batch_size, beam_size).

    Examples:

        >>> post = torch.softmax(torch.randn([32, 20, 1000]), -1)
        >>> indices, log_prob = beam_search_decoder(post, 3)

    """

    batch_size, seq_length, _ = post.shape
    log_post = post.log()
    log_prob, indices = log_post[:, 0, :].topk(k, sorted=True)
    indices = indices.unsqueeze(-1)
    for i in range(1, seq_length):
        log_prob = log_prob.unsqueeze(-1) + log_post[:, i, :].unsqueeze(1).repeat(1, k, 1)
        log_prob, index = log_prob.view(batch_size, -1).topk(k, sorted=True)
        indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1)
    return indices, log_prob
票数 3
EN

Stack Overflow用户

发布于 2021-10-18 17:12:35

你可以使用这个库

https://pypi.org/project/pytorch-beam-search/

它实现了PyTorch序列模型的波束搜索、贪婪搜索和采样。

以下代码片段实现了一个转换器seq2seq模型,并使用它来生成预测。

代码语言:javascript
复制
#pip install pytorch-beam-search
from pytorch_beam_search import seq2seq

# Create vocabularies
# Tokenize the way you need
source = [list("abcdefghijkl"), list("mnopqrstwxyz")]
target = [list("ABCDEFGHIJKL"), list("MNOPQRSTWXYZ")]
# An Index object represents a mapping from the vocabulary to
# to integers (indices) to feed into the models
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)

# Create tensors
X = source_index.text2tensor(source)
Y = target_index.text2tensor(target)
# X.shape == (n_source_examples, len_source_examples) == (2, 11)
# Y.shape == (n_target_examples, len_target_examples) == (2, 12)

# Create and train the model
model = seq2seq.Transformer(source_index, target_index)    # just a PyTorch model
model.fit(X, Y, epochs = 100)    # basic method included

# Generate new predictions
new_source = [list("new first in"), list("new second in")]
new_target = [list("new first out"), list("new second out")]
X_new = source_index.text2tensor(new_source)
Y_new = target_index.text2tensor(new_target)
loss, error_rate = model.evaluate(X_new, Y_new)    # basic method included
predictions, log_probabilities = seq2seq.beam_search(model, X_new) 
output = [target_index.tensor2text(p) for p in predictions]
output
票数 -2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64356953

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档