diff --git a/shardy/dialect/mpmd/ir/utils.h b/shardy/dialect/mpmd/ir/utils.h index 556d00a66..4e9b28635 100644 --- a/shardy/dialect/mpmd/ir/utils.h +++ b/shardy/dialect/mpmd/ir/utils.h @@ -62,7 +62,11 @@ inline constexpr StringRef kIsSdyPartitioned = "mpmd.is_sdy_partitioned"; inline constexpr StringRef kIsGspmdPartitioned = "mpmd.is_gspmd_partitioned"; // The suffix of the mesh name for a CPU mesh. +// LINT.IfChange constexpr StringRef kCpuMeshSuffix = "/cpu"; +// LINT.ThenChange( +// https://github.com/openxla/shardy/blob/main/shardy/integrations/python/jax/mpmd/types.py +// ) // Memory kind attributes. // Attr on func args and results to indicate whether the value lives on host or diff --git a/shardy/integrations/python/jax/mpmd/__init__.py b/shardy/integrations/python/jax/mpmd/__init__.py index d5dee7a26..f6fb33794 100644 --- a/shardy/integrations/python/jax/mpmd/__init__.py +++ b/shardy/integrations/python/jax/mpmd/__init__.py @@ -24,14 +24,14 @@ from shardy.integrations.python.jax.mpmd.ops import named_computation from shardy.integrations.python.jax.mpmd.ops import named_tensor from shardy.integrations.python.jax.mpmd.ops import reduce +from shardy.integrations.python.jax.mpmd.pipeline import FragmentInfo +from shardy.integrations.python.jax.mpmd.pipeline import FragmentMergeRule +from shardy.integrations.python.jax.mpmd.pipeline import FragmentMergeRules +from shardy.integrations.python.jax.mpmd.pipeline import FragmentOrigin from shardy.integrations.python.jax.mpmd.stages import MpmdCompiled as Compiled from shardy.integrations.python.jax.mpmd.stages import MpmdExecutable as Executable from shardy.integrations.python.jax.mpmd.stages import MpmdJitShardingInfo from shardy.integrations.python.jax.mpmd.stages import MpmdLowered as Lowered -from shardy.integrations.python.jax.mpmd.types import FragmentInfo -from shardy.integrations.python.jax.mpmd.types import FragmentMergeRule -from shardy.integrations.python.jax.mpmd.types import FragmentMergeRules -from shardy.integrations.python.jax.mpmd.types import FragmentOrigin from shardy.integrations.python.jax.mpmd.types import FunctionIOMeshAssignment from shardy.integrations.python.jax.mpmd.types import make_config from shardy.integrations.python.jax.mpmd.types import mesh_names diff --git a/shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.cc b/shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.cc index acd6b8240..4c7becf46 100644 --- a/shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.cc +++ b/shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.cc @@ -169,29 +169,35 @@ PartitioningResult MpmdProgram::ApplyPartitioning(PartitioningPhase phases) { func::FuncOp main_func = GetMainFunction(module); SetTopology(named_meshes, main_func); - SetArgDonationAttributes(main_func, donate_argnums); - // It is not necessary to do this - // validation after the export pipeline because here we're only checking that - // the attributes set on the main func are consistent with the received donate - // args. - VerifyOnlyDonatedArgsHaveDonationAttributes(main_func, donate_argnums); + if (phases & PartitioningPhase::kImport) { + SetArgDonationAttributes(main_func, donate_argnums); - SDY_LOG(INFO) << "Importing function named " << func_name - << " for MPMD partitioning."; + // It is not necessary to do this validation after the export pipeline + // because here we're only checking that the attributes set on the main func + // are consistent with the received donate args. + VerifyOnlyDonatedArgsHaveDonationAttributes(main_func, donate_argnums); - Import(module); + SDY_LOG(INFO) << "Importing function named " << func_name + << " for MPMD partitioning."; - SDY_LOG(INFO) << "Optimizing function named " << func_name - << " for pipeline parallelism."; - Optimize(module); + Import(module); + } + + if (phases & PartitioningPhase::kOptimize) { + SDY_LOG(INFO) << "Optimizing function named " << func_name + << " for pipeline parallelism."; + Optimize(module); + } - SDY_LOG(INFO) << "Applying SDY propagation to function named " << func_name - << "."; - PropagateSharding(module); + if (phases & PartitioningPhase::kPartition) { + SDY_LOG(INFO) << "Applying SDY propagation to function named " << func_name + << "."; + PropagateSharding(module); - SDY_LOG(INFO) << "Exporting MPMD function named " << func_name << "."; - Export(module); + SDY_LOG(INFO) << "Exporting MPMD function named " << func_name << "."; + Export(module); + } return PartitioningResult(module); } @@ -206,7 +212,8 @@ void MpmdProgram::Import(ModuleOp module) { ConvertMeshVectorToMap(input_meshes)}; import_options.outputIndexToMeshAssignment = { ConvertMeshVectorToMap(output_meshes)}; - import_options.mergeAfterScheduling = options.mpmd_merge_after_scheduling; + import_options.mergeAfterScheduling = + options.mpmd_merge_inferred_after_scheduling; import_options.absorbInferredFragmentsOnEntryPointFunction = options.mpmd_absorb_inferred_fragments_on_entry_point_function; import_options.cloneInferredFragments = @@ -227,7 +234,8 @@ void MpmdProgram::Optimize(ModuleOp module) { OptimizeOptions optimize_options; optimize_options.fragmentMergeRules = llvm::to_vector(fragment_merge_rules); - optimize_options.mergeAfterScheduling = options.mpmd_merge_after_scheduling; + optimize_options.mergeAfterScheduling = + options.mpmd_merge_inferred_after_scheduling; optimize_options.applyFragmentRemat = options.mpmd_fragment_remat; optimize_options.mergeRematFragments = options.mpmd_merge_remat_fragments; optimize_options.absorbInferredFragmentsOnEntryPointFunction = diff --git a/shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.h b/shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.h index 24d8e11bc..b83fcfa53 100644 --- a/shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.h +++ b/shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.h @@ -103,7 +103,7 @@ struct PartitioningOptions { bool mpmd_absorb_inferred_fragments_on_entry_point_function = false; bool mpmd_copy_constant_creation_from_producer_to_consumer = false; bool mpmd_apply_merge_transfers_pass = false; - bool mpmd_merge_after_scheduling = false; + bool mpmd_merge_inferred_after_scheduling = false; }; PartitioningOptions ParsePartitioningOptions( @@ -120,6 +120,7 @@ struct MpmdProgram { const std::vector>& output_meshes; const std::vector& donate_argnums; const mlir::mpmd::FragmentMergeRules& fragment_merge_rules; + const mlir::mpmd::FragmentScheduleRules& fragment_schedule_rules; // Runs the PartIR MPMD partitioning passes on the MPMD program. // diff --git a/shardy/integrations/python/jax/mpmd/jaxlib_utils.py b/shardy/integrations/python/jax/mpmd/jaxlib_utils.py new file mode 100644 index 000000000..22e5a7ac0 --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/jaxlib_utils.py @@ -0,0 +1,117 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for converting between Python types and jaxlib pybind pipeline.""" + +from jaxlib import _sdy_mpmd as jaxlib_mpmd + +from shardy.integrations.python.jax.mpmd import pipeline + + +def _to_jaxlib_split_type( + split_type: pipeline.SplitFragmentType | None, +) -> jaxlib_mpmd.SplitFragmentType | None: + """Convert native Python enum to pybinded enum.""" + if split_type is None: + return None + if split_type == pipeline.SplitFragmentType.KEEP_TRANSFERRED: + return jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED + elif split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED: + return jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED + else: + raise ValueError(f'Unknown SplitFragmentType: {split_type}') + + +def _from_jaxlib_split_type( + split_type: jaxlib_mpmd.SplitFragmentType | None, +) -> pipeline.SplitFragmentType | None: + """Convert pybinded enum to native Python enum.""" + if split_type is None: + return None + if split_type == jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED: + return pipeline.SplitFragmentType.KEEP_TRANSFERRED + elif split_type == jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED: + return pipeline.SplitFragmentType.DROP_TRANSFERRED + else: + raise ValueError(f'Unknown jaxlib_mpmd.SplitFragmentType: {split_type}') + + +def convert_fragment_info_to_pybind( + fragment: pipeline.FragmentInfo, +) -> jaxlib_mpmd.FragmentInfo: + """Converts FragmentInfo to jaxlib_mpmd.FragmentInfo.""" + return jaxlib_mpmd.FragmentInfo( + origins=[ + jaxlib_mpmd.FragmentOrigin( + origin.computation_name, origin.transpose_count + ) + for origin in fragment.origins + ], + stage_id=fragment.stage_id, + call_counter=fragment.call_counter, + split_type=_to_jaxlib_split_type(fragment.split_type), + mesh_name=fragment.mesh_name, + ) + + +def convert_pybind_fragment_info_to_types( + fragment: jaxlib_mpmd.FragmentInfo, +) -> pipeline.FragmentInfo: + """Converts jaxlib_mpmd.FragmentInfo to FragmentInfo.""" + return pipeline.FragmentInfo( + origins=tuple( + pipeline.FragmentOrigin( + origin.computation_name, origin.transpose_count + ) + for origin in fragment.origins + ), + stage_id=fragment.stage_id, + call_counter=fragment.call_counter, + split_type=_from_jaxlib_split_type(fragment.split_type), + mesh_name=fragment.mesh_name, + ) + + +def convert_fragment_merge_rules_to_pybind( + fragment_merge_rules: pipeline.FragmentMergeRules, +) -> list[jaxlib_mpmd.FragmentMergeRule]: + """Converts fragment merge rules to jaxlib_mpmd.FragmentMergeRules.""" + pybind_fragment_merge_rules = [] + for rule in fragment_merge_rules: + fragments = [ + convert_fragment_info_to_pybind(fragment) for fragment in rule.sources + ] + pybind_fragment_merge_rules.append( + jaxlib_mpmd.FragmentMergeRule( + sources=fragments, + target=convert_fragment_info_to_pybind(rule.target), + ) + ) + return pybind_fragment_merge_rules + + +def convert_fragment_schedule_rules_to_pybind( + fragment_schedule_rules: pipeline.FragmentScheduleRules, +) -> list[jaxlib_mpmd.FragmentScheduleRule]: + """Converts fragment schedule rules to jaxlib_mpmd.FragmentScheduleRules.""" + pybind_fragment_schedule_rules = [] + for rule in fragment_schedule_rules: + fragments = [ + convert_fragment_info_to_pybind(fragment) + for fragment in rule.ordered_fragments + ] + pybind_fragment_schedule_rules.append( + jaxlib_mpmd.FragmentScheduleRule(ordered_fragments=fragments) + ) + return pybind_fragment_schedule_rules diff --git a/shardy/integrations/python/jax/mpmd/jaxlib_utils_test.py b/shardy/integrations/python/jax/mpmd/jaxlib_utils_test.py new file mode 100644 index 000000000..1fb6789ef --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/jaxlib_utils_test.py @@ -0,0 +1,146 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for jaxlib conversion utilities.""" + +from absl.testing import absltest +from absl.testing import parameterized +from jaxlib import _sdy_mpmd as jaxlib_mpmd + +from shardy.integrations.python.jax.mpmd import jaxlib_utils +from shardy.integrations.python.jax.mpmd import pipeline + + +class SplitTypeConversionTest(parameterized.TestCase): + """Tests for SplitFragmentType conversion functions.""" + + @parameterized.named_parameters( + ( + 'keep_transferred', + pipeline.SplitFragmentType.KEEP_TRANSFERRED, + jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED, + ), + ( + 'drop_transferred', + pipeline.SplitFragmentType.DROP_TRANSFERRED, + jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED, + ), + ('none', None, None), + ) + def test_bidirectional_conversion(self, python_val, pybind_val): + """Test SplitFragmentType conversion in both directions.""" + self.assertEqual(jaxlib_utils._to_jaxlib_split_type(python_val), pybind_val) + self.assertEqual( + jaxlib_utils._from_jaxlib_split_type(pybind_val), python_val + ) + + +class FragmentInfoConversionTest(parameterized.TestCase): + """Tests for FragmentInfo conversion functions.""" + + @parameterized.named_parameters( + ( + 'single_origin', + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('comp1', 0),), mesh_name='mesh1' + ), + ), + ( + 'multiple_origins', + pipeline.FragmentInfo( + origins=( + pipeline.FragmentOrigin('comp1', 0), + pipeline.FragmentOrigin('comp2', 1), + ), + mesh_name='mesh1', + ), + ), + ( + 'all_fields', + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('comp1', 2),), + stage_id=5, + call_counter=3, + split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED, + mesh_name='mesh2', + ), + ), + ) + def test_roundtrip(self, fragment): + """Test Python → pybind → Python roundtrip preserves data.""" + pybind_fragment = jaxlib_utils.convert_fragment_info_to_pybind(fragment) + result = jaxlib_utils.convert_pybind_fragment_info_to_types(pybind_fragment) + self.assertEqual(result, fragment) + + +class FragmentMergeRulesConversionTest(absltest.TestCase): + """Tests for FragmentMergeRule conversion functions.""" + + def test_single_rule(self): + """Test converting single merge rule.""" + f1 = pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('f1', 0),), mesh_name='m1' + ) + f2 = pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('f2', 0),), mesh_name='m1' + ) + target = pipeline.FragmentInfo( + origins=( + pipeline.FragmentOrigin('f1', 0), + pipeline.FragmentOrigin('f2', 0), + ), + mesh_name='m1', + ) + + rule = pipeline.FragmentMergeRule(sources={f1, f2}, target=target) + result = jaxlib_utils.convert_fragment_merge_rules_to_pybind([rule]) + + self.assertLen(result, 1) + self.assertLen(result[0].sources, 2) + self.assertLen(result[0].target.origins, 2) + + +class FragmentScheduleRulesConversionTest(absltest.TestCase): + """Tests for FragmentScheduleRule conversion functions.""" + + def test_preserves_order(self): + """Test that ordered_fragments order is preserved.""" + frags = [ + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('first', 0),), mesh_name='m1' + ), + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('second', 0),), mesh_name='m1' + ), + pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin('third', 0),), mesh_name='m1' + ), + ] + + rule = pipeline.FragmentScheduleRule(ordered_fragments=frags) + result = jaxlib_utils.convert_fragment_schedule_rules_to_pybind([rule]) + + self.assertEqual( + result[0].ordered_fragments[0].origins[0].computation_name, 'first' + ) + self.assertEqual( + result[0].ordered_fragments[1].origins[0].computation_name, 'second' + ) + self.assertEqual( + result[0].ordered_fragments[2].origins[0].computation_name, 'third' + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/shardy/integrations/python/jax/mpmd/jit.py b/shardy/integrations/python/jax/mpmd/jit.py index ff84bd81c..650119d4b 100644 --- a/shardy/integrations/python/jax/mpmd/jit.py +++ b/shardy/integrations/python/jax/mpmd/jit.py @@ -28,7 +28,9 @@ import numpy as np import typing_extensions +from shardy.integrations.python.jax.mpmd import jaxlib_utils from shardy.integrations.python.jax.mpmd import ops +from shardy.integrations.python.jax.mpmd import pipeline from shardy.integrations.python.jax.mpmd import stages from shardy.integrations.python.jax.mpmd import types from shardy.integrations.python.jax.mpmd import utils @@ -38,7 +40,7 @@ @dataclasses.dataclass(frozen=True) class _MpmdPartitioningArgs: - """Arguments for mpmd_py.apply_mpmd_partitioning. + """Arguments for jaxlib_mpmd.apply_mpmd_partitioning. This is essentially a processed version of a MpmdConfig dataclass, but in a format that is more convenient for the C++ function. Note that users should @@ -72,8 +74,15 @@ class _MpmdPartitioningArgs: tpu_topology_args_proto: See `types.MpmdConfig.tpu_info`. This is required for TPUs when using GSPMD partitioning. partitioning_options: See `types.MpmdConfig.partitioning_options`. - fragment_merge_rules: See `types.MpmdConfig.fragment_merge_rules`. - fragment_schedule_rules: See `types.MpmdConfig.fragment_schedule_rules`. + fragment_merge_rules: A sequence of fragment merge rules. Each merge rule + contains a sequence of fragment metadata objects that should be merged + into a single fragment, together with metadata for the resulting fragment. + These rules are generated from the `pipeline_schedule` in the + `MpmdConfig`. + fragment_schedule_rules: A sequence of fragment schedule rules. Each + schedule rule contains a sequence of fragment metadata objects in the + order that they should be scheduled. These rules are generated from the + `pipeline_schedule` in the `MpmdConfig`. """ func_name: str @@ -83,8 +92,12 @@ class _MpmdPartitioningArgs: output_meshes: Sequence[str | None] donate_argnums: Sequence[int] partitioning_options: types.PartitioningOptions | None - fragment_merge_rules: Sequence[jaxlib_mpmd.FragmentMergeRule] - fragment_schedule_rules: Sequence[jaxlib_mpmd.FragmentScheduleRule] + fragment_merge_rules: Sequence[jaxlib_mpmd.FragmentMergeRule] = ( + dataclasses.field(default_factory=list) + ) + fragment_schedule_rules: Sequence[jaxlib_mpmd.FragmentScheduleRule] = ( + dataclasses.field(default_factory=list) + ) @dataclasses.dataclass(frozen=True) @@ -99,6 +112,16 @@ class MpmdLoweredArgs: flat_input_mesh_assignment: Sequence[str] | None = None +def _get_fragment_info( + mlir_module: mlir.ir.Module, +) -> list[pipeline.FragmentInfo]: + """Returns the fragment info for the given MLIR module.""" + return [ + jaxlib_utils.convert_pybind_fragment_info_to_types(info) + for info in jaxlib_mpmd.get_fragment_info(mlir_module) + ] + + def _apply_partitioning( mlir_module: mlir.ir.Module, partitioning_args: _MpmdPartitioningArgs, @@ -115,8 +138,7 @@ def _apply_partitioning( donate_argnums=partitioning_args.donate_argnums, partitioning_options=partitioning_args.partitioning_options, fragment_merge_rules=partitioning_args.fragment_merge_rules, - # TODO: b/424385447 - Reenable fragment_schedule_rules once - # we update jaxlib. + fragment_schedule_rules=partitioning_args.fragment_schedule_rules, phases=phases, ) @@ -281,11 +303,6 @@ def _shaped_abstractify(x): if arg_info.donated ] - assert not self._mpmd_config.fragment_merge_rules - fragment_merge_rules = [] - assert not self._mpmd_config.fragment_schedule_rules - fragment_schedule_rules = [] - partitioning_args = _MpmdPartitioningArgs( func_name=func_name, named_meshes=topology_shape, @@ -294,8 +311,8 @@ def _shaped_abstractify(x): output_meshes=flat_output_mesh_assignment, donate_argnums=donate_argnums, partitioning_options=self._mpmd_config.partitioning_options, - fragment_merge_rules=fragment_merge_rules, - fragment_schedule_rules=fragment_schedule_rules, + # Rules will be generated in _import_and_generate_rules, as long as a + # PipelineSchedule has been passed into MpmdConfig ) lowered_args = MpmdLoweredArgs( stablehlo_mlir_module=stablehlo_mlir_module, @@ -307,6 +324,70 @@ def _shaped_abstractify(x): ) return mlir_module, partitioning_args, lowered_args + def _import_and_generate_rules( + self, + mlir_module: mlir.ir.Module, + partitioning_args: _MpmdPartitioningArgs, + ) -> tuple[jaxlib_mpmd.PartitioningResult, _MpmdPartitioningArgs]: + if self._mpmd_config.pipeline_schedule is None: + raise ValueError('Pipeline schedule is not defined') + + # Validate and merge partitioning options with options required by the + # pipeline schedule + validated_options = types.validate_and_merge_partitioning_options( + pipeline_required_options=self._mpmd_config.pipeline_schedule.required_mpmd_options, + user_provided_options=partitioning_args.partitioning_options, + ) + partitioning_args_with_pipeline_options = dataclasses.replace( + partitioning_args, + partitioning_options=validated_options, + ) + + imported_result = _apply_partitioning( + mlir_module, + partitioning_args_with_pipeline_options, + jaxlib_mpmd.PartitioningPhase.IMPORT, + ) + context = pipeline.PipelineContext( + num_meshes=len(types.get_schedulable_meshes(self._mpmd_config.topology)) + ) + schedule_rules, merge_rules = pipeline.build_rules_from_pipeline( + _get_fragment_info(imported_result.mpmd_module), + self._mpmd_config.pipeline_schedule, + context, + ) + + # Populate the partitioning args with the generated rules + partitioning_args_with_rules = dataclasses.replace( + partitioning_args_with_pipeline_options, + fragment_schedule_rules=jaxlib_utils.convert_fragment_schedule_rules_to_pybind( + schedule_rules + ), + fragment_merge_rules=jaxlib_utils.convert_fragment_merge_rules_to_pybind( + merge_rules + ), + ) + + return imported_result.mpmd_module, partitioning_args_with_rules + + def _partition_with_pipeline_schedule( + self, + mlir_module: mlir.ir.Module, + partitioning_args: _MpmdPartitioningArgs, + ) -> jaxlib_mpmd.PartitioningResult: + + imported_module, partitioning_args_with_rules = ( + self._import_and_generate_rules(mlir_module, partitioning_args) + ) + partitioning_result = _apply_partitioning( + imported_module, + partitioning_args_with_rules, + jaxlib_mpmd.PartitioningPhase.OPTIMIZE + | jaxlib_mpmd.PartitioningPhase.PARTITION, + ) + + return partitioning_result + @typing_extensions.override def lower( self, @@ -327,9 +408,14 @@ def lower( self._prepare_partitioning_args(_private_parameters) ) - partitioning_result = _apply_partitioning( - mlir_module, partitioning_args, jaxlib_mpmd.PartitioningPhase.ALL - ) + if self._mpmd_config.pipeline_schedule: + partitioning_result = self._partition_with_pipeline_schedule( + mlir_module, partitioning_args + ) + else: + partitioning_result = _apply_partitioning( + mlir_module, partitioning_args, jaxlib_mpmd.PartitioningPhase.ALL + ) ifrt_ir_module = jaxlib_mpmd.clone_mlir_module( partitioning_result.mpmd_module ) @@ -514,9 +600,11 @@ def __init__( """Initializes an MpmdWrapped object.""" if override_func_name: + @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) + wrapper.__name__ = override_func_name self.func = wrapper else: @@ -627,8 +715,8 @@ def jit( out_shardings: See `jax.jit`. donate_argnums: See `jax.jit`. keep_unused: See `jax.jit`. - override_func_name: If provided, the function name will be overridden to - the provided value. + override_func_name: If provided, the function name will be overridden to the + provided value. Returns: An MpmdWrapped object. diff --git a/shardy/integrations/python/jax/mpmd/partitioning_options.py b/shardy/integrations/python/jax/mpmd/partitioning_options.py index b775abc88..fbaabdbbc 100644 --- a/shardy/integrations/python/jax/mpmd/partitioning_options.py +++ b/shardy/integrations/python/jax/mpmd/partitioning_options.py @@ -26,7 +26,7 @@ 'mpmd_absorb_inferred_fragments_on_entry_point_function', 'mpmd_copy_constant_creation_from_producer_to_consumer', 'mpmd_apply_merge_transfers_pass', - 'mpmd_merge_after_scheduling', + 'mpmd_merge_inferred_after_scheduling', }) MPMD_PIPELINE_SCHEDULE_OPTION = 'mpmd_pipeline_schedule' diff --git a/shardy/integrations/python/jax/mpmd/pipeline.py b/shardy/integrations/python/jax/mpmd/pipeline.py new file mode 100644 index 000000000..2c7a5e9c8 --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/pipeline.py @@ -0,0 +1,464 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Core data structures and helper functions for MPMD pipeline scheduling. + +The primary entry point for defining a schedule is the `PipelineSchedule` +object, which uses rule builders to determine the execution order and +merging of these fragments. Rule builders take lists of fragments and build +concrete scheduling/merging rules. + +There are two main approaches to defining pipeline schedules: + +1. Predicate-based approach (recommended for simple patterns): + Use binary predicates with helper functions to automatically generate rules. + `schedule_impl.py` contains implementations of common schedules using these + predicates. + +2. Direct construction (for complex custom schedules): + Explicitly build execution order and merge rules for full control. + +This is best shown through example: see `pipeline_test.py` for a concrete +PipelineSchedule definitions using both approaches. +""" + +import collections +from collections.abc import Collection, Mapping, Sequence, Set +import dataclasses +import enum +from typing import Callable + +FragmentMergeRules = Sequence['FragmentMergeRule'] +FragmentScheduleRules = Sequence['FragmentScheduleRule'] + +# Function that builds schedule rules from fragments and pipeline context. +ScheduleRuleBuilder = Callable[ + [Sequence['FragmentInfo'], 'PipelineContext'], FragmentScheduleRules +] + +# Function that constructs a target FragmentInfo from a sequence of source +# fragments that will be merged together into the target. +TargetInfoBuilder = Callable[[Sequence['FragmentInfo']], 'FragmentInfo'] +# Function that builds merge rules from fragments and pipeline context. +MergeRuleBuilder = Callable[ + [Sequence['FragmentInfo'], 'PipelineContext'], FragmentMergeRules +] +# Function that builds both schedule and merge rules from fragments and pipeline +# context. +ScheduleMergeRuleBuilder = Callable[ + [Sequence['FragmentInfo'], 'PipelineContext'], + tuple[FragmentScheduleRules, FragmentMergeRules], +] + +# Binary predicate determining if two fragments should be merged or scheduled +# together. +RuleGeneratorPredicate = Callable[ + ['FragmentInfo', 'FragmentInfo', 'PipelineContext'], bool +] + + +@dataclasses.dataclass(frozen=True) +class FragmentOrigin: + """The origin of a fragment.""" + + computation_name: str + transpose_count: int = 0 + + +@enum.unique +class SplitFragmentType(enum.Enum): + """Fragment split behavior for transferred data. + + These values indicate how fragment portions handle transferred data from + the original fragment if the fragment is split during compilation: + - KEEP_TRANSFERRED: Fragment portion retains transferred data + - DROP_TRANSFERRED: Fragment portion drops transferred data + """ + + KEEP_TRANSFERRED = enum.auto() + DROP_TRANSFERRED = enum.auto() + + +@dataclasses.dataclass(frozen=True) +class FragmentInfo: + """A fragment of a computation.""" + + origins: tuple[FragmentOrigin, ...] + stage_id: int | None = None + call_counter: int | None = None + split_type: SplitFragmentType | None = None + mesh_name: str = '' + + +def validate_fragment_rule_origins( + fragment_collection: Collection[FragmentInfo], +) -> None: + """Validates that all fragments have at least one origin.""" + for fragment in fragment_collection: + if not fragment.origins: + raise ValueError( + f'Each fragment must have at least one origin, but got {fragment} in' + f' {fragment_collection}.' + ) + + +def validate_fragment_rule_meshes( + fragment_collection: Collection[FragmentInfo], +) -> None: + """Validates that all fragments are on the same mesh.""" + first_fragment = next(iter(fragment_collection)) + first_mesh = first_fragment.mesh_name + if not all( + fragment.mesh_name == first_mesh for fragment in fragment_collection + ): + raise ValueError( + 'Fragments being merged/scheduled must be on the same mesh, but got' + f' {fragment_collection}.' + ) + + +@dataclasses.dataclass(frozen=True) +class FragmentMergeRule: + """A rule for merging fragments of a computation. + + Attributes: + sources: The source fragments to be merged. The order does not affect the + final position of the merged fragment. + target: The target fragment metadata that results from merging the sources. + """ + + sources: Set[FragmentInfo] + target: FragmentInfo + + def __post_init__(self): + # Validate the fragment merge rule. + if len(self.sources) < 2: + raise ValueError( + 'FragmentMergeRule must contain at least 2 source fragments, but got' + f' {self}.' + ) + validate_fragment_rule_origins(self.sources) + validate_fragment_rule_meshes(self.sources) + + if not self.target.origins: + raise ValueError( + f'Target fragment must have at least one origin, but got {self}.' + ) + + +@dataclasses.dataclass(frozen=True) +class FragmentScheduleRule: + """A rule for scheduling fragments in a specific execution order. + + Attributes: + ordered_fragments: Fragments in the order they should execute. Must contain + at least 2 fragments, and all fragments must be on the same mesh. + """ + + ordered_fragments: Sequence[FragmentInfo] + + def __post_init__(self): + # Validate the fragment schedule rule. + if len(self.ordered_fragments) < 2: + raise ValueError( + 'FragmentScheduleRule must contain at least 2 fragments, but got' + f' {self}.' + ) + validate_fragment_rule_origins(self.ordered_fragments) + validate_fragment_rule_meshes(self.ordered_fragments) + + +@dataclasses.dataclass(frozen=True) +class PipelineContext: + """Context for pipeline scheduling and merging predicates.""" + + num_meshes: int + + +@dataclasses.dataclass(frozen=True) +class PipelineSchedule: + """A set of rules and options which define an MPMD pipeline. + + Attributes: + merge_rule_builders: A sequence of functions that build merge rules for + fragments. + schedule_rule_builders: A sequence of functions that build schedule rules + for fragments. + schedule_merge_rule_builders: A sequence of functions that build both + schedule and merge rules for fragments. + required_mpmd_options: A mapping of PartitioningEnvironment flags that are + required for this schedule to function correctly. See + `partitioning_options.py` for available options. Relevant options + include: - mpmd_split_bwd_fragments: Set to True to split backward + fragments into separate weight gradient and activation gradient + fragments. This enables independent scheduling of weight and activation + gradients. - + mpmd_merge_inferred_after_scheduling: Set to True to defer merging of + inferred fragments until after scheduling. If False (default), inferred + fragments are merged before scheduling, which may create unintended data + dependencies that constrain your scheduling order. + """ + + merge_rule_builders: Sequence[MergeRuleBuilder] | None = None + schedule_rule_builders: Sequence[ScheduleRuleBuilder] | None = None + schedule_merge_rule_builders: Sequence[ScheduleMergeRuleBuilder] | None = None + required_mpmd_options: Mapping[str, bool | str] | None = None + + +def fragment_origins_contain(fragment: FragmentInfo, substring: str) -> bool: + """Checks if any computation name in fragment origins contains the substring.""" + return any( + substring in origin.computation_name for origin in fragment.origins + ) + + +def build_schedule_rules_from_predicate( + fragment_infos: Sequence[FragmentInfo], + context: PipelineContext, + *, + before_pred: RuleGeneratorPredicate, +) -> FragmentScheduleRules: + """Builds a list of scheduling rules using a binary predicate function.""" + res = [] + for i, a in enumerate(fragment_infos): + for j, b in enumerate(fragment_infos): + if i == j: + continue + if a.mesh_name != b.mesh_name: + continue + + if before_pred(a, b, context): + res.append(FragmentScheduleRule(ordered_fragments=[a, b])) + return res + + +def union_fragment_origins( + source_fragments: Sequence[FragmentInfo], +) -> tuple[FragmentOrigin, ...]: + """Union all origins from a sequence of fragment infos.""" + merged_origins = [] + seen_origins = set() + for fragment in source_fragments: + for origin in fragment.origins: + origin_key = (origin.computation_name, origin.transpose_count) + if origin_key not in seen_origins: + merged_origins.append(origin) + seen_origins.add(origin_key) + return tuple(merged_origins) + + +def minimal_create_target_info( + source_fragments: Sequence[FragmentInfo], +) -> FragmentInfo: + """Creates a target FragmentInfo based on a sequence of source FragmentInfos. + + FragmentMergeRule takes in a FragmentInfo which describes the final fragment + metadata after all sources have been merged. This functions creates a target + info with the minimal amount of information needed to create this target + FragmentInfo. + + Args: + source_fragments: List of source fragment infos to create target info from. + + Returns: + FragmentInfo object representing the target fragment info. + + Raises: + ValueError: If `source_fragments` is empty or fragments have inconsistent + `mesh_name` values. + """ + if not source_fragments: + raise ValueError( + 'Cannot create target info from empty source fragments sequence' + ) + + mesh_name = source_fragments[0].mesh_name + for fragment in source_fragments: + if fragment.mesh_name != mesh_name: + raise ValueError( + f'Inconsistent mesh_name values: {mesh_name} vs {fragment.mesh_name}' + ) + + return FragmentInfo( + origins=union_fragment_origins(source_fragments), + stage_id=None, + call_counter=None, + split_type=None, + mesh_name=mesh_name, + ) + + +def build_merge_rules_from_predicate( + fragment_infos: Sequence[FragmentInfo], + context: PipelineContext, + target_info_builder: TargetInfoBuilder = minimal_create_target_info, + *, + pred: RuleGeneratorPredicate, +) -> list[FragmentMergeRule]: + """Creates a list of fragment merge rules based on a binary predicate. + + Args: + fragment_infos: List of fragments to create merge rules for. + context: PipelineContext object containing additional context for the + scheduling and merging process. + target_info_builder: Function that creates a target fragment info based on + on a list of source fragment infos. Defaults to create_target_info. + pred: Binary predicate function that determines if fragments should be + merged. + + Returns: + List of FragmentMergeRule objects. + """ + merge_rules = [] + for i, fragment_a in enumerate(fragment_infos): + # Order of fragments should not matter for merge rules, so we can skip + # checking pairs of fragments that have already been checked. + for fragment_b in fragment_infos[i + 1 :]: + if fragment_a.mesh_name != fragment_b.mesh_name: + continue + + if pred(fragment_a, fragment_b, context): + merge_rules.append( + FragmentMergeRule( + sources={fragment_a, fragment_b}, + target=target_info_builder([fragment_a, fragment_b]), + ) + ) + return merge_rules + + +def build_rules_from_pipeline( + fragment_infos: Sequence[FragmentInfo], + pipeline: PipelineSchedule, + context: PipelineContext, +) -> tuple[FragmentScheduleRules, FragmentMergeRules]: + """Builds scheduling and merging rules from a PipelineSchedule. + + Args: + fragment_infos: List of fragments to build rules for. + pipeline: PipelineSchedule containing rule generators and options. + context: PipelineContext with pipeline configuration. + + Returns: + Tuple of (schedule_rules, merge_rules) built from rule builders. + """ + # Create a list of fragments for each mesh once + mesh_fragments = collections.defaultdict(list) + for fragment in fragment_infos: + mesh_fragments[fragment.mesh_name].append(fragment) + + all_schedule_rules = [] + if pipeline.schedule_rule_builders: + for builder in pipeline.schedule_rule_builders: + # Run each builder on fragments from each mesh separately + for _, single_mesh_fragments in mesh_fragments.items(): + all_schedule_rules.extend(builder(single_mesh_fragments, context)) + + all_merge_rules = [] + if pipeline.merge_rule_builders: + for builder in pipeline.merge_rule_builders: + # Run each builder on fragments from each mesh separately + for _, single_mesh_fragments in mesh_fragments.items(): + all_merge_rules.extend(builder(single_mesh_fragments, context)) + + if pipeline.schedule_merge_rule_builders: + for builder in pipeline.schedule_merge_rule_builders: + for _, single_mesh_fragments in mesh_fragments.items(): + schedule_rules, merge_rules = builder(single_mesh_fragments, context) + all_schedule_rules.extend(schedule_rules) + all_merge_rules.extend(merge_rules) + + return all_schedule_rules, all_merge_rules + + +def maybe_unique_transpose_count( + fragment: FragmentInfo, +) -> int | None: + """Returns transpose count if all fragment origins have the same value.""" + if not fragment.origins: + return None + + # Check if all origins have the same transpose count. + transpose_counts = {origin.transpose_count for origin in fragment.origins} + if len(transpose_counts) == 1: + return transpose_counts.pop() + + return None + + +def get_scheduling_unit_info(fragment: FragmentInfo) -> tuple[int, int] | None: + """Returns (call_counter, transpose_count) if fragment is a valid scheduling unit. + + A fragment is a valid scheduling unit if it meets all of the following + conditions: + - It is a user fragment (has origins) + - It has a call_counter + - It has a single transpose_count which is 0 or 1 + + Args: + fragment: Fragment to check scheduling unit for. + + Returns: + A tuple of (call_counter, transpose_count) if valid, None otherwise. + """ + if not fragment.origins: + return None + + if fragment.call_counter is None: + return None + + transpose_count = maybe_unique_transpose_count(fragment) + if transpose_count is not None and ( + transpose_count == 0 or transpose_count == 1 + ): + return (fragment.call_counter, transpose_count) + + return None + + +def get_staged_scheduling_info( + f1: FragmentInfo, f2: FragmentInfo, error_msg: str +) -> tuple[int, int, int, int] | None: + """Validates two fragments for scheduling and returns their info. + + Args: + f1: First fragment to validate + f2: Second fragment to validate + error_msg: Error message for stage_id validation + + Returns: + Tuple of (call_counter_f1, transpose_count_f1, call_counter_f2, + transpose_count_f2) if both fragments are valid scheduling units with + stages, None otherwise. + + Raises: + ValueError: If `stage_id` is not set on either of the fragments. + """ + f1_info = get_scheduling_unit_info(f1) + f2_info = get_scheduling_unit_info(f2) + if f1_info is None or f2_info is None: + return None + + if f1.stage_id is None or f2.stage_id is None: + raise ValueError(error_msg) + + call_counter_f1, transpose_count_f1 = f1_info + call_counter_f2, transpose_count_f2 = f2_info + return ( + call_counter_f1, + transpose_count_f1, + call_counter_f2, + transpose_count_f2, + ) diff --git a/shardy/integrations/python/jax/mpmd/pipeline_registry.py b/shardy/integrations/python/jax/mpmd/pipeline_registry.py new file mode 100644 index 000000000..2a468def3 --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/pipeline_registry.py @@ -0,0 +1,178 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Pipeline schedule registry. + +Central registry mapping schedule names to PipelineSchedule objects. Each +schedule defines fragment merging and ordering using binary predicate functions. + +Usage: + schedule = get_pipeline_schedule('1F1B') + config = make_config( + topology=topology, + name_to_mesh_assignment=mesh_assignment, + pipeline_schedule=schedule, + ) +""" + +import functools + +import immutabledict + +from shardy.integrations.python.jax.mpmd import pipeline +from shardy.integrations.python.jax.mpmd import schedule_impl + +ImmutableDict = immutabledict.immutabledict + +PIPELINE_SCHEDULES: ImmutableDict[str, pipeline.PipelineSchedule] = ( + ImmutableDict({ + 'ONE_FWD_ONE_BWD': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.one_fwd_one_bwd_schedule_predicate, + ) + ], + required_mpmd_options={'mpmd_pipeline_schedule': '1F1B'}, + ), + 'GPIPE': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.gpipe_schedule_predicate, + ) + ], + required_mpmd_options={'mpmd_pipeline_schedule': 'GPipe'}, + ), + 'GPIPE_BUT_1F1B_FOR_LAST_MESH': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.gpipe_with_1f1b_on_last_mesh_schedule_predicate, + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'GPipeBut1F1BForLastMesh' + }, + ), + 'ZERO_BUBBLE_H1': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.zero_bubble_h1_schedule_predicate, + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'ZeroBubbleH1', + 'mpmd_split_bwd_fragments': True, + }, + ), + 'ZERO_BUBBLE_H2_ZERO_TX_LATENCY': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.latency_hiding_zero_bubble_h2_schedule_predicate, + latency_stage_fraction=0.0, + ), + ) + ], + required_mpmd_options={ + 'mpmd_split_bwd_fragments': True, + 'mpmd_pipeline_schedule': 'ZeroBubbleH2ZeroTxLatency', + }, + ), + 'ZERO_BUBBLE_H2_HALF_TX_LATENCY': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.latency_hiding_zero_bubble_h2_schedule_predicate, + latency_stage_fraction=0.5, + ), + ) + ], + required_mpmd_options={ + 'mpmd_split_bwd_fragments': True, + 'mpmd_pipeline_schedule': 'ZeroBubbleH2HalfTxLatency', + }, + ), + 'ZERO_BUBBLE_H2_FULL_TX_LATENCY': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.latency_hiding_zero_bubble_h2_schedule_predicate, + latency_stage_fraction=1.0, + ), + ) + ], + required_mpmd_options={ + 'mpmd_split_bwd_fragments': True, + 'mpmd_pipeline_schedule': 'ZeroBubbleH2FullTxLatency', + }, + ), + 'PARALLEL_PIPELINES_WITH_WRAP_AROUND': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=schedule_impl.parallel_pipelines_with_wraparound_schedule_predicate, + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'ParallelPipelinesWithWrapAround', + }, + ), + 'CIRCULAR': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.circular_schedule_predicate_base, + reverse_backward=False, + ), + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'Circular', + }, + ), + 'CIRCULAR_WITH_REVERSED_BACKWARD': pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=functools.partial( + schedule_impl.circular_schedule_predicate_base, + reverse_backward=True, + ), + ) + ], + required_mpmd_options={ + 'mpmd_pipeline_schedule': 'CircularWithReversedBackward', + }, + ), + }) +) + + +def get_pipeline_schedule(schedule_name: str) -> pipeline.PipelineSchedule: + """Get a PipelineSchedule object for the given schedule name.""" + if schedule_name not in PIPELINE_SCHEDULES: + valid_schedules = sorted(PIPELINE_SCHEDULES.keys()) + raise KeyError( + f"Unknown pipeline schedule '{schedule_name}'. " + f'Valid schedules are: {valid_schedules!r}' + ) + return PIPELINE_SCHEDULES[schedule_name] diff --git a/shardy/integrations/python/jax/mpmd/pipeline_test.py b/shardy/integrations/python/jax/mpmd/pipeline_test.py new file mode 100644 index 000000000..a5a1afee7 --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/pipeline_test.py @@ -0,0 +1,213 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for MPMD pipeline functions.""" + +from collections.abc import Sequence +import functools +import unittest +from absl.testing import parameterized +from shardy.integrations.python.jax.mpmd import pipeline + + +def _make_fragment( + mesh_name: str = "mesh1", + origins: Sequence[pipeline.FragmentOrigin] | None = None, + **kwargs, +) -> pipeline.FragmentInfo: + """Helper to create FragmentInfo with common defaults.""" + # Use None instead of [] to avoid shared mutable default argument + if origins is None: + origins = () + return pipeline.FragmentInfo(origins=origins, mesh_name=mesh_name, **kwargs) + + +class BasicScheduleBuildTest(unittest.TestCase): + + def test_predicate_based_schedule_example(self): + def my_schedule_predicate(f1, f2, _): + return f1.call_counter < f2.call_counter + + def my_merge_predicate(f1, f2, _): + return f1.stage_id == f2.stage_id and f1.call_counter == f2.call_counter + + schedule = pipeline.PipelineSchedule( + schedule_rule_builders=[ + functools.partial( + pipeline.build_schedule_rules_from_predicate, + before_pred=my_schedule_predicate, + ) + ], + merge_rule_builders=[ + functools.partial( + pipeline.build_merge_rules_from_predicate, + pred=my_merge_predicate, + ) + ], + required_mpmd_options={}, + ) + + self.assertIsNotNone(schedule) + self.assertEqual(len(schedule.schedule_rule_builders), 1) + self.assertEqual(len(schedule.merge_rule_builders), 1) + + def test_direct_construction_schedule_example(self): + # This test implements the same logic as the predicate-based example above. + def custom_schedule_builder(fragment_infos, _): + forward = sorted( + [ + f + for f in fragment_infos + if f.origins + and pipeline.maybe_unique_transpose_count(f) == 0 + ], + key=lambda f: f.call_counter or 0, + ) + backward = sorted( + [ + f + for f in fragment_infos + if f.origins + and pipeline.maybe_unique_transpose_count(f) == 1 + ], + key=lambda f: f.call_counter or 0, + ) + + execution_order = [] + for fwd, bwd in zip(forward, backward): + execution_order.extend([fwd, bwd]) + + merge_rules = [ + pipeline.FragmentMergeRule( + sources={fwd, bwd}, + target=pipeline.minimal_create_target_info([fwd, bwd]), + ) + for fwd, bwd in zip(forward, backward) + if fwd.stage_id == bwd.stage_id + ] + + return [ + pipeline.FragmentScheduleRule(ordered_fragments=execution_order) + ], merge_rules + + schedule = pipeline.PipelineSchedule( + schedule_merge_rule_builders=[custom_schedule_builder], + required_mpmd_options={}, + ) + + self.assertIsNotNone(schedule) + self.assertEqual(len(schedule.schedule_merge_rule_builders), 1) + + +class MinimalCreateTargetInfoTest(parameterized.TestCase): + + def test_empty_source_fragments_raises_error(self): + with self.assertRaises(ValueError): + pipeline.minimal_create_target_info([]) + + def test_single_fragment(self): + origin = pipeline.FragmentOrigin("comp1", transpose_count=1) + fragment = _make_fragment( + origins=(origin,), + stage_id=5, + call_counter=10, + split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED, + ) + + result = pipeline.minimal_create_target_info([fragment]) + + self.assertEqual(result.origins, (origin,)) + # minimal_create_target_info always sets these to None + self.assertIsNone(result.stage_id) + self.assertIsNone(result.call_counter) + self.assertIsNone(result.split_type) + self.assertEqual(result.mesh_name, "mesh1") + + def test_origins_union_preserves_all_transpose_counts(self): + origin1 = pipeline.FragmentOrigin("comp1", transpose_count=0) + origin2 = pipeline.FragmentOrigin("comp2", transpose_count=1) + origin3 = pipeline.FragmentOrigin( + "comp1", transpose_count=1 + ) # Different transpose_count + + fragment1 = pipeline.FragmentInfo( + origins=(origin1, origin2), mesh_name="mesh1" + ) + fragment2 = pipeline.FragmentInfo(origins=(origin3,), mesh_name="mesh1") + + result = pipeline.minimal_create_target_info([fragment1, fragment2]) + + self.assertCountEqual(result.origins, (origin1, origin2, origin3)) + + def test_origins_union_removes_duplicates(self): + origin1 = pipeline.FragmentOrigin("comp1", transpose_count=0) + origin2 = pipeline.FragmentOrigin("comp2", transpose_count=1) + + fragment1 = pipeline.FragmentInfo( + origins=(origin1, origin2), mesh_name="mesh1" + ) + # `origin1` also exists in fragment1 origins + fragment2 = pipeline.FragmentInfo(origins=(origin1,), mesh_name="mesh1") + + result = pipeline.minimal_create_target_info([fragment1, fragment2]) + # Verify that the duplicate `origin1` does not remain + self.assertCountEqual(result.origins, (origin1, origin2)) + + def test_mesh_name_inconsistency_raises_error(self): + """Test that inconsistent mesh_name values raise ValueError.""" + fragment1 = _make_fragment(mesh_name="mesh1") + fragment2 = _make_fragment(mesh_name="mesh2") + + with self.assertRaises(ValueError) as cm: + pipeline.minimal_create_target_info([fragment1, fragment2]) + self.assertIn( + "Inconsistent mesh_name values: mesh1 vs mesh2", str(cm.exception) + ) + + def test_mesh_name_from_first_fragment(self): + fragment1 = _make_fragment(mesh_name="mesh1") + fragment2 = _make_fragment(mesh_name="mesh1") + + result = pipeline.minimal_create_target_info((fragment1, fragment2)) + + self.assertEqual(result.mesh_name, "mesh1") + + def test_always_returns_none_for_optional_fields(self): + """Test that stage_id, call_counter, and split_type are always None.""" + fragment1 = pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin("comp1", transpose_count=0),), + stage_id=5, + call_counter=10, + split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED, + mesh_name="mesh1", + ) + fragment2 = pipeline.FragmentInfo( + origins=(pipeline.FragmentOrigin("comp2", transpose_count=1),), + stage_id=5, + call_counter=10, + split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED, + mesh_name="mesh1", + ) + + result = pipeline.minimal_create_target_info([fragment1, fragment2]) + + # Regardless of input values, these should always be None + self.assertIsNone(result.stage_id) + self.assertIsNone(result.call_counter) + self.assertIsNone(result.split_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/shardy/integrations/python/jax/mpmd/schedule_impl.py b/shardy/integrations/python/jax/mpmd/schedule_impl.py new file mode 100644 index 000000000..394572f3e --- /dev/null +++ b/shardy/integrations/python/jax/mpmd/schedule_impl.py @@ -0,0 +1,368 @@ +# Copyright 2025 The MPMD Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementations of common pipeline scheduling predicates for MPMD.""" + +from typing import Callable + +from shardy.integrations.python.jax.mpmd import pipeline + + +def gpipe_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + _: pipeline.PipelineContext, +) -> bool: + """Returns true if `f1` must happen before `f2` in a GPipe schedule.""" + transpose_count_f1 = pipeline.maybe_unique_transpose_count(f1) + transpose_count_f2 = pipeline.maybe_unique_transpose_count(f2) + if ( + transpose_count_f1 is None + or transpose_count_f2 is None + or f1.call_counter is None + or f2.call_counter is None + ): + return False + + return (transpose_count_f1, f1.call_counter) < ( + transpose_count_f2, + f2.call_counter, + ) + + +def one_fwd_one_bwd_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, +) -> bool: + """Returns true if f1 must happen before f2 in a 1F1B schedule.""" + result = pipeline.get_staged_scheduling_info( + f1, f2, "All fragments must have a stage id for 1F1B scheduling." + ) + if result is None: + return False + call_counter_f1, transpose_count_f1, call_counter_f2, transpose_count_f2 = ( + result + ) + + # The following two conditions guarantee the forward and backward fragments + # are interleaved in the steady state of the pipeline. + + # Example: in mesh/stage 0 of pipeline of depth 4, the backward computation + # of microbatch 0 must be scheduled before the forward computation of + # microbatch 4: 0 == 4 - 4 + 0. + if transpose_count_f1 == 1 and transpose_count_f2 == 0: + return call_counter_f1 == call_counter_f2 - context.num_meshes + f1.stage_id + + # Example: in mesh/stage 0 of pipeline of depth 4, the forward computation of + # microbatch 5 must be scheduled before the backward computation of + # microbatch 2: 5 == 2 + 4 - (0 + 1). + if transpose_count_f1 == 0 and transpose_count_f2 == 1: + return call_counter_f1 == call_counter_f2 + context.num_meshes - ( + f1.stage_id + 1 + ) + + # If the fragments have the same transpose count, guarantee that the + # call_counter ordering is preserved. + if transpose_count_f1 == transpose_count_f2: + return call_counter_f1 < call_counter_f2 + + return False + + +def gpipe_with_1f1b_on_last_mesh_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, +) -> bool: + """Returns true if f1 must happen before f2 in a GPipe schedule with 1F1B on the last mesh.""" + result = pipeline.get_staged_scheduling_info( + f1, + f2, + "All fragments must have a stage id for GPipe with 1F1B on the last mesh" + " scheduling.", + ) + if result is None: + return False + # Validation successful - delegate to other functions + _ = result + + if f1.stage_id == context.num_meshes - 1: + return one_fwd_one_bwd_schedule_predicate(f1, f2, context) + return gpipe_schedule_predicate(f1, f2, context) + + +def circular_schedule_predicate_base( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, + reverse_backward: bool, +) -> bool: + """Returns true if f1 must happen before f2 in circular schedule.""" + # Check that both fragments are scheduling units + result = pipeline.get_staged_scheduling_info( + f1, f2, "Cannot schedule for circular pipelining without stages." + ) + if result is None: + return False + call_counter_f1, transpose_count_f1, call_counter_f2, transpose_count_f2 = ( + result + ) + + if transpose_count_f1 != transpose_count_f2: + # Forward fragments always happen before backward fragments + return transpose_count_f1 < transpose_count_f2 + + # Both forward or both backward - use phase-based ordering + phase_f1 = call_counter_f1 // context.num_meshes + phase_f2 = call_counter_f2 // context.num_meshes + + f1_list = [phase_f1, f1.stage_id, call_counter_f1] + f2_list = [phase_f2, f2.stage_id, call_counter_f2] + + # Forward fragments - ascending order + if transpose_count_f1 == 0: + return f1_list < f2_list + + # Backward fragments + if reverse_backward: + # Descending order + return f1_list > f2_list + + # Backward fragments with stage in descending order + f1_list[1], f2_list[1] = f2_list[1], f1_list[1] # Swap stage IDs + return f1_list < f2_list + + +def zero_bubble_h1_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, +) -> bool: + """Returns true if f1 must happen before f2 in a ZeroBubbleH1 schedule.""" + result = pipeline.get_staged_scheduling_info( + f1, f2, "All fragments must have a stage id for ZeroBubbleH1 scheduling." + ) + if result is None: + return False + call_counter_f1, transpose_count_f1, call_counter_f2, transpose_count_f2 = ( + result + ) + + is_wgrad_f1 = f1.split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED + is_wgrad_f2 = f2.split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED + + # The following two conditions guarantee the forward and backward fragments + # are interleaved in the steady state of the pipeline. They are just like + # 1F1B but specialized to actual back-propagation fragments. + + # Clause 1: Ba(i) < F(i + num_meshes - stage_id) + if transpose_count_f1 == 1 and not is_wgrad_f1 and transpose_count_f2 == 0: + return call_counter_f1 == call_counter_f2 - context.num_meshes + f1.stage_id + + # Clause 2: F(i + num_meshes - stage_id - 1) < Ba(i) + if transpose_count_f1 == 0 and transpose_count_f2 == 1 and not is_wgrad_f2: + return call_counter_f1 == call_counter_f2 + context.num_meshes - ( + f1.stage_id + 1 + ) + + # The rest of the conditions position the parameter gradient fragments. + # Clause 3: Bw(i) < F(i + num_meshes) + # e.g. Bw(0) < F(4) above. + if ( + transpose_count_f1 == 1 + and (is_wgrad_f1 or f1.stage_id == 0) + and transpose_count_f2 == 0 + ): + return call_counter_f2 - call_counter_f1 == context.num_meshes + + # Clause 4: Ba(i + stage_id) < Bw(i) + # e.g. + # mesh0: Ba(0) < Bw(0) + # mesh1: Ba(1) < Bw(0) + # mesh2: Ba(2) < Bw(0) + # mesh3: Ba(3) < Bw(0) + if ( + transpose_count_f1 == 1 + and not is_wgrad_f1 + and transpose_count_f2 == 1 + and is_wgrad_f2 + ): + return call_counter_f1 - call_counter_f2 == f1.stage_id + + # This is just needed for transitively completing Clauses 3 and 2, needed for + # the final phase where there may be no remaining forward to anchor to. + # Bw(i) < Ba(i + stage_id + 1) + if ( + transpose_count_f1 == 1 + and is_wgrad_f1 + and transpose_count_f2 == 1 + and not is_wgrad_f2 + ): + return call_counter_f2 - call_counter_f1 == f1.stage_id + 1 + + return False + + +def zero_bubble_h2_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, + init_fwd_per_stage_fn: Callable[[int], int], +) -> bool: + """Returns true if f1 must happen before f2 in a ZeroBubbleH2 schedule.""" + result = pipeline.get_staged_scheduling_info( + f1, f2, "All fragments must have a stage id for ZeroBubbleH2 scheduling." + ) + if result is None: + return False + _, transpose_count_f1, _, transpose_count_f2 = result + + is_wgrad_f1 = f1.split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED + is_wgrad_f2 = f2.split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED + + # How many fwd we are allowed to stream before entering steady state + init_fwd = init_fwd_per_stage_fn(f1.stage_id) + # The ZeroBubbleH2 pipeline is diagonally symmetric + complement_init_fwd = init_fwd_per_stage_fn( + context.num_meshes - f1.stage_id - 1 + ) + + # Initial phase + # Clause 1: F(i) <= B(_) for i < init_fwd + if ( + transpose_count_f1 == 0 + and transpose_count_f2 == 1 + and f1.call_counter < init_fwd + ): + return True + + # Clause 2: Ba(i) < F(i + init_fwd) + if ( + transpose_count_f1 == 1 + and not is_wgrad_f1 + and transpose_count_f2 == 0 + and f2.call_counter >= init_fwd + ): + return f2.call_counter - f1.call_counter == init_fwd + + # Clause 3: F(i + init_fwd - 1) < Ba(i) + if ( + transpose_count_f1 == 0 + and f1.call_counter >= init_fwd + and transpose_count_f2 == 1 + and not is_wgrad_f2 + ): + return f1.call_counter - f2.call_counter == init_fwd - 1 + + # Clause 4: Ba(i + complement_init_fwd - 1) < Bw(i) + if ( + transpose_count_f1 == 1 + and not is_wgrad_f1 + and transpose_count_f2 == 1 + and is_wgrad_f2 + ): + return f1.call_counter - f2.call_counter == complement_init_fwd - 1 + + # Clause 5: Bw(i) < Ba(i + complement_init_fwd) + if ( + transpose_count_f1 == 1 + and is_wgrad_f1 + and transpose_count_f2 == 1 + and not is_wgrad_f2 + ): + return f2.call_counter - f1.call_counter == complement_init_fwd + + return False + + +def latency_hiding_zero_bubble_h2_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + context: pipeline.PipelineContext, + latency_stage_fraction: float, +) -> bool: + """Returns true if f1 must happen before f2 in a latency-hiding ZeroBubbleH2 schedule. + + Args: + f1: First fragment to compare. + f2: Second fragment to compare. + context: Pipeline context with configuration. + latency_stage_fraction: Float between 0.0 and 1.0 specifying how much time + activation forwarding transfers take compared to a stage compute time. + """ + if not (0.0 <= latency_stage_fraction <= 1.0): + raise ValueError("latency_stage_fraction must be between 0.0 and 1.0") + + def init_fwds_per_stage(stage_id: int) -> int: + """Calculate number of forward microbatches before first backward.""" + # Number of transfers from beginning until first backward can execute + num_init_transfers = 2.0 * (context.num_meshes - stage_id - 1) + # Compute that has happened in initial first microbatch path + num_init_compute = 2.0 * (context.num_meshes - stage_id) - 1.0 + return int(num_init_compute + num_init_transfers * latency_stage_fraction) + + return zero_bubble_h2_schedule_predicate(f1, f2, context, init_fwds_per_stage) + + +def parallel_pipelines_with_wraparound_schedule_predicate( + f1: pipeline.FragmentInfo, + f2: pipeline.FragmentInfo, + _: pipeline.PipelineContext, +) -> bool: + """Returns true if f1 must happen before f2 in parallel pipelines with wraparound. + + Only supports forward fragments. The entrypoint for mesh{i} is call_counter + {i}. + For each mesh, the order is [F{n}, F{n-1}, ..., F{1}] rotated such that + the leading fragment is F{mesh_index}. + + Args: + f1: First fragment to compare. + f2: Second fragment to compare. + """ + result = pipeline.get_staged_scheduling_info( + f1, + f2, + "All fragments must have a stage id for parallel pipelines scheduling.", + ) + if result is None: + return False + call_counter_f1, transpose_count_f1, call_counter_f2, transpose_count_f2 = ( + result + ) + + # Only forward fragments supported + if transpose_count_f1 != 0 or transpose_count_f2 != 0: + raise ValueError("Only forward fragments supported for parallel pipelines") + + if call_counter_f1 == call_counter_f2: + raise ValueError( + "Should not have duplicate call counter in parallel pipelines" + ) + + # The entrypoint to stage{i} is call_counter {i}, so this always happens + # before + if call_counter_f1 == f1.stage_id or call_counter_f2 == f1.stage_id: + return call_counter_f1 == f1.stage_id + + # stage_id is the pivot. If both call_counters are on the same side of + # the pivot, we flip the order. But if they are on different + # sides, then we take the order as per normal. + if (call_counter_f1 > f1.stage_id and call_counter_f2 > f1.stage_id) or ( + call_counter_f1 < f1.stage_id and call_counter_f2 < f1.stage_id + ): + return call_counter_f1 > call_counter_f2 + + return call_counter_f1 < call_counter_f2 diff --git a/shardy/integrations/python/jax/mpmd/types.py b/shardy/integrations/python/jax/mpmd/types.py index eb9b4ad97..5a9c98367 100644 --- a/shardy/integrations/python/jax/mpmd/types.py +++ b/shardy/integrations/python/jax/mpmd/types.py @@ -15,14 +15,15 @@ """Common types used by PartIR:MPMD.""" -from collections.abc import Mapping, Sequence +from collections.abc import Mapping import dataclasses -import enum +from absl import logging import jax import jaxtyping from shardy.integrations.python.jax.mpmd import partitioning_options as part_options +from shardy.integrations.python.jax.mpmd import pipeline PyTree = jaxtyping.PyTree @@ -40,58 +41,21 @@ PartitioningOptions = dict[str, bool | str] -@dataclasses.dataclass(frozen=True) -class FragmentOrigin: - """The origin of a fragment.""" - - computation_name: str - transpose_count: int = 0 - - -@enum.unique -class SplitFragmentType(enum.Enum): - """Fragment split behavior for transferred data. - - These values indicate how fragment portions handle transferred data from - the original fragment if the fragment is split during compilation: - - KEEP_TRANSFERRED: Fragment portion retains transferred data - - DROP_TRANSFERRED: Fragment portion drops transferred data - """ - - KEEP_TRANSFERRED = enum.auto() - DROP_TRANSFERRED = enum.auto() - - -@dataclasses.dataclass(frozen=True) -class FragmentInfo: - """A fragment of a computation.""" +# LINT.IfChange +CPU_MESH_SUFFIX = '/cpu' +# LINT.ThenChange( +# https://github.com/openxla/shardy/blob/main/shardy/dialect/mpmd/ir/utils.h +# ) - origins: Sequence[FragmentOrigin] - stage_id: int | None = None - call_counter: int | None = None - split_type: SplitFragmentType | None = None - mesh_name: str = '' - -@dataclasses.dataclass(frozen=True) -class FragmentMergeRule: - """A rule for merging fragments of a computation.""" - - sources: Sequence[FragmentInfo] - target: FragmentInfo - - -FragmentMergeRules = Sequence[FragmentMergeRule] +def mesh_is_on_cpu(mesh_name: str) -> bool: + """Returns whether the mesh name is for a cpu mesh.""" + return mesh_name.endswith(CPU_MESH_SUFFIX) -@dataclasses.dataclass(frozen=True) -class FragmentScheduleRule: - """A rule for scheduling fragments of a computation.""" - - ordered_fragments: Sequence[FragmentInfo] - - -FragmentScheduleRules = Sequence[FragmentScheduleRule] +def get_schedulable_meshes(topology: Topology) -> list[str]: + """Returns the names of meshes in the topology that are not CPU meshes.""" + return [name for name in topology if not mesh_is_on_cpu(name)] @dataclasses.dataclass(frozen=True) @@ -127,12 +91,9 @@ class MpmdConfig: reading from arg shardings is not supported. TODO: b/377706756 - Read from arg shardings too, and migrate users to this and remove this option once stabilized. - fragment_merge_rules: A sequence of fragment merge rules. Each merge rule - contains a sequence of fragment metadata objects that should be merged - into a single fragment, together with metadata for the resulting fragment. - fragment_schedule_rules: A sequence of fragment schedule rules. Each - schedule rule contains a sequence of fragment metadata objects in the - order that they should be scheduled. + pipeline_schedule: A PipelineSchedule object used to generate merge and/or + schedule rules for partitioning, as well as set any required MPMD options + for the pipeline. """ topology: Topology @@ -142,8 +103,7 @@ class MpmdConfig: output_mesh_assignment: PyTree[str | None] partitioning_options: PartitioningOptions | None read_input_output_mesh_from_shardings: bool - fragment_merge_rules: FragmentMergeRules | None - fragment_schedule_rules: FragmentScheduleRules | None + pipeline_schedule: pipeline.PipelineSchedule | None @property def _spmd_mesh(self) -> jax.sharding.Mesh: @@ -211,8 +171,7 @@ def make_config( output_mesh_assignment: PyTree[str | None] = (), partitioning_options: PartitioningOptions | None = None, read_input_output_mesh_from_shardings: bool = False, - fragment_merge_rules: FragmentMergeRules | None = None, - fragment_schedule_rules: FragmentScheduleRules | None = None, + pipeline_schedule: pipeline.PipelineSchedule | None = None, ) -> MpmdConfig: """Creates a `MpmdConfig`, inferring the tpu topology if not provided. @@ -227,8 +186,7 @@ def make_config( output_mesh_assignment: See `MpmdConfig`. partitioning_options: See `MpmdConfig`. read_input_output_mesh_from_shardings: see `MpmdConfig`. - fragment_merge_rules: See `MpmdConfig`. - fragment_schedule_rules: See `MpmdConfig`. + pipeline_schedule: See `MpmdConfig`. Returns: An `MpmdConfig` object. @@ -255,13 +213,6 @@ def make_config( input_mesh_assignment, output_mesh_assignment, ) - if fragment_merge_rules is None: - fragment_merge_rules = [] - validate_fragment_merge_rules(fragment_merge_rules) - - if fragment_schedule_rules is None: - fragment_schedule_rules = [] - validate_fragment_schedule_rules(fragment_schedule_rules) return MpmdConfig( topology, @@ -271,8 +222,7 @@ def make_config( output_mesh_assignment, partitioning_options, read_input_output_mesh_from_shardings, - fragment_merge_rules, - fragment_schedule_rules, + pipeline_schedule, ) @@ -343,65 +293,6 @@ def validate_input_output_mesh_assignments( ) -def validate_fragment_rule_origins( - fragment_sequence: Sequence[FragmentInfo], -) -> None: - for fragment in fragment_sequence: - if not fragment.origins: - raise ValueError( - f'Each fragment must have at least one origin, but got {fragment} in' - f' {fragment_sequence}.' - ) - - -def validate_fragment_rule_meshes( - fragment_sequence: Sequence[FragmentInfo], -) -> None: - first_mesh = fragment_sequence[0].mesh_name - if not all( - fragment.mesh_name == first_mesh for fragment in fragment_sequence - ): - raise ValueError( - 'Fragments being merged/scheduled must be on the same mesh, but got' - f' {fragment_sequence}.' - ) - - -def validate_fragment_merge_rules( - fragment_merge_rules: FragmentMergeRules, -) -> None: - """Validates the fragment merge rules.""" - - for rule in fragment_merge_rules: - if len(rule.sources) < 2: - raise ValueError( - 'Fragment merge rule must contain at least two source fragments, but' - f' got {rule}.' - ) - validate_fragment_rule_origins(rule.sources) - validate_fragment_rule_meshes(rule.sources) - - if not rule.target.origins: - raise ValueError( - f'Target fragment must have at least one origin, but got {rule}.' - ) - - -def validate_fragment_schedule_rules( - fragment_schedule_rules: FragmentScheduleRules, -) -> None: - """Validates the fragment schedule rules.""" - for rule in fragment_schedule_rules: - if len(rule.ordered_fragments) < 2: - raise ValueError( - 'Fragment schedule rule must contain at least two fragments, but' - f' got {rule}.' - ) - - validate_fragment_rule_origins(rule.ordered_fragments) - validate_fragment_rule_meshes(rule.ordered_fragments) - - def mesh_names( pytree: PyTree[ jax.Array @@ -453,6 +344,106 @@ class FunctionIOMeshAssignment: output_meshes: PyTree[str] +def override_partitioning_options( + mpmd_options: Mapping[str, bool | str] | None, + base_options_to_override: PartitioningOptions | None = None, +) -> PartitioningOptions | None: + """Overrides the base partitioning options with the given MPMD options.""" + if mpmd_options is None: + return base_options_to_override + + _validate_partitioning_options(mpmd_options) + + options = {} + if base_options_to_override is not None: + options.update(base_options_to_override) + + options.update(mpmd_options) + return options + + +def check_partitioning_option_conflicts( + pipeline_required_options: Mapping[str, bool | str], + user_options_dict: Mapping[str, bool | str], +) -> list[str]: + """Checks for conflicts between pipeline requirements and user options. + + Args: + pipeline_required_options: Options required by the pipeline schedule. + user_options_dict: Options explicitly set by the user. + + Returns: + List of conflict error messages. Empty if no conflicts. + Logs warnings for options that will be set automatically. + """ + conflicts = [] + for k, required_value in pipeline_required_options.items(): + if k in user_options_dict: + user_value = user_options_dict[k] + if user_value != required_value: + conflicts.append( + f" - '{k}': pipeline schedule requires {required_value}, " + f'but user specified {user_value}' + ) + else: + # Option not set by user, will be set automatically + logging.warning( + 'Setting partitioning option %r to %s (required by pipeline' + ' schedule)', + k, + required_value, + ) + return conflicts + + +def validate_and_merge_partitioning_options( + pipeline_required_options: Mapping[str, bool | str] | None, + user_provided_options: PartitioningOptions | None, +) -> PartitioningOptions | None: + """Validates and merges user options with pipeline requirements. + + Ensures that user-provided partitioning options don't conflict with options + required by the pipeline schedule. If conflicts are found, it raises an error. + If the user hasn't set all required options, it logs warnings for the options + that still need to be set and will set them automatically. + + Args: + pipeline_required_options: Options required by the pipeline schedule. + user_provided_options: Options provided by the user via MpmdConfig. This is + expected to be a dict or None. + + Returns: + Merged partitioning options with pipeline requirements taking precedence + where the user hasn't specified them. + + Raises: + ValueError: If user options conflict with pipeline required options. + """ + if pipeline_required_options is None: + return user_provided_options + + user_options_dict = user_provided_options if user_provided_options else {} + + # Check for conflicts and log warnings + conflicts = check_partitioning_option_conflicts( + pipeline_required_options, user_options_dict + ) + + if conflicts: + conflict_msg = '\n'.join(conflicts) + raise ValueError( + f'Conflicting partitioning options detected:\n{conflict_msg}\n' + 'Please remove these options from your MpmdConfig or ensure they ' + 'match the pipeline schedule requirements.' + ) + + # Merge the options + return override_partitioning_options( + mpmd_options=pipeline_required_options, + base_options_to_override=user_provided_options, + ) + + def _validate_partitioning_options( partitioning_options: Mapping[str, bool | str] | None, ): diff --git a/shardy/integrations/python/jax/mpmd/types_test.py b/shardy/integrations/python/jax/mpmd/types_test.py index 09876a2f4..d9f270cc2 100644 --- a/shardy/integrations/python/jax/mpmd/types_test.py +++ b/shardy/integrations/python/jax/mpmd/types_test.py @@ -259,5 +259,77 @@ def test_invalid_pipeline_schedule_raises(self): ) +class ValidateAndMergePartitioningOptionsTest(parameterized.TestCase): + """Tests for validate_and_merge_partitioning_options function.""" + + def test_no_pipeline_options_returns_user_options(self): + user_options = {'mpmd_infer_transfers': True} + result = types.validate_and_merge_partitioning_options( + pipeline_required_options=None, user_provided_options=user_options + ) + self.assertEqual(user_options, result) + + def test_no_user_options_returns_pipeline_options(self): + result = types.validate_and_merge_partitioning_options( + pipeline_required_options={'mpmd_infer_transfers': True}, + user_provided_options=None, + ) + self.assertEqual({'mpmd_infer_transfers': True}, result) + + def test_no_conflict_with_compatible_options(self): + result = types.validate_and_merge_partitioning_options( + pipeline_required_options={ + 'mpmd_infer_transfers': True, + 'mpmd_split_bwd_fragments': True, + }, + user_provided_options={ + 'mpmd_infer_transfers': True, + 'mpmd_fragment_remat': False, + }, + ) + expected = { + 'mpmd_infer_transfers': True, + 'mpmd_fragment_remat': False, + 'mpmd_split_bwd_fragments': True, + } + self.assertEqual(expected, result) + + def test_conflict_raises_error(self): + with self.assertRaisesRegex( + ValueError, + r'(?s)Conflicting partitioning options detected:.*' + r'mpmd_infer_transfers.*' + r'pipeline schedule requires True.*' + r'user specified False', + ): + types.validate_and_merge_partitioning_options( + pipeline_required_options={'mpmd_infer_transfers': True}, + user_provided_options={'mpmd_infer_transfers': False}, + ) + + def test_pipeline_schedule_conflict_raises_error(self): + with self.assertRaisesRegex( + ValueError, + r'(?s)Conflicting partitioning options detected:.*' + r'mpmd_pipeline_schedule.*' + r'pipeline schedule requires 1F1B.*' + r'user specified GPipe', + ): + types.validate_and_merge_partitioning_options( + pipeline_required_options={'mpmd_pipeline_schedule': '1F1B'}, + user_provided_options={'mpmd_pipeline_schedule': 'GPipe'}, + ) + + def test_warnings_for_new_options(self): + with self.assertLogs(level='WARNING') as logs: + types.validate_and_merge_partitioning_options( + pipeline_required_options={'mpmd_infer_transfers': True}, + user_provided_options=None, + ) + self.assertLen(logs.output, 1) + self.assertIn('mpmd_infer_transfers', logs.output[0]) + self.assertIn('required by pipeline schedule', logs.output[0]) + + if __name__ == '__main__': absltest.main()