您好,我正在尝试创建一个带有GreedyEmbeddingHelper的BasicDecoder,但它给出了一个错误:
TypeError: helper must be a Helper, received: <class 'helper.GreedyEmbeddingHelper'>以下是我的代码的简化版本:
elif self.mode == 'decode':
# Start_tokens: [batch_size,] `int32` vector
start_tokens = tf.ones([self.batch_size, self.dimension], tf.float32) * 0.1337
end_token = 0.1337
def project_inputs(inputs):
print inputs.shape
return input_layer(inputs)
if not self.use_beamsearch_decode:
# Helper to feed inputs for greedy decoding: uses the argmax of the output
decoding_helper = helper.GreedyEmbeddingHelper(start_tokens=start_tokens,
end_token=end_token,
embedding=project_inputs)
# Basic decoder performs greedy decoding at each time step
print("building greedy decoder..")
inference_decoder = seq2seq.BasicDecoder(cell=self.decoder_cell,
helper=decoding_helper,
initial_state=self.decoder_initial_state,
output_layer=output_layer)
else:
# Beamsearch is used to approximately find the most likely translation
print("building beamsearch decoder..")
inference_decoder = beam_search_decoder.BeamSearchDecoder(cell=self.decoder_cell,
embedding=project_inputs,
start_tokens=start_tokens,
end_token=end_token,
initial_state=self.decoder_initial_state,
beam_width=self.beam_width,
output_layer=output_layer,)我不知道如何修复它,因为Helper是一个抽象类。所以这是不可能的。
发布于 2017-06-25 11:12:21
tf.contrib.seq2seq.GreedyEmbeddingHelper中定义了GreedyEmbeddingHelper。因此,使用tf.contrib.seq2seq.GreedyEmbeddingHelper而不是helper.GreedyEmbeddingHelper
https://stackoverflow.com/questions/44621681
复制相似问题