首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >四维张量子矩阵的加速计算

四维张量子矩阵的加速计算
EN

Stack Overflow用户
提问于 2022-06-01 12:36:39
回答 2查看 129关注 0票数 0

我有一个数组,其输入数据为1028* 24* 24*16。当我运行下面的代码时,它的工作速度非常慢。我怎么才能加快速度?谢谢

(我想从一个大数组中得到3*3矩阵。)

代码语言:javascript
复制
import itertools,math,time,random
import numpy as np

start=time.time()

def ls(x):
     x_p = x / np.sum(x)         
     return np.max(x_p)

inputs = np.random.rand(1028, 24, 24, 16)
b, r, c, ch = inputs.shape[0], inputs.shape[1], inputs.shape[2], inputs.shape[3]

inputs = np.transpose(inputs, (0, 3, 1, 2)) 
inputs = np.reshape(inputs, (b*ch, r, c))

s=2
r=3
num_r=9
num_c=9
ke = np.zeros((b*ch, num_r, num_c), dtype=np.float32)

for i in range(num_r):
     for j in range(num_c):
         outs = np.array(list(map(ls, inputs[:, i*s:i*s+r, j*s:j*s+r])))
         ke[:, i, j] = outs

ke = np.reshape(ke, (b, ch, num_r, num_c))
ke = np.transpose(ke, (0, 2, 3, 1))

print(time.time() - start)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-06-01 16:26:31

可以将原始代码包装(经过一些改进)为以下函数:

代码语言:javascript
复制
import numpy as np


def foo0(arr, s=2, r=3, nr=9, nc=9):
    forward_axes = (0, 3, 1, 2)
    backward_axes = (0, 2, 3, 1)
    b, r, c, ch = arr.shape
    arr = np.transpose(arr, forward_axes) 
    arr = np.reshape(arr, (b * ch, r, c))
    result = np.zeros((b * ch, nr, nc), dtype=np.float32)
    for i in range(nr):
        for j in range(nc):
            result[:, i, j] = np.fromiter(
                map(
                    lambda x: np.max(x / np.sum(x)),
                    arr[:, i * s:i * s + r, j * s:j * s + r]),
                dtype=np.float32)
    result = np.reshape(result, (b, ch, nr, nc))
    result = np.transpose(result, backward_axes)
    return result

虽然显式循环表明Numba可以在这里应用以获得一些低挂的结果,但不幸的是,如果没有与Python对象的交互,该函数就不能很容易地被修饰,从而大大降低了速度。幸运的是,核心计算可以很容易地向量化,只要nrnc足够小,这种优化就足够了:

代码语言:javascript
复制
def foo1(arr, s=2, r=3, nr=9, nc=9):
    forward_axes = (0, 3, 1, 2)
    backward_axes = (0, 2, 3, 1)
    b, r, c, ch = arr.shape
    arr = np.transpose(arr, forward_axes) 
    arr = np.reshape(arr, (b * ch, r, c))
    result = np.zeros((b * ch, nr, nc), dtype=np.float32)
    for i in range(nr):
        for j in range(nc):
            x = arr[:, i * s:i * s + r, j * s:j * s + r]
            result[:, i, j] = np.max(x, (-1, -2)) / np.sum(x, (-1, -2))
    result = np.reshape(result, (b, ch, nr, nc))
    result = np.transpose(result, backward_axes)
    return result

(上面的foo1()本质上等同于@ymmx's answer,还有一些额外的优化。)

请注意,max(x / k)max(x) / k相同,但除数大大减少。

其实,转置和整形虽然可以提高计算速度,但其实并不是必需的:

代码语言:javascript
复制
def foo2(arr, s=2, r=3, nr=9, nc=9):
    b, r, c, ch = arr.shape
    result = np.zeros((b, nr, nc, ch), dtype=np.float32)
    for i in range(nr):
        for j in range(nc):
            x = arr[:, i * s:i * s + r, j * s:j * s + r, :]
            result[:, i, j, :] = np.max(x, (1, 2)) / np.sum(x, (1, 2))
    return result

在Numba中,上面的转换比较简单,但是对于小型nr/nc来说,速度增益很小(与部分矢量化方法相比):

代码语言:javascript
复制
import numba as nb


@nb.njit
def sum_nb(arr):
    result = 0
    for x in arr:
        result += x
    return result


@nb.njit
def max_nb(arr):
    result = arr[0]
    for x in arr[1:]:
        if x > result:
            result = x
    return result


@nb.njit
def _sum_max(arr):
    b, r, c, ch = arr.shape
    res = np.empty((b, ch), dtype=arr.dtype)
    for i in range(b):
        for j in range(ch):
            x = arr[i, :, :, j].ravel()
            res[i, j] = max_nb(x) / sum_nb(x)
    return res


@nb.njit
def foo3(arr, s=2, r=3, nr=9, nc=9):
    b, r, c, ch = arr.shape
    result = np.zeros((b, nr, nc, ch), dtype=np.float32)
    for i in range(nr):
        for j in range(nc):
            result[:, i, j, :] = _sum_max(arr[:, i * s:i * s + r, j * s:j * s + r, :])
    return result

另一种选择是将Numba不兼容的代码保持在主循环之外:

代码语言:javascript
复制
@nb.njit(fastmath=True)
def _foo4(arr, result, s, r, nr, nc):
    bch, nr, nc = result.shape
    for i in range(nr):
        for j in range(nc):
            for k in range(bch):
                x = arr[k, i * s:i * s + r, j * s:j * s + r].ravel()
                result[k, i, j] = max_nb(x) / sum_nb(x)
    return result


def foo4(arr, s=2, r=3, nr=9, nc=9):
    forward_axes = (0, 3, 1, 2)
    backward_axes = (0, 2, 3, 1)
    b, r, c, ch = arr.shape
    arr = np.transpose(arr, forward_axes) 
    arr = np.reshape(arr, (b * ch, r, c))
    result = np.empty((b * ch, nr, nc))
    result = _foo4(arr, result, s, r, nr, nc)
    result = np.reshape(result, (b, ch, nr, nc))
    result = np.transpose(result, backward_axes)
    return result

但同样,速度增益也将微乎其微。

注意,完全向量化的方法不太可能是有效的,因为主循环中的对象是锯齿状的。

想了解一下相对速度:

代码语言:javascript
复制
funcs = foo0, foo1, foo2, foo3, foo4
arr = np.random.rand(100, 24, 24, 16)


timeds_n = {}
for p in range(1):
    n = 10 ** p
    k = 3
    arr = np.random.rand(100, 24, 24, 16)
    print(f"N = {arr.size}")
    base = funcs[0](arr)
    timeds_n[n] = []
    for func in funcs:
        res = func(arr)
        timed = %timeit -r 1 -n 1 -q -o func(arr)
        timeds_n[n].append(timed.best)
        print(f"{func.__name__:>24}  {np.allclose(base, res)}  {timed.best:.9f}")
代码语言:javascript
复制
N = 921600
                    foo0  True  1.757508748
                    foo1  True  0.095540081
                    foo2  True  0.179208341
                    foo3  True  0.160671403
                    foo4  True  0.155691721
票数 3
EN

Stack Overflow用户

发布于 2022-06-01 12:47:02

我认为问题主要是函数ls,它应该是矢量化的,列表/地图需要你花时间。

代码语言:javascript
复制
import itertools,math,time,random
import numpy as np

start=time.time()

def ls(x):
     x_p = x / np.sum(np.sum(x, axis=1), axis=1)[:,None,None]
     return np.max(np.max(x_p,axis=1),axis=1)

inputs = np.random.rand(1028, 24, 24, 16)
b, r, c, ch = inputs.shape[0], inputs.shape[1], inputs.shape[2], inputs.shape[3]

inputs = np.transpose(inputs, (0, 3, 1, 2))
inputs = np.reshape(inputs, (b*ch, r, c))

s=2
r=3
num_r=9
num_c=9
ke = np.zeros((b*ch, num_r, num_c), dtype=np.float32)

for i in range(num_r):
    print(i)
    for j in range(num_c):
        # outs = np.array(list(map(ls, inputs[:, i*s:i*s+r, j*s:j*s+r])))
        outs =  ls(inputs[:, i*s:i*s+r, j*s:j*s+r])
        ke[:, i, j] = outs

ke = np.reshape(ke, (b, ch, num_r, num_c))
ke = np.transpose(ke, (0, 2, 3, 1))

print(time.time() - start)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72462021

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档