From 501be9ebd6d605d52ce9f42ccdf1c4edfead5637 Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Tue, 2 Sep 2025 05:48:52 -0700 Subject: [PATCH] [mpmd] Use global mesh. PiperOrigin-RevId: 802106940 --- shardy/dialect/mpmd/ir/dialect.cc | 11 +- .../test/mark_fragment_reserved_memory.mlir | 215 +++++++++++++----- 2 files changed, 158 insertions(+), 68 deletions(-) diff --git a/shardy/dialect/mpmd/ir/dialect.cc b/shardy/dialect/mpmd/ir/dialect.cc index 5fc8c90a9..ea39563a4 100644 --- a/shardy/dialect/mpmd/ir/dialect.cc +++ b/shardy/dialect/mpmd/ir/dialect.cc @@ -258,19 +258,18 @@ RankedTensorType MeshTensorType::getLocalTensorType(sdy::MeshAttr sdy_mesh) { } RankedTensorType MeshTensorType::getLocalTensorType(Operation* op) { + sdy::TensorShardingAttr sharding = getSharding(); + if (!sharding) { + return getGlobalTensorType(); + } auto func_op = sdy::getEnclosingOfType(op); if (HasHomogeneousTopology(func_op)) { // TODO(b/439770762): Remove this once we have correct global meshes. return MeshTensorType::getLocalTensorType( GetTopologyMeshes(func_op).front().getMesh()); } - sdy::TensorShardingAttr sharding = getSharding(); - if (!sharding) { - return getGlobalTensorType(); - } - // TODO(b/441487083): Look up the mesh in the global mesh registry. return MeshTensorType::getLocalTensorType( - GetMeshOrFail(op, sharding.getMeshName())); + sdy::getMeshOp(op, sharding.getMeshName()).getMeshAttr()); } // Functions for the ShapedTypeInterface. diff --git a/shardy/dialect/mpmd/transforms/export/test/mark_fragment_reserved_memory.mlir b/shardy/dialect/mpmd/transforms/export/test/mark_fragment_reserved_memory.mlir index 73b7753f3..6b83f640b 100644 --- a/shardy/dialect/mpmd/transforms/export/test/mark_fragment_reserved_memory.mlir +++ b/shardy/dialect/mpmd/transforms/export/test/mark_fragment_reserved_memory.mlir @@ -1,77 +1,99 @@ -// RUN: mpmd_opt %s -mpmd-mark-fragment-reserved-memory 2>&1 | FileCheck %s +// RUN: mpmd_opt %s -mpmd-mark-fragment-reserved-memory -split-input-file +// 2>&1 | FileCheck %s // NOTE: -// - mesh_1_tensor and mesh_2_tensor is 128 bytes. -// - mesh_1_tensor_dist_x is 32 bytes. +// - mesh_1_tensor is 128 bytes. !mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> -!mesh_1_tensor_dist_x = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>, sharding=<@m1, [{"x"}, {?}]>> -!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> -// CHECK-LABEL: func @single_mesh -func.func @single_mesh(%arg0: !mesh_1_tensor, %arg1: !mesh_1_tensor) - -> (!mesh_1_tensor) attributes { - "topology"=#mpmd.topology< - <"m1": <["x"=4]>>>} { +// CHECK-LABEL: module @single_mesh +module @single_mesh { + sdy.mesh @mesh = <["x"=4]> + + func.func @main(%arg0: !mesh_1_tensor, %arg1: !mesh_1_tensor) + -> (!mesh_1_tensor) attributes { + "topology"=#mpmd.topology< + <"m1": <["x"=4]>>>} { + + // Fragment only takes inputs from the function, no intermediates, so 0. + // CHECK: mpmd.fragment + // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 0 + %0 = mpmd.fragment (%arg0, %arg1) + (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { + %4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32> + mpmd.return %4 : tensor<4x8xf32> + } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor + + // Fragment takes one input from the function and one intermediates, so 128 due to + // %arg0. + // CHECK: mpmd.fragment + // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 128 + %1 = mpmd.fragment (%0, %arg1) + (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { + %4 = stablehlo.multiply %arg2, %arg3 : tensor<4x8xf32> + mpmd.return %4 : tensor<4x8xf32> + } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor + + // Fragment takes two intermediates, so 256 due to %arg0 and %arg1 + // CHECK: mpmd.fragment + // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 256 + %2 = mpmd.fragment (%0, %1) + (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { + %4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32> + mpmd.return %4 : tensor<4x8xf32> + } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor + + // Fragment takes an input and a intermediate %2. Note %0's and %1's last use + // CHECK: mpmd.fragment + // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 128 + %3 = mpmd.fragment (%arg0, %2) + (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { + %4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32> + mpmd.return %4 : tensor<4x8xf32> + } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor + + func.return %3 : !mesh_1_tensor + } +} - // Fragment only takes inputs from the function, no intermediates, so 0. - // CHECK: mpmd.fragment - // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 0 - %0 = mpmd.fragment (%arg0, %arg1) - (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { - %4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32> - mpmd.return %4 : tensor<4x8xf32> - } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor - // Fragment takes one input from the function and one intermediates, so 128 due to - // %arg0. - // CHECK: mpmd.fragment - // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 128 - %1 = mpmd.fragment (%0, %arg1) - (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { - %4 = stablehlo.multiply %arg2, %arg3 : tensor<4x8xf32> - mpmd.return %4 : tensor<4x8xf32> - } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor +// ----- - // Fragment takes two intermediates, so 256 due to %arg0 and %arg1 - // CHECK: mpmd.fragment - // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 256 - %2 = mpmd.fragment (%0, %1) - (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { - %4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32> - mpmd.return %4 : tensor<4x8xf32> - } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor - // Fragment takes an input and a intermediate %2. Note %0's and %1's last use - // CHECK: mpmd.fragment - // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 128 - %3 = mpmd.fragment (%arg0, %2) - (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { - %4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32> - mpmd.return %4 : tensor<4x8xf32> - } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> - func.return %3 : !mesh_1_tensor +// CHECK-LABEL: module @duplicate_input +module @duplicate_input { + sdy.mesh @mesh = <["x"=4]> + + // Make sure we don't subtract twice from live memory usage if a fragment takes + // two of the same inputs. + // CHECK-LABEL: func @duplicate_input + func.func @main(%arg0: !mesh_1_tensor) + -> (!mesh_1_tensor) attributes { + "topology"=#mpmd.topology< + <"m1": <["x"=4]>>>} { + + // Fragment only takes inputs from the function, no intermediates, so 0. + // CHECK: mpmd.fragment + // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 0 + %0 = mpmd.fragment (%arg0, %arg0) + (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { + %1 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32> + mpmd.return %1 : tensor<4x8xf32> + } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor + + func.return %0 : !mesh_1_tensor + } } -// Make sure we don't subtract twice from live memory usage if a fragment takes -// two of the same inputs. -// CHECK-LABEL: func @duplicate_input -func.func @duplicate_input(%arg0: !mesh_1_tensor) - -> (!mesh_1_tensor) attributes { - "topology"=#mpmd.topology< - <"m1": <["x"=4]>>>} { - // Fragment only takes inputs from the function, no intermediates, so 0. - // CHECK: mpmd.fragment - // CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 0 - %0 = mpmd.fragment (%arg0, %arg0) - (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { - %1 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32> - mpmd.return %1 : tensor<4x8xf32> - } : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor +// ----- - func.return %0 : !mesh_1_tensor -} + +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + +module { + sdy.mesh @mesh = <["x"=4]> // CHECK-LABEL: func @offloaded_value func.func @offloaded_value(%arg0: !mesh_1_tensor, %arg1: !mesh_1_tensor) @@ -112,6 +134,19 @@ func.func @offloaded_value(%arg0: !mesh_1_tensor, %arg1: !mesh_1_tensor) func.return %2 : !mesh_1_tensor } +} + + +// ----- + + +// NOTE: +// - mesh_1_tensor and mesh_2_tensor is 128 bytes. +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> +!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> + +module { + sdy.mesh @mesh = <["x"=4]> // Same test as `@single_mesh` but now with some tensors existing on other // meshes. @@ -119,7 +154,7 @@ func.func @offloaded_value(%arg0: !mesh_1_tensor, %arg1: !mesh_1_tensor) func.func @multiple_meshes(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mesh_2_tensor) attributes { "topology"=#mpmd.topology< - <"m1": <["x"=4]>>, <"m2": <["x"=2]>>>} { + <"m1": <["x"=4]>>, <"m2": <["y"=2]>>>} { %0 = mpmd.transfer %arg0 : (!mesh_1_tensor) -> !mesh_2_tensor %1 = mpmd.transfer %arg1 : (!mesh_2_tensor) -> !mesh_1_tensor @@ -166,6 +201,20 @@ func.func @multiple_meshes(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) func.return %7 : !mesh_2_tensor } +} + + +// ----- + + +// NOTE: +// - mesh_1_tensor is 128 bytes. +// - mesh_1_tensor_dist_x is 32 bytes. +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> +!mesh_1_tensor_dist_x = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>, sharding=<@m1, [{"x"}, {?}]>> + +module { + sdy.mesh @mesh = <["x"=4, "y"=2]> // Tests that the pass accounts for the per device local shape, not global // shape, if the tensor is distributed. @@ -208,6 +257,16 @@ func.func @distributed_tensor(%arg0: !mesh_1_tensor) func.return %3 : !mesh_1_tensor } +} + + +// ----- + + +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + +module { + sdy.mesh @mesh = <["x"=2]> // Test that verifies the unused output of a fragment is not accounted for in // the live buffers. @@ -239,6 +298,16 @@ func.func @unused_fragment_result_is_not_counted( return %1, %arg3 : !mesh_1_tensor, !mesh_1_tensor } +} + + +// ----- + + +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + +module { + sdy.mesh @mesh = <["x"=2]> // Test that verifies a donated program argument is not accounted for after its // last use. The test verifies both jax.buffer_donor and tf.aliasing_output @@ -282,6 +351,17 @@ func.func @donated_program_arg_is_not_counted_after_last_use( return %1, %2 : !mesh_1_tensor, !mesh_1_tensor } +} + + +// ----- + + +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + +module { + sdy.mesh @mesh = <["x"=4]> + // Test that verifies args on hosts or donated args are not accounted for. // CHECK-LABEL: func @offloaded_or_unused_donated_args_are_not_counted @@ -313,6 +393,16 @@ func.func @offloaded_or_unused_donated_args_are_not_counted( func.return %0#0, %1#0 : !mesh_1_tensor, !mesh_1_tensor } +} + + +// ----- + + +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + +module { + sdy.mesh @mesh = <["x"=4]> // CHECK-LABEL: func @unused_input_not_donated func.func @unused_input_not_donated(%arg0: !mesh_1_tensor, %unused_arg1: !mesh_1_tensor) @@ -330,3 +420,4 @@ func.func @unused_input_not_donated(%arg0: !mesh_1_tensor, %unused_arg1: !mesh_1 func.return %0 : !mesh_1_tensor } +}