首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为自定义GPT-NEO模型实现do_sampling

为自定义GPT-NEO模型实现do_sampling
EN

Stack Overflow用户
提问于 2021-11-08 20:11:40
回答 1查看 73关注 0票数 0
代码语言:javascript
复制
import numpy as np
from transformers import GPTNeoForCausalLM, GPT2Tokenizer 
import coremltools as ct
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

sentence_fragment = "The Oceans are"

class NEO(torch.nn.Module):
    def __init__(self, model):
        super(NEO, self).__init__()
        self.next_token_predictor = model
    
    def forward(self, x):
        sentence = x
        predictions, _ = self.next_token_predictor(sentence)
        token = torch.argmax(predictions[-1, :], dim=0, keepdim=True)
        sentence = torch.cat((sentence, token), 0)
        return sentence

token_predictor = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", torchscript=True).eval()

context = torch.tensor(tokenizer.encode(sentence_fragment))
random_tokens = torch.randint(10000, (5,))
traced_token_predictor = torch.jit.trace(token_predictor, random_tokens)

model = NEO(model=traced_token_predictor)
scripted_model = torch.jit.script(model)

# Custom model

sentence_fragment = "The Oceans are"

for i in range(10):
    context = torch.tensor(tokenizer.encode(sentence_fragment))
    torch_out = scripted_model(context)
    sentence_fragment = tokenizer.decode(torch_out)
print("Custom model: {}".format(sentence_fragment))

# Stock model

model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", torchscript=True).eval()

sentence_fragment = "The Oceans are"

input_ids = tokenizer(sentence_fragment, return_tensors="pt").input_ids
gen_tokens = model.generate(input_ids, do_sample=True, max_length=20)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
print("Stock model: "+gen_text)

运行1

输出:

代码语言:javascript
复制
Custom model: The Oceans are the most important source of water for the entire world
代码语言:javascript
复制
Stock model: The Oceans are on the rise. The American Southwest is thriving, but the southern United States still

运行2

输出:

代码语言:javascript
复制
Custom model: The Oceans are the most important source of water for the entire world. 
代码语言:javascript
复制
Stock model: The Oceans are the land of man

This is a short video of the Australian government

自定义模型总是返回相同的输出。但是,对于do_sampling = True股票,model.generate在每次调用时返回不同的结果。我花了很多时间弄清楚do_sampling是如何为transformers工作的,所以我需要你们的帮助,感谢你们。

如何编写自定义模型,使其在每次调用时都有不同的结果?

谢谢!

EN

回答 1

Stack Overflow用户

发布于 2021-11-09 09:57:06

因此,答案是实现采样:D

代码语言:javascript
复制
class NEO(torch.nn.Module):
    def __init__(self, model):
        super(NEO, self).__init__()
        self.next_token_predictor = model
    
    def forward(self, x):
        sentence = x
        predictions, _ = self.next_token_predictor(sentence)
        # get top K (k=2) indicies of highest probs of tokens
        # 2 indicies would be enough, anyway you will got 2 in a power of N variations
        _, topK = torch.topk(predictions[-1, :], 2, dim=0)
        # get one of two of those indicies randomly, and concat sentence
        perm = torch.randperm(topK.size(0))
        idx = perm[:1]
        token = topK[idx.long()]
        sentence = torch.cat((sentence, token), 0)
        return sentence
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69889395

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档