From c3b76ef65288a7d3161919a94f666ac0e0fc9c82 Mon Sep 17 00:00:00 2001 From: Caio Rocha Date: Mon, 8 Sep 2025 23:16:02 +0000 Subject: [PATCH 1/7] wip --- python/mscclpp/language/channel.py | 110 +++++-- .../language/internal/buffer_access.py | 104 ++++--- .../mscclpp/language/internal/op_dep_graph.py | 283 ++++++++++++++++++ .../mscclpp/language/internal/operations.py | 236 ++++++++++----- python/mscclpp/language/internal/register.py | 22 ++ python/mscclpp/language/internal/types.py | 71 ++++- python/mscclpp/language/program.py | 25 +- python/mscclpp/language/rank.py | 34 ++- python/mscclpp/language/thread_block_group.py | 27 +- 9 files changed, 749 insertions(+), 163 deletions(-) create mode 100644 python/mscclpp/language/internal/op_dep_graph.py create mode 100644 python/mscclpp/language/internal/register.py diff --git a/python/mscclpp/language/channel.py b/python/mscclpp/language/channel.py index b413a7728..453024c17 100644 --- a/python/mscclpp/language/channel.py +++ b/python/mscclpp/language/channel.py @@ -25,6 +25,7 @@ class MemoryChannel: """ _channel_counts = defaultdict(int) + _channel_peer_counts = defaultdict(int) def __init__(self, dst_rank: int, src_rank: int): """Initialize a new MemoryChannel. @@ -47,6 +48,8 @@ def __init__(self, dst_rank: int, src_rank: int): self.channel_id = MemoryChannel._channel_counts[src_rank] MemoryChannel._channel_counts[src_rank] += 1 + self.channel_peer_id = MemoryChannel._channel_peer_counts[(src_rank, dst_rank)] + MemoryChannel._channel_peer_counts[(src_rank, dst_rank)] += 1 self.dst_rank = dst_rank self.src_rank = src_rank @@ -71,7 +74,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = F >>> channel.signal(tb=0, data_sync=SyncType.before) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) + op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False): @@ -92,7 +95,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = Fal >>> channel.wait(tb=0, data_sync=SyncType.after) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) + op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): @@ -133,21 +136,29 @@ def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb, self) op = GetOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)], dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), ) - get_program().add_operation(self.src_rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Send data from local memory to remote memory. @@ -192,21 +203,29 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb_id, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), ) - get_program().add_operation(self.src_rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Transfer data in packet format from local to remote scratch buffer. @@ -256,23 +275,31 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, t "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb_id, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), from_packet=True, to_packet=True, ) - get_program().add_operation(self.src_rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Transfer data from local buffer to remote scratch buffer in packet format. @@ -320,24 +347,31 @@ def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_gro "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb_id, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), from_packet=False, to_packet=True, ) + operations.append(op) - get_program().add_operation(self.src_rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def reduce( self, @@ -400,6 +434,7 @@ def reduce( "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: remote_chunks = [ RemoteChunk( @@ -418,21 +453,27 @@ def reduce( tb_channel_ids = get_program().setup_channel(tb_id, self) op = ReduceOperation( + rank=self.src_rank, + threadblock=tb_id, local_src_buff=[LocalChunk(local_src_chunk.buffer, local_src_chunk.index, local_src_chunk.size)], local_dst_buff=[LocalChunk(local_dst_chunk.buffer, local_dst_chunk.index, local_dst_chunk.size)], remote_src_buff=remote_chunks, remote_dst_buff=[], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), reduce_operation=reduce_op, ) + operations.apend(op) - get_program().add_operation(self.src_rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) @dataclass @@ -452,6 +493,7 @@ class PortChannel: """ _channel_counts = defaultdict(int) + _channel_peer_counts = defaultdict(int) def __init__(self, dst_rank: int, src_rank: int): """Initialize a new PortChannel. @@ -474,6 +516,8 @@ def __init__(self, dst_rank: int, src_rank: int): self.channel_id = PortChannel._channel_counts[src_rank] PortChannel._channel_counts[src_rank] += 1 + self.channel_peer_id = PortChannel._channel_peer_counts[(src_rank, dst_rank)] + PortChannel._channel_peer_counts[(src_rank, dst_rank)] += 1 self.dst_rank = dst_rank self.src_rank = src_rank @@ -496,7 +540,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both): >>> channel.signal(tb=0, data_sync=SyncType.before) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = SignalOperation(tb_channel_ids, self.channel_type, data_sync) + op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) def wait(self, tb: int, data_sync: SyncType = SyncType.both): @@ -515,7 +559,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both): >>> channel.wait(tb=0, data_sync=SyncType.after) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = WaitOperation(tb_channel_ids, self.channel_type, data_sync) + op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) def flush(self, tb: int, data_sync: SyncType = SyncType.both): @@ -534,7 +578,7 @@ def flush(self, tb: int, data_sync: SyncType = SyncType.both): >>> channel.flush(tb=0, data_sync=SyncType.after) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = FlushOperation(tb_channel_ids, self.channel_type, data_sync) + op = FlushOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): @@ -573,6 +617,8 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -618,6 +664,8 @@ def put_with_signal(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -663,6 +711,8 @@ def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int) tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -713,6 +763,8 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -836,13 +888,15 @@ def reduce(self, rank, buffer_offset, size, dst_chunk: Chunk, tb, reduce_op=Redu tb_channel_ids = get_program().setup_channel(tb, self) op = GroupLoadReduce( - self.buffer_type, - buffer_offset, - size, - dst_chunk, - tb_channel_ids, - self.channel_type, - reduce_op, + rank=self.src_rank, + threadblock=tb, + buffer_type=self.buffer_type, + buffer_offset=buffer_offset, + size=size, + dst_chunk=dst_chunk, + tb_channel_ids=tb_channel_ids, + channel_type=self.channel_type, + reduce_operation=reduce_op, ) get_program().add_operation(self.src_rank, tb, op) @@ -886,7 +940,7 @@ def broadcast(self, rank, src_chunk: Chunk, buffer_offset, size, tb): ) tb_channel_ids = get_program().setup_channel(tb, self) - op = GroupStore(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type) + op = GroupStore(self.src_rank, tb, src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type) get_program().add_operation(self.src_rank, tb, op) class SwitchChannelRankView: diff --git a/python/mscclpp/language/internal/buffer_access.py b/python/mscclpp/language/internal/buffer_access.py index ab8a7bdcc..f2b4987b9 100644 --- a/python/mscclpp/language/internal/buffer_access.py +++ b/python/mscclpp/language/internal/buffer_access.py @@ -3,92 +3,126 @@ from sortedcontainers import SortedDict from typing import List -from mscclpp.language.internal.types import BufferType, DataAccessType +from mscclpp.language.internal.types import * from mscclpp.language.internal.operations import * from enum import Enum class BuffersAccess: - def __init__(self): - self.intervals = { - BufferType.input: SortedDict(), - BufferType.output: SortedDict(), - BufferType.scratch: SortedDict(), - } + def __init__(self, num_ranks): + self.rank_intervals = [ + { + BufferType.input: SortedDict(), + BufferType.output: SortedDict(), + BufferType.scratch: SortedDict(), + } + for _ in range(num_ranks) + ] + self.track_sync = {} + self.track_barrier = {} def process_operations(self, operations): result_operations = [] - for operation in operations: + for i in range(len(operations)): + operation = operations[i] if operation.name == Instruction.nop or operation.name == Instruction.barrier: - self.clear_data_access() + self.track_sync[operation.rank, operation.threadblock] = i + if operation.name == Instruction.barrier: + self.update_barrier(operation, i) else: if operation.name == Instruction.pipeline: pipeline_buffer_access = BuffersAccess() pipeline_result_operations = pipeline_buffer_access.process_operations(operation.operations) operation.operations = pipeline_result_operations - data_access = operation.local_data_access() - sync_added = False + data_access = operation.local_data_access(i) + data_access_conflict = DataAccessConflict(operation.rank) for data_access_element in data_access: - if self.compute_data_access(data_access_element) and not sync_added: - result_operations.append(SyncOperation()) - sync_added = True + data_access_conflict = data_access_conflict + self.compute_data_access(data_access_element) + fix_operations = self.resolve_conflicts(operation.rank, operation.threadblock, i, data_access_conflict) + result_operations.extend(fix_operations) result_operations.append(operation) return result_operations + def update_barrier(self, operation, order_id): + for tb in operation.barrier_info.tb_list: + if operation.threadblock != tb: + self.track_barrier[operation.rank, operation.threadblock, tb] = order_id + def compute_data_access(self, data_access: DataAccess) -> bool: - keys = self.intervals[data_access.buffer_type].keys() + intervals = self.rank_intervals[data_access.rank] + keys = intervals[data_access.buffer_type].keys() idx = self.lower_bound(0, len(keys) - 1, keys, data_access) - conflict = False + conflict = DataAccessConflict(data_access.rank) while len(keys) > 0 and data_access.overlaps(keys[idx]): conflict_data_access = keys[idx] - conflict_operation_type = self.intervals[data_access.buffer_type][conflict_data_access] - if data_access.check_conflict(conflict_data_access): - self.clear_data_access() - conflict = True - break + conflict_operation_type = intervals[data_access.buffer_type][conflict_data_access] + conflict = conflict + data_access.check_conflict(conflict_data_access) - self.intervals[data_access.buffer_type].pop(conflict_data_access) + intervals[data_access.buffer_type].pop(conflict_data_access) if conflict_data_access.end > data_access.end: - self.intervals[data_access.buffer_type][ + intervals[data_access.buffer_type][ DataAccess( - conflict_data_access.operation_id, - data_access.end + 1, + conflict_data_access.rank, + conflict_data_access.threadblock, + conflict_data_access.operation_global_id, + conflict_data_access.operation_order_id, + data_access.end, conflict_data_access.end, conflict_data_access.buffer_type, conflict_operation_type, ) ] = conflict_operation_type if conflict_data_access.start < data_access.start: - self.intervals[data_access.buffer_type][ + intervals[data_access.buffer_type][ DataAccess( - conflict_data_access.operation_id, + conflict_data_access.rank, + conflict_data_access.threadblock, + conflict_data_access.operation_global_id, + conflict_data_access.operation_order_id, conflict_data_access.start, - data_access.start - 1, + data_access.start, conflict_data_access.buffer_type, conflict_operation_type, ) ] = conflict_operation_type - keys = self.intervals[data_access.buffer_type].keys() + keys = intervals[data_access.buffer_type].keys() idx = self.lower_bound(0, len(keys) - 1, keys, data_access) - self.intervals[data_access.buffer_type][data_access] = data_access.data_access_type + intervals[data_access.buffer_type][data_access] = data_access.data_access_type return conflict + + def resolve_conflicts(self, rank, threadblock, order_id, data_access_conflict: DataAccessConflict): + fix_operations = [] + if data_access_conflict.conflict_type == DataAccessConflictType.intra_threadblock: + for tb in data_access_conflict.threadblocks: + if tb[1] > self.data_sync[(rank, threadblock)]: + fix_operations.append(SyncOperation(rank, threadblock)) + self.data_sync[(rank, threadblock)] = order_id + break + if data_access_conflict.conflict_type == DataAccessConflictType.inter_threadblock: + conflict_tb = [threadblock] + for tb in data_access_conflict.threadblocks: + if (threadblock, tb[0]) not in self.track_barrier or self.track_barrier[(threadblock, tb[0])] < tb[1]: + if not tb[2]: + raise RuntimeError("Operations order not defined.") + conflict_tb.append(tb[0]) + for tb in conflict_tb: + op = BarrierOperation(rank, tb, conflict_tb) + self.update_barrier(op, order_id) + fix_operations.append(op) - def clear_data_access(self): - self.intervals[BufferType.input].clear() - self.intervals[BufferType.output].clear() - self.intervals[BufferType.scratch].clear() + return fix_operations def lower_bound(self, init_pos, final_pos, data_access_list, data_access): if init_pos >= final_pos: return init_pos mid_pos = (init_pos + final_pos) // 2 - if data_access.start <= data_access_list[mid_pos].end: + if data_access.lower_overlaps(data_access_list[mid_pos]): final_pos = mid_pos else: init_pos = mid_pos + 1 diff --git a/python/mscclpp/language/internal/op_dep_graph.py b/python/mscclpp/language/internal/op_dep_graph.py new file mode 100644 index 000000000..d884ffffb --- /dev/null +++ b/python/mscclpp/language/internal/op_dep_graph.py @@ -0,0 +1,283 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from mscclpp.language.internal.operations import * +from mscclpp.language.internal.types import * +from mscclpp.language.internal.register import ChannelRegister, SemaphoreRegister +from queue import Queue +from typing import Set, Dict, Tuple + + +class OperationDependencyGraph: + """ + A DAG structure to enforce correct execution order of collective communication operations. + Supports topological sorting based on rank/threadblock execution and signal/wait synchronization. + """ + + def __init__(self): + self.root_nodes: Set[OperationDependencyGraph.Node] = set() + self.last_node: Dict[Tuple[int, int], int] = {} + self.signalling: Dict[Tuple[int, int, int], Queue] = {} + self.waiting: Dict[Tuple[int, int, int], Queue] = {} + + self.barrier_nodes: Dict[Tuple[int, int], List[OperationDependencyGraph.Node]] = {} + self.tb_barriers: Dict[Tuple[int, int, int], int] = {} + + def add_operation(self, operation, agg_node = None): + """ + Inserts an operation into the DAG, adding edges based on dependencies. + """ + rank = operation.rank + threadblock = operation.threadblock + node = self.Node(operation) + if agg_node is not None: + agg_node.add_node(node) + + if isinstance(operation, BarrierOperation): + if (rank, threadblock, operation.barrier_id) not in self.tb_barriers: + self.tb_barriers[(rank, threadblock, operation.barrier_id)] = 0 + if (rank, operation.barrier_id) not in self.barrier_nodes: + self.barrier_nodes[(rank, operation.barrier_id)] = [] + + barrier_count = self.tb_barriers[(rank, threadblock, operation.barrier_id)] + if barrier_count > len(self.barrier_nodes[(rank, operation.barrier_id)]): + raise RuntimeError(f"Barrier node not create correctly for rank {rank}, threadblock {threadblock}, barrier_id {operation.barrier_id}.") + elif barrier_count == len(self.barrier_nodes[(rank, operation.barrier_id)]): + agg_node = self.AggregateNode() + self.barrier_nodes[(rank, operation.barrier_id)].append(agg_node) + else: + agg_node = self.barrier_nodes[(rank, operation.barrier_id)][barrier_count] + + self.tb_barriers[(rank, threadblock, operation.barrier_id)] += 1 + agg_node.add_node(node) + node = agg_node + + if (rank, threadblock) not in self.last_node: + self.last_node[(rank, threadblock)] = node + if node.get_input() == 0: + self.root_nodes.add(node) + else: + prev_node = self.last_node[(rank, threadblock)] + if prev_node is not node: + prev_node.next_nodes.append(node) + node.previous_nodes.append(prev_node) + node.add_input() + self.last_node[(rank, threadblock)] = node + if node in self.root_nodes: + self.root_nodes.remove(node) + + if isinstance(operation, SignalOperation) or (isinstance(operation, PutOperation) and (operation.with_signal or operation.with_signal_and_flush)): + for tb_channel_id in operation.channel_ids: + channel = ChannelRegister.get_channel(rank, threadblock, tb_channel_id) + op_info = (channel.src_rank, channel.dst_rank, channel.channel_peer_id) + if op_info not in self.waiting or self.waiting[op_info].empty(): + if op_info not in self.signalling: + self.signalling[op_info] = Queue() + self.signalling[op_info].put(node) + else: + waiting_node = self.waiting[op_info].get() + node.next_nodes.append(waiting_node) + waiting_node.previous_nodes.append(node) + waiting_node.add_input() + + if isinstance(operation, WaitOperation): + for tb_channel_id in operation.channel_ids: + channel = ChannelRegister.get_channel(rank, threadblock, tb_channel_id) + op_info = (channel.dst_rank, channel.src_rank, channel.channel_peer_id) + if op_info not in self.signalling or self.signalling[op_info].empty(): + if op_info not in self.waiting: + self.waiting[op_info] = Queue() + self.waiting[op_info].put(node) + else: + signalling_node = self.signalling[op_info].get() + signalling_node.next_nodes.append(node) + node.previous_nodes.append(signalling_node) + node.add_input() + + return node + + def add_tbg_operation(self, operations): + agg_node = self.AggregateNode() + for operation in operations: + self.add_operation(operation, agg_node) + + def add_semaphore_dependency(self): + queue = Queue() + sem_rel = {} + sem_acq = {} + sem_val = {} + + def compute_sem_op(sem_op, node): + for id in node.operaion.semaphore_ids: + if (node.operaion.rank, id) not in sem_op: + sem_op[(node.operaion.rank, id)] = [] + sem_val[(node.operaion.rank, id)] = SemaphoreRegister[(node.operaion.rank, id)].initial_value + sem_op[(node.operaion.rank, id)].append(node) + + def process_node(node): + for next_node in node.next_nodes: + next_node.add_reach() + if next_node.get_reach() == next_node.get_input(): + queue.put(next_node) + + for node in self.root_nodes: + queue.put(node) + + while True: + while not queue.empty(): + node = queue.get() + if isinstance(node, self.Node) and isinstance(node.operation, SemaphoreReleaseOperation): + compute_sem_op(sem_rel, node) + elif isinstance(node, self.Node) and isinstance(node.operation, SemaphoreAcquireOperation): + compute_sem_op(sem_acq, node) + else: + process_node(node) + + if not sem_rel and not sem_acq: + break + else: + if sem_rel.keys() != sem_acq.keys(): + raise RuntimeError(f"Undefined Semaphore Behaviour.") + else: + for key, sem_rel_nodes in sem_rel.keys(): + if len(sem_acq[key]) > 1 or sem_val[key] != sem_rel[key] - sem_acq[key]: + raise RuntimeError(f"Undefined Behaviour Semaphore Id {key[1]}.") + else: + sem_acq_node = sem_acq[key][0] + process_node(sem_acq_node) + for sem_rel_node in sem_rel_nodes: + sem_rel_nodes.next_nodes.append(sem_acq_node) + sem_acq_node.previous_nodes.append(sem_rel_node) + sem_acq_node.add_input() + process_node(sem_rel_node) + + def reset(self): + queue = Queue() + visited = set() + for node in self.root_nodes: + visited.add(node) + queue.put(node) + + while not queue.empty(): + node = queue.get() + node.reset() + for next_node in node.next_nodes: + if next_node not in visited: + visited.add(next_node) + queue.put(next_node) + + def check(self): + """ + Validates the DAG structure, ensuring all nodes are reachable and dependencies are correctly set. + """ + if len(self.signalling) > 0: + for key, queue in self.signalling.items(): + if not queue.empty(): + raise RuntimeError(f"Signalling from {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent wait operation.") + if len(self.waiting) > 0: + for key, queue in self.waiting.items(): + if not queue.empty(): + raise RuntimeError(f"Waiting for {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent signal operation.") + + def get_execution_order(self): + """ + Returns the order of operations in the DAG. + """ + self.reset() + self.check() + + order = [] + queue = Queue() + for node in self.root_nodes: + queue.put(node) + + while not queue.empty(): + node = queue.get() + order.extend(node.get_operations()) + for next_node in node.next_nodes: + next_node.add_reach() + if next_node.get_reach() == next_node.get_input(): + if isinstance(next_node, self.Node) and next_node.agg_node is not None: + for sub_node in next_node.agg_node.nodes: + queue.put(sub_node) + else: + queue.put(next_node) + + return order + + class BaseNode(): + def __init__(self): + self.previous_nodes = [] + self.next_nodes = [] + self.input = 0 + self.reach = 0 + + def add_input(self): + self.input += 1 + + def add_reach(self): + self.reach += 1 + + def get_input(self): + return self.input + + def get_reach(self): + return self.reach + + def reset(self): + self.reach = 0 + + class Node(BaseNode): + def __init__(self, operation): + self.operation = operation + self.agg_node = None + super().__init__() + + def get_operations(self): + return [self.operation] + + def add_input(self): + if self.agg_node is not None: + self.agg_node.input += 1 + else: + self.input += 1 + + def add_reach(self): + if self.agg_node is not None: + self.agg_node.reach += 1 + else: + self.reach += 1 + + def get_input(self): + if self.agg_node is not None: + return self.agg_node.input + else: + return self.input + + def get_reach(self): + if self.agg_node is not None: + return self.agg_node.reach + else: + return self.reach + + def reset(self): + if self.agg_node is not None: + self.agg_node.reset() + else: + self.reach = 0 + + + class AggregateNode(BaseNode): + def __init__(self): + self.nodes = [] + super().__init__() + + def add_node(self, node): + self.nodes.append(node) + node.agg_node = self + + def get_operations(self): + operations = [] + for node in self.nodes: + operations.append(node) + return operations \ No newline at end of file diff --git a/python/mscclpp/language/internal/operations.py b/python/mscclpp/language/internal/operations.py index 1083c45bd..9bdd90aab 100644 --- a/python/mscclpp/language/internal/operations.py +++ b/python/mscclpp/language/internal/operations.py @@ -11,6 +11,7 @@ DataAccess, DataAccessType, ) +from mscclpp.language.thread_block_group import ThreadBlockGroup from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import List @@ -32,6 +33,8 @@ class BaseOperation(ABC): """ id: uuid.UUID = field(default_factory=uuid.uuid4, init=False) + rank: int + threadblock: int name: str def local_data_access(self, sync_purpose=True): @@ -120,10 +123,16 @@ class ThreadBlockGroupInfo: def to_dict(self): return {"tb_id": self.tb_id, "tbg_size": self.tbg_size} + def start_offset(self, size): + return (size / self.tbg_size) * self.tb_id + + def end_offset(self, size): + return (size / self.tbg_size) * (self.tb_id + 1) + class SyncOperation(BaseOperation): - def __init__(self): - super().__init__(Instruction.nop) + def __init__(self, rank: int, threadblock: int): + super().__init__(rank, threadblock, Instruction.nop) def __add__(self, other): fused_operation = None @@ -146,36 +155,58 @@ def to_dict(self): class CopyOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, src_buff: List[LocalChunk], dst_buff: List[LocalChunk], - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroupInfo = None, from_packet: bool = False, to_packet: bool = False, ): if from_packet and to_packet: raise RuntimeError(f"Copy Operation from Packet to Packet is not Supported.") elif from_packet: - super().__init__(Instruction.unpack_packet) + super().__init__(rank, threadblock, Instruction.copy_packet) elif to_packet: - super().__init__(Instruction.copy_packet) + super().__init__(rank, threadblock, Instruction.transform_to_packet) else: - super().__init__(Instruction.copy) + super().__init__(rank, threadblock, Instruction.copy) self.src_buff = src_buff self.dst_buff = dst_buff - self.tbg_info = tbg_info + self.tbg = tbg - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] if self.name != Instruction.unpack_packet or not sync_purpose: for chunk in self.src_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.read) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.read, + self.tbg + ) ) if self.name != Instruction.copy_packet or not sync_purpose: for chunk in self.dst_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.write) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.write, + self.tbg + ) ) return data_access @@ -193,14 +224,14 @@ def to_dict(self): result["dst_buff"] = [] for chunk in self.dst_buff: result["dst_buff"].append(chunk.to_dict()) - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result class SemaphoreAcquireOperation(BaseOperation): - def __init__(self, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): - super().__init__(Instruction.sem_acquire) + def __init__(self, rank: int, threadblock: int, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): + super().__init__(rank, threadblock, Instruction.sem_acquire) self.semaphore_ids = semaphore_ids self.data_sync = data_sync @@ -232,8 +263,8 @@ def to_dict(self): class SemaphoreReleaseOperation(BaseOperation): - def __init__(self, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): - super().__init__(Instruction.sem_release) + def __init__(self, rank: int, threadblock: int, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): + super().__init__(rank, threadblock, Instruction.sem_release) self.semaphore_ids = semaphore_ids self.data_sync = data_sync @@ -267,15 +298,17 @@ def to_dict(self): class SignalOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, channels_ids: List[int], channel_type: ChannelType, data_sync: SyncType = SyncType.none, relaxed: bool = False, ): if relaxed: - super().__init__(Instruction.relaxed_signal) + super().__init__(rank, threadblock, Instruction.relaxed_signal) else: - super().__init__(Instruction.signal) + super().__init__(rank, threadblock, Instruction.signal) self.channel_ids = set(channels_ids) self.channel_type = channel_type self.data_sync = data_sync @@ -314,15 +347,17 @@ def to_dict(self): class WaitOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, channels_ids: List[int], channel_type: ChannelType, data_sync: SyncType = SyncType.none, relaxed: bool = False, ): if relaxed: - super().__init__(Instruction.relaxed_wait) + super().__init__(rank, threadblock, Instruction.relaxed_wait) else: - super().__init__(Instruction.wait) + super().__init__(rank, threadblock, Instruction.wait) self.channel_ids = set(channels_ids) self.channel_type = channel_type self.data_sync = data_sync @@ -361,7 +396,7 @@ def to_dict(self): class BarrierOperation(BaseOperation): __current_barriers = [] - def __init__(self, rank: int, tb_list: List[int]): + def __init__(self, rank: int, threadblock: int, tb_list: List[int]): for _ in range(len(BarrierOperation.__current_barriers), rank + 1): BarrierOperation.__current_barriers.append({}) barrier_info = BarrierOperation.BarrierInfo(tb_list) @@ -372,7 +407,7 @@ def __init__(self, rank: int, tb_list: List[int]): else: self.barrier_id = BarrierOperation.__current_barriers[rank][barrier_info] - super().__init__(Instruction.barrier) + super().__init__(rank, threadblock, Instruction.barrier) self.barrier_info = barrier_info def shift_ids(self, instance, num_instances, replication_function): @@ -406,8 +441,15 @@ def __hash__(self): class FlushOperation(BaseOperation): - def __init__(self, channels_ids: List[int], channel_type: ChannelType, data_sync: SyncType = SyncType.none): - super().__init__(Instruction.flush) + def __init__( + self, + rank: int, + threadblock: int, + channels_ids: List[int], + channel_type: ChannelType, + data_sync: SyncType = SyncType.none, + ): + super().__init__(rank, threadblock, Instruction.flush) self.channel_ids = set(channels_ids) self.channel_type = channel_type self.data_sync = data_sync @@ -440,24 +482,36 @@ def to_dict(self): class GetOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, src_buff: List[RemoteChunk], dst_buff: List[LocalChunk], channel_ids: List[int], channel_type: ChannelType, - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, ): - super().__init__(Instruction.get) + super().__init__(rank, threadblock, Instruction.get) self.src_buff = src_buff self.dst_buff = dst_buff self.channel_ids = channel_ids self.channel_type = channel_type - self.tbg_info = tbg_info + self.tbg = tbg - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] for chunk in self.dst_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.write) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.write, + self.tbg + ) ) return data_access @@ -473,14 +527,14 @@ def __add__(self, other): isinstance(other, GetOperation) and self.src_buff[0].size == other.src_buff[0].size and self.channel_type == other.channel_type - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = GetOperation( src_buff=self.src_buff + other.src_buff, dst_buff=self.dst_buff + other.dst_buff, channel_ids=self.channel_ids + other.channel_ids, channel_type=self.channel_type, - tbg_info=self.tbg_info, + tbg=self.tbg, ) return fused_operation @@ -495,40 +549,42 @@ def to_dict(self): result["dst_buff"].append(chunk.to_dict()) result["channel_ids"] = self.channel_ids result["channel_type"] = self.channel_type.value - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result class PutOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, src_buff: List[LocalChunk], dst_buff: List[RemoteChunk], channel_ids: List[int], channel_type: ChannelType, - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroupInfo = None, from_packet: bool = False, to_packet: bool = False, with_signal: bool = False, with_signal_and_flush: bool = False, ): if from_packet and to_packet: - super().__init__(Instruction.read_put_packet) + super().__init__(rank, threadblock, Instruction.read_put_packet) elif to_packet: - super().__init__(Instruction.put_packet) + super().__init__(rank, threadblock, Instruction.put_packet) elif from_packet: raise RuntimeError(f"Put Operation from Packet is not Supported.") else: if with_signal: if with_signal_and_flush: - super().__init__(Instruction.put_with_signal_and_flush) + super().__init__(rank, threadblock, Instruction.put_with_signal_and_flush) else: - super().__init__(Instruction.put_with_signal) + super().__init__(rank, threadblock, Instruction.put_with_signal) elif with_signal_and_flush: - super().__init__(Instruction.put_with_signal_and_flush) + super().__init__(rank, threadblock, Instruction.put_with_signal_and_flush) else: - super().__init__(Instruction.put) + super().__init__(rank, threadblock, Instruction.put) self.src_buff = src_buff self.dst_buff = dst_buff @@ -537,14 +593,24 @@ def __init__( self.to_packet = to_packet self.with_signal = with_signal self.with_signal_and_flush = with_signal_and_flush - self.tbg_info = tbg_info + self.tbg = tbg - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] if self.name != Instruction.read_put_packet or not sync_purpose: for chunk in self.src_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.read) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.read, + self.tbg + ) ) return data_access @@ -567,14 +633,14 @@ def __add__(self, other): and self.name == other.name and self.src_buff[0].size == other.src_buff[0].size and self.channel_type == other.channel_type - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = PutOperation( src_buff=self.src_buff + other.src_buff, dst_buff=self.dst_buff + other.dst_buff, channel_ids=self.channel_ids + other.channel_ids, channel_type=self.channel_type, - tbg_info=self.tbg_info, + tbg=self.tbg, to_packet=self.to_packet, with_signal=self.with_signal, with_signal_and_flush=self.with_signal_and_flush, @@ -593,8 +659,8 @@ def to_dict(self): if self.channel_type == ChannelType.port: result["channel_ids"] = self.channel_ids result["channel_type"] = self.channel_type.value - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result @@ -602,6 +668,8 @@ def to_dict(self): class ReduceOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, local_src_buff: List[LocalChunk], local_dst_buff: List[LocalChunk], remote_src_buff: List[RemoteChunk] = None, @@ -610,7 +678,7 @@ def __init__( put_channel_ids: List[int] = None, channel_type: ChannelType = ChannelType.none, reduce_operation: ReduceOperationType = ReduceOperationType.sum, - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroupInfo = None, packet: bool = False, ): remote_src_buff = remote_src_buff if remote_src_buff is not None else [] @@ -620,18 +688,18 @@ def __init__( if len(remote_src_buff) == 0 and len(remote_dst_buff) == 0: if packet: - super().__init__(Instruction.reduce_packet) + super().__init__(rank, threadblock, Instruction.reduce_packet) else: - super().__init__(Instruction.reduce) + super().__init__(rank, threadblock, Instruction.reduce) elif len(remote_src_buff) == 0: if packet: - super().__init__(Instruction.reduce_send_packet) + super().__init__(rank, threadblock, Instruction.reduce_send_packet) else: - super().__init__(Instruction.reduce_send) + super().__init__(rank, threadblock, Instruction.reduce_send) elif len(remote_dst_buff) == 0 and not packet: - super().__init__(Instruction.read_reduce) + super().__init__(rank, threadblock, Instruction.read_reduce) elif not packet: - super().__init__(Instruction.read_reduce_send) + super().__init__(rank, threadblock, Instruction.read_reduce_send) else: raise RuntimeError(f"Reduce Operation invalid parameters.") @@ -643,20 +711,40 @@ def __init__( self.put_channel_ids = put_channel_ids self.channel_type = channel_type self.reduce_operation = reduce_operation - self.tbg_info = tbg_info + self.tbg = tbg self.packet = packet - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] for i in range(len(self.local_src_buff)): chunk = self.local_src_buff[i] if not self.packet or i != 0 or not sync_purpose: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.read) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.read, + self.tbg + ) ) for chunk in self.local_dst_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.write) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.write, + self.tbg + ) ) return data_access @@ -684,7 +772,7 @@ def __add__(self, other): and self.local_dst_buff == other.local_dst_buff and self.channel_type == other.channel_type and self.reduce_operation == other.reduce_operation - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = ReduceOperation( self.local_src_buff + other.local_src_buff[1:], @@ -693,7 +781,7 @@ def __add__(self, other): channel_ids=self.channel_ids + other.channel_ids, channel_type=self.channel_type, reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, + tbg=self.tbg, packet=self.packet, ) if ( @@ -707,7 +795,7 @@ def __add__(self, other): and other.name == Instruction.put and self.local_dst_buff[0] == other.src_buff[0] and other.channel_type == ChannelType.memory - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = ReduceOperation( self.local_src_buff, @@ -718,7 +806,7 @@ def __add__(self, other): put_channel_ids=self.put_channel_ids + other.channel_ids, channel_type=self.channel_type, reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, + tbg=self.tbg, packet=self.packet, ) if ( @@ -727,7 +815,7 @@ def __add__(self, other): and other.name == Instruction.put_packet and self.local_dst_buff[0] == other.src_buff[0] and other.channel_type == ChannelType.memory - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = ReduceOperation( self.local_src_buff, @@ -738,7 +826,7 @@ def __add__(self, other): put_channel_ids=self.put_channel_ids + other.channel_ids, channel_type=other.channel_type, reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, + tbg=self.tbg, packet=self.packet, ) @@ -763,8 +851,8 @@ def to_dict(self): if self.channel_type != ChannelType.none: result["channel_type"] = self.channel_type.value result["reduce_op"] = self.reduce_operation.value - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result @@ -772,6 +860,8 @@ def to_dict(self): class GroupLoadReduce(BaseOperation): def __init__( self, + rank, + threadblock: int, buffer_type: BufferType, buffer_offset: int, size: int, @@ -780,7 +870,7 @@ def __init__( channel_type: ChannelType = ChannelType.switch, reduce_operation: ReduceOperationType = ReduceOperationType.sum, ): - super().__init__(Instruction.group_load_reduce) + super().__init__(rank, threadblock, Instruction.group_load_reduce) self.buffer_type = buffer_type self.buffer_offset = buffer_offset self.size = size @@ -831,6 +921,8 @@ def to_dict(self): class GroupStore(BaseOperation): def __init__( self, + rank, + threadblock: int, src_chunk: Chunk, buffer_type: BufferType, buffer_offset: int, @@ -838,7 +930,7 @@ def __init__( channel_ids: List[int], channel_type: ChannelType = ChannelType.switch, ): - super().__init__(Instruction.group_store) + super().__init__(rank, threadblock, Instruction.group_store) self.src_chunk = src_chunk self.buffer_type = buffer_type self.buffer_offset = buffer_offset @@ -865,6 +957,8 @@ def to_dict(self): class GroupLoadReduceStore(BaseOperation): def __init__( self, + rank, + threadblock: int, buffer_type: BufferType, size: int, src_index: List[int], @@ -873,7 +967,7 @@ def __init__( channel_type: ChannelType = ChannelType.switch, reduce_operation: ReduceOperationType = ReduceOperationType.sum, ): - super().__init__(Instruction.group_load_reduce_store) + super().__init__(rank, threadblock, Instruction.group_load_reduce_store) self.buffer_type = buffer_type self.size = size self.src_index = src_index @@ -907,8 +1001,8 @@ def to_dict(self): @dataclass class PipelineOperation(BaseOperation): - def __init__(self, unit_size: int, num_chunks: int, operations=None): - super().__init__(Instruction.pipeline) + def __init__(self, rank: int, threadblock: int, unit_size: int, num_chunks: int, operations=None): + super().__init__(rank, threadblock, Instruction.pipeline) self.unit_size = unit_size self.num_chunks = num_chunks self.operations = operations if operations is not None else [] diff --git a/python/mscclpp/language/internal/register.py b/python/mscclpp/language/internal/register.py new file mode 100644 index 000000000..95e696cb0 --- /dev/null +++ b/python/mscclpp/language/internal/register.py @@ -0,0 +1,22 @@ + +class ChannelRegister: + channels = {} + + @staticmethod + def add_channel(rank, tb, tb_channel_id, channel): + ChannelRegister.channels[(rank, tb, tb_channel_id)] = channel + + @staticmethod + def get_channel(rank: int, threadblock: int, tb_channel_id: int): + return ChannelRegister.channels.get((rank, threadblock, tb_channel_id)) + +class SemaphoreRegister: + semaphores = {} + + @staticmethod + def add_semaphore(semaphore): + SemaphoreRegister.semaphores[(semaphore.rank, semaphore.id)] = semaphore + + @staticmethod + def get_channel(rank: int, semaphore_id: int): + return SemaphoreRegister.semaphores.get((rank, semaphore_id)) \ No newline at end of file diff --git a/python/mscclpp/language/internal/types.py b/python/mscclpp/language/internal/types.py index 411a46959..8f92cea64 100644 --- a/python/mscclpp/language/internal/types.py +++ b/python/mscclpp/language/internal/types.py @@ -1,11 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from dataclasses import dataclass +from mscclpp.language.thread_block_group import ThreadBlockGroup +from dataclasses import dataclass, field from enum import Enum from typing import List, Set from collections import defaultdict - +import uuid class SyncType(Enum): none = "none" @@ -172,33 +173,77 @@ def __str__(self): @dataclass class DataAccess: - operation_id: int - start: int - end: int + rank: int + threadblock: int + operation_global_id: uuid.UUID + operation_order_id: int + start: float + end: float buffer_type: BufferType data_access_type: DataAccessType - + tb_group: ThreadBlockGroup = None + def __lt__(self, other): if self.start != other.start: return self.start < other.start return self.end < other.end - def __eq__(self, other): - return self.start == other.start and self.end == other.end + def __eq__(self, other, tolerance=1e-5): + return (abs(self.start - other.start) < tolerance and + abs(self.end - other.end) < tolerance) def __hash__(self): return hash((self.start, self.end)) - def overlaps(self, other) -> bool: - return self.start <= other.end and other.start <= self.end + def lower_overlaps(self, other, tolerance=1e-5) -> bool: + return (self.start + tolerance < other.end) + + def overlaps(self, other, tolerance=1e-5) -> bool: + return (self.start + tolerance < other.end) and (other.start + tolerance < self.end) def check_conflict(self, other) -> bool: - return ( + if ( self.overlaps(other) - and self.operation_id != other.operation_id + and self.operation_global_id != other.operation_global_id and (self.data_access_type != DataAccessType.read or other.data_access_type != DataAccessType.read) - ) + ): + if self.threadblock == other.threadblock: + return DataAccessConflict(self.rank, {}, DataAccessConflictType.intra_threadblock) + else: + is_order_defined = ((self.tb_group is not None and other.tb_group is not None and self.tb_group.tbg_overlap(other.tb_group)) + or (self.tb_group is not None and other.tb_group is None and self.tb_group.tb_overlap(other.threadblock)) + or (self.tb_group is None and other.tb_group is not None and other.tb_group.tb_overlap(self.threadblock))) + return DataAccessConflict(self.rank, {(other.threadblock, other.operation_order_id, is_order_defined)}, DataAccessConflictType.inter_threadblock) + else: + return DataAccessConflict(self.rank) + +class DataAccessConflictType(Enum): + inter_threadblock = "inter_tb" + intra_threadblock = "intra_tb" + none = "none" + + def __add__(self, other): + if not isinstance(other, DataAccessConflictType): + return NotImplemented + + map_to_num = {DataAccessConflictType.none: 0, DataAccessConflictType.intra_threadblock: 1, DataAccessConflictType.inter_threadblock: 3} + map_to_dact = {0: DataAccessConflictType.none, 1: DataAccessConflictType.intra_threadblock, 3: DataAccessConflictType.inter_threadblock} + return map_to_dact[map_to_num[self] | map_to_num[other]] + + def __str__(self): + return self.value + +@dataclass +class DataAccessConflict(): + rank: int + threadblocks: Set[int] = field(default_factory=set) + conflict_type: DataAccessConflictType = DataAccessConflictType.none + + def __add__(self, other): + if not isinstance(other, DataAccessConflict): + return NotImplemented + return DataAccessConflict(self.rank, self.threadblocks | other.threadblocks, self.conflict_type + other.conflict_type) class ReplicationPolicy(Enum): interleaved = "interleaved" diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index 007a8bcf9..6fc0fd7e3 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -5,6 +5,9 @@ from mscclpp.language.internal.globals import set_program from mscclpp.language.internal.types import BufferType, RemoteBuffer, ChannelType, ReplicationPolicy from mscclpp.language.internal.gpu import Gpu +from mscclpp.language.internal.register import ChannelRegister, SemaphoreRegister +from mscclpp.language.internal.op_dep_graph import OperationDependencyGraph +from mscclpp.language.internal.buffer_access import BuffersAccess from typing import List import json @@ -99,6 +102,8 @@ def __init__( self.min_message_size = min_message_size self.max_message_size = max_message_size assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" + self.op_dep_dag = OperationDependencyGraph() + self.buffers_access = BuffersAccess(num_ranks) self.buffers = collective.init_buffers() self.gpus: List[Gpu] = [] for rank in range(self.num_ranks): @@ -134,22 +139,36 @@ def add_channel(self, channel): def setup_channel(self, tb, channel): tb_channel_ids = [] tb_channel_ids.append(self.gpus[channel.src_rank].setup_channel(tb, channel)) + for tb_channel_id in tb_channel_ids: + ChannelRegister.add_channel(channel.src_rank, tb, tb_channel_id, channel) return tb_channel_ids def setup_remote_chunk(self, rank, tb, remote_chunk: RemoteBuffer, channel_access: ChannelType): return self.gpus[rank].add_remote_buffer(tb, remote_chunk, channel_access) def add_semaphore(self, semaphore): + SemaphoreRegister.add_semaphore(semaphore) self.gpus[semaphore.rank].add_semaphore(semaphore) def add_operation(self, rank, tb, operation): if self.loop_context != None: self.loop_context.add_operation(rank, tb, operation) else: - self.gpus[rank].add_operation(tb, operation) + self.op_dep_dag.add_operation(operation) + + def add_tbg_operation(self, operations): + self.op_dep_dag.add_tbg_operation(operations) def post_process_operations(self): - for gpu in self.gpus: + #self.op_dep_dag.add_semaphore_dependency() + list_op = self.op_dep_dag.get_execution_order() + print(f"execution order operation: {list_op}") + list_op = self.buffers_access.process_operations(list_op) + print(f"adding sync operations: {list_op}") + for op in list_op: + self.gpus[op.rank].add_operation(op.threadblock, op) + + """ for gpu in self.gpus: if self.instr_fusion: gpu.optimize_operations() gpu.adding_data_sync() @@ -159,7 +178,7 @@ def post_process_operations(self): self.instances, self.get_default_replication_policy_function(), self.get_buffer_replication_policy_function(), - ) + ) """ def get_default_replication_policy_function(self): return lambda value, instance, num_instances: value * num_instances + instance diff --git a/python/mscclpp/language/rank.py b/python/mscclpp/language/rank.py index 02374c83a..e09905c00 100644 --- a/python/mscclpp/language/rank.py +++ b/python/mscclpp/language/rank.py @@ -110,20 +110,27 @@ def _copy( "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: op = CopyOperation( + rank=self.rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), from_packet=from_packet, to_packet=to_packet, ) - - get_program().add_operation(self.rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def copy(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Copy data from source chunk to destination chunk. @@ -240,21 +247,28 @@ def reduce( "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: op = ReduceOperation( - [LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)] + rank=self.rank, + threadblock=tb_id, + local_src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)] + [LocalChunk(chunk.buffer, chunk.index, chunk.size) for chunk in other_chunks], - [LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], + local_dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], reduce_operation=reduce_op, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), packet=packet, ) + operations.append(op) - get_program().add_operation(self.rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def barrier(self, tb_list: List[int]): """Create a synchronization barrier between thread blocks. @@ -278,8 +292,8 @@ def barrier(self, tb_list: List[int]): op = SyncOperation() get_program().add_operation(self.rank, tb_list[0], op) else: - op = BarrierOperation(self.rank, tb_list) for tb in tb_list: + op = BarrierOperation(self.rank, tb, tb_list) get_program().add_operation(self.rank, tb, op) diff --git a/python/mscclpp/language/thread_block_group.py b/python/mscclpp/language/thread_block_group.py index 9ffe4e2dd..a0d63c498 100644 --- a/python/mscclpp/language/thread_block_group.py +++ b/python/mscclpp/language/thread_block_group.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import List, Dict +from typing import List, Dict, Set class ThreadBlockGroup: @@ -20,10 +20,9 @@ def __init__(self, tb_list: List[int]): tb_list: List of thread block objects """ - self.tb_list: List[int] = tb_list + self.tb_list: Set[int] = set(tb_list) self._tb_id: Dict[int, int] = {} - # Check for duplicates and build ID mapping seen = set() for i, tb in enumerate(self.tb_list): if tb in seen: @@ -51,3 +50,25 @@ def get_internal_id(self, tb: int) -> int: def numtb(self) -> int: """Return the number of thread blocks in the group.""" return len(self.tb_list) + + def tbg_overlap(self, other): + for tb in self.tb_list: + if tb in other.tb_list: + return True + return False + + def tb_overlap(self, tb_id): + return tb_id in self.tb_list + + def to_dict(self, tb): + return {"tb_id": self.get_internal_id(tb), "tbg_size": self.numtb()} + + def start_offset(self, tb, size): + tb_id = self.get_internal_id(tb) + return (size / self.numtb()) * tb_id + + def end_offset(self, tb, size): + tb_id = self.get_internal_id(tb) + return (size / self.numtb()) * (tb_id + 1) + + From 4615007c73ce5e242b57c34abb86ec51945715a8 Mon Sep 17 00:00:00 2001 From: Caio Rocha Date: Tue, 9 Sep 2025 17:18:40 +0000 Subject: [PATCH 2/7] wip --- python/mscclpp/language/internal/buffer_access.py | 14 +++++++++----- python/mscclpp/language/internal/types.py | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/mscclpp/language/internal/buffer_access.py b/python/mscclpp/language/internal/buffer_access.py index f2b4987b9..c2b645670 100644 --- a/python/mscclpp/language/internal/buffer_access.py +++ b/python/mscclpp/language/internal/buffer_access.py @@ -49,6 +49,7 @@ def update_barrier(self, operation, order_id): for tb in operation.barrier_info.tb_list: if operation.threadblock != tb: self.track_barrier[operation.rank, operation.threadblock, tb] = order_id + self.track_sync[operation.rank, operation.threadblock] = order_id def compute_data_access(self, data_access: DataAccess) -> bool: intervals = self.rank_intervals[data_access.rank] @@ -73,6 +74,8 @@ def compute_data_access(self, data_access: DataAccess) -> bool: conflict_data_access.end, conflict_data_access.buffer_type, conflict_operation_type, + conflict_data_access.tb_group + ) ] = conflict_operation_type if conflict_data_access.start < data_access.start: @@ -86,6 +89,7 @@ def compute_data_access(self, data_access: DataAccess) -> bool: data_access.start, conflict_data_access.buffer_type, conflict_operation_type, + conflict_data_access.tb_group ) ] = conflict_operation_type @@ -99,17 +103,17 @@ def resolve_conflicts(self, rank, threadblock, order_id, data_access_conflict: D fix_operations = [] if data_access_conflict.conflict_type == DataAccessConflictType.intra_threadblock: for tb in data_access_conflict.threadblocks: - if tb[1] > self.data_sync[(rank, threadblock)]: + if (rank, threadblock) not in self.track_sync or tb[1] > self.track_sync[(rank, threadblock)]: fix_operations.append(SyncOperation(rank, threadblock)) - self.data_sync[(rank, threadblock)] = order_id + self.track_sync[(rank, threadblock)] = order_id break if data_access_conflict.conflict_type == DataAccessConflictType.inter_threadblock: - conflict_tb = [threadblock] + conflict_tb = set([threadblock]) for tb in data_access_conflict.threadblocks: - if (threadblock, tb[0]) not in self.track_barrier or self.track_barrier[(threadblock, tb[0])] < tb[1]: + if threadblock != tb[0] and ((rank, threadblock, tb[0]) not in self.track_barrier or self.track_barrier[(rank, threadblock, tb[0])] < tb[1]): if not tb[2]: raise RuntimeError("Operations order not defined.") - conflict_tb.append(tb[0]) + conflict_tb.add(tb[0]) for tb in conflict_tb: op = BarrierOperation(rank, tb, conflict_tb) self.update_barrier(op, order_id) diff --git a/python/mscclpp/language/internal/types.py b/python/mscclpp/language/internal/types.py index 8f92cea64..19000fdd3 100644 --- a/python/mscclpp/language/internal/types.py +++ b/python/mscclpp/language/internal/types.py @@ -208,12 +208,12 @@ def check_conflict(self, other) -> bool: and (self.data_access_type != DataAccessType.read or other.data_access_type != DataAccessType.read) ): if self.threadblock == other.threadblock: - return DataAccessConflict(self.rank, {}, DataAccessConflictType.intra_threadblock) + return DataAccessConflict(self.rank, {(other.threadblock, other.operation_order_id, True)}, DataAccessConflictType.intra_threadblock) else: is_order_defined = ((self.tb_group is not None and other.tb_group is not None and self.tb_group.tbg_overlap(other.tb_group)) or (self.tb_group is not None and other.tb_group is None and self.tb_group.tb_overlap(other.threadblock)) or (self.tb_group is None and other.tb_group is not None and other.tb_group.tb_overlap(self.threadblock))) - return DataAccessConflict(self.rank, {(other.threadblock, other.operation_order_id, is_order_defined)}, DataAccessConflictType.inter_threadblock) + return DataAccessConflict(self.rank, {(self.threadblock, other.operation_order_id, True), (other.threadblock, other.operation_order_id, is_order_defined)}, DataAccessConflictType.inter_threadblock) else: return DataAccessConflict(self.rank) From f4137b5e95616cec42c286af59a1a291412f8ef4 Mon Sep 17 00:00:00 2001 From: Caio Rocha Date: Wed, 10 Sep 2025 08:24:10 +0000 Subject: [PATCH 3/7] wip --- .../language/internal/buffer_access.py | 16 +++- .../mscclpp/language/internal/op_dep_graph.py | 83 ++++++++++++++----- .../mscclpp/language/internal/operations.py | 49 +++++------ python/mscclpp/language/internal/register.py | 2 +- python/mscclpp/language/program.py | 4 +- python/mscclpp/language/rank.py | 4 +- 6 files changed, 97 insertions(+), 61 deletions(-) diff --git a/python/mscclpp/language/internal/buffer_access.py b/python/mscclpp/language/internal/buffer_access.py index c2b645670..7c5f3d023 100644 --- a/python/mscclpp/language/internal/buffer_access.py +++ b/python/mscclpp/language/internal/buffer_access.py @@ -29,6 +29,8 @@ def process_operations(self, operations): self.track_sync[operation.rank, operation.threadblock] = i if operation.name == Instruction.barrier: self.update_barrier(operation, i) + if operation.name == Instruction.sem_acquire: + self.update_semaphore(operation, i) else: if operation.name == Instruction.pipeline: pipeline_buffer_access = BuffersAccess() @@ -51,6 +53,11 @@ def update_barrier(self, operation, order_id): self.track_barrier[operation.rank, operation.threadblock, tb] = order_id self.track_sync[operation.rank, operation.threadblock] = order_id + def update_semaphore(self, operation, order_id): + for tb in operation.tb_sync: + if operation.threadblock != tb: + self.track_barrier[operation.rank, operation.threadblock, tb] = order_id + def compute_data_access(self, data_access: DataAccess) -> bool: intervals = self.rank_intervals[data_access.rank] keys = intervals[data_access.buffer_type].keys() @@ -114,10 +121,11 @@ def resolve_conflicts(self, rank, threadblock, order_id, data_access_conflict: D if not tb[2]: raise RuntimeError("Operations order not defined.") conflict_tb.add(tb[0]) - for tb in conflict_tb: - op = BarrierOperation(rank, tb, conflict_tb) - self.update_barrier(op, order_id) - fix_operations.append(op) + if len(conflict_tb) > 1: + for tb in conflict_tb: + op = BarrierOperation(rank, tb, conflict_tb) + self.update_barrier(op, order_id) + fix_operations.append(op) return fix_operations diff --git a/python/mscclpp/language/internal/op_dep_graph.py b/python/mscclpp/language/internal/op_dep_graph.py index d884ffffb..1f5cb6f0b 100644 --- a/python/mscclpp/language/internal/op_dep_graph.py +++ b/python/mscclpp/language/internal/op_dep_graph.py @@ -22,6 +22,7 @@ def __init__(self): self.barrier_nodes: Dict[Tuple[int, int], List[OperationDependencyGraph.Node]] = {} self.tb_barriers: Dict[Tuple[int, int, int], int] = {} + self.node_list = [] def add_operation(self, operation, agg_node = None): """ @@ -52,6 +53,7 @@ def add_operation(self, operation, agg_node = None): agg_node.add_node(node) node = agg_node + self.node_list.append(node) if (rank, threadblock) not in self.last_node: self.last_node[(rank, threadblock)] = node if node.get_input() == 0: @@ -107,18 +109,25 @@ def add_semaphore_dependency(self): sem_acq = {} sem_val = {} + self.reset() + def compute_sem_op(sem_op, node): - for id in node.operaion.semaphore_ids: - if (node.operaion.rank, id) not in sem_op: - sem_op[(node.operaion.rank, id)] = [] - sem_val[(node.operaion.rank, id)] = SemaphoreRegister[(node.operaion.rank, id)].initial_value - sem_op[(node.operaion.rank, id)].append(node) + for id in node.operation.semaphore_ids: + if (node.operation.rank, id) not in sem_op: + sem_op[(node.operation.rank, id)] = [] + sem_val[(node.operation.rank, id)] = SemaphoreRegister.get_semaphore(node.operation.rank, id).initial_value + sem_op[(node.operation.rank, id)].append(node) def process_node(node): for next_node in node.next_nodes: next_node.add_reach() if next_node.get_reach() == next_node.get_input(): - queue.put(next_node) + if isinstance(next_node, self.Node) and next_node.agg_node is not None: + for sub_node in next_node.agg_node.nodes: + queue.put(sub_node) + else: + queue.put(next_node) + for node in self.root_nodes: queue.put(node) @@ -136,35 +145,61 @@ def process_node(node): if not sem_rel and not sem_acq: break else: - if sem_rel.keys() != sem_acq.keys(): - raise RuntimeError(f"Undefined Semaphore Behaviour.") - else: - for key, sem_rel_nodes in sem_rel.keys(): - if len(sem_acq[key]) > 1 or sem_val[key] != sem_rel[key] - sem_acq[key]: + removed_keys = [] + for key in sem_acq.keys(): + if key in sem_rel: + if len(sem_acq[key]) > 1 or sem_val[key] != len(sem_rel[key]) - len(sem_acq[key]): raise RuntimeError(f"Undefined Behaviour Semaphore Id {key[1]}.") else: sem_acq_node = sem_acq[key][0] + sem_val[key] = 0 + if sem_acq_node in self.root_nodes: + self.root_nodes.remove(sem_acq_node) process_node(sem_acq_node) - for sem_rel_node in sem_rel_nodes: - sem_rel_nodes.next_nodes.append(sem_acq_node) + for sem_rel_node in sem_rel[key]: + process_node(sem_rel_node) + sem_rel_node.next_nodes.append(sem_acq_node) + sem_acq_node.operation.add_tb_sync(sem_rel_node.operation.threadblock) sem_acq_node.previous_nodes.append(sem_rel_node) sem_acq_node.add_input() - process_node(sem_rel_node) + + removed_keys.append(key) + + for key in removed_keys: + sem_rel.pop(key) + sem_acq.pop(key) + + if len(sem_rel.keys()) > 0 or len(sem_acq.keys()): + raise RuntimeError(f"Undefined Semaphore Behaviour.") def reset(self): + for node in self.node_list: + node.reset() + + def print(self): + """ + Returns the order of operations in the DAG. + """ + self.reset() + self.check() + queue = Queue() - visited = set() for node in self.root_nodes: - visited.add(node) queue.put(node) while not queue.empty(): node = queue.get() - node.reset() + print(f"node {node.print()}") for next_node in node.next_nodes: - if next_node not in visited: - visited.add(next_node) - queue.put(next_node) + next_node.add_reach() + print(f"next_node {next_node.print()}") + if next_node.get_reach() == next_node.get_input(): + if isinstance(next_node, self.Node) and next_node.agg_node is not None: + for sub_node in next_node.agg_node.nodes: + queue.put(sub_node) + else: + queue.put(next_node) + print() def check(self): """ @@ -265,6 +300,9 @@ def reset(self): self.agg_node.reset() else: self.reach = 0 + + def print(self): + return f"rank {self.operation.rank} tb {self.operation.threadblock} {self.operation.name}" class AggregateNode(BaseNode): @@ -280,4 +318,7 @@ def get_operations(self): operations = [] for node in self.nodes: operations.append(node) - return operations \ No newline at end of file + return operations + + def print(self): + return f"rank {self.operations[0].rank} tb {self.operations[0].threadblock} {self.operations[0].name}" \ No newline at end of file diff --git a/python/mscclpp/language/internal/operations.py b/python/mscclpp/language/internal/operations.py index 9bdd90aab..c7b731374 100644 --- a/python/mscclpp/language/internal/operations.py +++ b/python/mscclpp/language/internal/operations.py @@ -115,21 +115,6 @@ def to_dict(self): return {"buffer_id": self.buffer_id, "index": self.index, "size": self.size} -@dataclass -class ThreadBlockGroupInfo: - tb_id: int - tbg_size: int - - def to_dict(self): - return {"tb_id": self.tb_id, "tbg_size": self.tbg_size} - - def start_offset(self, size): - return (size / self.tbg_size) * self.tb_id - - def end_offset(self, size): - return (size / self.tbg_size) * (self.tb_id + 1) - - class SyncOperation(BaseOperation): def __init__(self, rank: int, threadblock: int): super().__init__(rank, threadblock, Instruction.nop) @@ -159,7 +144,7 @@ def __init__( threadblock: int, src_buff: List[LocalChunk], dst_buff: List[LocalChunk], - tbg: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, from_packet: bool = False, to_packet: bool = False, ): @@ -186,8 +171,8 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, chunk.type, DataAccessType.read, self.tbg @@ -201,8 +186,8 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, chunk.type, DataAccessType.write, self.tbg @@ -234,6 +219,10 @@ def __init__(self, rank: int, threadblock: int, semaphore_ids: List[int], data_s super().__init__(rank, threadblock, Instruction.sem_acquire) self.semaphore_ids = semaphore_ids self.data_sync = data_sync + self.tb_sync = set() + + def add_tb_sync(self, tb): + self.tb_sync.add(tb) def shift_ids(self, instance, num_instances, replication_function): for i in range(len(self.semaphore_ids)): @@ -506,8 +495,8 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, chunk.type, DataAccessType.write, self.tbg @@ -563,7 +552,7 @@ def __init__( dst_buff: List[RemoteChunk], channel_ids: List[int], channel_type: ChannelType, - tbg: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, from_packet: bool = False, to_packet: bool = False, with_signal: bool = False, @@ -605,8 +594,8 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, chunk.type, DataAccessType.read, self.tbg @@ -678,7 +667,7 @@ def __init__( put_channel_ids: List[int] = None, channel_type: ChannelType = ChannelType.none, reduce_operation: ReduceOperationType = ReduceOperationType.sum, - tbg: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, packet: bool = False, ): remote_src_buff = remote_src_buff if remote_src_buff is not None else [] @@ -725,8 +714,8 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, chunk.type, DataAccessType.read, self.tbg @@ -739,8 +728,8 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(chunk.size, self.threadblock) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(chunk.size, self.threadblock) if self.tbg is not None else chunk.size, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, chunk.type, DataAccessType.write, self.tbg diff --git a/python/mscclpp/language/internal/register.py b/python/mscclpp/language/internal/register.py index 95e696cb0..ed0737d9a 100644 --- a/python/mscclpp/language/internal/register.py +++ b/python/mscclpp/language/internal/register.py @@ -18,5 +18,5 @@ def add_semaphore(semaphore): SemaphoreRegister.semaphores[(semaphore.rank, semaphore.id)] = semaphore @staticmethod - def get_channel(rank: int, semaphore_id: int): + def get_semaphore(rank: int, semaphore_id: int): return SemaphoreRegister.semaphores.get((rank, semaphore_id)) \ No newline at end of file diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index 6fc0fd7e3..019063a79 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -160,11 +160,9 @@ def add_tbg_operation(self, operations): self.op_dep_dag.add_tbg_operation(operations) def post_process_operations(self): - #self.op_dep_dag.add_semaphore_dependency() + self.op_dep_dag.add_semaphore_dependency() list_op = self.op_dep_dag.get_execution_order() - print(f"execution order operation: {list_op}") list_op = self.buffers_access.process_operations(list_op) - print(f"adding sync operations: {list_op}") for op in list_op: self.gpus[op.rank].add_operation(op.threadblock, op) diff --git a/python/mscclpp/language/rank.py b/python/mscclpp/language/rank.py index e09905c00..4dcaa6c7a 100644 --- a/python/mscclpp/language/rank.py +++ b/python/mscclpp/language/rank.py @@ -422,7 +422,7 @@ def acquire(self, tb: int, data_sync: SyncType = SyncType.both): Example: >>> sem.acquire(tb=0, data_sync=SyncType.before) """ - op = SemaphoreAcquireOperation([self.id], data_sync) + op = SemaphoreAcquireOperation(self.rank, tb, [self.id], data_sync) get_program().add_operation(self.rank, tb, op) def release(self, tb: int, data_sync: SyncType = SyncType.both): @@ -440,5 +440,5 @@ def release(self, tb: int, data_sync: SyncType = SyncType.both): Example: >>> sem.release(tb=0, data_sync=SyncType.after) """ - op = SemaphoreReleaseOperation([self.id], data_sync) + op = SemaphoreReleaseOperation(self.rank, tb, [self.id], data_sync) get_program().add_operation(self.rank, tb, op) From dec8b10894811b55233fa36df59c12a266203aba Mon Sep 17 00:00:00 2001 From: Caio Rocha Date: Sun, 26 Oct 2025 21:28:29 +0000 Subject: [PATCH 4/7] wip --- python/mscclpp/language/internal/op_dep_graph.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/mscclpp/language/internal/op_dep_graph.py b/python/mscclpp/language/internal/op_dep_graph.py index 1f5cb6f0b..c00bc725c 100644 --- a/python/mscclpp/language/internal/op_dep_graph.py +++ b/python/mscclpp/language/internal/op_dep_graph.py @@ -177,9 +177,6 @@ def reset(self): node.reset() def print(self): - """ - Returns the order of operations in the DAG. - """ self.reset() self.check() From e40c410509da005d32f602ab7509f68fc8c8f8f2 Mon Sep 17 00:00:00 2001 From: Caio Rocha Date: Sat, 29 Nov 2025 02:01:43 +0000 Subject: [PATCH 5/7] wip --- .../language/internal/buffer_access.py | 5 +- .../mscclpp/language/internal/op_dep_graph.py | 76 ++++++++++++------- .../mscclpp/language/internal/operations.py | 47 +++++++++++- python/mscclpp/language/internal/types.py | 5 ++ python/mscclpp/language/loop.py | 16 +--- python/mscclpp/language/program.py | 11 ++- 6 files changed, 111 insertions(+), 49 deletions(-) diff --git a/python/mscclpp/language/internal/buffer_access.py b/python/mscclpp/language/internal/buffer_access.py index 7c5f3d023..5631aa02e 100644 --- a/python/mscclpp/language/internal/buffer_access.py +++ b/python/mscclpp/language/internal/buffer_access.py @@ -9,7 +9,8 @@ class BuffersAccess: - def __init__(self, num_ranks): + def __init__(self, num_ranks, intra_rank_sync): + self.intra_rank_sync = intra_rank_sync self.rank_intervals = [ { BufferType.input: SortedDict(), @@ -114,7 +115,7 @@ def resolve_conflicts(self, rank, threadblock, order_id, data_access_conflict: D fix_operations.append(SyncOperation(rank, threadblock)) self.track_sync[(rank, threadblock)] = order_id break - if data_access_conflict.conflict_type == DataAccessConflictType.inter_threadblock: + if data_access_conflict.conflict_type == DataAccessConflictType.inter_threadblock and self.intra_rank_sync: conflict_tb = set([threadblock]) for tb in data_access_conflict.threadblocks: if threadblock != tb[0] and ((rank, threadblock, tb[0]) not in self.track_barrier or self.track_barrier[(rank, threadblock, tb[0])] < tb[1]): diff --git a/python/mscclpp/language/internal/op_dep_graph.py b/python/mscclpp/language/internal/op_dep_graph.py index c00bc725c..bbe51ce88 100644 --- a/python/mscclpp/language/internal/op_dep_graph.py +++ b/python/mscclpp/language/internal/op_dep_graph.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from mscclpp.language.internal.globals import * from mscclpp.language.internal.operations import * from mscclpp.language.internal.types import * from mscclpp.language.internal.register import ChannelRegister, SemaphoreRegister from queue import Queue from typing import Set, Dict, Tuple +import warnings class OperationDependencyGraph: @@ -105,6 +107,7 @@ def add_tbg_operation(self, operations): def add_semaphore_dependency(self): queue = Queue() + processed_node = set() sem_rel = {} sem_acq = {} sem_val = {} @@ -112,13 +115,20 @@ def add_semaphore_dependency(self): self.reset() def compute_sem_op(sem_op, node): - for id in node.operation.semaphore_ids: - if (node.operation.rank, id) not in sem_op: - sem_op[(node.operation.rank, id)] = [] - sem_val[(node.operation.rank, id)] = SemaphoreRegister.get_semaphore(node.operation.rank, id).initial_value - sem_op[(node.operation.rank, id)].append(node) + operation = node.operation + for id in operation.semaphore_ids: + if (operation.rank, id) not in sem_op: + sem_op[(operation.rank, id)] = [] + sem_val[(operation.rank, id)] = SemaphoreRegister.get_semaphore(operation.rank, id).initial_value + sem_op[(operation.rank, id)].append((node, operation.pipeline_context)) + + return True def process_node(node): + if node in processed_node: + return + processed_node.add(node) + for next_node in node.next_nodes: next_node.add_reach() if next_node.get_reach() == next_node.get_input(): @@ -128,49 +138,59 @@ def process_node(node): else: queue.put(next_node) - for node in self.root_nodes: queue.put(node) while True: + sem_ops_found = False + new_sem_rel_node = [] while not queue.empty(): node = queue.get() if isinstance(node, self.Node) and isinstance(node.operation, SemaphoreReleaseOperation): - compute_sem_op(sem_rel, node) + sem_ops_found = compute_sem_op(sem_rel, node) + new_sem_rel_node.append(node) elif isinstance(node, self.Node) and isinstance(node.operation, SemaphoreAcquireOperation): - compute_sem_op(sem_acq, node) + sem_ops_found = compute_sem_op(sem_acq, node) else: process_node(node) - if not sem_rel and not sem_acq: + if not sem_ops_found: break else: removed_keys = [] for key in sem_acq.keys(): - if key in sem_rel: - if len(sem_acq[key]) > 1 or sem_val[key] != len(sem_rel[key]) - len(sem_acq[key]): - raise RuntimeError(f"Undefined Behaviour Semaphore Id {key[1]}.") - else: - sem_acq_node = sem_acq[key][0] - sem_val[key] = 0 - if sem_acq_node in self.root_nodes: - self.root_nodes.remove(sem_acq_node) - process_node(sem_acq_node) - for sem_rel_node in sem_rel[key]: - process_node(sem_rel_node) - sem_rel_node.next_nodes.append(sem_acq_node) - sem_acq_node.operation.add_tb_sync(sem_rel_node.operation.threadblock) - sem_acq_node.previous_nodes.append(sem_rel_node) - sem_acq_node.add_input() - - removed_keys.append(key) + if key not in sem_rel: + sem_rel[key] = [] + if len(sem_acq[key]) > 1 or sem_val[key] < len(sem_rel[key]) - len(sem_acq[key]): + get_program().disable_inter_tb_sync() + warnings.warn(f"Undefined Behaviour Semaphore Id.", UserWarning) + return + + for sem_rel_node in new_sem_rel_node: + process_node(sem_rel_node) + + if sem_val[key] == len(sem_rel[key]) - len(sem_acq[key]): + sem_acq_node, sem_acq_ctx = sem_acq[key][0] + sem_val[key] = 0 + if sem_acq_node in self.root_nodes: + self.root_nodes.remove(sem_acq_node) + process_node(sem_acq_node) + for sem_rel_node, sem_rel_ctx in sem_rel[key]: + if sem_rel_ctx is not sem_acq_ctx: + raise RuntimeError(f"Semaphore cross pipeline context violation.") + sem_rel_node.next_nodes.append(sem_acq_node) + sem_acq_node.operation.add_tb_sync(sem_rel_node.operation.threadblock) + sem_acq_node.previous_nodes.append(sem_rel_node) + sem_acq_node.add_input() + + removed_keys.append(key) for key in removed_keys: sem_rel.pop(key) sem_acq.pop(key) - if len(sem_rel.keys()) > 0 or len(sem_acq.keys()): - raise RuntimeError(f"Undefined Semaphore Behaviour.") + if len(sem_acq.keys()) > 0: + raise RuntimeError(f"Semaphore acquire hanging.") def reset(self): for node in self.node_list: diff --git a/python/mscclpp/language/internal/operations.py b/python/mscclpp/language/internal/operations.py index c7b731374..1b09865c4 100644 --- a/python/mscclpp/language/internal/operations.py +++ b/python/mscclpp/language/internal/operations.py @@ -12,6 +12,7 @@ DataAccessType, ) from mscclpp.language.thread_block_group import ThreadBlockGroup +from mscclpp.language.loop import LoopIterationContext from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import List @@ -36,6 +37,8 @@ class BaseOperation(ABC): rank: int threadblock: int name: str + # TODO: Only fuse operation with the same pipeline_context + pipeline_context: LoopIterationContext = field(default=None) def local_data_access(self, sync_purpose=True): """Get list of local data accesses performed by this operation. @@ -84,6 +87,12 @@ def shift_ids(self, instance, num_instances, replication_function): """ return + def set_pipeline_context(self, pipeline_context): + self.pipeline_context = pipeline_context + + def basic_fusion_check(self, other_op): + return self.rank == other_op.rank and self.threadblock == other_op.thread_block and self.pipeline_context == other_op.pipeline_context + def __add__(self, other): """Attempt to fuse this operation with another operation. @@ -120,6 +129,9 @@ def __init__(self, rank: int, threadblock: int): super().__init__(rank, threadblock, Instruction.nop) def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if isinstance(other, SyncOperation): fused_operation = SyncOperation() @@ -229,6 +241,9 @@ def shift_ids(self, instance, num_instances, replication_function): self.semaphore_ids[i] = replication_function(self.semaphore_ids[i], instance, num_instances) def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if isinstance(other, SemaphoreAcquireOperation): fused_operation = SemaphoreAcquireOperation( @@ -262,6 +277,9 @@ def shift_ids(self, instance, num_instances, replication_function): self.semaphore_ids[i] = replication_function(self.semaphore_ids[i], instance, num_instances) def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if isinstance(other, SemaphoreReleaseOperation): fused_operation = SemaphoreReleaseOperation( @@ -303,6 +321,9 @@ def __init__( self.data_sync = data_sync def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if ( isinstance(other, SignalOperation) @@ -352,6 +373,9 @@ def __init__( self.data_sync = data_sync def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if ( isinstance(other, WaitOperation) @@ -403,6 +427,9 @@ def shift_ids(self, instance, num_instances, replication_function): self.barrier_id = replication_function(self.barrier_id, instance, num_instances) def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if check_data_sync_op(other): other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) @@ -444,6 +471,9 @@ def __init__( self.data_sync = data_sync def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if isinstance(other, FlushOperation) and self.channel_type == other.channel_type: fused_operation = FlushOperation( @@ -511,6 +541,9 @@ def shift_buffers(self, instance, num_instances, replication_function): chunk.index = replication_function(chunk.index, chunk.size, instance, num_instances) def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if ( isinstance(other, GetOperation) @@ -610,6 +643,9 @@ def shift_buffers(self, instance, num_instances, replication_function): chunk.index = replication_function(chunk.index, chunk.size, instance, num_instances) def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if ( isinstance(other, PutOperation) @@ -748,6 +784,9 @@ def shift_buffers(self, instance, num_instances, replication_function): chunk.index = replication_function(chunk.index, chunk.size, instance, num_instances) def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if ( isinstance(other, ReduceOperation) @@ -873,6 +912,9 @@ def shift_buffers(self, instance, num_instances, replication_function): self.dst_chunk.index = replication_function(self.dst_chunk.index, self.size, instance, num_instances) def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if ( isinstance(other, GroupStore) @@ -1034,6 +1076,9 @@ def shift_ids(self, instance, num_instances, replication_function): operation.shift_ids(instance, num_instances, replication_function) def __add__(self, other): + if not self.basic_fused_check(self): + return None + fused_operation = None if (self.get_data_sync() & SyncType.after) == SyncType.after and check_data_sync_op(other): other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) @@ -1088,4 +1133,4 @@ def add_data_sync(operations): ): result_operations.append(SyncOperation()) - return result_operations + return result_operations \ No newline at end of file diff --git a/python/mscclpp/language/internal/types.py b/python/mscclpp/language/internal/types.py index 19000fdd3..ebccc7e16 100644 --- a/python/mscclpp/language/internal/types.py +++ b/python/mscclpp/language/internal/types.py @@ -251,3 +251,8 @@ class ReplicationPolicy(Enum): def __str__(self): return self.value + +@dataclass +class PipelineContext(): + unit: int + num_chunks: int \ No newline at end of file diff --git a/python/mscclpp/language/loop.py b/python/mscclpp/language/loop.py index 06ca90f40..d79c6bd78 100644 --- a/python/mscclpp/language/loop.py +++ b/python/mscclpp/language/loop.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. from mscclpp.language.internal.globals import * -from mscclpp.language.internal.operations import PipelineOperation class LoopIterationContext: @@ -35,7 +34,6 @@ def __init__(self, unit, num_chunks): """ self.unit = unit self.num_chunks = num_chunks - self.operations = [] def __enter__(self): """Enter the context and set this as the active loop context. @@ -54,17 +52,7 @@ def __exit__(self, exc_type, exc_value, traceback): """ get_program().set_loop_context(None) - pipeline_operation = {} - for rank, tb, operation in self.operations: - key = (rank, tb) - if key not in pipeline_operation: - pipeline_operation[key] = PipelineOperation(self.unit, self.num_chunks) - pipeline_operation[key].add_operation(operation) - - for (rank, tb), pipeline in pipeline_operation.items(): - get_program().add_operation(rank, tb, pipeline) - - def add_operation(self, rank, tb, operation): + def process_operation(self, operation): """Add an operation to be included in the pipeline. This method is called internally to collect operations that should be @@ -76,4 +64,4 @@ def add_operation(self, rank, tb, operation): tb (int): The thread block ID that will execute this operation. operation: The operation object to be added to the pipeline. """ - self.operations.append((rank, tb, operation)) + operation.set_pipeline_context(self) diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index 019063a79..4ef110503 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -48,6 +48,7 @@ def __init__( protocol: str = "Simple", instr_fusion: bool = True, auto_sync: bool = True, + intra_rank_sync: bool = True, replication_policy: ReplicationPolicy = ReplicationPolicy.interleaved, reuse_resources: bool = False, num_threads_per_block: int = 1024, @@ -103,7 +104,7 @@ def __init__( self.max_message_size = max_message_size assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" self.op_dep_dag = OperationDependencyGraph() - self.buffers_access = BuffersAccess(num_ranks) + self.buffers_access = BuffersAccess(num_ranks, intra_rank_sync) self.buffers = collective.init_buffers() self.gpus: List[Gpu] = [] for rank in range(self.num_ranks): @@ -129,6 +130,9 @@ def __exit__(self, exc_type, exc_value, traceback): """ set_program(None) + def disable_inter_tb_sync(self): + self.buffers_access.intra_rank_sync = False + def add_channel(self, channel): if channel.channel_type == ChannelType.switch: for gpu in channel.rank_group.ranks: @@ -152,9 +156,8 @@ def add_semaphore(self, semaphore): def add_operation(self, rank, tb, operation): if self.loop_context != None: - self.loop_context.add_operation(rank, tb, operation) - else: - self.op_dep_dag.add_operation(operation) + self.loop_context.process_operation(operation) + self.op_dep_dag.add_operation(operation) def add_tbg_operation(self, operations): self.op_dep_dag.add_tbg_operation(operations) From 1f36c97670eda6d7ae2b224c3686b1cf042f6c20 Mon Sep 17 00:00:00 2001 From: Caio Rocha Date: Wed, 17 Dec 2025 17:13:51 +0000 Subject: [PATCH 6/7] wip --- python/mscclpp/language/internal/op_dep_graph.py | 2 +- python/mscclpp/language/internal/operations.py | 6 +++--- python/mscclpp/language/internal/types.py | 7 +------ python/mscclpp/language/rank.py | 2 +- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/python/mscclpp/language/internal/op_dep_graph.py b/python/mscclpp/language/internal/op_dep_graph.py index bbe51ce88..938bb2202 100644 --- a/python/mscclpp/language/internal/op_dep_graph.py +++ b/python/mscclpp/language/internal/op_dep_graph.py @@ -334,7 +334,7 @@ def add_node(self, node): def get_operations(self): operations = [] for node in self.nodes: - operations.append(node) + operations.extend(node.get_operations()) return operations def print(self): diff --git a/python/mscclpp/language/internal/operations.py b/python/mscclpp/language/internal/operations.py index 728f3c12b..c47d2b807 100644 --- a/python/mscclpp/language/internal/operations.py +++ b/python/mscclpp/language/internal/operations.py @@ -91,7 +91,7 @@ def set_pipeline_context(self, pipeline_context): self.pipeline_context = pipeline_context def basic_fusion_check(self, other_op): - return self.rank == other_op.rank and self.threadblock == other_op.thread_block and self.pipeline_context == other_op.pipeline_context + return self.rank == other_op.rank and self.threadblock == other_op.thread_block and self.pipeline_context is other_op.pipeline_context def __add__(self, other): """Attempt to fuse this operation with another operation. @@ -163,9 +163,9 @@ def __init__( if from_packet and to_packet: raise RuntimeError(f"Copy Operation from Packet to Packet is not Supported.") elif from_packet: - super().__init__(rank, threadblock, Instruction.copy_packet) + super().__init__(rank, threadblock, Instruction.unpack_packet) elif to_packet: - super().__init__(rank, threadblock, Instruction.transform_to_packet) + super().__init__(rank, threadblock, Instruction.copy_packet) else: super().__init__(rank, threadblock, Instruction.copy) diff --git a/python/mscclpp/language/internal/types.py b/python/mscclpp/language/internal/types.py index d796b70ce..e7a5de079 100644 --- a/python/mscclpp/language/internal/types.py +++ b/python/mscclpp/language/internal/types.py @@ -252,9 +252,4 @@ class ReplicationPolicy(Enum): none = "none" def __str__(self): - return self.value - -@dataclass -class PipelineContext(): - unit: int - num_chunks: int \ No newline at end of file + return self.value \ No newline at end of file diff --git a/python/mscclpp/language/rank.py b/python/mscclpp/language/rank.py index 3ec8f30b8..462465e6e 100644 --- a/python/mscclpp/language/rank.py +++ b/python/mscclpp/language/rank.py @@ -289,7 +289,7 @@ def barrier(self, tb_list: List[int]): if len(tb_list) == 0: raise RuntimeError("Barrier requires at least thread block.") elif len(tb_list) == 1: - op = SyncOperation() + op = SyncOperation(self.rank, tb_list[0]) get_program().add_operation(self.rank, tb_list[0], op) else: for tb in tb_list: From d2d4c09dae7f1b4409bccd1dfb6c1425aa847fb6 Mon Sep 17 00:00:00 2001 From: Caio Rocha Date: Fri, 19 Dec 2025 18:46:58 +0000 Subject: [PATCH 7/7] wip --- python/mscclpp/language/channel.py | 36 +- .../language/internal/buffer_access.py | 143 +++- .../mscclpp/language/internal/op_dep_graph.py | 99 ++- .../mscclpp/language/internal/operations.py | 691 ++++++++++-------- python/mscclpp/language/internal/register.py | 6 +- python/mscclpp/language/internal/types.py | 72 +- python/mscclpp/language/loop.py | 7 +- python/mscclpp/language/program.py | 7 +- python/mscclpp/language/rank.py | 14 +- python/mscclpp/language/thread_block_group.py | 2 - 10 files changed, 652 insertions(+), 425 deletions(-) diff --git a/python/mscclpp/language/channel.py b/python/mscclpp/language/channel.py index cb0d8f307..d0d29b040 100644 --- a/python/mscclpp/language/channel.py +++ b/python/mscclpp/language/channel.py @@ -152,14 +152,10 @@ def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg=( - tb_group - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), ) operations.append(op) - + if tb_group is None: get_program().add_operation(self.src_rank, tb_id, operations[0]) else: @@ -219,11 +215,7 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg=( - tb_group - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), ) operations.append(op) @@ -291,11 +283,7 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, t dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg=( - tb_group - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), from_packet=True, to_packet=True, ) @@ -363,11 +351,7 @@ def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_gro dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg=( - tb_group - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), from_packet=False, to_packet=True, ) @@ -466,11 +450,7 @@ def reduce( remote_dst_buff=[], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg=( - tb_group - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), reduce_operation=reduce_op, ) operations.apend(op) @@ -1004,7 +984,9 @@ def broadcast(self, rank, src_chunk: Chunk, buffer_offset, size, tb): ) tb_channel_ids = get_program().setup_channel(tb, self) - op = GroupStore(self.src_rank, tb, src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type) + op = GroupStore( + self.src_rank, tb, src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type + ) get_program().add_operation(self.src_rank, tb, op) class SwitchChannelRankView: diff --git a/python/mscclpp/language/internal/buffer_access.py b/python/mscclpp/language/internal/buffer_access.py index 5631aa02e..f0d06b3e8 100644 --- a/python/mscclpp/language/internal/buffer_access.py +++ b/python/mscclpp/language/internal/buffer_access.py @@ -33,29 +33,43 @@ def process_operations(self, operations): if operation.name == Instruction.sem_acquire: self.update_semaphore(operation, i) else: - if operation.name == Instruction.pipeline: - pipeline_buffer_access = BuffersAccess() - pipeline_result_operations = pipeline_buffer_access.process_operations(operation.operations) - operation.operations = pipeline_result_operations data_access = operation.local_data_access(i) - data_access_conflict = DataAccessConflict(operation.rank) + data_access_conflict_same_ctx = DataAccessConflict(operation.rank) + data_access_conflict_diff_ctx = DataAccessConflict(operation.rank) for data_access_element in data_access: - data_access_conflict = data_access_conflict + self.compute_data_access(data_access_element) - fix_operations = self.resolve_conflicts(operation.rank, operation.threadblock, i, data_access_conflict) + computed_data_access_conflict_same_ctx, computed_data_access_conflict_diff_ctx = ( + self.compute_data_access(data_access_element) + ) + data_access_conflict_same_ctx = ( + data_access_conflict_same_ctx + computed_data_access_conflict_same_ctx + ) + data_access_conflict_diff_ctx = ( + data_access_conflict_diff_ctx + computed_data_access_conflict_diff_ctx + ) + fix_operations = self.resolve_conflicts( + operation.rank, + operation.threadblock, + operation.pipeline_context, + i, + data_access_conflict_same_ctx, + data_access_conflict_diff_ctx, + ) result_operations.extend(fix_operations) result_operations.append(operation) + result_operations = self.add_pipeline_context_sync_operations(result_operations) + return result_operations def update_barrier(self, operation, order_id): - for tb in operation.barrier_info.tb_list: + for tb in operation.barrier_info.tb_list: if operation.threadblock != tb: self.track_barrier[operation.rank, operation.threadblock, tb] = order_id self.track_sync[operation.rank, operation.threadblock] = order_id def update_semaphore(self, operation, order_id): - for tb in operation.tb_sync: + for tb in operation.tb_sync: if operation.threadblock != tb: self.track_barrier[operation.rank, operation.threadblock, tb] = order_id @@ -63,12 +77,19 @@ def compute_data_access(self, data_access: DataAccess) -> bool: intervals = self.rank_intervals[data_access.rank] keys = intervals[data_access.buffer_type].keys() idx = self.lower_bound(0, len(keys) - 1, keys, data_access) - conflict = DataAccessConflict(data_access.rank) + conflict_same_ctx = DataAccessConflict(data_access.rank) + conflict_diff_ctx = DataAccessConflict(data_access.rank) while len(keys) > 0 and data_access.overlaps(keys[idx]): conflict_data_access = keys[idx] conflict_operation_type = intervals[data_access.buffer_type][conflict_data_access] - conflict = conflict + data_access.check_conflict(conflict_data_access) + if ( + data_access.pipeline_context is None + or data_access.pipeline_context == conflict_data_access.pipeline_context + ): + conflict_same_ctx = conflict_same_ctx + data_access.check_conflict(conflict_data_access) + else: + conflict_diff_ctx = conflict_diff_ctx + data_access.check_conflict(conflict_data_access) intervals[data_access.buffer_type].pop(conflict_data_access) if conflict_data_access.end > data_access.end: @@ -82,8 +103,8 @@ def compute_data_access(self, data_access: DataAccess) -> bool: conflict_data_access.end, conflict_data_access.buffer_type, conflict_operation_type, - conflict_data_access.tb_group - + conflict_data_access.tb_group, + conflict_data_access.pipeline_context, ) ] = conflict_operation_type if conflict_data_access.start < data_access.start: @@ -97,7 +118,8 @@ def compute_data_access(self, data_access: DataAccess) -> bool: data_access.start, conflict_data_access.buffer_type, conflict_operation_type, - conflict_data_access.tb_group + conflict_data_access.tb_group, + conflict_data_access.pipeline_context, ) ] = conflict_operation_type @@ -105,20 +127,34 @@ def compute_data_access(self, data_access: DataAccess) -> bool: idx = self.lower_bound(0, len(keys) - 1, keys, data_access) intervals[data_access.buffer_type][data_access] = data_access.data_access_type - return conflict - - def resolve_conflicts(self, rank, threadblock, order_id, data_access_conflict: DataAccessConflict): + return (conflict_same_ctx, conflict_diff_ctx) + + def resolve_conflicts( + self, + rank, + threadblock, + pipeline_context, + order_id, + data_access_conflict_same_ctx: DataAccessConflict, + data_access_conflict_diff_ctx: DataAccessConflict, + ): fix_operations = [] - if data_access_conflict.conflict_type == DataAccessConflictType.intra_threadblock: - for tb in data_access_conflict.threadblocks: + if data_access_conflict_same_ctx.conflict_type == DataAccessConflictType.intra_threadblock: + for tb in data_access_conflict_same_ctx.threadblocks: if (rank, threadblock) not in self.track_sync or tb[1] > self.track_sync[(rank, threadblock)]: fix_operations.append(SyncOperation(rank, threadblock)) self.track_sync[(rank, threadblock)] = order_id break - if data_access_conflict.conflict_type == DataAccessConflictType.inter_threadblock and self.intra_rank_sync: + if ( + data_access_conflict_same_ctx.conflict_type == DataAccessConflictType.inter_threadblock + and self.intra_rank_sync + ): conflict_tb = set([threadblock]) - for tb in data_access_conflict.threadblocks: - if threadblock != tb[0] and ((rank, threadblock, tb[0]) not in self.track_barrier or self.track_barrier[(rank, threadblock, tb[0])] < tb[1]): + for tb in data_access_conflict_same_ctx.threadblocks: + if threadblock != tb[0] and ( + (rank, threadblock, tb[0]) not in self.track_barrier + or self.track_barrier[(rank, threadblock, tb[0])] < tb[1] + ): if not tb[2]: raise RuntimeError("Operations order not defined.") conflict_tb.add(tb[0]) @@ -128,7 +164,68 @@ def resolve_conflicts(self, rank, threadblock, order_id, data_access_conflict: D self.update_barrier(op, order_id) fix_operations.append(op) - return fix_operations + if pipeline_context is not None: + if (rank, threadblock) not in pipeline_context.pre_operations: + pipeline_context.pre_operations[(rank, threadblock)] = [] + + if data_access_conflict_diff_ctx.conflict_type == DataAccessConflictType.intra_threadblock: + for tb in data_access_conflict_diff_ctx.threadblocks: + if (rank, threadblock) not in self.track_sync or tb[1] > self.track_sync[(rank, threadblock)]: + self.track_sync[(rank, threadblock)] = order_id + pipeline_context.pre_operations[(rank, threadblock)].append(SyncOperation(rank, threadblock)) + break + if ( + data_access_conflict_diff_ctx.conflict_type == DataAccessConflictType.inter_threadblock + and self.intra_rank_sync + ): + conflict_tb = set([threadblock]) + for tb in data_access_conflict_diff_ctx.threadblocks: + if threadblock != tb[0] and ( + (rank, threadblock, tb[0]) not in self.track_barrier + or self.track_barrier[(rank, threadblock, tb[0])] < tb[1] + ): + if not tb[2]: + raise RuntimeError("Operations order not defined.") + conflict_tb.add(tb[0]) + if len(conflict_tb) > 1: + for tb in conflict_tb: + op = BarrierOperation(rank, tb, conflict_tb) + self.update_barrier(op, order_id) + pipeline_context.pre_operations[(rank, threadblock)].append(op) + + return fix_operations + + def add_pipeline_context_sync_operations(self, operations): + result_operations = [] + pipeline_operations = dict() + for i in range(len(operations)): + operation = operations[i] + if operation.pipeline_context is not None: + pipeline_context = operation.pipeline_context + result_operations.extend( + pipeline_context.pre_operations.get((operation.rank, operation.threadblock), []) + ) + pipeline_context.pre_operations.pop((operation.rank, operation.threadblock), None) + + if (operation.rank, operation.threadblock, operation.pipeline_context) not in pipeline_operations: + pipeline_operations[(operation.rank, operation.threadblock, operation.pipeline_context)] = ( + PipelineOperation( + operation.rank, + operation.threadblock, + operation.pipeline_context.unit, + operation.pipeline_context.num_chunks, + ) + ) + result_operations.append( + pipeline_operations[(operation.rank, operation.threadblock, operation.pipeline_context)] + ) + pipeline_operations[(operation.rank, operation.threadblock, operation.pipeline_context)].add_operation( + operation + ) + else: + result_operations.append(operation) + + return result_operations def lower_bound(self, init_pos, final_pos, data_access_list, data_access): if init_pos >= final_pos: diff --git a/python/mscclpp/language/internal/op_dep_graph.py b/python/mscclpp/language/internal/op_dep_graph.py index 938bb2202..848cf1328 100644 --- a/python/mscclpp/language/internal/op_dep_graph.py +++ b/python/mscclpp/language/internal/op_dep_graph.py @@ -25,8 +25,8 @@ def __init__(self): self.barrier_nodes: Dict[Tuple[int, int], List[OperationDependencyGraph.Node]] = {} self.tb_barriers: Dict[Tuple[int, int, int], int] = {} self.node_list = [] - - def add_operation(self, operation, agg_node = None): + + def add_operation(self, operation, agg_node=None): """ Inserts an operation into the DAG, adding edges based on dependencies. """ @@ -37,14 +37,16 @@ def add_operation(self, operation, agg_node = None): agg_node.add_node(node) if isinstance(operation, BarrierOperation): - if (rank, threadblock, operation.barrier_id) not in self.tb_barriers: + if (rank, threadblock, operation.barrier_id) not in self.tb_barriers: self.tb_barriers[(rank, threadblock, operation.barrier_id)] = 0 if (rank, operation.barrier_id) not in self.barrier_nodes: self.barrier_nodes[(rank, operation.barrier_id)] = [] barrier_count = self.tb_barriers[(rank, threadblock, operation.barrier_id)] if barrier_count > len(self.barrier_nodes[(rank, operation.barrier_id)]): - raise RuntimeError(f"Barrier node not create correctly for rank {rank}, threadblock {threadblock}, barrier_id {operation.barrier_id}.") + raise RuntimeError( + f"Barrier node not create correctly for rank {rank}, threadblock {threadblock}, barrier_id {operation.barrier_id}." + ) elif barrier_count == len(self.barrier_nodes[(rank, operation.barrier_id)]): agg_node = self.AggregateNode() self.barrier_nodes[(rank, operation.barrier_id)].append(agg_node) @@ -70,7 +72,9 @@ def add_operation(self, operation, agg_node = None): if node in self.root_nodes: self.root_nodes.remove(node) - if isinstance(operation, SignalOperation) or (isinstance(operation, PutOperation) and (operation.with_signal or operation.with_signal_and_flush)): + if isinstance(operation, SignalOperation) or ( + isinstance(operation, PutOperation) and (operation.with_signal or operation.with_signal_and_flush) + ): for tb_channel_id in operation.channel_ids: channel = ChannelRegister.get_channel(rank, threadblock, tb_channel_id) op_info = (channel.src_rank, channel.dst_rank, channel.channel_peer_id) @@ -121,14 +125,14 @@ def compute_sem_op(sem_op, node): sem_op[(operation.rank, id)] = [] sem_val[(operation.rank, id)] = SemaphoreRegister.get_semaphore(operation.rank, id).initial_value sem_op[(operation.rank, id)].append((node, operation.pipeline_context)) - + return True def process_node(node): if node in processed_node: return processed_node.add(node) - + for next_node in node.next_nodes: next_node.add_reach() if next_node.get_reach() == next_node.get_input(): @@ -136,7 +140,7 @@ def process_node(node): for sub_node in next_node.agg_node.nodes: queue.put(sub_node) else: - queue.put(next_node) + queue.put(next_node) for node in self.root_nodes: queue.put(node) @@ -165,10 +169,10 @@ def process_node(node): get_program().disable_inter_tb_sync() warnings.warn(f"Undefined Behaviour Semaphore Id.", UserWarning) return - + for sem_rel_node in new_sem_rel_node: process_node(sem_rel_node) - + if sem_val[key] == len(sem_rel[key]) - len(sem_acq[key]): sem_acq_node, sem_acq_ctx = sem_acq[key][0] sem_val[key] = 0 @@ -182,16 +186,16 @@ def process_node(node): sem_acq_node.operation.add_tb_sync(sem_rel_node.operation.threadblock) sem_acq_node.previous_nodes.append(sem_rel_node) sem_acq_node.add_input() - + removed_keys.append(key) - + for key in removed_keys: sem_rel.pop(key) sem_acq.pop(key) if len(sem_acq.keys()) > 0: raise RuntimeError(f"Semaphore acquire hanging.") - + def reset(self): for node in self.node_list: node.reset() @@ -199,7 +203,7 @@ def reset(self): def print(self): self.reset() self.check() - + queue = Queue() for node in self.root_nodes: queue.put(node) @@ -215,7 +219,7 @@ def print(self): for sub_node in next_node.agg_node.nodes: queue.put(sub_node) else: - queue.put(next_node) + queue.put(next_node) print() def check(self): @@ -225,11 +229,39 @@ def check(self): if len(self.signalling) > 0: for key, queue in self.signalling.items(): if not queue.empty(): - raise RuntimeError(f"Signalling from {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent wait operation.") + raise RuntimeError( + f"Signalling from {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent wait operation." + ) if len(self.waiting) > 0: for key, queue in self.waiting.items(): if not queue.empty(): - raise RuntimeError(f"Waiting for {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent signal operation.") + raise RuntimeError( + f"Waiting for {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent signal operation." + ) + + def fusion_operations(self): + self.reset() + self.check() + + for node in self.root_nodes: + for next_node in node.next_nodes: + if isinstance(node, self.Node) and isinstance(next_node, self.Node): + fused_op = node.operation + next_node.operation + if fused_op is not None: + node.operation = fused_op + node.next_nodes.remove(next_node) + next_node.previous_nodes.remove(node) + for nn in next_node.next_nodes: + node.next_nodes.append(nn) + ## Change from list to another data sctruct that allos insert and remove in log(n) + nn.previous_nodes.remove(next_node) + nn.previous_nodes.append(node) + for pn in next_node.previous_nodes: + node.previous_nodes.append(pn) + pn.next_nodes.remove(next_node) + pn.next_nodes.append(node) + node.add_input() + del next_node def get_execution_order(self): """ @@ -237,7 +269,7 @@ def get_execution_order(self): """ self.reset() self.check() - + order = [] queue = Queue() for node in self.root_nodes: @@ -253,12 +285,13 @@ def get_execution_order(self): for sub_node in next_node.agg_node.nodes: queue.put(sub_node) else: - queue.put(next_node) + queue.put(next_node) return order - class BaseNode(): + class BaseNode: def __init__(self): + ## Change from list to another data sctruct that allos insert and remove in log(n) self.previous_nodes = [] self.next_nodes = [] self.input = 0 @@ -266,7 +299,7 @@ def __init__(self): def add_input(self): self.input += 1 - + def add_reach(self): self.reach += 1 @@ -285,6 +318,11 @@ def __init__(self, operation): self.agg_node = None super().__init__() + def __del__(self): + self.decrease_input(len(self.previous_nodes)) + if self.agg_node is not None: + self.agg_node.nodes.remove(self) + def get_operations(self): return [self.operation] @@ -294,6 +332,12 @@ def add_input(self): else: self.input += 1 + def decrease_input(self, amount=1): + if self.agg_node is not None: + self.agg_node.input -= amount + else: + self.input -= amount + def add_reach(self): if self.agg_node is not None: self.agg_node.reach += 1 @@ -302,31 +346,30 @@ def add_reach(self): def get_input(self): if self.agg_node is not None: - return self.agg_node.input + return self.agg_node.input else: return self.input - + def get_reach(self): if self.agg_node is not None: - return self.agg_node.reach + return self.agg_node.reach else: return self.reach def reset(self): if self.agg_node is not None: - self.agg_node.reset() + self.agg_node.reset() else: self.reach = 0 def print(self): return f"rank {self.operation.rank} tb {self.operation.threadblock} {self.operation.name}" - class AggregateNode(BaseNode): def __init__(self): self.nodes = [] super().__init__() - + def add_node(self, node): self.nodes.append(node) node.agg_node = self @@ -338,4 +381,4 @@ def get_operations(self): return operations def print(self): - return f"rank {self.operations[0].rank} tb {self.operations[0].threadblock} {self.operations[0].name}" \ No newline at end of file + return f"rank {self.operations[0].rank} tb {self.operations[0].threadblock} {self.operations[0].name}" diff --git a/python/mscclpp/language/internal/operations.py b/python/mscclpp/language/internal/operations.py index c47d2b807..62f627294 100644 --- a/python/mscclpp/language/internal/operations.py +++ b/python/mscclpp/language/internal/operations.py @@ -91,7 +91,11 @@ def set_pipeline_context(self, pipeline_context): self.pipeline_context = pipeline_context def basic_fusion_check(self, other_op): - return self.rank == other_op.rank and self.threadblock == other_op.thread_block and self.pipeline_context is other_op.pipeline_context + return ( + self.rank == other_op.rank + and self.threadblock == other_op.threadblock + and self.pipeline_context is other_op.pipeline_context + ) def __add__(self, other): """Attempt to fuse this operation with another operation. @@ -129,18 +133,16 @@ def __init__(self, rank: int, threadblock: int): super().__init__(rank, threadblock, Instruction.nop) def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if isinstance(other, SyncOperation): - fused_operation = SyncOperation() - elif isinstance(other, BarrierOperation): - fused_operation = other - elif isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before: - fused_operation = other - elif check_data_sync_op(other): - other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) + if self.basic_fusion_check(other): + if isinstance(other, SyncOperation): + fused_operation = SyncOperation(self.rank, self.threadblock) + elif isinstance(other, BarrierOperation): + fused_operation = other + elif isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before: + fused_operation = other + elif check_data_sync_op(other): + other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) return fused_operation @@ -183,11 +185,20 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + ( + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) + if self.tbg is not None + else 0 + ), + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), chunk.type, DataAccessType.read, - self.tbg + self.tbg, + self.pipeline_context, ) ) if self.name != Instruction.copy_packet or not sync_purpose: @@ -198,11 +209,20 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + ( + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) + if self.tbg is not None + else 0 + ), + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), chunk.type, DataAccessType.write, - self.tbg + self.tbg, + self.pipeline_context, ) ) return data_access @@ -241,22 +261,25 @@ def shift_ids(self, instance, num_instances, replication_function): self.semaphore_ids[i] = replication_function(self.semaphore_ids[i], instance, num_instances) def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if isinstance(other, SemaphoreAcquireOperation): - fused_operation = SemaphoreAcquireOperation( - semaphore_ids=self.semaphore_ids + other.semaphore_ids, - data_sync=self.data_sync | other.data_sync, - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if isinstance(other, SemaphoreAcquireOperation): + fused_operation = SemaphoreAcquireOperation( + self.rank, + self.threadblock, + semaphore_ids=self.semaphore_ids + other.semaphore_ids, + data_sync=self.data_sync | other.data_sync, + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -277,22 +300,25 @@ def shift_ids(self, instance, num_instances, replication_function): self.semaphore_ids[i] = replication_function(self.semaphore_ids[i], instance, num_instances) def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if isinstance(other, SemaphoreReleaseOperation): - fused_operation = SemaphoreReleaseOperation( - semaphore_ids=self.semaphore_ids + other.semaphore_ids, - data_sync=self.data_sync | other.data_sync, - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if isinstance(other, SemaphoreReleaseOperation): + fused_operation = SemaphoreReleaseOperation( + self.rank, + self.threadblock, + semaphore_ids=self.semaphore_ids + other.semaphore_ids, + data_sync=self.data_sync | other.data_sync, + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -321,29 +347,32 @@ def __init__( self.data_sync = data_sync def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if ( - isinstance(other, SignalOperation) - and self.channel_type == other.channel_type - and self.name == other.name - and not self.channel_ids & other.channel_ids - ): - fused_operation = SignalOperation( - channels_ids=self.channel_ids | other.channel_ids, - channel_type=self.channel_type, - data_sync=self.data_sync | other.data_sync, - relaxed=(self.name == Instruction.relaxed_signal), - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if ( + isinstance(other, SignalOperation) + and self.channel_type == other.channel_type + and self.name == other.name + and not self.channel_ids & other.channel_ids + ): + fused_operation = SignalOperation( + self.rank, + self.threadblock, + channels_ids=self.channel_ids | other.channel_ids, + channel_type=self.channel_type, + data_sync=self.data_sync | other.data_sync, + relaxed=(self.name == Instruction.relaxed_signal), + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -373,29 +402,32 @@ def __init__( self.data_sync = data_sync def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if ( - isinstance(other, WaitOperation) - and self.name == other.name - and not self.channel_ids & other.channel_ids - and self.channel_type == other.channel_type - ): - fused_operation = WaitOperation( - channels_ids=self.channel_ids | other.channel_ids, - channel_type=self.channel_type, - data_sync=self.data_sync | other.data_sync, - relaxed=(self.name == Instruction.relaxed_wait), - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if ( + isinstance(other, WaitOperation) + and self.name == other.name + and not self.channel_ids & other.channel_ids + and self.channel_type == other.channel_type + ): + fused_operation = WaitOperation( + self.rank, + self.threadblock, + channels_ids=self.channel_ids | other.channel_ids, + channel_type=self.channel_type, + data_sync=self.data_sync | other.data_sync, + relaxed=(self.name == Instruction.relaxed_wait), + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -427,14 +459,12 @@ def shift_ids(self, instance, num_instances, replication_function): self.barrier_id = replication_function(self.barrier_id, instance, num_instances) def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if check_data_sync_op(other): - other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) - elif isinstance(other, SyncOperation): - fused_operation = self + if self.basic_fusion_check(other): + if check_data_sync_op(other): + other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) + elif isinstance(other, SyncOperation): + fused_operation = self return fused_operation @@ -471,23 +501,26 @@ def __init__( self.data_sync = data_sync def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if isinstance(other, FlushOperation) and self.channel_type == other.channel_type: - fused_operation = FlushOperation( - channels_ids=self.channel_ids | other.channel_ids, - channel_type=self.channel_type, - data_sync=self.data_sync | other.data_sync, - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if isinstance(other, FlushOperation) and self.channel_type == other.channel_type: + fused_operation = FlushOperation( + self.rank, + self.threadblock, + channels_ids=self.channel_ids | other.channel_ids, + channel_type=self.channel_type, + data_sync=self.data_sync | other.data_sync, + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -526,10 +559,15 @@ def local_data_access(self, order_id, sync_purpose=True): self.id, order_id, chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), chunk.type, DataAccessType.write, - self.tbg + self.tbg, + self.pipeline_context, ) ) return data_access @@ -541,23 +579,23 @@ def shift_buffers(self, instance, num_instances, replication_function): chunk.index = replication_function(chunk.index, chunk.size, instance, num_instances) def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if ( - isinstance(other, GetOperation) - and self.src_buff[0].size == other.src_buff[0].size - and self.channel_type == other.channel_type - and self.tbg == other.tbg - ): - fused_operation = GetOperation( - src_buff=self.src_buff + other.src_buff, - dst_buff=self.dst_buff + other.dst_buff, - channel_ids=self.channel_ids + other.channel_ids, - channel_type=self.channel_type, - tbg=self.tbg, - ) + if self.basic_fusion_check(other): + if ( + isinstance(other, GetOperation) + and self.src_buff[0].size == other.src_buff[0].size + and self.channel_type == other.channel_type + and self.tbg == other.tbg + ): + fused_operation = GetOperation( + self.rank, + self.threadblock, + src_buff=self.src_buff + other.src_buff, + dst_buff=self.dst_buff + other.dst_buff, + channel_ids=self.channel_ids + other.channel_ids, + channel_type=self.channel_type, + tbg=self.tbg, + ) return fused_operation @@ -627,11 +665,20 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + ( + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) + if self.tbg is not None + else 0 + ), + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), chunk.type, DataAccessType.read, - self.tbg + self.tbg, + self.pipeline_context, ) ) return data_access @@ -643,33 +690,33 @@ def shift_buffers(self, instance, num_instances, replication_function): chunk.index = replication_function(chunk.index, chunk.size, instance, num_instances) def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if ( - isinstance(other, PutOperation) - and ( - self.name == Instruction.put - or self.name == Instruction.put_packet - or self.name == Instruction.put_with_signal - or self.name == Instruction.put_with_signal_and_flush - ) - and self.name == other.name - and self.src_buff[0].size == other.src_buff[0].size - and self.channel_type == other.channel_type - and self.tbg == other.tbg - ): - fused_operation = PutOperation( - src_buff=self.src_buff + other.src_buff, - dst_buff=self.dst_buff + other.dst_buff, - channel_ids=self.channel_ids + other.channel_ids, - channel_type=self.channel_type, - tbg=self.tbg, - to_packet=self.to_packet, - with_signal=self.with_signal, - with_signal_and_flush=self.with_signal_and_flush, - ) + if self.basic_fusion_check(other): + if ( + isinstance(other, PutOperation) + and ( + self.name == Instruction.put + or self.name == Instruction.put_packet + or self.name == Instruction.put_with_signal + or self.name == Instruction.put_with_signal_and_flush + ) + and self.name == other.name + and self.src_buff[0].size == other.src_buff[0].size + and self.channel_type == other.channel_type + and self.tbg == other.tbg + ): + fused_operation = PutOperation( + self.rank, + self.threadblock, + src_buff=self.src_buff + other.src_buff, + dst_buff=self.dst_buff + other.dst_buff, + channel_ids=self.channel_ids + other.channel_ids, + channel_type=self.channel_type, + tbg=self.tbg, + to_packet=self.to_packet, + with_signal=self.with_signal, + with_signal_and_flush=self.with_signal_and_flush, + ) return fused_operation @@ -759,11 +806,20 @@ def local_data_access(self, order_id, sync_purpose=True): self.threadblock, self.id, order_id, - chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + ( + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) + if self.tbg is not None + else 0 + ), + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), chunk.type, DataAccessType.read, - self.tbg + self.tbg, + self.pipeline_context, ) ) for chunk in self.local_dst_buff: @@ -774,10 +830,15 @@ def local_data_access(self, order_id, sync_purpose=True): self.id, order_id, chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, - chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), chunk.type, DataAccessType.write, - self.tbg + self.tbg, + self.pipeline_context, ) ) return data_access @@ -793,122 +854,130 @@ def shift_buffers(self, instance, num_instances, replication_function): chunk.index = replication_function(chunk.index, chunk.size, instance, num_instances) def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if ( - isinstance(other, ReduceOperation) - and ( - self.name == Instruction.reduce - or self.name == Instruction.reduce_packet - or self.name == Instruction.read_reduce - ) - and self.name == other.name - and self.local_src_buff[0] == other.local_src_buff[0] - and self.local_dst_buff == other.local_dst_buff - and self.channel_type == other.channel_type - and self.reduce_operation == other.reduce_operation - and self.tbg == other.tbg - ): - fused_operation = ReduceOperation( - self.local_src_buff + other.local_src_buff[1:], - self.local_dst_buff, - remote_src_buff=self.remote_src_buff + other.remote_src_buff, - channel_ids=self.channel_ids + other.channel_ids, - channel_type=self.channel_type, - reduce_operation=self.reduce_operation, - tbg=self.tbg, - packet=self.packet, - ) - if ( - isinstance(other, PutOperation) - and ( - self.name == Instruction.reduce - or self.name == Instruction.reduce_send - or self.name == Instruction.read_reduce - or self.name == Instruction.read_reduce_send - ) - and other.name == Instruction.put - and self.local_dst_buff[0] == other.src_buff[0] - and other.channel_type == ChannelType.memory - and self.tbg == other.tbg - ): - fused_operation = ReduceOperation( - self.local_src_buff, - self.local_dst_buff, - remote_src_buff=self.remote_src_buff, - remote_dst_buff=self.remote_dst_buff + other.dst_buff, - channel_ids=self.channel_ids, - put_channel_ids=self.put_channel_ids + other.channel_ids, - channel_type=self.channel_type, - reduce_operation=self.reduce_operation, - tbg=self.tbg, - packet=self.packet, - ) - if ( - isinstance(other, PutOperation) - and (self.name == Instruction.reduce_packet or self.name == Instruction.reduce_send_packet) - and other.name == Instruction.put_packet - and self.local_dst_buff[0] == other.src_buff[0] - and other.channel_type == ChannelType.memory - and self.tbg == other.tbg - ): - fused_operation = ReduceOperation( - self.local_src_buff, - self.local_dst_buff, - remote_src_buff=self.remote_src_buff, - remote_dst_buff=self.remote_dst_buff + other.dst_buff, - channel_ids=self.channel_ids, - put_channel_ids=self.put_channel_ids + other.channel_ids, - channel_type=other.channel_type, - reduce_operation=self.reduce_operation, - tbg=self.tbg, - packet=self.packet, - ) - if ( - isinstance(other, CopyOperation) - and self.name == Instruction.reduce_packet - and other.name == Instruction.copy_packet - and self.local_dst_buff[0] == other.src_buff[0] - and self.tbg_info == other.tbg_info - ): - fused_operation = ReduceOperation( - self.local_src_buff, - self.local_dst_buff, - local_pkt_dst_buff=other.dst_buff, - remote_src_buff=self.remote_src_buff, - remote_dst_buff=self.remote_dst_buff, - channel_ids=self.channel_ids, - put_channel_ids=self.put_channel_ids, - channel_type=self.channel_type, - reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, - packet=self.packet, - ) - if ( - isinstance(other, PutOperation) - and (self.name == Instruction.reduce_copy_packet or self.name == Instruction.reduce_copy_send_packet) - and ( - (other.name == Instruction.put_packet and self.local_dst_buff[0] == other.src_buff[0]) - or (other.name == Instruction.read_put_packet and self.local_pkt_dst_buff[0] == other.src_buff[0]) - ) - and other.channel_type == ChannelType.memory - and self.tbg_info == other.tbg_info - ): - fused_operation = ReduceOperation( - self.local_src_buff, - self.local_dst_buff, - local_pkt_dst_buff=self.local_pkt_dst_buff, - remote_src_buff=self.remote_src_buff, - remote_dst_buff=self.remote_dst_buff + other.dst_buff, - channel_ids=self.channel_ids, - put_channel_ids=self.put_channel_ids + other.channel_ids, - channel_type=other.channel_type, - reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, - packet=self.packet, - ) + if self.basic_fusion_check(other): + if ( + isinstance(other, ReduceOperation) + and ( + self.name == Instruction.reduce + or self.name == Instruction.reduce_packet + or self.name == Instruction.read_reduce + ) + and self.name == other.name + and self.local_src_buff[0] == other.local_src_buff[0] + and self.local_dst_buff == other.local_dst_buff + and self.channel_type == other.channel_type + and self.reduce_operation == other.reduce_operation + and self.tbg == other.tbg + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff + other.local_src_buff[1:], + self.local_dst_buff, + remote_src_buff=self.remote_src_buff + other.remote_src_buff, + channel_ids=self.channel_ids + other.channel_ids, + channel_type=self.channel_type, + reduce_operation=self.reduce_operation, + tbg=self.tbg, + packet=self.packet, + ) + if ( + isinstance(other, PutOperation) + and ( + self.name == Instruction.reduce + or self.name == Instruction.reduce_send + or self.name == Instruction.read_reduce + or self.name == Instruction.read_reduce_send + ) + and other.name == Instruction.put + and self.local_dst_buff[0] == other.src_buff[0] + and other.channel_type == ChannelType.memory + and self.tbg == other.tbg + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff, + self.local_dst_buff, + remote_src_buff=self.remote_src_buff, + remote_dst_buff=self.remote_dst_buff + other.dst_buff, + channel_ids=self.channel_ids, + put_channel_ids=self.put_channel_ids + other.channel_ids, + channel_type=self.channel_type, + reduce_operation=self.reduce_operation, + tbg=self.tbg, + packet=self.packet, + ) + if ( + isinstance(other, PutOperation) + and (self.name == Instruction.reduce_packet or self.name == Instruction.reduce_send_packet) + and other.name == Instruction.put_packet + and self.local_dst_buff[0] == other.src_buff[0] + and other.channel_type == ChannelType.memory + and self.tbg == other.tbg + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff, + self.local_dst_buff, + remote_src_buff=self.remote_src_buff, + remote_dst_buff=self.remote_dst_buff + other.dst_buff, + channel_ids=self.channel_ids, + put_channel_ids=self.put_channel_ids + other.channel_ids, + channel_type=other.channel_type, + reduce_operation=self.reduce_operation, + tbg=self.tbg, + packet=self.packet, + ) + if ( + isinstance(other, CopyOperation) + and self.name == Instruction.reduce_packet + and other.name == Instruction.copy_packet + and self.local_dst_buff[0] == other.src_buff[0] + and self.tbg_info == other.tbg_info + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff, + self.local_dst_buff, + local_pkt_dst_buff=other.dst_buff, + remote_src_buff=self.remote_src_buff, + remote_dst_buff=self.remote_dst_buff, + channel_ids=self.channel_ids, + put_channel_ids=self.put_channel_ids, + channel_type=self.channel_type, + reduce_operation=self.reduce_operation, + tbg_info=self.tbg_info, + packet=self.packet, + ) + if ( + isinstance(other, PutOperation) + and (self.name == Instruction.reduce_copy_packet or self.name == Instruction.reduce_copy_send_packet) + and ( + (other.name == Instruction.put_packet and self.local_dst_buff[0] == other.src_buff[0]) + or (other.name == Instruction.read_put_packet and self.local_pkt_dst_buff[0] == other.src_buff[0]) + ) + and other.channel_type == ChannelType.memory + and self.tbg_info == other.tbg_info + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff, + self.local_dst_buff, + local_pkt_dst_buff=self.local_pkt_dst_buff, + remote_src_buff=self.remote_src_buff, + remote_dst_buff=self.remote_dst_buff + other.dst_buff, + channel_ids=self.channel_ids, + put_channel_ids=self.put_channel_ids + other.channel_ids, + channel_type=other.channel_type, + reduce_operation=self.reduce_operation, + tbg_info=self.tbg_info, + packet=self.packet, + ) return fused_operation @@ -966,27 +1035,27 @@ def shift_buffers(self, instance, num_instances, replication_function): self.dst_chunk.index = replication_function(self.dst_chunk.index, self.size, instance, num_instances) def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if ( - isinstance(other, GroupStore) - and self.buffer_type == other.buffer_type - and self.size == other.size - and self.dst_chunk == other.src_chunk - and self.channel_ids == other.channel_ids - and self.channel_type == other.channel_type - ): - fused_operation = GroupLoadReduceStore( - buffer_type=self.buffer_type, - size=self.size, - src_index=[self.buffer_offset], - dst_index=[other.buffer_offset], - channel_ids=self.channel_ids, - channel_type=self.channel_type, - reduce_operation=self.reduce_operation, - ) + if self.basic_fusion_check(other): + if ( + isinstance(other, GroupStore) + and self.buffer_type == other.buffer_type + and self.size == other.size + and self.dst_chunk == other.src_chunk + and self.channel_ids == other.channel_ids + and self.channel_type == other.channel_type + ): + fused_operation = GroupLoadReduceStore( + self.rank, + self.threadblock, + buffer_type=self.buffer_type, + size=self.size, + src_index=[self.buffer_offset], + dst_index=[other.buffer_offset], + channel_ids=self.channel_ids, + channel_type=self.channel_type, + reduce_operation=self.reduce_operation, + ) return fused_operation @@ -1130,14 +1199,12 @@ def shift_ids(self, instance, num_instances, replication_function): operation.shift_ids(instance, num_instances, replication_function) def __add__(self, other): - if not self.basic_fused_check(self): - return None - fused_operation = None - if (self.get_data_sync() & SyncType.after) == SyncType.after and check_data_sync_op(other): - other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) - elif isinstance(other, SyncOperation) and (self.get_data_sync() & SyncType.after) == SyncType.after: - fused_operation = self + if self.basic_fusion_check(other): + if (self.get_data_sync() & SyncType.after) == SyncType.after and check_data_sync_op(other): + other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) + elif isinstance(other, SyncOperation) and (self.get_data_sync() & SyncType.after) == SyncType.after: + fused_operation = self return fused_operation @@ -1180,11 +1247,11 @@ def add_data_sync(operations): if operation.name in data_sync_operations and ( operation.data_sync == SyncType.before or operation.data_sync == SyncType.both ): - result_operations.append(SyncOperation()) + result_operations.append(SyncOperation(operation.rank, operation.threadblock)) result_operations.append(operation) if operation.name in data_sync_operations and ( operation.data_sync == SyncType.after or operation.data_sync == SyncType.both ): - result_operations.append(SyncOperation()) + result_operations.append(SyncOperation(operation.rank, operation.threadblock)) - return result_operations \ No newline at end of file + return result_operations diff --git a/python/mscclpp/language/internal/register.py b/python/mscclpp/language/internal/register.py index ed0737d9a..833415615 100644 --- a/python/mscclpp/language/internal/register.py +++ b/python/mscclpp/language/internal/register.py @@ -1,4 +1,3 @@ - class ChannelRegister: channels = {} @@ -9,7 +8,8 @@ def add_channel(rank, tb, tb_channel_id, channel): @staticmethod def get_channel(rank: int, threadblock: int, tb_channel_id: int): return ChannelRegister.channels.get((rank, threadblock, tb_channel_id)) - + + class SemaphoreRegister: semaphores = {} @@ -19,4 +19,4 @@ def add_semaphore(semaphore): @staticmethod def get_semaphore(rank: int, semaphore_id: int): - return SemaphoreRegister.semaphores.get((rank, semaphore_id)) \ No newline at end of file + return SemaphoreRegister.semaphores.get((rank, semaphore_id)) diff --git a/python/mscclpp/language/internal/types.py b/python/mscclpp/language/internal/types.py index e7a5de079..99e71e57a 100644 --- a/python/mscclpp/language/internal/types.py +++ b/python/mscclpp/language/internal/types.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from mscclpp.language.loop import LoopIterationContext from mscclpp.language.thread_block_group import ThreadBlockGroup from dataclasses import dataclass, field from enum import Enum @@ -8,6 +9,7 @@ from collections import defaultdict import uuid + class SyncType(Enum): none = "none" before = "before" @@ -183,22 +185,22 @@ class DataAccess: end: float buffer_type: BufferType data_access_type: DataAccessType - tb_group: ThreadBlockGroup = None - + tb_group: ThreadBlockGroup + pipeline_context: LoopIterationContext + def __lt__(self, other): if self.start != other.start: return self.start < other.start return self.end < other.end def __eq__(self, other, tolerance=1e-5): - return (abs(self.start - other.start) < tolerance and - abs(self.end - other.end) < tolerance) + return abs(self.start - other.start) < tolerance and abs(self.end - other.end) < tolerance def __hash__(self): return hash((self.start, self.end)) def lower_overlaps(self, other, tolerance=1e-5) -> bool: - return (self.start + tolerance < other.end) + return self.start + tolerance < other.end def overlaps(self, other, tolerance=1e-5) -> bool: return (self.start + tolerance < other.end) and (other.start + tolerance < self.end) @@ -210,15 +212,41 @@ def check_conflict(self, other) -> bool: and (self.data_access_type != DataAccessType.read or other.data_access_type != DataAccessType.read) ): if self.threadblock == other.threadblock: - return DataAccessConflict(self.rank, {(other.threadblock, other.operation_order_id, True)}, DataAccessConflictType.intra_threadblock) + return DataAccessConflict( + self.rank, + {(other.threadblock, other.operation_order_id, True)}, + DataAccessConflictType.intra_threadblock, + ) else: - is_order_defined = ((self.tb_group is not None and other.tb_group is not None and self.tb_group.tbg_overlap(other.tb_group)) - or (self.tb_group is not None and other.tb_group is None and self.tb_group.tb_overlap(other.threadblock)) - or (self.tb_group is None and other.tb_group is not None and other.tb_group.tb_overlap(self.threadblock))) - return DataAccessConflict(self.rank, {(self.threadblock, other.operation_order_id, True), (other.threadblock, other.operation_order_id, is_order_defined)}, DataAccessConflictType.inter_threadblock) + is_order_defined = ( + ( + self.tb_group is not None + and other.tb_group is not None + and self.tb_group.tbg_overlap(other.tb_group) + ) + or ( + self.tb_group is not None + and other.tb_group is None + and self.tb_group.tb_overlap(other.threadblock) + ) + or ( + self.tb_group is None + and other.tb_group is not None + and other.tb_group.tb_overlap(self.threadblock) + ) + ) + return DataAccessConflict( + self.rank, + { + (self.threadblock, other.operation_order_id, True), + (other.threadblock, other.operation_order_id, is_order_defined), + }, + DataAccessConflictType.inter_threadblock, + ) else: return DataAccessConflict(self.rank) + class DataAccessConflictType(Enum): inter_threadblock = "inter_tb" intra_threadblock = "intra_tb" @@ -228,28 +256,40 @@ def __add__(self, other): if not isinstance(other, DataAccessConflictType): return NotImplemented - map_to_num = {DataAccessConflictType.none: 0, DataAccessConflictType.intra_threadblock: 1, DataAccessConflictType.inter_threadblock: 3} - map_to_dact = {0: DataAccessConflictType.none, 1: DataAccessConflictType.intra_threadblock, 3: DataAccessConflictType.inter_threadblock} + map_to_num = { + DataAccessConflictType.none: 0, + DataAccessConflictType.intra_threadblock: 1, + DataAccessConflictType.inter_threadblock: 3, + } + map_to_dact = { + 0: DataAccessConflictType.none, + 1: DataAccessConflictType.intra_threadblock, + 3: DataAccessConflictType.inter_threadblock, + } return map_to_dact[map_to_num[self] | map_to_num[other]] def __str__(self): return self.value + @dataclass -class DataAccessConflict(): +class DataAccessConflict: rank: int - threadblocks: Set[int] = field(default_factory=set) + threadblocks: Set[int] = field(default_factory=set) conflict_type: DataAccessConflictType = DataAccessConflictType.none def __add__(self, other): if not isinstance(other, DataAccessConflict): return NotImplemented - return DataAccessConflict(self.rank, self.threadblocks | other.threadblocks, self.conflict_type + other.conflict_type) + return DataAccessConflict( + self.rank, self.threadblocks | other.threadblocks, self.conflict_type + other.conflict_type + ) + class ReplicationPolicy(Enum): interleaved = "interleaved" none = "none" def __str__(self): - return self.value \ No newline at end of file + return self.value diff --git a/python/mscclpp/language/loop.py b/python/mscclpp/language/loop.py index d79c6bd78..ad70ecedb 100644 --- a/python/mscclpp/language/loop.py +++ b/python/mscclpp/language/loop.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from mscclpp.language.internal.globals import * +from typing import Dict class LoopIterationContext: @@ -34,6 +35,7 @@ def __init__(self, unit, num_chunks): """ self.unit = unit self.num_chunks = num_chunks + self.pre_operations = dict() def __enter__(self): """Enter the context and set this as the active loop context. @@ -52,7 +54,7 @@ def __exit__(self, exc_type, exc_value, traceback): """ get_program().set_loop_context(None) - def process_operation(self, operation): + def process_operation(self, operations): """Add an operation to be included in the pipeline. This method is called internally to collect operations that should be @@ -64,4 +66,5 @@ def process_operation(self, operation): tb (int): The thread block ID that will execute this operation. operation: The operation object to be added to the pipeline. """ - operation.set_pipeline_context(self) + for operation in operations: + operation.set_pipeline_context(self) diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index 0d82db223..61554a592 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -12,6 +12,7 @@ from mscclpp.language.rank import Semaphore from mscclpp.language.collectives import * from mscclpp.language.utils import AlgoSpec, ReplicationPolicy +from mscclpp.language.internal.operations import add_data_sync from typing import List import json @@ -214,15 +215,19 @@ def add_semaphore(self, semaphore): def add_operation(self, rank, tb, operation): if self.loop_context != None: - self.loop_context.process_operation(operation) + self.loop_context.process_operation([operation]) self.op_dep_dag.add_operation(operation) def add_tbg_operation(self, operations): + if self.loop_context != None: + self.loop_context.process_operation(operations) self.op_dep_dag.add_tbg_operation(operations) def post_process_operations(self): self.op_dep_dag.add_semaphore_dependency() + self.op_dep_dag.fusion_operations() list_op = self.op_dep_dag.get_execution_order() + list_op = add_data_sync(list_op) list_op = self.buffers_access.process_operations(list_op) for op in list_op: self.gpus[op.rank].add_operation(op.threadblock, op) diff --git a/python/mscclpp/language/rank.py b/python/mscclpp/language/rank.py index 462465e6e..b3fbdf88f 100644 --- a/python/mscclpp/language/rank.py +++ b/python/mscclpp/language/rank.py @@ -117,16 +117,12 @@ def _copy( threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], - tbg=( - tb_group - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), from_packet=from_packet, to_packet=to_packet, ) operations.append(op) - + if tb_group is None: get_program().add_operation(self.rank, tb_id, operations[0]) else: @@ -256,11 +252,7 @@ def reduce( + [LocalChunk(chunk.buffer, chunk.index, chunk.size) for chunk in other_chunks], local_dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], reduce_operation=reduce_op, - tbg=( - tb_group - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), packet=packet, ) operations.append(op) diff --git a/python/mscclpp/language/thread_block_group.py b/python/mscclpp/language/thread_block_group.py index a0d63c498..811d30f88 100644 --- a/python/mscclpp/language/thread_block_group.py +++ b/python/mscclpp/language/thread_block_group.py @@ -70,5 +70,3 @@ def start_offset(self, tb, size): def end_offset(self, tb, size): tb_id = self.get_internal_id(tb) return (size / self.numtb()) * (tb_id + 1) - -