MLSYS ENGINEERING

2.4. Matmul (a naive implementation)

Now, we know matmul mathematically. How can we implement it in Python?

The most straightforward way to implement matmul is using three nested for loops, as shown in the following code snippet. The nested loops iterate over the m rows of A, the n columns of B, and the k pairs of elements to compute the inner product.

Code 3. Naive matmul implementation.
import numpy as np

def matmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    assert a.shape[1] == b.shape[0], "Incompatible shapes."
    m, k = a.shape
    _, n = b.shape
    output = np.zeros((m, n))

    for i in range(m):
        for j in range(n):
            for l in range(k):
                output[i, j] += a[i, l] * b[l, j]

    return output

In this chapter, we have covered the basic concepts of tensor, ops, and matmul. Starting from the next chapter, we will gradually optimize the performance of the matmul op from software to hardware while covering the foundational knowledge in MLSys.