首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >C中的Strassen乘法

C中的Strassen乘法
EN

Stack Overflow用户
提问于 2021-06-24 13:18:37
回答 1查看 90关注 0票数 1

请查看以下代码:

代码语言:javascript
复制
#include<stdio.h>
#include<stdlib.h>

int **divide(int **Matrix,int n,int position)
{
    int i,j;
    int **Partition=malloc(sizeof(*Partition)*n);
    for(i=0;i<n;i++)
    {
        Partition[i]=calloc(n,sizeof(*Partition[i]));
    }
    if(position==1)
    {
        for(i=0;i<n/2;i++)
        {
            for(j=0;j<n/2;j++)
            {
                Partition[i][j]=Matrix[i][j];
            }
        }
    }
    else if(position==2)
    {
        for(i=0;i<n/2;i++)
        {
            for(j=0;j<n/2;j++)
            {
                Partition[i][j]=Matrix[i][j+n/2];
            }
        }
    }
    else if(position==3)
    {
        for(i=0;i<n/2;i++)
        {
            for(j=0;j<n/2;j++)
            {
                Partition[i][j]=Matrix[i+n/2][j];
            }
        }
    }
    else if(position==4)
    {
        for(i=0;i<n/2;i++)
        {
            for(j=0;j<n/2;j++)
            {
                Partition[i][j]=Matrix[i+n/2][j+n/2];
            }
        }
    }
    return Partition;
}


int **allocate(int n)
{
    int **newmatrix=malloc(sizeof(*newmatrix)*n);
    for(int i=0;i<n;i++)
    {
        newmatrix[i]=calloc(n, sizeof(*newmatrix[i]));
    }
    return newmatrix;
}
void mfree(int **matrix,int n) {
    for (int i=0;i<n;i++) {
        free(matrix[i]);
    }
    free(matrix);
}
int **add(int **a,int **b,int n)
{
    int **c=allocate(n);
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            c[i][j]=a[i][j]+b[i][j];
        }
    }
    return c;
}
int **subtract(int **a,int **b,int n)
{
    int **c=allocate(n);
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            c[i][j]=a[i][j]-b[i][j];
        }
    }
    return c;
}
void print(int **Matrix,int n)
{
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            printf("%d ",Matrix[i][j]);
        }
        printf("\n");
    }
}

int **Strassens(int **A,int **B,int n)
{
    int **C=allocate(n);
    if(n==1)
    {
        C[0][0]=A[0][0]*B[0][0];
    }
    else
    {   //Allocate the submatrices
        int **a11=allocate(n/2);
        int **a12=allocate(n/2);
        int **a21=allocate(n/2);
        int **a22=allocate(n/2);

        int **b11=allocate(n/2);
        int **b12=allocate(n/2);
        int **b21=allocate(n/2);
        int **b22=allocate(n/2);


        a11=divide(A,n,1);
        a12=divide(A,n,2);
        a21=divide(A,n,3);
        a22=divide(A,n,4);

        b11=divide(B,n,1);
        b12=divide(B,n,2);
        b21=divide(B,n,3);
        b22=divide(B,n,4);

        
        int **s1=subtract(b12,b22,n/2);
        int **s2=add(a11,a12,n/2);
        int **s3=add(a21,a22,n/2);
        int **s4=subtract(b21,b11,n/2);
        int **s5=add(a11,a22,n/2);
        int **s6=add(b11,b22,n/2);
        int **s7=subtract(a12,a22,n/2);
        int **s8=add(b21,b22,n/2);
        int **s9=subtract(a11,a21,n/2);
        int **s10=add(b11,a12,n/2);

        int **p1=Strassens(a11,s1,n/2);
        int **p2=Strassens(s2,b22,n/2);
        int **p3=Strassens(s3,b11,n/2);
        int **p4=Strassens(a22,s4,n/2);
        int **p5=Strassens(s5,s6,n/2);
        int **p6=Strassens(s7,s8,n/2);
        int **p7=Strassens(s9,s10,n/2);


        int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
        int **c12=add(p1,p2,n/2);
        int **c21=add(p3,p4,n/2);
        int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);

        for(int i=0;i<n/2;i++)
        {
            for(int j=0;j<n/2;j++)
            {
                C[i][j]=c11[i][j];
                C[i][j+n/2]=c12[i][j];
                C[i+n/2][j]=c21[i][j];
                C[i+n/2][j+n/2]=c22[i][j];
            }
        }
    }
    return C;
}

int main()
{
    int n=8;  //Dimension of the square matrix,  n*n;
    int **A=allocate(n);
    int **B=allocate(n);
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            A[i][j]=j+1;
            B[i][j]=j+1;
        }
    }
    printf("Matrix A:\n");
    print(A,n);
    printf("Matrix B: \n");
    print(B,n);
    printf("\n...Performing Multiplication with Strassen's...\nMatrix A*B:\n");
    int **C = Strassens(A,B,n);
    print(C,n);
    mfree(C,n);
}

我知道这是个很愚蠢的问题,数学有问题。但我没办法搞清楚我哪里出了问题。问题是,当我用相等的值相乘两个矩阵时,我得到了想要的结果,但这不适用于不同值的矩阵。例如,查看输出:

代码语言:javascript
复制
Matrix A:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8

...Performing Multiplication with Strassen's...
Matrix A*B:
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288

代码语言:javascript
复制
Matrix A:
1 2 3 4 5 6 7 8
2 3 4 5 6 7 8 9
3 4 5 6 7 8 9 10
4 5 6 7 8 9 10 11
5 6 7 8 9 10 11 12
6 7 8 9 10 11 12 13
7 8 9 10 11 12 13 14
8 9 10 11 12 13 14 15
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8

...Performing Multiplication with Strassen's...
Matrix A*B:
316 424 484 528 460 440 372 288
300 398 452 426 412 366 308 154
268 360 414 446 348 312 246 134
252 254 382 424 300 126 182 112
156 232 260 272 404 352 252 136
140 150 228 34 356 334 188 138
108 168 70 54 292 224 246 118
92 -122 38 24 244 222 182 104
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-06-24 13:36:46

-对不起。在这部分中有一个轻微的数学错误:

代码语言:javascript
复制
int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
int **c12=add(p1,p2,n/2);
int **c21=add(p3,p4,n/2);
int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);

将c11和c22替换为

代码语言:javascript
复制
int **c11=subtract(add(add(p5,p4,n/2),p6,n/2),p2,n/2);
...
int **c22=subtract(subtract(add(p5,p1,n/2),p3,n/2),p7,n/2);

纠正数学错误。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68116559

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档