Bringing support for Nvidia and Apple GPUs to Luminal through compilers
Image Credit: https://exxactcorp.com/
a
and b
, which are float arrays. We get the index of the current thread using blockIdx
, blockDim
, and threadIdx
, and use that to index into the input and output arrays.
This kernel gets ran for each element of the input arrays, all in parallel.
crates/luminal_metal/src/prim.rs
. The compiler simply loops through all ops in the graph and swaps them out with the Metal variant.
These primitive operations are very simple. Here’s the MetalExp2 op, slightly simplified for clarity:
a
, and you want b = a.cos().exp()
. Naievely we would allocate an intermediate buffer, launch a Cos kernel to do cos(a)
, write the output to the intermediate buffer, then launch an Exp kernel to do exp(intermediate)
and write the result to an output buffer.
Elementwise fusion does away with that and generates a single kernel that does out = exp(cos(a))
, so no intermediate reads and writes are needed, and only one kernel needs to be launched.
This actually is taken much furthur, fusing unary operations, binary operations, across reshapes, permutes, expands, etc. Turns out we can get very far with this idea!
Here’s an example of how many ops fusion can merge together. On the left is the unfused graph, on the right is the functionally identical fused graph: