Strassen’s Algorithm for Matrix multiplication

Strassen’s Algorithm for Matrix multiplication is a recursive algorithm for multiplying n x n matrices in O(nlog(7)) ie O(n2.81) time. It outperforms the naive O(n3) matrix multiplication algorithm.

Naive Matrix-Multiplication (A, B) Pseudocode
1. n = Length[A]
2. C = n x n matrix
3. for i = 1 to n
4.   do for j = 1 to n
5.     do Cij = 0
6.       for k = 1 to n
7.         do Cij = Cij + aik * bkj
8. return C

Code Implementation

//
//  main.cpp
//  Matrix multiplication
//
//  Created by Himanshu on 17/09/21.
//

#include <iostream>
using namespace std;
const int N = 2, M = 2;

void multiplyMatrices (int A[N][M], int B[N][M]) {
    int C[N][M];
    
    for (int i=0; i<N; i++) {
        for (int j=0; j<M; j++) {
            C[i][j] = 0;
            for (int k=0; k<N; k++) {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
    
    cout<<"Multiplication of Matrices A and B:"<<endl;
    
    for (int i=0; i<N; i++) {
        for (int j=0; j<M; j++) {
            cout<<C[i][j]<<" ";
        }
        cout<<endl;
    }
    
    
}
 
int main() {
    int A[N][M] = {{1, 2}, {3, 4}};
    int B[N][M] = {{5, 6}, {7, 8}};
    
    multiplyMatrices(A, B);
}

Output:

Multiplication of Matrices A and B:
19 22 
43 50 

Here’s a working example: Matrix Multiplication

Strassen’s Algorithm for Matrix multiplication

Strassen’s algorithm is based on a familiar design technique – Divide & Conquer. Now if we wish to compute the product C = AB, where each of A, B and C are n x n (2 x 2) matrices. Let, we divide each of A, B and C into four n/2 x n/2 matrices, and rewrite C = A x B as follows:

Matrix Multiplication
C = A x B

This equation corresponds to the following 4 equations:

r = ae + bg
s = af + bh
t = ce + dg
u = cf + dh

Each of these 4 equations specifies multiplication of n/2 x n/2 matrices and addition of their n/2 x n/2 products. Using divide-and-conquer strategy, we derive the following recurrence relation:

T(n) = 8T(n/2) + O(n2)
T(n) = 23 * T(n/2) + O(n2)
T(n) = O(n3) (using Master theorem)

Here we do 8 multiplications for matrices of size N/2 x N/2 or (1 x 1) and 4 (N2) additions. Hence the above recurrence relation. The above recurrence has T(n) = O(n3) solution.

Now, Strassen discovered a different recursive approach that requires only 7 recursive multiplication of (N/2) x (N/2) matrices and O(n2) scalar additions and subtractions, resulting in:

T(n) = 7T(n/2) + O(n2)
T(n) = 2lg7 * T(n/2) + O(n2)
T(n) = O(nlg7)
T(n) = O(n2.81)
(using Master theorem)

which is a great improvement than naive O(n3) for a larger N.

Strassen’s Matrices
P1 = A1B1 = (a.(f - h)) = af - ah
P2 = A2B2 = ((a + b).h) = ah + bh
P3 = A3B3 = ((c + d).e) = ce + de
P4 = A4B4 = (d.(g - e)) = dg - de
P5 = A5B5 = ((a + d).(e + h)) = ae + ah + de + dh
P6 = A6B6 = ((b - d).(g + h)) = bg + bh - dg - dh
P7 = A7B7 = ((a - c).(e + f)) = ae + af - ce - cf

Now, these 7 matrices (P1, P2, P3, P4, P5, P6, P7) can be used to calculate – r, s, t, u (C = A x B) in the following manner:

r = P5 + P4 - P2 + P6
  = ae + bg
s = P1 + P2
  = af + bh
t = P3 + P4
  = ce + dg
u = P5 + P1 - P3 - P7
  = cf + dh

Thus, we calculated C = A x B using only 7 matrix multiplication.

However, Strassen’s algorithm is often not the method of choice for practical applications, for 4 reasons:

  1. The constant factor hidden in the running time of Strassen’s algorithm is larger than the constant factor in the naive O(n3) method.
  2. When the matrices are sparse, specially tailored methods for sparse matrices are fast.
  3. Strassen’s algorithm is not quite as numerically stable as naive method.
  4. The sub-matrices formed at the levels of recursion consume space.

Reference:
Introduction to Algorithms – (CLRS book)

4 thoughts on “Strassen’s Algorithm for Matrix multiplication

Leave a Reply

Your email address will not be published. Required fields are marked *