首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Numpy ndarray的动态轴标度

Numpy ndarray的动态轴标度
EN

Stack Overflow用户
提问于 2015-06-27 23:35:44
回答 3查看 1.7K关注 0票数 10

我希望在三维数组的给定方向上获得2D切片,其中direction (或提取切片的轴)由另一个变量给出。

假设idx是3D数组中2D切片的索引,并假定direction是获取该2D切片的轴,则最初的方法是:

代码语言:javascript
复制
if direction == 0:
    return A[idx, :, :]
elif direction == 1:
    return A[:, idx, :]
else:
    return A[:, :, idx]

我很确定一定有一种方法可以做到这一点,而不需要附加条件,至少在原始python中也是如此。numpy有捷径吗?

到目前为止,我找到的更好的解决方案(用于动态执行)依赖于transpose操作符:

代码语言:javascript
复制
# for 3 dimensions [0,1,2] and direction == 1 --> [1, 0, 2]
tr = [direction] + range(A.ndim)
del tr[direction+1]

return np.transpose(A, tr)[idx]

但是我想知道是否有更好的/更容易的/更快的功能,因为对于3D,转置代码看起来几乎比3 if/elif更糟糕。它推广到ND更好,N越大,代码就越漂亮,但是3D是完全一样的。

EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2015-06-27 23:53:40

转置是便宜的(时间上的)。有一些numpy函数使用它将操作轴(或轴)移动到已知位置--通常是形状列表的前端或末端。tensordot是一个在脑海中浮现的。

其他函数构造索引元组。它们可以从列表或数组开始,以便于操作,然后将其转换为应用程序的元组。例如

代码语言:javascript
复制
I = [slice(None)]*A.ndim
I[axis] = idx
A[tuple(I)]

np.apply_along_axis做了类似的事情。查看这样的函数代码是有指导意义的。

我想,numpy函数的作者最担心的是它是否运行得很好,其次是速度,最后是它看起来是否漂亮。你可以在一个函数中隐藏各种丑陋的代码!

tensordot

代码语言:javascript
复制
at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dot(at, bt)
return res.reshape(olda + oldb)

上面的代码计算newaxes_..newshape...

apply_along_axis构造一个(0...,:,0...)索引元组。

代码语言:javascript
复制
i = zeros(nd, 'O')
i[axis] = slice(None, None)
i.put(indlist, ind)
....arr[tuple(i.tolist())]
票数 8
EN

Stack Overflow用户

发布于 2022-08-31 10:52:18

这是蟒蛇。您可以这样简单地使用eval()

代码语言:javascript
复制
def get_by_axis(a, idx, axis):
    indexing_list = a.ndim*[':']
    indexing_list[axis] = str(idx) 
    expression = f"a[{', '.join(indexing_list)}]"
    return eval(expression)

显然是,在这种情况下,您不接受来自不受信任用户的输入。

票数 0
EN

Stack Overflow用户

发布于 2018-02-06 12:02:17

要动态索引维度,可以使用交换轴,如下所示:

代码语言:javascript
复制
a = np.arange(7 * 8 * 9).reshape((7, 8, 9))

axis = 1
idx = 2

np.swapaxes(a, 0, axis)[idx]

运行时比较

自然方法(非动态):

代码语言:javascript
复制
%timeit a[:, idx, :]
300 ns ± 1.58 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

交换轴:

代码语言:javascript
复制
%timeit np.swapaxes(a, 0, axis)[idx]
752 ns ± 4.54 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

具有列表理解的索引:

代码语言:javascript
复制
%timeit a[[idx if i==axis else slice(None) for i in range(a.ndim)]]
票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/31094641

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档