diff --git a/loopy/__init__.py b/loopy/__init__.py index a73f83bb9..177fae61c 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -76,7 +76,7 @@ affine_map_inames, find_unused_axis_tag, make_reduction_inames_unique, has_schedulable_iname_nesting, get_iname_duplication_options, - add_inames_to_insn, add_inames_for_unused_hw_axes) + add_inames_to_insn, add_inames_for_unused_hw_axes, map_domain) from loopy.transform.instruction import ( find_instructions, map_instructions, @@ -202,7 +202,7 @@ "affine_map_inames", "find_unused_axis_tag", "make_reduction_inames_unique", "has_schedulable_iname_nesting", "get_iname_duplication_options", - "add_inames_to_insn", "add_inames_for_unused_hw_axes", + "add_inames_to_insn", "add_inames_for_unused_hw_axes", "map_domain", "add_prefetch", "change_arg_to_image", "tag_array_axes", "tag_data_axes", diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index d67df1154..57183109b 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -767,4 +767,87 @@ def subst_into_pwaff(new_space, pwaff, subst_dict): # }}} + +# {{{ add_and_name_dims + +def add_and_name_dims(isl_obj, dt, names): + """Append dimensions of the specified dimension type to the provided ISL + object, and set their names. + + :arg isl_obj: An :class:`islpy.Set` or :class:`islpy.Map` to which + new dimensions will be added. + + :arg dt: An :class:`islpy.dim_type`, i.e., an :class:`int`, specifying the + dimension type for the new dimensions. + + :arg names: An iterable of :class:`str` values specifying the names of the + new dimensions to be added. + + :returns: An object of the same type as *isl_obj* with the new dimensions + added and named. + + """ + + new_idx_start = isl_obj.dim(dt) + isl_obj = isl_obj.add_dims(dt, len(names)) + for i, name in enumerate(names): + isl_obj = isl_obj.set_dim_name(dt, new_idx_start+i, name) + return isl_obj + +# }}} + + +# {{{ add_eq_constraint_from_names + +def add_eq_constraint_from_names(isl_obj, var1, var2): + """Add constraint *var1* = *var2* to an ISL object. + + :arg isl_obj: An :class:`islpy.Set` or :class:`islpy.Map` to which + a new constraint will be added. + + :arg var1: A :class:`str` specifying the name of the first variable + involved in constraint *var1* = *var2*. + + :arg var2: A :class:`str` specifying the name of the second variable + involved in constraint *var1* = *var2*. + + :returns: An object of the same type as *isl_obj* with the constraint + *var1* = *var2*. + + """ + return isl_obj.add_constraint( + isl.Constraint.eq_from_names( + isl_obj.space, + {1: 0, var1: 1, var2: -1})) + +# }}} + + +# {{{ find_and_rename_dim + +def find_and_rename_dim(isl_obj, dt, old_name, new_name): + """Rename a dimension in an ISL object. + + :arg isl_obj: An :class:`islpy.Set` or :class:`islpy.Map` containing the + dimension to be renamed. + + :arg dt: An :class:`islpy.dim_type` (i.e., :class:`int`) specifying the + dimension type containing the dimension to be renamed. + + :arg old_name: A :class:`str` specifying the name of the dimension to be + renamed. + + :arg new_name: A :class:`str` specifying the new name of the dimension to + be renamed. + + :returns: An object of the same type as *isl_obj* with the dimension + *old_name* renamed to *new_name*. + + """ + return isl_obj.set_dim_name( + dt, isl_obj.find_dim_by_name(dt, old_name), new_name) + +# }}} + + # vim: foldmethod=marker diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index c3b4a42ee..548f9ec01 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -72,6 +72,8 @@ .. autofunction:: add_inames_to_insn +.. autofunction:: map_domain + .. autofunction:: add_inames_for_unused_hw_axes """ @@ -1832,6 +1834,436 @@ def add_inames_to_insn(kernel, inames, insn_match): # }}} +# {{{ map_domain and associated functions + +# {{{ _MapDomainMapper + +class _MapDomainMapper(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, new_inames, substitutions): + super(_MapDomainMapper, self).__init__(rule_mapping_context) + + self.old_inames = frozenset(substitutions) + self.new_inames = new_inames + + self.substitutions = substitutions + + def map_reduction(self, expr, expn_state): + red_overlap = frozenset(expr.inames) & self.old_inames + arg_ctx_overlap = frozenset(expn_state.arg_context) & self.old_inames + if red_overlap: + if len(red_overlap) != len(self.old_inames): + raise LoopyError("Reduction '%s' involves a part " + "of the map domain inames. Reductions must " + "either involve all or none of the map domain " + "inames." % str(expr)) + + if arg_ctx_overlap: + if arg_ctx_overlap == red_overlap: + # All variables are shadowed by context, that's OK. + return super(_MapDomainMapper, self).map_reduction( + expr, expn_state) + else: + raise LoopyError("Reduction '%s' has" + "some of the reduction variables affected " + "by the map_domain shadowed by context. " + "Either all or none must be shadowed." + % str(expr)) + + new_inames = list(expr.inames) + for old_iname in self.old_inames: + new_inames.remove(old_iname) + new_inames.extend(self.new_inames) + + from loopy.symbolic import Reduction + return Reduction(expr.operation, tuple(new_inames), + self.rec(expr.expr, expn_state), + expr.allow_simultaneous) + else: + return super(_MapDomainMapper, self).map_reduction(expr, expn_state) + + def map_variable(self, expr, expn_state): + if (expr.name in self.old_inames + and expr.name not in expn_state.arg_context): + return self.substitutions[expr.name] + else: + return super(_MapDomainMapper, self).map_variable(expr, expn_state) + +# }}} + + +# {{{ _find_aff_subst_from_map(iname, isl_map) + +def _find_aff_subst_from_map(iname, isl_map): + if not isinstance(isl_map, isl.BasicMap): + raise RuntimeError("isl_map must be a BasicMap") + + dt, dim_idx = isl_map.get_var_dict()[iname] + + assert dt == dim_type.in_ + + # Force isl to solve for only this iname on its side of the map, by + # projecting out all other "in" variables. + isl_map = isl_map.project_out(dt, dim_idx+1, isl_map.dim(dt)-(dim_idx+1)) + isl_map = isl_map.project_out(dt, 0, dim_idx) + dim_idx = 0 + + # Convert map to set to avoid "domain of affine expression should be a set". + # The old "in" variable will be the last of the out_dims. + new_dim_idx = isl_map.dim(dim_type.out) + isl_map = isl_map.move_dims( + dim_type.out, isl_map.dim(dim_type.out), + dt, dim_idx, 1) + isl_map = isl_map.range() # now a set + dt = dim_type.set + dim_idx = new_dim_idx + del new_dim_idx + + for cns in isl_map.get_constraints(): + if cns.is_equality() and cns.involves_dims(dt, dim_idx, 1): + coeff = cns.get_coefficient_val(dt, dim_idx) + cns_zeroed = cns.set_coefficient_val(dt, dim_idx, 0) + if cns_zeroed.involves_dims(dt, dim_idx, 1): + # not suitable, constraint still involves dim, perhaps in a div + continue + + if coeff.is_one(): + return -cns_zeroed.get_aff() + elif coeff.is_negone(): + return cns_zeroed.get_aff() + else: + # not suitable, coefficient does not have unit coefficient + continue + + raise LoopyError("No suitable equation for '%s' found" % iname) + +# }}} + + +# {{{ _apply_identity_for_missing_map_dims(mapping, desired_dims) + +def _apply_identity_for_missing_map_dims(mapping, desired_dims): + """For every variable v in *desired_dims* that is not found in the + input space for *mapping*, add input dimension v, output dimension + v_'proxy'_, and constraint v = v_'proxy'_ to the mapping. Also return a + list of the (v, v_'proxy'_) pairs. + + :arg mapping: An :class:`islpy.Map`. + + :arg desired_dims: An iterable of :class:`str` specifying the names of the + desired map input dimensions. + + :returns: A two-tuple containing the mapping with the new dimensions and + constraints added, and a list of two-tuples of :class:`str` values + specifying the (v, v_'proxy'_) pairs. + + """ + + # If the transform map in map_domain (below) does not contain all the + # inames in the iname domain (set) to which it is applied, the missing + # inames must be added to the transform map so that intersect_domain() + # doesn't remove them from the iname domain when the map is applied. + + # No two map dimension names can match, so we create a unique name for each + # new variable in the output dimension by appending _'proxy'_, and return a + # list of the (v, v_'proxy'_) pairs so that the proxy dims can be + # identified and replaced later. + + # (Apostrophes are not allowed in inames, so this suffix + # will not match any existing inames. This function is also used on + # dependency maps, which may contain variable names consisting of an iname + # suffixed with a single apostrophe.) + + from loopy.isl_helpers import ( + add_and_name_dims, add_eq_constraint_from_names) + + # {{{ Find any missing vars and add them to the input and output space + + missing_dims = list( + set(desired_dims) - set(mapping.get_var_names(dim_type.in_))) + augmented_mapping = add_and_name_dims( + mapping, dim_type.in_, missing_dims) + + missing_dims_proxies = [d+"_'prox'_" for d in missing_dims] + assert not set(missing_dims_proxies) & set( + augmented_mapping.get_var_dict().keys()) + + augmented_mapping = add_and_name_dims( + augmented_mapping, dim_type.out, missing_dims_proxies) + + proxy_name_pairs = list(zip(missing_dims, missing_dims_proxies)) + + # }}} + + # {{{ Add identity constraint (v = v_'proxy'_) for each new pair of dims + + for real_iname, proxy_iname in proxy_name_pairs: + augmented_mapping = add_eq_constraint_from_names( + augmented_mapping, proxy_iname, real_iname) + + # }}} + + return augmented_mapping, proxy_name_pairs + +# }}} + + +# {{{ _error_if_any_iname_in_constraint + +def _error_if_any_iname_in_constraint( + inames, nest_constraints, constraint_descriptor_str): + """Raise informative error if any iname in *inames* is constrained by any + nest constraint in *nest_constraints*. + """ + # (This function is only used when new machinery from + # new-loop-nest-constraints branch is detected.) + + for constraint in nest_constraints: + for tier in constraint: + for iname in inames: + if tier.contains(iname): + raise ValueError( + "%s constraint %s contains iname(s) " + "transformed by map in map_domain." + % (constraint_descriptor_str, constraint)) + +# }}} + + +# {{{ map_domain + +@for_each_kernel +def map_domain(kernel, transform_map): + """Transform an iname domain by applying a mapping from existing inames to + new inames. + + :arg transform_map: A bijective :class:`islpy.Map` from existing inames to + new inames. To be applicable to a kernel domain, all input inames in + the map must be found in the domain. The map must be applicable to + exactly one domain found in *kernel.domains*. + + """ + + # FIXME: Express _split_iname_backend in terms of this + # Missing/deleted for now: + # - slab processing + # - priorities processing + # FIXME: Express affine_map_inames in terms of this, deprecate + + # Make sure the map is bijective + if not transform_map.is_bijective(): + raise LoopyError("transform_map must be bijective") + + transform_map_out_dims = frozenset(transform_map.get_var_dict(dim_type.out)) + transform_map_in_dims = frozenset(transform_map.get_var_dict(dim_type.in_)) + + # {{{ Make sure that none of the mapped inames are involved in loop priorities + + # kernel.loop_priority is being replaced with kernel.loop_nest_constraints, + # handle both attributes. + if hasattr(kernel, "loop_priority") and kernel.loop_priority: + for prio in kernel.loop_priority: + if set(prio) & transform_map_in_dims: + raise ValueError( + "Loop priority %s contains iname(s) transformed by " + "map %s in map_domain." % (prio, transform_map)) + if hasattr(kernel, "loop_nest_constraints") and kernel.loop_nest_constraints: + _error_if_any_iname_in_constraint( + transform_map_in_dims, + kernel.loop_nest_constraints.must_nest, "Must-nest") + _error_if_any_iname_in_constraint( + transform_map_in_dims, + kernel.loop_nest_constraints.must_not_nest, "Must-not-nest") + + # }}} + + # {{{ Solve for representation of old inames in terms of new + + substitutions = {} + var_substitutions = {} + applied_iname_rewrites = kernel.applied_iname_rewrites[:] + + from loopy.symbolic import aff_to_expr + from pymbolic import var + for iname in transform_map_in_dims: + subst_from_map = aff_to_expr( + _find_aff_subst_from_map(iname, transform_map)) + substitutions[iname] = subst_from_map + var_substitutions[var(iname)] = subst_from_map + + applied_iname_rewrites.append(var_substitutions) + del var_substitutions + + # }}} + + # {{{ Function to apply mapping to one set + + def process_set(s): + """Return the transformed set. Assume that map is applicable to this + set.""" + + # {{{ Align dims of transform_map and s so that map can be applied + + # Create a map whose input space matches the set + map_with_s_domain = isl.Map.from_domain(s) + + # {{{ Check for missing map dims and add them + + # For every iname v in the domain that is *not* found in the input + # space of the transform map, add input dimension v, output dimension + # v_'proxy'_, and constraint v = v_'proxy'_ to the transform map. + # Otherwise, v will be dropped from the domain when the map is applied. + + augmented_transform_map, proxy_name_pairs = \ + _apply_identity_for_missing_map_dims( + transform_map, s.get_var_names(dim_type.set)) + + # }}} + + # {{{ Align transform map input dims with set dims + + # FIXME: Make an exported/documented interface of this in islpy + + dim_types = [dim_type.param, dim_type.in_, dim_type.out] + # Variables found in iname domain set + s_names = { + map_with_s_domain.get_dim_name(dt, i) + for dt in dim_types + for i in range(map_with_s_domain.dim(dt)) + } + # Variables found in transform map + map_names = { + augmented_transform_map.get_dim_name(dt, i) + for dt in dim_types + for i in range(augmented_transform_map.dim(dt)) + } + # (_align_dim_type uses these two sets to determine which names are in + # both the obj and template) + + from islpy import _align_dim_type + aligned_map = _align_dim_type( + dim_type.param, + augmented_transform_map, map_with_s_domain, False, + map_names, s_names) + aligned_map = _align_dim_type( + dim_type.in_, + aligned_map, map_with_s_domain, False, + map_names, s_names) + + # }}} + + # }}} + + # Apply the transform map to the domain + new_s = aligned_map.intersect_domain(s).range() + + # Now rename any proxy dims back to their original names + + from loopy.isl_helpers import find_and_rename_dim + for real_iname, proxy_iname in proxy_name_pairs: + new_s = find_and_rename_dim( + new_s, dim_type.set, proxy_iname, real_iname) + + return new_s + + # FIXME: Revive _project_out_only_if_all_instructions_in_within + + # }}} + + # {{{ Apply the transform map to exactly one domain + + map_applied_to_one_dom = False + new_domains = [] + transform_map_rules = ( + "Transform map must be applicable to exactly one domain. " + "A transform map is applicable to a domain if its input " + "inames are a subset of the domain inames.") + + for old_domain in kernel.domains: + + # Make sure transform map is applicable to this set. Then transform. + + if not transform_map_in_dims <= frozenset(old_domain.get_var_dict()): + + # Map not applicable to this set because map transforms at least + # one iname that is not present in the set. Don't transform. + new_domains.append(old_domain) + continue + + elif map_applied_to_one_dom: + + # Map is applicable to this domain, but this map was + # already applied. Error. + raise LoopyError( + "Transform map %s was applicable to more than one domain. %s" + % (transform_map, transform_map_rules)) + + else: + + # Map is applicable to this domain, and this map has not yet + # been applied. Transform. + new_domains.append(process_set(old_domain)) + map_applied_to_one_dom = True + + # If we get this far, either the map has been applied to 1 domain (good) + # or the map could not be applied to any domain, which should produce an error. + if not map_applied_to_one_dom: + raise LoopyError( + "Transform map %s was not applicable to any domain. %s" + % (transform_map, transform_map_rules)) + + # }}} + + # {{{ Update within_inames for each statement + + # If we get this far, we know that the map was applied to exactly one domain, + # and that all the inames in transform_map_in_dims were transformed to + # inames in transform_map_out_dims. However, it's still possible that for some + # statements, stmt.within_inames will contain at least one but not all of the + # transformed inames (transform_map_in_dims). + # In this case, it's not clear what within_inames should be. Therefore, we + # require that if any transformed inames are found in stmt.within_inames, + # ALL transformed inames must be found in stmt.within_inames. + + new_stmts = [] + for stmt in kernel.instructions: + overlap = transform_map_in_dims & stmt.within_inames + if overlap: + if len(overlap) != len(transform_map_in_dims): + raise LoopyError("Statement '%s' is within only a part " + "of the mapped inames in transformation map %s. " + "Statements must be within all or none of the mapped " + "inames." % (stmt.id, transform_map)) + + stmt = stmt.copy(within_inames=( + stmt.within_inames - transform_map_in_dims) | transform_map_out_dims) + else: + # Leave stmt unmodified + pass + + new_stmts.append(stmt) + + # }}} + + kernel = kernel.copy( + domains=new_domains, + instructions=new_stmts, + applied_iname_rewrites=applied_iname_rewrites) + + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + ins = _MapDomainMapper(rule_mapping_context, + transform_map_out_dims, substitutions) + + kernel = ins.map_kernel(kernel) + kernel = rule_mapping_context.finish_kernel(kernel) + + return kernel + +# }}} + +# }}} + + @for_each_kernel def add_inames_for_unused_hw_axes(kernel, within=None): """ diff --git a/test/test_transform.py b/test/test_transform.py index 51e7c2636..3df2d6c2c 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -594,6 +594,382 @@ def test_nested_substs_in_insns(ctx_factory): lp.auto_test_vs_ref(ref_prg, ctx, t_unit) +# {{{ test_map_domain_vs_split_iname + +def _ensure_dim_names_match_and_align(obj_map, tgt_map): + # (This function is also defined in independent, unmerged branch + # new-dependency-and-nest-constraint-semantics-development, and used in + # child branches thereof. Once these branches are all merged, it may make + # sense to move this function to a location for more general-purpose + # machinery. In the other branches, this function's name excludes the + # leading underscore.) + from islpy import align_spaces + from islpy import dim_type as dt + + # first make sure names match + if not all( + set(obj_map.get_var_names(dt)) == set(tgt_map.get_var_names(dt)) + for dt in + [dt.in_, dt.out, dt.param]): + raise ValueError( + "Cannot align spaces; names don't match:\n%s\n%s" + % (obj_map, tgt_map)) + + return align_spaces(obj_map, tgt_map) + + +def test_map_domain_vs_split_iname(ctx_factory): + + # {{{ Make kernel + + knl = lp.make_kernel( + [ + "[nx,nt] -> {[x, t]: 0 <= x < nx and 0 <= t < nt}", + "[ni] -> {[i]: 0 <= i < ni}", + ], + """ + a[x,t] = b[x,t] {id=stmta} + c[x,t] = d[x,t] {id=stmtc} + e[i] = f[i] + """, + lang_version=(2018, 2), + ) + knl = lp.add_and_infer_dtypes(knl, {"b,d,f": np.float32}) + ref_knl = knl + + # }}} + + # {{{ Apply domain change mapping + + knl_map_dom = ref_knl + + # Create map_domain mapping: + import islpy as isl + transform_map = isl.BasicMap( + "[nt] -> {[t] -> [t_outer, t_inner]: " + "0 <= t_inner < 32 and " + "32*t_outer + t_inner = t and " + "0 <= 32*t_outer + t_inner < nt}") + + # Call map_domain to transform kernel + knl_map_dom = lp.map_domain(knl_map_dom, transform_map) + + # Prioritize loops (prio should eventually be updated in map_domain?) + loop_priority = "x, t_outer, t_inner" + knl_map_dom = lp.prioritize_loops(knl_map_dom, loop_priority) + + # Get a linearization + proc_knl_map_dom = lp.preprocess_kernel(knl_map_dom) + lin_knl_map_dom = lp.get_one_linearized_kernel( + proc_knl_map_dom["loopy_kernel"], proc_knl_map_dom.callables_table) + + # }}} + + # {{{ Split iname and see if we get the same result + + knl_split_iname = ref_knl + knl_split_iname = lp.split_iname(knl_split_iname, "t", 32) + knl_split_iname = lp.prioritize_loops(knl_split_iname, loop_priority) + proc_knl_split_iname = lp.preprocess_kernel(knl_split_iname) + lin_knl_split_iname = lp.get_one_linearized_kernel( + proc_knl_split_iname["loopy_kernel"], proc_knl_split_iname.callables_table) + + for d_map_domain, d_split_iname in zip( + knl_map_dom["loopy_kernel"].domains, + knl_split_iname["loopy_kernel"].domains): + d_map_domain_aligned = _ensure_dim_names_match_and_align( + d_map_domain, d_split_iname) + assert d_map_domain_aligned == d_split_iname + + for litem_map_domain, litem_split_iname in zip( + lin_knl_map_dom.linearization, lin_knl_split_iname.linearization): + assert litem_map_domain == litem_split_iname + + # Can't easily compare instructions because equivalent subscript + # expressions may have different orders + + lp.auto_test_vs_ref(proc_knl_split_iname, ctx_factory(), proc_knl_map_dom, + parameters={"nx": 128, "nt": 128, "ni": 128}) + + # }}} + +# }}} + + +# {{{ test_map_domain_transform_map_validity_and_errors + +def test_map_domain_transform_map_validity_and_errors(ctx_factory): + + # {{{ Make kernel + + knl = lp.make_kernel( + [ + "[nx,nt] -> {[x, y, z, t]: 0 <= x,y,z < nx and 0 <= t < nt}", + "[m] -> {[j]: 0 <= j < m}", + ], + """ + a[y,x,t,z] = b[y,x,t,z] {id=stmta} + for j + <>temp = j {dep=stmta} + end + """, + lang_version=(2018, 2), + ) + knl = lp.add_and_infer_dtypes(knl, {"b": np.float32}) + ref_knl = knl + + # }}} + + # {{{ Make sure map_domain *succeeds* when map includes 2 of 4 dims in one + # domain. + + # {{{ Apply domain change mapping that splits t and renames y; (similar to + # split_iname test above, but doesn't hurt to test this slightly different + # scenario) + + knl_map_dom = ref_knl + + # Create map_domain mapping that only includes t and y + # (x and z should be unaffected) + import islpy as isl + transform_map = isl.BasicMap( + "[nx,nt] -> {[t, y] -> [t_outer, t_inner, y_new]: " + "0 <= t_inner < 16 and " + "16*t_outer + t_inner = t and " + "0 <= 16*t_outer + t_inner < nt and " + "y = y_new" + "}") + + # Call map_domain to transform kernel; this should *not* produce an error + knl_map_dom = lp.map_domain(knl_map_dom, transform_map) + + # Prioritize loops + desired_prio = "x, t_outer, t_inner, z, y_new" + + # Use constrain_loop_nesting if it's available + cln_attr = getattr(lp, "constrain_loop_nesting", None) + if cln_attr is not None: + knl_map_dom = lp.constrain_loop_nesting( # noqa pylint:disable=no-member + knl_map_dom, desired_prio) + else: + knl_map_dom = lp.prioritize_loops(knl_map_dom, desired_prio) + + # Get a linearization + proc_knl_map_dom = lp.preprocess_kernel(knl_map_dom) + lin_knl_map_dom = lp.get_one_linearized_kernel( + proc_knl_map_dom["loopy_kernel"], proc_knl_map_dom.callables_table) + + # }}} + + # {{{ Use split_iname and rename_iname, and make sure we get the same result + + knl_split_iname = ref_knl + knl_split_iname = lp.split_iname(knl_split_iname, "t", 16) + knl_split_iname = lp.rename_iname(knl_split_iname, "y", "y_new") + try: + # Use constrain_loop_nesting if it's available + knl_split_iname = lp.constrain_loop_nesting(knl_split_iname, desired_prio) + except AttributeError: + knl_split_iname = lp.prioritize_loops(knl_split_iname, desired_prio) + proc_knl_split_iname = lp.preprocess_kernel(knl_split_iname) + lin_knl_split_iname = lp.get_one_linearized_kernel( + proc_knl_split_iname["loopy_kernel"], proc_knl_split_iname.callables_table) + + for d_map_domain, d_split_iname in zip( + knl_map_dom["loopy_kernel"].domains, + knl_split_iname["loopy_kernel"].domains): + d_map_domain_aligned = _ensure_dim_names_match_and_align( + d_map_domain, d_split_iname) + assert d_map_domain_aligned == d_split_iname + + for litem_map_domain, litem_split_iname in zip( + lin_knl_map_dom.linearization, lin_knl_split_iname.linearization): + assert litem_map_domain == litem_split_iname + + # Can't easily compare instructions because equivalent subscript + # expressions may have different orders + + lp.auto_test_vs_ref(proc_knl_split_iname, ctx_factory(), proc_knl_map_dom, + parameters={"nx": 32, "nt": 32, "m": 32}) + + # }}} + + # }}} + + # {{{ Make sure we error on a map that is not bijective + + # Not bijective + transform_map = isl.BasicMap( + "[nx,nt] -> {[t, y, rogue] -> [t_new, y_new]: " + "y = y_new and t = t_new" + "}") + + from loopy.diagnostic import LoopyError + knl = ref_knl + try: + knl = lp.map_domain(knl, transform_map) + raise AssertionError() + except LoopyError as err: + assert "map must be bijective" in str(err) + + # }}} + + # {{{ Make sure there's an error if transform map does not apply to + # exactly one domain. + + test_maps = [ + # Map where some inames match exactly one domain but there's also a + # rogue dim + isl.BasicMap( + "[nx,nt] -> {[t, y, rogue] -> [t_new, y_new, rogue_new]: " + "y = y_new and t = t_new and rogue = rogue_new" + "}"), + # Map where all inames match exactly one domain but there's also a + # rogue dim + isl.BasicMap( + "[nx,nt] -> {[t, y, x, z, rogue] -> " + "[t_new, y_new, x_new, z_new, rogue_new]: " + "y = y_new and t = t_new and x = x_new and z = z_new " + "and rogue = rogue_new" + "}"), + # Map where no inames match any domain + isl.BasicMap( + "[nx,nt] -> {[rogue] -> [rogue_new]: " + "rogue = rogue_new" + "}"), + ] + + for transform_map in test_maps: + try: + knl = lp.map_domain(knl, transform_map) + raise AssertionError() + except LoopyError as err: + assert ( + "was not applicable to any domain. " + "Transform map must be applicable to exactly one domain." + in str(err)) + + # }}} + + # {{{ Make sure there's an error if we try to map inames in priorities + + knl = ref_knl + knl = lp.prioritize_loops(knl, "y, z") + knl = lp.prioritize_loops(knl, "x, z") + try: + transform_map = isl.BasicMap( + "[nx,nt] -> {[t, y] -> [t_new, y_new]: " + "y = y_new and t = t_new }") + knl = lp.map_domain(knl, transform_map) + raise AssertionError() + except ValueError as err: + assert ( + "Loop priority ('y', 'z') contains iname(s) " + "transformed by map" in str(err)) + + # }}} + + # {{{ Make sure we error when stmt.within_inames contains at least one but + # not all mapped inames + + # {{{ Make potentially problematic kernel + + knl = lp.make_kernel( + [ + "[n, m] -> { [i, j]: 0 <= i < n and 0 <= j < m }", + "[ell] -> { [k]: 0 <= k < ell }", + ], + """ + for i + <>t0 = i {id=stmt0} + for j + <>t1 = j {id=stmt1, dep=stmt0} + end + <>t2 = i + 1 {id=stmt2, dep=stmt1} + end + for k + <>t3 = k {id=stmt3, dep=stmt2} + end + """, + lang_version=(2018, 2), + ) + + # }}} + + # This should fail: + try: + transform_map = isl.BasicMap( + "[n, m] -> {[i, j] -> [i_new, j_new]: " + "i_new = i + j and j_new = 2 + i }") + knl = lp.map_domain(knl, transform_map) + raise AssertionError() + except LoopyError as err: + assert ( + "Statements must be within all or none of the mapped inames" + in str(err)) + + # This should succeed: + transform_map = isl.BasicMap( + "[n, m] -> {[i] -> [i_new]: i_new = i + 2 }") + knl = lp.map_domain(knl, transform_map) + + # }}} + +# }}} + + +def test_diamond_tiling(ctx_factory, interactive=False): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + ref_knl = lp.make_kernel( + "[nx,nt] -> {[ix, it]: 1<=ix {[ix, it] -> [tx, tt, tparity, itt, itx]: " + "16*(tx - tt) + itx - itt = ix - it and " + "16*(tx + tt + tparity) + itt + itx = ix + it and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") + knl = lp.map_domain(knl_for_transform, m) + knl = lp.prioritize_loops(knl, "tt,tparity,tx,itt,itx") + + if interactive: + nx = 43 + u = np.zeros((nx, 200)) + x = np.linspace(-1, 1, nx) + dx = x[1] - x[0] + u[:, 0] = u[:, 1] = np.exp(-100*x**2) + + u_dev = cl.array.to_device(queue, u) + knl(queue, u=u_dev, dx=dx, dt=dx) + + u = u_dev.get() + import matplotlib.pyplot as plt + plt.imshow(u.T) + plt.show() + else: + types = {"dt,dx,u": np.float64} + knl = lp.add_and_infer_dtypes(knl, types) + ref_knl = lp.add_and_infer_dtypes(ref_knl, types) + + lp.auto_test_vs_ref(ref_knl, ctx, knl, + parameters={ + "nx": 200, "nt": 300, + "dx": 1, "dt": 1 + }) + + def test_extract_subst_with_iname_deps_in_templ(ctx_factory): knl = lp.make_kernel( "{[i, j, k]: 0<=i<100 and 0<=j,k<5}",