以seq2seq教程为例,假设我们有(5,5),(10,10)的存储桶,批处理大小为16: model_with_buckets用于构建模型。对于model_with_buckets(它是encoder_inputs)的输入,它是一个存储桶中的批处理,例如。大小为5*16但是,有代码可以对所有存储桶运行此批处理,即使它的大小与存储桶大小不同
# tensorflow/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
# in def model_with_buckets()
# this will run twice: seq2seq(encoder_inputs[:5],...) and seq2seq(encoder_inputs[:10],...)
# but encoder_inputs only belongs to bucket (5,5) and with size 5*16
for j, bucket in enumerate(buckets):
...
bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]],
decoder_inputs[:bucket[1]])当输出时,只使用encoder_inputs所属的桶的损失。
# models/tutorials/rnn/translate/seq2seq_model.py
# in def step()
output_feed = [self.updates[bucket_id], # Update Op that does SGD.
self.gradient_norms[bucket_id], # Gradient norm.
self.losses[bucket_id]] # Loss for this batch.因此,在我看来,model_with_buckets正在做不必要的工作,将encoder_inputs提供给其他不属于它的存储桶。这样做的目的是什么?
发布于 2017-07-11 18:09:39
代码处于图的阶段,composition.That是重点。当与session关联时,step()函数将选择特定的存储桶。
https://stackoverflow.com/questions/43110815
复制相似问题