当我尝试获取50个点的邻域的平均距离时,我收到错误"ValueError: operands be broadcast with shapes (5,) (4,)“。邻域的距离是使用sklearn的函数BallTree计算的。此函数的查询返回一个numpy双精度数组,其中包含中心点和邻域之间的距离,我正在搜索以获得中心点和邻域之间的平均距离。
代码如下:
import numpy as np
import sklearn.neighbors
rng = np.random.RandomState(0)
X = rng.random_sample((50, 3))
treeBall_Neighbors = sklearn.neighbors.BallTree(X, leaf_size=2)
indices_Neighbors,distance_Neighbors=treeBall_Neighbors.query_radius(X[:], r=0.2,count_only=False,return_distance=True)
print(distance_Neighbors.mean())这很奇怪,因为如果我尝试一个接一个地获取平均距离,我不会得到错误:
print(distance_Neighbors[0].mean())
print(distance_Neighbors[1].mean())
...你能帮助我在不使用for循环的情况下获得一个包含邻居平均值的numpy数组吗?
发布于 2020-11-12 20:20:28
不幸的是,我觉得你有点不走运。让我们来看看distance_Neighbors最终是什么样子:
distance_Neighbors
array([array([0.19662693, 0. , 0.12978415, 0.15542077, 0.19196227]),
array([0. , 0.10668909, 0.1864282 , 0.15770229]),
array([0.]), array([0.14046915, 0. , 0.19662693]),
...
array([0.15770229, 0.15542077, 0.166294 , 0.08416146, 0. ]),
array([0.16614007, 0. , 0.18970757, 0.19556229, 0.11739919]),
array([0.]), array([0.]), array([0.])], dtype=object)这是一个由dtype = object组成的参差不齐的数组-在numpy中,它不是一个有用的数据类型(当您尝试时,它们往往是奇怪错误的来源,比如您发现的错误)。这不是你的错,这就是sklearn的输出,但是你就像一个列表一样被卡住了,因为一个乱七八糟的dtype = object数组从一开始就不比一个列表好。
作为第二个问题,所有这些数组中都有一个0.0,这将扰乱您的平均值。如果您不介意没有邻居为nan的结果,您可以这样做:
[a[a>0].mean() for a in distance_Neighbors]
Out[13]:
[0.1684485306859162,
0.1502731942636554,
nan,
0.1685480404455817,
0.14716664746268726,
...
0.16064019067059504,
0.17164368138912153,
0.14089463228303245,
0.16720227805614482,
nan,
nan,
nan]如果你想要一些其他的方法来处理它们,你需要创建一个真正的for循环。如果您根本不想处理它们,只需使用[a.mean() for a in distance_Neighbors]
https://stackoverflow.com/questions/64803198
复制相似问题