Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions shardy/dialect/mpmd/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ inline constexpr StringRef kIsSdyPartitioned = "mpmd.is_sdy_partitioned";
inline constexpr StringRef kIsGspmdPartitioned = "mpmd.is_gspmd_partitioned";

// The suffix of the mesh name for a CPU mesh.
// LINT.IfChange
constexpr StringRef kCpuMeshSuffix = "/cpu";
// LINT.ThenChange(
// https://github.com/openxla/shardy/blob/main/shardy/integrations/python/jax/mpmd/types.py
// )

// Memory kind attributes.
// Attr on func args and results to indicate whether the value lives on host or
Expand Down
8 changes: 4 additions & 4 deletions shardy/integrations/python/jax/mpmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from shardy.integrations.python.jax.mpmd.ops import named_computation
from shardy.integrations.python.jax.mpmd.ops import named_tensor
from shardy.integrations.python.jax.mpmd.ops import reduce
from shardy.integrations.python.jax.mpmd.pipeline import FragmentInfo
from shardy.integrations.python.jax.mpmd.pipeline import FragmentMergeRule
from shardy.integrations.python.jax.mpmd.pipeline import FragmentMergeRules
from shardy.integrations.python.jax.mpmd.pipeline import FragmentOrigin
from shardy.integrations.python.jax.mpmd.stages import MpmdCompiled as Compiled
from shardy.integrations.python.jax.mpmd.stages import MpmdExecutable as Executable
from shardy.integrations.python.jax.mpmd.stages import MpmdJitShardingInfo
from shardy.integrations.python.jax.mpmd.stages import MpmdLowered as Lowered
from shardy.integrations.python.jax.mpmd.types import FragmentInfo
from shardy.integrations.python.jax.mpmd.types import FragmentMergeRule
from shardy.integrations.python.jax.mpmd.types import FragmentMergeRules
from shardy.integrations.python.jax.mpmd.types import FragmentOrigin
from shardy.integrations.python.jax.mpmd.types import FunctionIOMeshAssignment
from shardy.integrations.python.jax.mpmd.types import make_config
from shardy.integrations.python.jax.mpmd.types import mesh_names
Expand Down
46 changes: 27 additions & 19 deletions shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,29 +169,35 @@ PartitioningResult MpmdProgram::ApplyPartitioning(PartitioningPhase phases) {

func::FuncOp main_func = GetMainFunction(module);
SetTopology(named_meshes, main_func);
SetArgDonationAttributes(main_func, donate_argnums);

// It is not necessary to do this
// validation after the export pipeline because here we're only checking that
// the attributes set on the main func are consistent with the received donate
// args.
VerifyOnlyDonatedArgsHaveDonationAttributes(main_func, donate_argnums);
if (phases & PartitioningPhase::kImport) {
SetArgDonationAttributes(main_func, donate_argnums);

SDY_LOG(INFO) << "Importing function named " << func_name
<< " for MPMD partitioning.";
// It is not necessary to do this validation after the export pipeline
// because here we're only checking that the attributes set on the main func
// are consistent with the received donate args.
VerifyOnlyDonatedArgsHaveDonationAttributes(main_func, donate_argnums);

Import(module);
SDY_LOG(INFO) << "Importing function named " << func_name
<< " for MPMD partitioning.";

SDY_LOG(INFO) << "Optimizing function named " << func_name
<< " for pipeline parallelism.";
Optimize(module);
Import(module);
}

if (phases & PartitioningPhase::kOptimize) {
SDY_LOG(INFO) << "Optimizing function named " << func_name
<< " for pipeline parallelism.";
Optimize(module);
}

SDY_LOG(INFO) << "Applying SDY propagation to function named " << func_name
<< ".";
PropagateSharding(module);
if (phases & PartitioningPhase::kPartition) {
SDY_LOG(INFO) << "Applying SDY propagation to function named " << func_name
<< ".";
PropagateSharding(module);

SDY_LOG(INFO) << "Exporting MPMD function named " << func_name << ".";
Export(module);
SDY_LOG(INFO) << "Exporting MPMD function named " << func_name << ".";
Export(module);
}

return PartitioningResult(module);
}
Expand All @@ -206,7 +212,8 @@ void MpmdProgram::Import(ModuleOp module) {
ConvertMeshVectorToMap(input_meshes)};
import_options.outputIndexToMeshAssignment = {
ConvertMeshVectorToMap(output_meshes)};
import_options.mergeAfterScheduling = options.mpmd_merge_after_scheduling;
import_options.mergeAfterScheduling =
options.mpmd_merge_inferred_after_scheduling;
import_options.absorbInferredFragmentsOnEntryPointFunction =
options.mpmd_absorb_inferred_fragments_on_entry_point_function;
import_options.cloneInferredFragments =
Expand All @@ -227,7 +234,8 @@ void MpmdProgram::Optimize(ModuleOp module) {

OptimizeOptions optimize_options;
optimize_options.fragmentMergeRules = llvm::to_vector(fragment_merge_rules);
optimize_options.mergeAfterScheduling = options.mpmd_merge_after_scheduling;
optimize_options.mergeAfterScheduling =
options.mpmd_merge_inferred_after_scheduling;
optimize_options.applyFragmentRemat = options.mpmd_fragment_remat;
optimize_options.mergeRematFragments = options.mpmd_merge_remat_fragments;
optimize_options.absorbInferredFragmentsOnEntryPointFunction =
Expand Down
3 changes: 2 additions & 1 deletion shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct PartitioningOptions {
bool mpmd_absorb_inferred_fragments_on_entry_point_function = false;
bool mpmd_copy_constant_creation_from_producer_to_consumer = false;
bool mpmd_apply_merge_transfers_pass = false;
bool mpmd_merge_after_scheduling = false;
bool mpmd_merge_inferred_after_scheduling = false;
};

PartitioningOptions ParsePartitioningOptions(
Expand All @@ -120,6 +120,7 @@ struct MpmdProgram {
const std::vector<std::optional<std::string>>& output_meshes;
const std::vector<int64_t>& donate_argnums;
const mlir::mpmd::FragmentMergeRules& fragment_merge_rules;
const mlir::mpmd::FragmentScheduleRules& fragment_schedule_rules;

// Runs the PartIR MPMD partitioning passes on the MPMD program.
//
Expand Down
117 changes: 117 additions & 0 deletions shardy/integrations/python/jax/mpmd/jaxlib_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2025 The MPMD Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for converting between Python types and jaxlib pybind pipeline."""

from jaxlib import _sdy_mpmd as jaxlib_mpmd

from shardy.integrations.python.jax.mpmd import pipeline


def _to_jaxlib_split_type(
split_type: pipeline.SplitFragmentType | None,
) -> jaxlib_mpmd.SplitFragmentType | None:
"""Convert native Python enum to pybinded enum."""
if split_type is None:
return None
if split_type == pipeline.SplitFragmentType.KEEP_TRANSFERRED:
return jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED
elif split_type == pipeline.SplitFragmentType.DROP_TRANSFERRED:
return jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED
else:
raise ValueError(f'Unknown SplitFragmentType: {split_type}')


def _from_jaxlib_split_type(
split_type: jaxlib_mpmd.SplitFragmentType | None,
) -> pipeline.SplitFragmentType | None:
"""Convert pybinded enum to native Python enum."""
if split_type is None:
return None
if split_type == jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED:
return pipeline.SplitFragmentType.KEEP_TRANSFERRED
elif split_type == jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED:
return pipeline.SplitFragmentType.DROP_TRANSFERRED
else:
raise ValueError(f'Unknown jaxlib_mpmd.SplitFragmentType: {split_type}')


def convert_fragment_info_to_pybind(
fragment: pipeline.FragmentInfo,
) -> jaxlib_mpmd.FragmentInfo:
"""Converts FragmentInfo to jaxlib_mpmd.FragmentInfo."""
return jaxlib_mpmd.FragmentInfo(
origins=[
jaxlib_mpmd.FragmentOrigin(
origin.computation_name, origin.transpose_count
)
for origin in fragment.origins
],
stage_id=fragment.stage_id,
call_counter=fragment.call_counter,
split_type=_to_jaxlib_split_type(fragment.split_type),
mesh_name=fragment.mesh_name,
)


def convert_pybind_fragment_info_to_types(
fragment: jaxlib_mpmd.FragmentInfo,
) -> pipeline.FragmentInfo:
"""Converts jaxlib_mpmd.FragmentInfo to FragmentInfo."""
return pipeline.FragmentInfo(
origins=tuple(
pipeline.FragmentOrigin(
origin.computation_name, origin.transpose_count
)
for origin in fragment.origins
),
stage_id=fragment.stage_id,
call_counter=fragment.call_counter,
split_type=_from_jaxlib_split_type(fragment.split_type),
mesh_name=fragment.mesh_name,
)


def convert_fragment_merge_rules_to_pybind(
fragment_merge_rules: pipeline.FragmentMergeRules,
) -> list[jaxlib_mpmd.FragmentMergeRule]:
"""Converts fragment merge rules to jaxlib_mpmd.FragmentMergeRules."""
pybind_fragment_merge_rules = []
for rule in fragment_merge_rules:
fragments = [
convert_fragment_info_to_pybind(fragment) for fragment in rule.sources
]
pybind_fragment_merge_rules.append(
jaxlib_mpmd.FragmentMergeRule(
sources=fragments,
target=convert_fragment_info_to_pybind(rule.target),
)
)
return pybind_fragment_merge_rules


def convert_fragment_schedule_rules_to_pybind(
fragment_schedule_rules: pipeline.FragmentScheduleRules,
) -> list[jaxlib_mpmd.FragmentScheduleRule]:
"""Converts fragment schedule rules to jaxlib_mpmd.FragmentScheduleRules."""
pybind_fragment_schedule_rules = []
for rule in fragment_schedule_rules:
fragments = [
convert_fragment_info_to_pybind(fragment)
for fragment in rule.ordered_fragments
]
pybind_fragment_schedule_rules.append(
jaxlib_mpmd.FragmentScheduleRule(ordered_fragments=fragments)
)
return pybind_fragment_schedule_rules
146 changes: 146 additions & 0 deletions shardy/integrations/python/jax/mpmd/jaxlib_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2025 The MPMD Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for jaxlib conversion utilities."""

from absl.testing import absltest
from absl.testing import parameterized
from jaxlib import _sdy_mpmd as jaxlib_mpmd

from shardy.integrations.python.jax.mpmd import jaxlib_utils
from shardy.integrations.python.jax.mpmd import pipeline


class SplitTypeConversionTest(parameterized.TestCase):
"""Tests for SplitFragmentType conversion functions."""

@parameterized.named_parameters(
(
'keep_transferred',
pipeline.SplitFragmentType.KEEP_TRANSFERRED,
jaxlib_mpmd.SplitFragmentType.KEEP_TRANSFERRED,
),
(
'drop_transferred',
pipeline.SplitFragmentType.DROP_TRANSFERRED,
jaxlib_mpmd.SplitFragmentType.DROP_TRANSFERRED,
),
('none', None, None),
)
def test_bidirectional_conversion(self, python_val, pybind_val):
"""Test SplitFragmentType conversion in both directions."""
self.assertEqual(jaxlib_utils._to_jaxlib_split_type(python_val), pybind_val)
self.assertEqual(
jaxlib_utils._from_jaxlib_split_type(pybind_val), python_val
)


class FragmentInfoConversionTest(parameterized.TestCase):
"""Tests for FragmentInfo conversion functions."""

@parameterized.named_parameters(
(
'single_origin',
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('comp1', 0),), mesh_name='mesh1'
),
),
(
'multiple_origins',
pipeline.FragmentInfo(
origins=(
pipeline.FragmentOrigin('comp1', 0),
pipeline.FragmentOrigin('comp2', 1),
),
mesh_name='mesh1',
),
),
(
'all_fields',
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('comp1', 2),),
stage_id=5,
call_counter=3,
split_type=pipeline.SplitFragmentType.KEEP_TRANSFERRED,
mesh_name='mesh2',
),
),
)
def test_roundtrip(self, fragment):
"""Test Python → pybind → Python roundtrip preserves data."""
pybind_fragment = jaxlib_utils.convert_fragment_info_to_pybind(fragment)
result = jaxlib_utils.convert_pybind_fragment_info_to_types(pybind_fragment)
self.assertEqual(result, fragment)


class FragmentMergeRulesConversionTest(absltest.TestCase):
"""Tests for FragmentMergeRule conversion functions."""

def test_single_rule(self):
"""Test converting single merge rule."""
f1 = pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('f1', 0),), mesh_name='m1'
)
f2 = pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('f2', 0),), mesh_name='m1'
)
target = pipeline.FragmentInfo(
origins=(
pipeline.FragmentOrigin('f1', 0),
pipeline.FragmentOrigin('f2', 0),
),
mesh_name='m1',
)

rule = pipeline.FragmentMergeRule(sources={f1, f2}, target=target)
result = jaxlib_utils.convert_fragment_merge_rules_to_pybind([rule])

self.assertLen(result, 1)
self.assertLen(result[0].sources, 2)
self.assertLen(result[0].target.origins, 2)


class FragmentScheduleRulesConversionTest(absltest.TestCase):
"""Tests for FragmentScheduleRule conversion functions."""

def test_preserves_order(self):
"""Test that ordered_fragments order is preserved."""
frags = [
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('first', 0),), mesh_name='m1'
),
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('second', 0),), mesh_name='m1'
),
pipeline.FragmentInfo(
origins=(pipeline.FragmentOrigin('third', 0),), mesh_name='m1'
),
]

rule = pipeline.FragmentScheduleRule(ordered_fragments=frags)
result = jaxlib_utils.convert_fragment_schedule_rules_to_pybind([rule])

self.assertEqual(
result[0].ordered_fragments[0].origins[0].computation_name, 'first'
)
self.assertEqual(
result[0].ordered_fragments[1].origins[0].computation_name, 'second'
)
self.assertEqual(
result[0].ordered_fragments[2].origins[0].computation_name, 'third'
)


if __name__ == '__main__':
absltest.main()
Loading
Loading