首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch向量化存储桶求和查找数量

PyTorch向量化存储桶求和查找数量
EN

Stack Overflow用户
提问于 2019-12-07 04:03:33
回答 1查看 154关注 0票数 1

使用PyTorch,我已经弄清楚了以下代码,用于通过一些“存储桶索引”计算项目属性的总和:

代码语言:javascript
复制
DATASET_SIZE = 10
NUM_BUCKETS = 4
bucket_assignment = torch.tensor([0,1,2,3,0,1,2,3,0,1], dtype = torch.long)
values_to_add = torch.tensor([1,2,3,4,5,6,7,8,9,10], dtype = torch.float)
buckets = torch.zeros(NUM_BUCKETS, dtype = torch.float)
buckets.index_add_(0, bucket_assignment, values_to_add)

# Buckets is now tensor([15., 18., 10., 12.])

在我的例子中,这是专门检查问题的分配界限,然后代码检查没有存储桶分配不足或过度分配。

我想一次检查多个不同的可能的赋值(然后选择一个最好的选项,代码没有显示)。我想我可以通过将另一个维度添加到bucket_assignment plus和buckets来实现这一点,并让每一行都是一组不同的赋值。然而,这并不能像预期的那样工作,因为index_add_的第二个参数必须是一个简单的向量,我不能传入任何更高等级的张量。

例如。

代码语言:javascript
复制
BATCH_SIZE = 2
DATASET_SIZE = 5
NUM_BUCKETS = 3
bucket_assignment = torch.tensor([[0,1,2,0,1], [1,1,1,2,1]], dtype = torch.long)
values_to_add = torch.tensor([1,2,3,4,5], dtype = torch.float)
buckets = torch.zeros(BATCH_SIZE, NUM_BUCKETS, dtype = torch.float)
buckets.index_add_(0, bucket_assignment, values_to_add)

我想得到这样的结果:

代码语言:javascript
复制
tensor([[5., 7., 3.], [ 0., 11.,  4.]])

相反,我得到了一个错误:

代码语言:javascript
复制
RuntimeError: invalid argument 3: Index is supposed to be a vector at ../aten/src/TH/generic/THTensorEvenMoreMath.cpp:733

由于.index_add的局限性,这并不出人意料,但我不知道该如何进行。

我不确定有什么其他方法可以让我在PyTorch中解决这个问题--有没有其他torch方法可以让我实现同样的目的。这里的主要目标是矢量化和避免Python中的循环,因为实际上批处理大小很大,我将利用GPU加速。

EN

回答 1

Stack Overflow用户

发布于 2019-12-07 18:53:03

如果批处理大小是问题所在,您可以使用torch.masked_select获取每个存储桶torch.masked_select(values_to_add, bucket_assignment == bucket_num)的相加的值,其中PyTorch将广播values_to_add,然后只迭代普通python中的存储桶,如下所示:

代码语言:javascript
复制
def bucket_sizes(bucket_num):
    mask = bucket_assignment == bucket_num
    buckets = torch.masked_select(values_to_add, mask)
    buckets = torch.split(buckets, list(mask.sum(dim=1)))
    return [bucket.sum() for bucket in buckets]

torch.tensor([bucket_sizes(i) for i in range(NUM_BUCKETS)]).T
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59219635

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档