diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 0a2b151d4274..66d77b48aa0c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -160,7 +160,7 @@ def _criterion(outputs, inputs): input_shape = data["input_ids"].shape for k, v in data.items(): if v.shape == input_shape: - data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) + data[k] = v.repeat((1, ) * (v.dim() - 1) + (times,)) sharded_model.train() if booster.plugin.stage_manager is not None: