diff --git a/tests/python/multidevice/test_matmul.py b/tests/python/multidevice/test_matmul.py index 76ce1939edf..1c877b4acb5 100644 --- a/tests/python/multidevice/test_matmul.py +++ b/tests/python/multidevice/test_matmul.py @@ -81,7 +81,7 @@ def definition(self): self.add_output(self.out) def multidevice_schedule(self): - for t in [self.inp, self.weight, self.bias, self.out]: + for t in [self.inp, self.weight, self.bias]: self.sched._set_device_mesh(t, mesh) # Shard N for weight (N, K) and bias (N) @@ -90,12 +90,6 @@ def multidevice_schedule(self): self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x) self.sched.set_allocation_as_loop(t) - # Output of linear: {.., i{M}, i{N}, r{K}} - # Shard N -> axis(-2) - self.sched.split(self.out, -2, d, False) - self.sched.parallelize(self.out, -3, nvfuser.ParallelType.mesh_x) - self.sched.set_allocation_as_loop(self.out) - torch.cuda.set_device(multidevice_test.local_rank) b, s = 2, 1024 @@ -135,7 +129,7 @@ def definition(self): self.add_output(self.out) def multidevice_schedule(self): - for t in [self.inp, self.weight, self.out]: + for t in [self.inp, self.weight]: self.sched._set_device_mesh(t, mesh) self.sched.split(t, -1, d, False) self.sched.parallelize(t, -2, nvfuser.ParallelType.mesh_x)