Getting Started

Author: Hongzheng Chen (hzchen@cs.cornell.edu)

In this tutorial, we demonstrate the basic usage of Allo.

Import Allo

First we import the necessary packages.

import allo

Algorithm Definition

Allo leverages an algorithm-optimization decoupled paradigm, which means users can first define the algorithm in a high-level language and then optimize the program with various hardware customization techniques (i.e., schedule primitives). Here we show how to define a general matrix multiplication (GEMM) in the Allo DSL.

We first import the necessary data types from Allo. In this example, we use int32 as the data type for all the variables.

from allo.ir.types import int32

We then define a function that takes two 32x32 matrices as inputs and returns a 32x32 matrix as output. The variable declaration is defined as <name>: <type>[<shape>]. We require strict type annotation in Allo’s kernels, which is different from directly programming in Python.

Inside the kernel, we provide a shorthand for the loop iterator. For example, for i, j, k in allo.grid(32, 32, 32) is equivalent to the following nested for-loop:

for i in range(32):
    for j in range(32):
        for k in range(32):
            # body

The allo.grid API is used to define the iteration space of the loop. The arguments denote the upper bounds of the loop iterators. Notice the above range-loop is also supported in the new Allo, so users have more flexibility to define the loop structure.

def gemm(A: int32[32, 32], B: int32[32, 32]) -> int32[32, 32]:
    C: int32[32, 32] = 0
    for i, j, k in allo.grid(32, 32, 32):
        C[i, j] += A[i, k] * B[k, j]
    return C

Create the Schedule

After defining the algorithm, we can start applying transformations to the kernel in order to achieve high performance. We call allo.customize to create a schedule for the kernel, where schedule denotes the set of transformations.

s = allo.customize(gemm)

Inspect the Intermediate Representation (IR)

Allo leverage the MLIR infrastructure to represent the program, and we can directly print out the IR by using s.module.

print(s.module)
module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 32 {
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %0 = affine.load %arg0[%arg2, %arg4] {from = "A"} : memref<32x32xi32>
          %1 = affine.load %arg1[%arg4, %arg3] {from = "B"} : memref<32x32xi32>
          %2 = arith.extsi %0 : i32 to i64
          %3 = arith.extsi %1 : i32 to i64
          %4 = arith.muli %2, %3 : i64
          %5 = affine.load %alloc[%arg2, %arg3] {from = "C"} : memref<32x32xi32>
          %6 = arith.trunci %4 : i64 to i32
          %7 = arith.addi %5, %6 : i32
          affine.store %7, %alloc[%arg2, %arg3] {to = "C"} : memref<32x32xi32>
        } {loop_name = "k"}
      } {loop_name = "j"}
    } {loop_name = "i", op_name = "S_i_j_k_0"}
    return %alloc : memref<32x32xi32>
  }
}

Let’s take a close look at the generated IR. Basically an MLIR program is a set of operations in different dialects, and the operations are referred to as <dialect>.<ops>. In this example, we can see that the generated IR contains the following dialects:

  • func: Used to define the function signature and the return of the function.

  • memref: Used to define the shape and memory layout of the tensors.

  • affine: Used to define the loop structure.

  • arith: Used to conduct actual arithmetic operations.

  • linalg: Currently only used to initialize the tensors.

And the inner-most dot-product is explicitly represented by a sequence of load/store operations and some arithmetic operations. Allo also attaches some attributes to the operations, including the tensor names, loop names, and operation names, which are further used for optimization.

Apply Transformations

Next, we start transforming the program by using the schedule primitives. We can refer to the loops by using the loop names. For example, to split the outer-most loop into two, we can call the .split() primitive as follows:

s.split("i", factor=8)

We can print out the IR again to see the effect of the transformation.

Note

In the Allo DSL, all the transformations are applied immediately, so users can directly see the changes after they apply the transformations.

print(s.module)
#map = affine_map<(d0, d1) -> (d0 + d1 * 8)>
module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 4 {
      affine.for %arg3 = 0 to 8 {
        affine.for %arg4 = 0 to 32 {
          affine.for %arg5 = 0 to 32 {
            %1 = affine.apply #map(%arg3, %arg2)
            %2 = affine.load %arg0[%1, %arg5] {from = "A"} : memref<32x32xi32>
            %3 = affine.load %arg1[%arg5, %arg4] {from = "B"} : memref<32x32xi32>
            %4 = arith.extsi %2 : i32 to i64
            %5 = arith.extsi %3 : i32 to i64
            %6 = arith.muli %4, %5 : i64
            %7 = affine.load %alloc[%1, %arg4] {from = "C"} : memref<32x32xi32>
            %8 = arith.trunci %6 : i64 to i32
            %9 = arith.addi %7, %8 : i32
            affine.store %9, %alloc[%1, %arg4] {to = "C"} : memref<32x32xi32>
          } {loop_name = "k"}
        } {loop_name = "j"}
      } {loop_name = "i.inner"}
    } {loop_name = "i.outer", op_name = "S_i_j_k_0"}
    %0 = allo.create_op_handle "S_i_j_k_0"
    return %alloc : memref<32x32xi32>
  }
}

We can see that the outer-most loop is split into two loops, and the original loop is replaced by the two new loops. The new loops are named as i.outer and i.inner.

Similarly, we can split the j loop:

s.split("j", factor=8)
print(s.module)
#map = affine_map<(d0, d1) -> (d0 + d1 * 8)>
module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 4 {
      affine.for %arg3 = 0 to 8 {
        affine.for %arg4 = 0 to 4 {
          affine.for %arg5 = 0 to 8 {
            affine.for %arg6 = 0 to 32 {
              %1 = affine.apply #map(%arg5, %arg4)
              %2 = affine.apply #map(%arg3, %arg2)
              %3 = affine.load %arg0[%2, %arg6] {from = "A"} : memref<32x32xi32>
              %4 = affine.load %arg1[%arg6, %1] {from = "B"} : memref<32x32xi32>
              %5 = arith.extsi %3 : i32 to i64
              %6 = arith.extsi %4 : i32 to i64
              %7 = arith.muli %5, %6 : i64
              %8 = affine.load %alloc[%2, %1] {from = "C"} : memref<32x32xi32>
              %9 = arith.trunci %7 : i64 to i32
              %10 = arith.addi %8, %9 : i32
              affine.store %10, %alloc[%2, %1] {to = "C"} : memref<32x32xi32>
            } {loop_name = "k"}
          } {loop_name = "j.inner"}
        } {loop_name = "j.outer"}
      } {loop_name = "i.inner"}
    } {loop_name = "i.outer", op_name = "S_i_j_k_0"}
    %0 = allo.create_op_handle "S_i_j_k_0"
    return %alloc : memref<32x32xi32>
  }
}

We can further reorder the loops by using .reorder(). For example, we can move the splitted outer loops together, and move the splitted inner loops together.

s.reorder("i.outer", "j.outer", "i.inner", "j.inner")
print(s.module)
#map = affine_map<(d0, d1) -> (d0 + d1 * 8)>
module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    %c0_i32 = arith.constant 0 : i32
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 4 {
      affine.for %arg3 = 0 to 4 {
        affine.for %arg4 = 0 to 8 {
          affine.for %arg5 = 0 to 8 {
            affine.for %arg6 = 0 to 32 {
              %0 = affine.apply #map(%arg5, %arg3)
              %1 = affine.apply #map(%arg4, %arg2)
              %2 = affine.load %arg0[%1, %arg6] {from = "A"} : memref<32x32xi32>
              %3 = affine.load %arg1[%arg6, %0] {from = "B"} : memref<32x32xi32>
              %4 = arith.extsi %2 : i32 to i64
              %5 = arith.extsi %3 : i32 to i64
              %6 = arith.muli %4, %5 : i64
              %7 = affine.load %alloc[%1, %0] {from = "C"} : memref<32x32xi32>
              %8 = arith.trunci %6 : i64 to i32
              %9 = arith.addi %7, %8 : i32
              affine.store %9, %alloc[%1, %0] {to = "C"} : memref<32x32xi32>
            } {loop_name = "k"}
          } {loop_name = "j.inner"}
        } {loop_name = "i.inner"}
      } {loop_name = "j.outer"}
    } {loop_name = "i.outer", op_name = "S_i_j_k_0"}
    return %alloc : memref<32x32xi32>
  }
}

We can see the changes from the loop names in the generated IR.

Create the Executable

The next step is to generate the executable from the schedule. We can directly call .build() function on the schedule and specify the target hardware as llvm. By default, Allo will generate a LLVM program that can be executed on the CPU. Otherwise, you can also specify the target as vhls to generate a Vivado HLS program that can be synthesized to an FPGA accelerator.

mod = s.build(target="llvm")

Note

s.build(target="llvm") is equivalent to s.build().

Prepare the Inputs/Outputs for the Executable

To run the executable, we can generate random NumPy arrays as input data, and directly feed them into the executable. Allo will automatically handle the input data and generate corresponding internal wrappers for LLVM to execute, but we still need to make sure the data types are consistent. By default, np.random.randint will generate np.int64 data type, while we use int32 when defining our kernel function, so we need to explicitly cast the data type to np.int32.

import numpy as np

np_A = np.random.randint(0, 100, (32, 32)).astype(np.int32)
np_B = np.random.randint(0, 100, (32, 32)).astype(np.int32)

Run the Executable

With the prepared inputs/outputs, we can feed them to our executable. Notice our module can return a new array as output, so we can directly assign the output to a new variable.

np_C = mod(np_A, np_B)

Finally, we can do a sanity check to see if the results are correct.

golden_C = np.matmul(np_A, np_B)
np.testing.assert_allclose(np_C, golden_C, rtol=1e-5, atol=1e-5)
print("Results are correct!")
Results are correct!

Total running time of the script: (0 minutes 0.196 seconds)

Gallery generated by Sphinx-Gallery