Note
Go to the end to download the full example code.
IR Builder Walkthrough¶
Author: Hongzheng Chen (hzchen@cs.cornell.edu)
This guide will walk you through the process of translating a Python-based Allo program to the internal MLIR representation. We will use the vector addition example to demonstrate the process.
import allo
from allo.ir.types import int32
Algorithm Definition¶
We can define a matrix_add
function as follows. In the new frontend, we
leverage the parsing technique to
translate the Python code to an MLIR program. Therefore, the first
step is to parse the Python code to the
Abstract Syntax Tree (AST) representation.
M, N = 1024, 1024
def matrix_add(A: int32[M, N]) -> int32[M, N]:
B: int32[M, N] = 0
for i, j in allo.grid(M, N):
B[i, j] = A[i, j] + 1
return B
Python has a rich set of tools to support
reflection.
One of the most useful tools is the inspect
module, which provides
an API to access the source code of a Python function. We can call
inspect.getsource
to get the source code of the matrix_add
.
import inspect
src = inspect.getsource(matrix_add)
print(src)
def matrix_add(A: int32[M, N]) -> int32[M, N]:
B: int32[M, N] = 0
for i, j in allo.grid(M, N):
B[i, j] = A[i, j] + 1
return B
After we get the string representation of the source code, we can use
the ast
module to parse the code to an AST. The astpretty
module
can be used to print the AST in a human-readable format, which requires to
be installed through pip
separately. Otherwise, you can just use
ast.dump
to print the AST in raw format.
import ast, astpretty
tree = ast.parse(src)
astpretty.pprint(tree, indent=2, show_offsets=False)
Module(
body=[
FunctionDef(
name='matrix_add',
args=arguments(
posonlyargs=[],
args=[
arg(
arg='A',
annotation=Subscript(
value=Name(id='int32', ctx=Load()),
slice=Tuple(
elts=[
Name(id='M', ctx=Load()),
Name(id='N', ctx=Load()),
],
ctx=Load(),
),
ctx=Load(),
),
type_comment=None,
),
],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[],
),
body=[
AnnAssign(
target=Name(id='B', ctx=Store()),
annotation=Subscript(
value=Name(id='int32', ctx=Load()),
slice=Tuple(
elts=[
Name(id='M', ctx=Load()),
Name(id='N', ctx=Load()),
],
ctx=Load(),
),
ctx=Load(),
),
value=Constant(value=0, kind=None),
simple=1,
),
For(
target=Tuple(
elts=[
Name(id='i', ctx=Store()),
Name(id='j', ctx=Store()),
],
ctx=Store(),
),
iter=Call(
func=Attribute(
value=Name(id='allo', ctx=Load()),
attr='grid',
ctx=Load(),
),
args=[
Name(id='M', ctx=Load()),
Name(id='N', ctx=Load()),
],
keywords=[],
),
body=[
Assign(
targets=[
Subscript(
value=Name(id='B', ctx=Load()),
slice=Tuple(
elts=[
Name(id='i', ctx=Load()),
Name(id='j', ctx=Load()),
],
ctx=Load(),
),
ctx=Store(),
),
],
value=BinOp(
left=Subscript(
value=Name(id='A', ctx=Load()),
slice=Tuple(
elts=[
Name(id='i', ctx=Load()),
Name(id='j', ctx=Load()),
],
ctx=Load(),
),
ctx=Load(),
),
op=Add(),
right=Constant(value=1, kind=None),
),
type_comment=None,
),
],
orelse=[],
type_comment=None,
),
Return(
value=Name(id='B', ctx=Load()),
),
],
decorator_list=[],
returns=Subscript(
value=Name(id='int32', ctx=Load()),
slice=Tuple(
elts=[
Name(id='M', ctx=Load()),
Name(id='N', ctx=Load()),
],
ctx=Load(),
),
ctx=Load(),
),
type_comment=None,
type_params=[],
),
],
type_ignores=[],
)
The AST is a tree structure that represents the syntactic structure of the
source code. Each node is an operator or an annotation in the source code.
For example, the FunctionDef
node represents a function
definition, and the AnnAssign
node represents an annotated assignment statement.
Note
We also wrap the above functions in allo.customize
, you can
directly call s = allo.customize(matrix_add, verbose=True)
to obtain
the AST of the function. The entry point of the customize
function is
located in allo/customize.py.
Traverse the AST¶
After obtaining the AST, we can traverse the tree node one by one to generate the IR.
The IR builder is inside allo/ir/builder.py.
Basically, the builder is a dispatcher that maps the AST node to the corresponding
IR builder function. For example, the FunctionDef
node will be mapped to
ASTTransformer.build_FunctionDef
.
All the builder function are staticmethod
s that take in two arguments:
an AST context and an AST node.
The AST context stores necessary information used to build the IR, including:
ip_stack
: The stack of insertion points. The insertion point is used to denote the current position of the IR builder. For example, when we are building the body of a function, the insertion point is the function body.buffers
: The dictionary that stores all the tensors in the program.induction_vars
: The list of loop iterators, e.g.,i
,j
,k
.global_vars
: The global variables defined outside the user-defined function.top_func
: The top-level function of the current program.
The first node to traverse is the Module
node, which is the root of the AST.
We can see the build_Module
function only does one thing: traverse the statements
inside the body of the module, and recursively call build_stmt
.
@staticmethod
def build_Module(ctx, node):
for stmt in node.body:
build_stmt(ctx, stmt)
FunctionDef Node¶
And then we meet the FunctionDef
node, which is the function definition.
The build_FunctionDef
function first creates the input and output data types
based on users’ annotations. Then, it creates a new MLIR function operation by calling
func_op = func_d.FuncOp(name=node.name, type=func_type, ip=ip, loc=loc)
Here, func_d
is the func dialect defined in MLIR.
The FuncOp
is the operation that represents a function in MLIR. The function arguments are
explained below:
name
is the name of the function, and we directly use the ASTFunctionDef
node’s namematrix_add
as the operation name.type
is theFunctionType
that defines the input and output types of the function.ip
is the insertion point of the function, which is the current insertion point of the AST context, and we can directly obtain it by callingctx.get_ip()
.loc
is the actual line number of the function, which can be usually omitted.
After creating the function operation, we need to create the function body. We first update the insertion point
to the function body by calling ctx.push_ip(func_op.entry_block)
. Then, we traverse the function body and recursively
call build_stmt
. The function arguments are inserted into the buffers
for further usage.
Note
You may probably notice the MockArg
class. This is a mock class that is used to store the
function arguments, which are BlockArgument
s in MLIR. It is different from other operations
that inherently have a result
attribute. Therefore, we mock the BlockArgument
to make
it consistent with other operations by providing a result
property method.
AnnAssign Node¶
Next, let’s visit the AnnAssign
node, which is the annotated assignment statement.
The build_AnnAssign
function first evaluate the right hand side of the assignment statement
by calling rhs = build_stmt(ctx, node.value)
. Then, it gets the user-defined type annotation
to generate correct data types for the tensor. Please refer to memref
dialect for more details. Similarly, we can call memref_d.AllocOp
to create a new memory allocation,
and you can see the actual memref.alloc
operation in the generated MLIR code.
One more thing to mention is that what we see inside the AST is just string, so if we want to
get the actual value of a literal, we need to retrieve it from the ctx.global_vars
dictionary.
For example, the int32[M, N]
generates the following annotation:
slice=Index(
value=Tuple(
elts=[
Name(id='M', ctx=Load()),
Name(id='N', ctx=Load()),
],
ctx=Load(),
),
)
We can see the M
and N
are just the Name
nodes, and we need to retrieve the actual value
from the ctx.global_vars
dictionary by calling something like ctx.global_vars[node.slice.value.elts[0].id]
.
For Node¶
The next operator is the For
node, which is the for-loop statement. We provide different APIs to
support different loop structures, so we need to further dispatch the For
node to the corresponding
builder function. For example, here we use allo.grid
, so it will be dispatched to build_grid_for
.
We provide some helper functions in allo/ir/transform.py to make the IR creation easier.
In this case, we can just call build_for_loops
and pass in the bounds and the names of the loops
to create a loop nest.
Before building the loop body, we need to update the insertion point:
ctx.set_ip(for_loops[-1].body.operations[0])
After calling build_stmts(ctx, node.body)
, we also need to recover the insertion point:
ctx.pop_ip()
Other Nodes¶
The build process is similar for other nodes, so we will not go into them one by one.
Please refer to the source code for more details.
After building the IR, you can call s.module
to see the effect.
Most of the MLIR operations can be found on this webpage, and now you can follow the definitions and add more amazing facilities to the new Allo compiler!
Total running time of the script: (0 minutes 0.005 seconds)