Source code for allo.customize

# Copyright Allo authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-name-in-module

import re
import inspect
import textwrap
import copy
from dataclasses import dataclass
from functools import wraps
from types import FunctionType as PyFunctionType
from typing import Union
from collections.abc import Callable

from ._mlir.ir import (
    Context,
    Location,
    InsertionPoint,
    StringAttr,
    UnitAttr,
    IndexType,
    IntegerType,
    IntegerAttr,
    TypeAttr,
    F32Type,
    MemRefType,
    FlatSymbolRefAttr,
    AffineMap,
    AffineMapAttr,
)
from ._mlir.dialects import (
    allo as allo_d,
    memref as memref_d,
    affine as affine_d,
    scf as scf_d,
    arith as arith_d,
    func as func_d,
)
from ._mlir.dialects.affine import (
    AffineExpr,
    AffineDimExpr,
)
from ._mlir.exceptions import (
    AlloValueError,
)

from . import primitives as prim
from .ir.visitor import ASTContext
from .ir.utils import MockArg, MockBuffer, parse_ast, get_global_vars
from .ir.builder import ASTTransformer
from .ir.infer import TypeInferer
from .ir.transform import (
    get_affine_loop_nests,
    find_loop_in_bands,
    find_buffer,
    find_func_in_module,
    LoopWrapper,
)
from .ir.use_def import UseDefChain
from .passes import (
    _mlir_lower_pipeline,
    lower_linalg_and_attach_names,
)
from .backend.llvm import LLVMModule
from .backend.hls import HLSModule
from .library import KERNEL2SCHEDULE


def getsourcefile(obj):
    ret = inspect.getsourcefile(obj)
    if ret is None:
        ret = inspect.getfile(obj)
    return ret


def getsourcelines(obj):
    return inspect.getsourcelines(obj)


def wrapped_apply(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        sch = args[0]
        with sch.module.context, Location.unknown():
            res = fn(*args, **kwargs)
        _mlir_lower_pipeline(sch.module)
        # Remove previous Python-C++ references
        sch.module.context._clear_live_operations()
        # Update top function in the current context
        for op in sch.module.body.operations:
            if isinstance(op, func_d.FuncOp) and op.name.value == sch.top_func_name:
                sch.top_func = op
                break
        else:
            raise RuntimeError("Top function not found")
        # Update insertion point
        sch.ip = InsertionPoint.at_block_terminator(sch.top_func.entry_block)
        # Record primitive sequences
        sch.primitive_sequences.append((fn.__name__, list(args[1:]), kwargs))
        return res

    return wrapper


@dataclass
class Partition:
    Complete = 0
    Block = 1
    Cyclic = 2


[docs] class Schedule: def __init__( self, module, top_func, func_args, ip, ext_libs=None, use_def_chain=None, inst_list=None, ): self.module = module self.top_func = top_func self.top_func_name = top_func.name.value self.func_args = func_args # only store names here self.ip = ip self.primitive_sequences = [] if ext_libs is None: ext_libs = [] self.ext_libs = ext_libs self.use_def_chain = use_def_chain self.partitioned_arrays = {} self.inst_list = inst_list if inst_list is not None else [] def get_loops(self, func=None): if isinstance(func, str): func = self._find_function(func) if func is None: func = self.top_func return get_affine_loop_nests(func) def _find_band(self, band_name, func=None): loops = self.get_loops(func) if band_name in loops.loops: return loops[band_name] raise RuntimeError(f"Band {band_name} not found") def _find_function(self, name, error=True): for func in self.module.body.operations: if isinstance(func, func_d.FuncOp) and func.name.value == name: return func if error: raise RuntimeError(f"Function {name} not found") return None def _get_func_and_axis(self, axis): if isinstance(axis, LoopWrapper): func = self._find_function(axis.func) return func, axis if ":" in axis: func_name, axis = axis.split(":") else: func_name = self.top_func_name func = self._find_function(func_name) return func, axis
[docs] @wrapped_apply def split(self, axis, factor): """ `split` will find the loop with loop index `axis` and tile it with each tile size `factor` The new inner loop will be named `axis.inner` and the outer loop will be named `axis.outer` Parameters ---------- axis: str The name of an index in the kernel. factor: int The size of each tile, e.g. the size of the inner nested loop. """ func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) i32 = IntegerType.get_unsigned(32) factor = IntegerAttr.get(i32, factor) allo_d.SplitOp(loop_hdl.result, factor, ip=ip)
[docs] @wrapped_apply def reorder(self, *args): """ Reorders nested loops with indices listed in `args` such that the outermost loop is the first index listed in `args`, the second is the second outermost, and so on. This function is vardic, accepting each index as a separate argument. """ func, axis = self._get_func_and_axis(args[0]) band_name, _ = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdls = [] for arg in args: func, axis = self._get_func_and_axis(arg) band_name, axis = find_loop_in_bands(func, axis) loop_hdls.append( allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) ) arg_results = [arg.result for arg in loop_hdls] allo_d.ReorderOp(arg_results, ip=ip)
[docs] @wrapped_apply def unroll(self, axis, factor=0): """ Unrolls a loop with loop index `axis` by `factor`. Parameters ---------- axis: str The name of an index in the kernel. factor: int The factor to unroll by, for example a factor of 2 will cause the body to be duplicated once. """ func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) i32 = IntegerType.get_unsigned(32) factor = IntegerAttr.get(i32, factor) allo_d.UnrollOp(loop_hdl.result, factor=factor, ip=ip)
[docs] @wrapped_apply def fuse(self, *args): """ Combines loops with indices listed in `args` into a single loop over a single index. This function is vardic, accepting each index as a separate argument. """ func, axis = self._get_func_and_axis(args[0]) band_name, _ = find_loop_in_bands(func, args[0]) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdls = [] for arg in args: func, axis = self._get_func_and_axis(args) band_name, axis = find_loop_in_bands(func, arg) loop_hdls.append( allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) ) arg_results = [arg.result for arg in loop_hdls] allo_d.FuseOp(arg_results, ip=ip)
[docs] @wrapped_apply def partition(self, target, partition_type=Partition.Complete, dim=0, factor=0): """ Partitions a given array, for example if the array is `B`, this would be `<schedule>.B`. There are three types, `Partition.Complete`, `Partition.Block`, and `Partition.cyclic`. block: The original array is split into `factor` equally sized blocks of consecutive elements of the original array cyclic:The original array is split into `factor` equally sized blocks interleaving the elements of the original array. complete: The original array is split into its individual elements. This corresponds to resolving a memory into registers. Parameters ---------- target: allo.ir.utils.MockBuffer The array to partition. partition_type: allo.customize.Partition The type of partition. factor: int The number of arrays created by a block or cyclic partition. dim: int The dimension of `target` to partition. If `dim=0`, all dimensions are partitioned. """ # TODO: test whether the partition has conflicts for different functions if partition_type > 2: raise AlloValueError("Invalid partition type") if dim < 0: raise AlloValueError("Invalid dimension") if factor < 0: raise AlloValueError("Invalid factor") if partition_type == Partition.Complete: partition_type = 0 elif partition_type == Partition.Block: partition_type = 1 elif partition_type == Partition.Cyclic: partition_type = 2 else: raise AlloValueError("Not supported partition type") # test whether partitioning the same array for parray, items in self.partitioned_arrays.items(): for item in items: if ( parray.split(":")[0] == target.func and parray.split(":")[1] == target.name ): if item[0] == Partition.Complete and item[1] == 0: # this array has been completely partitioned along all the axes return raise AlloValueError( f"Cannot partition the same array twice: {parray}, {item} vs ({partition_type}, {dim}, {factor})" ) # actual partition i32 = IntegerType.get_signless(32) ui32 = IntegerType.get_unsigned(32) # find all the tensors that need to be partitioned visited_target_names = [] visited_func_calls = [] def recursive_partition(inner_target): name = f"{inner_target.func}:{inner_target.name}" if name in visited_target_names: return visited_target_names.append(name) _, _, mlir_target = find_buffer(self.module, inner_target, self.func_args) # equivalent users for tensor in self.use_def_chain.get_equivalent_tensors(name): recursive_partition(MockBuffer(tensor.path, tensor.name)) # calling the same function if isinstance(mlir_target, func_d.CallOp): visited_func_calls.append(mlir_target) for func in self.module.body.operations: if isinstance(func, func_d.FuncOp): for call_op in func.entry_block.operations: if ( isinstance(call_op, func_d.CallOp) and mlir_target.attributes["callee"] == call_op.attributes["callee"] and call_op not in visited_func_calls ): visited_func_calls.append(call_op) buffer = MockBuffer( func.attributes["sym_name"].value, call_op.attributes["name"].value, ) recursive_partition(buffer) recursive_partition(target) for inner_target in visited_target_names: func, _, mlir_target = find_buffer( self.module, MockBuffer(inner_target.split(":")[0], inner_target.split(":")[1]), self.func_args, ) if inner_target not in self.partitioned_arrays: self.partitioned_arrays[inner_target] = [(partition_type, dim, factor)] else: self.partitioned_arrays[inner_target].append( (partition_type, dim, factor) ) allo_d.PartitionOp( mlir_target.result, partition_kind=IntegerAttr.get(i32, partition_type), dim=IntegerAttr.get(ui32, dim), factor=IntegerAttr.get(ui32, factor), ip=InsertionPoint.at_block_terminator(func.entry_block), ) # Calculate layout map # first N: partition index # last N : physical index shape = mlir_target.result.type.shape partition_idx = [] address_idx = [] for i, _ in enumerate(shape): if dim == 0 or (dim > 0 and i == dim - 1): if partition_type == Partition.Cyclic: partition_idx.append(AffineDimExpr.get(i) % factor) address_idx.append( AffineExpr.get_floor_div(AffineDimExpr.get(i), factor) ) elif partition_type == Partition.Block: # block factor N means partition into N blocks # each block has shape[dim] / factor elements block_factor = (shape[i] + factor - 1) // factor partition_idx.append( AffineExpr.get_floor_div(AffineDimExpr.get(i), block_factor) ) address_idx.append(AffineDimExpr.get(i) % block_factor) else: # Partition.Complete partition_idx.append(AffineDimExpr.get(i)) address_idx.append(AffineExpr.get_constant(0)) else: partition_idx.append(AffineExpr.get_constant(0)) address_idx.append(AffineDimExpr.get(i)) affine_map = AffineMap.get( dim_count=len(shape), symbol_count=0, exprs=partition_idx + address_idx ) affine_attr = AffineMapAttr.get(affine_map) only_target_names = [item.split(":")[-1] for item in visited_target_names] for op in self.module.body.operations: if ( isinstance(op, memref_d.GlobalOp) and op.attributes["sym_name"].value in only_target_names ): op.attributes["type"] = TypeAttr.get( MemRefType.get( op.attributes["type"].value.shape, op.attributes["type"].value.element_type, affine_attr, op.attributes["type"].value.memory_space, ) )
[docs] @wrapped_apply def buffer_at(self, target, axis): """ Creates a chip buffer to hold the values of `target` written to in loop with index `axis` instead of immediately writing them to memory. Parameters ---------- target: allo.ir.utils.MockBuffer An array written to in a loop. axis: str The loop index whose body contains writes to target """ _, _, target = find_buffer(self.module, target, self.func_args) func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) memref_type = MemRefType.get((1,), F32Type.get()) allo_d.BufferAtOp(memref_type, target.result, loop_hdl.result, ip=ip)
[docs] @wrapped_apply def reshape(self, target, shape): """ Takes an array in the kernel, `target`, for example if the array is `B`, then would be `target` would be `<schedule>.B`, and reshapes it to tuple `shape`. As an example, if the desired shape is 32 by 4 by 8, the `<shape>` would be `(32, 4, 8)`. Parameters ---------- target: allo.ir.utils.MockBuffer The array, represented by a memory, to reshape. shape: tuple The new shape of the memory. """ _, _, target = find_buffer(self.module, target, self.func_args) eletype = MemRefType(target.result.type).element_type memref_type = MemRefType.get(shape, eletype) allo_d.ReshapeOp(memref_type, target.result, ip=self.ip)
[docs] @wrapped_apply def pipeline(self, axis, initiation_interval=1, rewind=False): """ Pipelines a loop with index `axis` into `initiation_interval` stages. Parameters ---------- axis: str The index of the loop to pipeline. initiation_interval: int The initiation_interval to be used when pipelining. rewind: bool If true, rewinding is allowed, allowing continuous loop pipelining. This is only effective for perfect loop nests inside a top level function. """ i32 = IntegerType.get_unsigned(32) ii = IntegerAttr.get(i32, initiation_interval) func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) if rewind: self.get_loops(func)[band_name][axis].loop.attributes[ "rewind" ] = UnitAttr.get() self.get_loops(func)[band_name][axis].loop.attributes["pipeline_ii"] = ii
[docs] @wrapped_apply def parallel(self, axis): """ Instantiates a loop with index `axis` to be computed in parallel with the loops it is nested with. Parameters ---------- axis: str The index of the loop to be computed in parallel. """ func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) allo_d.ParallelOp(loop_hdl.result, ip=ip)
[docs] @wrapped_apply def inline(self, axis=None): """ Inlines a function `axis`. Parameters ---------- axis: str The function to inline. """ assert axis is None or isinstance(axis, str), "Function name must be a string" if axis is None: axis = self.top_func_name func = self._find_function(axis) func.attributes["inline"] = UnitAttr.get()
[docs] @wrapped_apply def dataflow(self, axis): """ Applies a "dataflow" attribute to function `axis`. This allows for parallelism if the given function uses streams or the `to` schedule. Parameters ---------- axis: str | allo.ir.LoopWrapper The function to add the attribute to. """ if isinstance(axis, str): # function func = self._find_function(axis) func.attributes["dataflow"] = UnitAttr.get() return func, _ = self._get_func_and_axis(axis) band_name, loop_name = axis.name.split(".", 1) band_name = band_name.split(":")[1] cnt = 0 def locate_loop(op): nonlocal cnt for ope in op.body.operations: if isinstance(ope, (scf_d.ForOp, affine_d.AffineForOp)): locate_loop(ope) if ( "loop_name" in op.attributes and op.attributes["loop_name"].value == loop_name ): cnt += 1 op.attributes["dataflow"] = UnitAttr.get() for op in func.entry_block.operations: if isinstance(op, (scf_d.ForOp, affine_d.AffineForOp)): if ( "op_name" in op.attributes and op.attributes["op_name"].value == band_name ): locate_loop(op) if cnt == 0: raise RuntimeError(f"Dataflow loop {band_name}.{loop_name} not found")
[docs] @wrapped_apply def compute_at(self, from_loop, target_loop): """ If `from_loop` and `target_loop` are indices over the same range, `<schedule>.compute_at(from_loop, target_loop)` merges the two loops, taking the body of `from_loop` and appending it to the body of `target_loop`. Parameters ---------- from_loop: str The loop whose body is being moved. target_loop: str The loop whose body is being appended to. """ from_band, _ = find_loop_in_bands(self.top_func, from_loop) target_band, target_axis = find_loop_in_bands(self.top_func, target_loop) from_hdl = allo_d.CreateOpHandleOp(from_band, ip=self.ip) target_hdl = allo_d.CreateOpHandleOp(target_band, ip=self.ip) loop_hdl = allo_d.CreateLoopHandleOp( target_hdl.result, StringAttr.get(target_axis), ip=self.ip ) allo_d.ComputeAtOp( from_hdl.result, target_hdl.result, loop_hdl.result, ip=self.ip )
[docs] @wrapped_apply def reuse_at(self, target, axis): """ Takes an array in a kernel, for example if the array is `B`, this would be `<schedule>.B`, accessed by index `axis` and creates a reuse buffer to reuse values from `target` which are accessed in a sequentially moving window. Parameters ---------- target: allo.ir.utils.MockBuffer The array being accessed. axis: str The loop index used to access values in `target` """ _, _, target = find_buffer(self.module, target, self.func_args) func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) memref_type = MemRefType.get((1,), F32Type.get()) def find_reuse_buffers(res): for func in self.module.body.operations: if isinstance(func, func_d.FuncOp): for op in func.entry_block.operations: if ( isinstance(op, memref_d.AllocOp) and "name" in op.attributes and band_name + "_reuse" in StringAttr(op.attributes["name"]).value ): res.append(op) prev_reuse_buffers = [] find_reuse_buffers(prev_reuse_buffers) allo_d.ReuseAtOp(memref_type, target.result, loop_hdl.result, ip=ip) _mlir_lower_pipeline(self.module) new_reuse_buffers = [] find_reuse_buffers(new_reuse_buffers) new_reuse_buffers = [ buf for buf in new_reuse_buffers if buf not in prev_reuse_buffers ] if len(new_reuse_buffers) - len(prev_reuse_buffers) != 1: raise RuntimeError("Reuse buffer not found") return MockBuffer( self.top_func_name, StringAttr(new_reuse_buffers[-1].attributes["name"]).value, )
[docs] @wrapped_apply def to(self, target, dst, axis=None, depth=-1): """ Takes an array in the kernel, `target`, for example if the array is `B`, this would be `target` would be `<schedule>.B`, and converts it into a stream. `dst` is the name of the array any value of `target` is written to. For example if `C[i, j] = B[i, j]`, `dst` would be specified as `"C"`. If values of `<target>` get written to multiple arrays. Multiple calls to `<schedule>.to(...)` may be needed. Parameters ---------- target: allo.ir.utils.MockBuffer The array to convert to a stream. dst: str An array which a value of `target` is written to. axis: str Move axis-th loop body to xcel scope. depth: int The streaming channel depth. """ return prim.to( self.module, target, dst, axis, depth, self.func_args, self.top_func_name )
[docs] @wrapped_apply def unfold(self, band_name, axes): """ Finds a set of nested loops with name `band_name` and for every `<i>` in list `axes`. The `<i>th` nested loop is unfolded into a constant number of copies of it's loop body. Parameters ---------- band_name: str The set of nested loops to unroll. axes: list[int] A list of the axes to unroll. """ assert isinstance(axes, list), "Axes must be a list" axes.sort() assert axes == list( range(axes[0], axes[0] + len(axes)) ), "Axes must be consecutive" # start from the inner most loop if ":" in band_name: func = self._find_function(band_name.split(":")[0]) band_name = band_name.split(":")[1] else: func = self.top_func for axis in axes[::-1]: # Need to recompute the loop nests due to the MLIR bug: # https://reviews.llvm.org/D101422 # Otherwise, it may hit invalid operations band = self._find_band(band_name, func) target_outer = band.get_outer_most() loops = list(band) op_to_remove = [] _, loop_wrapper = loops[axis] loop = loop_wrapper.loop lower_bound = loop.attributes["lowerBoundMap"] assert str(lower_bound) == "affine_map<() -> (0)>", "Lower bound must be 0" upper_bound = loop.attributes["upperBoundMap"] upper_bound = int( re.findall(r"affine_map<\(\) -> \(([0-9]*)\)>", str(upper_bound))[0] ) if axis > 0: ip = InsertionPoint.at_block_terminator(loops[axis - 1][1].loop.body) else: ip = InsertionPoint(target_outer) for op in loop.body.operations: if isinstance(op, affine_d.AffineYieldOp): break def update_operand(op, old, new): if isinstance(op, affine_d.AffineForOp): # pylint: disable=cell-var-from-loop for in_op in op.body.operations: update_operand(in_op, old, new) else: op.operation.replace_uses_of_with(old, new) # unfold the body `upper_bound` times for idx in range(upper_bound): # pylint: disable=too-many-function-args cst_op = arith_d.ConstantOp(IndexType.get(), idx, ip=ip) # Directly duplicate the loop itself # (to preserve a scope for replacing the induction variable), # and replace the induction variable with the constant new_loop = loop.operation.clone(ip) for op in new_loop.body.operations: if isinstance(op, affine_d.AffineYieldOp): break update_operand(op, new_loop.induction_variable, cst_op.result) op.move_before(new_loop) if isinstance(op, affine_d.AffineForOp): new_name = ( f"{band_name}_{idx}" if "op_name" not in op.attributes else f"{op.attributes['op_name'].value}_{idx}" ) op.attributes["op_name"] = StringAttr.get(new_name) if isinstance(op, func_d.CallOp): # Also need to duplicate the function outside the top function old_func = self._find_function( FlatSymbolRefAttr(op.attributes["callee"]).value ) dup_func = old_func.operation.clone(InsertionPoint(func)) new_name = ( f"{FlatSymbolRefAttr(op.attributes['callee']).value}_{idx}" ) dup_func.attributes["sym_name"] = StringAttr.get(new_name) op.attributes["callee"] = FlatSymbolRefAttr.get(new_name) if old_func not in op_to_remove: op_to_remove.append(old_func) op_to_remove.append(new_loop) # need to erase at the end for op in op_to_remove: op.operation.erase() loop.operation.erase() # TODO: use a class to wrap the results return axes
# pylint: disable=redefined-builtin
[docs] @wrapped_apply def compose(self, schs: list, id=None, instantiate=None): """ Uses `schs`, a schedule for a kernel called in this kernel, in this kernel. A kernel, `<k1>`, may call another kernel, `<k2>`. This means the output of `<k1>.customize()` will contain the MLIR for the compiled `<k2>`, `<s2'>`. `<s2'>` will not have any custom schedule. To use a custom schedule, `<s2>`, the compiled `<k2>` with some schedule can be created. This is inserted into the schedule for this kernel through `self.compose(<s2>)`. Parameters ---------- schs: allo.customize.Schedule The schedule of a kernel used in `self`. id: str Identifies the schedule to replace contained in `self`. This schedule in `self` must be annotated if `id` is specified. instantiate: list This is a list of objects used to instantiate types `schs` is generic over. """ def get_name(arg): if isinstance(arg, (LoopWrapper, MockBuffer)): arg = copy.copy(arg) orig_func_name = arg.func if arg.func is not None else sch.top_func_name func_name = ( orig_func_name if id is None else orig_func_name + "_" + str(id) ) if self._find_function(func_name, error=False) is None: func_name = orig_func_name + "_0" arg.func = func_name return arg orig_func_name = arg.split(":")[0] if ":" in arg else sch.top_func_name arg = arg.split(":")[1] if ":" in arg else arg func_name = orig_func_name if id is None else orig_func_name + "_" + str(id) if self._find_function(func_name, error=False) is None: func_name = orig_func_name + "_0" return f"{func_name}:{arg}" if not isinstance(schs, list): schs = [schs] for sch in schs: if isinstance(sch, PyFunctionType): schedule = customize(sch, instantiate=instantiate) if sch not in KERNEL2SCHEDULE: raise RuntimeError( f"Cannot find schedule for kernel {sch.__name__}" ) sch = KERNEL2SCHEDULE[sch](schedule) if not isinstance(sch, Schedule): raise TypeError("The first argument must be a Schedule object") for primitive in sch.primitive_sequences: args, kwargs = primitive[1:] # Avoid changing the original schedule args = args.copy() kwargs = kwargs.copy() # Update axes if primitive[0] in {"reorder", "fuse"}: args = [get_name(arg) for arg in args] elif primitive[0] in { "split", "unroll", "pipeline", "parallel", "dataflow", }: if "axis" in kwargs: kwargs["axis"] = get_name(kwargs["axis"]) else: args[0] = get_name(args[0]) elif primitive[0] in {"buffer_at", "reuse_at"}: if "axis" in kwargs: kwargs["axis"] = get_name(kwargs["axis"]) else: args[1] = get_name(args[1]) elif primitive[0] == "unfold": if "band_name" in kwargs: kwargs["band_name"] = get_name(kwargs["band_name"]) else: args[0] = get_name(args[0]) # Update target buffers if primitive[0] in { "partition", "to", "buffer_at", "reuse_at", "reshape", }: if "target" in kwargs: kwargs["target"] = get_name(kwargs["target"]) else: args[0] = get_name(args[0]) with self.module.context, Location.unknown(): primitive_func = getattr(self, primitive[0]) # directly apply primitives to new functions primitive_func(*args, **kwargs) self.primitive_sequences.append((primitive[0], args, kwargs))
def build(self, target=None, mode=None, project=None, configs=None): if target is None or target == "llvm": target = "llvm" return LLVMModule( self.module, top_func_name=self.top_func_name, ext_libs=self.ext_libs, ) if target in {"vhls", "vivado_hls", "vitis_hls"}: return HLSModule( self.module, top_func_name=self.top_func_name, platform="vivado_hls" if target != "vitis_hls" else "vitis_hls", mode=mode, project=project, ext_libs=self.ext_libs, configs=configs, func_args=self.func_args, ) raise NotImplementedError(f"Target {target} is not supported")
def customize( fn: Union[Callable, str], verbose: bool = False, enable_tensor: bool = False, lower_linalg: bool = False, global_vars: dict = None, instantiate: list = None, context: Context = None, ): # Get Python AST if isinstance(fn, str): src = fn else: src, _ = getsourcelines(fn) src = [textwrap.fill(line, tabsize=4, width=9999) for line in src] src = textwrap.dedent("\n".join(src)) tree = parse_ast(src, verbose) if instantiate is None: instantiate = [] if global_vars is None: global_vars = get_global_vars(fn) # Use-def chain analysis use_def_chain = UseDefChain(global_vars.copy(), instantiate) use_def_chain.visit(tree) # Type construction ctx_type_inf = ASTContext( global_vars=global_vars.copy(), mlir_ctx=Context() if context is None else context, enable_tensor=enable_tensor, verbose=verbose, ) ctx_type_inf.inst = instantiate tree = TypeInferer()(ctx_type_inf, tree) ctx_type_inf = None # Start building IR ctx = ASTContext( global_vars=global_vars, mlir_ctx=Context() if context is None else context, enable_tensor=enable_tensor, verbose=verbose, ) ctx.inst = instantiate module = ASTTransformer()(ctx, tree) if lower_linalg: lower_linalg_and_attach_names(module) ctx.top_func = find_func_in_module(module, fn.__name__) sch = Schedule( module, ctx.top_func, ctx.func_args, InsertionPoint.at_block_terminator(ctx.top_func.entry_block), ext_libs=ctx.ext_libs, use_def_chain=use_def_chain, inst_list=instantiate, ) # Attach buffers to schedule: # The reason why we do not attach buffers to function is that # we may have multiple schedules referring to the same function, # which will cause conflicts of different buffers in different contexts. if isinstance(fn, Callable): for name, buffer in ctx.buffers.items(): if isinstance(buffer, MockArg): # Function arguments setattr( sch, name, MockBuffer(fn.__name__, name, buffer.idx), ) elif isinstance( buffer, (memref_d.AllocOp, func_d.CallOp, memref_d.GetGlobalOp) ): # Intermediate buffers setattr(sch, name, MockBuffer(fn.__name__, name)) # Check if there are memory leaks # All live operations = {top_func} + {top_func_ip} buffer = None ctx.buffers = None global_vars = {} # Functions are stored in ctx.global_vars, which should also be removed ctx = None # assert module.context._get_live_operation_count() == 2, ( # "All live operations = 1 (top_func) + 1 (top_func_ip), " # f"expected 2, but got {module.context._get_live_operation_count()}" # ) return sch