Skip to content
Merged
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
cf77131
lintrunner
Priya2698 Feb 12, 2025
dddfe64
rm duplicate test from rebase
Priya2698 Feb 13, 2025
a66ab55
split and merge reshape if ViewOp does not reshard
Priya2698 Feb 24, 2025
41c6bcb
reshape-permute-sdpa-reshape block
Priya2698 Feb 26, 2025
a238bab
clean test
Priya2698 Feb 26, 2025
02d32cf
add war to identify view op as not resharding, comment
Priya2698 Mar 4, 2025
f2aff84
add parallelize input flag
Priya2698 Mar 4, 2025
aa7aada
rm import
Priya2698 Mar 4, 2025
92f6ff9
fix rebase
Priya2698 Mar 6, 2025
3f7f3a5
fix rebase
Priya2698 Mar 6, 2025
02d2ce3
return did pos from reorder
Priya2698 Mar 7, 2025
db7ac1e
reshape sharding for transformer case
Priya2698 Mar 12, 2025
4502436
add ordering of inputs, custom selector for directioned propagation, …
Priya2698 Mar 13, 2025
af5ad27
support multiple merges or splits in a reshape
Priya2698 Mar 13, 2025
0b9577c
use producer's parallel type
Priya2698 Mar 13, 2025
e0c4026
return number of did shardings on reshape
Priya2698 Mar 13, 2025
f798f1e
update handling of view op
Priya2698 Mar 14, 2025
f90b478
move tests into separate files
Priya2698 Mar 18, 2025
4b066d1
reorder back as logical, more tests
Priya2698 Mar 18, 2025
3b673ab
fix rebase
Priya2698 Mar 18, 2025
748949a
fix rebase
Priya2698 Mar 18, 2025
39e03ee
clean
Priya2698 Mar 18, 2025
ca80c00
lintrunner
Priya2698 Mar 18, 2025
19bca7e
reorder did to front
Priya2698 Mar 19, 2025
5eabe96
check if ref input has device mesh
Priya2698 Mar 19, 2025
99c3574
update preseg pass
Priya2698 Mar 20, 2025
1a20a7c
reorder to original in the interim
Priya2698 Mar 20, 2025
de91943
allocation domain util fn
Priya2698 Mar 22, 2025
d56a742
rm allocation domain reorder
Priya2698 Mar 22, 2025
7be563d
rebase, reorder as alloc
Priya2698 Apr 5, 2025
fc80957
set device mesh on fusion inputs
Priya2698 Apr 7, 2025
841e8b3
Merge branch 'main' into pm/preseg_sharding_prop
Priya2698 Apr 7, 2025
443aef4
early return if ref_inputs do not have mesh
Priya2698 Apr 7, 2025
9b86c5b
lintrunner
Priya2698 Apr 7, 2025
83f7133
remove view op specific changes
Priya2698 Apr 11, 2025
7919ef2
undo changes in lower.cpp
Priya2698 Apr 11, 2025
fc43602
Merge branch 'main' into pm/preseg_sharding_prop
Priya2698 Apr 11, 2025
7a34ea4
undo changes in scheduler_utils merged in another PR
Priya2698 Apr 11, 2025
6e6b1d0
fix reorderLoopAsAllocation
Priya2698 Apr 11, 2025
cdc8d5d
set allocation
Priya2698 Apr 11, 2025
07fbcb7
shard each input individually in backprop
Priya2698 Apr 11, 2025
5df0a6b
Merge branch 'main' into pm/preseg_sharding_prop
Priya2698 Apr 11, 2025
7956dac
cleanup
Priya2698 Apr 11, 2025
c271701
rm reshape tests
Priya2698 Apr 11, 2025
c61f548
extraneous change
Priya2698 Apr 11, 2025
61dbdc6
shardAllLike
Priya2698 Apr 11, 2025
24331ad
fix build error, test without serial
Priya2698 Apr 11, 2025
3f146b6
test without changing preseg order
Priya2698 Apr 11, 2025
916a693
include parallel type serial
Priya2698 Apr 12, 2025
c0e96cf
undo debug changes
Priya2698 Apr 12, 2025
632f743
propagate only for nonsharded inputs
Priya2698 Apr 12, 2025
2bb4624
revert to only propagating for unsharded inputs
Priya2698 Apr 12, 2025
ad931e4
derive contiguity through transforms
Priya2698 Apr 14, 2025
3757b3c
fix test
Priya2698 Apr 15, 2025
bd945e3
set allocation domain last
Priya2698 Apr 15, 2025
a40cf3d
undo shardAllLike changes
Priya2698 Apr 15, 2025
3b7e77f
move allocation domain to makeReshardingContiguous
Priya2698 Apr 15, 2025
993da4c
Merge branch 'main' into pm/preseg_sharding_prop
Priya2698 Apr 15, 2025
7f9b8fb
Merge branch 'main' into pm/preseg_sharding_prop
Priya2698 Apr 16, 2025
044cbc8
undo changes
Priya2698 Apr 16, 2025
29de7ff
lintrunner
Priya2698 Apr 16, 2025
af1c71d
modify tests
Priya2698 Apr 16, 2025
6152055
Update csrc/preseg_passes/make_resharding_contiguous.cpp
Priya2698 Apr 17, 2025
7d54f58
Update csrc/preseg_passes/make_resharding_contiguous.cpp
Priya2698 Apr 17, 2025
9f94484
restore tests
Priya2698 Apr 18, 2025
7c539a2
move tests
Priya2698 Apr 19, 2025
7dab9da
Merge branch 'main' into pm/preseg_sharding_prop
Priya2698 Apr 22, 2025
11d3d04
clean tests
Priya2698 Apr 22, 2025
00d93fc
comment;
Priya2698 Apr 22, 2025
bbb835e
comment
Priya2698 Apr 22, 2025
3192018
propagate only selected types
Priya2698 Apr 22, 2025
a527540
Merge branch 'main' into pm/preseg_sharding_prop
Priya2698 Apr 23, 2025
afb580c
check contiguity for both inp and output
Priya2698 Apr 23, 2025
3b48014
Update csrc/preseg_passes/propagate_shardings.cpp
Priya2698 Apr 23, 2025
354f8da
Update csrc/preseg_passes/propagate_shardings.cpp
Priya2698 Apr 23, 2025
52573e7
review comments
Priya2698 Apr 23, 2025
5eea3ab
Simplify some tests since sharding propagation is in place
wujingyue Apr 23, 2025
21834ae
Merge branch 'main' into wjy/prop
wujingyue Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions tests/python/multidevice/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def definition(self):
self.add_output(self.out)

def multidevice_schedule(self):
for t in [self.inp, self.weight, self.bias, self.out]:
for t in [self.inp, self.weight, self.bias]:
self.sched._set_device_mesh(t, mesh)

# Shard N for weight (N, K) and bias (N)
Expand All @@ -90,12 +90,6 @@ def multidevice_schedule(self):
self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(t)

# Output of linear: {.., i{M}, i{N}, r{K}}
# Shard N -> axis(-2)
self.sched.split(self.out, -2, d, False)
self.sched.parallelize(self.out, -3, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.out)

torch.cuda.set_device(multidevice_test.local_rank)

b, s = 2, 1024
Expand Down Expand Up @@ -135,7 +129,7 @@ def definition(self):
self.add_output(self.out)

def multidevice_schedule(self):
for t in [self.inp, self.weight, self.out]:
for t in [self.inp, self.weight]:
self.sched._set_device_mesh(t, mesh)
self.sched.split(t, -1, d, False)
self.sched.parallelize(t, -2, nvfuser.ParallelType.mesh_x)
Expand Down