首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Matlab“外积”的Python实现

Matlab“外积”的Python实现
EN

Stack Overflow用户
提问于 2021-09-03 15:02:27
回答 1查看 85关注 0票数 0

我正在尝试将下面这段关于矩阵的外积的Matlab代码重写成python代码,

代码语言:javascript
复制
function Y = matlab_outer_product(X,x)
A = reshape(X, [size(X) ones(1,ndims(x))]);
B = reshape(x, [ones(1,ndims(X)) size(x)]);
Y = squeeze(bsxfun(@times,A,B));
end

我的python代码的一对一翻译如下(考虑到numpy数组和matlab矩阵的形状是如何排列的),

代码语言:javascript
复制
def python_outer_product(X, x):
    X_shape = list(X.shape)
    x_shape = list(x.shape)
    A = X.reshape(*list(np.ones(np.ndim(x),dtype=int)),*X_shape)
    B = x.reshape(*x_shape,*list(np.ones(np.ndim(X),dtype=int)))
    Y = A*B
    return Y.squeeze()

然后尝试输入,例如,

代码语言:javascript
复制
matlab_outer_product([1,2],[[3,4];[5,6]])
python_out_product(np.array([[1,2]], np.array([[3,4],[5,6]])))

输出结果不太匹配。在matlab中,它输出

代码语言:javascript
复制
output(:,:,1) = [[3,5];[6,10]]
output(:,:,2) = [[4,6];[8,12]]

在python中,它输出

代码语言:javascript
复制
output = array([
       [[ 3,  6],
        [ 4,  8]],

       [[ 5, 10],
        [ 6, 12]]
])

它们几乎完全相同,但并不完全相同。我想知道代码出了什么问题,以及如何更改python代码以与matlab输出匹配?

EN

回答 1

Stack Overflow用户

发布于 2021-09-03 15:43:12

完全血淋淋的细节(因为我的MATLAB内存很旧):

八度音阶

代码语言:javascript
复制
>> X = [1,2];
>> x = [[3,4];[5,6]];
>> A = reshape(X, [size(X) ones(1,ndims(x))]);
>> B = reshape(x, [ones(1,ndims(X)) size(x)]);
>> A
A =

   1   2

>> B
B =

ans(:,:,1,1) =  3
ans(:,:,2,1) =  5
ans(:,:,1,2) =  4
ans(:,:,2,2) =  6

>> bsxfun(@times,A,B)
ans =

ans(:,:,1,1) =

   3   6

ans(:,:,2,1) =

    5   10

ans(:,:,1,2) =

   4   8

ans(:,:,2,2) =

    6   12

>> squeeze(bsxfun(@times,A,B))
ans =

ans(:,:,1) =

    3    5
    6   10

ans(:,:,2) =

    4    6
    8   12

从(1,2)和(2,2)开始,将第二个扩展为(1,1,2,2)。bsxfun产生一个(1,2,2,2),它被压缩为(2,2,2)。

AX重塑为[1 2 1 1],但两个外部大小为1的维度被挤出,因此不会发生任何更改。

这个MATLAB输出器有点卷积,使用bsxfun执行(1,2,1,1)与(1,1,1,2)的元素乘法。至少在八度音阶中是一样的

代码语言:javascript
复制
A.*B

在numpy中

代码语言:javascript
复制
In [77]: X
Out[77]: array([[1, 2]])    # (1,2)
In [78]: x
Out[78]: 
array([[3, 4],              # (2,2)
       [5, 6]])

请注意,展平后的MATLAB/Octave x具有元素(3,5,4,6),而数值为3,4,5,6。

在numpy中,我可以简单地这样做:

代码语言:javascript
复制
In [79]: X[:,:,None,None]*x
Out[79]: 
array([[[[ 3,  4],          (1,2,2,2)
         [ 5,  6]],

        [[ 6,  8],
         [10, 12]]]])

或者没有额外的1维尺寸的X

代码语言:javascript
复制
In [84]: (X[0,:,None,None]*x)
Out[84]: 
array([[[ 3,  4],
        [ 5,  6]],

       [[ 6,  8],
        [10, 12]]])

In [85]: (X[0,:,None,None]*x).ravel()
Out[85]: array([ 3,  4,  5,  6,  6,  8, 10, 12])

与之相比,八度拉威尔

代码语言:javascript
复制
>> squeeze(bsxfun(@times,A,B))(:)'
ans =

    3    6    5   10    4    8    6   12

我们可以将转置加到numpy

代码语言:javascript
复制
In [96]: (X[0,:,None,None]*x).transpose(2,1,0).ravel()
Out[96]: array([ 3,  6,  5, 10,  4,  8,  6, 12])
In [97]: (X[0,:,None,None]*x).transpose(2,1,0)
Out[97]: 
array([[[ 3,  6],
        [ 5, 10]],

       [[ 4,  8],
        [ 6, 12]]])

至少在numpy中,我们可以在很多方面调整维度顺序,所以我不会尝试建议一个最优的。我仍然认为,编写对numpy来说“自然”的代码比盲目地按照MATLAB的顺序编写代码要好得多。

再试一次

我在上面意识到,MATLAB只是对(1,2,1,1)数组(1,1,1,2)进行广播,其中额外的1被添加到“A*.B”中。

使用转置到最外面的相同维度(在numpy中领先)

代码语言:javascript
复制
In [5]: X = X.T; x = x.T
In [6]: X.shape
Out[6]: (2, 1)
In [7]: x.shape
Out[7]: (2, 2)
In [8]: x
Out[8]: 
array([[3, 5],
       [4, 6]])
In [9]: x.ravel()
Out[9]: array([3, 5, 4, 6])   # compare with MATLAB (:)'

具有相同维数扩展的元素乘法:

代码语言:javascript
复制
In [10]: X[None,None,:,:]*x[:,:,None,None]
Out[10]: 
array([[[[ 3],
         [ 6]],

        [[ 5],
         [10]]],


       [[[ 4],
         [ 8]],

        [[ 6],
         [12]]]])
In [11]: _.shape
Out[11]: (2, 2, 2, 1)         # compare with octave (1,2,2,2)
In [12]: __.squeeze()
Out[12]: 
array([[[ 3,  6],
        [ 5, 10]],

       [[ 4,  8],
        [ 6, 12]]])

ravel与Octave相同:

代码语言:javascript
复制
In [13]: ___.ravel()
Out[13]: array([ 3,  6,  5, 10,  4,  8,  6, 12])

可以使用expand_dims来代替索引。在内部,它使用reshape

代码语言:javascript
复制
In [15]: np.expand_dims(X,(0,1)).shape
Out[15]: (1, 1, 2, 1)
In [16]: np.expand_dims(x,(2,3)).shape
Out[16]: (2, 2, 1, 1)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69047027

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档