Note
Go to the end to download the full example code.
Kernel Composition¶
Author: Hongzheng Chen (hzchen@cs.cornell.edu)
This document will discuss kernel composition. In the previous tutorials, we have seen how to write a simple kernel. However, in real applications, we often need to compose multiple kernels together.
In the following example, we define a matrix_add
and a gemm
kernel, and wrap them into a top
-level function.
import allo
from allo.ir.types import int32, float32
M, K, N = 32, 32, 32
def matrix_add(A: int32[M, N]) -> int32[M, N]:
B: int32[M, N] = 0
for i, j in allo.grid(M, N):
B[i, j] = A[i, j] + 1
return B
def gemm(A: int32[M, K], B: int32[K, N]) -> int32[M, N]:
C: int32[M, N] = 0
for i, j in allo.grid(M, N):
for k in allo.reduction(K):
C[i, j] += A[i, k] * B[k, j]
return C
def top(A: int32[M, K], B: int32[K, N]) -> int32[M, N]:
C = gemm(A, B)
D = matrix_add(C)
return D
Different teams or people can then work on different parts of the code and optimize each kernel.
We first create a schedule for the matrix_add
kernel, and add several optimizations.
s1 = allo.customize(matrix_add)
s1.pipeline("j")
print(s1.module)
module {
func.func @matrix_add(%arg0: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "s", otypes = "s"} {
%alloc = memref.alloc() {name = "B"} : memref<32x32xi32>
%c0_i32 = arith.constant 0 : i32
linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
affine.for %arg1 = 0 to 32 {
affine.for %arg2 = 0 to 32 {
%0 = affine.load %arg0[%arg1, %arg2] {from = "A"} : memref<32x32xi32>
%1 = arith.extsi %0 : i32 to i33
%c1_i32 = arith.constant 1 : i32
%2 = arith.extsi %c1_i32 : i32 to i33
%3 = arith.addi %1, %2 : i33
%4 = arith.trunci %3 : i33 to i32
affine.store %4, %alloc[%arg1, %arg2] {to = "B"} : memref<32x32xi32>
} {loop_name = "j", pipeline_ii = 1 : ui32}
} {loop_name = "i", op_name = "S_i_j_0"}
return %alloc : memref<32x32xi32>
}
}
Then we create a schedule for the gemm
kernel and optimize it.
s2 = allo.customize(gemm)
s2.reorder("k", "j")
s2.buffer_at(s2.C, axis="i")
s2.pipeline("j")
print(s2.module)
module {
func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
%alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
%c0_i32 = arith.constant 0 : i32
linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
affine.for %arg2 = 0 to 32 {
%alloc_0 = memref.alloc() : memref<32xi32>
affine.for %arg3 = 0 to 32 {
affine.store %c0_i32, %alloc_0[%arg3] : memref<32xi32>
} {buffer, loop_name = "j_init", pipeline_ii = 1 : i32}
affine.for %arg3 = 0 to 32 {
affine.for %arg4 = 0 to 32 {
%0 = affine.load %arg0[%arg2, %arg3] {from = "A"} : memref<32x32xi32>
%1 = affine.load %arg1[%arg3, %arg4] {from = "B"} : memref<32x32xi32>
%2 = arith.extsi %0 : i32 to i64
%3 = arith.extsi %1 : i32 to i64
%4 = arith.muli %2, %3 : i64
%5 = affine.load %alloc_0[%arg4] : memref<32xi32>
%6 = arith.trunci %4 : i64 to i32
%7 = arith.addi %5, %6 : i32
affine.store %7, %alloc_0[%arg4] : memref<32xi32>
} {loop_name = "j", pipeline_ii = 1 : ui32}
} {loop_name = "k", op_name = "S_k_0", reduction}
affine.for %arg3 = 0 to 32 {
%0 = affine.load %alloc_0[%arg3] : memref<32xi32>
affine.store %0, %alloc[%arg2, %arg3] : memref<32x32xi32>
} {buffer, loop_name = "j_back", pipeline_ii = 1 : i32}
} {loop_name = "i", op_name = "S_i_j_0"}
return %alloc : memref<32x32xi32>
}
}
Notice that now we only optimize the separate kernels but do not incorporate them into the top-level function, as shown in the following printed module.
s = allo.customize(top)
print(s.module)
module {
func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
%alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
%c0_i32 = arith.constant 0 : i32
linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
affine.for %arg2 = 0 to 32 {
affine.for %arg3 = 0 to 32 {
affine.for %arg4 = 0 to 32 {
%0 = affine.load %arg0[%arg2, %arg4] {from = "A"} : memref<32x32xi32>
%1 = affine.load %arg1[%arg4, %arg3] {from = "B"} : memref<32x32xi32>
%2 = arith.extsi %0 : i32 to i64
%3 = arith.extsi %1 : i32 to i64
%4 = arith.muli %2, %3 : i64
%5 = affine.load %alloc[%arg2, %arg3] {from = "C"} : memref<32x32xi32>
%6 = arith.trunci %4 : i64 to i32
%7 = arith.addi %5, %6 : i32
affine.store %7, %alloc[%arg2, %arg3] {to = "C"} : memref<32x32xi32>
} {loop_name = "k", op_name = "S_k_0", reduction}
} {loop_name = "j"}
} {loop_name = "i", op_name = "S_i_j_0"}
return %alloc : memref<32x32xi32>
}
func.func @matrix_add(%arg0: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "s", otypes = "s"} {
%alloc = memref.alloc() {name = "B"} : memref<32x32xi32>
%c0_i32 = arith.constant 0 : i32
linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
affine.for %arg1 = 0 to 32 {
affine.for %arg2 = 0 to 32 {
%0 = affine.load %arg0[%arg1, %arg2] {from = "A"} : memref<32x32xi32>
%1 = arith.extsi %0 : i32 to i33
%c1_i32 = arith.constant 1 : i32
%2 = arith.extsi %c1_i32 : i32 to i33
%3 = arith.addi %1, %2 : i33
%4 = arith.trunci %3 : i33 to i32
affine.store %4, %alloc[%arg1, %arg2] {to = "B"} : memref<32x32xi32>
} {loop_name = "j"}
} {loop_name = "i", op_name = "S_i_j_0"}
return %alloc : memref<32x32xi32>
}
func.func @top(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
%0 = call @gemm(%arg0, %arg1) {name = "C"} : (memref<32x32xi32>, memref<32x32xi32>) -> memref<32x32xi32>
%1 = call @matrix_add(%0) {name = "D"} : (memref<32x32xi32>) -> memref<32x32xi32>
return %1 : memref<32x32xi32>
}
}
Therefore, after each part has been optimized, we need to explicitly compose them together.
In Allo, we can use the .compose()
primitive to compose the schedules together into the parent function.
s.compose([s1, s2])
print(s.module)
module {
func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
%alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
%c0_i32 = arith.constant 0 : i32
linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
affine.for %arg2 = 0 to 32 {
%alloc_0 = memref.alloc() : memref<32xi32>
affine.for %arg3 = 0 to 32 {
affine.store %c0_i32, %alloc_0[%arg3] : memref<32xi32>
} {buffer, loop_name = "j_init", pipeline_ii = 1 : i32}
affine.for %arg3 = 0 to 32 {
affine.for %arg4 = 0 to 32 {
%0 = affine.load %arg0[%arg2, %arg3] {from = "A"} : memref<32x32xi32>
%1 = affine.load %arg1[%arg3, %arg4] {from = "B"} : memref<32x32xi32>
%2 = arith.extsi %0 : i32 to i64
%3 = arith.extsi %1 : i32 to i64
%4 = arith.muli %2, %3 : i64
%5 = affine.load %alloc_0[%arg4] : memref<32xi32>
%6 = arith.trunci %4 : i64 to i32
%7 = arith.addi %5, %6 : i32
affine.store %7, %alloc_0[%arg4] : memref<32xi32>
} {loop_name = "j", pipeline_ii = 1 : ui32}
} {loop_name = "k", op_name = "S_k_0", reduction}
affine.for %arg3 = 0 to 32 {
%0 = affine.load %alloc_0[%arg3] : memref<32xi32>
affine.store %0, %alloc[%arg2, %arg3] : memref<32x32xi32>
} {buffer, loop_name = "j_back", pipeline_ii = 1 : i32}
} {loop_name = "i", op_name = "S_i_j_0"}
return %alloc : memref<32x32xi32>
}
func.func @matrix_add(%arg0: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "s", otypes = "s"} {
%alloc = memref.alloc() {name = "B"} : memref<32x32xi32>
%c0_i32 = arith.constant 0 : i32
linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
affine.for %arg1 = 0 to 32 {
affine.for %arg2 = 0 to 32 {
%0 = affine.load %arg0[%arg1, %arg2] {from = "A"} : memref<32x32xi32>
%1 = arith.extsi %0 : i32 to i33
%c1_i32 = arith.constant 1 : i32
%2 = arith.extsi %c1_i32 : i32 to i33
%3 = arith.addi %1, %2 : i33
%4 = arith.trunci %3 : i33 to i32
affine.store %4, %alloc[%arg1, %arg2] {to = "B"} : memref<32x32xi32>
} {loop_name = "j", pipeline_ii = 1 : ui32}
} {loop_name = "i", op_name = "S_i_j_0"}
return %alloc : memref<32x32xi32>
}
func.func @top(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
%0 = call @gemm(%arg0, %arg1) {name = "C"} : (memref<32x32xi32>, memref<32x32xi32>) -> memref<32x32xi32>
%1 = call @matrix_add(%0) {name = "D"} : (memref<32x32xi32>) -> memref<32x32xi32>
return %1 : memref<32x32xi32>
}
}
We can see that the schedules for the matrix_add
and gemm
kernels are both correctly optimized in the top-level function.
Template Composition¶
Sometimes we may define template kernels and invoke the kernel with different template arguments. Allo provides an id option to specify the exact kernel to be composed.
def kernel[T_in, T_out, S](A: "T_in[S]") -> "T_out[S]":
B: T_out[S] = 0
for i in range(S):
with allo.meta_if(T_out == int32):
B[i] = A[i] + 1
with allo.meta_else():
B[i] = A[i] * 2
return B
def top2(A: int32[M]) -> float32[M]:
C = kernel[int32, int32, M, "K1"](A)
D = kernel[int32, float32, M, "K2"](C)
return D
Specifically, the last argument of the template kernel is the id of the kernel. Later on we can use this ID for distinguishing different kernels during composition. We also customize the two template kernels with different optimizations first.
s1 = allo.customize(kernel, instantiate=[int32, int32, M])
s1.unroll("i", factor=4)
print(s1.module)
s2 = allo.customize(kernel, instantiate=[int32, float32, M])
s2.pipeline("i")
print(s2.module)
module {
func.func @kernel(%arg0: memref<32xi32>) -> memref<32xi32> attributes {itypes = "s", otypes = "s"} {
%alloc = memref.alloc() {name = "B"} : memref<32xi32>
%c0_i32 = arith.constant 0 : i32
linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32xi32>)
affine.for %arg1 = 0 to 32 {
%0 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
%1 = arith.extsi %0 : i32 to i33
%c1_i32 = arith.constant 1 : i32
%2 = arith.extsi %c1_i32 : i32 to i33
%3 = arith.addi %1, %2 : i33
%4 = arith.trunci %3 : i33 to i32
affine.store %4, %alloc[%arg1] {to = "B"} : memref<32xi32>
} {loop_name = "i", op_name = "S_i_0", unroll = 4 : i32}
return %alloc : memref<32xi32>
}
}
module {
func.func @kernel(%arg0: memref<32xi32>) -> memref<32xf32> attributes {itypes = "s", otypes = "_"} {
%c0_i32 = arith.constant 0 : i32
%0 = arith.sitofp %c0_i32 : i32 to f32
%alloc = memref.alloc() {name = "B"} : memref<32xf32>
linalg.fill ins(%0 : f32) outs(%alloc : memref<32xf32>)
affine.for %arg1 = 0 to 32 {
%1 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
%2 = arith.extsi %1 : i32 to i64
%c2_i32 = arith.constant 2 : i32
%3 = arith.extsi %c2_i32 : i32 to i64
%4 = arith.muli %2, %3 : i64
%5 = arith.sitofp %4 : i64 to f32
affine.store %5, %alloc[%arg1] {to = "B"} : memref<32xf32>
} {loop_name = "i", op_name = "S_i_0", pipeline_ii = 1 : ui32}
return %alloc : memref<32xf32>
}
}
Finally, we compose the two template kernels into the top-level function with the ID specified.
s = allo.customize(top2)
s.compose(s1, id="K1")
s.compose(s2, id="K2")
print(s.module)
module {
func.func @kernel_K1(%arg0: memref<32xi32>) -> memref<32xi32> attributes {itypes = "s", otypes = "s"} {
%alloc = memref.alloc() {name = "B"} : memref<32xi32>
%c0_i32 = arith.constant 0 : i32
linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32xi32>)
affine.for %arg1 = 0 to 32 {
%0 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
%1 = arith.extsi %0 : i32 to i33
%c1_i32 = arith.constant 1 : i32
%2 = arith.extsi %c1_i32 : i32 to i33
%3 = arith.addi %1, %2 : i33
%4 = arith.trunci %3 : i33 to i32
affine.store %4, %alloc[%arg1] {to = "B"} : memref<32xi32>
} {loop_name = "i", op_name = "S_i_0", unroll = 4 : i32}
return %alloc : memref<32xi32>
}
func.func @kernel_K2(%arg0: memref<32xi32>) -> memref<32xf32> attributes {itypes = "s", otypes = "_"} {
%c0_i32 = arith.constant 0 : i32
%0 = arith.sitofp %c0_i32 : i32 to f32
%alloc = memref.alloc() {name = "B"} : memref<32xf32>
linalg.fill ins(%0 : f32) outs(%alloc : memref<32xf32>)
affine.for %arg1 = 0 to 32 {
%1 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
%2 = arith.extsi %1 : i32 to i64
%c2_i32 = arith.constant 2 : i32
%3 = arith.extsi %c2_i32 : i32 to i64
%4 = arith.muli %2, %3 : i64
%5 = arith.sitofp %4 : i64 to f32
affine.store %5, %alloc[%arg1] {to = "B"} : memref<32xf32>
} {loop_name = "i", op_name = "S_i_0", pipeline_ii = 1 : ui32}
return %alloc : memref<32xf32>
}
func.func @top2(%arg0: memref<32xi32>) -> memref<32xf32> attributes {itypes = "s", otypes = "_"} {
%0 = call @kernel_K1(%arg0) {name = "C"} : (memref<32xi32>) -> memref<32xi32>
%1 = call @kernel_K2(%0) {name = "D"} : (memref<32xi32>) -> memref<32xf32>
return %1 : memref<32xf32>
}
}
We can see from the printed module that the loop in the first kernel is unrolled by a factor of 4, and the loop in the second kernel is pipelined.
Total running time of the script: (0 minutes 0.804 seconds)