3.1. Multi-threading
If we have m pieces of work to do, each of which takes one minute to run,
it would take m minutes to complete all of them if we run them sequentially.
If we run them with m parallel threads, ideally it would take only one minute
to complete all the work. This is the key idea behind using multi-threading to
optimize matmul.
A thread is just like any program you know. It executes code in sequential order. However, we can have many threads running at the same time on different cores of a CPU or GPU, which can significantly reduce the total execution time compared to a sequential run.
The naive matmul implementation in Code 3 (introduced in Matmul (a naive implementation)) relies on three nested loops. We have repeated the code snippet here for your convenience.
import numpy as np
def matmul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
assert a.shape[1] == b.shape[0], "Incompatible shapes."
m, k = a.shape
_, n = b.shape
output = np.zeros((m, n))
for i in range(m):
for j in range(n):
for l in range(k):
output[i, j] += a[i, l] * b[l, j]
return output
In performance engineering, when we see loops, we think about parallelization. The key idea is to divide the workload and run them on different cores in parallel, which is much faster than running them on a single core sequentially.
Now, we want to do the += operation for m * n * k times. A straightforward
approach is to parallelize the outermost loop (i). We can use m threads,
each of which does n * k operations. Ideally, it would reduce the running
time by m times.
We are very flexible in the number of threads to use. For example, we can use 3
threads, each of which takes care of 1/3 of the rows. We can also use m * n * k
total threads, where each thread would handle a single += operation.
How do we choose the right number of threads? Finding the right balance is crucial:
- Too few threads: If the CPU has 8 cores, while we only have 3 threads running in parallel, it would leave the rest 5 cores idle. In such cases, we fail to fully utilize the hardware.
- Too many threads: If you only need to do a small matmul, but you have 400 threads to parallelize them, the system may spend the majority of its time on the overhead of creating, managing, and terminating threads. The actual computation only takes a small portion of the workload. This can actually make the end-to-end execution slower due to the high overhead.
The number of threads is effectively a "performance hyperparameter" that must be tuned to achieve the best results for a given hardware architecture. However, as we will see in the next section, using multiple threads also introduces new risks that must be handled carefully.