Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ if(BUILD_FOR_HEXAGON)
add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0)
endif()

# distributed disco runtime are disabled for hexagon
if (NOT BUILD_FOR_HEXAGON)
tvm_file_glob(GLOB RUNTIME_DISCO_DISTRIBUTED_SRCS src/runtime/disco/distributed/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_DISCO_DISTRIBUTED_SRCS})
endif()

# Package runtime rules
if(NOT USE_RTTI)
add_definitions(-DDMLC_ENABLE_RTTI=0)
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/runtime/disco/disco_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class DiscoWorker {
explicit DiscoWorker(int worker_id, int num_workers, int num_groups,
WorkerZeroData* worker_zero_data, DiscoChannel* channel)
: worker_id(worker_id),
local_worker_id(worker_id),
num_workers(num_workers),
num_groups(num_groups),
default_device(Device{DLDeviceType::kDLCPU, 0}),
Expand All @@ -68,6 +69,9 @@ class DiscoWorker {

/*! \brief The id of the worker.*/
int worker_id;
/*! \brief The local id of the worker. This can be different from worker_id if the session is
* consisted with multiple sub-sessions. */
int local_worker_id;
/*! \brief Total number of workers */
int num_workers;
/*! \brief Total number of workers */
Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/disco/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class Session : public ObjectRef {
*/
TVM_DLL static Session ProcessSession(int num_workers, int num_groups,
String process_pool_creator, String entrypoint);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj);
};

Expand Down
33 changes: 33 additions & 0 deletions python/tvm/exec/disco_remote_socket_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Launch disco session in the remote node and connect to the server."""
import sys
import tvm
from . import disco_worker as _ # pylint: disable=unused-import


if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: <server_host> <server_port> <num_workers>")
sys.exit(1)

server_host = sys.argv[1]
server_port = int(sys.argv[2])
num_workers = int(sys.argv[3])
func = tvm.get_global_func("runtime.disco.RemoteSocketSession")
func(server_host, server_port, num_workers)
1 change: 1 addition & 0 deletions python/tvm/runtime/disco/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
ProcessSession,
Session,
ThreadedSession,
SocketSession,
)
23 changes: 23 additions & 0 deletions python/tvm/runtime/disco/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,29 @@ def _configure_structlog(self) -> None:
func(config, os.getpid())


@register_func("runtime.disco.create_socket_session_local_workers")
def _create_socket_session_local_workers(num_workers) -> Session:
"""Create the local session for each distributed node over socket session."""
return ProcessSession(num_workers)


@register_object("runtime.disco.SocketSession")
class SocketSession(Session):
"""A Disco session backed by socket-based multi-node communication."""

def __init__(
self, num_nodes: int, num_workers_per_node: int, num_groups: int, host: str, port: int
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.SocketSession, # type: ignore # pylint: disable=no-member
num_nodes,
num_workers_per_node,
num_groups,
host,
port,
)


@register_func("runtime.disco._configure_structlog")
def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None:
"""Configure structlog for all disco workers
Expand Down
20 changes: 20 additions & 0 deletions src/runtime/disco/bcast_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ class BcastSessionObj : public SessionObj {
* \param TVMArgs The input arguments in TVM's PackedFunc calling convention
*/
virtual void BroadcastPacked(const TVMArgs& args) = 0;

/*!
* \brief Send a packed sequence to a worker. This function is usually called by the controler to
* communicate with worker-0, because the worker-0 is assumed to be always collocated with the
* controler. Sending to other workers may not be supported.
* \param worker_id The worker id to send the packed sequence to.
* \param args The packed sequence to send.
*/
virtual void SendPacked(int worker_id, const TVMArgs& args) = 0;

/*!
* \brief Receive a packed sequence from a worker. This function is usually called by the
* controler to communicate with worker-0, because the worker-0 is assumed to be always
Expand All @@ -83,6 +93,16 @@ class BcastSessionObj : public SessionObj {

struct Internal;
friend struct Internal;
friend class SocketSessionObj;
friend class RemoteSocketSession;
};

/*!
* \brief Managed reference to BcastSessionObj.
*/
class BcastSession : public Session {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BcastSession, Session, BcastSessionObj);
};

} // namespace runtime
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/disco/disco_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ struct DiscoWorker::Impl {
}

static void CopyFromWorker0(DiscoWorker* self, int reg_id) {
if (self->worker_zero_data != nullptr) {
if (self->worker_id == 0) {
NDArray tgt = GetNDArrayFromHost(self);
NDArray src = GetReg(self, reg_id);
tgt.CopyFrom(src);
}
}

static void CopyToWorker0(DiscoWorker* self, int reg_id) {
if (self->worker_zero_data != nullptr) {
if (self->worker_id == 0) {
NDArray src = GetNDArrayFromHost(self);
NDArray tgt = GetReg(self, reg_id);
tgt.CopyFrom(src);
Expand Down
Loading