Strassen算法的C语言实现
Strassen算法是一种高效的矩阵乘法算法,其时间复杂度为O(n^2.807),比传统的O(n^3)矩阵乘法算法更快,以下是Strassen算法的C语言实现:

(图片来源网络,侵删)
#include <stdio.h>
#include <stdlib.h>
// 分配矩阵内存
int** allocate_matrix(int n) {
int** matrix = (int**)malloc(n * sizeof(int*));
for (int i = 0; i < n; i++) {
matrix[i] = (int*)malloc(n * sizeof(int));
}
return matrix;
}
// 释放矩阵内存
void free_matrix(int** matrix, int n) {
for (int i = 0; i < n; i++) {
free(matrix[i]);
}
free(matrix);
}
// 复制矩阵
void copy_matrix(int** src, int** dest, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
dest[i][j] = src[i][j];
}
}
}
// 矩阵加法
void add_matrices(int** A, int** B, int** C, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
}
// 矩阵减法
void subtract_matrices(int** A, int** B, int** C, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
}
// 普通矩阵乘法(用于小矩阵)
void multiply_matrices(int** A, int** B, int** C, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = 0;
for (int k = 0; k < n; k++) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
}
// 将矩阵分成四个子矩阵
void split_matrix(int** src, int** dest, int n, int row, int col) {
for (int i = 0; i < n/2; i++) {
for (int j = 0; j < n/2; j++) {
dest[i][j] = src[i + row][j + col];
}
}
}
// 将四个子矩阵合并成一个矩阵
void merge_matrices(int** src, int** dest, int n, int row, int col) {
for (int i = 0; i < n/2; i++) {
for (int j = 0; j < n/2; j++) {
dest[i + row][j + col] = src[i][j];
}
}
}
// Strassen算法实现
void strassen(int** A, int** B, int** C, int n) {
// 当矩阵较小时,使用普通乘法
if (n <= 64) {
multiply_matrices(A, B, C, n);
return;
}
int new_size = n / 2;
// 分配子矩阵内存
int** A11 = allocate_matrix(new_size);
int** A12 = allocate_matrix(new_size);
int** A21 = allocate_matrix(new_size);
int** A22 = allocate_matrix(new_size);
int** B11 = allocate_matrix(new_size);
int** B12 = allocate_matrix(new_size);
int** B21 = allocate_matrix(new_size);
int** B22 = allocate_matrix(new_size);
int** C11 = allocate_matrix(new_size);
int** C12 = allocate_matrix(new_size);
int** C21 = allocate_matrix(new_size);
int** C22 = allocate_matrix(new_size);
int** P1 = allocate_matrix(new_size);
int** P2 = allocate_matrix(new_size);
int** P3 = allocate_matrix(new_size);
int** P4 = allocate_matrix(new_size);
int** P5 = allocate_matrix(new_size);
int** P6 = allocate_matrix(new_size);
int** P7 = allocate_matrix(new_size);
int** temp1 = allocate_matrix(new_size);
int** temp2 = allocate_matrix(new_size);
// 分割矩阵A和B
split_matrix(A, A11, n, 0, 0);
split_matrix(A, A12, n, 0, new_size);
split_matrix(A, A21, n, new_size, 0);
split_matrix(A, A22, n, new_size, new_size);
split_matrix(B, B11, n, 0, 0);
split_matrix(B, B12, n, 0, new_size);
split_matrix(B, B21, n, new_size, 0);
split_matrix(B, B22, n, new_size, new_size);
// 计算P1到P7
subtract_matrices(B12, B22, temp1, new_size); // B12 - B22
strassen(A11, temp1, P1, new_size); // P1 = A11 * (B12 - B22)
add_matrices(A11, A12, temp1, new_size); // A11 + A12
strassen(temp1, B22, P2, new_size); // P2 = (A11 + A12) * B22
add_matrices(A21, A22, temp1, new_size); // A21 + A22
strassen(temp1, B11, P3, new_size); // P3 = (A21 + A22) * B11
subtract_matrices(B21, B11, temp1, new_size); // B21 - B11
strassen(A22, temp1, P4, new_size); // P4 = A22 * (B21 - B11)
add_matrices(A11, A22, temp1, new_size); // A11 + A22
add_matrices(B11, B22, temp2, new_size); // B11 + B22
strassen(temp1, temp2, P5, new_size); // P5 = (A11 + A22) * (B11 + B22)
subtract_matrices(A12, A22, temp1, new_size); // A12 - A22
add_matrices(B21, B22, temp2, new_size); // B21 + B22
strassen(temp1, temp2, P6, new_size); // P6 = (A12 - A22) * (B21 + B22)
subtract_matrices(A11, A21, temp1, new_size); // A11 - A21
add_matrices(B11, B12, temp2, new_size); // B11 + B12
strassen(temp1, temp2, P7, new_size); // P7 = (A11 - A21) * (B11 + B12)
// 计算C11, C12, C21, C22
add_matrices(P5, P4, temp1, new_size); // P5 + P4
subtract_matrices(temp1, P2, temp2, new_size); // (P5 + P4) - P2
add_matrices(temp2, P6, C11, new_size); // C11 = ((P5 + P4) - P2) + P6
add_matrices(P1, P2, C12, new_size); // C12 = P1 + P2
add_matrices(P3, P4, C21, new_size); // C21 = P3 + P4
add_matrices(P1, P5, temp1, new_size); // P1 + P5
subtract_matrices(temp1, P3, temp2, new_size); // (P1 + P5) - P3
subtract_matrices(temp2, P7, C22, new_size); // C22 = ((P1 + P5) - P3) - P7
// 合并结果到C矩阵
merge_matrices(C11, C, new_size, 0, 0);
merge_matrices(C12, C, new_size, 0, new_size);
merge_matrices(C21, C, new_size, new_size, 0);
merge_matrices(C22, C, new_size, new_size, new_size);
// 释放内存
free_matrix(A11, new_size);
free_matrix(A12, new_size);
free_matrix(A21, new_size);
free_matrix(A22, new_size);
free_matrix(B11, new_size);
free_matrix(B12, new_size);
free_matrix(B21, new_size);
free_matrix(B22, new_size);
free_matrix(C11, new_size);
free_matrix(C12, new_size);
free_matrix(C21, new_size);
free_matrix(C22, new_size);
free_matrix(P1, new_size);
free_matrix(P2, new_size);
free_matrix(P3, new_size);
free_matrix(P4, new_size);
free_matrix(P5, new_size);
free_matrix(P6, new_size);
free_matrix(P7, new_size);
free_matrix(temp1, new_size);
free_matrix(temp2, new_size);
}
// 打印矩阵
void print_matrix(int** matrix, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
printf("%d ", matrix[i][j]);
}
printf("\n");
}
}
int main() {
int n = 8; // 矩阵大小(必须是2的幂)
// 分配矩阵内存
int** A = allocate_matrix(n);
int** B = allocate_matrix(n);
int** C = allocate_matrix(n);
// 初始化矩阵A和B(这里使用随机值)
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
A[i][j] = i + j;
B[i][j] = i * j;
}
}
printf("Matrix A:\n");
print_matrix(A, n);
printf("\nMatrix B:\n");
print_matrix(B, n);
// 使用Strassen算法计算矩阵乘法
strassen(A, B, C, n);
printf("\nResult Matrix C (A * B):\n");
print_matrix(C, n);
// 释放内存
free_matrix(A, n);
free_matrix(B, n);
free_matrix(C, n);
return 0;
}
代码说明
-
矩阵操作函数:
allocate_matrix:分配矩阵内存free_matrix:释放矩阵内存copy_matrix:复制矩阵add_matrices:矩阵加法subtract_matrices:矩阵减法multiply_matrices:普通矩阵乘法(用于小矩阵)
-
Strassen算法核心:
split_matrix:将矩阵分成四个子矩阵merge_matrices:将四个子矩阵合并成一个矩阵strassen:Strassen算法的主函数,递归计算矩阵乘法
-
优化:
- 当矩阵大小小于等于64时,使用普通乘法(阈值可根据实际情况调整)
- 递归地将大矩阵分解为小矩阵,直到达到阈值
-
主函数:
(图片来源网络,侵删)- 初始化两个矩阵A和B
- 调用Strassen算法计算矩阵乘法
- 打印结果
注意事项
-
输入矩阵的大小必须是2的幂(如2, 4, 8, 16, ...),如果不是,需要填充到最近的2的幂。
-
对于非常大的矩阵,递归实现可能会导致栈溢出,可以考虑使用迭代实现。
-
实际应用中,可以根据硬件性能调整阈值(如代码中的64),以获得最佳性能。
-
这个实现使用int类型存储矩阵元素,如果需要处理大数,可以改为long long或其他类型。
(图片来源网络,侵删)
