Skip to content

Unintended InsertExplicitReshardsPass behaviour with stablehlo.concatenate #945

@ddilbazTT

Description

@ddilbazTT

-insert-explicit-reshards inserts reshards into the input of concatenate when the output is sharded, however, this produces the wrong output mathematically.

For example:

module {
  sdy.mesh @mesh = <["_axis_0_updated"=1, "_axis_0"=8]>
  func.func @concatenate_reshape_test(%arg0: tensor<128x2880xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}, %arg1: tensor<32x2880x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0"}, {}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x128x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0", ?}, {?}, {?}]>, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
    %0 = stablehlo.concatenate %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}]>]>} : (tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>) -> tensor<4096x2880xbf16>
    %1 = stablehlo.reshape %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>} : (tensor<4096x2880xbf16>) -> tensor<32x128x2880xbf16>
    %2 = stablehlo.dot_general %1, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>} : (tensor<32x128x2880xbf16>, tensor<32x2880x5760xbf16>) -> tensor<32x128x5760xbf16>
    return %2 : tensor<32x128x5760xbf16>
  }
}

This will result in 32 reshard operations on %arg0, and look like this:

module {
  sdy.mesh @mesh = <["_axis_0_updated"=1, "_axis_0"=8]>
  func.func @concatenate_reshape_test(%arg0: tensor<128x2880xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}, %arg1: tensor<32x2880x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0"}, {}, {}]>, sdy.sharding_origins = {_axis_0 = "self"}, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x128x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0", ?}, {?}, {?}]>, sdy.sharding_origins = {_axis_0 = "input: 1"}, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
    %0 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %1 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %2 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %3 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %4 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %5 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %6 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %7 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %8 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %9 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %10 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %11 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %12 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %13 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %14 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %15 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %16 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %17 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %18 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %19 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %20 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %21 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %22 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %23 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %24 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %25 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %26 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %27 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %28 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %29 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %30 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %31 = sdy.reshard %arg0 <@mesh, [{"_axis_0"}, {}]> : tensor<128x2880xbf16>
    %32 = stablehlo.concatenate %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}]>]>, sdy.sharding_origins = [{_axis_0 = "input: 1"}]} : (tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>) -> tensor<4096x2880xbf16>
    %33 = stablehlo.reshape %32 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>, sdy.sharding_origins = [{_axis_0 = "input: 1"}]} : (tensor<4096x2880xbf16>) -> tensor<32x128x2880xbf16>
    %34 = stablehlo.dot_general %33, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>, sdy.sharding_origins = [{_axis_0 = "input: 1"}]} : (tensor<32x128x2880xbf16>, tensor<32x2880x5760xbf16>) -> tensor<32x128x5760xbf16>
    return %34 : tensor<32x128x5760xbf16>
  }
}

However, this is incorrect. The reshard should be added to the output of concatenate.

module {
  sdy.mesh @mesh = <["_axis_0_updated"=1, "_axis_0"=8]>
  func.func @concatenate_reshape_test(%arg0: tensor<128x2880xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}, %arg1: tensor<32x2880x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0"}, {}, {}]>, ttcore.shard_status = #ttcore.shard_status<presharded>}) -> (tensor<32x128x5760xbf16> {sdy.sharding = #sdy.sharding<@mesh, [{"_axis_0", ?}, {?}, {?}]>, ttcore.shard_status = #ttcore.shard_status<unsharded>}) {
    %0 = stablehlo.concatenate %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {}]>]>} : (tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>, tensor<128x2880xbf16>) -> tensor<4096x2880xbf16>
    %1 = sdy.reshard %0 <@mesh, [{"_axis_0"}, {}]> : tensor<4096x2880xbf16>
    %2 = stablehlo.reshape %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>} : (tensor<4096x2880xbf16>) -> tensor<32x128x2880xbf16>
    %3 = stablehlo.dot_general %2, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"_axis_0", ?}, {?}, {?}]>]>} : (tensor<32x128x2880xbf16>, tensor<32x2880x5760xbf16>) -> tensor<32x128x5760xbf16>
    return %3 : tensor<32x128x5760xbf16>
  }
}

I wrote a patch in shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc specifically for concatenate op that I can put up a PR for. However, I wanted to know if this behaviour is intended or if there are other ways of achieving the reshard on the output (instead of the input).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions