Kernel Composition

Author: Hongzheng Chen (hzchen@cs.cornell.edu)

This document will discuss kernel composition. In the previous tutorials, we have seen how to write a simple kernel. However, in real applications, we often need to compose multiple kernels together.

In the following example, we define a matrix_add and a gemm kernel, and wrap them into a top-level function.

import allo
from allo.ir.types import int32, float32

M, K, N = 32, 32, 32


def matrix_add(A: int32[M, N]) -> int32[M, N]:
    B: int32[M, N] = 0
    for i, j in allo.grid(M, N):
        B[i, j] = A[i, j] + 1
    return B


def gemm(A: int32[M, K], B: int32[K, N]) -> int32[M, N]:
    C: int32[M, N] = 0
    for i, j in allo.grid(M, N):
        for k in allo.reduction(K):
            C[i, j] += A[i, k] * B[k, j]
    return C


def top(A: int32[M, K], B: int32[K, N]) -> int32[M, N]:
    C = gemm(A, B)
    D = matrix_add(C)
    return D

Different teams or people can then work on different parts of the code and optimize each kernel. We first create a schedule for the matrix_add kernel, and add several optimizations.

s1 = allo.customize(matrix_add)
s1.pipeline("j")
print(s1.module)
module {
  func.func @matrix_add(%arg0: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "s", otypes = "s"} {
    %alloc = memref.alloc() {name = "B"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg1 = 0 to 32 {
      affine.for %arg2 = 0 to 32 {
        %0 = affine.load %arg0[%arg1, %arg2] {from = "A"} : memref<32x32xi32>
        %1 = arith.extsi %0 : i32 to i33
        %c1_i32 = arith.constant 1 : i32
        %2 = arith.extsi %c1_i32 : i32 to i33
        %3 = arith.addi %1, %2 : i33
        %4 = arith.trunci %3 : i33 to i32
        affine.store %4, %alloc[%arg1, %arg2] {to = "B"} : memref<32x32xi32>
      } {loop_name = "j", pipeline_ii = 1 : ui32}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xi32>
  }
}

Then we create a schedule for the gemm kernel and optimize it.

s2 = allo.customize(gemm)
s2.reorder("k", "j")
s2.buffer_at(s2.C, axis="i")
s2.pipeline("j")
print(s2.module)
module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 32 {
      %alloc_0 = memref.alloc() : memref<32xi32>
      affine.for %arg3 = 0 to 32 {
        affine.store %c0_i32, %alloc_0[%arg3] : memref<32xi32>
      } {buffer, loop_name = "j_init", pipeline_ii = 1 : i32}
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %0 = affine.load %arg0[%arg2, %arg3] {from = "A"} : memref<32x32xi32>
          %1 = affine.load %arg1[%arg3, %arg4] {from = "B"} : memref<32x32xi32>
          %2 = arith.extsi %0 : i32 to i64
          %3 = arith.extsi %1 : i32 to i64
          %4 = arith.muli %2, %3 : i64
          %5 = affine.load %alloc_0[%arg4] : memref<32xi32>
          %6 = arith.trunci %4 : i64 to i32
          %7 = arith.addi %5, %6 : i32
          affine.store %7, %alloc_0[%arg4] : memref<32xi32>
        } {loop_name = "j", pipeline_ii = 1 : ui32}
      } {loop_name = "k", op_name = "S_k_0", reduction}
      affine.for %arg3 = 0 to 32 {
        %0 = affine.load %alloc_0[%arg3] : memref<32xi32>
        affine.store %0, %alloc[%arg2, %arg3] : memref<32x32xi32>
      } {buffer, loop_name = "j_back", pipeline_ii = 1 : i32}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xi32>
  }
}

Notice that now we only optimize the separate kernels but do not incorporate them into the top-level function, as shown in the following printed module.

s = allo.customize(top)
print(s.module)
module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 32 {
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %0 = affine.load %arg0[%arg2, %arg4] {from = "A"} : memref<32x32xi32>
          %1 = affine.load %arg1[%arg4, %arg3] {from = "B"} : memref<32x32xi32>
          %2 = arith.extsi %0 : i32 to i64
          %3 = arith.extsi %1 : i32 to i64
          %4 = arith.muli %2, %3 : i64
          %5 = affine.load %alloc[%arg2, %arg3] {from = "C"} : memref<32x32xi32>
          %6 = arith.trunci %4 : i64 to i32
          %7 = arith.addi %5, %6 : i32
          affine.store %7, %alloc[%arg2, %arg3] {to = "C"} : memref<32x32xi32>
        } {loop_name = "k", op_name = "S_k_0", reduction}
      } {loop_name = "j"}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xi32>
  }
  func.func @matrix_add(%arg0: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "s", otypes = "s"} {
    %alloc = memref.alloc() {name = "B"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg1 = 0 to 32 {
      affine.for %arg2 = 0 to 32 {
        %0 = affine.load %arg0[%arg1, %arg2] {from = "A"} : memref<32x32xi32>
        %1 = arith.extsi %0 : i32 to i33
        %c1_i32 = arith.constant 1 : i32
        %2 = arith.extsi %c1_i32 : i32 to i33
        %3 = arith.addi %1, %2 : i33
        %4 = arith.trunci %3 : i33 to i32
        affine.store %4, %alloc[%arg1, %arg2] {to = "B"} : memref<32x32xi32>
      } {loop_name = "j"}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xi32>
  }
  func.func @top(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %0 = call @gemm(%arg0, %arg1) {name = "C"} : (memref<32x32xi32>, memref<32x32xi32>) -> memref<32x32xi32>
    %1 = call @matrix_add(%0) {name = "D"} : (memref<32x32xi32>) -> memref<32x32xi32>
    return %1 : memref<32x32xi32>
  }
}

Therefore, after each part has been optimized, we need to explicitly compose them together. In Allo, we can use the .compose() primitive to compose the schedules together into the parent function.

s.compose([s1, s2])
print(s.module)
module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 32 {
      %alloc_0 = memref.alloc() : memref<32xi32>
      affine.for %arg3 = 0 to 32 {
        affine.store %c0_i32, %alloc_0[%arg3] : memref<32xi32>
      } {buffer, loop_name = "j_init", pipeline_ii = 1 : i32}
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %0 = affine.load %arg0[%arg2, %arg3] {from = "A"} : memref<32x32xi32>
          %1 = affine.load %arg1[%arg3, %arg4] {from = "B"} : memref<32x32xi32>
          %2 = arith.extsi %0 : i32 to i64
          %3 = arith.extsi %1 : i32 to i64
          %4 = arith.muli %2, %3 : i64
          %5 = affine.load %alloc_0[%arg4] : memref<32xi32>
          %6 = arith.trunci %4 : i64 to i32
          %7 = arith.addi %5, %6 : i32
          affine.store %7, %alloc_0[%arg4] : memref<32xi32>
        } {loop_name = "j", pipeline_ii = 1 : ui32}
      } {loop_name = "k", op_name = "S_k_0", reduction}
      affine.for %arg3 = 0 to 32 {
        %0 = affine.load %alloc_0[%arg3] : memref<32xi32>
        affine.store %0, %alloc[%arg2, %arg3] : memref<32x32xi32>
      } {buffer, loop_name = "j_back", pipeline_ii = 1 : i32}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xi32>
  }
  func.func @matrix_add(%arg0: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "s", otypes = "s"} {
    %alloc = memref.alloc() {name = "B"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg1 = 0 to 32 {
      affine.for %arg2 = 0 to 32 {
        %0 = affine.load %arg0[%arg1, %arg2] {from = "A"} : memref<32x32xi32>
        %1 = arith.extsi %0 : i32 to i33
        %c1_i32 = arith.constant 1 : i32
        %2 = arith.extsi %c1_i32 : i32 to i33
        %3 = arith.addi %1, %2 : i33
        %4 = arith.trunci %3 : i33 to i32
        affine.store %4, %alloc[%arg1, %arg2] {to = "B"} : memref<32x32xi32>
      } {loop_name = "j", pipeline_ii = 1 : ui32}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xi32>
  }
  func.func @top(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %0 = call @gemm(%arg0, %arg1) {name = "C"} : (memref<32x32xi32>, memref<32x32xi32>) -> memref<32x32xi32>
    %1 = call @matrix_add(%0) {name = "D"} : (memref<32x32xi32>) -> memref<32x32xi32>
    return %1 : memref<32x32xi32>
  }
}

We can see that the schedules for the matrix_add and gemm kernels are both correctly optimized in the top-level function.

Template Composition

Sometimes we may define template kernels and invoke the kernel with different template arguments. Allo provides an id option to specify the exact kernel to be composed.

def kernel[T_in, T_out, S](A: "T_in[S]") -> "T_out[S]":
    B: T_out[S] = 0
    for i in range(S):
        with allo.meta_if(T_out == int32):
            B[i] = A[i] + 1
        with allo.meta_else():
            B[i] = A[i] * 2
    return B


def top2(A: int32[M]) -> float32[M]:
    C = kernel[int32, int32, M, "K1"](A)
    D = kernel[int32, float32, M, "K2"](C)
    return D

Specifically, the last argument of the template kernel is the id of the kernel. Later on we can use this ID for distinguishing different kernels during composition. We also customize the two template kernels with different optimizations first.

s1 = allo.customize(kernel, instantiate=[int32, int32, M])
s1.unroll("i", factor=4)
print(s1.module)

s2 = allo.customize(kernel, instantiate=[int32, float32, M])
s2.pipeline("i")
print(s2.module)
module {
  func.func @kernel(%arg0: memref<32xi32>) -> memref<32xi32> attributes {itypes = "s", otypes = "s"} {
    %alloc = memref.alloc() {name = "B"} : memref<32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32xi32>)
    affine.for %arg1 = 0 to 32 {
      %0 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
      %1 = arith.extsi %0 : i32 to i33
      %c1_i32 = arith.constant 1 : i32
      %2 = arith.extsi %c1_i32 : i32 to i33
      %3 = arith.addi %1, %2 : i33
      %4 = arith.trunci %3 : i33 to i32
      affine.store %4, %alloc[%arg1] {to = "B"} : memref<32xi32>
    } {loop_name = "i", op_name = "S_i_0", unroll = 4 : i32}
    return %alloc : memref<32xi32>
  }
}

module {
  func.func @kernel(%arg0: memref<32xi32>) -> memref<32xf32> attributes {itypes = "s", otypes = "_"} {
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.sitofp %c0_i32 : i32 to f32
    %alloc = memref.alloc() {name = "B"} : memref<32xf32>
    linalg.fill ins(%0 : f32) outs(%alloc : memref<32xf32>)
    affine.for %arg1 = 0 to 32 {
      %1 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
      %2 = arith.extsi %1 : i32 to i64
      %c2_i32 = arith.constant 2 : i32
      %3 = arith.extsi %c2_i32 : i32 to i64
      %4 = arith.muli %2, %3 : i64
      %5 = arith.sitofp %4 : i64 to f32
      affine.store %5, %alloc[%arg1] {to = "B"} : memref<32xf32>
    } {loop_name = "i", op_name = "S_i_0", pipeline_ii = 1 : ui32}
    return %alloc : memref<32xf32>
  }
}

Finally, we compose the two template kernels into the top-level function with the ID specified.

s = allo.customize(top2)
s.compose(s1, id="K1")
s.compose(s2, id="K2")
print(s.module)
module {
  func.func @kernel_K1(%arg0: memref<32xi32>) -> memref<32xi32> attributes {itypes = "s", otypes = "s"} {
    %alloc = memref.alloc() {name = "B"} : memref<32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32xi32>)
    affine.for %arg1 = 0 to 32 {
      %0 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
      %1 = arith.extsi %0 : i32 to i33
      %c1_i32 = arith.constant 1 : i32
      %2 = arith.extsi %c1_i32 : i32 to i33
      %3 = arith.addi %1, %2 : i33
      %4 = arith.trunci %3 : i33 to i32
      affine.store %4, %alloc[%arg1] {to = "B"} : memref<32xi32>
    } {loop_name = "i", op_name = "S_i_0", unroll = 4 : i32}
    return %alloc : memref<32xi32>
  }
  func.func @kernel_K2(%arg0: memref<32xi32>) -> memref<32xf32> attributes {itypes = "s", otypes = "_"} {
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.sitofp %c0_i32 : i32 to f32
    %alloc = memref.alloc() {name = "B"} : memref<32xf32>
    linalg.fill ins(%0 : f32) outs(%alloc : memref<32xf32>)
    affine.for %arg1 = 0 to 32 {
      %1 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
      %2 = arith.extsi %1 : i32 to i64
      %c2_i32 = arith.constant 2 : i32
      %3 = arith.extsi %c2_i32 : i32 to i64
      %4 = arith.muli %2, %3 : i64
      %5 = arith.sitofp %4 : i64 to f32
      affine.store %5, %alloc[%arg1] {to = "B"} : memref<32xf32>
    } {loop_name = "i", op_name = "S_i_0", pipeline_ii = 1 : ui32}
    return %alloc : memref<32xf32>
  }
  func.func @top2(%arg0: memref<32xi32>) -> memref<32xf32> attributes {itypes = "s", otypes = "_"} {
    %0 = call @kernel_K1(%arg0) {name = "C"} : (memref<32xi32>) -> memref<32xi32>
    %1 = call @kernel_K2(%0) {name = "D"} : (memref<32xi32>) -> memref<32xf32>
    return %1 : memref<32xf32>
  }
}

We can see from the printed module that the loop in the first kernel is unrolled by a factor of 4, and the loop in the second kernel is pipelined.

Total running time of the script: (0 minutes 0.804 seconds)

Gallery generated by Sphinx-Gallery