Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 79 additions & 43 deletions python/mscclpp/language/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class MemoryChannel:
"""

_channel_counts = defaultdict(int)
_channel_peer_counts = defaultdict(int)

@classmethod
def reset(cls):
Expand Down Expand Up @@ -52,6 +53,8 @@ def __init__(self, dst_rank: int, src_rank: int):

self.channel_id = MemoryChannel._channel_counts[src_rank]
MemoryChannel._channel_counts[src_rank] += 1
self.channel_peer_id = MemoryChannel._channel_peer_counts[(src_rank, dst_rank)]
MemoryChannel._channel_peer_counts[(src_rank, dst_rank)] += 1

self.dst_rank = dst_rank
self.src_rank = src_rank
Expand All @@ -76,7 +79,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = F
>>> channel.signal(tb=0, data_sync=SyncType.before)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)

def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False):
Expand All @@ -97,7 +100,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = Fal
>>> channel.wait(tb=0, data_sync=SyncType.after)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)

def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
Expand Down Expand Up @@ -138,21 +141,25 @@ def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb, self)
op = GetOperation(
rank=self.src_rank,
threadblock=tb_id,
src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)],
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
tbg=(tb_group if tb_group is not None else None),
)
get_program().add_operation(self.src_rank, tb_id, op)
operations.append(op)

if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)

def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Send data from local memory to remote memory.
Expand Down Expand Up @@ -197,21 +204,25 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = PutOperation(
rank=self.src_rank,
threadblock=tb_id,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
tbg=(tb_group if tb_group is not None else None),
)
get_program().add_operation(self.src_rank, tb_id, op)
operations.append(op)

if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)

def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Transfer data in packet format from local to remote scratch buffer.
Expand Down Expand Up @@ -261,23 +272,27 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, t
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = PutOperation(
rank=self.src_rank,
threadblock=tb_id,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
tbg=(tb_group if tb_group is not None else None),
from_packet=True,
to_packet=True,
)
get_program().add_operation(self.src_rank, tb_id, op)
operations.append(op)

if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)

def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Transfer data from local buffer to remote scratch buffer in packet format.
Expand Down Expand Up @@ -325,24 +340,27 @@ def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_gro
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = PutOperation(
rank=self.src_rank,
threadblock=tb_id,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
tbg=(tb_group if tb_group is not None else None),
from_packet=False,
to_packet=True,
)
operations.append(op)

get_program().add_operation(self.src_rank, tb_id, op)
if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)

def reduce(
self,
Expand Down Expand Up @@ -405,6 +423,7 @@ def reduce(
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
remote_chunks = [
RemoteChunk(
Expand All @@ -423,21 +442,23 @@ def reduce(
tb_channel_ids = get_program().setup_channel(tb_id, self)

op = ReduceOperation(
rank=self.src_rank,
threadblock=tb_id,
local_src_buff=[LocalChunk(local_src_chunk.buffer, local_src_chunk.index, local_src_chunk.size)],
local_dst_buff=[LocalChunk(local_dst_chunk.buffer, local_dst_chunk.index, local_dst_chunk.size)],
remote_src_buff=remote_chunks,
remote_dst_buff=[],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
if tb_group is not None
else None
),
tbg=(tb_group if tb_group is not None else None),
reduce_operation=reduce_op,
)
operations.apend(op)

get_program().add_operation(self.src_rank, tb_id, op)
if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)


@dataclass
Expand All @@ -457,6 +478,7 @@ class PortChannel:
"""

_channel_counts = defaultdict(int)
_channel_peer_counts = defaultdict(int)

@classmethod
def reset(cls):
Expand Down Expand Up @@ -484,6 +506,8 @@ def __init__(self, dst_rank: int, src_rank: int):

self.channel_id = PortChannel._channel_counts[src_rank]
PortChannel._channel_counts[src_rank] += 1
self.channel_peer_id = PortChannel._channel_peer_counts[(src_rank, dst_rank)]
PortChannel._channel_peer_counts[(src_rank, dst_rank)] += 1

self.dst_rank = dst_rank
self.src_rank = src_rank
Expand All @@ -506,7 +530,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both):
>>> channel.signal(tb=0, data_sync=SyncType.before)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync)
op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)

def wait(self, tb: int, data_sync: SyncType = SyncType.both):
Expand All @@ -525,7 +549,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both):
>>> channel.wait(tb=0, data_sync=SyncType.after)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync)
op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)

def flush(self, tb: int, data_sync: SyncType = SyncType.both):
Expand All @@ -544,7 +568,7 @@ def flush(self, tb: int, data_sync: SyncType = SyncType.both):
>>> channel.flush(tb=0, data_sync=SyncType.after)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = FlushOperation(tb_channel_ids, self.channel_type, data_sync)
op = FlushOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)

def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
Expand Down Expand Up @@ -583,6 +607,8 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
tb_channel_ids = get_program().setup_channel(tb, self)

op = PutOperation(
rank=self.src_rank,
threadblock=tb,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
Expand Down Expand Up @@ -628,6 +654,8 @@ def put_with_signal(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
tb_channel_ids = get_program().setup_channel(tb, self)

op = PutOperation(
rank=self.src_rank,
threadblock=tb,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
Expand Down Expand Up @@ -673,6 +701,8 @@ def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int)
tb_channel_ids = get_program().setup_channel(tb, self)

op = PutOperation(
rank=self.src_rank,
threadblock=tb,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
Expand Down Expand Up @@ -772,6 +802,8 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
tb_channel_ids = get_program().setup_channel(tb, self)

op = PutOperation(
rank=self.src_rank,
threadblock=tb,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
Expand Down Expand Up @@ -900,13 +932,15 @@ def reduce(self, rank, buffer_offset, size, dst_chunk: Chunk, tb, reduce_op=Redu

tb_channel_ids = get_program().setup_channel(tb, self)
op = GroupLoadReduce(
self.buffer_type,
buffer_offset,
size,
dst_chunk,
tb_channel_ids,
self.channel_type,
reduce_op,
rank=self.src_rank,
threadblock=tb,
buffer_type=self.buffer_type,
buffer_offset=buffer_offset,
size=size,
dst_chunk=dst_chunk,
tb_channel_ids=tb_channel_ids,
channel_type=self.channel_type,
reduce_operation=reduce_op,
)
get_program().add_operation(self.src_rank, tb, op)

Expand Down Expand Up @@ -950,7 +984,9 @@ def broadcast(self, rank, src_chunk: Chunk, buffer_offset, size, tb):
)

tb_channel_ids = get_program().setup_channel(tb, self)
op = GroupStore(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type)
op = GroupStore(
self.src_rank, tb, src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type
)
get_program().add_operation(self.src_rank, tb, op)

class SwitchChannelRankView:
Expand Down
Loading