Note
Go to the end to download the full example code.
Template Kernels¶
Author: Hongzheng Chen (hzchen@cs.cornell.edu)
This document explains how to write a template kernel in Allo. Template kernels are useful when we need to reuse a kernel with different data types or when certain computation patterns depend on specific constants. By leveraging template kernels, we can achieve greater flexibility and reusability in the code.
import allo
from allo.ir.types import int32, float32
We follow Python’s convention to use type variable to define a template kernel.
Specifically, the type variable is specified after the function name using square brackets: def kernel[T](...)
, and the type variable can be used in the function signature and body.
Importantly, as the native Python interpreter does not support Allo’s type declaration (i.e., base type + shape), we need to use string annotations like "T[10]"
to specify the type of the variables.
Otherwise, it will raise a type error.
In the following, we define a simple addition function that adds 1 to each element of the input array.
To invoke the kernel with a specific data type, we can use the instantiate
argument in the allo.customize
function.
def kernel[T](A: "T[10]") -> "T[10]":
B: T[10]
for i in range(10):
B[i] = A[i] + 1
return B
s = allo.customize(kernel, instantiate=[int32])
print(s.module)
module {
func.func @kernel(%arg0: memref<10xi32>) -> memref<10xi32> attributes {itypes = "s", otypes = "s"} {
%alloc = memref.alloc() {name = "B"} : memref<10xi32>
affine.for %arg1 = 0 to 10 {
%0 = affine.load %arg0[%arg1] {from = "A"} : memref<10xi32>
%1 = arith.extsi %0 : i32 to i33
%c1_i32 = arith.constant 1 : i32
%2 = arith.extsi %c1_i32 : i32 to i33
%3 = arith.addi %1, %2 : i33
%4 = arith.trunci %3 : i33 to i32
affine.store %4, %alloc[%arg1] {to = "B"} : memref<10xi32>
} {loop_name = "i", op_name = "S_i_0"}
return %alloc : memref<10xi32>
}
}
We can see that the kernel is specialized with the given int32
data type.
Similarly, we can directly declare a new kernel by specifying float32
as the data type.
s = allo.customize(kernel, instantiate=[float32])
print(s.module)
module {
func.func @kernel(%arg0: memref<10xf32>) -> memref<10xf32> attributes {itypes = "_", otypes = "_"} {
%alloc = memref.alloc() {name = "B"} : memref<10xf32>
affine.for %arg1 = 0 to 10 {
%0 = affine.load %arg0[%arg1] {from = "A"} : memref<10xf32>
%c1_i32 = arith.constant 1 : i32
%1 = arith.sitofp %c1_i32 : i32 to f32
%2 = arith.addf %0, %1 : f32
affine.store %2, %alloc[%arg1] {to = "B"} : memref<10xf32>
} {loop_name = "i", op_name = "S_i_0"}
return %alloc : memref<10xf32>
}
}
If we not only want to specialize the data type but also the shape of the array, we can provide another type variable, and pass it to the instantiate
argument.
Note that here we also use the <type_var>: base_type
notation to constrain the type of the type variable. Here we constrain the type variable M
to be an integer.
def kernel2[T, M: int32](A: "T[M]") -> "T[M]":
B: T[M]
for i in range(M):
B[i] = A[i] + 1
return B
s = allo.customize(kernel2, instantiate=[int32, 20])
print(s.module)
module {
func.func @kernel2(%arg0: memref<20xi32>) -> memref<20xi32> attributes {itypes = "s", otypes = "s"} {
%alloc = memref.alloc() {name = "B"} : memref<20xi32>
affine.for %arg1 = 0 to 20 {
%0 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
%1 = arith.extsi %0 : i32 to i33
%c1_i32 = arith.constant 1 : i32
%2 = arith.extsi %c1_i32 : i32 to i33
%3 = arith.addi %1, %2 : i33
%4 = arith.trunci %3 : i33 to i32
affine.store %4, %alloc[%arg1] {to = "B"} : memref<20xi32>
} {loop_name = "i", op_name = "S_i_0"}
return %alloc : memref<20xi32>
}
}
Furthermore, Allo’s template also enables metaprogramming that can evaluate type variables at compile time.
Specifically, we can use the allo.meta_if
, allo.meta_elif
, and allo.meta_else
to conditionally generate code based on the type variables.
Just to make sure the conditions can be determined at compile time.
def kernel3[T, M: int32](A: "T[M]") -> "T[M]":
B: T[M]
for i in range(M):
with allo.meta_if(T == int32):
B[i] = A[i] + 1
with allo.meta_else():
B[i] = A[i] - 1
return B
In final generated code, we can see that only a single branch is generated based on the given data type.
s = allo.customize(kernel3, instantiate=[int32, 20])
print(s.module)
s = allo.customize(kernel3, instantiate=[float32, 20])
print(s.module)
module {
func.func @kernel3(%arg0: memref<20xi32>) -> memref<20xi32> attributes {itypes = "s", otypes = "s"} {
%alloc = memref.alloc() {name = "B"} : memref<20xi32>
affine.for %arg1 = 0 to 20 {
%0 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
%1 = arith.extsi %0 : i32 to i33
%c1_i32 = arith.constant 1 : i32
%2 = arith.extsi %c1_i32 : i32 to i33
%3 = arith.addi %1, %2 : i33
%4 = arith.trunci %3 : i33 to i32
affine.store %4, %alloc[%arg1] {to = "B"} : memref<20xi32>
} {loop_name = "i", op_name = "S_i_0"}
return %alloc : memref<20xi32>
}
}
module {
func.func @kernel3(%arg0: memref<20xf32>) -> memref<20xf32> attributes {itypes = "_", otypes = "_"} {
%alloc = memref.alloc() {name = "B"} : memref<20xf32>
affine.for %arg1 = 0 to 20 {
%0 = affine.load %arg0[%arg1] {from = "A"} : memref<20xf32>
%c1_i32 = arith.constant 1 : i32
%1 = arith.sitofp %c1_i32 : i32 to f32
%2 = arith.subf %0, %1 : f32
affine.store %2, %alloc[%arg1] {to = "B"} : memref<20xf32>
} {loop_name = "i", op_name = "S_i_0"}
return %alloc : memref<20xf32>
}
}
Total running time of the script: (0 minutes 0.348 seconds)