label = tf.constant([0,1,2,3,4,4,5,5])我有张量,例如,上面的张量。我要过滤元素为4的张量,输出张量应该是4,4,如何实现呢?谢谢。
发布于 2017-09-26 06:54:34
只需使用tf.where获取条件为真的索引,使用tf.gather收集指定的值即可。
import tensorflow as tf
label = tf.constant([0,1,2,3,4,4,5,5])
filtered = tf.gather(label, tf.where(tf.equal(label, 4)))
sess = tf.Session()
print(sess.run(filtered))[4]
https://stackoverflow.com/questions/46419257
复制相似问题