请查看以下代码:
#include<stdio.h>
#include<stdlib.h>
int **divide(int **Matrix,int n,int position)
{
int i,j;
int **Partition=malloc(sizeof(*Partition)*n);
for(i=0;i<n;i++)
{
Partition[i]=calloc(n,sizeof(*Partition[i]));
}
if(position==1)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i][j];
}
}
}
else if(position==2)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i][j+n/2];
}
}
}
else if(position==3)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i+n/2][j];
}
}
}
else if(position==4)
{
for(i=0;i<n/2;i++)
{
for(j=0;j<n/2;j++)
{
Partition[i][j]=Matrix[i+n/2][j+n/2];
}
}
}
return Partition;
}
int **allocate(int n)
{
int **newmatrix=malloc(sizeof(*newmatrix)*n);
for(int i=0;i<n;i++)
{
newmatrix[i]=calloc(n, sizeof(*newmatrix[i]));
}
return newmatrix;
}
void mfree(int **matrix,int n) {
for (int i=0;i<n;i++) {
free(matrix[i]);
}
free(matrix);
}
int **add(int **a,int **b,int n)
{
int **c=allocate(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
c[i][j]=a[i][j]+b[i][j];
}
}
return c;
}
int **subtract(int **a,int **b,int n)
{
int **c=allocate(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
c[i][j]=a[i][j]-b[i][j];
}
}
return c;
}
void print(int **Matrix,int n)
{
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
printf("%d ",Matrix[i][j]);
}
printf("\n");
}
}
int **Strassens(int **A,int **B,int n)
{
int **C=allocate(n);
if(n==1)
{
C[0][0]=A[0][0]*B[0][0];
}
else
{ //Allocate the submatrices
int **a11=allocate(n/2);
int **a12=allocate(n/2);
int **a21=allocate(n/2);
int **a22=allocate(n/2);
int **b11=allocate(n/2);
int **b12=allocate(n/2);
int **b21=allocate(n/2);
int **b22=allocate(n/2);
a11=divide(A,n,1);
a12=divide(A,n,2);
a21=divide(A,n,3);
a22=divide(A,n,4);
b11=divide(B,n,1);
b12=divide(B,n,2);
b21=divide(B,n,3);
b22=divide(B,n,4);
int **s1=subtract(b12,b22,n/2);
int **s2=add(a11,a12,n/2);
int **s3=add(a21,a22,n/2);
int **s4=subtract(b21,b11,n/2);
int **s5=add(a11,a22,n/2);
int **s6=add(b11,b22,n/2);
int **s7=subtract(a12,a22,n/2);
int **s8=add(b21,b22,n/2);
int **s9=subtract(a11,a21,n/2);
int **s10=add(b11,a12,n/2);
int **p1=Strassens(a11,s1,n/2);
int **p2=Strassens(s2,b22,n/2);
int **p3=Strassens(s3,b11,n/2);
int **p4=Strassens(a22,s4,n/2);
int **p5=Strassens(s5,s6,n/2);
int **p6=Strassens(s7,s8,n/2);
int **p7=Strassens(s9,s10,n/2);
int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
int **c12=add(p1,p2,n/2);
int **c21=add(p3,p4,n/2);
int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);
for(int i=0;i<n/2;i++)
{
for(int j=0;j<n/2;j++)
{
C[i][j]=c11[i][j];
C[i][j+n/2]=c12[i][j];
C[i+n/2][j]=c21[i][j];
C[i+n/2][j+n/2]=c22[i][j];
}
}
}
return C;
}
int main()
{
int n=8; //Dimension of the square matrix, n*n;
int **A=allocate(n);
int **B=allocate(n);
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
A[i][j]=j+1;
B[i][j]=j+1;
}
}
printf("Matrix A:\n");
print(A,n);
printf("Matrix B: \n");
print(B,n);
printf("\n...Performing Multiplication with Strassen's...\nMatrix A*B:\n");
int **C = Strassens(A,B,n);
print(C,n);
mfree(C,n);
}我知道这是个很愚蠢的问题,数学有问题。但我没办法搞清楚我哪里出了问题。问题是,当我用相等的值相乘两个矩阵时,我得到了想要的结果,但这不适用于不同值的矩阵。例如,查看输出:
Matrix A:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
...Performing Multiplication with Strassen's...
Matrix A*B:
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288和
Matrix A:
1 2 3 4 5 6 7 8
2 3 4 5 6 7 8 9
3 4 5 6 7 8 9 10
4 5 6 7 8 9 10 11
5 6 7 8 9 10 11 12
6 7 8 9 10 11 12 13
7 8 9 10 11 12 13 14
8 9 10 11 12 13 14 15
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
...Performing Multiplication with Strassen's...
Matrix A*B:
316 424 484 528 460 440 372 288
300 398 452 426 412 366 308 154
268 360 414 446 348 312 246 134
252 254 382 424 300 126 182 112
156 232 260 272 404 352 252 136
140 150 228 34 356 334 188 138
108 168 70 54 292 224 246 118
92 -122 38 24 244 222 182 104发布于 2021-06-24 13:36:46
-对不起。在这部分中有一个轻微的数学错误:
int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
int **c12=add(p1,p2,n/2);
int **c21=add(p3,p4,n/2);
int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);将c11和c22替换为
int **c11=subtract(add(add(p5,p4,n/2),p6,n/2),p2,n/2);
...
int **c22=subtract(subtract(add(p5,p1,n/2),p3,n/2),p7,n/2);纠正数学错误。
https://stackoverflow.com/questions/68116559
复制相似问题