-
Notifications
You must be signed in to change notification settings - Fork 29
Open
Description
-insert-explicit-reshards inserts reshards into the input of concatenate when the output is sharded, however, this produces the wrong output mathematically.
For example:
module {
sdy.mesh @mesh = <["_axis_0_updated"=1, "_axis_0"=8]>
func.func @concatenate_reshape_test(%arg0: tensor<128x2880xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}, %arg1: tensor<32x2880x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0"}, {}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x128x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0", ?}, {?}, {?}]>, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
%0 = stablehlo.concatenate %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}]>]>} : (tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>) -> tensor<4096x2880xbf16>
%1 = stablehlo.reshape %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>} : (tensor<4096x2880xbf16>) -> tensor<32x128x2880xbf16>
%2 = stablehlo.dot_general %1, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>} : (tensor<32x128x2880xbf16>, tensor<32x2880x5760xbf16>) -> tensor<32x128x5760xbf16>
return %2 : tensor<32x128x5760xbf16>
}
}
This will result in 32 reshard operations on %arg0, and look like this:
module {
sdy.mesh @mesh = <["_axis_0_updated"=1, "_axis_0"=8]>
func.func @concatenate_reshape_test(%arg0: tensor<128x2880xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}, %arg1: tensor<32x2880x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0"}, {}, {}]>, sdy.sharding_origins = {_axis_0 = "self"}, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x128x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0", ?}, {?}, {?}]>, sdy.sharding_origins = {_axis_0 = "input: 1"}, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
%0 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%1 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%2 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%3 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%4 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%5 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%6 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%7 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%8 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%9 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%10 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%11 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%12 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%13 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%14 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%15 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%16 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%17 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%18 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%19 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%20 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%21 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%22 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%23 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%24 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%25 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%26 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%27 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%28 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%29 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%30 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%31 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
%32 = stablehlo.concatenate %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}]>]>, sdy.sharding_origins = [{_axis_0 = "input: 1"}]} : (tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>) -> tensor<4096x2880xbf16>
%33 = stablehlo.reshape %32 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>, sdy.sharding_origins = [{_axis_0 = "input: 1"}]} : (tensor<4096x2880xbf16>) -> tensor<32x128x2880xbf16>
%34 = stablehlo.dot_general %33, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>, sdy.sharding_origins = [{_axis_0 = "input: 1"}]} : (tensor<32x128x2880xbf16>, tensor<32x2880x5760xbf16>) -> tensor<32x128x5760xbf16>
return %34 : tensor<32x128x5760xbf16>
}
}
However, this is incorrect. The reshard should be added to the output of concatenate.
module {
sdy.mesh @mesh = <["_axis_0_updated"=1, "_axis_0"=8]>
func.func @concatenate_reshape_test(%arg0: tensor<128x2880xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}, %arg1: tensor<32x2880x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0"}, {}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x128x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0", ?}, {?}, {?}]>, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
%0 = stablehlo.concatenate %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>) -> tensor<4096x2880xbf16>
%1 = sdy.reshard %0 <@mesh, [{"_axis_0"}, {}]> : tensor<4096x2880xbf16>
%2 = stablehlo.reshape %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>} : (tensor<4096x2880xbf16>) -> tensor<32x128x2880xbf16>
%3 = stablehlo.dot_general %2, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>} : (tensor<32x128x2880xbf16>, tensor<32x2880x5760xbf16>) -> tensor<32x128x5760xbf16>
return %3 : tensor<32x128x5760xbf16>
}
}
I wrote a patch in shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc specifically for concatenate op that I can put up a PR for. However, I wanted to know if this behaviour is intended or if there are other ways of achieving the reshard on the output (instead of the input).
Metadata
Metadata
Assignees
Labels
No labels