首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow插件seq2seq输出BasicDecoder call (tfa.seq2seq)

Tensorflow插件seq2seq输出BasicDecoder call (tfa.seq2seq)
EN

Stack Overflow用户
提问于 2022-02-28 08:42:08
回答 1查看 90关注 0票数 0

构建一个基于seq2seq的tfa.seq2seq,基本上与https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt#train_the_model类似。在调用BasicDecoder时,我正在查看输出的性质。我创建了一个解码器实例

代码语言:javascript
复制
  decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, \
                               sampler=greedy_sampler, output_layer=decoder.fc)

然后再叫它

代码语言:javascript
复制
  outputs, _, _ = decoder_instance(decoder_embedding_matrix, \ 
     start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)

outputs在这里是什么:预测概率?

接下来,我想做这样的事情

代码语言:javascript
复制
  predicted_logits = predicted_logits[:, -1, :]
  predicted_logits = predicted_logits/temperature
  
  # Sample the output logits to generate token IDs.
  predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
  predicted_ids = tf.squeeze(predicted_ids, axis=-1)
    
  # Convert from token ids to characters
  predicted_chars = chars_from_ids(predicted_ids)

编辑

在我的测试中,outputs如下所示

代码语言:javascript
复制
BasicDecoderOutput(rnn_output=<tf.Tensor: shape=(1, 1, 106), dtype=float32, numpy=
array([[[-1.7647576 ,  1.2142688 ,  2.3475904 ,  0.35890207,
          0.72230023, -0.3587367 , -0.02984604, -1.9962349 ,
          0.510706  , -1.4457364 , -0.43458703, -0.55248725,
         -0.9126631 , -0.5542034 , -1.2392808 , -1.0972862 ,
         -0.7256295 ,  0.02101   , -1.0858598 ,  0.9452345 ,
          0.56474745,  0.2157154 ,  1.6094822 ,  0.6396736 ,
          1.5741622 ,  1.4455014 ,  0.9529134 ,  0.37970737,
         -0.60284877,  0.73455685,  1.0571934 ,  1.3716137 ,
         -1.0882497 ,  1.7738185 ,  1.1919689 ,  0.8144775 ,
          0.84732264,  1.6677057 ,  1.8040668 ,  0.86257285,
          2.0206916 ,  1.3602887 ,  1.2091455 ,  1.318665  ,
         -0.6775206 , -0.9906771 , -0.39923188, -1.0290842 ,
         -1.3546644 , -1.5678416 ,  0.624691  , -1.0316744 ,
          1.2098004 ,  1.4669724 ,  0.9996722 ,  0.12806134,
         -0.42086226, -0.11248919, -0.8277442 ,  0.622267  ,
         -1.6404072 ,  0.2762841 , -0.54035664, -0.6325757 ,
         -0.16794772,  0.8435169 ,  1.1214966 , -1.5629222 ,
          0.27472585,  0.8861834 , -1.7886144 ,  0.56741697,
         -1.9197755 , -1.8073375 , -1.5050163 , -1.7794812 ,
         -0.11308812,  1.3161705 ,  1.027235  ,  1.3830551 ,
         -1.374056  , -1.4779223 ,  0.19962706, -1.6843308 ,
          0.370475  ,  0.8292502 , -1.2990475 , -1.8491654 ,
         -3.4606798 , -0.9822829 , -2.391135  , -3.6944065 ,
         -3.5912528 , -2.4165688 , -2.640759  , -4.0524964 ,
         -3.0878603 , -1.6555822 , -1.2015637 , -1.7716323 ,
          1.7384199 , -2.4340994 , -0.7337967 , -0.88279086,
         -0.85630864, -0.8148002 ]]], dtype=float32)>, sample_id=<tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[2]], dtype=int32)>)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-01 11:58:11

使用class GreedyEmbeddingSampler(Sampler):进行推理https://github.com/tensorflow/addons/blob/v0.15.0/tensorflow_addons/seq2seq/sampler.py#L559-L650

代码语言:javascript
复制
def sample(self, time, outputs, state):
    """sample for GreedyEmbeddingHelper."""
    del time, state  # unused by sample_fn
    # Outputs are logits, use argmax to get the most probable id
    if not isinstance(outputs, tf.Tensor):
        raise TypeError(
            "Expected outputs to be a single Tensor, got: %s" % type(outputs)
        )
    sample_ids = tf.argmax(outputs, axis=-1, output_type=tf.int32)
    return sample_ids

所以# Outputs are logits, use argmax to get the most probable id

BasicDecoder返回outputs = BasicDecoderOutput(cell_outputs, sample_ids),这是RNN单元或最终密集层的输出,以及logits的argmax的id。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71292414

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档