基本上是尝试用bidirectional_dynamic_rnn (重塑输入)来替代bidirectional_rnn,但在分类任务上得到的结果要差得多。我做错了什么吗?重塑?
bidirectional_rnn版本(代码摘录):
encoder_embedded_inputs = [embedding_ops.embedding_lookup(
W, encoder_input) for encoder_input in encoder_inputs]
encoder_outputs, encoder_state_fw, encoder_state_bw = rnn.bidirectional_rnn(
encoder_cell_fw,
encoder_cell_bw,
encoder_embedded_inputs,
sequence_length=sequence_length,
dtype=dtype)
encoder_state = array_ops.concat(1, [array_ops.concat(
1, encoder_state_fw), array_ops.concat(1, encoder_state_bw)])
top_states = [array_ops.reshape(e, [-1, 1, cell.output_size * 2])
for e in encoder_outputs]
attention_states = array_ops.concat(1, top_states)分类准确率: 95%
bidirectional_dynamic_rnn版本(代码摘录):
encoder_embedded_inputs = [embedding_ops.embedding_lookup(
W, encoder_input) for encoder_input in encoder_inputs]
emb_size = int(encoder_embedded_inputs[0].get_shape()[1])
enc_size = len(encoder_embedded_inputs)
birnn_inputs = tf.reshape(encoder_embedded_inputs, [-1,enc_size,emb_size])
encoder_outputs, encoder_states = rnn.bidirectional_dynamic_rnn(
encoder_cell_fw,
encoder_cell_bw,
birnn_inputs,
sequence_length=sequence_length,
dtype=dtype)
encoder_state_fw, encoder_state_bw = encoder_states
encoder_state = array_ops.concat(1, [array_ops.concat(
1, encoder_state_fw), array_ops.concat(1, encoder_state_bw)])
attention_states = tf.concat(2, encoder_outputs)分类准确率: 70%
发布于 2017-02-04 20:21:26
好吧,我发现tf.reshape不适合这个任务,我应该用tf.stack和tf.transpose来代替。因此,它基本上是在使用混乱的输入,并且不再能够学习。
错误:
birnn_inputs = tf.reshape(encoder_embedded_inputs, [-1,enc_size,emb_size])右图:
birnn_inputs = tf.stack(encoder_embedded_inputs)
birnn_inputs = tf.transpose(birnn_inputs, [1,0,2])所以现在它运行得很好。
https://stackoverflow.com/questions/42039583
复制相似问题