Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions shardy/dialect/mpmd/ir/dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,18 @@ RankedTensorType MeshTensorType::getLocalTensorType(sdy::MeshAttr sdy_mesh) {
}

RankedTensorType MeshTensorType::getLocalTensorType(Operation* op) {
sdy::TensorShardingAttr sharding = getSharding();
if (!sharding) {
return getGlobalTensorType();
}
auto func_op = sdy::getEnclosingOfType<FuncOp>(op);
if (HasHomogeneousTopology(func_op)) {
// TODO(b/439770762): Remove this once we have correct global meshes.
return MeshTensorType::getLocalTensorType(
GetTopologyMeshes(func_op).front().getMesh());
}
sdy::TensorShardingAttr sharding = getSharding();
if (!sharding) {
return getGlobalTensorType();
}
// TODO(b/441487083): Look up the mesh in the global mesh registry.
return MeshTensorType::getLocalTensorType(
GetMeshOrFail(op, sharding.getMeshName()));
sdy::getMeshOp(op, sharding.getMeshName()).getMeshAttr());
}

// Functions for the ShapedTypeInterface.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,77 +1,99 @@
// RUN: mpmd_opt %s -mpmd-mark-fragment-reserved-memory 2>&1 | FileCheck %s
// RUN: mpmd_opt %s -mpmd-mark-fragment-reserved-memory -split-input-file
// 2>&1 | FileCheck %s

// NOTE:
// - mesh_1_tensor and mesh_2_tensor is 128 bytes.
// - mesh_1_tensor_dist_x is 32 bytes.
// - mesh_1_tensor is 128 bytes.
!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
!mesh_1_tensor_dist_x = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>, sharding=<@m1, [{"x"}, {?}]>>
!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>>

// CHECK-LABEL: func @single_mesh
func.func @single_mesh(%arg0: !mesh_1_tensor, %arg1: !mesh_1_tensor)
-> (!mesh_1_tensor) attributes {
"topology"=#mpmd.topology<
<"m1": <["x"=4]>>>} {
// CHECK-LABEL: module @single_mesh
module @single_mesh {
sdy.mesh @mesh = <["x"=4]>

func.func @main(%arg0: !mesh_1_tensor, %arg1: !mesh_1_tensor)
-> (!mesh_1_tensor) attributes {
"topology"=#mpmd.topology<
<"m1": <["x"=4]>>>} {

// Fragment only takes inputs from the function, no intermediates, so 0.
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 0
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0, %arg1)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor

// Fragment takes one input from the function and one intermediates, so 128 due to
// %arg0.
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 128
%1 = mpmd.fragment<mesh="m1", origin=["f2"]> (%0, %arg1)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%4 = stablehlo.multiply %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor

// Fragment takes two intermediates, so 256 due to %arg0 and %arg1
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 256
%2 = mpmd.fragment<mesh="m1", origin=["f3"]> (%0, %1)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor

// Fragment takes an input and a intermediate %2. Note %0's and %1's last use
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 128
%3 = mpmd.fragment<mesh="m1", origin=["f4"]> (%arg0, %2)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor

func.return %3 : !mesh_1_tensor
}
}

// Fragment only takes inputs from the function, no intermediates, so 0.
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 0
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0, %arg1)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor

// Fragment takes one input from the function and one intermediates, so 128 due to
// %arg0.
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 128
%1 = mpmd.fragment<mesh="m1", origin=["f2"]> (%0, %arg1)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%4 = stablehlo.multiply %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor
// -----

// Fragment takes two intermediates, so 256 due to %arg0 and %arg1
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 256
%2 = mpmd.fragment<mesh="m1", origin=["f3"]> (%0, %1)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor

// Fragment takes an input and a intermediate %2. Note %0's and %1's last use
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 128
%3 = mpmd.fragment<mesh="m1", origin=["f4"]> (%arg0, %2)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%4 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %4 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor
!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>

func.return %3 : !mesh_1_tensor
// CHECK-LABEL: module @duplicate_input
module @duplicate_input {
sdy.mesh @mesh = <["x"=4]>

// Make sure we don't subtract twice from live memory usage if a fragment takes
// two of the same inputs.
// CHECK-LABEL: func @duplicate_input
func.func @main(%arg0: !mesh_1_tensor)
-> (!mesh_1_tensor) attributes {
"topology"=#mpmd.topology<
<"m1": <["x"=4]>>>} {

// Fragment only takes inputs from the function, no intermediates, so 0.
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 0
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0, %arg0)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%1 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %1 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor

func.return %0 : !mesh_1_tensor
}
}

// Make sure we don't subtract twice from live memory usage if a fragment takes
// two of the same inputs.
// CHECK-LABEL: func @duplicate_input
func.func @duplicate_input(%arg0: !mesh_1_tensor)
-> (!mesh_1_tensor) attributes {
"topology"=#mpmd.topology<
<"m1": <["x"=4]>>>} {

// Fragment only takes inputs from the function, no intermediates, so 0.
// CHECK: mpmd.fragment
// CHECK-SAME: xla_tpu_user_reserved_hbm_bytes = 0
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0, %arg0)
(%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) {
%1 = stablehlo.add %arg2, %arg3 : tensor<4x8xf32>
mpmd.return %1 : tensor<4x8xf32>
} : (!mesh_1_tensor, !mesh_1_tensor) -> !mesh_1_tensor
// -----

func.return %0 : !mesh_1_tensor
}

!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>

module {
sdy.mesh @mesh = <["x"=4]>

// CHECK-LABEL: func @offloaded_value
func.func @offloaded_value(%arg0: !mesh_1_tensor, %arg1: !mesh_1_tensor)
Expand Down Expand Up @@ -112,14 +134,27 @@ func.func @offloaded_value(%arg0: !mesh_1_tensor, %arg1: !mesh_1_tensor)

func.return %2 : !mesh_1_tensor
}
}


// -----


// NOTE:
// - mesh_1_tensor and mesh_2_tensor is 128 bytes.
!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4x8xf32>>

module {
sdy.mesh @mesh = <["x"=4]>

// Same test as `@single_mesh` but now with some tensors existing on other
// meshes.
// CHECK-LABEL: func @multiple_meshes
func.func @multiple_meshes(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor)
-> (!mesh_2_tensor) attributes {
"topology"=#mpmd.topology<
<"m1": <["x"=4]>>, <"m2": <["x"=2]>>>} {
<"m1": <["x"=4]>>, <"m2": <["y"=2]>>>} {

%0 = mpmd.transfer %arg0 : (!mesh_1_tensor) -> !mesh_2_tensor
%1 = mpmd.transfer %arg1 : (!mesh_2_tensor) -> !mesh_1_tensor
Expand Down Expand Up @@ -166,6 +201,20 @@ func.func @multiple_meshes(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor)

func.return %7 : !mesh_2_tensor
}
}


// -----


// NOTE:
// - mesh_1_tensor is 128 bytes.
// - mesh_1_tensor_dist_x is 32 bytes.
!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>
!mesh_1_tensor_dist_x = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>, sharding=<@m1, [{"x"}, {?}]>>

module {
sdy.mesh @mesh = <["x"=4, "y"=2]>

// Tests that the pass accounts for the per device local shape, not global
// shape, if the tensor is distributed.
Expand Down Expand Up @@ -208,6 +257,16 @@ func.func @distributed_tensor(%arg0: !mesh_1_tensor)

func.return %3 : !mesh_1_tensor
}
}


// -----


!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>

module {
sdy.mesh @mesh = <["x"=2]>

// Test that verifies the unused output of a fragment is not accounted for in
// the live buffers.
Expand Down Expand Up @@ -239,6 +298,16 @@ func.func @unused_fragment_result_is_not_counted(

return %1, %arg3 : !mesh_1_tensor, !mesh_1_tensor
}
}


// -----


!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>

module {
sdy.mesh @mesh = <["x"=2]>

// Test that verifies a donated program argument is not accounted for after its
// last use. The test verifies both jax.buffer_donor and tf.aliasing_output
Expand Down Expand Up @@ -282,6 +351,17 @@ func.func @donated_program_arg_is_not_counted_after_last_use(

return %1, %2 : !mesh_1_tensor, !mesh_1_tensor
}
}


// -----


!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>

module {
sdy.mesh @mesh = <["x"=4]>


// Test that verifies args on hosts or donated args are not accounted for.
// CHECK-LABEL: func @offloaded_or_unused_donated_args_are_not_counted
Expand Down Expand Up @@ -313,6 +393,16 @@ func.func @offloaded_or_unused_donated_args_are_not_counted(

func.return %0#0, %1#0 : !mesh_1_tensor, !mesh_1_tensor
}
}


// -----


!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>>

module {
sdy.mesh @mesh = <["x"=4]>

// CHECK-LABEL: func @unused_input_not_donated
func.func @unused_input_not_donated(%arg0: !mesh_1_tensor, %unused_arg1: !mesh_1_tensor)
Expand All @@ -330,3 +420,4 @@ func.func @unused_input_not_donated(%arg0: !mesh_1_tensor, %unused_arg1: !mesh_1

func.return %0 : !mesh_1_tensor
}
}
Loading