diff --git a/doc/ref_kernel.rst b/doc/ref_kernel.rst index 922315685..09eceb1d1 100644 --- a/doc/ref_kernel.rst +++ b/doc/ref_kernel.rst @@ -220,6 +220,8 @@ Tag Meaning Identifiers ----------- +.. _reserved-identifiers: + Reserved Identifiers ^^^^^^^^^^^^^^^^^^^^ diff --git a/doc/ref_other.rst b/doc/ref_other.rst index b13f39869..d41109b9d 100644 --- a/doc/ref_other.rst +++ b/doc/ref_other.rst @@ -26,9 +26,16 @@ Automatic Testing .. autofunction:: auto_test_vs_ref +Checking Dependencies at the Statement-Instance Level +----------------------------------------------------- + +.. automodule:: loopy.schedule.checker + Troubleshooting --------------- +.. currentmodule:: loopy + Printing :class:`LoopKernel` objects ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index 177fae61c..ad245c014 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -129,6 +129,9 @@ from loopy.schedule import ( generate_loop_schedules, get_one_scheduled_kernel, get_one_linearized_kernel, linearize) +from loopy.schedule.checker import ( + get_pairwise_statement_orderings, +) from loopy.statistics import (ToCountMap, ToCountPolynomialMap, CountGranularity, stringify_stats_mapping, Op, MemAccess, get_op_map, get_mem_access_map, get_synchronization_map, gather_access_footprints, @@ -268,6 +271,7 @@ "generate_loop_schedules", "get_one_scheduled_kernel", "get_one_linearized_kernel", "linearize", + "get_pairwise_statement_orderings", "GeneratedProgram", "CodeGenerationResult", "PreambleInfo", diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py index 57183109b..733a00c55 100644 --- a/loopy/isl_helpers.py +++ b/loopy/isl_helpers.py @@ -823,10 +823,10 @@ def add_eq_constraint_from_names(isl_obj, var1, var2): # }}} -# {{{ find_and_rename_dim +# {{{ find_and_rename_dims -def find_and_rename_dim(isl_obj, dt, old_name, new_name): - """Rename a dimension in an ISL object. +def find_and_rename_dims(isl_obj, dt, rename_dict): + """Rename dimensions in an ISL object. :arg isl_obj: An :class:`islpy.Set` or :class:`islpy.Map` containing the dimension to be renamed. @@ -834,18 +834,22 @@ def find_and_rename_dim(isl_obj, dt, old_name, new_name): :arg dt: An :class:`islpy.dim_type` (i.e., :class:`int`) specifying the dimension type containing the dimension to be renamed. - :arg old_name: A :class:`str` specifying the name of the dimension to be - renamed. + :arg rename_dict: A :class:`dict` mapping current :class:`string` dimension + names to replacement names. - :arg new_name: A :class:`str` specifying the new name of the dimension to - be renamed. - - :returns: An object of the same type as *isl_obj* with the dimension - *old_name* renamed to *new_name*. + :returns: An object of the same type as *isl_obj* with the dimension names + changed according to *rename_dict*. """ - return isl_obj.set_dim_name( + for old_name, new_name in rename_dict.items(): + idx = isl_obj.find_dim_by_name(dt, old_name) + if idx == -1: + raise ValueError( + "find_and_rename_dims did not find dimension %s" + % (old_name)) + isl_obj = isl_obj.set_dim_name( dt, isl_obj.find_dim_by_name(dt, old_name), new_name) + return isl_obj # }}} diff --git a/loopy/schedule/checker/__init__.py b/loopy/schedule/checker/__init__.py new file mode 100644 index 000000000..b987255d4 --- /dev/null +++ b/loopy/schedule/checker/__init__.py @@ -0,0 +1,161 @@ +""" +.. autofunction:: get_pairwise_statement_orderings + +.. automodule:: loopy.schedule.checker.schedule +""" + + +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +# {{{ get pairwise statement orderings + +def get_pairwise_statement_orderings( + knl, + lin_items, + stmt_id_pairs, + perform_closure_checks=False, + ): + r"""For each statement pair in a subset of all statement pairs found in a + linearized kernel, determine the (relative) order in which the statement + instances are executed. For each pair, represent this relative ordering + using three ``statement instance orderings`` (SIOs): + + - The intra-thread SIO: A :class:`islpy.Map` from each instance of the + first statement to all instances of the second statement that occur + later, such that both statement instances in each before-after pair are + executed within the same work-item (thread). + + - The intra-group SIO: A :class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that occur later, such + that both statement instances in each before-after pair are executed + within the same work-group (though potentially by different work-items). + + - The global SIO: A :class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that occur later, even + if the two statement instances in a given before-after pair are executed + within different work-groups. + + :arg knl: A preprocessed :class:`loopy.LoopKernel` containing the + linearization items that will be used to create the SIOs. + + :arg lin_items: A list of :class:`loopy.schedule.ScheduleItem` + (to be renamed to `loopy.schedule.LinearizationItem`) containing all + linearization items for which SIOs will be created. To allow usage of + this routine during linearization, a truncated (i.e. partial) + linearization may be passed through this argument. + + :arg stmt_id_pairs: A sequence containing pairs of statement identifiers. + + :returns: A dictionary mapping each two-tuple of statement identifiers + provided in `stmt_id_pairs` to a + :class:`~loopy.schedule.checker.schedule.StatementOrdering`, which + contains the three SIOs described above. + + .. doctest: + + >>> import loopy as lp + >>> import numpy as np + >>> # Make kernel ----------------------------------------------------------- + >>> knl = lp.make_kernel( + ... "{[j,k]: 0<=j>> knl = lp.add_and_infer_dtypes(knl, {"a": np.float32, "b": np.float32}) + >>> # Preprocess + >>> knl = lp.preprocess_kernel(knl) + >>> # Get a linearization + >>> knl = lp.get_one_linearized_kernel( + ... knl["loopy_kernel"], knl.callables_table) + >>> # Get pairwise order info ----------------------------------------------- + >>> from loopy.schedule.checker import get_pairwise_statement_orderings + >>> sio_dict = get_pairwise_statement_orderings( + ... knl, + ... knl.linearization, + ... [("stmt_a", "stmt_b")], + ... ) + >>> # Print map + >>> print(str(sio_dict[("stmt_a", "stmt_b")].sio_intra_thread + ... ).replace("{ ", "{\n").replace(" :", "\n:")) + [pj, pk] -> { + [_lp_linchk_stmt' = 0, j'] -> [_lp_linchk_stmt = 1, k] + : pj > 0 and pk > 0 and 0 <= j' < pj and 0 <= k < pk } + + """ + + # {{{ make sure kernel has been preprocessed + + from loopy.kernel import KernelState + assert knl.state in [ + KernelState.PREPROCESSED, + KernelState.LINEARIZED] + + # }}} + + # {{{ Find any EnterLoop inames that are tagged as concurrent + # so that get_pairwise_statement_orderings_inner() knows to ignore them + # (In the future, this should only include inames tagged with 'vec'.) + + # FIXME Consider just putting this ilp/vec logic inside + # get_pairwise_statement_orderings_inner; passing these in as + # 'loops_to_ignore' made more sense when we were just dealing with the + # intra-thread case. + from loopy.schedule.checker.utils import ( + partition_inames_by_concurrency, + get_EnterLoop_inames, + ) + conc_inames, _ = partition_inames_by_concurrency(knl) + enterloop_inames = get_EnterLoop_inames(lin_items) + ilp_and_vec_inames = conc_inames & enterloop_inames + + # The only concurrent EnterLoop inames should be Vec and ILP + from loopy.kernel.data import (VectorizeTag, IlpBaseTag) + for conc_iname in ilp_and_vec_inames: + # Assert that there exists an ilp or vectorize tag (out of the + # potentially multiple other tags on this concurrent iname). + assert any( + isinstance(tag, (VectorizeTag, IlpBaseTag)) + for tag in knl.iname_to_tags[conc_iname]) + + # }}} + + # {{{ Create the SIOs + + from loopy.schedule.checker.schedule import ( + get_pairwise_statement_orderings_inner + ) + return get_pairwise_statement_orderings_inner( + knl, + lin_items, + stmt_id_pairs, + ilp_and_vec_inames=ilp_and_vec_inames, + perform_closure_checks=perform_closure_checks, + ) + + # }}} + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/schedule/checker/lexicographic_order_map.py b/loopy/schedule/checker/lexicographic_order_map.py new file mode 100644 index 000000000..5821202cb --- /dev/null +++ b/loopy/schedule/checker/lexicographic_order_map.py @@ -0,0 +1,201 @@ +# coding: utf-8 +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import islpy as isl + + +def get_statement_ordering_map( + sched_before, sched_after, lex_map, before_mark): + """Return a statement ordering represented as a map from each statement + instance to all statement instances occurring later. + + :arg sched_before: An :class:`islpy.Map` representing a schedule + as a mapping from statement instances (for one particular statement) + to lexicographic time. The statement represented will typically + be the dependee in a dependency relationship. + + :arg sched_after: An :class:`islpy.Map` representing a schedule + as a mapping from statement instances (for one particular statement) + to lexicographic time. The statement represented will typically + be the depender in a dependency relationship. + + :arg lex_map: An :class:`islpy.Map` representing a lexicographic + ordering as a mapping from each point in lexicographic time + to every point that occurs later in lexicographic time. E.g.:: + + {[i0', i1', i2', ...] -> [i0, i1, i2, ...] : + i0' < i0 or (i0' = i0 and i1' < i1) + or (i0' = i0 and i1' = i1 and i2' < i2) ...} + + :arg before_mark: A :class:`str` to be appended to the names of the + map dimensions representing the 'before' statement in the + 'happens before' relationship. + + :returns: An :class:`islpy.Map` representing the statement odering as + a mapping from each statement instance to all statement instances + occurring later. I.e., we compose relations B, L, and A as + B ∘ L ∘ A^-1, where B is `sched_before`, A is `sched_after`, + and L is `lex_map`. + + """ + + # Perform the composition of relations + sio = sched_before.apply_range( + lex_map).apply_range(sched_after.reverse()) + + # Append mark to in_ dims + from loopy.schedule.checker.utils import ( + append_mark_to_isl_map_var_names, + ) + return append_mark_to_isl_map_var_names( + sio, isl.dim_type.in_, before_mark) + + +def _create_lex_order_set( + dim_names, + in_dim_mark, + var_name_to_pwaff=None, + ): + """Return an :class:`islpy.Set` representing a lexicographic ordering + over a space with the number of dimensions provided in `dim_names` + (the set itself will have twice this many dimensions in order to + represent the ordering as before-after pairs of points). + + :arg dim_names: A list of :class:`str` variable names to be used + to describe lexicographic space dimensions for a point in a lexicographic + ordering. (see example below) + + :arg in_dim_mark: A :class:`str` to be appended to dimension names to + distinguish corresponding dimensions in before-after pairs of points. + (see example below) + + :arg var_name_to_pwaff: A dictionary mapping variable names in `dim_names` to + :class:`islpy.PwAff` instances that represent each of the variables + (var_name_to_pwaff may be produced by `islpy.make_zero_and_vars`). + The key '0' is also included and represents a :class:`islpy.PwAff` zero + constant. This dictionary defines the space to be used for the set and + must also include versions of `dim_names` with the `in_dim_mark` + appended. If no value is passed, the dictionary will be made using + `dim_names` and `dim_names` with the `in_dim_mark` appended. + + :returns: An :class:`islpy.Set` representing a big-endian lexicographic + ordering with the number of dimensions provided in `dim_names`. The set + has two dimensions for each name in `dim_names`, one identified by the + given name and another identified by the same name with `in_dim_mark` + appended. The set contains all points which meet a 'happens before' + constraint defining the lexicographic ordering. E.g., if + `dim_names = [i0, i1, i2]` and `in_dim_mark="'"`, + return the set containing all points in a 3-dimensional, big-endian + lexicographic ordering such that point + `[i0', i1', i2']` happens before `[i0, i1, i2]`. I.e., return:: + + {[i0', i1', i2', i0, i1, i2] : + i0' < i0 or (i0' = i0 and i1' < i1) + or (i0' = i0 and i1' = i1 and i2' < i2)} + + """ + + from loopy.schedule.checker.utils import ( + append_mark_to_strings, + ) + + in_dim_names = append_mark_to_strings(dim_names, mark=in_dim_mark) + + # If no var_name_to_pwaff passed, make them using the names provided + # (make sure to pass var names in desired order of space dims) + if var_name_to_pwaff is None: + var_name_to_pwaff = isl.make_zero_and_vars( + in_dim_names+dim_names, + []) + + # Initialize set with constraint i0' < i0 + lex_order_set = var_name_to_pwaff[in_dim_names[0]].lt_set( + var_name_to_pwaff[dim_names[0]]) + + # For each dim d, starting with d=1, equality_conj_set will be constrained + # by d equalities, e.g., (i0' = i0 and i1' = i1 and ... i(d-1)' = i(d-1)). + equality_conj_set = var_name_to_pwaff[0].eq_set( + var_name_to_pwaff[0]) # initialize to 'true' + + for i in range(1, len(in_dim_names)): + + # Add the next equality constraint to equality_conj_set + equality_conj_set = equality_conj_set & \ + var_name_to_pwaff[in_dim_names[i-1]].eq_set( + var_name_to_pwaff[dim_names[i-1]]) + + # Create a set constrained by adding a less-than constraint for this dim, + # e.g., (i1' < i1), to the current equality conjunction set. + # For each dim d, starting with d=1, this full conjunction will have + # d equalities and one inequality, e.g., + # (i0' = i0 and i1' = i1 and ... i(d-1)' = i(d-1) and id' < id) + full_conj_set = var_name_to_pwaff[in_dim_names[i]].lt_set( + var_name_to_pwaff[dim_names[i]]) & equality_conj_set + + # Union this new constraint with the current lex_order_set + lex_order_set = lex_order_set | full_conj_set + + return lex_order_set + + +def create_lex_order_map( + dim_names, + in_dim_mark, + ): + """Return a map from each point in a lexicographic ordering to every + point that occurs later in the lexicographic ordering. + + :arg dim_names: A list of :class:`str` variable names for the + lexicographic space dimensions. + + :arg in_dim_mark: A :class:`str` to be appended to `dim_names` to create + the names for the input dimensions of the map, thereby distinguishing + them from the corresponding output dimensions in before-after pairs of + points. (see example below) + + :returns: An :class:`islpy.Map` representing a lexicographic + ordering as a mapping from each point in lexicographic time + to every point that occurs later in lexicographic time. + E.g., if `dim_names = [i0, i1, i2]` and `in_dim_mark = "'"`, + return the map:: + + {[i0', i1', i2'] -> [i0, i1, i2] : + i0' < i0 or (i0' = i0 and i1' < i1) + or (i0' = i0 and i1' = i1 and i2' < i2)} + + """ + + n_dims = len(dim_names) + dim_type = isl.dim_type + + # First, get a set representing the lexicographic ordering. + lex_order_set = _create_lex_order_set( + dim_names, + in_dim_mark=in_dim_mark, + ) + + # Now convert that set to a map. + lex_map = isl.Map.from_domain(lex_order_set) + return lex_map.move_dims( + dim_type.out, 0, dim_type.in_, + n_dims, n_dims) diff --git a/loopy/schedule/checker/schedule.py b/loopy/schedule/checker/schedule.py new file mode 100644 index 000000000..39b44c2ce --- /dev/null +++ b/loopy/schedule/checker/schedule.py @@ -0,0 +1,1464 @@ +""" +.. data:: LIN_CHECK_IDENTIFIER_PREFIX + + The :class:`str` prefix for identifiers involved in linearization + checking. + +.. data:: LEX_VAR_PREFIX + + The :class:`str` prefix for the variables representing the + dimensions in the lexicographic ordering used in a pairwise schedule. E.g., + a prefix of ``_lp_linchk_lex`` might yield lexicographic dimension + variables ``_lp_linchk_lex0``, ``_lp_linchk_lex1``, ``_lp_linchk_lex2``. + Cf. :ref:`reserved-identifiers`. + +.. data:: STATEMENT_VAR_NAME + + The :class:`str` name for the statement-identifying dimension of maps + representing schedules and statement instance orderings. + +.. data:: LTAG_VAR_NAMES + + An array of :class:`str` names for map dimensions carrying values for local + (intra work-group) thread identifiers in maps representing schedules and + statement instance orderings. + +.. data:: GTAG_VAR_NAMES + + An array of :class:`str` names for map dimensions carrying values for group + identifiers in maps representing schedules and statement instance orderings. + +.. data:: BEFORE_MARK + + The :class:`str` identifier to be appended to input dimension names in + maps representing schedules and statement instance orderings. + +.. autoclass:: SpecialLexPointWRTLoop +.. autoclass:: StatementOrdering +.. autofunction:: get_pairwise_statement_orderings_inner +""" + + +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import islpy as isl +from dataclasses import dataclass +from loopy.schedule.checker.utils import ( + add_and_name_isl_dims, + add_eq_isl_constraint_from_names, + append_mark_to_isl_map_var_names, + move_dims_by_name, + remove_dims_by_name, +) +from loopy.schedule.checker.utils import ( # noqa + prettier_map_string, +) +from loopy.isl_helpers import ( + find_and_rename_dims, +) +dim_type = isl.dim_type + + +# {{{ Constants + +LIN_CHECK_IDENTIFIER_PREFIX = "_lp_linchk_" +LEX_VAR_PREFIX = "%slex" % (LIN_CHECK_IDENTIFIER_PREFIX) +STATEMENT_VAR_NAME = "%sstmt" % (LIN_CHECK_IDENTIFIER_PREFIX) +LTAG_VAR_NAMES = [] +GTAG_VAR_NAMES = [] +for par_level in [0, 1, 2]: + LTAG_VAR_NAMES.append("%slid%d" % (LIN_CHECK_IDENTIFIER_PREFIX, par_level)) + GTAG_VAR_NAMES.append("%sgid%d" % (LIN_CHECK_IDENTIFIER_PREFIX, par_level)) +BEFORE_MARK = "'" + +# }}} + + +# {{{ Helper Functions + +# {{{ _pad_tuple_with_zeros + +def _pad_tuple_with_zeros(tup, desired_length): + return tup[:] + tuple([0]*(desired_length-len(tup))) + +# }}} + + +# {{{ _simplify_lex_dims + +def _simplify_lex_dims(tup0, tup1): + """Simplify a pair of lex tuples in order to reduce the complexity of + resulting maps. Remove lex tuple dimensions with matching integer values + since these do not provide information on relative ordering. Once a + dimension is found where both tuples have non-matching integer values, + remove any faster-updating lex dimensions since they are not necessary + to specify a relative ordering. + """ + + new_tup0 = [] + new_tup1 = [] + + # Loop over dims from slowest updating to fastest + for d0, d1 in zip(tup0, tup1): + if isinstance(d0, int) and isinstance(d1, int): + + # Both vals are ints for this dim + if d0 == d1: + # Do not keep this dim + continue + elif d0 > d1: + # These ints inform us about the relative ordering of + # two statements. While their values may be larger than 1 in + # the lexicographic ordering describing a larger set of + # statements, in a pairwise schedule, only ints 0 and 1 are + # necessary to specify relative order. To keep the pairwise + # schedules as simple and comprehensible as possible, use only + # integers 0 and 1 to specify this relative ordering. + # (doesn't take much extra time since we are already going + # through these to remove unnecessary lex tuple dims) + new_tup0.append(1) + new_tup1.append(0) + + # No further dims needed to fully specify ordering + break + else: # d1 > d0 + new_tup0.append(0) + new_tup1.append(1) + + # No further dims needed to fully specify ordering + break + else: + # Keep this dim without modifying + new_tup0.append(d0) + new_tup1.append(d1) + + if len(new_tup0) == 0: + # Statements map to the exact same point(s) in the lex ordering, + # which is okay, but to represent this, our lex tuple cannot be empty. + return (0, ), (0, ) + else: + return tuple(new_tup0), tuple(new_tup1) + +# }}} + +# }}} + + +# {{{ class SpecialLexPointWRTLoop + +class SpecialLexPointWRTLoop: + """Strings identifying a particular point or set of points in a + lexicographic ordering of statements, specified relative to a loop. + + .. attribute:: PRE + A :class:`str` indicating the last lexicographic point that + precedes the loop. + + .. attribute:: TOP + A :class:`str` indicating the first lexicographic point in + an arbitrary loop iteration. + + .. attribute:: BOTTOM + A :class:`str` indicating the last lexicographic point in + an arbitrary loop iteration. + + .. attribute:: POST + A :class:`str` indicating the first lexicographic point that + follows the loop. + """ + + PRE = "pre" + TOP = "top" + BOTTOM = "bottom" + POST = "post" + +# }}} + + +# {{{ class StatementOrdering + +@dataclass +class StatementOrdering: + r"""A container for the three statement instance orderings (described + below) used to formalize the ordering of statement instances for a pair of + statements. + + Also included (mostly for testing and debugging) are the + intra-thread pairwise schedule (`pwsched_intra_thread`), intra-group + pairwise schedule (`pwsched_intra_group`), and global pairwise schedule + (`pwsched_global`), each containing a pair of mappings from statement + instances to points in a lexicographic ordering, one for each statement. + Each SIO is created by composing the two mappings in the corresponding + pairwise schedule with an associated mapping defining the ordering of + points in the lexicographical space (not included). + """ + + sio_intra_thread: isl.Map + sio_intra_group: isl.Map + sio_global: isl.Map + pwsched_intra_thread: tuple + pwsched_intra_group: tuple + pwsched_global: tuple + +# }}} + + +# {{{ _gather_blex_ordering_info + +# {{{ Helper functions + +def _assert_exact_closure(mapping): + closure_test, closure_exact = mapping.transitive_closure() + assert closure_exact + assert closure_test == mapping + + +INAME_DIMS = slice(1, None, 2) # Odd indices of (unpadded) blex tuples track inames +CODE_SEC_DIMS = slice(0, None, 2) # Even indices track code sections + + +def _add_one_blex_tuple( + all_blex_points, blex_tuple, all_seq_blex_dim_names, + conc_inames, knl): + """Create the (bounded) set of blex points represented by *blex_tuple* and + add it to *all_blex_points*. + """ + + # blex_tuple: (int, iname, int, iname, int, ...) + # - Contains 1 initial dim plus 2 dims for each sequential loop surrounding + # the *current* linearization item + # - Will need padding with zeros for any trailing blex dims + # - blex_tuple[INAME_DIMS] is a subset of all sequential inames + + # {{{ Get inames domain for current inames + + # (need to account for concurrent inames here rather than adding them on + # to blex map at the end because a sequential iname domain may depend on a + # concurrent iname domain) + + # Get set of inames nested outside (including this iname) + all_within_inames = set(blex_tuple[INAME_DIMS]) | conc_inames + + dom = knl.get_inames_domain( + all_within_inames).project_out_except( + all_within_inames, [dim_type.set]) + + # }}} + + # {{{ Prepare for union between dom and all_blex_points + + # Rename sequential iname dims in dom to corresponding blex dim names + dom = find_and_rename_dims( + dom, dim_type.set, + dict(zip(blex_tuple[INAME_DIMS], all_seq_blex_dim_names[INAME_DIMS]))) + + # Move concurrent inames in dom to params + dom = move_dims_by_name( + dom, dim_type.param, dom.n_param(), + dim_type.set, conc_inames) + + # Add any new params found in dom to all_blex_points prior to aligning dom + # with all_blex_points + missing_params = set( + dom.get_var_names(dim_type.param) # needed params + ) - set(all_blex_points.get_var_names(dim_type.param)) # current params + all_blex_points = add_and_name_isl_dims( + all_blex_points, dim_type.param, missing_params) + + # Add missing blex dims to dom and align it with all_blex_points + dom = isl.align_spaces(dom, all_blex_points) + + # Set values for non-iname (integer) blex dims in dom (excludes 0-padding at end) + for blex_dim_name, blex_val in zip( + all_seq_blex_dim_names[CODE_SEC_DIMS], blex_tuple[CODE_SEC_DIMS]): + dom = add_eq_isl_constraint_from_names(dom, blex_dim_name, blex_val) + # Set values for any unused (rightmost, fastest-updating) dom blex dims to zero + for blex_dim_name in all_seq_blex_dim_names[len(blex_tuple):]: + dom = add_eq_isl_constraint_from_names(dom, blex_dim_name, 0) + + # }}} + + # Add this blex set to full set of blex points + return all_blex_points | dom + +# }}} + + +def _gather_blex_ordering_info( + knl, + sync_kind, + lin_items, + seq_loops_with_barriers, + max_seq_loop_depth, + conc_inames, + loop_bounds, + all_stmt_ids, + all_conc_lex_dim_names, + gid_lex_dim_names, + conc_iname_constraint_dicts, + perform_closure_checks=False, + ): + r"""For the given sync_kind ("local" or "global"), create a mapping from + statement instances to blex space (dict), as well as a mapping + defining the blex ordering (isl map from blex space -> blex space) + + Note that, unlike in the intra-thread case, there will be a single + blex ordering map defining the blex ordering for all statement pairs, + rather than separate (smaller) lex ordering maps for each pair + + :arg knl: A preprocessed :class:`loopy.LoopKernel` containing the + linearization items that will be used to create the SIOs. This + kernel will be used to get the domains associated with the inames + used in the statements. + + :sync_kind: A :class:`str` indicating whether we are creating the + intra-group blex ordering ("local") or the global blex ordering + ("global"). + + :arg lin_items: A list of :class:`loopy.schedule.ScheduleItem` + (to be renamed to `loopy.schedule.LinearizationItem`) containing + all linearization items for which SIOs will be + created. To allow usage of this routine during linearization, a + truncated (i.e. partial) linearization may be passed through this + argument + + :arg seq_loops_with_barriers: A set of :class:`str` inames identifying the + non-concurrent loops that contain barriers whose scope affects this + blex ordering. I.e., global barriers affect the global blex ordering, + and both global *and* local barriers affect the intra-group blex + ordering. + + :arg max_seq_loop_depth: A :class:`int` containing the maximum number of + nested non-concurrent loops among those found in + *seq_loops_with_barriers*. + + :arg conc_inames: The set of all :class:`str` inames tagged with a + :class:`loopy.kernel.data.ConcurrentTag`. + + :arg loop_bounds: A :class:`dict` mapping each non-concurrent iname to a + two-tuple containing two :class:`islpy.Set`\ s representing the lower + and upper bounds for the iname. + + :arg all_stmt_ids: A set of all statement identifiers to include in the + mapping from statements to blex time. + + :arg all_conc_lex_dim_names: A list containing the subset of the + :data:`LTAG_VAR_NAMES` and :data:`GTAG_VAR_NAMES` used in this kernel. + + :arg gid_lex_dim_names: A list containing the subset of the + :data:`GTAG_VAR_NAMES` used in this kernel. + + :arg conc_iname_constraint_dicts: A set of :class:`dict`\ s that will be + passed to :func:`islpy.Constratint.eq_from_names` to create constraints + that set each of the concurrent lex dimensions equal to its + corresponding iname. + + :arg perform_closure_checks: A :class:`bool` specifying whether to perform + checks ensuring that the blex map that results after we subtract some + pairs from the full blex map is transitively closed. + + :returns: A :class:`dict` mapping each statement id in :attr:`all_stmt_ids` + to a tuple representing its instances in blex time, an + :class:`islpy.Map` imposing an ordering on the points in blex time, and + a list of the blex dimension names corresponding to sequential + execution (i.e., not the :data:`LTAG_VAR_NAMES` and :data:`GTAG_VAR_NAMES`) + + """ + from loopy.schedule import (EnterLoop, LeaveLoop, Barrier, RunInstruction) + from loopy.schedule.checker.lexicographic_order_map import ( + create_lex_order_map, + ) + from loopy.schedule.checker.utils import ( + add_and_name_isl_dims, + append_mark_to_strings, + add_eq_isl_constraint_from_names, + ) + slex = SpecialLexPointWRTLoop + + # {{{ First, create map from stmt instances to blex space. + + # At the same time, + # - Gather information necessary to create the blex ordering map, i.e., for + # each loop, gather the 6 lex order tuples defined above in + # SpecialLexPointWRTLoop that will be required to create sub-maps which + # will be *excluded* (subtracted) from a standard lexicographic ordering in + # order to create the blex ordering + # - Create all_blex_points, a set containing *all* blex points, which will + # be used later to impose bounds on the full blex map and any blex maps to + # be subtracted from it + + # {{{ Create the initial (pre-subtraction) blex order map, initially w/o bounds + + # Determine the number of blex dims we will need + n_seq_blex_dims = max_seq_loop_depth*2 + 1 + + # Create names for the blex dimensions for sequential loops + seq_blex_dim_names = [ + LEX_VAR_PREFIX+str(i) for i in range(n_seq_blex_dims)] + seq_blex_dim_names_prime = append_mark_to_strings( + seq_blex_dim_names, mark=BEFORE_MARK) + + # Begin with the blex order map created as a standard lexicographical order + # (bounds will be applied later by intersecting this with map containing + # all blex points) + blex_order_map = create_lex_order_map( + dim_names=seq_blex_dim_names, + in_dim_mark=BEFORE_MARK) + + # }}} + + # {{{ Create a template set for the space of all blex points + + # Create set of all blex points by starting with (0, 0, 0, ...) + # and then unioning this with each new set of blex points we find + all_blex_points = isl.align_spaces( + isl.Map("[ ] -> { [ ] -> [ ] }"), blex_order_map).range() + for var_name in seq_blex_dim_names: + all_blex_points = add_eq_isl_constraint_from_names( + all_blex_points, var_name, 0) + # Add concurrent inames as params + # (iname domains found in the pass below may depend on concurrent inames) + all_blex_points = add_and_name_isl_dims( + all_blex_points, dim_type.param, conc_inames) + + # }}} + + stmt_inst_to_blex = {} # Map stmt instances to blex space + iname_to_blex_dim = {} # Map from inames to corresponding blex space dim + blex_exclusion_info = {} # Info for creating maps to exclude from blex order + next_blex_tuple = [0] # Next tuple of points in blex order + sync_kinds_affecting_ordering = set([sync_kind]) + # Global barriers also syncronize across threads within a group + if sync_kind == "local": + sync_kinds_affecting_ordering.add("global") + + for lin_item in lin_items: + if isinstance(lin_item, EnterLoop): + enter_iname = lin_item.iname + if enter_iname in seq_loops_with_barriers: + # Save the blex point prior to this loop + pre_loop_blex_pt = next_blex_tuple[:] + + # Increment next_blex_tuple[-1] for statements in the section + # of code between this EnterLoop and the matching LeaveLoop. + next_blex_tuple[-1] += 1 + + # Upon entering a loop, add one blex dimension for the loop + # iteration, add second blex dim to enumerate sections of + # code within new loop + next_blex_tuple.append(enter_iname) + next_blex_tuple.append(0) + + # Store 2 tuples that will later be used to create mappings + # between blex points that will be subtracted from the full + # blex order map + blex_exclusion_info[enter_iname] = { + slex.PRE: tuple(pre_loop_blex_pt), + slex.TOP: tuple(next_blex_tuple), + } + # (copy these blex points when creating dict because + # the lists will continue to be updated) + + # {{{ Create the blex set for this point, add it to all_blex_points + + all_blex_points = _add_one_blex_tuple( + all_blex_points, next_blex_tuple, + seq_blex_dim_names, conc_inames, knl) + + # }}} + + elif isinstance(lin_item, LeaveLoop): + leave_iname = lin_item.iname + if leave_iname in seq_loops_with_barriers: + + # Record the blex dim for this loop iname + iname_to_blex_dim[leave_iname] = len(next_blex_tuple) - 2 + + # Save the blex tuple prior to exiting loop + pre_end_loop_blex_pt = next_blex_tuple[:] + + # Upon leaving a loop: + # - Pop lex dim for enumerating code sections within this loop + # - Pop lex dim for the loop iteration + # - Increment lex dim val enumerating items in current section + next_blex_tuple.pop() + next_blex_tuple.pop() + next_blex_tuple[-1] += 1 + + # Store 2 tuples that will later be used to create mappings + # between blex points that will be subtracted from the full + # blex order map + blex_exclusion_info[leave_iname][slex.BOTTOM] = tuple( + pre_end_loop_blex_pt) + blex_exclusion_info[leave_iname][slex.POST] = tuple( + next_blex_tuple) + # (copy these blex points when creating dict because + # the lists will continue to be updated) + + # {{{ Create the blex set for this point, add it to all_blex_points + + all_blex_points = _add_one_blex_tuple( + all_blex_points, next_blex_tuple, + seq_blex_dim_names, conc_inames, knl) + + # }}} + + elif isinstance(lin_item, RunInstruction): + # Add stmt->blex pair to stmt_inst_to_blex + stmt_inst_to_blex[lin_item.insn_id] = tuple(next_blex_tuple) + + # (Don't increment blex dim val) + + elif isinstance(lin_item, Barrier): + # Increment blex dim val if the sync scope matches + if lin_item.synchronization_kind in sync_kinds_affecting_ordering: + next_blex_tuple[-1] += 1 + + # {{{ Create the blex set for this point, add it to all_blex_points + + all_blex_points = _add_one_blex_tuple( + all_blex_points, next_blex_tuple, + seq_blex_dim_names, conc_inames, knl) + + # }}} + + lp_stmt_id = lin_item.originating_insn_id + + if lp_stmt_id is None: + # Barriers without stmt ids were inserted as a result of a + # dependency. They don't themselves have dependencies. + # Don't map this barrier to a blex tuple. + continue + + # This barrier has a stmt id. + # If it was included in listed stmts, process it. + # Otherwise, there's nothing left to do (we've already + # incremented next_blex_tuple if necessary, and this barrier + # does not need to be assigned to a designated point in blex + # time) + if lp_stmt_id in all_stmt_ids: + + # Assign a blex point to this barrier just as we would for an + # assignment stmt + stmt_inst_to_blex[lp_stmt_id] = tuple(next_blex_tuple) + + # If sync scope matches, give this barrier its *own* point in + # lex time by updating blex tuple after barrier. + if lin_item.synchronization_kind in sync_kinds_affecting_ordering: + next_blex_tuple[-1] += 1 + + # {{{ Create the blex set for this point, add it to + # all_blex_points + + all_blex_points = _add_one_blex_tuple( + all_blex_points, next_blex_tuple, + seq_blex_dim_names, conc_inames, knl) + + # }}} + else: + from loopy.schedule import (CallKernel, ReturnFromKernel) + # No action needed for these types of linearization item + assert isinstance( + lin_item, (CallKernel, ReturnFromKernel)) + pass + + # At this point, some blex tuples may have more dimensions than others; + # the missing dims are the fastest-updating dims, and their values should + # be zero. Add them. + for stmt, tup in stmt_inst_to_blex.items(): + stmt_inst_to_blex[stmt] = _pad_tuple_with_zeros(tup, n_seq_blex_dims) + + # }}} + + # {{{ Second, create the blex order map + + # {{{ Bound the full (pre-subtraction) blex order map + + conc_iname_to_iname_prime = { + conc_iname: conc_iname+BEFORE_MARK for conc_iname in conc_inames} + all_blex_points_prime = append_mark_to_isl_map_var_names( + all_blex_points, dim_type.set, BEFORE_MARK) + all_blex_points_prime = find_and_rename_dims( + all_blex_points_prime, dim_type.param, conc_iname_to_iname_prime, + ) + blex_order_map = blex_order_map.intersect_domain( + all_blex_points_prime).intersect_range(all_blex_points) + + # }}} + + # {{{ Subtract unwanted pairs from full blex order map + + # Create mapping (dict) from iname to corresponding blex dim name + seq_iname_to_blex_var = {} + for iname, dim in iname_to_blex_dim.items(): + seq_iname_to_blex_var[iname] = seq_blex_dim_names[dim] + seq_iname_to_blex_var[iname+BEFORE_MARK] = seq_blex_dim_names_prime[dim] + + # {{{ Get a template map matching blex_order_map.space that will serve as + # the starting point when creating the maps to subtract from blex_order_map + + # This template includes concurrent inames as params, both marked + # ('before') and unmarked ('after'). + # Note that this template cannot be created until *after* the intersection + # of blex_order_map with all_blex_points above, otherwise the template will + # be missing necessary parameters. + blex_map_template = isl.align_spaces( + isl.Map("[ ] -> { [ ] -> [ ] }"), blex_order_map) + blex_set_template = blex_map_template.range() + + # }}} + + # {{{ _pad_tuples_and_assign_integer_vals_to_map_template() helper + + seq_blex_in_out_dim_names = seq_blex_dim_names_prime + seq_blex_dim_names + + def _pad_tuples_and_assign_integer_vals_to_map_template( + in_tuple, out_tuple): + # External variables read (not written): + # n_seq_blex_dims, seq_blex_in_out_dim_names, blex_map_template + + # Pad the tuples + in_tuple_padded = _pad_tuple_with_zeros(in_tuple, n_seq_blex_dims) + out_tuple_padded = _pad_tuple_with_zeros(out_tuple, n_seq_blex_dims) + + # Assign map values for ints only + map_with_int_vals_assigned = blex_map_template + for dim_name, val in zip( + seq_blex_in_out_dim_names, + in_tuple_padded+out_tuple_padded): + if isinstance(val, int): + map_with_int_vals_assigned = add_eq_isl_constraint_from_names( + map_with_int_vals_assigned, dim_name, val) + + return map_with_int_vals_assigned + + # }}} + + # {{{ Create blex map to subtract for each iname in blex_exclusion_info + + maps_to_subtract = [] + for iname, key_lex_tuples in blex_exclusion_info.items(): + + # {{{ Create blex map to subtract for one iname + + """Create the maps that must be subtracted from the + initial blex order map for this particular loop using the 6 blex + tuples in key_lex_tuples: + PRE->FIRST, BOTTOM(iname')->TOP(iname'+1), LAST->POST + + The PRE, TOP, BOTTOM, and POST blex points for a given loop are defined + above in doc for SpecialLexPointWRTLoop. + + FIRST indicates the first lexicographic point in the + first loop iteration (i.e., TOP, with the iname set to its min. val). + + LAST indicates the last lexicographic point in the + last loop iteration (i.e., BOTTOM, with the iname set to its max val). + + """ + + # {{{ Create PRE->FIRST, BOTTOM(iname')->TOP(iname'+1), LAST->POST + # maps (initially without iname domain bounds) + + # We know which blex dims correspond to inames due to their + # position in blex tuples (int, iname, int, iname, int, ...), and their + # iname domain bounds will be set later by intersecting the subtraction + # map with the (bounded) full blex map. + + # Perform the following: + # - For map domains/ranges corresponding to the PRE, BOTTOM, TOP, and + # POST sets, leave the blex dims corresponding to inames unbounded and + # set the values for blex dims that will be ints, i.e., the + # even-indexed (intra-loop-section) blex dims and any trailing zeros. + # - For map domains/ranges corresponding to the FIRST and LAST sets, + # start with the TOP and BOTTOM sets and then set the map dimension + # corresponding to this iname to loop_bounds[iname][0] and + # loop_bounds[iname][1]. + # - For the BOTTOM->TOP map, add constraint iname = iname' + 1 + + # We will add a condition to fix iteration values for + # *surrounding* sequential loops (iname = iname') after combining the three + # maps (PRE-FIRST, BOTTOM->TOP, LAST->POST) below + + # BOTTOM/TOP tuples will be used multiple times, so grab them now + top_tuple = key_lex_tuples[slex.TOP] + bottom_tuple = key_lex_tuples[slex.BOTTOM] + + # {{{ Create PRE->FIRST map + + # PRE dim vals should all be inames (bounded later) or ints (assign now). + # FIRST dim values will be inames, ints, and the lexmin bound for this iname. + + # Create FIRST by starting with TOP blex tuple and then intersecting + # it with a set that imposes the lexmin bound for this loop. + + # Create initial PRE->FIRST map and assign int (non-iname) dim values. + pre_to_first_map = _pad_tuples_and_assign_integer_vals_to_map_template( + key_lex_tuples[slex.PRE], top_tuple) + + # Get the set representing the value of the iname on the first + # iteration of the loop + loop_min_bound = loop_bounds[iname][0] + # (concurrent inames included in set params) + + # Prepare the loop_min_bound set for intersection with the range of + # pre_to_first_map by renaming iname dims to blex dims and aligning + # spaces + loop_min_bound = find_and_rename_dims( + loop_min_bound, dim_type.set, + {k: seq_iname_to_blex_var[k] for k in top_tuple[INAME_DIMS]}) + # Align with blex space (adds needed dims) + loop_first_set = isl.align_spaces(loop_min_bound, blex_set_template) + + # Finish making PRE->FIRST pair by intersecting this with the range of + # our pre_to_first_map + pre_to_first_map = pre_to_first_map.intersect_range(loop_first_set) + + # }}} + + # {{{ Create BOTTOM->TOP map + + # BOTTOM/TOP dim vals should all be inames (bounded later) or ints + # (assign now). + + # Create BOTTOM->TOP map and assign int (non-iname) dim values + bottom_to_top_map = _pad_tuples_and_assign_integer_vals_to_map_template( + bottom_tuple, top_tuple) + + # Add constraint iname = iname' + 1 + blex_var_for_iname = seq_iname_to_blex_var[iname] + bottom_to_top_map = bottom_to_top_map.add_constraint( + isl.Constraint.eq_from_names( + bottom_to_top_map.space, + {1: 1, blex_var_for_iname + BEFORE_MARK: 1, blex_var_for_iname: -1})) + + # }}} + + # {{{ LAST->POST + + # POST dim vals should all be inames (bounded later) or ints (assign now). + # LAST dim values will be inames, ints, and our lexmax bound for this iname. + + # Create LAST by starting with BOTTOM blex tuple and then intersecting + # it with a set that imposes the lexmax bound for this loop. + + # Create initial LAST->POST map and assign int (non-iname) dim values. + last_to_post_map = _pad_tuples_and_assign_integer_vals_to_map_template( + bottom_tuple, key_lex_tuples[slex.POST]) + + # Get the set representing the value of the iname on the last + # iteration of the loop + loop_max_bound = loop_bounds[iname][1] + # (concurrent inames included in set params) + + # {{{ Prepare the loop_max_bound set for intersection with the domain of + # last_to_post_map by renaming iname dims to blex dims and aligning + # spaces + loop_max_bound = find_and_rename_dims( + loop_max_bound, dim_type.set, + {k: seq_iname_to_blex_var[k] for k in bottom_tuple[INAME_DIMS]}) + + # There may be concurrent inames in the dim_type.param dimensions of + # the loop_max_bound, and we need to append the BEFORE_MARK to those + # inames to ensure that they are distinguished from the corresponding + # non-marked 'after' (concurrent) inames. + # (While the other dims in loop_max_bound also correspond to 'before' + # dimensions of last_to_post_map, which carry the 'before' mark, we do + # not need to append the mark to them in loop_max_bound because calling + # last_to_post_map.intersect_domain(loop_last_set) below will match the + # space.in_ dims by position rather than name) + loop_max_bound = find_and_rename_dims( + loop_max_bound, dim_type.param, conc_iname_to_iname_prime) + + # Align with blex space (adds needed dims) + loop_last_set = isl.align_spaces(loop_max_bound, blex_set_template) + + # }}} + + # Make LAST->POST pair by intersecting this with the range of our map + # Finish making LAST->POST pair by intersecting this with the range of + # our last_to_post_map + last_to_post_map = last_to_post_map.intersect_domain(loop_last_set) + + # }}} + + # }}} + + map_to_subtract = pre_to_first_map | bottom_to_top_map | last_to_post_map + + # Add condition to fix iter value for *surrounding* sequential loops (j = j') + # (odd indices in key_lex_tuples[PRE] contain the sounding inames) + for seq_surrounding_iname in key_lex_tuples[slex.PRE][INAME_DIMS]: + s_blex_var = seq_iname_to_blex_var[seq_surrounding_iname] + map_to_subtract = add_eq_isl_constraint_from_names( + map_to_subtract, s_blex_var, s_blex_var+BEFORE_MARK) + + # Bound the blex dims by intersecting with the full blex map, which + # contains all the bound constraints + map_to_subtract &= blex_order_map + + # }}} + + maps_to_subtract.append(map_to_subtract) + + # }}} + + # {{{ Subtract transitive closure of union of blex maps to subtract + + if maps_to_subtract: + + # Get union of maps + map_to_subtract = maps_to_subtract[0] + for other_map in maps_to_subtract[1:]: + map_to_subtract |= other_map + + # Get transitive closure of maps + map_to_subtract_closure, closure_exact = map_to_subtract.transitive_closure() + + assert closure_exact # FIXME warn instead? + + # {{{ Check assumptions about map transitivity + + if perform_closure_checks: + + # Make sure map_to_subtract_closure is subset of blex_order_map + assert map_to_subtract <= blex_order_map + assert map_to_subtract_closure <= blex_order_map + + # Make sure blex_order_map and map_to_subtract are closures + _assert_exact_closure(blex_order_map) + _assert_exact_closure(map_to_subtract_closure) + + # }}} + + # Subtract closure from blex order map + blex_order_map -= map_to_subtract_closure + + # {{{ Check assumptions about map transitivity + + # Make sure blex_order_map is closure after subtraction + if perform_closure_checks: + _assert_exact_closure(blex_order_map) + + # }}} + + # }}} + + # Add LID/GID dims to blex order map: + + # At this point, all concurrent inames should be params in blex order map. + # Rename them to the corresponding concurrent lex dim name and move them to + # in_/out dims. + + # NOTE: + # Even though all parallel thread dims are active throughout the + # whole kernel, they may be assigned (tagged) to one iname for some + # subset of statements and another iname for a different subset of + # statements (e.g., tiled, parallel matmul). + # There could, e.g., be *multiple* inames that correspond to LID0, and each + # of these inames could be involved in defining the domain set for other + # inames. We don't want to lose any of this information. For this reason, + # we first creat the LID/GID dims, then set each one equal to *all* + # corresponding concurrent inames (which are in param dims), and then + # remove the (param) iname dims. + + # Add conc lex dim names to both in_ and out dims + blex_order_map = add_and_name_isl_dims( + blex_order_map, dim_type.in_, + [v+BEFORE_MARK for v in all_conc_lex_dim_names]) + blex_order_map = add_and_name_isl_dims( + blex_order_map, dim_type.out, all_conc_lex_dim_names) + + # Set each of the new conc lex dims equal to *all* corresponding inames + # (here, conc_iname_constraint_dicts includes primed inames) + for constraint_dict in conc_iname_constraint_dicts: + blex_order_map = blex_order_map.add_constraint( + isl.Constraint.eq_from_names(blex_order_map.space, constraint_dict)) + + # Now remove conc inames from params + blex_order_map = remove_dims_by_name( + blex_order_map, dim_type.param, + conc_inames | set([v+BEFORE_MARK for v in conc_inames])) + + if sync_kind == "local": + # For intra-group case, constrain GID 'before' to equal GID 'after' + + # (in the current implementation, all gid_lex_dim_names should be + # present in blex_order_map) + for var_name in gid_lex_dim_names: + blex_order_map = add_eq_isl_constraint_from_names( + blex_order_map, var_name, var_name+BEFORE_MARK) + + # (if sync_kind == "global", don't need constraints on LID/GID vars) + + # }}} + + # }}} + + return ( + stmt_inst_to_blex, # map stmt instances to blex space + blex_order_map, + seq_blex_dim_names, + ) + +# }}} + + +# {{{ get_pairwise_statement_orderings_inner + +def get_pairwise_statement_orderings_inner( + knl, + lin_items, + stmt_id_pairs, + ilp_and_vec_inames=frozenset(), + perform_closure_checks=False, + ): + r"""For each statement pair in a subset of all statement pairs found in a + linearized kernel, determine the (relative) order in which the statement + instances are executed. For each pair, represent this relative ordering + using three ``statement instance orderings`` (SIOs): + + - The intra-thread SIO: A :class:`islpy.Map` from each instance of the + first statement to all instances of the second statement that occur + later, such that both statement instances in each before-after pair are + executed within the same work-item (thread). + + - The intra-group SIO: A :class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that occur later, such + that both statement instances in each before-after pair are executed + within the same work-group (though potentially by different work-items). + + - The global SIO: A :class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that occur later, even + if the two statement instances in a given before-after pair are executed + within different work-groups. + + :arg knl: A preprocessed :class:`loopy.LoopKernel` containing the + linearization items that will be used to create the SIOs. This + kernel will be used to get the domains associated with the inames + used in the statements, and to determine which inames have been + tagged with parallel tags. + + :arg lin_items: A list of :class:`loopy.schedule.ScheduleItem` + (to be renamed to `loopy.schedule.LinearizationItem`) containing + all linearization items for which SIOs will be + created. To allow usage of this routine during linearization, a + truncated (i.e. partial) linearization may be passed through this + argument + + :arg stmt_id_pairs: A list containing pairs of statement identifiers. + + :arg ilp_and_vec_inames: A set of inames that will be ignored when + determining the relative ordering of statements. This will typically + contain concurrent inames tagged with the ``vec`` or ``ilp`` array + access tags. + + :returns: A dictionary mapping each two-tuple of statement identifiers + provided in `stmt_id_pairs` to a :class:`StatementOrdering`, which + contains the three SIOs described above. + """ + + from loopy.schedule import (EnterLoop, LeaveLoop, Barrier, RunInstruction) + from loopy.kernel.data import (LocalInameTag, GroupInameTag) + from loopy.schedule.checker.lexicographic_order_map import ( + create_lex_order_map, + get_statement_ordering_map, + ) + from loopy.schedule.checker.utils import ( + add_and_name_isl_dims, + append_mark_to_strings, + sorted_union_of_names_in_isl_sets, + create_symbolic_map_from_tuples, + insert_and_name_isl_dims, + partition_inames_by_concurrency, + ) + + all_stmt_ids = set().union(*stmt_id_pairs) + conc_inames = partition_inames_by_concurrency(knl)[0] + + # {{{ Intra-thread lex order creation + + # First, use one pass through lin_items to generate an *intra-thread* + # lexicographic ordering describing the relative order of all statements + # represented by all_stmt_ids + + # For each statement, map the stmt_id to a tuple representing points + # in the intra-thread lexicographic ordering containing items of :class:`int` or + # :class:`str` :mod:`loopy` inames + stmt_inst_to_lex_intra_thread = {} + + # Keep track of the next tuple of points in our lexicographic + # ordering, initially this as a 1-d point with value 0 + next_lex_tuple = [0] + + # While we're passing through, determine which loops contain barriers, + # this information will be used later when creating *intra-group* and + # *global* lexicographic orderings + seq_loops_with_barriers = {"local": set(), "global": set()} + max_depth_of_barrier_loop = {"local": 0, "global": 0} + current_seq_inames = [] + + # While we're passing through, also determine the values of the active + # inames on the first and last iteration of each loop that contains + # barriers (dom.lexmin/lexmax). + # This information will be used later when creating *intra-group* and + # *global* lexicographic orderings + loop_bounds = {} + + for lin_item in lin_items: + if isinstance(lin_item, EnterLoop): + iname = lin_item.iname + + if iname in ilp_and_vec_inames: + continue + + current_seq_inames.append(iname) + + # Increment next_lex_tuple[-1] for statements in the section + # of code between this EnterLoop and the matching LeaveLoop. + # (not technically necessary if no statement was added in the + # previous section; gratuitous incrementing is counteracted + # in the simplification step below) + next_lex_tuple[-1] += 1 + + # Upon entering a loop, add one lex dimension for the loop iteration, + # add second lex dim to enumerate sections of code within new loop + next_lex_tuple.append(iname) + next_lex_tuple.append(0) + + elif isinstance(lin_item, LeaveLoop): + iname = lin_item.iname + + if iname in ilp_and_vec_inames: + continue + + current_seq_inames.pop() + + # Upon leaving a loop: + # - Pop lex dim for enumerating code sections within this loop + # - Pop lex dim for the loop iteration + # - Increment lex dim val enumerating items in current section of code + next_lex_tuple.pop() + next_lex_tuple.pop() + next_lex_tuple[-1] += 1 + + # (not technically necessary if no statement was added in the + # previous section; gratuitous incrementing is counteracted + # in the simplification step below) + + elif isinstance(lin_item, RunInstruction): + lp_stmt_id = lin_item.insn_id + + # Only process listed stmts, otherwise ignore + if lp_stmt_id in all_stmt_ids: + # Add item to stmt_inst_to_lex_intra_thread + stmt_inst_to_lex_intra_thread[lp_stmt_id] = tuple(next_lex_tuple) + + # Increment lex dim val enumerating items in current section of code + next_lex_tuple[-1] += 1 + + elif isinstance(lin_item, Barrier): + lp_stmt_id = lin_item.originating_insn_id + sync_kind = lin_item.synchronization_kind + seq_loops_with_barriers[sync_kind] |= set(current_seq_inames) + max_depth_of_barrier_loop[sync_kind] = max( + len(current_seq_inames), max_depth_of_barrier_loop[sync_kind]) + + # {{{ Store bounds for loops containing barriers + + # Only compute the bounds we haven't already stored; bounds finding + # will only happen once for each barrier-containing loop + for depth, iname in enumerate(current_seq_inames): + + # If we haven't already stored bounds for this iname, do so + if iname not in loop_bounds: + + # Get set of inames that might be involved in this bound + # (this iname plus any nested outside this iname, including + # concurrent inames) + seq_surrounding_inames = set(current_seq_inames[:depth]) + all_surrounding_inames = seq_surrounding_inames | conc_inames + + # Get inames domain + inames_involved_in_bound = all_surrounding_inames | {iname} + dom = knl.get_inames_domain( + inames_involved_in_bound).project_out_except( + inames_involved_in_bound, [dim_type.set]) + + # {{{ Move domain dims for surrounding inames to parameters + + dom = move_dims_by_name( + dom, dim_type.param, dom.n_param(), + dim_type.set, all_surrounding_inames) + + # }}} + + lmin = dom.lexmin() + lmax = dom.lexmax() + + # Now move non-concurrent param inames back to set dim + lmin = move_dims_by_name( + lmin, dim_type.set, 0, + dim_type.param, seq_surrounding_inames) + lmax = move_dims_by_name( + lmax, dim_type.set, 0, + dim_type.param, seq_surrounding_inames) + + loop_bounds[iname] = (lmin, lmax) + + # }}} + + if lp_stmt_id is None: + # Barriers without stmt ids were inserted as a result of a + # dependency. They don't themselves have dependencies. Ignore them. + + # FIXME: It's possible that we could record metadata about them + # (e.g. what dependency produced them) and verify that they're + # adequately protecting all statement instance pairs. + + continue + + # If barrier was identified in listed stmts, process it + if lp_stmt_id in all_stmt_ids: + # Add item to stmt_inst_to_lex_intra_thread + stmt_inst_to_lex_intra_thread[lp_stmt_id] = tuple(next_lex_tuple) + + # Increment lex dim val enumerating items in current section of code + next_lex_tuple[-1] += 1 + + else: + from loopy.schedule import (CallKernel, ReturnFromKernel) + # No action needed for these types of linearization item + assert isinstance( + lin_item, (CallKernel, ReturnFromKernel)) + pass + + # Since global barriers also syncronize threads *within* a work-group, our + # mechanisms that account for the effect of *local* barriers on execution + # order need to view *global* barriers as also having that effect. + # Include global barriers in seq_loops_with_barriers["local"] and + # max_depth_of_barrier_loop["local"]. + seq_loops_with_barriers["local"] |= seq_loops_with_barriers["global"] + max_depth_of_barrier_loop["local"] = max( + max_depth_of_barrier_loop["local"], max_depth_of_barrier_loop["global"]) + + # }}} + + # {{{ Create lex dim names representing parallel axes + + # Create lex dim names representing lid/gid axes. + # At the same time, create the dicts that will be used later to create map + # constraints that match each parallel iname to the corresponding lex dim + # name in schedules, i.e., i = lid0, j = lid1, etc. + lid_lex_dim_names = set() + gid_lex_dim_names = set() + + # Dicts that will be used to create constraints i = lid0, j = lid1, etc. + # (for efficiency, create these dicts one time per concurrent iname here, + # rather than recreating the dicts multiple times later) + conc_iname_constraint_dicts = {} + conc_iname_constraint_dicts_prime = {} + + # NOTE: Even though all parallel thread dims are active throughout the + # whole kernel, they may be assigned (tagged) to one iname for some + # subset of statements and another iname for a different subset of + # statements (e.g., tiled, paralle. matmul). + for iname in knl.all_inames(): + conc_tag = knl.iname_tags_of_type(iname, (LocalInameTag, GroupInameTag)) + if conc_tag: + assert len(conc_tag) == 1 # (should always be true) + conc_tag = conc_tag.pop() + if isinstance(conc_tag, LocalInameTag): + tag_var = LTAG_VAR_NAMES[conc_tag.axis] + lid_lex_dim_names.add(tag_var) + else: # Must be GroupInameTag + tag_var = GTAG_VAR_NAMES[conc_tag.axis] + gid_lex_dim_names.add(tag_var) + + tag_var_prime = tag_var+BEFORE_MARK + iname_prime = iname+BEFORE_MARK + conc_iname_constraint_dicts[iname] = {1: 0, iname: 1, tag_var: -1} + conc_iname_constraint_dicts_prime[iname_prime] = { + 1: 0, iname_prime: 1, tag_var_prime: -1} + + # Sort for consistent dimension ordering + lid_lex_dim_names = sorted(lid_lex_dim_names) + gid_lex_dim_names = sorted(gid_lex_dim_names) + + # }}} + + # {{{ Intra-group and global blex ("barrier-lex") order creation + + # (may be combined with pass above in future) + + """In blex space, we order barrier-delimited sections of code. + Each statement instance within a single barrier-delimited section will + map to the same blex point. The resulting statement instance ordering + (SIO) will map each statement to all statements that occur in a later + barrier-delimited section. + + To achieve this, we will first create a map from statement instances to + lexicographic space almost as we did above in the intra-thread case, + though we will not increment the fastest-updating lex dim with each + statement, and we will increment it with each barrier encountered. To + denote these differences, we refer to this space as 'blex' space. + + The resulting pairwise schedule, if composed with a map defining a + standard lexicographic ordering to create an SIO, would include a number + of unwanted 'before->after' pairs of statement instances, so before + creating the SIO, we will subtract unwanted pairs from a standard + lex order map, yielding the 'blex' order map. + """ + + # {{{ Create blex order maps and blex tuples defining statement ordering (x2) + + all_conc_lex_dim_names = lid_lex_dim_names + gid_lex_dim_names + all_conc_iname_constraint_dicts = list( + conc_iname_constraint_dicts.values() + ) + list(conc_iname_constraint_dicts_prime.values()) + + # Get the blex schedule blueprint (dict will become a map below) and + # blex order map w.r.t. local and global barriers + (stmt_inst_to_lblex, + lblex_order_map, + seq_lblex_dim_names) = _gather_blex_ordering_info( + knl, + "local", + lin_items, + seq_loops_with_barriers["local"], + max_depth_of_barrier_loop["local"], + conc_inames, + loop_bounds, + all_stmt_ids, + all_conc_lex_dim_names, + gid_lex_dim_names, + all_conc_iname_constraint_dicts, + perform_closure_checks=perform_closure_checks, + ) + (stmt_inst_to_gblex, + gblex_order_map, + seq_gblex_dim_names) = _gather_blex_ordering_info( + knl, + "global", + lin_items, + seq_loops_with_barriers["global"], + max_depth_of_barrier_loop["global"], + conc_inames, + loop_bounds, + all_stmt_ids, + all_conc_lex_dim_names, + gid_lex_dim_names, + all_conc_iname_constraint_dicts, + perform_closure_checks=perform_closure_checks, + ) + + # }}} + + # }}} end intra-group and global blex order creation + + # {{{ Create pairwise schedules (ISL maps) for each stmt pair + + # {{{ _get_map_for_stmt() + + def _get_map_for_stmt( + stmt_id, lex_points, int_sid, lex_dim_names): + + # Get inames domain for statement instance (a BasicSet) + within_inames = knl.id_to_insn[stmt_id].within_inames + dom = knl.get_inames_domain( + within_inames).project_out_except(within_inames, [dim_type.set]) + + # Create map space (an isl space in current implementation) + # {('statement', ) -> + # (lexicographic ordering dims)} + dom_inames_ordered = sorted_union_of_names_in_isl_sets([dom]) + + in_names_sched = [STATEMENT_VAR_NAME] + dom_inames_ordered[:] + sched_space = isl.Space.create_from_names( + isl.DEFAULT_CONTEXT, + in_=in_names_sched, + out=lex_dim_names, + params=[], + ) + + # Insert 'statement' dim into domain so that its space allows + # for intersection with sched map later + dom_to_intersect = insert_and_name_isl_dims( + dom, dim_type.set, [STATEMENT_VAR_NAME], 0) + + # Each map will map statement instances -> lex time. + # At this point, statement instance tuples consist of single int. + # Add all inames from domains to each map domain tuple. + tuple_pair = [( + (int_sid, ) + tuple(dom_inames_ordered), + lex_points + )] + + # Note that lex_points may have fewer dims than the out-dim of sched_space + # if sched_space includes concurrent LID/GID dims. This is okay because + # the following symbolic map creation step, when assigning dim values, + # zips the space dims with the lex tuple, and any leftover LID/GID dims + # will not be assigned a value yet, which is what we want. + + # Create map + sched_map = create_symbolic_map_from_tuples( + tuple_pairs_with_domains=zip(tuple_pair, [dom_to_intersect, ]), + space=sched_space, + ) + + # Set inames equal to relevant GID/LID var names + for iname, constraint_dict in conc_iname_constraint_dicts.items(): + # Even though all parallel thread dims are active throughout the + # whole kernel, they may be assigned (tagged) to one iname for some + # subset of statements and another iname for a different subset of + # statements (e.g., tiled, paralle. matmul). + # So before adding each parallel iname constraint, make sure the + # iname applies to this statement: + if iname in dom_inames_ordered: + sched_map = sched_map.add_constraint( + isl.Constraint.eq_from_names(sched_map.space, constraint_dict)) + + return sched_map + + # }}} + + pairwise_sios = {} + + for stmt_ids in stmt_id_pairs: + # Determine integer IDs that will represent each statement in mapping + # (dependency map creation assumes sid_before=0 and sid_after=1, unless + # before and after refer to same stmt, in which case + # sid_before=sid_after=0) + int_sids = [0, 0] if stmt_ids[0] == stmt_ids[1] else [0, 1] + + # {{{ Create SIO for intra-thread case (lid0' == lid0, gid0' == gid0, etc) + + # Simplify tuples to the extent possible ------------------------------------ + + lex_tuples = [stmt_inst_to_lex_intra_thread[stmt_id] for stmt_id in stmt_ids] + + # At this point, one of the lex tuples may have more dimensions than + # another; the missing dims are the fastest-updating dims, and their + # values should be zero. Add them. + max_lex_dims = max([len(lex_tuple) for lex_tuple in lex_tuples]) + lex_tuples_padded = [ + _pad_tuple_with_zeros(lex_tuple, max_lex_dims) + for lex_tuple in lex_tuples] + + # Now generate maps from the blueprint -------------------------------------- + + lex_tuples_simplified = _simplify_lex_dims(*lex_tuples_padded) + + # Create names for the output dimensions for sequential loops + seq_lex_dim_names = [ + LEX_VAR_PREFIX+str(i) for i in range(len(lex_tuples_simplified[0]))] + + intra_thread_sched_maps = [ + _get_map_for_stmt( + stmt_id, lex_tuple, int_sid, + seq_lex_dim_names+all_conc_lex_dim_names) + for stmt_id, lex_tuple, int_sid + in zip(stmt_ids, lex_tuples_simplified, int_sids) + ] + + # Create pairwise lex order map (pairwise only in the intra-thread case) + lex_order_map = create_lex_order_map( + dim_names=seq_lex_dim_names, + in_dim_mark=BEFORE_MARK, + ) + + # Add lid/gid dims to lex order map + lex_order_map = add_and_name_isl_dims( + lex_order_map, dim_type.out, all_conc_lex_dim_names) + lex_order_map = add_and_name_isl_dims( + lex_order_map, dim_type.in_, + append_mark_to_strings(all_conc_lex_dim_names, mark=BEFORE_MARK)) + # Constrain lid/gid vars to be equal (this is the intra-thread case) + for var_name in all_conc_lex_dim_names: + lex_order_map = add_eq_isl_constraint_from_names( + lex_order_map, var_name, var_name+BEFORE_MARK) + + # Create statement instance ordering, + # maps each statement instance to all statement instances occurring later + sio_intra_thread = get_statement_ordering_map( + *intra_thread_sched_maps, # note, func accepts exactly two maps + lex_order_map, + before_mark=BEFORE_MARK, + ) + + # }}} + + # {{{ Create SIOs for intra-group case (gid0' == gid0, etc) and global case + + def _get_sched_maps_and_sio_for_conc_exec( + stmt_inst_to_blex, blex_order_map, seq_blex_dim_names): + # (Vars from outside func used here: + # stmt_ids, int_sids, all_conc_lex_dim_names) + + # Use *unsimplified* lex tuples w/ blex map, which are already padded + blex_tuples_padded = [stmt_inst_to_blex[stmt_id] for stmt_id in stmt_ids] + + sched_maps = [ + _get_map_for_stmt( + stmt_id, blex_tuple, int_sid, + seq_blex_dim_names+all_conc_lex_dim_names) # all par names + for stmt_id, blex_tuple, int_sid + in zip(stmt_ids, blex_tuples_padded, int_sids) + ] + + # Note that for the intra-group case, we already constrained GID + # 'before' to equal GID 'after' earlier in _gather_blex_ordering_info() + + # Create statement instance ordering + sio = get_statement_ordering_map( + *sched_maps, # note, func accepts exactly two maps + blex_order_map, + before_mark=BEFORE_MARK, + ) + + return sched_maps, sio + + pwsched_intra_group, sio_intra_group = _get_sched_maps_and_sio_for_conc_exec( + stmt_inst_to_lblex, lblex_order_map, seq_lblex_dim_names) + pwsched_global, sio_global = _get_sched_maps_and_sio_for_conc_exec( + stmt_inst_to_gblex, gblex_order_map, seq_gblex_dim_names) + + # }}} + + # Store sched maps along with SIOs + pairwise_sios[tuple(stmt_ids)] = StatementOrdering( + sio_intra_thread=sio_intra_thread, + sio_intra_group=sio_intra_group, + sio_global=sio_global, + pwsched_intra_thread=tuple(intra_thread_sched_maps), + pwsched_intra_group=tuple(pwsched_intra_group), + pwsched_global=tuple(pwsched_global), + ) + + # }}} + + return pairwise_sios + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/schedule/checker/utils.py b/loopy/schedule/checker/utils.py new file mode 100644 index 000000000..0fc0971da --- /dev/null +++ b/loopy/schedule/checker/utils.py @@ -0,0 +1,357 @@ +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import islpy as isl +dim_type = isl.dim_type + + +def prettier_map_string(map_obj): + return str( + map_obj + ).replace("{ ", "{\n").replace(" }", "\n}").replace("; ", ";\n").replace( + "(", "\n (") + + +def insert_and_name_isl_dims(isl_set, dt, names, new_idx_start): + new_set = isl_set.insert_dims(dt, new_idx_start, len(names)) + for i, name in enumerate(names): + new_set = new_set.set_dim_name(dt, new_idx_start+i, name) + return new_set + + +def add_and_name_isl_dims(isl_map, dt, names): + new_idx_start = isl_map.dim(dt) + new_map = isl_map.add_dims(dt, len(names)) + for i, name in enumerate(names): + new_map = new_map.set_dim_name(dt, new_idx_start+i, name) + return new_map + + +def reorder_dims_by_name( + isl_set, dt, desired_dims_ordered): + """Return an isl_set with the dimensions of the specified dim type + in the specified order. + + :arg isl_set: A :class:`islpy.Set` whose dimensions are + to be reordered. + + :arg dt: A :class:`islpy.dim_type`, i.e., an :class:`int`, + specifying the dimension to be reordered. + + :arg desired_dims_ordered: A :class:`list` of :class:`str` elements + representing the desired dimensions in order by dimension name. + + :returns: An :class:`islpy.Set` matching `isl_set` with the + dimension order matching `desired_dims_ordered`. + + """ + + assert dt != dim_type.param + assert set(isl_set.get_var_names(dt)) == set(desired_dims_ordered) + + other_dt = dim_type.param + other_dim_len = len(isl_set.get_var_names(other_dt)) + + new_set = isl_set.copy() + for desired_idx, name in enumerate(desired_dims_ordered): + + current_idx = new_set.find_dim_by_name(dt, name) + if current_idx != desired_idx: + # First move to other dim because isl is stupid + new_set = new_set.move_dims( + other_dt, other_dim_len, dt, current_idx, 1) + # Now move it where we actually want it + new_set = new_set.move_dims( + dt, desired_idx, other_dt, other_dim_len, 1) + + return new_set + + +def move_dims_by_name( + isl_obj, dst_type, dst_pos_start, src_type, dim_names): + dst_pos = dst_pos_start + for dim_name in dim_names: + src_idx = isl_obj.find_dim_by_name(src_type, dim_name) + if src_idx == -1: + raise ValueError( + "move_dims_by_name did not find dimension %s" + % (dim_name)) + isl_obj = isl_obj.move_dims( + dst_type, dst_pos, src_type, src_idx, 1) + dst_pos += 1 + return isl_obj + + +def remove_dims_by_name(isl_obj, dt, dim_names): + for dim_name in dim_names: + idx = isl_obj.find_dim_by_name(dt, dim_name) + if idx == -1: + raise ValueError( + "remove_dims_by_name did not find dimension %s" + % (dim_name)) + isl_obj = isl_obj.remove_dims(dt, idx, 1) + return isl_obj + + +def rename_dims( + isl_set, rename_map, + dts=(dim_type.in_, dim_type.out, dim_type.param)): + new_isl_set = isl_set.copy() + for dt in dts: + for idx, old_name in enumerate(isl_set.get_var_names(dt)): + if old_name in rename_map: + new_isl_set = new_isl_set.set_dim_name( + dt, idx, rename_map[old_name]) + return new_isl_set + + +def ensure_dim_names_match_and_align(obj_map, tgt_map): + + # first make sure names match + if not all( + set(obj_map.get_var_names(dt)) == set(tgt_map.get_var_names(dt)) + for dt in + [dim_type.in_, dim_type.out, dim_type.param]): + raise ValueError( + "Cannot align spaces; names don't match:\n%s\n%s" + % (prettier_map_string(obj_map), prettier_map_string(tgt_map)) + ) + + return isl.align_spaces(obj_map, tgt_map) + + +def add_eq_isl_constraint_from_names(isl_map, var1, var2): + # add constraint var1 = var2 + assert isinstance(var1, str) + # var2 may be an int or a string + if isinstance(var2, str): + return isl_map.add_constraint( + isl.Constraint.eq_from_names( + isl_map.space, + {1: 0, var1: 1, var2: -1})) + else: + assert isinstance(var2, int) + return isl_map.add_constraint( + isl.Constraint.eq_from_names( + isl_map.space, + {1: var2, var1: -1})) + + +def add_int_bounds_to_isl_var(isl_map, var, lbound, ubound): + # NOTE: these are inclusive bounds + # add constraint var1 = var2 + return isl_map.add_constraint( + isl.Constraint.ineq_from_names( + isl_map.space, {1: -1*lbound, var: 1}) + ).add_constraint( + isl.Constraint.ineq_from_names( + isl_map.space, {1: ubound, var: -1})) + + +def append_mark_to_isl_map_var_names(old_isl_map, dt, mark): + """Return an :class:`islpy.Map` with a mark appended to the specified + dimension names. + + :arg old_isl_map: An :class:`islpy.Map`. + + :arg dt: An :class:`islpy.dim_type`, i.e., an :class:`int`, + specifying the dimension to be marked. + + :arg mark: A :class:`str` to be appended to the specified dimension + names. If not provided, `mark` defaults to an apostrophe. + + :returns: An :class:`islpy.Map` matching `old_isl_map` with + `mark` appended to the `dt` dimension names. + + """ + + new_map = old_isl_map.copy() + for i in range(len(old_isl_map.get_var_names(dt))): + new_map = new_map.set_dim_name(dt, i, old_isl_map.get_dim_name( + dt, i)+mark) + return new_map + + +def append_mark_to_strings(strings, mark): + assert isinstance(strings, list) + return [s+mark for s in strings] + + +def sorted_union_of_names_in_isl_sets( + isl_sets, + set_dim=dim_type.set): + r"""Return a sorted list of the union of all variable names found in + the provided :class:`islpy.Set`\ s. + """ + + inames = set().union(*[isl_set.get_var_names(set_dim) for isl_set in isl_sets]) + + # Sorting is not necessary, but keeps results consistent between runs + return sorted(inames) + + +def create_symbolic_map_from_tuples( + tuple_pairs_with_domains, + space, + ): + """Return an :class:`islpy.Map` constructed using the provided space, + mapping input->output tuples provided in `tuple_pairs_with_domains`, + with each set of tuple variables constrained by the domains provided. + + :arg tuple_pairs_with_domains: A :class:`list` with each element being + a tuple of the form `((tup_in, tup_out), domain)`. + `tup_in` and `tup_out` are tuples containing elements of type + :class:`int` and :class:`str` representing values for the + input and output dimensions in `space`, and `domain` is a + :class:`islpy.Set` constraining variable bounds. + + :arg space: A :class:`islpy.Space` to be used to create the map. + + :returns: A :class:`islpy.Map` constructed using the provided space + as follows. For each `((tup_in, tup_out), domain)` in + `tuple_pairs_with_domains`, map + `(tup_in)->(tup_out) : domain`, where `tup_in` and `tup_out` are + numeric or symbolic values assigned to the input and output + dimension variables in `space`, and `domain` specifies conditions + on these values. + + """ + # FIXME allow None for domains + + space_out_names = space.get_var_names(dim_type.out) + space_in_names = space.get_var_names(dim_type.in_) + + def _conjunction_of_dim_eq_conditions(dim_names, values, var_name_to_pwaff): + condition = var_name_to_pwaff[0].eq_set(var_name_to_pwaff[0]) + for dim_name, val in zip(dim_names, values): + if isinstance(val, int): + condition = condition \ + & var_name_to_pwaff[dim_name].eq_set(var_name_to_pwaff[0]+val) + else: + condition = condition \ + & var_name_to_pwaff[dim_name].eq_set(var_name_to_pwaff[val]) + return condition + + # Get islvars from space + var_name_to_pwaff = isl.affs_from_space( + space.move_dims( + dim_type.out, 0, + dim_type.in_, 0, + len(space_in_names), + ).range() + ) + + # Initialize union of maps to empty + union_of_maps = isl.Map.from_domain( + var_name_to_pwaff[0].eq_set(var_name_to_pwaff[0]+1) # 0 == 1 (false) + ).move_dims( + dim_type.out, 0, dim_type.in_, len(space_in_names), len(space_out_names)) + + # Loop through tuple pairs + for (tup_in, tup_out), dom in tuple_pairs_with_domains: + + # Set values for 'in' dimension using tuple vals + condition = _conjunction_of_dim_eq_conditions( + space_in_names, tup_in, var_name_to_pwaff) + + # Set values for 'out' dimension using tuple vals + condition = condition & _conjunction_of_dim_eq_conditions( + space_out_names, tup_out, var_name_to_pwaff) + + # Convert set to map by moving dimensions around + map_from_set = isl.Map.from_domain(condition) + map_from_set = map_from_set.move_dims( + dim_type.out, 0, dim_type.in_, + len(space_in_names), len(space_out_names)) + + # Align the *out* dims of dom with the space *in_* dims + # in preparation for intersection + dom_with_set_dim_aligned = reorder_dims_by_name( + dom, dim_type.set, + space_in_names, + ) + + # Intersect domain with this map + union_of_maps = union_of_maps.union( + map_from_set.intersect_domain(dom_with_set_dim_aligned)) + + return union_of_maps + + +def partition_inames_by_concurrency(knl): + from loopy.kernel.data import ConcurrentTag + conc_inames = set() + non_conc_inames = set() + + all_inames = knl.all_inames() + for iname in all_inames: + if knl.iname_tags_of_type(iname, ConcurrentTag): + conc_inames.add(iname) + else: + non_conc_inames.add(iname) + + return conc_inames, all_inames-conc_inames + + +def get_EnterLoop_inames(linearization_items): + from loopy.schedule import EnterLoop + + # Note: each iname must live in len-1 list to avoid char separation + return set().union(*[ + [item.iname, ] for item in linearization_items + if isinstance(item, EnterLoop) + ]) + + +def create_elementwise_comparison_conjunction_set( + names0, names1, var_name_to_pwaff, op="eq"): + """Create a set constrained by the conjunction of conditions comparing + `names0` to `names1`. + + :arg names0: A list of :class:`str` representing variable names. + + :arg names1: A list of :class:`str` representing variable names. + + :arg var_name_to_pwaff: A dictionary from variable names to :class:`islpy.PwAff` + instances that represent each of the variables + (var_name_to_pwaff may be produced by `islpy.make_zero_and_vars`). The key + '0' is also include and represents a :class:`islpy.PwAff` zero constant. + + :arg op: A :class:`str` describing the operator to use when creating + the set constraints. Options: `eq` for `=`, `lt` for `<` + + :returns: A set involving `var_name_to_pwaff` cosntrained by the constraints + `{names0[0] names1[0] and names0[1] names1[1] and ...}`. + + """ + + # initialize set with constraint that is always true + conj_set = var_name_to_pwaff[0].eq_set(var_name_to_pwaff[0]) + for n0, n1 in zip(names0, names1): + if op == "eq": + conj_set = conj_set & var_name_to_pwaff[n0].eq_set(var_name_to_pwaff[n1]) + elif op == "ne": + conj_set = conj_set & var_name_to_pwaff[n0].ne_set(var_name_to_pwaff[n1]) + elif op == "lt": + conj_set = conj_set & var_name_to_pwaff[n0].lt_set(var_name_to_pwaff[n1]) + + return conj_set diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 548f9ec01..85a548896 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -2158,10 +2158,10 @@ def process_set(s): # Now rename any proxy dims back to their original names - from loopy.isl_helpers import find_and_rename_dim - for real_iname, proxy_iname in proxy_name_pairs: - new_s = find_and_rename_dim( - new_s, dim_type.set, proxy_iname, real_iname) + from loopy.isl_helpers import find_and_rename_dims + new_s = find_and_rename_dims( + new_s, dim_type.set, + dict([pair[::-1] for pair in proxy_name_pairs])) # (reverse pair order) return new_s diff --git a/setup.py b/setup.py index 2e907c1b9..701f796d5 100644 --- a/setup.py +++ b/setup.py @@ -90,6 +90,7 @@ def write_git_revision(package_name): # https://github.com/inducer/loopy/pull/419 "numpy>=1.19", + "dataclasses>=0.7;python_version<='3.6'", "cgen>=2016.1", "islpy>=2019.1", diff --git a/test/test_linearization_checker.py b/test/test_linearization_checker.py new file mode 100644 index 000000000..4901b1fd8 --- /dev/null +++ b/test/test_linearization_checker.py @@ -0,0 +1,1785 @@ +from __future__ import division, print_function + +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import six # noqa: F401 +import sys +import numpy as np +import loopy as lp +import islpy as isl +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl + as pytest_generate_tests) +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa +import logging +from loopy import ( + preprocess_kernel, + get_one_linearized_kernel, +) +from loopy.schedule.checker.schedule import ( + LEX_VAR_PREFIX, + STATEMENT_VAR_NAME, + LTAG_VAR_NAMES, + GTAG_VAR_NAMES, + BEFORE_MARK, +) +from loopy.schedule.checker.utils import ( + ensure_dim_names_match_and_align, + prettier_map_string, +) +from loopy.schedule.checker import ( + get_pairwise_statement_orderings, +) + +logger = logging.getLogger(__name__) + + +# {{{ Helper functions for map creation/handling + +def _align_and_compare_maps(maps): + + for map1, map2 in maps: + # Align maps and compare + map1_aligned = ensure_dim_names_match_and_align(map1, map2) + if map1_aligned != map2: + print("Maps not equal:") + print(prettier_map_string(map1_aligned)) + print(prettier_map_string(map2)) + assert map1_aligned == map2 + + +def _lex_point_string(dim_vals, lid_inames=(), gid_inames=()): + # Return a string describing a point in a lex space + # by assigning values to lex dimension variables + # (used to create maps below) + + return ", ".join( + ["%s%d=%s" % (LEX_VAR_PREFIX, idx, str(val)) + for idx, val in enumerate(dim_vals)] + + ["%s=%s" % (LTAG_VAR_NAMES[idx], iname) + for idx, iname in enumerate(lid_inames)] + + ["%s=%s" % (GTAG_VAR_NAMES[idx], iname) + for idx, iname in enumerate(gid_inames)] + ) + + +def _isl_map_with_marked_dims(s, placeholder_mark="'"): + # For creating legible tests, map strings may be created with a placeholder + # for the 'before' mark. Replace this placeholder with BEFORE_MARK before + # creating the map. + # ALSO, if BEFORE_MARK == "'", ISL will ignore this mark when creating + # variable names, so it must be added manually. + from loopy.schedule.checker.utils import ( + append_mark_to_isl_map_var_names, + ) + dt = isl.dim_type + if BEFORE_MARK == "'": + # ISL will ignore the apostrophe; manually name the in_ vars + return append_mark_to_isl_map_var_names( + isl.Map(s.replace(placeholder_mark, BEFORE_MARK)), + dt.in_, + BEFORE_MARK) + else: + return isl.Map(s.replace(placeholder_mark, BEFORE_MARK)) + + +def _check_orderings_for_stmt_pair( + stmt_id_before, + stmt_id_after, + all_sios, + sio_intra_thread_exp=None, + sched_before_intra_thread_exp=None, + sched_after_intra_thread_exp=None, + sio_intra_group_exp=None, + sched_before_intra_group_exp=None, + sched_after_intra_group_exp=None, + sio_global_exp=None, + sched_before_global_exp=None, + sched_after_global_exp=None, + ): + + order_info = all_sios[(stmt_id_before, stmt_id_after)] + + # Get pairs of maps to compare for equality + map_candidates = zip([ + sio_intra_thread_exp, + sched_before_intra_thread_exp, sched_after_intra_thread_exp, + sio_intra_group_exp, + sched_before_intra_group_exp, sched_after_intra_group_exp, + sio_global_exp, + sched_before_global_exp, sched_after_global_exp, + ], [ + order_info.sio_intra_thread, + order_info.pwsched_intra_thread[0], order_info.pwsched_intra_thread[1], + order_info.sio_intra_group, + order_info.pwsched_intra_group[0], order_info.pwsched_intra_group[1], + order_info.sio_global, + order_info.pwsched_global[0], order_info.pwsched_global[1], + ]) + + # Only compare to maps that were passed + maps_to_compare = [(m1, m2) for m1, m2 in map_candidates if m1 is not None] + _align_and_compare_maps(maps_to_compare) + + +def _process_and_linearize(knl, knl_name="loopy_kernel"): + # Return linearization items along with the preprocessed kernel and + # linearized kernel + proc_knl = preprocess_kernel(knl) + lin_knl = get_one_linearized_kernel( + proc_knl[knl_name], proc_knl.callables_table) + return lin_knl.linearization, proc_knl[knl_name], lin_knl + +# }}} + + +# {{{ test_intra_thread_pairwise_schedule_creation() + +def test_intra_thread_pairwise_schedule_creation(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + # Example kernel + # stmt_c depends on stmt_b only to create deterministic order + # stmt_d depends on stmt_c only to create deterministic order + knl = lp.make_kernel( + [ + "{[i]: 0<=itemp = b[i,k] {id=stmt_a} + end + for j + a[i,j] = temp + 1 {id=stmt_b,dep=stmt_a} + c[i,j] = d[i,j] {id=stmt_c,dep=stmt_b} + end + end + for t + e[t] = f[t] {id=stmt_d, dep=stmt_c} + end + """, + assumptions="pi,pj,pk,pt >= 1", + ) + knl = lp.add_and_infer_dtypes( + knl, + {"b": np.float32, "d": np.float32, "f": np.float32}) + knl = lp.prioritize_loops(knl, "i,k") + knl = lp.prioritize_loops(knl, "i,j") + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ("stmt_a", "stmt_c"), + ("stmt_a", "stmt_d"), + ("stmt_b", "stmt_c"), + ("stmt_b", "stmt_d"), + ("stmt_c", "stmt_d"), + ] + pworders = get_pairwise_statement_orderings( + lin_knl, + lin_items, + stmt_id_pairs, + perform_closure_checks=True, + ) + + # {{{ Relationship between stmt_a and stmt_b + + # Create expected maps and compare + + sched_stmt_a_intra_thread_exp = isl.Map( + "[pi, pk] -> { [%s=0, i, k] -> [%s] : 0 <= i < pi and 0 <= k < pk }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "0"]), + ) + ) + + sched_stmt_b_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=1, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "1"]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, + sched_before_intra_thread_exp=sched_stmt_a_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_b_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_a and stmt_c + + # Create expected maps and compare + + sched_stmt_a_intra_thread_exp = isl.Map( + "[pi, pk] -> { [%s=0, i, k] -> [%s] : 0 <= i < pi and 0 <= k < pk }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "0"]), + ) + ) + + sched_stmt_c_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=1, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "1"]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_c", pworders, + sched_before_intra_thread_exp=sched_stmt_a_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_c_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_a and stmt_d + + # Create expected maps and compare + + sched_stmt_a_intra_thread_exp = isl.Map( + "[pi, pk] -> { [%s=0, i, k] -> [%s] : 0 <= i < pi and 0 <= k < pk }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([0, ]), + ) + ) + + sched_stmt_d_intra_thread_exp = isl.Map( + "[pt] -> { [%s=1, t] -> [%s] : 0 <= t < pt }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([1, ]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_d", pworders, + sched_before_intra_thread_exp=sched_stmt_a_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_d_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_b and stmt_c + + # Create expected maps and compare + + sched_stmt_b_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=0, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "j", 0]), + ) + ) + + sched_stmt_c_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=1, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "j", 1]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "stmt_c", pworders, + sched_before_intra_thread_exp=sched_stmt_b_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_c_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_b and stmt_d + + # Create expected maps and compare + + sched_stmt_b_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=0, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([0, ]), + ) + ) + + sched_stmt_d_intra_thread_exp = isl.Map( + "[pt] -> { [%s=1, t] -> [%s] : 0 <= t < pt }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([1, ]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "stmt_d", pworders, + sched_before_intra_thread_exp=sched_stmt_b_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_d_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_c and stmt_d + + # Create expected maps and compare + + sched_stmt_c_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=0, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([0, ]), + ) + ) + + sched_stmt_d_intra_thread_exp = isl.Map( + "[pt] -> { [%s=1, t] -> [%s] : 0 <= t < pt }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([1, ]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_c", "stmt_d", pworders, + sched_before_intra_thread_exp=sched_stmt_c_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_d_intra_thread_exp, + ) + + # }}} + +# }}} + + +# {{{ test_pairwise_schedule_creation_with_hw_par_tags() + +def test_pairwise_schedule_creation_with_hw_par_tags(): + # (further sched testing in SIO tests below) + + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + # Example kernel + knl = lp.make_kernel( + [ + "{[i,ii]: 0<=i,iitemp = b[i,ii,j,jj] {id=stmt_a} + a[i,ii,j,jj] = temp + 1 {id=stmt_b,dep=stmt_a} + end + end + end + end + """, + assumptions="pi,pj >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32, "b": np.float32}) + knl = lp.tag_inames(knl, {"j": "l.1", "jj": "l.0", "i": "g.0"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ] + pworders = get_pairwise_statement_orderings( + lin_knl, + lin_items, + stmt_id_pairs, + perform_closure_checks=True, + ) + + # {{{ Relationship between stmt_a and stmt_b + + # Create expected maps and compare + + sched_stmt_a_intra_thread_exp = isl.Map( + "[pi,pj] -> {[%s=0,i,ii,j,jj] -> [%s] : 0 <= i,ii < pi and 0 <= j,jj < pj}" + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["ii", "0"], + lid_inames=["jj", "j"], gid_inames=["i"], + ), + ) + ) + + sched_stmt_b_intra_thread_exp = isl.Map( + "[pi,pj] -> {[%s=1,i,ii,j,jj] -> [%s] : 0 <= i,ii < pi and 0 <= j,jj < pj}" + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["ii", "1"], + lid_inames=["jj", "j"], gid_inames=["i"], + ), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, + sched_before_intra_thread_exp=sched_stmt_a_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_b_intra_thread_exp, + ) + + # }}} + +# }}} + + +# {{{ test_lex_order_map_creation() + +def test_lex_order_map_creation(): + from loopy.schedule.checker.lexicographic_order_map import ( + create_lex_order_map, + ) + + def _check_lex_map(exp_lex_order_map, n_dims): + + lex_order_map = create_lex_order_map( + dim_names=["%s%d" % (LEX_VAR_PREFIX, i) for i in range(n_dims)], + in_dim_mark=BEFORE_MARK, + ) + + assert lex_order_map == exp_lex_order_map + assert lex_order_map.get_var_dict() == exp_lex_order_map.get_var_dict() + + exp_lex_order_map = _isl_map_with_marked_dims( + "{{ " + "[{0}0', {0}1', {0}2', {0}3', {0}4'] -> [{0}0, {0}1, {0}2, {0}3, {0}4] :" + "(" + "{0}0' < {0}0 " + ") or (" + "{0}0'={0}0 and {0}1' < {0}1 " + ") or (" + "{0}0'={0}0 and {0}1'={0}1 and {0}2' < {0}2 " + ") or (" + "{0}0'={0}0 and {0}1'={0}1 and {0}2'={0}2 and {0}3' < {0}3 " + ") or (" + "{0}0'={0}0 and {0}1'={0}1 and {0}2'={0}2 and {0}3'={0}3 and {0}4' < {0}4" + ")" + "}}".format(LEX_VAR_PREFIX)) + + _check_lex_map(exp_lex_order_map, 5) + + exp_lex_order_map = _isl_map_with_marked_dims( + "{{ " + "[{0}0'] -> [{0}0] :" + "(" + "{0}0' < {0}0 " + ")" + "}}".format(LEX_VAR_PREFIX)) + + _check_lex_map(exp_lex_order_map, 1) + +# }}} + + +# {{{ test_intra_thread_statement_instance_ordering() + +def test_intra_thread_statement_instance_ordering(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + # Example kernel (add deps to fix loop order) + knl = lp.make_kernel( + [ + "{[i]: 0<=itemp = b[i,k] {id=stmt_a} + end + for j + a[i,j] = temp + 1 {id=stmt_b,dep=stmt_a} + c[i,j] = d[i,j] {id=stmt_c,dep=stmt_b} + end + end + for t + e[t] = f[t] {id=stmt_d, dep=stmt_c} + end + """, + assumptions="pi,pj,pk,pt >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes( + knl, + {"b": np.float32, "d": np.float32, "f": np.float32}) + knl = lp.prioritize_loops(knl, "i,k") + knl = lp.prioritize_loops(knl, "i,j") + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + # Get pairwise schedules + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ("stmt_a", "stmt_c"), + ("stmt_a", "stmt_d"), + ("stmt_b", "stmt_c"), + ("stmt_b", "stmt_d"), + ("stmt_c", "stmt_d"), + ] + pworders = get_pairwise_statement_orderings( + proc_knl, + lin_items, + stmt_id_pairs, + perform_closure_checks=True, + ) + + # {{{ Relationship between stmt_a and stmt_b + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj, pk] -> {{ " + "[{0}'=0, i', k'] -> [{0}=1, i, j] : " + "0 <= i,i' < pi and 0 <= k' < pk and 0 <= j < pj and i >= i' " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_a and stmt_c + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj, pk] -> {{ " + "[{0}'=0, i', k'] -> [{0}=1, i, j] : " + "0 <= i,i' < pi and 0 <= k' < pk and 0 <= j < pj and i >= i' " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_c", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_a and stmt_d + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pt, pi, pk] -> {{ " + "[{0}'=0, i', k'] -> [{0}=1, t] : " + "0 <= i' < pi and 0 <= k' < pk and 0 <= t < pt " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_d", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_b and stmt_c + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', j'] -> [{0}=1, i, j] : " + "0 <= i,i' < pi and 0 <= j,j' < pj and i > i'; " + "[{0}'=0, i', j'] -> [{0}=1, i=i', j] : " + "0 <= i' < pi and 0 <= j,j' < pj and j >= j'; " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "stmt_c", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_b and stmt_d + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pt, pi, pj] -> {{ " + "[{0}'=0, i', j'] -> [{0}=1, t] : " + "0 <= i' < pi and 0 <= j' < pj and 0 <= t < pt " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "stmt_d", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_c and stmt_d + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pt, pi, pj] -> {{ " + "[{0}'=0, i', j'] -> [{0}=1, t] : " + "0 <= i' < pi and 0 <= j' < pj and 0 <= t < pt " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_c", "stmt_d", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + +# }}} + + +# {{{ test_statement_instance_ordering_with_hw_par_tags() + +def test_statement_instance_ordering_with_hw_par_tags(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + from loopy.schedule.checker.utils import ( + partition_inames_by_concurrency, + ) + + # Example kernel + knl = lp.make_kernel( + [ + "{[i,ii]: 0<=i,iitemp = b[i,ii,j,jj] {id=stmt_a} + a[i,ii,j,jj] = temp + 1 {id=stmt_b,dep=stmt_a} + end + end + end + end + """, + assumptions="pi,pj >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32, "b": np.float32}) + knl = lp.tag_inames(knl, {"j": "l.1", "jj": "l.0", "i": "g.0"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + # Get pairwise schedules + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ] + pworders = get_pairwise_statement_orderings( + lin_knl, + lin_items, + stmt_id_pairs, + perform_closure_checks=True, + ) + + # Create string for representing parallel iname condition in sio + conc_inames, _ = partition_inames_by_concurrency(knl["loopy_kernel"]) + par_iname_condition = " and ".join( + "{0} = {0}'".format(iname) for iname in conc_inames) + + # {{{ Relationship between stmt_a and stmt_b + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj and ii >= ii' " + "and {1} " + "}}".format( + STATEMENT_VAR_NAME, + par_iname_condition, + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + +# }}} + + +# {{{ test_statement_instance_ordering_of_barriers() + +def test_statement_instance_ordering_of_barriers(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + from loopy.schedule.checker.utils import ( + partition_inames_by_concurrency, + ) + + # Example kernel + knl = lp.make_kernel( + [ + "{[i,ii]: 0<=i,iitemp = b[i,ii,j,jj] {id=stmt_a,dep=gbar} + ... lbarrier {id=lbar0,dep=stmt_a} + a[i,ii,j,jj] = temp + 1 {id=stmt_b,dep=lbar0} + ... lbarrier {id=lbar1,dep=stmt_b} + end + end + end + end + <>temp2 = 0.5 {id=stmt_c,dep=lbar1} + """, + assumptions="pi,pj >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes(knl, {"a,b": np.float32}) + knl = lp.tag_inames(knl, {"j": "l.0", "i": "g.0"}) + knl = lp.prioritize_loops(knl, "ii,jj") + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + # Get pairwise schedules + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ("gbar", "stmt_a"), + ("stmt_b", "lbar1"), + ("lbar1", "stmt_c"), + ] + pworders = get_pairwise_statement_orderings( + lin_knl, + lin_items, + stmt_id_pairs, + perform_closure_checks=True, + ) + + # Create string for representing parallel iname SAME condition in sio + conc_inames, _ = partition_inames_by_concurrency(knl["loopy_kernel"]) + par_iname_condition = " and ".join( + "{0} = {0}'".format(iname) for iname in conc_inames) + + # {{{ Intra-thread relationship between stmt_a and stmt_b + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " + "and (ii > ii' or (ii = ii' and jj >= jj')) " + "and {1} " + "}}".format( + STATEMENT_VAR_NAME, + par_iname_condition, + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, + sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between gbar and stmt_a + + # intra-thread case + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj < pj " # domains + "and i = i' " # parallel inames must be same + "and ii >= ii' " # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # intra-group case + # (this test also confirms that our SIO construction accounts for the fact + # that global barriers *also* syncronize across threads *within* a group, + # which is why the before->after condition below is *not* + # "and (ii > ii' or (ii = ii' and jj > 0))") + sio_intra_group_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj < pj " # domains + "and i = i' " # GID inames must be same + "and ii >= ii'" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # global case + sio_global_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj < pj " # domains + "and ii >= ii' " # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + _check_orderings_for_stmt_pair( + "gbar", "stmt_a", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + # sio_intra_group_exp=sio_intra_group_exp, + sio_global_exp=sio_global_exp) + + # }}} + + # {{{ Relationship between stmt_b and lbar1 + + # intra thread case + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and i = i' and j = j'" # parallel inames must be same + "and (ii > ii' or (ii = ii' and jj >= jj'))" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # intra-group case + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and i = i' " # GID parallel inames must be same + "and (ii > ii' or (ii = ii' and jj >= jj'))" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # global case + + sio_global_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and ii > ii'" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "lbar1", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + sio_intra_group_exp=sio_intra_group_exp, + sio_global_exp=sio_global_exp, + ) + + # }}} + + # {{{ Relationship between stmt_a and stmt_b + + # intra thread case + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and i = i' and j = j'" # parallel inames must be same + "and (ii > ii' or (ii = ii' and jj >= jj'))" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # intra-group case + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and i = i' " # GID parallel inames must be same + "and (ii > ii' or (ii = ii' and jj >= jj'))" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + sio_intra_group_exp=sio_intra_group_exp, + ) + + # }}} + + # {{{ Relationship between lbar1 and stmt_c + + # intra thread case + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1] : " + "0 <= i',ii' < pi and 0 <= j',jj' < pj " # domains + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # intra-group case + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1] : " + "0 <= i',ii' < pi and 0 <= j',jj' < pj " # domains + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # global case + + # (only happens before if not last iteration of ii + sio_global_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1] : " + "0 <= i',ii' < pi and 0 <= j',jj' < pj " # domains + "and ii' < pi-1" + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + _check_orderings_for_stmt_pair( + "lbar1", "stmt_c", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + sio_intra_group_exp=sio_intra_group_exp, + sio_global_exp=sio_global_exp, + ) + + # }}} + +# }}} + + +# {{{ test_sios_and_schedules_with_barriers() + +def test_sios_and_schedules_with_barriers(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + assumptions = "ij_end >= ij_start + 1 and lg_end >= 1" + knl = lp.make_kernel( + [ + "{[i,j]: ij_start<=i,jtemp0 = 0 {id=stmt_0} + ... lbarrier {id=stmt_b0,dep=stmt_0} + <>temp1 = 1 {id=stmt_1,dep=stmt_b0} + for i + <>tempi0 = 0 {id=stmt_i0,dep=stmt_1} + ... lbarrier {id=stmt_ib0,dep=stmt_i0} + ... gbarrier {id=stmt_ibb0,dep=stmt_i0} + <>tempi1 = 0 {id=stmt_i1,dep=stmt_ib0} + <>tempi2 = 0 {id=stmt_i2,dep=stmt_i1} + for j + <>tempj0 = 0 {id=stmt_j0,dep=stmt_i2} + ... lbarrier {id=stmt_jb0,dep=stmt_j0} + <>tempj1 = 0 {id=stmt_j1,dep=stmt_jb0} + end + end + <>temp2 = 0 {id=stmt_2,dep=stmt_i0} + end + end + end + """, + assumptions=assumptions, + lang_version=(2018, 2) + ) + knl = lp.tag_inames(knl, {"l0": "l.0", "l1": "l.1", "g0": "g.0"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [("stmt_j1", "stmt_2"), ("stmt_1", "stmt_i0")] + pworders = get_pairwise_statement_orderings( + lin_knl, lin_items, stmt_id_pairs, perform_closure_checks=True) + + # {{{ Relationship between stmt_j1 and stmt_2 + + # Create expected maps and compare + + # Iname bound strings to facilitate creation of expected maps + i_bound_str = "ij_start <= i < ij_end" + i_bound_str_p = "ij_start <= i' < ij_end" + j_bound_str = "ij_start <= j < ij_end" + j_bound_str_p = "ij_start <= j' < ij_end" + ij_bound_str = i_bound_str + " and " + j_bound_str + ij_bound_str_p = i_bound_str_p + " and " + j_bound_str_p + conc_iname_bound_str = "0 <= l0,l1,g0 < lg_end" + conc_iname_bound_str_p = "0 <= l0',l1',g0' < lg_end" + + # {{{ Intra-group + + # (this test also confirms that our sched/SIO construction accounts for the + # fact that global barriers *also* syncronize across threads *within* a + # group, which is why dim 2 below is asigned the value 3 instead of 2) + sched_stmt_j1_intra_group_exp = isl.Map( + "[ij_start, ij_end, lg_end] -> {" + "[%s=0, i, j, l0, l1, g0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["2", "i", "3", "j", "1"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sched_stmt_2_intra_group_exp = isl.Map( + "[lg_end] -> {[%s=1, l0, l1, g0] -> [%s] : %s}" + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["3", "0", "0", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + conc_iname_bound_str, + ) + ) + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{ " + "[{0}'=0, i', j', l0', l1', g0'] -> [{0}=1, l0, l1, g0] : " + "(ij_start <= j' < ij_end-1 or " # not last iteration of j + " ij_start <= i' < ij_end-1) " # not last iteration of i + "and g0 = g0' " # within a single group + "and {1} and {2} and {3} " # iname bounds + "and {4}" # param assumptions + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + assumptions, + ) + ) + + # }}} + + # {{{ Global + + sched_stmt_j1_global_exp = isl.Map( + "[ij_start, ij_end, lg_end] -> {" + "[%s=0, i, j, l0, l1, g0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "i", "1"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sched_stmt_2_global_exp = isl.Map( + "[lg_end] -> {[%s=1, l0, l1, g0] -> [%s] : " + "%s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["2", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + conc_iname_bound_str, + ) + ) + + sio_global_exp = _isl_map_with_marked_dims( + "[ij_start,ij_end,lg_end] -> {{ " + "[{0}'=0, i', j', l0', l1', g0'] -> [{0}=1, l0, l1, g0] : " + "ij_start <= i' < ij_end-1 " # not last iteration of i + "and {1} and {2} and {3} " # iname bounds + "and {4}" # param assumptions + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + assumptions, + ) + ) + + # }}} + + _check_orderings_for_stmt_pair( + "stmt_j1", "stmt_2", pworders, + sio_intra_group_exp=sio_intra_group_exp, + sched_before_intra_group_exp=sched_stmt_j1_intra_group_exp, + sched_after_intra_group_exp=sched_stmt_2_intra_group_exp, + sio_global_exp=sio_global_exp, + sched_before_global_exp=sched_stmt_j1_global_exp, + sched_after_global_exp=sched_stmt_2_global_exp, + ) + + # {{{ Check for some key example pairs in the sio_intra_group map + + # Get maps + order_info = pworders[("stmt_j1", "stmt_2")] + + # As long as this is not the last iteration of the i loop, then there + # should be a barrier between the last instance of statement stmt_j1 + # and statement stmt_2: + ij_end_val = 7 + last_i_val = ij_end_val - 1 + max_non_last_i_val = last_i_val - 1 # max i val that isn't the last iteration + + wanted_pairs = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{" + "[{0}' = 0, i', j'=ij_end-1, g0', l0', l1'] -> [{0} = 1, l0, l1, g0] : " + "ij_start <= i' <= {1} " # constrain i + "and ij_end >= {2} " # constrain ij_end + "and g0 = g0' " # within a single group + "and {3} and {4} " # conc iname bounds + "}}".format( + STATEMENT_VAR_NAME, + max_non_last_i_val, + ij_end_val, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + wanted_pairs = ensure_dim_names_match_and_align( + wanted_pairs, order_info.sio_intra_group) + + assert wanted_pairs.is_subset(order_info.sio_intra_group) + + # If this IS the last iteration of the i loop, then there + # should NOT be a barrier between the last instance of statement stmt_j1 + # and statement stmt_2: + unwanted_pairs = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{" + "[{0}' = 0, i', j'=ij_end-1, g0', l0', l1'] -> [{0} = 1, l0, l1, g0] : " + "ij_start <= i' <= {1} " # constrain i + "and ij_end >= {2} " # constrain p + "and g0 = g0' " # within a single group + "and {3} and {4} " # conc iname bounds + "}}".format( + STATEMENT_VAR_NAME, + last_i_val, + ij_end_val, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + unwanted_pairs = ensure_dim_names_match_and_align( + unwanted_pairs, order_info.sio_intra_group) + + assert not unwanted_pairs.is_subset(order_info.sio_intra_group) + + # }}} + + # }}} + + # {{{ Relationship between stmt_1 and stmt_i0 + + # Create expected maps and compare + + # {{{ Intra-group + + sched_stmt_1_intra_group_exp = isl.Map( + "[lg_end] -> {[%s=0, l0, l1, g0] -> [%s] : " + "%s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "0", "0", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + conc_iname_bound_str, + ) + ) + + sched_stmt_i0_intra_group_exp = isl.Map( + "[ij_start, ij_end, lg_end] -> {" + "[%s=1, i, l0, l1, g0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["2", "i", "0", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + i_bound_str, + conc_iname_bound_str, + ) + ) + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{ " + "[{0}'=0, l0', l1', g0'] -> [{0}=1, i, l0, l1, g0] : " + "ij_start + 1 <= i < ij_end " # not first iteration of i + "and g0 = g0' " # within a single group + "and {1} and {2} and {3} " # iname bounds + "and {4}" # param assumptions + "}}".format( + STATEMENT_VAR_NAME, + i_bound_str, + conc_iname_bound_str, + conc_iname_bound_str_p, + assumptions, + ) + ) + + # }}} + + # {{{ Global + + sched_stmt_1_global_exp = isl.Map( + "[lg_end] -> {[%s=0, l0, l1, g0] -> [%s] : " + "%s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["0", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + conc_iname_bound_str, + ) + ) + + sched_stmt_i0_global_exp = isl.Map( + "[ij_start, ij_end, lg_end] -> {" + "[%s=1, i, l0, l1, g0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "i", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + i_bound_str, + conc_iname_bound_str, + ) + ) + + sio_global_exp = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{ " + "[{0}'=0, l0', l1', g0'] -> [{0}=1, i, l0, l1, g0] : " + "ij_start + 1 <= i < ij_end " # not first iteration of i + "and {1} and {2} and {3} " # iname bounds + "and {4}" # param assumptions + "}}".format( + STATEMENT_VAR_NAME, + i_bound_str, + conc_iname_bound_str, + conc_iname_bound_str_p, + assumptions, + ) + ) + + # }}} + + _check_orderings_for_stmt_pair( + "stmt_1", "stmt_i0", pworders, + sio_intra_group_exp=sio_intra_group_exp, + sched_before_intra_group_exp=sched_stmt_1_intra_group_exp, + sched_after_intra_group_exp=sched_stmt_i0_intra_group_exp, + sio_global_exp=sio_global_exp, + sched_before_global_exp=sched_stmt_1_global_exp, + sched_after_global_exp=sched_stmt_i0_global_exp, + ) + + # }}} + +# }}} + + +# {{{ test_sios_and_schedules_with_vec_and_barriers() + +def test_sios_and_schedules_with_vec_and_barriers(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + knl = lp.make_kernel( + "{[i, j, l0] : 0 <= i < 4 and 0 <= j < n and 0 <= l0 < 32}", + """ + for l0 + for i + for j + b[i,j,l0] = 1 {id=stmt_1} + ... lbarrier {id=b,dep=stmt_1} + c[i,j,l0] = 2 {id=stmt_2, dep=b} + end + end + end + """) + knl = lp.add_and_infer_dtypes(knl, {"b": "float32", "c": "float32"}) + + knl = lp.tag_inames(knl, {"i": "vec", "l0": "l.0"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [("stmt_1", "stmt_2")] + pworders = get_pairwise_statement_orderings( + lin_knl, lin_items, stmt_id_pairs, perform_closure_checks=True) + + # {{{ Relationship between stmt_1 and stmt_2 + + # Create expected maps and compare + + # Iname bound strings to facilitate creation of expected maps + ij_bound_str = "0 <= i < 4 and 0 <= j < n" + ij_bound_str_p = "0 <= i' < 4 and 0 <= j' < n" + conc_iname_bound_str = "0 <= l0 < 32" + conc_iname_bound_str_p = "0 <= l0' < 32" + + # {{{ Intra-thread + + sched_stmt_1_intra_thread_exp = isl.Map( + "[n] -> {" + "[%s=0, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["j", "0"], # lex points (initial matching dim gets removed) + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sched_stmt_2_intra_thread_exp = isl.Map( + "[n] -> {" + "[%s=1, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["j", "1"], # lex points (initial matching dim gets removed) + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[n] -> {{ " + "[{0}'=0, i', j', l0'] -> [{0}=1, i, j, l0] : " + "j' <= j " + "and l0 = l0' " # within a single thread + "and {1} and {2} and {3} and {4}" # iname bounds + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + + # }}} + + # {{{ Intra-group + + # Intra-group scheds would be same due to lbarrier, + # but since lex tuples are not simplified in intra-group/global + # cases, there's an extra lex dim: + + sched_stmt_1_intra_group_exp = isl.Map( + "[n] -> {" + "[%s=0, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "j", "0"], # lex points + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sched_stmt_2_intra_group_exp = isl.Map( + "[n] -> {" + "[%s=1, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "j", "1"], # lex points + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[n] -> {{ " + "[{0}'=0, i', j', l0'] -> [{0}=1, i, j, l0] : " + "j' <= j " + "and {1} and {2} and {3} and {4}" # iname bounds + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + + # }}} + + # {{{ Global + + sched_stmt_1_global_exp = isl.Map( + "[n] -> {" + "[%s=0, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["0"], # lex points + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + # (same as stmt_1 except for statement id because no global barriers) + sched_stmt_2_global_exp = isl.Map( + "[n] -> {" + "[%s=1, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["0"], # lex points + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sio_global_exp = _isl_map_with_marked_dims( + "[n] -> {{ " + "[{0}'=0, i', j', l0'] -> [{0}=1, i, j, l0] : " + "False " + "and {1} and {2} and {3} and {4}" # iname bounds + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + + # }}} + + _check_orderings_for_stmt_pair( + "stmt_1", "stmt_2", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + sched_before_intra_thread_exp=sched_stmt_1_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_2_intra_thread_exp, + sio_intra_group_exp=sio_intra_group_exp, + sched_before_intra_group_exp=sched_stmt_1_intra_group_exp, + sched_after_intra_group_exp=sched_stmt_2_intra_group_exp, + sio_global_exp=sio_global_exp, + sched_before_global_exp=sched_stmt_1_global_exp, + sched_after_global_exp=sched_stmt_2_global_exp, + ) + + # }}} + +# }}} + + +# {{{ test_sios_with_matmul + +def test_sios_with_matmul(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + # For now, this test just ensures all pairwise SIOs can be created + # for a complex parallel kernel without any errors/exceptions. Later PRs + # will examine this kernel's SIOs and related dependencies for accuracy. + + bsize = 16 + knl = lp.make_kernel( + "{[i,k,j]: 0<=itemp0 = 0 {id=stmt_i0} + ... lbarrier {id=stmt_b0,dep=stmt_i0} + <>temp1 = 1 {id=stmt_i1,dep=stmt_b0} + for j + <>tempj0 = 0 {id=stmt_j0,dep=stmt_i1} + ... lbarrier {id=stmt_jb0,dep=stmt_j0} + ... gbarrier {id=stmt_jbb0,dep=stmt_j0} + <>tempj1 = 0 {id=stmt_j1,dep=stmt_jb0} + <>tempj2 = 0 {id=stmt_j2,dep=stmt_j1} + for k + <>tempk0 = 0 {id=stmt_k0,dep=stmt_j2} + ... lbarrier {id=stmt_kb0,dep=stmt_k0} + <>tempk1 = 0 {id=stmt_k1,dep=stmt_kb0} + end + end + <>temp2 = 0 {id=stmt_i2,dep=stmt_j0} + end + """, + assumptions=assumptions, + lang_version=(2018, 2) + ) + + ref_knl = knl + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [ + ("stmt_i0", "stmt_i1"), + ("stmt_i1", "stmt_j0"), + ("stmt_j0", "stmt_j1"), + ("stmt_j1", "stmt_j2"), + ("stmt_j2", "stmt_k0"), + ("stmt_k0", "stmt_k1"), + ("stmt_k1", "stmt_i2"), + ] + # Set perform_closure_checks=True and get the orderings + get_pairwise_statement_orderings( + lin_knl, lin_items, stmt_id_pairs, perform_closure_checks=True) + + # Now try it with concurrent i loop + knl = lp.tag_inames(knl, "i:g.0") + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [ + ("stmt_i0", "stmt_i1"), + ("stmt_i1", "stmt_j0"), + ("stmt_j0", "stmt_j1"), + ("stmt_j1", "stmt_j2"), + ("stmt_j2", "stmt_k0"), + ("stmt_k0", "stmt_k1"), + ("stmt_k1", "stmt_i2"), + ] + # Set perform_closure_checks=True and get the orderings + get_pairwise_statement_orderings( + lin_knl, lin_items, stmt_id_pairs, perform_closure_checks=True) + + # Now try it with concurrent i and j loops + knl = lp.tag_inames(knl, "j:g.1") + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [ + ("stmt_i0", "stmt_i1"), + ("stmt_i1", "stmt_j0"), + ("stmt_j0", "stmt_j1"), + ("stmt_j1", "stmt_j2"), + ("stmt_j2", "stmt_k0"), + ("stmt_k0", "stmt_k1"), + ("stmt_k1", "stmt_i2"), + ] + # Set perform_closure_checks=True and get the orderings + get_pairwise_statement_orderings( + lin_knl, lin_items, stmt_id_pairs, perform_closure_checks=True) + + # Now try it with concurrent i and k loops + knl = ref_knl + knl = lp.tag_inames(knl, {"i": "g.0", "k": "g.1"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [ + ("stmt_i0", "stmt_i1"), + ("stmt_i1", "stmt_j0"), + ("stmt_j0", "stmt_j1"), + ("stmt_j1", "stmt_j2"), + ("stmt_j2", "stmt_k0"), + ("stmt_k0", "stmt_k1"), + ("stmt_k1", "stmt_i2"), + ] + # Set perform_closure_checks=True and get the orderings + get_pairwise_statement_orderings( + lin_knl, lin_items, stmt_id_pairs, perform_closure_checks=True) + + # FIXME create some expected sios and compare + +# }}} + + +# {{{ test_blex_map_transitivity_with_duplicate_conc_inames + +def test_blex_map_transitivity_with_duplicate_conc_inames(): + + knl = lp.make_kernel( + [ + "{[i,j,ii,jj]: 0 <= i,j,jj < n and i <= ii < n}", + "{[k, kk]: 0 <= k,kk < n}", + ], + """ + for i + for ii + <> si = 0 {id=si} + ... lbarrier {id=bari, dep=si} + end + end + for j + for jj + <> sj = 0 {id=sj, dep=si} + ... lbarrier {id=barj, dep=sj} + end + end + for k + for kk + <> sk = 0 {id=sk, dep=sj} + ... lbarrier {id=bark, dep=sk} + end + end + """, + assumptions="0 < n", + lang_version=(2018, 2) + ) + + knl = lp.tag_inames(knl, {"i": "l.0", "j": "l.0", "k": "l.0"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [ + ("si", "si"), + ("si", "sj"), + ("si", "sk"), + ("sj", "sj"), + ("sj", "sk"), + ("sk", "sk"), + ] + + # Set perform_closure_checks=True and get the orderings + get_pairwise_statement_orderings( + lin_knl, lin_items, stmt_id_pairs, perform_closure_checks=True) + + # print(prettier_map_string(pw_sios[("si", "sj")].sio_intra_thread)) + # print(prettier_map_string(pw_sios[("si", "sj")].sio_intra_group)) + # print(prettier_map_string(pw_sios[("si", "sj")].sio_global)) + + # FIXME create some expected sios and compare + +# }}} + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: foldmethod=marker