Note
Go to the end to download the full example code.
Other Features¶
Author: Hongzheng Chen (hzchen@cs.cornell.edu)
This document will discuss other features that are not covered in the previous tutorials.
Dynamic Shapes¶
In some cases, the shape of the tensor is not known at compile time, so we can use [...] to represent the dynamic shape.
From the generated MLIR module, we can see it has a "?" in the shape of the tensor, which means the shape is not predefined,
but we can still run the LLVM module with arbitrary shapes of NumPy arrays.
import allo
from allo.ir.types import int32, float32
import numpy as np
def kernel(A: float32[...], B: float32[...], size: int32):
for i in range(size):
B[i] = A[i]
s = allo.customize(kernel)
print(s.module)
np_A = np.random.random((256,)).astype(np.float32)
allo_A = np.zeros((256,)).astype(np.float32)
mod = s.build()
mod(np_A, allo_A, 256)
np.testing.assert_allclose(np_A, allo_A)
module {
func.func @kernel(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: i32) attributes {itypes = "__s", otypes = ""} {
%c0_i32 = arith.constant 0 : i32
%c0_i32_0 = arith.constant 0 : i32
%0 = arith.index_cast %c0_i32_0 : i32 to index
%1 = arith.index_cast %arg2 : i32 to index
%c1_i32 = arith.constant 1 : i32
%c1_i32_1 = arith.constant 1 : i32
%2 = arith.index_cast %c1_i32_1 : i32 to index
scf.for %arg3 = %0 to %1 step %2 {
%3 = memref.load %arg0[%arg3] {from = "A"} : memref<?xf32>
memref.store %3, %arg1[%arg3] {to = "B"} : memref<?xf32>
} {loop_name = "i", op_name = "S_i_0"}
return
}
}
We can also check the generated HLS code that the arguments are declared as pointers.
code = s.build(target="vhls")
print(code)
//===------------------------------------------------------------*- C++ -*-===//
//
// Automatically generated file for High-level Synthesis (HLS).
//
//===----------------------------------------------------------------------===//
#include <algorithm>
#include <ap_axi_sdata.h>
#include <ap_fixed.h>
#include <ap_int.h>
#include <hls_math.h>
#include <hls_stream.h>
#include <hls_streamofblocks.h>
#include <math.h>
#include <stdint.h>
using namespace std;
/// This is top function.
void kernel(
float *v0,
float *v1,
int32_t v2
) { // L2
int v3 = v2; // L6
for (int v4 = 0; v4 < v3; v4 += 1) { // L10
float v5 = *v0[v4]; // L11
*v1[v4] = v5; // L12
}
}
Tuple Return¶
Another feature is the tuple support. As in Python, we can return multiple values from a function, Allo also supports this by explicitly specifying the return type as a tuple.
def callee(a: float32, b: float32) -> (float32, float32):
c: float32 = a + b
d: float32 = a - b
return c, d
def kernel(A: float32[10], B: float32[10]) -> (float32[10], float32[10]):
C: float32[10] = 0
D: float32[10] = 0
for i in range(10):
C[i], D[i] = callee(A[i], B[i])
return C, D
s = allo.customize(kernel)
print(s.module)
mod = s.build()
np_A = np.random.random((10,)).astype(np.float32)
np_B = np.random.random((10,)).astype(np.float32)
np_C, np_D = mod(np_A, np_B)
np_C_ref = np.zeros((10,), dtype=np.float32)
np_D_ref = np.zeros((10,), dtype=np.float32)
for i in range(10):
np_C_ref[i], np_D_ref[i] = callee(np_A[i], np_B[i])
np.testing.assert_allclose(np_C, np_C_ref)
np.testing.assert_allclose(np_D, np_D_ref)
module {
func.func @callee(%arg0: f32, %arg1: f32) -> (f32, f32) attributes {itypes = "__", otypes = "__"} {
%0 = arith.addf %arg0, %arg1 : f32
%alloc = memref.alloc() {name = "c"} : memref<f32>
affine.store %0, %alloc[] {to = "c"} : memref<f32>
%1 = arith.subf %arg0, %arg1 : f32
%alloc_0 = memref.alloc() {name = "d"} : memref<f32>
affine.store %1, %alloc_0[] {to = "d"} : memref<f32>
%2 = affine.load %alloc[] {from = "c"} : memref<f32>
%3 = affine.load %alloc_0[] {from = "d"} : memref<f32>
return %2, %3 : f32, f32
}
func.func @kernel(%arg0: memref<10xf32>, %arg1: memref<10xf32>) -> (memref<10xf32>, memref<10xf32>) attributes {itypes = "__", otypes = "__"} {
%c0_i32 = arith.constant 0 : i32
%c0_i32_0 = arith.constant 0 : i32
%0 = arith.sitofp %c0_i32_0 : i32 to f32
%alloc = memref.alloc() {name = "C"} : memref<10xf32>
linalg.fill ins(%0 : f32) outs(%alloc : memref<10xf32>)
%c0_i32_1 = arith.constant 0 : i32
%c0_i32_2 = arith.constant 0 : i32
%1 = arith.sitofp %c0_i32_2 : i32 to f32
%alloc_3 = memref.alloc() {name = "D"} : memref<10xf32>
linalg.fill ins(%1 : f32) outs(%alloc_3 : memref<10xf32>)
affine.for %arg2 = 0 to 10 {
%2 = affine.load %arg0[%arg2] {from = "A"} : memref<10xf32>
%3 = affine.load %arg1[%arg2] {from = "B"} : memref<10xf32>
%4:2 = func.call @callee(%2, %3) : (f32, f32) -> (f32, f32)
affine.store %4#0, %alloc[%arg2] {to = "C"} : memref<10xf32>
affine.store %4#1, %alloc_3[%arg2] {to = "D"} : memref<10xf32>
} {loop_name = "i", op_name = "S_i_0"}
return %alloc, %alloc_3 : memref<10xf32>, memref<10xf32>
}
}
Compile-Time Constant Expressions (ConstExpr)¶
ConstExpr allows you to declare variables that are evaluated at Python
level during compilation. This enables using Python helper functions for
compile-time computations like computing coefficients or lookup indices.
from allo.ir.types import ConstExpr
# Python helper functions - evaluated at compile time
def compute_coefficient(i):
"""Compute a coefficient based on index."""
import math
return math.cos(2.0 * math.pi * i / 8)
def compute_index(i, offset):
"""Compute a transformed index."""
return (i + offset) % 8
def kernel_with_constexpr(A: float32[8], B: float32[8]):
with allo.meta_for(8) as i:
# ConstExpr values are computed at Python level during compilation
coef: ConstExpr[float32] = compute_coefficient(i)
idx: ConstExpr[int32] = compute_index(i, 3)
# Use the compile-time constants in expressions
B[i] = A[idx] * coef
s = allo.customize(kernel_with_constexpr)
print(s.module)
module {
func.func @kernel_with_constexpr(%arg0: memref<8xf32>, %arg1: memref<8xf32>) attributes {itypes = "__", otypes = ""} {
%0 = affine.load %arg0[3] {from = "A"} : memref<8xf32>
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
%1 = arith.mulf %0, %cst_0 : f32
affine.store %1, %arg1[0] {to = "B"} : memref<8xf32>
%2 = affine.load %arg0[4] {from = "A"} : memref<8xf32>
%cst_1 = arith.constant 0.707106769 : f32
%cst_2 = arith.constant 0.707106769 : f32
%3 = arith.mulf %2, %cst_2 : f32
affine.store %3, %arg1[1] {to = "B"} : memref<8xf32>
%4 = affine.load %arg0[5] {from = "A"} : memref<8xf32>
%cst_3 = arith.constant 6.12323426E-17 : f32
%cst_4 = arith.constant 6.12323426E-17 : f32
%5 = arith.mulf %4, %cst_4 : f32
affine.store %5, %arg1[2] {to = "B"} : memref<8xf32>
%6 = affine.load %arg0[6] {from = "A"} : memref<8xf32>
%cst_5 = arith.constant -0.707106769 : f32
%cst_6 = arith.constant -0.707106769 : f32
%7 = arith.mulf %6, %cst_6 : f32
affine.store %7, %arg1[3] {to = "B"} : memref<8xf32>
%8 = affine.load %arg0[7] {from = "A"} : memref<8xf32>
%cst_7 = arith.constant -1.000000e+00 : f32
%cst_8 = arith.constant -1.000000e+00 : f32
%9 = arith.mulf %8, %cst_8 : f32
affine.store %9, %arg1[4] {to = "B"} : memref<8xf32>
%10 = affine.load %arg0[0] {from = "A"} : memref<8xf32>
%cst_9 = arith.constant -0.707106769 : f32
%cst_10 = arith.constant -0.707106769 : f32
%11 = arith.mulf %10, %cst_10 : f32
affine.store %11, %arg1[5] {to = "B"} : memref<8xf32>
%12 = affine.load %arg0[1] {from = "A"} : memref<8xf32>
%cst_11 = arith.constant -1.83697015E-16 : f32
%cst_12 = arith.constant -1.83697015E-16 : f32
%13 = arith.mulf %12, %cst_12 : f32
affine.store %13, %arg1[6] {to = "B"} : memref<8xf32>
%14 = affine.load %arg0[2] {from = "A"} : memref<8xf32>
%cst_13 = arith.constant 0.707106769 : f32
%cst_14 = arith.constant 0.707106769 : f32
%15 = arith.mulf %14, %cst_14 : f32
affine.store %15, %arg1[7] {to = "B"} : memref<8xf32>
return
}
}
From the generated MLIR, you can see that the coef and idx values
are embedded as constants in the code, not computed at runtime.
Key points about ConstExpr:
The RHS is evaluated at Python level, not compiled as Allo code
You can use arbitrary Python functions (math, etc.)
Use
ConstExpr[int32]orConstExpr[float32]for the type annotation
Total running time of the script: (0 minutes 0.238 seconds)