首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >ValueError:无法将操作数与获取BallTree邻居的平均距离的形状一起广播

ValueError:无法将操作数与获取BallTree邻居的平均距离的形状一起广播
EN

Stack Overflow用户
提问于 2020-11-12 19:37:28
回答 1查看 47关注 0票数 1

当我尝试获取50个点的邻域的平均距离时,我收到错误"ValueError: operands be broadcast with shapes (5,) (4,)“。邻域的距离是使用sklearn的函数BallTree计算的。此函数的查询返回一个numpy双精度数组,其中包含中心点和邻域之间的距离,我正在搜索以获得中心点和邻域之间的平均距离。

代码如下:

代码语言:javascript
复制
import numpy as np
import sklearn.neighbors

rng = np.random.RandomState(0)
X = rng.random_sample((50, 3))

treeBall_Neighbors = sklearn.neighbors.BallTree(X, leaf_size=2)
indices_Neighbors,distance_Neighbors=treeBall_Neighbors.query_radius(X[:], r=0.2,count_only=False,return_distance=True)

print(distance_Neighbors.mean())

这很奇怪,因为如果我尝试一个接一个地获取平均距离,我不会得到错误:

代码语言:javascript
复制
print(distance_Neighbors[0].mean())
print(distance_Neighbors[1].mean())
...

你能帮助我在不使用for循环的情况下获得一个包含邻居平均值的numpy数组吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-11-12 20:20:28

不幸的是,我觉得你有点不走运。让我们来看看distance_Neighbors最终是什么样子:

代码语言:javascript
复制
distance_Neighbors

array([array([0.19662693, 0.        , 0.12978415, 0.15542077, 0.19196227]),
       array([0.        , 0.10668909, 0.1864282 , 0.15770229]),
       array([0.]), array([0.14046915, 0.        , 0.19662693]),
       ...
       array([0.15770229, 0.15542077, 0.166294  , 0.08416146, 0.        ]),
       array([0.16614007, 0.        , 0.18970757, 0.19556229, 0.11739919]),
       array([0.]), array([0.]), array([0.])], dtype=object)

这是一个由dtype = object组成的参差不齐的数组-在numpy中,它不是一个有用的数据类型(当您尝试时,它们往往是奇怪错误的来源,比如您发现的错误)。这不是你的错,这就是sklearn的输出,但是你就像一个列表一样被卡住了,因为一个乱七八糟的dtype = object数组从一开始就不比一个列表好。

作为第二个问题,所有这些数组中都有一个0.0,这将扰乱您的平均值。如果您不介意没有邻居为nan的结果,您可以这样做:

代码语言:javascript
复制
[a[a>0].mean() for a in distance_Neighbors]
Out[13]: 
[0.1684485306859162,
 0.1502731942636554,
 nan,
 0.1685480404455817,
 0.14716664746268726,
 ...
 0.16064019067059504,
 0.17164368138912153,
 0.14089463228303245,
 0.16720227805614482,
 nan,
 nan,
 nan]

如果你想要一些其他的方法来处理它们,你需要创建一个真正的for循环。如果您根本不想处理它们,只需使用[a.mean() for a in distance_Neighbors]

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64803198

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档