Strassen矩阵乘法(C++)
时间:2022-11-03 01:00:00
思路
两个矩阵A,B相乘时.有三种方法
暴力计算法. 三个for循环, 时间复杂度为O(n^3).因为Cij=∑(k=1->n)Aik*Bkj,需要循环, 且C中有n^2个元素, 因此,时间复杂度为O(n^3)
分治法. 首先将A,B,C分为相等大小的方块矩阵.
所以C11=A11*B11 A12*B21, C12=A11*B12 A12*B22,
C21=A21*B11 A22*B21, C22=A21*B12 A22*B22
用T(n)表示n*n矩阵乘法, 所以有T(n)=8T(n/2) Θ(n^2). 其中, 8T(n/2)表示8次子矩阵乘法, 子矩阵的规模为n/2 * n/2. θ(n^2)表示4次矩阵加法的时间复杂性和合并C矩阵的时间复杂性.最后结果是Θ(n^3)与暴力计算时间的复杂性相同.
Strassen算法,可以优化时间复杂度O(n^log7).
现在重新定义7个新矩阵
M1=(A11 A22)*(B11 B22)
M2=(A21 A22)*B11
M3=A11*(B12-B22)
M4=A22*(B21-B11)
M5=(A11 A12)*B22
M6=(A21-A11)*(B11 B12)
M7=(A12-A22)*(B21 B22)
结果矩阵C可以组合上述矩阵,如下
C11=M1 M4-M5 M7
C12=M3 M5
C21=M2 M4
C22=M1-M2 M3 M6
此时共有7次乘法,18次加减法. 写递推公式T(n)=7T(n/2) Θ(n^2). 最终结果是O(n^log7)=O(n^2.807).
代码如下:
#include using namespace std; // 暴力求解矩阵相乘 void MUL(int** MatrixA,int** MatrixB,int** MatrixResult,int Msize){ for(int i=0;i> MSize;
// 定义三个矩阵
int** MatrixA;
int** MatrixB;
int** MatrixC;
// 初始化三个矩阵
MatrixA=new int*[MSize];
MatrixB=new int*[MSize];
MatrixC=new int*[MSize];
for(int i=0;i> MatrixA[i][j];
}
}
for(int i=0;i> MatrixB[i][j];
}
}
Strassen(MSize,MatrixA,MatrixB,MatrixC);
// 打印输出结果矩阵
for(int i=0;i