首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何对python中的tensor2tensor模型进行推断(没有解码二进制和TensorFlow服务)

如何对python中的tensor2tensor模型进行推断(没有解码二进制和TensorFlow服务)
EN

Stack Overflow用户
提问于 2019-12-07 00:53:30
回答 1查看 560关注 0票数 0

对于如何在不使用解码二进制文件和tensor2tensor服务的情况下对tensor2tensor模型进行推断,我有点困惑。下面两个代码示例似乎是最接近它的,但我得到了一个“不能转换符号张量”错误。

模型输出在通过解码器并抛出错误之前给了我这个结果:

代码语言:javascript
复制
<tf.Tensor 'transformer/strided_slice:0' shape=(?, ?) dtype=int32>

https://gist.github.com/alexwolf22/7b24636c99a6f56da13c27a1ce573b8a#file-using_t2t_models-py https://gist.github.com/alexwolf22/e0ae60d8908c2772f2d3aedacf0ea618#file-decode_t2t_funcs-py

,我可能做错了什么?是否有一个更好的例子来说明如何在tensor2tensor中使用代码进行推理?--我在tensor2tensor回购中也通过了decoding.py --但是仍然没有成功。下面是我跟随的博客以及我的推理代码。

https://medium.com/data-from-the-trenches/training-cutting-edge-neural-networks-with-tensor2tensor-and-10-lines-of-code-10973c030b8

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
from tensor2tensor.utils.trainer_lib import create_hparams, registry
from tensor2tensor import problems
from tensor2tensor.layers import common_hparams
from tensor2tensor.models.transformer import Transformer

def encode(input_txt, encoders):
    """List of Strings to features dict, ready for inference"""
    encoded_inputs = [encoders["inputs"].encode(x) + [1] for x in input_txt]

    # pad each input so is they are the same length
    biggest_seq = len(max(encoded_inputs, key=len))
    for i, text_input in enumerate(encoded_inputs):
        encoded_inputs[i] = text_input + [0 for x in range(biggest_seq - len(text_input))]

    # Format Input Data For Model
    batched_inputs = tf.reshape(encoded_inputs, [len(encoded_inputs), -1, 1])
    return {"inputs": batched_inputs}


def decode(integers, encoders):
    """Decode list of ints to list of strings""" 

    # Turn to list to remove EOF mark
    to_decode = list(np.squeeze(integers))
    if isinstance(to_decode[0], np.ndarray):
        to_decode = map(lambda x: list(np.squeeze(x)), to_decode)
    else:
        to_decode = [to_decode]

    # remove <EOF> Tag before decoding
    to_decode = map(lambda x: x[:x.index(1)], filter(lambda x: 1 in x, to_decode))

    # Decode and return Translated text
    return [encoders["inputs"].decode(np.squeeze(x)) for x in to_decode]

INPUT_TEXT_TO_TRANSLATE = 'Translate this sentence into French'

# Set Tensor2Tensor Arguments
MODEL_DIR_PATH = 'data'
MODEL = 'transformer'
HPARAMS = 'transformer_base'
T2T_PROBLEM = 'translate_enfr_wmt_small8k_rev'

hparams = create_hparams(HPARAMS, data_dir=MODEL_DIR_PATH, problem_name=T2T_PROBLEM)

# Make any changes to default Hparams for model architechture used during training
hparams.batch_size = 1024
hparams.hidden_size = 7*80
hparams.filter_size = 7*80*4
hparams.num_heads = 8

# Load model into Memory
T2T_MODEL = registry.model(MODEL)(hparams, tf.estimator.ModeKeys.PREDICT)

# Init T2T Token Encoder/ Decoders
DATA_ENCODERS = problems.problem(T2T_PROBLEM).feature_encoders(MODEL_DIR_PATH)

### START USING MODELS
encoded_inputs= encode(INPUT_TEXT_TO_TRANSLATE, DATA_ENCODERS)
model_output = T2T_MODEL.infer(encoded_inputs, beam_size=2)["outputs"]
translated_text_in_french =  decode(model_output, DATA_ENCODERS)

print(translated_text_in_french)
EN

回答 1

Stack Overflow用户

发布于 2020-07-11 22:14:17

hparams.batch_size = 1024

您在批处理模式下进行预测。您可以更新您的解码方法。

代码语言:javascript
复制
# Setup helper functions for encoding and decoding
def encode(input_str, output_str=None):
    """Input str to features dict, ready for inference"""
    inputs = encoders["inputs"].encode(input_str) + [1]  # add EOS id
    batch_inputs = tf.reshape(inputs, [1, -1, 1])  # Make it 3D.
    return {"inputs": batch_inputs}

def decode(integers):
    """List of ints to str"""
    integers = list(np.squeeze(integers))
    if 1 in integers:
      integers = integers[:integers.index(1)]
    return encoders["inputs"].decode(np.squeeze(integers))

请按照下面的url中的示例操作。

https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb#scrollTo=EB4MP7_y_SuQ

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59222088

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档