diff --git a/loopy/symbolic.py b/loopy/symbolic.py index ba6d71a80..f47e32f9d 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -48,6 +48,7 @@ from constantdict import constantdict from typing_extensions import Self, override +import namedisl as nisl import islpy as isl import pymbolic.primitives as p import pytools.lex @@ -2044,30 +2045,45 @@ def map_subscript(self, expr: p.Subscript) -> Set[p.Subscript]: # {{{ (pw)aff to expr conversion -def aff_to_expr(aff: isl.Aff) -> ArithmeticExpression: +def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression: from pymbolic import var + # FIXME: remove this once namedisl is the standard in loopy denom = aff.get_denominator_val().to_python() - result = (aff.get_constant_val()*denom).to_python() - for dt in [isl.dim_type.in_, isl.dim_type.param]: - for i in range(aff.dim(dt)): - coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if isinstance(aff, isl.Aff): + for dt in [isl.dim_type.in_, isl.dim_type.param]: + for i in range(aff.dim(dt)): + coeff = (aff.get_coefficient_val(dt, i)*denom).to_python() + if coeff: + dim_name = not_none(aff.get_dim_name(dt, i)) + result += coeff*var(dim_name) + + for i in range(aff.dim(isl.dim_type.div)): + coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() + if coeff: + result += coeff*aff_to_expr(aff.get_div(i)) + + else: + in_names = set(aff.dim_type_names(isl.dim_type.in_)) + param_names = set(aff.dim_type_names(isl.dim_type.param)) + + for name in in_names | param_names: + coeff = (aff.get_coefficient_val(name) * denom).to_python() if coeff: - dim_name = not_none(aff.get_dim_name(dt, i)) - result += coeff*var(dim_name) + result = coeff * var(name) - for i in range(aff.dim(isl.dim_type.div)): - coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python() - if coeff: - result += coeff*aff_to_expr(aff.get_div(i)) + for name in aff.dim_type_names(isl.dim_type.div): + coeff = (aff.get_coefficient_val(name) * denom).to_python() + if coeff: + result += coeff * aff_to_expr(aff.get_div(name)) assert not isinstance(result, complex) return flatten(result // denom) def pw_aff_to_expr( - pw_aff: int | isl.PwAff | isl.Aff, + pw_aff: int | isl.PwAff | isl.Aff | nisl.PwAff | nisl.Aff, int_ok: bool = False ) -> ArithmeticExpression: if isinstance(pw_aff, int): diff --git a/loopy/transform/compute.py b/loopy/transform/compute.py new file mode 100644 index 000000000..59ddf8a2e --- /dev/null +++ b/loopy/transform/compute.py @@ -0,0 +1,120 @@ +import islpy as isl +import namedisl as nisl + +import loopy as lp +from loopy.kernel import LoopKernel +from loopy.kernel.data import AddressSpace +from loopy.kernel.instruction import MultiAssignmentBase +from loopy.match import parse_stack_match +from loopy.symbolic import ( + RuleAwareSubstitutionMapper, + SubstitutionRuleMappingContext, + pw_aff_to_expr +) +from loopy.transform.precompute import contains_a_subst_rule_invocation +from loopy.translation_unit import for_each_kernel + +from pymbolic import var +from pymbolic.mapper.substitutor import make_subst_func + +from pytools.tag import Tag + + +@for_each_kernel +def compute( + kernel: LoopKernel, + substitution: str, + compute_map: isl.Map | nisl.Map, + storage_inames: list[str], + default_tag: Tag | str | None = None, + temporary_address_space: AddressSpace | None = None + ) -> LoopKernel: + """ + Inserts an instruction to compute an expression given by :arg:`substitution` + and replaces all invocations of :arg:`substitution` with the result of the + compute instruction. + + :arg substitution: The substitution rule for which the compute + transform should be applied. + + :arg compute_map: An :class:`isl.Map` representing a relation between + substitution rule indices and tuples `(a, l)`, where `a` is a vector of + storage indices and `l` is a vector of "timestamps". + """ + if isinstance(compute_map, isl.Map): + compute_map = nisl.make_map(compute_map) + + if not temporary_address_space: + temporary_address_space = AddressSpace.GLOBAL + + # {{{ normalize names + + iname_to_storage_map = { + iname : (iname + "_store" if iname in kernel.all_inames() else iname) + for iname in storage_inames + } + + compute_map = compute_map.rename_dims(iname_to_storage_map) + + # }}} + + # {{{ update kernel domain to contain storage inames + + new_storage_axes = list(iname_to_storage_map.values()) + + # FIXME: use DomainChanger to add domain to kernel + storage_domain = compute_map.range().project_out_except(new_storage_axes) + new_domain = kernel.domains[0] + + # }}} + + # {{{ express substitution inputs as pw affs of (storage, time) names + + compute_pw_aff = compute_map.reverse().as_pw_multi_aff() + + storage_ax_to_global_expr = { + dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name)) + for dim_name in compute_map.dim_type_names(isl.dim_type.in_) + } + + # }}} + + # {{{ generate instruction from compute map + + rule_mapping_ctx = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + + expr_subst_map = RuleAwareSubstitutionMapper( + rule_mapping_ctx, + make_subst_func(storage_ax_to_global_expr), + within=parse_stack_match(None) + ) + + subst_expr = kernel.substitutions[substitution].expression + compute_expression = expr_subst_map(subst_expr, kernel, None) + + temporary_name = substitution + "_temp" + assignee = var(temporary_name)[tuple( + var(iname) for iname in new_storage_axes + )] + + compute_insn_id = substitution + "_compute" + compute_insn = lp.Assignment( + id=compute_insn_id, + assignee=assignee, + expression=compute_expression, + ) + + # }}} + + # {{{ replace substitution rule with newly created instruction + + for insn in kernel.instructions: + if contains_a_subst_rule_invocation(kernel, insn) \ + and isinstance(insn, MultiAssignmentBase): + print(insn) + + + # }}} + + return kernel