Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from constantdict import constantdict
from typing_extensions import Self, override

import namedisl as nisl

Check failure on line 51 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Import "namedisl" could not be resolved (reportMissingImports)
import islpy as isl
import pymbolic.primitives as p
import pytools.lex
Expand Down Expand Up @@ -2044,30 +2045,45 @@

# {{{ (pw)aff to expr conversion

def aff_to_expr(aff: isl.Aff) -> ArithmeticExpression:
def aff_to_expr(aff: isl.Aff | nisl.Aff) -> ArithmeticExpression:

Check warning on line 2048 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "Aff" is unknown (reportUnknownMemberType)

Check warning on line 2048 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of parameter "aff" is partially unknown   Parameter type is "Aff | Unknown" (reportUnknownParameterType)
from pymbolic import var

# FIXME: remove this once namedisl is the standard in loopy
denom = aff.get_denominator_val().to_python()

Check warning on line 2052 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "to_python" is partially unknown   Type of "to_python" is "(() -> int) | Unknown" (reportUnknownMemberType)

Check warning on line 2052 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "get_denominator_val" is partially unknown   Type of "get_denominator_val" is "(() -> Val) | Unknown" (reportUnknownMemberType)

result = (aff.get_constant_val()*denom).to_python()

Check warning on line 2053 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "get_constant_val" is partially unknown   Type of "get_constant_val" is "(() -> Val) | Unknown" (reportUnknownMemberType)

Check warning on line 2053 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "to_python" is partially unknown   Type of "to_python" is "(() -> int) | Unknown" (reportUnknownMemberType)
for dt in [isl.dim_type.in_, isl.dim_type.param]:
for i in range(aff.dim(dt)):
coeff = (aff.get_coefficient_val(dt, i)*denom).to_python()
if isinstance(aff, isl.Aff):
for dt in [isl.dim_type.in_, isl.dim_type.param]:
for i in range(aff.dim(dt)):
coeff = (aff.get_coefficient_val(dt, i)*denom).to_python()

Check warning on line 2057 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "to_python" is partially unknown   Type of "to_python" is "(() -> int) | Unknown" (reportUnknownMemberType)
if coeff:
dim_name = not_none(aff.get_dim_name(dt, i))
result += coeff*var(dim_name)

for i in range(aff.dim(isl.dim_type.div)):
coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python()

Check warning on line 2063 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "to_python" is partially unknown   Type of "to_python" is "(() -> int) | Unknown" (reportUnknownMemberType)
if coeff:
result += coeff*aff_to_expr(aff.get_div(i))

else:
in_names = set(aff.dim_type_names(isl.dim_type.in_))

Check warning on line 2068 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is unknown   Argument corresponds to parameter "iterable" in function "__init__" (reportUnknownArgumentType)

Check warning on line 2068 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "dim_type_names" is unknown (reportUnknownMemberType)
param_names = set(aff.dim_type_names(isl.dim_type.param))

for name in in_names | param_names:
coeff = (aff.get_coefficient_val(name) * denom).to_python()
if coeff:
dim_name = not_none(aff.get_dim_name(dt, i))
result += coeff*var(dim_name)
result = coeff * var(name)

for i in range(aff.dim(isl.dim_type.div)):
coeff = (aff.get_coefficient_val(isl.dim_type.div, i)*denom).to_python()
if coeff:
result += coeff*aff_to_expr(aff.get_div(i))
for name in aff.dim_type_names(isl.dim_type.div):
coeff = (aff.get_coefficient_val(name) * denom).to_python()
if coeff:
result += coeff * aff_to_expr(aff.get_div(name))

assert not isinstance(result, complex)
return flatten(result // denom)


def pw_aff_to_expr(
pw_aff: int | isl.PwAff | isl.Aff,
pw_aff: int | isl.PwAff | isl.Aff | nisl.PwAff | nisl.Aff,
int_ok: bool = False
) -> ArithmeticExpression:
if isinstance(pw_aff, int):
Expand Down
120 changes: 120 additions & 0 deletions loopy/transform/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import islpy as isl
import namedisl as nisl

Check failure on line 2 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Import "namedisl" could not be resolved (reportMissingImports)

import loopy as lp
from loopy.kernel import LoopKernel
from loopy.kernel.data import AddressSpace
from loopy.kernel.instruction import MultiAssignmentBase
from loopy.match import parse_stack_match
from loopy.symbolic import (
RuleAwareSubstitutionMapper,
SubstitutionRuleMappingContext,
pw_aff_to_expr
)
from loopy.transform.precompute import contains_a_subst_rule_invocation
from loopy.translation_unit import for_each_kernel

from pymbolic import var
from pymbolic.mapper.substitutor import make_subst_func

from pytools.tag import Tag


@for_each_kernel
def compute(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this could just be a @for_each_kernel decorator?

kernel: LoopKernel,
substitution: str,
compute_map: isl.Map | nisl.Map,
storage_inames: list[str],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is storage_inames?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed this the first time around. storage_inames corresponds to the inames that would be generated in something like a tiled matmul to fill shared memory with input tiles. Maybe storage_axes is a better name. This corresponds to the a in the (a, l) range of compute_map.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer Sequence to list on input.

default_tag: Tag | str | None = None,
temporary_address_space: AddressSpace | None = None
) -> LoopKernel:
"""
Inserts an instruction to compute an expression given by :arg:`substitution`
and replaces all invocations of :arg:`substitution` with the result of the
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all

Not really: only the relevant ones, where relevant should be defined below.

compute instruction.

:arg substitution: The substitution rule for which the compute
transform should be applied.

:arg compute_map: An :class:`isl.Map` representing a relation between
substitution rule indices and tuples `(a, l)`, where `a` is a vector of
storage indices and `l` is a vector of "timestamps".
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the boundary of a and l determined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In it's current form, it relies on user input to determine what a is (this is storage_inames).

"""
if isinstance(compute_map, isl.Map):
compute_map = nisl.make_map(compute_map)

if not temporary_address_space:
temporary_address_space = AddressSpace.GLOBAL

# {{{ normalize names

iname_to_storage_map = {
iname : (iname + "_store" if iname in kernel.all_inames() else iname)
for iname in storage_inames
}

compute_map = compute_map.rename_dims(iname_to_storage_map)

Check failure on line 57 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Cannot access attribute "rename_dims" for class "Map"   Attribute "rename_dims" is unknown (reportAttributeAccessIssue)

# }}}

# {{{ update kernel domain to contain storage inames

new_storage_axes = list(iname_to_storage_map.values())

# FIXME: use DomainChanger to add domain to kernel
storage_domain = compute_map.range().project_out_except(new_storage_axes)

Check failure on line 66 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument missing for parameter "types" (reportCallIssue)
new_domain = kernel.domains[0]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the DomainChanger.


# }}}

# {{{ express substitution inputs as pw affs of (storage, time) names

compute_pw_aff = compute_map.reverse().as_pw_multi_aff()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does .as_multi_pw_aff() do in this context? I've never used it.

Copy link
Contributor Author

@a-alveyblanc a-alveyblanc Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this particular instance, it expresses substitution inputs as piecewise affine functions of (a, l). compute uses the output of the resulting PwMultiAff to determine the multidimensional index expressions of the RHS of a substitution rule.


storage_ax_to_global_expr = {
dim_name : pw_aff_to_expr(compute_pw_aff.get_at(dim_name))
for dim_name in compute_map.dim_type_names(isl.dim_type.in_)

Check failure on line 77 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Cannot access attribute "dim_type_names" for class "Map"   Attribute "dim_type_names" is unknown (reportAttributeAccessIssue)
}

# }}}

# {{{ generate instruction from compute map

rule_mapping_ctx = SubstitutionRuleMappingContext(
kernel.substitutions, kernel.get_var_name_generator())

expr_subst_map = RuleAwareSubstitutionMapper(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to subclass this guy. Otherwise you won't be able to decide whether the usage site is "in-footprint".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was that RuleAwareSubstitutionMapper only mapped storage axes to index expressions in pymbolic.

Do you mean RuleInvocationReplacer? This explicitly checks footprints with ArrayToBufferMap and some other information computed earlier in precompute.

rule_mapping_ctx,
make_subst_func(storage_ax_to_global_expr),
within=parse_stack_match(None)
)

subst_expr = kernel.substitutions[substitution].expression
compute_expression = expr_subst_map(subst_expr, kernel, None)

Check failure on line 94 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument of type "None" cannot be assigned to parameter "insn" of type "InstructionBase" in function "__call__"   "None" is not assignable to "InstructionBase" (reportArgumentType)

temporary_name = substitution + "_temp"
assignee = var(temporary_name)[tuple(
var(iname) for iname in new_storage_axes
)]

compute_insn_id = substitution + "_compute"
compute_insn = lp.Assignment(
id=compute_insn_id,
assignee=assignee,
expression=compute_expression,
)

# }}}

# {{{ replace substitution rule with newly created instruction

for insn in kernel.instructions:
if contains_a_subst_rule_invocation(kernel, insn) \
and isinstance(insn, MultiAssignmentBase):
print(insn)


# }}}

return kernel
Loading