我试图使用tf.cond()根据一个条件创建两个不同的图。关于两个图,我们想要有重量正则化损失,因此我们使用tf.losses.get_regularization_loss()。下面是我们项目的伪代码
def net_1(x,y):
statement 1 (has trainable params)
statement 2 (has trainable params)
returndef net_2(x,y):
statement 1 (has trainable params)
statement 2 (has trainable params)
statement 3 (has trainable params)
returnstep = tf.get_or_create_global_step()
tf.cond(tf.greater(step, 100), net_1, net_2)
loss = 0.0
loss += tf.losses.get_regularization_loss()如果保留tf.losses.get_regularization_loss(),就会得到错误:
Retval没有价值
否则,就没有错误。
如果我们要强制实施tf.cond(),是否需要特别注意tf.losses.get_regularization_loss()。
发布于 2020-08-04 21:30:28
同样的问题,我用两个类似的函数替换了tf.cond (这与正则化有关).现在找不到更好的解决方案了。
https://stackoverflow.com/questions/56762468
复制相似问题