Equivalence Checking

Author: Hongzheng Chen (hzchen@cs.cornell.edu)

In this tutorial, we demonstrate how to use Allo’s verifier facility to check the equivalence of different scheduling transformations. The verifier ensures that various optimizations applied to the same algorithm do not alter its functional behavior.

First, we import the necessary packages:

import allo
from allo.ir.types import float32

Create the Schedule

We define a general matrix multiplication (GEMM) kernel that takes two 32x32 matrices as inputs and produces a 32x32 output matrix. The reduction loop is used to accumulate the multiplication results.

def gemm(A: float32[32, 32], B: float32[32, 32]) -> float32[32, 32]:
    C: float32[32, 32] = 0
    for i, j in allo.grid(32, 32):
        for k in allo.reduction(32):
            C[i, j] += A[i, k] * B[k, j]
    return C

We create two schedules for the GEMM kernel using different transformations. The first schedule, s1, applies a loop reordering transformation, while the second schedule, s2, applies a buffering transformation on the output tensor.

s1 = allo.customize(gemm)
s1.reorder("gemm:i", "gemm:j")
print(s1.module)
module {
  func.func @gemm(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> attributes {itypes = "__", otypes = "_"} {
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.sitofp %c0_i32 : i32 to f32
    %alloc = memref.alloc() {name = "C"} : memref<32x32xf32>
    linalg.fill ins(%0 : f32) outs(%alloc : memref<32x32xf32>)
    affine.for %arg2 = 0 to 32 {
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %1 = affine.load %arg0[%arg2, %arg4] {from = "A"} : memref<32x32xf32>
          %2 = affine.load %arg1[%arg4, %arg3] {from = "B"} : memref<32x32xf32>
          %3 = arith.mulf %1, %2 : f32
          %4 = affine.load %alloc[%arg2, %arg3] {from = "C"} : memref<32x32xf32>
          %5 = arith.addf %4, %3 : f32
          affine.store %5, %alloc[%arg2, %arg3] {to = "C"} : memref<32x32xf32>
        } {loop_name = "k", op_name = "S_k_0", reduction}
      } {loop_name = "j"}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xf32>
  }
}

In the code above, s1 is customized by reordering the loops corresponding to indices i and j. The printed intermediate representation (IR) shows the effect of this transformation.

s2 = allo.customize(gemm)
s2.buffer_at(s2.C, axis="i")
print(s2.module)
module {
  func.func @gemm(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> attributes {itypes = "__", otypes = "_"} {
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.sitofp %c0_i32 : i32 to f32
    %alloc = memref.alloc() {name = "C"} : memref<32x32xf32>
    linalg.fill ins(%0 : f32) outs(%alloc : memref<32x32xf32>)
    affine.for %arg2 = 0 to 32 {
      %alloc_0 = memref.alloc() : memref<32xf32>
      %cst = arith.constant 0.000000e+00 : f32
      affine.for %arg3 = 0 to 32 {
        affine.store %cst, %alloc_0[%arg3] : memref<32xf32>
      } {buffer, loop_name = "j_init", pipeline_ii = 1 : i32}
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %1 = affine.load %arg0[%arg2, %arg4] {from = "A"} : memref<32x32xf32>
          %2 = affine.load %arg1[%arg4, %arg3] {from = "B"} : memref<32x32xf32>
          %3 = arith.mulf %1, %2 : f32
          %4 = affine.load %alloc_0[%arg3] : memref<32xf32>
          %5 = arith.addf %4, %3 : f32
          affine.store %5, %alloc_0[%arg3] : memref<32xf32>
        } {loop_name = "k", op_name = "S_k_0", reduction}
      } {loop_name = "j"}
      affine.for %arg3 = 0 to 32 {
        %1 = affine.load %alloc_0[%arg3] : memref<32xf32>
        affine.store %1, %alloc[%arg2, %arg3] : memref<32x32xf32>
      } {buffer, loop_name = "j_back", pipeline_ii = 1 : i32}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xf32>
  }
}

Here, a buffering transformation is applied on tensor C along the i axis. The IR output confirms that the transformation has been incorporated. Although the schedules differ in structure, they should implement equivalent functionality.

Verifying Equivalence

Next, we use the verifier facility to check whether the two schedules, s1 and s2, are equivalent. The allo.verify function compares the schedules and returns a truthy value if they are functionally identical.

verifier = allo.verify(s1, s2)
assert verifier, "Failed to verify the equivalence of two schedules!"
print("s1 and s2 are equivalent!")
s1 and s2 are equivalent!

The assertion confirms that the transformations applied in s1 and s2 preserve the semantics of the original GEMM algorithm.

Introducing a Non-equivalent Schedule

To illustrate the effectiveness of the verifier, we define an alternative GEMM kernel, gemm_wrong, which incorrectly implements the multiplication by overwriting the output instead of accumulating the results. The schedule derived from gemm_wrong (named s3) should not be equivalent to s1.

def gemm_wrong(A: float32[32, 32], B: float32[32, 32]) -> float32[32, 32]:
    C: float32[32, 32] = 0
    for i, j in allo.grid(32, 32):
        for k in allo.reduction(32):
            C[i, j] = A[i, k] * B[k, j]
    return C


s3 = allo.customize(gemm_wrong)
print(s3.module)
verifier = allo.verify(s1, s3)
assert not verifier, "Failed to verify the equivalence of two schedules!"
print("s1 and s3 are not equivalent!")
module {
  func.func @gemm_wrong(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> attributes {itypes = "__", otypes = "_"} {
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.sitofp %c0_i32 : i32 to f32
    %alloc = memref.alloc() {name = "C"} : memref<32x32xf32>
    linalg.fill ins(%0 : f32) outs(%alloc : memref<32x32xf32>)
    affine.for %arg2 = 0 to 32 {
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %1 = affine.load %arg0[%arg2, %arg4] {from = "A"} : memref<32x32xf32>
          %2 = affine.load %arg1[%arg4, %arg3] {from = "B"} : memref<32x32xf32>
          %3 = arith.mulf %1, %2 : f32
          affine.store %3, %alloc[%arg2, %arg3] {to = "C"} : memref<32x32xf32>
        } {loop_name = "k", op_name = "S_k_0", reduction}
      } {loop_name = "j"}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xf32>
  }
}

Verifier reported non-equivalence between schedules.
Differences between generated programs:
--- Program A (Schedule A)
+++ Program B (Schedule B)
@@ -13,7 +13,7 @@
 #include <math.h>

 #include <stdint.h>

 using namespace std;

-void gemm(

+void gemm_wrong(

   float v0[32][32],

   float v1[32][32],

   float v2[32][32]

@@ -30,9 +30,7 @@
         float v8 = v0[i][k];   // L10

         float v9 = v1[k][j];   // L11

         float v10 = v8 * v9;   // L12

-        float v11 = v2[i][j];  // L13

-        float v12 = v11 + v10; // L14

-        v2[i][j] = v12;        // L15

+        v2[i][j] = v10;        // L13

       }

     }

   }

s1 and s3 are not equivalent!

The verifier correctly detects that s3 does not preserve the intended accumulation, thus confirming that s1 and s3 are not equivalent.

Conclusion

This tutorial has demonstrated how to use Allo’s verifier facility to ensure that different scheduling transformations yield equivalent computational behavior. By verifying the equivalence of various schedules, you can confidently apply optimizations without compromising the functional correctness of your algorithms.

Total running time of the script: (0 minutes 3.243 seconds)

Gallery generated by Sphinx-Gallery