首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何创建具有动态"zero_state“的zero_state(推理失败)

如何创建具有动态"zero_state“的zero_state(推理失败)
EN

Stack Overflow用户
提问于 2017-01-16 02:53:15
回答 1查看 949关注 0票数 2

我一直在与"dynamic_rnn“一起创建一个模型。

该模型基于80个时间段信号,我希望在每次运行之前使"initial_state“为零,因此我已经设置了以下代码片段来完成这一任务:

代码语言:javascript
复制
state = cell_L1.zero_state(self.BatchSize,Xinputs.dtype)
outputs, outState = rnn.dynamic_rnn(cell_L1,Xinputs,initial_state=state,  dtype=tf.float32)

这对训练过程非常有用。问题是,一旦我进入推理,其中我的BatchSize = 1,我得到一个错误,因为rnn“状态”不匹配新的新的鑫get形状。因此,我认为我需要根据输入批次大小来生成"self.BatchSize“,而不是硬编码。我尝试了许多不同的方法,但都没有奏效。我不想通过feed_dict传递一堆零,因为它是一个基于批处理大小的常量。

以下是我的一些尝试。它们通常都会失败,因为在构建图形时输入大小是未知的:

代码语言:javascript
复制
state = cell_L1.zero_state(Xinputs.get_shape()[0],Xinputs.dtype)

.

代码语言:javascript
复制
state = tf.zeros([Xinputs.get_shape()[0], self.state_size], Xinputs.dtype, name="RnnInitializer")

另一种方法认为初始化程序可能在运行时才被调用,但在图形生成时仍然失败:

代码语言:javascript
复制
init = lambda shape, dtype: np.zeros(*shape)
state = tf.get_variable("state", shape=[Xinputs.get_shape()[0], self.state_size],initializer=init)

是否有一种方法可以动态地创建这个常量初始状态,还是需要使用张量服务代码通过feed_dict重新设置它?是否有一种聪明的方法在图中只执行一次,也许是使用tf.Variable.assign?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-01-25 20:53:39

这个问题的解决方案是如何获得"batch_size“,以便变量不是硬编码的。

这是给出的例子中正确的方法:

代码语言:javascript
复制
Xinputs = tf.placeholder(tf.int32, (None, self.sequence_size, self.num_params), name="input")
state = cell_L1.zero_state(Xinputs.get_shape()[0],Xinputs.dtype)

问题在于"get_shape()“的使用,它返回张量的”形状“,并将batch_size值取为。文档似乎没有那么清晰,但这似乎是一个常量值,所以当您将图形加载到推理中时,这个值仍然是硬编码的(可能只在图形创建时计算?)。

使用"tf.shape()“函数似乎可以做到这一点。这不是返回形状,而是张量。因此,这似乎是在运行时更新更多。使用此代码片段解决了训练批128的问题,然后将图形加载到TensorFlow-Service推理中,只处理1批。

代码语言:javascript
复制
Xinputs = tf.placeholder(tf.int32, (None, self.sequence_size, self.num_params), name="input")
batch_size = tf.shape(Xinputs)[0]
state = self.cell_L1.zero_state(batch_size,Xinputs.dtype)

这里有一个指向TensorFlow FAQ的很好的链接,它描述了这种方法‘如何构建一个处理可变批大小的图形?’:https://www.tensorflow.org/resources/faq

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41668786

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档