首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >计算一个张量与另一个张量所有滚动之间成对矩阵乘积的有效方法

计算一个张量与另一个张量所有滚动之间成对矩阵乘积的有效方法
EN

Stack Overflow用户
提问于 2019-12-14 11:45:40
回答 1查看 78关注 0票数 0

假设我们有两个张量:

形状为(d,m,n)的张量A

形状为(d,n,l)的张量B。

如果我们想得到A和B的最右矩阵的成对矩阵乘积,我认为可以使用np.einsum('dmn,……nl->d.ml‘,A,B),其大小为(d,d,m,l)。不过,我想得到的配对产品,不是所有的配对。

导入参数k,1<=k<=d,我想得到以下成对矩阵产品:

从…

A(0,.)@B(0,…)

A(0,.)@B(k-1,.);

从…

A(1,.)@B(1,.)

A(1,.)@B(k,.);

……;

从…

A(d-2,.)@B(d-2,.),

A(d-2,.)@B(d-1,.)

至A(d-2,.)@B(k-3,.);

从…

A(d-1,.)@B(d-1,.)

A(d-1,.)@B(k-2,.)

注意,我们使用滚动的方法来处理张量B (如numpy.roll)。

最后,我们得到一个张量,它的形状是(d,k,m,l)。

最有效的方法是什么。

我知道几种方法,比如:

  1. 首先得到np.einsum('dmn,nl->d.ml‘,A,B),然后使用掩膜提取(d,k)对,
  2. 平铺B,然后以某种方式使用einsum。

但我认为还有更好的方法。

EN

回答 1

Stack Overflow用户

发布于 2019-12-14 22:09:59

我怀疑你能做得比for循环好得多。例如,与双for循环相比,下面是使用einsum和stride_tricks的矢量化版本:

代码:

代码语言:javascript
复制
from simple_benchmark import BenchmarkBuilder, MultiArgument
import numpy as np
from numpy.lib.stride_tricks import as_strided
B = BenchmarkBuilder()

@B.add_function()
def loopy(A,B,k): 
    d,m,n = A.shape                                   
    l = B.shape[-1]                     
    out = np.empty((d,k,m,l),int)                      
    for i in range(d):                         
        for j in range(k):                     
            out[i,j] = A[i]@B[(i+j)%d]                      
    return out                     

@B.add_function()
def vectory(A,B,k):                                            
    d,m,n = A.shape                                            
    l = B.shape[-1]                                            
    BB = np.concatenate([B,B[:k-1]],0)                         
    BB = as_strided(BB,(d,k,n,l),np.repeat(BB.strides,(2,1,1)))
    return np.einsum("ikl,ijln->ijkn",A,BB)                    

@B.add_arguments('d x k x m x n x l')
def argument_provider():
    for exp in range(10):
        d,k,m,n,l = (np.r_[1.6,1.5,1.5,1.5,1.5]**exp*(4,2,2,2,2)).astype(int)
        print(d,k,m,n,l)
        A = np.random.randint(0,10,(d,m,n))                            
        B = np.random.randint(0,10,(d,n,l))
        yield k*d*m*n*l,MultiArgument([A,B,k])

r = B.run()
r.plot()

import pylab
pylab.savefig('diagwa.png')
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59334818

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档