我用simpletransformers库训练了一个T5转换器。
下面是获取预测结果的代码:
pred_values = model.predict(input_values)然而,它只返回排名靠前或贪婪的预测,我如何才能获得排名前10的结果?
发布于 2021-03-13 01:46:55
必需的参数是num_return_sequences,它显示要生成的样本数。但是,如果要使用光束搜索算法,还应为光束搜索设置一个数字。
model_args = T5Args()
model_args.num_beams = 5
model_args.num_return_sequences = 2或者,您可以使用top_k或top_p生成顶级示例并从中进行选择,在这些情况下,您必须将do_sample设置为True。有关参数的更多信息,请参考1和2,这是详细的解释。
model_args = T5Args()
model_args.do_sample = True
model_args.top_p = 0.9
model_args.num_return_sequences = 2https://stackoverflow.com/questions/66597384
复制相似问题