Source code for allo.customize

# Copyright Allo authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-name-in-module, too-many-nested-blocks, too-many-instance-attributes

import re
import inspect
import textwrap
import copy
from dataclasses import dataclass
from functools import wraps
from types import FunctionType as PyFunctionType
from typing import Union
from collections.abc import Callable

from ._mlir.ir import (
    Context,
    Location,
    InsertionPoint,
    StringAttr,
    UnitAttr,
    IndexType,
    IntegerType,
    IntegerAttr,
    TypeAttr,
    F32Type,
    MemRefType,
    FlatSymbolRefAttr,
    FunctionType,
    AffineMap,
    AffineMapAttr,
    BlockArgument,
)
from ._mlir.dialects import (
    allo as allo_d,
    memref as memref_d,
    affine as affine_d,
    scf as scf_d,
    arith as arith_d,
    func as func_d,
)
from ._mlir.dialects.affine import (
    AffineExpr,
    AffineDimExpr,
)
from ._mlir.exceptions import (
    AlloValueError,
)

from . import primitives as prim
from .ir.visitor import ASTContext
from .ir.utils import MockArg, MockBuffer, parse_ast, get_global_vars
from .ir.builder import ASTTransformer
from .ir.infer import TypeInferer
from .ir.transform import (
    get_affine_loop_nests,
    find_loop_in_bands,
    find_buffer,
    find_func_in_module,
    LoopWrapper,
)
from .passes import (
    _mlir_lower_pipeline,
    lower_linalg_and_attach_names,
    analyze_use_def,
)
from .utils import freeze_list
from .backend.llvm import LLVMModule
from .backend.hls import HLSModule
from .backend.xls import XLSCCModule
from .library import KERNEL2SCHEDULE
from .library.systolic import check_systolic, prepare_systolic


def getsourcefile(obj):
    ret = inspect.getsourcefile(obj)
    if ret is None:
        ret = inspect.getfile(obj)
    return ret


def wrapped_apply(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        sch = args[0]
        with sch.module.context, Location.unknown():
            res = fn(*args, **kwargs)
        _mlir_lower_pipeline(sch.module)
        # Update top function in the current context
        for op in sch.module.body.operations:
            if isinstance(op, func_d.FuncOp) and op.name.value == sch.top_func_name:
                sch.top_func = op
                break
        else:
            raise RuntimeError("Top function not found")
        # Update insertion point
        sch.ip = InsertionPoint.at_block_terminator(sch.top_func.entry_block)
        # Record primitive sequences
        if fn.__name__ != "compose":
            sch.primitive_sequences.append((fn.__name__, list(args[1:]), kwargs))
        return res

    return wrapper


@dataclass
class Partition:
    Complete = 0
    Block = 1
    Cyclic = 2


[docs] class Schedule: def __init__( self, module, top_func, func_args, ip, ext_libs=None, inst_list=None, func_instances=None, ): self.module = module self.top_func = top_func self.top_func_name = top_func.name.value # func_args are dtensors self.func_args = func_args self.ip = ip self.primitive_sequences = [] if ext_libs is None: ext_libs = [] self.ext_libs = ext_libs self.partitioned_arrays = {} self.inst_list = inst_list if inst_list is not None else [] if func_args: for func_name, _ in func_args.items(): if func_name not in self.func_args: self.func_args[func_name] = [] self.func_instances = func_instances self.systolic = check_systolic(self) def get_loops(self, func=None): if isinstance(func, str): func = self._find_function(func) if func is None: func = self.top_func return get_affine_loop_nests(func) def _find_band(self, band_name, func=None): loops = self.get_loops(func) if band_name in loops.loops: return loops[band_name] raise RuntimeError(f"Band {band_name} not found") def _find_function(self, name, error=True): for func in self.module.body.operations: if isinstance(func, func_d.FuncOp) and func.name.value == name: return func if error: raise RuntimeError(f"Function {name} not found") return None def _get_func_and_axis(self, axis): if isinstance(axis, LoopWrapper): func = self._find_function(axis.func) return func, axis if ":" in axis: func_name, axis = axis.split(":") else: func_name = self.top_func_name func = self._find_function(func_name) return func, axis
[docs] @wrapped_apply def split(self, axis, factor): """ `split` will find the loop with loop index `axis` and tile it with each tile size `factor` The new inner loop will be named `axis.inner` and the outer loop will be named `axis.outer` Parameters ---------- axis: str The name of an index in the kernel. factor: int The size of each tile, e.g. the size of the inner nested loop. """ func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) i32 = IntegerType.get_unsigned(32) factor = IntegerAttr.get(i32, factor) allo_d.SplitOp(loop_hdl.result, factor, ip=ip)
[docs] @wrapped_apply def reorder(self, *args): """ Reorders nested loops with indices listed in `args` such that the outermost loop is the first index listed in `args`, the second is the second outermost, and so on. This function is vardic, accepting each index as a separate argument. """ func, axis = self._get_func_and_axis(args[0]) band_name, _ = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdls = [] for arg in args: func, axis = self._get_func_and_axis(arg) band_name, axis = find_loop_in_bands(func, axis) loop_hdls.append( allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) ) arg_results = [arg.result for arg in loop_hdls] allo_d.ReorderOp(arg_results, ip=ip)
[docs] @wrapped_apply def unroll(self, axis, factor=0): """ Unrolls a loop with loop index `axis` by `factor`. Parameters ---------- axis: str The name of an index in the kernel. factor: int The factor to unroll by, for example a factor of 2 will cause the body to be duplicated once. """ func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) i32 = IntegerType.get_unsigned(32) factor = IntegerAttr.get(i32, factor) allo_d.UnrollOp(loop_hdl.result, factor=factor, ip=ip)
[docs] @wrapped_apply def fuse(self, *args): """ Combines loops with indices listed in `args` into a single loop over a single index. This function is vardic, accepting each index as a separate argument. """ func, axis = self._get_func_and_axis(args[0]) band_name, _ = find_loop_in_bands(func, args[0]) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdls = [] axes = [] for arg in args: func, axis = self._get_func_and_axis(args) band_name, axis = find_loop_in_bands(func, arg) loop_hdls.append( allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) ) axes.append(axis) arg_results = [arg.result for arg in loop_hdls] allo_d.FuseOp(arg_results, ip=ip) if isinstance(args[0], LoopWrapper): name = "_".join(axes) + "_fused" return LoopWrapper(f"{args[0].func}:{band_name}.{name}", None) return LoopWrapper(f"{func.name.value}:{band_name}", None)
[docs] @wrapped_apply def partition(self, target, partition_type=Partition.Complete, dim=0, factor=0): """ Partitions a given array and propagates to all callers and callees. Parameters ---------- target: allo.ir.utils.MockBuffer | str The array to partition. partition_type: allo.customize.Partition Complete, Block, or Cyclic partition type. factor: int The number of arrays created by a block or cyclic partition. dim: int The dimension to partition. If dim=0, all dimensions are partitioned. """ # Validate inputs if partition_type not in ( Partition.Complete, Partition.Block, Partition.Cyclic, ): raise AlloValueError("Invalid partition type") if dim < 0: raise AlloValueError("Invalid dimension") if factor < 0: raise AlloValueError("Invalid factor") # Convert partition type to integer partition_type_int = { Partition.Complete: 0, Partition.Block: 1, Partition.Cyclic: 2, }[partition_type] # Normalize target to MockBuffer if isinstance(target, str): func_name, buf_name = target.split(":") target = MockBuffer(func_name, buf_name) # Check for duplicate partitioning target_key = f"{target.func}:{target.name}" if target_key in self.partitioned_arrays: for item in self.partitioned_arrays[target_key]: if item[0] == Partition.Complete and item[1] == 0: return # Already completely partitioned raise AlloValueError( f"Cannot partition the same array twice: {target_key}" ) # Collect all buffers that need partitioning via propagation buffers_to_partition = self._collect_partition_targets(target) # Apply partitioning to all collected buffers i32 = IntegerType.get_signless(32) ui32 = IntegerType.get_unsigned(32) for buf in buffers_to_partition: buf_key = f"{buf.func}:{buf.name}" func, _, mlir_target = find_buffer(self.module, buf, self.func_args) # Record partition info if buf_key not in self.partitioned_arrays: self.partitioned_arrays[buf_key] = [] self.partitioned_arrays[buf_key].append((partition_type_int, dim, factor)) # Create partition operation allo_d.PartitionOp( mlir_target.result, partition_kind=IntegerAttr.get(i32, partition_type_int), dim=IntegerAttr.get(ui32, dim), factor=IntegerAttr.get(ui32, factor), ip=InsertionPoint.at_block_terminator(func.entry_block), ) # Update types for function calls if isinstance(mlir_target, func_d.CallOp): self._update_call_types(mlir_target, partition_type_int, dim, factor) # If this buffer is returned by a function, update the function's return type # and all call sites to that function self._propagate_return_type_to_callers( func, mlir_target, partition_type_int, dim, factor ) # Update global memory references self._update_global_types(buffers_to_partition, partition_type_int, dim, factor)
def _propagate_return_type_to_callers( self, func, mlir_target, partition_type, dim, factor ): """If a partitioned buffer is returned, update function signature and all call sites.""" if not isinstance(func, func_d.FuncOp): return # Check if this buffer is returned by the function for op in func.entry_block.operations: if isinstance(op, func_d.ReturnOp) and op.operands: returned_value = op.operands[0] # Check if the partitioned buffer is what's being returned if ( hasattr(mlir_target, "result") and returned_value == mlir_target.result ): # Compute new return type with partition layout shape = mlir_target.result.type.shape layout_attr = self._compute_partition_layout( shape, partition_type, dim, factor ) old_type = func.type.results[0] new_return_type = MemRefType.get( old_type.shape, old_type.element_type, layout_attr, old_type.memory_space, ) # Update function signature new_func_type = FunctionType.get( list(func.type.inputs), [new_return_type] ) func.attributes["function_type"] = TypeAttr.get(new_func_type) # Update ALL call sites to this function func_name = func.attributes["sym_name"].value self._update_all_call_sites(func_name, new_return_type) break def _collect_partition_targets(self, target): """Collect all buffers that need partitioning by traversing call graph.""" visited = set() to_partition = [] worklist = [target] while worklist: buf = worklist.pop() buf_key = f"{buf.func}:{buf.name}" if buf_key in visited: continue visited.add(buf_key) to_partition.append(buf) _, _, mlir_target = find_buffer(self.module, buf, self.func_args) # Add equivalent variables (aliases) for equiv in self._get_equivalent_buffers(buf): if f"{equiv.func}:{equiv.name}" not in visited: worklist.append(equiv) # Propagate through function calls (callees) worklist.extend(self._get_callee_buffers(mlir_target)) # Propagate to callers worklist.extend(self._get_caller_buffers(mlir_target)) return to_partition def _get_equivalent_buffers(self, buf): """Get all buffers that are aliases of the given buffer.""" result = [] arg_names = [ dtensor.name if hasattr(dtensor, "name") else dtensor for dtensor in self.func_args.get(buf.func, []) ] # Convert to argument index if it's a function argument if buf.name in arg_names: idx = arg_names.index(buf.name) lookup_key = f"{buf.func}:{idx}" else: lookup_key = f"{buf.func}:{buf.name}" for equiv_key in self.get_equivalent_variables(lookup_key): path, name = equiv_key.split(":") if name.isdigit(): # Convert argument index back to name arg = self.func_args[path][int(name)] name = arg.name if hasattr(arg, "name") else arg result.append(MockBuffer(path, name)) return result def _get_callee_buffers(self, mlir_target): """Get buffers in called functions that correspond to this buffer.""" result = [] # If this is a call result, partition the returned buffer in callee if isinstance(mlir_target, func_d.CallOp): callee_name = FlatSymbolRefAttr(mlir_target.attributes["callee"]).value callee_func = self._find_function(callee_name, error=False) if callee_func: # Find the returned buffer for op in callee_func.entry_block.operations: if isinstance(op, func_d.ReturnOp) and op.operands: returned = op.operands[0] if hasattr(returned, "owner") and returned.owner: owner = returned.owner op_name = getattr( owner, "name", getattr(getattr(owner, "operation", None), "name", ""), ) if op_name == "memref.alloc" and "name" in owner.attributes: buf_name = StringAttr(owner.attributes["name"]).value result.append(MockBuffer(callee_name, buf_name)) break # If value is passed to a call, partition corresponding parameter if hasattr(mlir_target, "result") and mlir_target.result: for use in mlir_target.result.uses: if isinstance(use.owner, func_d.CallOp): call_op = use.owner callee_name = FlatSymbolRefAttr(call_op.attributes["callee"]).value for i, operand in enumerate(call_op.operands): if operand == mlir_target.result: param = self._get_param_buffer(callee_name, i) if param: result.append(param) break # Handle BlockArgument (function parameter) passed to calls if isinstance(mlir_target, BlockArgument): for use in mlir_target.uses: if isinstance(use.owner, func_d.CallOp): call_op = use.owner callee_name = FlatSymbolRefAttr(call_op.attributes["callee"]).value for i, operand in enumerate(call_op.operands): if operand == mlir_target: param = self._get_param_buffer(callee_name, i) if param: result.append(param) break return result def _get_caller_buffers(self, mlir_target): """Get buffers in calling functions that correspond to this buffer.""" result = [] # Find calls to the same function and propagate if isinstance(mlir_target, func_d.CallOp): callee_attr = mlir_target.attributes["callee"] for func in self.module.body.operations: if isinstance(func, func_d.FuncOp): for op in func.entry_block.operations: if ( isinstance(op, func_d.CallOp) and op.attributes["callee"] == callee_attr and op != mlir_target and "name" in op.attributes ): func_name = func.attributes["sym_name"].value call_name = op.attributes["name"].value result.append(MockBuffer(func_name, call_name)) return result def _get_param_buffer(self, func_name, param_idx): """Get MockBuffer for a function parameter by index.""" args = self.func_args.get(func_name, []) if param_idx < len(args): arg = args[param_idx] name = arg.name if hasattr(arg, "name") else arg return MockBuffer(func_name, name) return None def _compute_partition_layout(self, shape, partition_type, dim, factor): """Compute the affine map for partitioned memory layout.""" partition_idx = [] address_idx = [] for i, size in enumerate(shape): applies_to_dim = (dim == 0) or (i == dim - 1) if applies_to_dim: if partition_type == Partition.Cyclic: partition_idx.append(AffineDimExpr.get(i) % factor) address_idx.append( AffineExpr.get_floor_div(AffineDimExpr.get(i), factor) ) elif partition_type == Partition.Block: block_size = (size + factor - 1) // factor partition_idx.append( AffineExpr.get_floor_div(AffineDimExpr.get(i), block_size) ) address_idx.append(AffineDimExpr.get(i) % block_size) else: # Complete partition_idx.append(AffineDimExpr.get(i)) address_idx.append(AffineExpr.get_constant(0)) else: partition_idx.append(AffineExpr.get_constant(0)) address_idx.append(AffineDimExpr.get(i)) affine_map = AffineMap.get( dim_count=len(shape), symbol_count=0, exprs=partition_idx + address_idx ) return AffineMapAttr.get(affine_map) def _update_call_types(self, call_op, partition_type, dim, factor): """Update function signature types after partitioning a call result.""" callee_name = FlatSymbolRefAttr(call_op.attributes["callee"]).value callee_func = self._find_function(callee_name, error=False) if not callee_func: return shape = call_op.result.type.shape layout_attr = self._compute_partition_layout(shape, partition_type, dim, factor) # Create new return type with partition layout old_type = callee_func.type.results[0] new_return_type = MemRefType.get( old_type.shape, old_type.element_type, layout_attr, old_type.memory_space ) # Update input types if operands are partitioned new_input_types = list(callee_func.type.inputs) for i, operand in enumerate(call_op.operands): if hasattr(operand.type, "layout"): if operand.type.layout != callee_func.type.inputs[i].layout: new_input_types[i] = operand.type callee_func.arguments[i].set_type(operand.type) # Update function type new_func_type = FunctionType.get(new_input_types, [new_return_type]) callee_func.attributes["function_type"] = TypeAttr.get(new_func_type) # CRITICAL: Update ALL call sites to this function with the new return type self._update_all_call_sites(callee_name, new_return_type) # Update downstream function parameters that use this result for use in call_op.result.uses: if isinstance(use.owner, func_d.CallOp): self._update_downstream_param_type( use.owner, call_op.result, new_return_type ) def _update_all_call_sites(self, callee_name, new_return_type): """Update all call sites to a function when its signature changes.""" for func in self.module.body.operations: if not isinstance(func, func_d.FuncOp): continue for op in func.entry_block.operations: if isinstance(op, func_d.CallOp): call_callee = FlatSymbolRefAttr(op.attributes["callee"]).value if call_callee == callee_name: # Update the call operation's result type op.result.set_type(new_return_type) # If this call's result is used by another function, # we need to update that function's parameter types too for use in op.result.uses: user_op = use.owner if isinstance(user_op, func_d.CallOp): self._update_downstream_param_type( user_op, op.result, new_return_type ) def _update_downstream_param_type(self, call_op, partitioned_value, new_type): """Update parameter type in downstream function call.""" callee_name = FlatSymbolRefAttr(call_op.attributes["callee"]).value callee_func = self._find_function(callee_name, error=False) if not callee_func: return for i, operand in enumerate(call_op.operands): if operand == partitioned_value: callee_func.arguments[i].set_type(new_type) new_inputs = list(callee_func.type.inputs) new_inputs[i] = new_type new_func_type = FunctionType.get(new_inputs, callee_func.type.results) callee_func.attributes["function_type"] = TypeAttr.get(new_func_type) break def _update_global_types(self, buffers, partition_type, dim, factor): """Update global memory reference types.""" buffer_names = {buf.name for buf in buffers} for op in self.module.body.operations: if isinstance(op, memref_d.GlobalOp): if op.attributes["sym_name"].value in buffer_names: old_type = op.attributes["type"].value layout_attr = self._compute_partition_layout( old_type.shape, partition_type, dim, factor ) new_type = MemRefType.get( old_type.shape, old_type.element_type, layout_attr, old_type.memory_space, ) op.attributes["type"] = TypeAttr.get(new_type) # @wrapped_apply
[docs] def buffer_at_regular(self, target, axis): """ Creates a chip buffer to hold the values of `target` written to in loop with index `axis` instead of immediately writing them to memory. Parameters ---------- target: allo.ir.utils.MockBuffer An array written to in a loop. axis: str The loop index whose body contains writes to target """ _, _, target = find_buffer(self.module, target, self.func_args) func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) memref_type = MemRefType.get((1,), F32Type.get()) allo_d.BufferAtOp(memref_type, target.result, loop_hdl.result, ip=ip)
[docs] def buffer_at(self, target, axis): """ Creates a chip buffer to hold the values of `target` written to in loop with index `axis` instead of immediately writing them to memory. Parameters ---------- target: allo.ir.utils.MockBuffer An array written to in a loop. axis: str The loop index whose body contains writes to target """ if self.systolic: return self.buffer_at_systolic(target, axis) with self.module.context, Location.unknown(): self.buffer_at_regular(target, axis) _mlir_lower_pipeline(self.module) # Update top function in the current context for op in self.module.body.operations: if isinstance(op, func_d.FuncOp) and op.name.value == self.top_func_name: self.top_func = op break else: raise RuntimeError("Top function not found") # Update insertion point self.ip = InsertionPoint.at_block_terminator(self.top_func.entry_block) # Record primitive sequences self.primitive_sequences.append(("buffer_at", [target, axis], {}))
[docs] def buffer_at_systolic(self, target, axis): """ Creates a chip buffer to hold the values of `target` written to in loop with index `axis` instead of immediately writing them to memory in a systolic array. Parameters ---------- target: allo.ir.utils.MockBuffer An array written to in a loop. axis: str The loop index whose body contains writes to target """ buff_name = target.name _, _, target = find_buffer(self.module, target, self.func_args) func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) band = self._find_band(band_name, func) loops = list(band) outer_loop = loops[0][1].loop middle_loop = loops[1][1].loop # Middle loop inner_loop = loops[-1][1].loop # Last/innermost loop i_size = int( re.findall( r"affine_map<\(\) -> \(([0-9]*)\)>", str(outer_loop.attributes["upperBoundMap"]), )[0] ) j_size = int( re.findall( r"affine_map<\(\) -> \(([0-9]*)\)>", str(middle_loop.attributes["upperBoundMap"]), )[0] ) k_size = int( re.findall( r"affine_map<\(\) -> \(([0-9]*)\)>", str(inner_loop.attributes["upperBoundMap"]), )[0] ) load_type = MemRefType(target.result.type).element_type with self.module.context, Location.unknown(): ip = InsertionPoint.at_block_begin(func.body.blocks[0]) fifo_memref_type = MemRefType.get([i_size, j_size + 1, k_size], load_type) fifo_memref = memref_d.AllocOp(fifo_memref_type, [], [], ip=ip) fifo_memref.attributes["name"] = StringAttr.get(f"{buff_name}_fifo") fifo_mock_buffer = MockBuffer(func.name.value, f"{buff_name}_fifo") fifo_mock_buffer.result = fifo_memref.result setattr(self, f"{buff_name}_fifo", fifo_mock_buffer) return fifo_mock_buffer
[docs] @wrapped_apply def reshape(self, target, shape): """ Takes an array in the kernel, `target`, for example if the array is `B`, then would be `target` would be `<schedule>.B`, and reshapes it to tuple `shape`. As an example, if the desired shape is 32 by 4 by 8, the `<shape>` would be `(32, 4, 8)`. Parameters ---------- target: allo.ir.utils.MockBuffer The array, represented by a memory, to reshape. shape: tuple The new shape of the memory. """ _, _, target = find_buffer(self.module, target, self.func_args) eletype = MemRefType(target.result.type).element_type memref_type = MemRefType.get(shape, eletype) allo_d.ReshapeOp(memref_type, target.result, ip=self.ip)
[docs] @wrapped_apply def pipeline(self, axis, initiation_interval=1, rewind=False): """ Pipelines a loop with index `axis` into `initiation_interval` stages. Parameters ---------- axis: str The index of the loop to pipeline. initiation_interval: int The initiation_interval to be used when pipelining. rewind: bool If true, rewinding is allowed, allowing continuous loop pipelining. This is only effective for perfect loop nests inside a top level function. """ i32 = IntegerType.get_unsigned(32) ii = IntegerAttr.get(i32, initiation_interval) func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) if rewind: self.get_loops(func)[band_name][axis].loop.attributes[ "rewind" ] = UnitAttr.get() self.get_loops(func)[band_name][axis].loop.attributes["pipeline_ii"] = ii
[docs] @wrapped_apply def parallel(self, axis): """ Instantiates a loop with index `axis` to be computed in parallel with the loops it is nested with. Parameters ---------- axis: str The index of the loop to be computed in parallel. """ func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) allo_d.ParallelOp(loop_hdl.result, ip=ip)
[docs] @wrapped_apply def inline(self, axis=None): """ Inlines a function `axis`. Parameters ---------- axis: str The function to inline. """ assert axis is None or isinstance(axis, str), "Function name must be a string" if axis is None: axis = self.top_func_name func = self._find_function(axis) func.attributes["inline"] = UnitAttr.get()
[docs] @wrapped_apply def dataflow(self, axis): """ Applies a "dataflow" attribute to function `axis`. This allows for parallelism if the given function uses streams or the `to` schedule. Parameters ---------- axis: str | allo.ir.LoopWrapper The function to add the attribute to. """ if isinstance(axis, str): # function func = self._find_function(axis) func.attributes["dataflow"] = UnitAttr.get() return func, _ = self._get_func_and_axis(axis) band_name, loop_name = axis.name.split(".", 1) band_name = band_name.split(":")[1] cnt = 0 def locate_loop(op): nonlocal cnt for ope in op.body.operations: if isinstance(ope, (scf_d.ForOp, affine_d.AffineForOp)): locate_loop(ope) if ( "loop_name" in op.attributes and op.attributes["loop_name"].value == loop_name ): cnt += 1 op.attributes["dataflow"] = UnitAttr.get() for op in func.entry_block.operations: if isinstance(op, (scf_d.ForOp, affine_d.AffineForOp)): if ( "op_name" in op.attributes and op.attributes["op_name"].value == band_name ): locate_loop(op) if cnt == 0: raise RuntimeError(f"Dataflow loop {band_name}.{loop_name} not found")
[docs] @wrapped_apply def compute_at(self, from_loop, target_loop): """ If `from_loop` and `target_loop` are indices over the same range, `<schedule>.compute_at(from_loop, target_loop)` merges the two loops, taking the body of `from_loop` and appending it to the body of `target_loop`. Parameters ---------- from_loop: str The loop whose body is being moved. target_loop: str The loop whose body is being appended to. """ from_band, _ = find_loop_in_bands(self.top_func, from_loop) target_band, target_axis = find_loop_in_bands(self.top_func, target_loop) from_hdl = allo_d.CreateOpHandleOp(from_band, ip=self.ip) target_hdl = allo_d.CreateOpHandleOp(target_band, ip=self.ip) loop_hdl = allo_d.CreateLoopHandleOp( target_hdl.result, StringAttr.get(target_axis), ip=self.ip ) allo_d.ComputeAtOp( from_hdl.result, target_hdl.result, loop_hdl.result, ip=self.ip )
[docs] @wrapped_apply def reuse_at(self, target, axis): """ Takes an array in a kernel, for example if the array is `B`, this would be `<schedule>.B`, accessed by index `axis` and creates a reuse buffer to reuse values from `target` which are accessed in a sequentially moving window. Parameters ---------- target: allo.ir.utils.MockBuffer The array being accessed. axis: str The loop index used to access values in `target` """ _, _, target = find_buffer(self.module, target, self.func_args) func, axis = self._get_func_and_axis(axis) band_name, axis = find_loop_in_bands(func, axis) ip = InsertionPoint.at_block_terminator(func.entry_block) op_hdl = allo_d.CreateOpHandleOp(band_name, ip=ip) loop_hdl = allo_d.CreateLoopHandleOp(op_hdl.result, StringAttr.get(axis), ip=ip) memref_type = MemRefType.get((1,), F32Type.get()) def find_reuse_buffers(res): for func in self.module.body.operations: if isinstance(func, func_d.FuncOp): for op in func.entry_block.operations: if ( isinstance(op, memref_d.AllocOp) and "name" in op.attributes and band_name + "_reuse" in StringAttr(op.attributes["name"]).value ): res.append(op) prev_reuse_buffers = [] find_reuse_buffers(prev_reuse_buffers) allo_d.ReuseAtOp(memref_type, target.result, loop_hdl.result, ip=ip) _mlir_lower_pipeline(self.module) new_reuse_buffers = [] find_reuse_buffers(new_reuse_buffers) new_reuse_buffers = [ buf for buf in new_reuse_buffers if buf not in prev_reuse_buffers ] if len(new_reuse_buffers) != 1: raise RuntimeError("Reuse buffer not found") return MockBuffer( self.top_func_name, StringAttr(new_reuse_buffers[-1].attributes["name"]).value, )
[docs] @wrapped_apply def to(self, target, dst, axis=None, depth=-1): """ Takes an array in the kernel, `target`, for example if the array is `B`, this would be `target` would be `<schedule>.B`, and converts it into a stream. `dst` is the name of the array any value of `target` is written to. For example if `C[i, j] = B[i, j]`, `dst` would be specified as `"C"`. If values of `<target>` get written to multiple arrays. Multiple calls to `<schedule>.to(...)` may be needed. Parameters ---------- target: allo.ir.utils.MockBuffer The array to convert to a stream. dst: str An array which a value of `target` is written to. axis: str Move axis-th loop body to xcel scope. depth: int The streaming channel depth. """ return prim.to( self.module, target, dst, axis, depth, self.func_args, self.top_func_name )
[docs] @wrapped_apply def unfold(self, band_name, axes): """ Finds a set of nested loops with name `band_name` and for every `<i>` in list `axes`. The `<i>th` nested loop is unfolded into a constant number of copies of it's loop body. Parameters ---------- band_name: str The set of nested loops to unroll. axes: list[int] A list of the axes to unroll. """ if self.systolic: prepare_systolic(self, band_name) assert isinstance(axes, list), "Axes must be a list" axes.sort() assert axes == list( range(axes[0], axes[0] + len(axes)) ), "Axes must be consecutive" # start from the inner most loop if ":" in band_name: func = self._find_function(band_name.split(":")[0]) band_name = band_name.split(":")[1] else: func = self.top_func for axis in axes[::-1]: # Need to recompute the loop nests due to the MLIR bug: # https://reviews.llvm.org/D101422 # Otherwise, it may hit invalid operations band = self._find_band(band_name, func) target_outer = band.get_outer_most() loops = list(band) op_to_remove = [] _, loop_wrapper = loops[axis] loop = loop_wrapper.loop lower_bound = loop.attributes["lowerBoundMap"] assert str(lower_bound) == "affine_map<() -> (0)>", "Lower bound must be 0" upper_bound = loop.attributes["upperBoundMap"] upper_bound = int( re.findall(r"affine_map<\(\) -> \(([0-9]*)\)>", str(upper_bound))[0] ) if axis > 0: ip = InsertionPoint.at_block_terminator(loops[axis - 1][1].loop.body) else: ip = InsertionPoint(target_outer) for op in loop.body.operations: if isinstance(op, affine_d.AffineYieldOp): break def update_operand(op, old, new): if isinstance(op, affine_d.AffineForOp): # pylint: disable=cell-var-from-loop for in_op in op.body.operations: update_operand(in_op, old, new) else: op.operation.replace_uses_of_with(old, new) # unfold the body `upper_bound` times for idx in range(upper_bound): # pylint: disable=too-many-function-args cst_op = arith_d.ConstantOp(IndexType.get(), idx, ip=ip) # Directly duplicate the loop itself # (to preserve a scope for replacing the induction variable), # and replace the induction variable with the constant new_loop = loop.operation.clone(ip) for op in new_loop.body.operations: if isinstance(op, affine_d.AffineYieldOp): break update_operand(op, new_loop.induction_variable, cst_op.result) op.move_before(new_loop) if isinstance(op, affine_d.AffineForOp): new_name = ( f"{band_name}_{idx}" if "op_name" not in op.attributes else f"{op.attributes['op_name'].value}_{idx}" ) op.attributes["op_name"] = StringAttr.get(new_name) if isinstance(op, func_d.CallOp): # Also need to duplicate the function outside the top function old_func = self._find_function( FlatSymbolRefAttr(op.attributes["callee"]).value ) dup_func = old_func.operation.clone(InsertionPoint(func)) new_name = ( f"{FlatSymbolRefAttr(op.attributes['callee']).value}_{idx}" ) dup_func.attributes["sym_name"] = StringAttr.get(new_name) # extend self.func_args self.func_args[new_name] = self.func_args[ FlatSymbolRefAttr(op.attributes["callee"]).value ] op.attributes["callee"] = FlatSymbolRefAttr.get(new_name) if old_func not in op_to_remove: op_to_remove.append(old_func) op_to_remove.append(new_loop) # need to erase at the end for op in op_to_remove: op.operation.erase() loop.operation.erase() # TODO: use a class to wrap the results return axes
# pylint: disable=redefined-builtin
[docs] @wrapped_apply def compose(self, schs: list, id=None, instantiate=None): """ Uses `schs`, a schedule for a kernel called in this kernel, in this kernel. A kernel, `<k1>`, may call another kernel, `<k2>`. This means the output of `<k1>.customize()` will contain the MLIR for the compiled `<k2>`, `<s2'>`. `<s2'>` will not have any custom schedule. To use a custom schedule, `<s2>`, the compiled `<k2>` with some schedule can be created. This is inserted into the schedule for this kernel through `self.compose(<s2>)`. Parameters ---------- schs: allo.customize.Schedule The schedule of a kernel used in `self`. id: str Identifies the schedule to replace contained in `self`. This schedule in `self` must be annotated if `id` is specified. instantiate: list This is a list of objects used to instantiate types `schs` is generic over. """ def get_name(arg): if isinstance(arg, (LoopWrapper, MockBuffer)): arg = copy.copy(arg) orig_func_name = arg.func if arg.func is not None else sch.top_func_name func_name = ( orig_func_name if id is None else orig_func_name + "_" + str(id) ) if self._find_function(func_name, error=False) is None: func_name = orig_func_name + "_0" arg.func = func_name return arg orig_func_name = arg.split(":")[0] if ":" in arg else sch.top_func_name arg = arg.split(":")[1] if ":" in arg else arg func_name = orig_func_name if id is None else orig_func_name + "_" + str(id) if self._find_function(func_name, error=False) is None: func_name = orig_func_name + "_0" return f"{func_name}:{arg}" if not isinstance(schs, list): schs = [schs] for sch in schs: if isinstance(sch, PyFunctionType): schedule = customize(sch, instantiate=instantiate) if sch not in KERNEL2SCHEDULE: raise RuntimeError( f"Cannot find schedule for kernel {sch.__name__}" ) sch = KERNEL2SCHEDULE[sch](schedule) if not isinstance(sch, Schedule): raise TypeError("The first argument must be a Schedule object") if hasattr(sch, "stateful_var_map") and sch.stateful_var_map: stateful_seen = getattr(self, "_stateful_seen", {}) self._stateful_seen = stateful_seen func_name = f"{sch.top_func_name}_{id}" if id else sch.top_func_name func = self._find_function(func_name) for _, (global_name, _) in sch.stateful_var_map.items(): seen = stateful_seen.setdefault(global_name, []) suffix = id if id else f"inst{len(seen)}" new_name = f"{global_name}_{suffix}" original_global = next( ( op for op in self.module.body.operations if isinstance(op, memref_d.GlobalOp) and op.attributes["sym_name"].value == global_name ), None, ) if original_global is None: raise RuntimeError(f"Stateful global {global_name} not found") target = ( original_global if not seen else original_global.operation.clone( InsertionPoint(original_global) ) ) target.attributes["sym_name"] = StringAttr.get(new_name) seen.append(new_name) for op in func.entry_block.operations: if isinstance(op, memref_d.GetGlobalOp) and ( FlatSymbolRefAttr(op.attributes["name"]).value == global_name ): op.attributes["name"] = FlatSymbolRefAttr.get(new_name) for primitive in sch.primitive_sequences: args, kwargs = primitive[1:] # Avoid changing the original schedule args = args.copy() kwargs = kwargs.copy() # Update axes if primitive[0] in {"reorder", "fuse"}: args = [get_name(arg) for arg in args] elif primitive[0] in { "split", "unroll", "pipeline", "parallel", "dataflow", }: if "axis" in kwargs: kwargs["axis"] = get_name(kwargs["axis"]) else: args[0] = get_name(args[0]) elif primitive[0] in {"buffer_at", "reuse_at"}: if "axis" in kwargs: kwargs["axis"] = get_name(kwargs["axis"]) else: args[1] = get_name(args[1]) elif primitive[0] == "unfold": if "band_name" in kwargs: kwargs["band_name"] = get_name(kwargs["band_name"]) else: args[0] = get_name(args[0]) # Update target buffers if primitive[0] in { "partition", "to", "buffer_at", "reuse_at", "reshape", }: if "target" in kwargs: kwargs["target"] = get_name(kwargs["target"]) else: args[0] = get_name(args[0]) with self.module.context, Location.unknown(): primitive_func = getattr(self, primitive[0]) # directly apply primitives to new functions primitive_func(*args, **kwargs)
def get_equivalent_variables(self, name): use_def = analyze_use_def(self.module) for ele in use_def: if name in ele: return ele return [] def build( self, target=None, mode=None, project=None, configs=None, wrap_io=True, use_memory=False, ): if target is None or target == "llvm": target = "llvm" return LLVMModule( self.module, top_func_name=self.top_func_name, ext_libs=self.ext_libs, ) # TODO Add XLS DSLX Backend if target in {"xls", "xlscc"}: return XLSCCModule( self.module, top_func_name=self.top_func_name, project=project, use_memory=use_memory, ) if target in {"vhls", "vivado_hls", "vitis_hls", "pynq", "tapa", "ihls"}: match target: case "vitis_hls": platform = "vitis_hls" case "tapa": platform = "tapa" case "ihls": platform = "intel_hls" case "pynq": platform = "pynq" case _: platform = "vivado_hls" return HLSModule( self.module, top_func_name=self.top_func_name, platform=platform, mode=mode, project=project, ext_libs=self.ext_libs, configs=configs, func_args=self.func_args, wrap_io=wrap_io, ) raise NotImplementedError(f"Target {target} is not supported")
def customize( fn: Union[Callable, str], verbose: bool = False, enable_tensor: bool = False, lower_linalg: bool = False, global_vars: dict = None, instantiate: list = None, context: Context = None, typing_rule_set: str = "default", unroll: bool = True, ) -> Schedule: """ Args: - typing_rule_set (str): Identifier of the typing rule set used during IR building. This controls implicit type casting behavior. Currently supported values include `"default"`, which is primarily intended for HLS backends, and `"cpp-style"`, which follows C++-like typing rules and is used for the AIE backend. Defaults to `"default"`. """ # Get Python AST if isinstance(fn, str): src, starting_line_no = fn, 1 file_name = None else: src, starting_line_no = inspect.getsourcelines(fn) src = [textwrap.fill(line, tabsize=4, width=9999) for line in src] src = textwrap.dedent("\n".join(src)) file_name = inspect.getfile(fn) tree = parse_ast(src, starting_line_no=starting_line_no, verbose=verbose) if instantiate is None: instantiate = [] if global_vars is None: global_vars = get_global_vars(fn) # Type construction ctx_type_inf = ASTContext( tree=tree, global_vars=global_vars.copy(), mlir_ctx=Context() if context is None else context, inst=instantiate, unroll=unroll, enable_tensor=enable_tensor, typing_rule_set=typing_rule_set, verbose=verbose, ) tree = TypeInferer()(ctx_type_inf, tree) # Start building IR ctx = ASTContext( tree=tree, global_vars=global_vars, mlir_ctx=Context() if context is None else context, inst=instantiate, func_predicate_tags=ctx_type_inf.func_predicate_tags, unroll=unroll, meta_fors_to_unroll=ctx_type_inf.meta_fors_to_unroll, enable_tensor=enable_tensor, verbose=verbose, ) module = ASTTransformer()(ctx, tree, file_name) func_instances = { orig_name: { dim: f"{orig_name}_{str(freeze_list(predicate_tag))}" for dim, predicate_tag in kernel_instance_info.items() } for orig_name, kernel_instance_info in ctx.func_predicate_tags.items() } if lower_linalg: lower_linalg_and_attach_names(module) ctx.top_func = find_func_in_module(module, fn.__name__) sch = Schedule( module, ctx.top_func, ctx.func_args, InsertionPoint.at_block_terminator(ctx.top_func.entry_block), ext_libs=ctx.ext_libs, inst_list=instantiate, func_instances=func_instances, ) sch.stateful_var_map = getattr(ctx, "stateful_var_map", {}) # Attach buffers to schedule: # The reason why we do not attach buffers to function is that # we may have multiple schedules referring to the same function, # which will cause conflicts of different buffers in different contexts. if isinstance(fn, Callable): for name, buffer in ctx.buffers.items(): if isinstance(buffer, MockArg): # Function arguments setattr( sch, name, MockBuffer(fn.__name__, name, buffer.idx), ) elif isinstance( buffer, (memref_d.AllocOp, func_d.CallOp, memref_d.GetGlobalOp) ): # Intermediate buffers setattr(sch, name, MockBuffer(fn.__name__, name)) # Check if there are memory leaks # All live operations = {top_func} + {top_func_ip} buffer = None ctx.buffers = None global_vars = {} # Functions are stored in ctx.global_vars, which should also be removed ctx = None # assert module.context._get_live_operation_count() == 2, ( # "All live operations = 1 (top_func) + 1 (top_func_ip), " # f"expected 2, but got {module.context._get_live_operation_count()}" # ) return sch