首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >从bidirectional_rnn切换到bidirectional_dynamic_rnn后,结果变得更糟

从bidirectional_rnn切换到bidirectional_dynamic_rnn后,结果变得更糟
EN

Stack Overflow用户
提问于 2017-02-04 19:03:17
回答 1查看 302关注 0票数 1

基本上是尝试用bidirectional_dynamic_rnn (重塑输入)来替代bidirectional_rnn,但在分类任务上得到的结果要差得多。我做错了什么吗?重塑?

bidirectional_rnn版本(代码摘录):

代码语言:javascript
复制
encoder_embedded_inputs = [embedding_ops.embedding_lookup(
                    W, encoder_input) for encoder_input in encoder_inputs]

encoder_outputs, encoder_state_fw, encoder_state_bw  = rnn.bidirectional_rnn(
                    encoder_cell_fw,
                    encoder_cell_bw,
                    encoder_embedded_inputs,
                    sequence_length=sequence_length,
                    dtype=dtype)

encoder_state = array_ops.concat(1, [array_ops.concat(
                1, encoder_state_fw), array_ops.concat(1, encoder_state_bw)])
top_states = [array_ops.reshape(e, [-1, 1, cell.output_size * 2])
                              for e in encoder_outputs]
attention_states = array_ops.concat(1, top_states)

分类准确率: 95%

bidirectional_dynamic_rnn版本(代码摘录):

代码语言:javascript
复制
encoder_embedded_inputs = [embedding_ops.embedding_lookup(
                    W, encoder_input) for encoder_input in encoder_inputs]
emb_size = int(encoder_embedded_inputs[0].get_shape()[1])
enc_size = len(encoder_embedded_inputs)
birnn_inputs = tf.reshape(encoder_embedded_inputs, [-1,enc_size,emb_size])

encoder_outputs, encoder_states  = rnn.bidirectional_dynamic_rnn(
                    encoder_cell_fw,
                    encoder_cell_bw,
                    birnn_inputs,
                    sequence_length=sequence_length,
                    dtype=dtype)
encoder_state_fw, encoder_state_bw = encoder_states
encoder_state = array_ops.concat(1, [array_ops.concat(
                1, encoder_state_fw), array_ops.concat(1, encoder_state_bw)])

attention_states = tf.concat(2, encoder_outputs)

分类准确率: 70%

EN

回答 1

Stack Overflow用户

发布于 2017-02-04 20:21:26

好吧,我发现tf.reshape不适合这个任务,我应该用tf.stack和tf.transpose来代替。因此,它基本上是在使用混乱的输入,并且不再能够学习。

错误:

代码语言:javascript
复制
birnn_inputs = tf.reshape(encoder_embedded_inputs, [-1,enc_size,emb_size])

右图:

代码语言:javascript
复制
birnn_inputs = tf.stack(encoder_embedded_inputs)
birnn_inputs = tf.transpose(birnn_inputs, [1,0,2])

所以现在它运行得很好。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/42039583

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档