Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,9 @@ def _multi_gpu_exists():
# Mark a test as requiring NCCL support
requires_nccl = Feature("nccl", "NCCL", cmake_flag="USE_NCCL", parent_features="cuda")

# Mark a test as requiring RCCL support
requires_rccl = Feature("rccl", "RCCL", cmake_flag="USE_RCCL", parent_features="rocm")

# Mark a test as requiring the NVPTX compilation on the CUDA runtime
requires_nvptx = Feature(
"nvptx",
Expand Down
12 changes: 7 additions & 5 deletions tests/python/disco/test_ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from tvm.script import relax as R

_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
_ccl = [get_global_func("runtime.disco.compiled_ccl")()]
_ccl = [
pytest.param(
ccl,
marks=tvm.testing.Feature._all_features[ccl].marks(),
)
for ccl in ["nccl", "rccl"]
]


def create_device_target(ccl):
Expand Down Expand Up @@ -445,7 +451,6 @@ def main(
W1: R.Tensor((128, 128), "float32"),
W2: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv0: R.Tensor((128, 128), "float32") = R.matmul(x, W1)
lv1: R.Tensor((128, 128), "float32") = R.nn.gelu(lv0)
Expand All @@ -461,7 +466,6 @@ def main(
W1: R.Tensor((128, 64), "float32"), # shard along axis 1
W2: R.Tensor((64, 128), "float32"), # shard along axis 0
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
broadcast_x: R.Tensor((128, 128), "float32") = R.ccl.broadcast_from_worker0(x)
lv0: R.Tensor((128, 64), "float32") = R.matmul(broadcast_x, W1)
Expand Down Expand Up @@ -538,7 +542,6 @@ def main( # pylint: disable=too-many-locals
Wv: R.Tensor((128, 512), "float32"),
Wo: R.Tensor((512, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
# q
lv0: R.Tensor((1, 10, 512), "float32") = R.matmul(x, Wq)
Expand Down Expand Up @@ -578,7 +581,6 @@ def main( # pylint: disable=too-many-locals
Wv: R.Tensor((128, 256), "float32"), # shard along axis 1
Wo: R.Tensor((256, 128), "float32"), # shard along axis 0
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
broadcast_x: R.Tensor((1, 10, 128), "float32") = R.ccl.broadcast_from_worker0(x)
# q
Expand Down
8 changes: 7 additions & 1 deletion tests/python/disco/test_custom_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ class AllReduceStrategyType(enum.IntEnum):
AllReduceStrategyType.AUTO,
]

_ccl = [ccl for ccl in tvm.get_global_func("runtime.disco.compiled_ccl")() if ccl == "nccl"]
_ccl = [
pytest.param(
ccl,
marks=tvm.testing.Feature._all_features[ccl].marks(),
)
for ccl in ["nccl"]
]


@pytest.mark.parametrize("shape", _shapes)
Expand Down
9 changes: 8 additions & 1 deletion tests/python/disco/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Test sharded loader"""

# pylint: disable=missing-docstring
import json
import tempfile
Expand Down Expand Up @@ -120,6 +121,7 @@ def _simulate_presharded_weights(base_path, param_dict, num_shards, shard_info):
)


@tvm.testing.requires_nccl
def test_load_shard():
devices = [0, 1]
num_shards = len(devices)
Expand Down Expand Up @@ -177,6 +179,7 @@ def _create_presharded_loader(sess, path):
return loader


@tvm.testing.requires_nccl
def test_load_presharded():
devices = [0, 1]
param_dict = {
Expand Down Expand Up @@ -217,6 +220,7 @@ def test_load_presharded():
)


@tvm.testing.requires_nccl
def test_load_shard_in_relax():
devices = [0, 1]
num_shards = len(devices)
Expand Down Expand Up @@ -248,7 +252,6 @@ class Module: # pylint: disable=too-few-public-methods
def main(
loader: R.Object,
) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128), "float32")):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv0: R.Tensor((64, 64), "float32") = R.call_pure_packed(
"runtime.disco.ShardLoaderLoad",
Expand Down Expand Up @@ -309,6 +312,7 @@ def relax_build(mod, target):
)


@tvm.testing.requires_nccl
def test_load_shard_all():
devices = [0, 1]
num_shards = len(devices)
Expand Down Expand Up @@ -346,6 +350,7 @@ def test_load_shard_all():
np.testing.assert_equal(param_dict["param_1"][16:32, :], p_1[1].numpy())


@tvm.testing.requires_nccl
def test_load_all_presharded():
devices = [0, 1]
num_shards = len(devices)
Expand Down Expand Up @@ -375,6 +380,7 @@ def test_load_all_presharded():
np.testing.assert_equal(param_dict["param_1"][:, 64:128], p_1[1].numpy())


@tvm.testing.requires_nccl
def test_load_shard_broadcast():
devices = [0, 1]
param_dict = {
Expand All @@ -396,6 +402,7 @@ def test_load_shard_broadcast():
np.testing.assert_equal(param_dict["param_1"], p_1[1].numpy())


@tvm.testing.requires_nccl
def test_load_qkv_proj_shard(): # pylint: disable=too-many-locals
devices = [0, 1]
num_shards = len(devices)
Expand Down
3 changes: 1 addition & 2 deletions tests/python/disco/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Basic tests for a Disco session"""

# pylint: disable=missing-docstring
import tempfile

Expand Down Expand Up @@ -260,7 +261,6 @@ def t2(A: T.Buffer((16, 8), "float32"), B: T.Buffer((8, 16), "float32")):
def transpose_1(
A: R.Tensor((8, 16), dtype="float32")
) -> R.Tensor((16, 8), dtype="float32"):
R.func_attr({"global_symbol": "transpose_1"})
cls = TestMod
with R.dataflow():
B = R.call_tir(cls.t1, (A,), out_sinfo=R.Tensor((16, 8), dtype="float32"))
Expand All @@ -271,7 +271,6 @@ def transpose_1(
def transpose_2(
A: R.Tensor((16, 8), dtype="float32")
) -> R.Tensor((8, 16), dtype="float32"):
R.func_attr({"global_symbol": "transpose_2"})
cls = TestMod
with R.dataflow():
B = R.call_tir(cls.t2, (A,), out_sinfo=R.Tensor((8, 16), dtype="float32"))
Expand Down
1 change: 1 addition & 0 deletions tests/scripts/task_python_unittest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ TEST_FILES=(
"auto_scheduler"
"autotvm"
"codegen"
"disco"
"ir"
"meta_schedule"
"micro"
Expand Down