Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions python/tvm/exec/disco_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,84 @@
import os
import sys

from tvm import runtime as _ # pylint: disable=unused-import
from typing import Callable

import tvm
from tvm._ffi import get_global_func, register_func
from tvm.runtime import NDArray, ShapeTuple, String
from tvm.runtime.ndarray import array


@register_func("tests.disco.add_one")
def _add_one(x: int) -> int: # pylint: disable=invalid-name
@register_func("tests.disco.add_one", override=True)
def _add_one(x: int) -> int:
return x + 1


@register_func("tests.disco.add_one_float", override=True)
def _add_one_float(x: float): # pylint: disable=invalid-name
def _add_one_float(x: float):
return x + 0.5


@register_func("tests.disco.add_one_ndarray", override=True)
def _add_one_ndarray(x: NDArray) -> NDArray: # pylint: disable=invalid-name
def _add_one_ndarray(x: NDArray) -> NDArray:
return array(x.numpy() + 1)


@register_func("tests.disco.str", override=True)
def _str_func(x: str): # pylint: disable=invalid-name
def _str_func(x: str):
return x + "_suffix"


@register_func("tests.disco.str_obj", override=True)
def _str_obj_func(x: String): # pylint: disable=invalid-name
def _str_obj_func(x: String):
assert isinstance(x, String)
return String(x + "_suffix")


@register_func("tests.disco.shape_tuple", override=True)
def _shape_tuple_func(x: ShapeTuple): # pylint: disable=invalid-name
def _shape_tuple_func(x: ShapeTuple):
assert isinstance(x, ShapeTuple)
return ShapeTuple(list(x) + [4, 5])


@register_func("tests.disco.test_callback", override=True)
def _make_callback(device: tvm.runtime.Device) -> Callable[[str, int], NDArray]:
"""For use in tests/python/disco/test_callback.py

This function simulates a callback to be used for lazy parameter
loading.

Parameters
----------
device: tvm.runtime.Device

The device on which parameters should be located, when
returned by the callback function.

Returns
-------
fget_item: Callable[[str,int], NDArray]

A callback function that accepts a parameter's name and index,
and returns the specified parameter.

"""
import numpy as np # pylint: disable=import-outside-toplevel

def fget_item(param_name: str, param_index: int) -> NDArray:
if param_index == 0:
assert param_name == "A"
arr = np.arange(16).reshape([4, 4]).astype("int32")
elif param_index == 1:
assert param_name == "B"
arr = np.arange(4).reshape([2, 2]).astype("float32")
else:
raise ValueError(f"Unexpected index {param_index}")
return tvm.nd.array(arr, device=device)

return fget_item


def main():
"""Main worker function"""
if len(sys.argv) != 5:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/disco/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def __init__(self, num_workers: int) -> None:
class ProcessSession(Session):
"""A Disco session backed by pipe-based multi-processing."""

def __init__(self, num_workers: int, entrypoint: str) -> None:
def __init__(self, num_workers: int, entrypoint: str = "tvm.exec.disco_worker") -> None:
self.__init_handle_by_constructor__(
_ffi_api.SessionProcess, # type: ignore # pylint: disable=no-member
num_workers,
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,9 @@ def _multi_gpu_exists():
# Mark a test as requiring the cuBLAS library.
requires_cublas = Feature("cublas", "cuBLAS", cmake_flag="USE_CUBLAS", parent_features="cuda")

# Mark a test as requiring NCCL support
requires_nccl = Feature("nccl", "NCCL", cmake_flag="USE_NCCL", parent_features="cuda")

# Mark a test as requiring the NVPTX compilation on the CUDA runtime
requires_nvptx = Feature(
"nvptx",
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/disco/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWo
TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ShapeTuple {
return ShapeTuple({WorkerId()});
});
TVM_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t {
return WorkerId();
});
TVM_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device {
return DiscoWorker::ThreadLocal()->default_device;
});

} // namespace runtime
} // namespace tvm
130 changes: 130 additions & 0 deletions tests/python/disco/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test sharded loader"""
# pylint: disable=missing-docstring

import pathlib
import tempfile

import numpy as np

import tvm
import tvm.testing

from tvm.script import relax as R, tir as T


@tvm.testing.requires_nccl
def test_callback():
@R.function
def transform_params(
rank_arg: R.Prim(value="rank"),
fget_item: R.Callable([R.Object, R.Prim("int64")], R.Object),
):
"""Simulate lazy loading of parameters in a callback

The output of a lazy parameter loading, which would accept a
callback to load the parameters.
"""
rank = T.int64()

A = fget_item(R.str("A"), R.prim_value(0))
A = R.match_cast(A, R.Tensor([4, 4], "int32"))
A = R.strided_slice(A, axes=[0], begin=[rank * 2], end=[(rank + 1) * 2])

B = fget_item(R.str("B"), R.prim_value(1))
B = R.match_cast(B, R.Tensor([2, 2], "float32"))
B = R.strided_slice(B, axes=[1], begin=[rank * 1], end=[(rank + 1) * 1])

return (A, B)

pipeline = tvm.ir.transform.Sequential(
[
tvm.relax.transform.LegalizeOps(),
tvm.dlight.ApplyDefaultSchedule(tvm.dlight.gpu.Fallback()),
],
name="pipeline",
)

with tvm.target.Target("cuda"):
mod = tvm.IRModule.from_expr(transform_params)
mod = pipeline(mod)
built = tvm.relax.build(mod, "cuda")

num_shards = 2

session = tvm.runtime.disco.ProcessSession(num_workers=num_shards)
session.import_python_module("tvm.exec.disco_worker")
session.init_ccl("nccl", *range(num_shards))

worker_device = session.get_global_func("runtime.disco.device")()
worker_id = session.get_global_func("runtime.disco.worker_rank")()
callback_maker = session.get_global_func("tests.disco.test_callback")
fget_item = callback_maker(worker_device)

with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = pathlib.Path(temp_dir)

# TODO(Lunderberg): Update `disco.Session.load_vm_module` to
# allow a `tvm.runtime.Module` argument. This would avoid the
# need for a temporary file.
shlib_path = temp_dir.joinpath("libtemp.so")
built.export_library(shlib_path)
vm = session.load_vm_module(shlib_path.as_posix())
transform_params = vm["transform_params"]

params = transform_params(worker_id, fget_item)

# Worker 0 is the same PID as the controlling scope, so
# `debug_get_from_remote(0)` returns the NDArray containing
# the output.
params_gpu0 = params.debug_get_from_remote(0)
assert params_gpu0[0].device == tvm.cuda(0)
assert params_gpu0[1].device == tvm.cuda(0)
np.testing.assert_array_equal(
params_gpu0[0].numpy(),
[
[0, 1, 2, 3],
[4, 5, 6, 7],
],
)
np.testing.assert_array_equal(
params_gpu0[1].numpy(),
[[0], [2]],
)

# Worker 1 is a different PID altogether, so
# `debug_get_from_remote(1)` returns a new NDArray within the
# calling scope's PID.
params_gpu1 = params.debug_get_from_remote(1)
assert params_gpu1[0].device == tvm.cpu()
assert params_gpu1[1].device == tvm.cpu()
np.testing.assert_array_equal(
params_gpu1[0].numpy(),
[
[8, 9, 10, 11],
[12, 13, 14, 15],
],
)
np.testing.assert_array_equal(
params_gpu1[1].numpy(),
[[1], [3]],
)


if __name__ == "__main__":
tvm.testing.main()