Strassen’s Algorithm for Matrix multiplication

Strassen’s Algorithm for Matrix multiplication is a recursive algorithm for multiplying n x n matrices in O(n^{log(7)}) ie O(n^{2.81}) time. It outperforms the naive O(n^3) matrix multiplication algorithm.

Naive Matrix-Multiplication (A, B) Pseudocode

1. n = Length[A]
2. C = n x n matrix
3. for i = 1 to n
4.   do for j = 1 to n
5.     do Cij = 0
6.       for k = 1 to n
7.         do Cij = Cij + aik * bkj
8. return C

Code Implementation

//
//  main.cpp
//  Matrix multiplication
//
//  Created by Himanshu on 17/09/21.
//

#include <iostream>
using namespace std;
const int N = 2, M = 2;

void multiplyMatrices (int A[N][M], int B[N][M]) {
    int C[N][M];
    
    for (int i=0; i<N; i++) {
        for (int j=0; j<M; j++) {
            C[i][j] = 0;
            for (int k=0; k<N; k++) {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
    
    cout<<"Multiplication of Matrices A and B:"<<endl;
    
    for (int i=0; i<N; i++) {
        for (int j=0; j<M; j++) {
            cout<<C[i][j]<<" ";
        }
        cout<<endl;
    }
    
    
}
 
int main() {
    int A[N][M] = {{1, 2}, {3, 4}};
    int B[N][M] = {{5, 6}, {7, 8}};
    
    multiplyMatrices(A, B);
}

Output:

Multiplication of Matrices A and B:
19 22 
43 50 

Here’s a working example:
Matrix Multiplication

Strassen’s Algorithm for Matrix multiplication

Strassen’s algorithm is based on a familiar design technique – Divide & Conquer. Now if we wish to compute the product C = AB, where each of A, B and C are n x n (2 x 2) matrices. Let, we divide each of A, B and C into four n/2 x n/2 matrices, and rewrite C = A x B as follows:

Matrix Multiplication
C = A x B

This equation corresponds to the following 4 equations:

r = ae + bg
s = af + bh
t = ce + dg
u = cf + dh

Each of these 4 equations specifies multiplication of n/2 x n/2 matrices and addition of their n/2 x n/2 products. Using divide-and-conquer strategy, we derive the following recurrence relation:

T(n) = 8T(n/2) + O(n^2)

Here we do 8 multiplications for matrices of size N/2 x N/2 or (1 x 1) and 4 (N^2) additions. Hence the above recurrence relation. The above recurrence has T(n) = O(n^3) solution.

Now, Strassen discovered a different recursive approach that requires only 7 recursive multiplication of (N/2) x (N/2) matrices and O(n^2) scalar additions and subtractions, resulting in:

T(n) = 7T(n/2) + O(n^{2})
T(n) = O(n^{lg7})
T(n) = O(n^{2.81})

which is a great improvement than naive O(n^3) for a larger N.

Strassen’s Matrices
P1 = A1B1 = (a.(f - h)) = af - ah
P2 = A2B2 = ((a + b).h) = ah + bh
P3 = A3B3 = ((c + d).e) = ce + de
P4 = A4B4 = (d.(g - e)) = dg - de
P5 = A5B5 = ((a + d).(e + h)) = ae + ah + de + dh
P6 = A6B6 = ((b - d).(g + h)) = bg + bh - dg - dh
P7 = A7B7 = ((a - c).(e + f)) = ae + af - ce - cf

Now, these 7 matrices (P1, P2, P3, P4, P5, P6, P7) can be used to calculate – r, s, t, u (C = A x B) in the following manner:

r = P5 + P4 - P2 + P6
  = ae + bg
s = P1 + P2
  = af + bh
t = P3 + P4
  = ce + dg
u = P5 + P1 - P3 - P7
  = cf + dh

Thus, we calculated C = A x B using only 7 matrix multiplication.

However, Strassen’s algorithm is often not the method of choice for practical applications, for 4 reasons:

  1. The constant factor hidden in the running time of Strassen’s algorithm is larger than the constant factor in the naive O(n^3) method.
  2. When the matrices are sparse, specially tailored methods for sparse matrices are fast.
  3. Strassen’s algorithm is not quite as numerically stable as naive method.
  4. The sub-matrices formed at the levels of recursion consume space.

Reference:
Introduction to Algorithms – (CLRS book)

One thought on “Strassen’s Algorithm for Matrix multiplication

Leave a Reply