首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >cudnnGRU is_training占位符

cudnnGRU is_training占位符
EN

Stack Overflow用户
提问于 2017-10-23 19:55:00
回答 1查看 337关注 0票数 0

在创建批量规范化模型时,我可以为is_training参数提供一个位置持有者,如下所示:

代码语言:javascript
复制
training = tf.placeholder(tf.bool)  
sym = create_symbol(training)
# ....
# Training: sess.run(model, feed_dict={X: data, y: label, training: True})
# Inference: sess.run(pred, feed_dict={X: data, training: False})

但是,当我对包含cudnnGRU (或cudnnLSTM)的符号执行此操作时,它不喜欢位置持有者:

代码语言:javascript
复制
cudnn_cell = tf.contrib.cudnn_rnn.CudnnGRU(num_layers=1, 
                                           num_units=NUMHIDDEN, 
                                           input_size=EMBEDSIZE)    # Set params
params_size_t = cudnn_cell.params_size()
params = tf.Variable(tf.random_uniform([params_size_t]), validate_shape=False)   
input_h = tf.Variable(tf.zeros([1, BATCHSIZE, NUMHIDDEN]))
outputs, states = cudnn_cell(is_training=training ,
                             input_data=word_list,
                             input_h=input_h,
                             params=params)

错误消息:

TypeError:为争论“is_training”而不是dtype=bool>而埋怨。

EN

回答 1

Stack Overflow用户

发布于 2017-10-23 20:16:38

这是因为tf.layers.batch_normalization支持“TensorFlow布尔值,或TensorFlow布尔标量张量(例如占位符)。”(见文件)。

但是tf.contrib.cudnn_rnn.CudnnGRU只支持布尔值。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46897471

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档