这是在TensorFlow 1.11.0上。tft.apply_buckets的documentation并不是很具描述性。具体来说,我读到:"bucket_boundaries:存储桶边界表示为秩2张量。“
我假设这必须是存储桶索引和存储桶边界?
当我尝试下面的玩具示例时:
import tensorflow as tf
import tensorflow_transform as tft
import numpy as np
tf.enable_eager_execution()
x = np.array([-1,9,19, 29, 39])
xt = tf.cast(
tf.convert_to_tensor(x),
tf.float32
)
boundaries = tf.cast(
tf.transpose(
tf.convert_to_tensor([[0, 1, 2, 3], [10, 20, 30, 40]])
),
tf.float32
)
buckets = tft.apply_buckets(xt, boundaries)我得到了:
InvalidArgumentError: Expected sorted boundaries [Op:BucketizeWithInputBoundaries] name: assign_buckets
请注意,在这种情况下,x和bucket_boundaries参数是:
tf.Tensor([-1. 9. 19. 29. 39.], shape=(5,), dtype=float32)
tf.Tensor(
[[ 0. 10.]
[ 1. 20.]
[ 2. 30.]
[ 3. 40.]], shape=(4, 2), dtype=float32)因此,似乎bucket_boundaries不应该是索引和边界。有谁知道如何正确使用这种方法吗?
发布于 2019-07-18 01:50:37
经过一些尝试之后,我发现bucket_boundaries应该是一个二维数组,其中的条目是存储桶边界,并且该数组被包装,所以它有两列。如下例所示:
import tensorflow as tf
import tensorflow_transform as tft
import numpy as np
tf.enable_eager_execution()
x = np.array([-1,9,19, 29, 39])
xt = tf.cast(
tf.convert_to_tensor(x),
tf.float32
)
boundaries = tf.cast(
tf.transpose(
tf.convert_to_tensor([[0, 20, 40, 60], [10, 30, 50, 70]])
),
tf.float32
)
buckets = tft.apply_buckets(xt, boundaries)因此,期望的输入是:
print (xt)
print (buckets)
print (boundaries)tf.Tensor([-1. 9. 19. 29. 39.], shape=(5,), dtype=float32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
tf.Tensor(
[[ 0. 10.]
[20. 30.]
[40. 50.]
[60. 70.]], shape=(4, 2), dtype=float32)发布于 2020-02-24 22:15:25
我想补充一下,因为这是谷歌搜索"tft.apply_buckets“的唯一结果:)
对于我来说,这个例子在最新版本的TFT中不起作用。下面的代码对我来说是有效的。
请注意,存储桶被指定为秩2张量,但在内部维度中只有一个元素。
(我用错了词,但希望我下面的例子能澄清)
import tensorflow as tf
import tensorflow_transform as tft
import numpy as np
tf.enable_eager_execution()
xt = tf.cast(tf.convert_to_tensor(np.array([-1,9,19, 29, 39])),tf.float32)
bds = [[0],[10],[20],[30],[40]]
boundaries = tf.cast(tf.convert_to_tensor(bds),tf.float32)
buckets = tft.apply_buckets(xt, boundaries)谢谢你的帮助,因为这个答案让我走了很远的路!
我在TFT源代码中找到的其余部分:https://github.com/tensorflow/transform/blob/deb198d59f09624984622f7249944cdd8c3b733f/tensorflow_transform/mappers.py#L1697-L1698
https://stackoverflow.com/questions/57081644
复制相似问题