diff --git a/python/mscclpp/language/channel.py b/python/mscclpp/language/channel.py index 1b22e4e27..d0d29b040 100644 --- a/python/mscclpp/language/channel.py +++ b/python/mscclpp/language/channel.py @@ -25,6 +25,7 @@ class MemoryChannel: """ _channel_counts = defaultdict(int) + _channel_peer_counts = defaultdict(int) @classmethod def reset(cls): @@ -52,6 +53,8 @@ def __init__(self, dst_rank: int, src_rank: int): self.channel_id = MemoryChannel._channel_counts[src_rank] MemoryChannel._channel_counts[src_rank] += 1 + self.channel_peer_id = MemoryChannel._channel_peer_counts[(src_rank, dst_rank)] + MemoryChannel._channel_peer_counts[(src_rank, dst_rank)] += 1 self.dst_rank = dst_rank self.src_rank = src_rank @@ -76,7 +79,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = F >>> channel.signal(tb=0, data_sync=SyncType.before) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) + op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False): @@ -97,7 +100,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = Fal >>> channel.wait(tb=0, data_sync=SyncType.after) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) + op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): @@ -138,21 +141,25 @@ def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb, self) op = GetOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)], dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), ) - get_program().add_operation(self.src_rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Send data from local memory to remote memory. @@ -197,21 +204,25 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb_id, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), ) - get_program().add_operation(self.src_rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Transfer data in packet format from local to remote scratch buffer. @@ -261,23 +272,27 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, t "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb_id, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), from_packet=True, to_packet=True, ) - get_program().add_operation(self.src_rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Transfer data from local buffer to remote scratch buffer in packet format. @@ -325,24 +340,27 @@ def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_gro "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb_id, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), from_packet=False, to_packet=True, ) + operations.append(op) - get_program().add_operation(self.src_rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def reduce( self, @@ -405,6 +423,7 @@ def reduce( "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: remote_chunks = [ RemoteChunk( @@ -423,21 +442,23 @@ def reduce( tb_channel_ids = get_program().setup_channel(tb_id, self) op = ReduceOperation( + rank=self.src_rank, + threadblock=tb_id, local_src_buff=[LocalChunk(local_src_chunk.buffer, local_src_chunk.index, local_src_chunk.size)], local_dst_buff=[LocalChunk(local_dst_chunk.buffer, local_dst_chunk.index, local_dst_chunk.size)], remote_src_buff=remote_chunks, remote_dst_buff=[], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), reduce_operation=reduce_op, ) + operations.apend(op) - get_program().add_operation(self.src_rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) @dataclass @@ -457,6 +478,7 @@ class PortChannel: """ _channel_counts = defaultdict(int) + _channel_peer_counts = defaultdict(int) @classmethod def reset(cls): @@ -484,6 +506,8 @@ def __init__(self, dst_rank: int, src_rank: int): self.channel_id = PortChannel._channel_counts[src_rank] PortChannel._channel_counts[src_rank] += 1 + self.channel_peer_id = PortChannel._channel_peer_counts[(src_rank, dst_rank)] + PortChannel._channel_peer_counts[(src_rank, dst_rank)] += 1 self.dst_rank = dst_rank self.src_rank = src_rank @@ -506,7 +530,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both): >>> channel.signal(tb=0, data_sync=SyncType.before) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = SignalOperation(tb_channel_ids, self.channel_type, data_sync) + op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) def wait(self, tb: int, data_sync: SyncType = SyncType.both): @@ -525,7 +549,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both): >>> channel.wait(tb=0, data_sync=SyncType.after) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = WaitOperation(tb_channel_ids, self.channel_type, data_sync) + op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) def flush(self, tb: int, data_sync: SyncType = SyncType.both): @@ -544,7 +568,7 @@ def flush(self, tb: int, data_sync: SyncType = SyncType.both): >>> channel.flush(tb=0, data_sync=SyncType.after) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = FlushOperation(tb_channel_ids, self.channel_type, data_sync) + op = FlushOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): @@ -583,6 +607,8 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -628,6 +654,8 @@ def put_with_signal(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -673,6 +701,8 @@ def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int) tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -772,6 +802,8 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -900,13 +932,15 @@ def reduce(self, rank, buffer_offset, size, dst_chunk: Chunk, tb, reduce_op=Redu tb_channel_ids = get_program().setup_channel(tb, self) op = GroupLoadReduce( - self.buffer_type, - buffer_offset, - size, - dst_chunk, - tb_channel_ids, - self.channel_type, - reduce_op, + rank=self.src_rank, + threadblock=tb, + buffer_type=self.buffer_type, + buffer_offset=buffer_offset, + size=size, + dst_chunk=dst_chunk, + tb_channel_ids=tb_channel_ids, + channel_type=self.channel_type, + reduce_operation=reduce_op, ) get_program().add_operation(self.src_rank, tb, op) @@ -950,7 +984,9 @@ def broadcast(self, rank, src_chunk: Chunk, buffer_offset, size, tb): ) tb_channel_ids = get_program().setup_channel(tb, self) - op = GroupStore(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type) + op = GroupStore( + self.src_rank, tb, src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type + ) get_program().add_operation(self.src_rank, tb, op) class SwitchChannelRankView: diff --git a/python/mscclpp/language/internal/buffer_access.py b/python/mscclpp/language/internal/buffer_access.py index ab8a7bdcc..f0d06b3e8 100644 --- a/python/mscclpp/language/internal/buffer_access.py +++ b/python/mscclpp/language/internal/buffer_access.py @@ -3,92 +3,236 @@ from sortedcontainers import SortedDict from typing import List -from mscclpp.language.internal.types import BufferType, DataAccessType +from mscclpp.language.internal.types import * from mscclpp.language.internal.operations import * from enum import Enum class BuffersAccess: - def __init__(self): - self.intervals = { - BufferType.input: SortedDict(), - BufferType.output: SortedDict(), - BufferType.scratch: SortedDict(), - } + def __init__(self, num_ranks, intra_rank_sync): + self.intra_rank_sync = intra_rank_sync + self.rank_intervals = [ + { + BufferType.input: SortedDict(), + BufferType.output: SortedDict(), + BufferType.scratch: SortedDict(), + } + for _ in range(num_ranks) + ] + self.track_sync = {} + self.track_barrier = {} def process_operations(self, operations): result_operations = [] - for operation in operations: + for i in range(len(operations)): + operation = operations[i] if operation.name == Instruction.nop or operation.name == Instruction.barrier: - self.clear_data_access() + self.track_sync[operation.rank, operation.threadblock] = i + if operation.name == Instruction.barrier: + self.update_barrier(operation, i) + if operation.name == Instruction.sem_acquire: + self.update_semaphore(operation, i) else: - if operation.name == Instruction.pipeline: - pipeline_buffer_access = BuffersAccess() - pipeline_result_operations = pipeline_buffer_access.process_operations(operation.operations) - operation.operations = pipeline_result_operations - data_access = operation.local_data_access() - sync_added = False + data_access = operation.local_data_access(i) + data_access_conflict_same_ctx = DataAccessConflict(operation.rank) + data_access_conflict_diff_ctx = DataAccessConflict(operation.rank) for data_access_element in data_access: - if self.compute_data_access(data_access_element) and not sync_added: - result_operations.append(SyncOperation()) - sync_added = True + computed_data_access_conflict_same_ctx, computed_data_access_conflict_diff_ctx = ( + self.compute_data_access(data_access_element) + ) + data_access_conflict_same_ctx = ( + data_access_conflict_same_ctx + computed_data_access_conflict_same_ctx + ) + data_access_conflict_diff_ctx = ( + data_access_conflict_diff_ctx + computed_data_access_conflict_diff_ctx + ) + fix_operations = self.resolve_conflicts( + operation.rank, + operation.threadblock, + operation.pipeline_context, + i, + data_access_conflict_same_ctx, + data_access_conflict_diff_ctx, + ) + result_operations.extend(fix_operations) result_operations.append(operation) + result_operations = self.add_pipeline_context_sync_operations(result_operations) + return result_operations + def update_barrier(self, operation, order_id): + for tb in operation.barrier_info.tb_list: + if operation.threadblock != tb: + self.track_barrier[operation.rank, operation.threadblock, tb] = order_id + self.track_sync[operation.rank, operation.threadblock] = order_id + + def update_semaphore(self, operation, order_id): + for tb in operation.tb_sync: + if operation.threadblock != tb: + self.track_barrier[operation.rank, operation.threadblock, tb] = order_id + def compute_data_access(self, data_access: DataAccess) -> bool: - keys = self.intervals[data_access.buffer_type].keys() + intervals = self.rank_intervals[data_access.rank] + keys = intervals[data_access.buffer_type].keys() idx = self.lower_bound(0, len(keys) - 1, keys, data_access) - conflict = False + conflict_same_ctx = DataAccessConflict(data_access.rank) + conflict_diff_ctx = DataAccessConflict(data_access.rank) while len(keys) > 0 and data_access.overlaps(keys[idx]): conflict_data_access = keys[idx] - conflict_operation_type = self.intervals[data_access.buffer_type][conflict_data_access] - if data_access.check_conflict(conflict_data_access): - self.clear_data_access() - conflict = True - break + conflict_operation_type = intervals[data_access.buffer_type][conflict_data_access] + if ( + data_access.pipeline_context is None + or data_access.pipeline_context == conflict_data_access.pipeline_context + ): + conflict_same_ctx = conflict_same_ctx + data_access.check_conflict(conflict_data_access) + else: + conflict_diff_ctx = conflict_diff_ctx + data_access.check_conflict(conflict_data_access) - self.intervals[data_access.buffer_type].pop(conflict_data_access) + intervals[data_access.buffer_type].pop(conflict_data_access) if conflict_data_access.end > data_access.end: - self.intervals[data_access.buffer_type][ + intervals[data_access.buffer_type][ DataAccess( - conflict_data_access.operation_id, - data_access.end + 1, + conflict_data_access.rank, + conflict_data_access.threadblock, + conflict_data_access.operation_global_id, + conflict_data_access.operation_order_id, + data_access.end, conflict_data_access.end, conflict_data_access.buffer_type, conflict_operation_type, + conflict_data_access.tb_group, + conflict_data_access.pipeline_context, ) ] = conflict_operation_type if conflict_data_access.start < data_access.start: - self.intervals[data_access.buffer_type][ + intervals[data_access.buffer_type][ DataAccess( - conflict_data_access.operation_id, + conflict_data_access.rank, + conflict_data_access.threadblock, + conflict_data_access.operation_global_id, + conflict_data_access.operation_order_id, conflict_data_access.start, - data_access.start - 1, + data_access.start, conflict_data_access.buffer_type, conflict_operation_type, + conflict_data_access.tb_group, + conflict_data_access.pipeline_context, ) ] = conflict_operation_type - keys = self.intervals[data_access.buffer_type].keys() + keys = intervals[data_access.buffer_type].keys() idx = self.lower_bound(0, len(keys) - 1, keys, data_access) - self.intervals[data_access.buffer_type][data_access] = data_access.data_access_type - return conflict + intervals[data_access.buffer_type][data_access] = data_access.data_access_type + return (conflict_same_ctx, conflict_diff_ctx) + + def resolve_conflicts( + self, + rank, + threadblock, + pipeline_context, + order_id, + data_access_conflict_same_ctx: DataAccessConflict, + data_access_conflict_diff_ctx: DataAccessConflict, + ): + fix_operations = [] + if data_access_conflict_same_ctx.conflict_type == DataAccessConflictType.intra_threadblock: + for tb in data_access_conflict_same_ctx.threadblocks: + if (rank, threadblock) not in self.track_sync or tb[1] > self.track_sync[(rank, threadblock)]: + fix_operations.append(SyncOperation(rank, threadblock)) + self.track_sync[(rank, threadblock)] = order_id + break + if ( + data_access_conflict_same_ctx.conflict_type == DataAccessConflictType.inter_threadblock + and self.intra_rank_sync + ): + conflict_tb = set([threadblock]) + for tb in data_access_conflict_same_ctx.threadblocks: + if threadblock != tb[0] and ( + (rank, threadblock, tb[0]) not in self.track_barrier + or self.track_barrier[(rank, threadblock, tb[0])] < tb[1] + ): + if not tb[2]: + raise RuntimeError("Operations order not defined.") + conflict_tb.add(tb[0]) + if len(conflict_tb) > 1: + for tb in conflict_tb: + op = BarrierOperation(rank, tb, conflict_tb) + self.update_barrier(op, order_id) + fix_operations.append(op) + + if pipeline_context is not None: + if (rank, threadblock) not in pipeline_context.pre_operations: + pipeline_context.pre_operations[(rank, threadblock)] = [] - def clear_data_access(self): - self.intervals[BufferType.input].clear() - self.intervals[BufferType.output].clear() - self.intervals[BufferType.scratch].clear() + if data_access_conflict_diff_ctx.conflict_type == DataAccessConflictType.intra_threadblock: + for tb in data_access_conflict_diff_ctx.threadblocks: + if (rank, threadblock) not in self.track_sync or tb[1] > self.track_sync[(rank, threadblock)]: + self.track_sync[(rank, threadblock)] = order_id + pipeline_context.pre_operations[(rank, threadblock)].append(SyncOperation(rank, threadblock)) + break + if ( + data_access_conflict_diff_ctx.conflict_type == DataAccessConflictType.inter_threadblock + and self.intra_rank_sync + ): + conflict_tb = set([threadblock]) + for tb in data_access_conflict_diff_ctx.threadblocks: + if threadblock != tb[0] and ( + (rank, threadblock, tb[0]) not in self.track_barrier + or self.track_barrier[(rank, threadblock, tb[0])] < tb[1] + ): + if not tb[2]: + raise RuntimeError("Operations order not defined.") + conflict_tb.add(tb[0]) + if len(conflict_tb) > 1: + for tb in conflict_tb: + op = BarrierOperation(rank, tb, conflict_tb) + self.update_barrier(op, order_id) + pipeline_context.pre_operations[(rank, threadblock)].append(op) + + return fix_operations + + def add_pipeline_context_sync_operations(self, operations): + result_operations = [] + pipeline_operations = dict() + for i in range(len(operations)): + operation = operations[i] + if operation.pipeline_context is not None: + pipeline_context = operation.pipeline_context + result_operations.extend( + pipeline_context.pre_operations.get((operation.rank, operation.threadblock), []) + ) + pipeline_context.pre_operations.pop((operation.rank, operation.threadblock), None) + + if (operation.rank, operation.threadblock, operation.pipeline_context) not in pipeline_operations: + pipeline_operations[(operation.rank, operation.threadblock, operation.pipeline_context)] = ( + PipelineOperation( + operation.rank, + operation.threadblock, + operation.pipeline_context.unit, + operation.pipeline_context.num_chunks, + ) + ) + result_operations.append( + pipeline_operations[(operation.rank, operation.threadblock, operation.pipeline_context)] + ) + pipeline_operations[(operation.rank, operation.threadblock, operation.pipeline_context)].add_operation( + operation + ) + else: + result_operations.append(operation) + + return result_operations def lower_bound(self, init_pos, final_pos, data_access_list, data_access): if init_pos >= final_pos: return init_pos mid_pos = (init_pos + final_pos) // 2 - if data_access.start <= data_access_list[mid_pos].end: + if data_access.lower_overlaps(data_access_list[mid_pos]): final_pos = mid_pos else: init_pos = mid_pos + 1 diff --git a/python/mscclpp/language/internal/op_dep_graph.py b/python/mscclpp/language/internal/op_dep_graph.py new file mode 100644 index 000000000..848cf1328 --- /dev/null +++ b/python/mscclpp/language/internal/op_dep_graph.py @@ -0,0 +1,384 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from mscclpp.language.internal.globals import * +from mscclpp.language.internal.operations import * +from mscclpp.language.internal.types import * +from mscclpp.language.internal.register import ChannelRegister, SemaphoreRegister +from queue import Queue +from typing import Set, Dict, Tuple +import warnings + + +class OperationDependencyGraph: + """ + A DAG structure to enforce correct execution order of collective communication operations. + Supports topological sorting based on rank/threadblock execution and signal/wait synchronization. + """ + + def __init__(self): + self.root_nodes: Set[OperationDependencyGraph.Node] = set() + self.last_node: Dict[Tuple[int, int], int] = {} + self.signalling: Dict[Tuple[int, int, int], Queue] = {} + self.waiting: Dict[Tuple[int, int, int], Queue] = {} + + self.barrier_nodes: Dict[Tuple[int, int], List[OperationDependencyGraph.Node]] = {} + self.tb_barriers: Dict[Tuple[int, int, int], int] = {} + self.node_list = [] + + def add_operation(self, operation, agg_node=None): + """ + Inserts an operation into the DAG, adding edges based on dependencies. + """ + rank = operation.rank + threadblock = operation.threadblock + node = self.Node(operation) + if agg_node is not None: + agg_node.add_node(node) + + if isinstance(operation, BarrierOperation): + if (rank, threadblock, operation.barrier_id) not in self.tb_barriers: + self.tb_barriers[(rank, threadblock, operation.barrier_id)] = 0 + if (rank, operation.barrier_id) not in self.barrier_nodes: + self.barrier_nodes[(rank, operation.barrier_id)] = [] + + barrier_count = self.tb_barriers[(rank, threadblock, operation.barrier_id)] + if barrier_count > len(self.barrier_nodes[(rank, operation.barrier_id)]): + raise RuntimeError( + f"Barrier node not create correctly for rank {rank}, threadblock {threadblock}, barrier_id {operation.barrier_id}." + ) + elif barrier_count == len(self.barrier_nodes[(rank, operation.barrier_id)]): + agg_node = self.AggregateNode() + self.barrier_nodes[(rank, operation.barrier_id)].append(agg_node) + else: + agg_node = self.barrier_nodes[(rank, operation.barrier_id)][barrier_count] + + self.tb_barriers[(rank, threadblock, operation.barrier_id)] += 1 + agg_node.add_node(node) + node = agg_node + + self.node_list.append(node) + if (rank, threadblock) not in self.last_node: + self.last_node[(rank, threadblock)] = node + if node.get_input() == 0: + self.root_nodes.add(node) + else: + prev_node = self.last_node[(rank, threadblock)] + if prev_node is not node: + prev_node.next_nodes.append(node) + node.previous_nodes.append(prev_node) + node.add_input() + self.last_node[(rank, threadblock)] = node + if node in self.root_nodes: + self.root_nodes.remove(node) + + if isinstance(operation, SignalOperation) or ( + isinstance(operation, PutOperation) and (operation.with_signal or operation.with_signal_and_flush) + ): + for tb_channel_id in operation.channel_ids: + channel = ChannelRegister.get_channel(rank, threadblock, tb_channel_id) + op_info = (channel.src_rank, channel.dst_rank, channel.channel_peer_id) + if op_info not in self.waiting or self.waiting[op_info].empty(): + if op_info not in self.signalling: + self.signalling[op_info] = Queue() + self.signalling[op_info].put(node) + else: + waiting_node = self.waiting[op_info].get() + node.next_nodes.append(waiting_node) + waiting_node.previous_nodes.append(node) + waiting_node.add_input() + + if isinstance(operation, WaitOperation): + for tb_channel_id in operation.channel_ids: + channel = ChannelRegister.get_channel(rank, threadblock, tb_channel_id) + op_info = (channel.dst_rank, channel.src_rank, channel.channel_peer_id) + if op_info not in self.signalling or self.signalling[op_info].empty(): + if op_info not in self.waiting: + self.waiting[op_info] = Queue() + self.waiting[op_info].put(node) + else: + signalling_node = self.signalling[op_info].get() + signalling_node.next_nodes.append(node) + node.previous_nodes.append(signalling_node) + node.add_input() + + return node + + def add_tbg_operation(self, operations): + agg_node = self.AggregateNode() + for operation in operations: + self.add_operation(operation, agg_node) + + def add_semaphore_dependency(self): + queue = Queue() + processed_node = set() + sem_rel = {} + sem_acq = {} + sem_val = {} + + self.reset() + + def compute_sem_op(sem_op, node): + operation = node.operation + for id in operation.semaphore_ids: + if (operation.rank, id) not in sem_op: + sem_op[(operation.rank, id)] = [] + sem_val[(operation.rank, id)] = SemaphoreRegister.get_semaphore(operation.rank, id).initial_value + sem_op[(operation.rank, id)].append((node, operation.pipeline_context)) + + return True + + def process_node(node): + if node in processed_node: + return + processed_node.add(node) + + for next_node in node.next_nodes: + next_node.add_reach() + if next_node.get_reach() == next_node.get_input(): + if isinstance(next_node, self.Node) and next_node.agg_node is not None: + for sub_node in next_node.agg_node.nodes: + queue.put(sub_node) + else: + queue.put(next_node) + + for node in self.root_nodes: + queue.put(node) + + while True: + sem_ops_found = False + new_sem_rel_node = [] + while not queue.empty(): + node = queue.get() + if isinstance(node, self.Node) and isinstance(node.operation, SemaphoreReleaseOperation): + sem_ops_found = compute_sem_op(sem_rel, node) + new_sem_rel_node.append(node) + elif isinstance(node, self.Node) and isinstance(node.operation, SemaphoreAcquireOperation): + sem_ops_found = compute_sem_op(sem_acq, node) + else: + process_node(node) + + if not sem_ops_found: + break + else: + removed_keys = [] + for key in sem_acq.keys(): + if key not in sem_rel: + sem_rel[key] = [] + if len(sem_acq[key]) > 1 or sem_val[key] < len(sem_rel[key]) - len(sem_acq[key]): + get_program().disable_inter_tb_sync() + warnings.warn(f"Undefined Behaviour Semaphore Id.", UserWarning) + return + + for sem_rel_node in new_sem_rel_node: + process_node(sem_rel_node) + + if sem_val[key] == len(sem_rel[key]) - len(sem_acq[key]): + sem_acq_node, sem_acq_ctx = sem_acq[key][0] + sem_val[key] = 0 + if sem_acq_node in self.root_nodes: + self.root_nodes.remove(sem_acq_node) + process_node(sem_acq_node) + for sem_rel_node, sem_rel_ctx in sem_rel[key]: + if sem_rel_ctx is not sem_acq_ctx: + raise RuntimeError(f"Semaphore cross pipeline context violation.") + sem_rel_node.next_nodes.append(sem_acq_node) + sem_acq_node.operation.add_tb_sync(sem_rel_node.operation.threadblock) + sem_acq_node.previous_nodes.append(sem_rel_node) + sem_acq_node.add_input() + + removed_keys.append(key) + + for key in removed_keys: + sem_rel.pop(key) + sem_acq.pop(key) + + if len(sem_acq.keys()) > 0: + raise RuntimeError(f"Semaphore acquire hanging.") + + def reset(self): + for node in self.node_list: + node.reset() + + def print(self): + self.reset() + self.check() + + queue = Queue() + for node in self.root_nodes: + queue.put(node) + + while not queue.empty(): + node = queue.get() + print(f"node {node.print()}") + for next_node in node.next_nodes: + next_node.add_reach() + print(f"next_node {next_node.print()}") + if next_node.get_reach() == next_node.get_input(): + if isinstance(next_node, self.Node) and next_node.agg_node is not None: + for sub_node in next_node.agg_node.nodes: + queue.put(sub_node) + else: + queue.put(next_node) + print() + + def check(self): + """ + Validates the DAG structure, ensuring all nodes are reachable and dependencies are correctly set. + """ + if len(self.signalling) > 0: + for key, queue in self.signalling.items(): + if not queue.empty(): + raise RuntimeError( + f"Signalling from {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent wait operation." + ) + if len(self.waiting) > 0: + for key, queue in self.waiting.items(): + if not queue.empty(): + raise RuntimeError( + f"Waiting for {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent signal operation." + ) + + def fusion_operations(self): + self.reset() + self.check() + + for node in self.root_nodes: + for next_node in node.next_nodes: + if isinstance(node, self.Node) and isinstance(next_node, self.Node): + fused_op = node.operation + next_node.operation + if fused_op is not None: + node.operation = fused_op + node.next_nodes.remove(next_node) + next_node.previous_nodes.remove(node) + for nn in next_node.next_nodes: + node.next_nodes.append(nn) + ## Change from list to another data sctruct that allos insert and remove in log(n) + nn.previous_nodes.remove(next_node) + nn.previous_nodes.append(node) + for pn in next_node.previous_nodes: + node.previous_nodes.append(pn) + pn.next_nodes.remove(next_node) + pn.next_nodes.append(node) + node.add_input() + del next_node + + def get_execution_order(self): + """ + Returns the order of operations in the DAG. + """ + self.reset() + self.check() + + order = [] + queue = Queue() + for node in self.root_nodes: + queue.put(node) + + while not queue.empty(): + node = queue.get() + order.extend(node.get_operations()) + for next_node in node.next_nodes: + next_node.add_reach() + if next_node.get_reach() == next_node.get_input(): + if isinstance(next_node, self.Node) and next_node.agg_node is not None: + for sub_node in next_node.agg_node.nodes: + queue.put(sub_node) + else: + queue.put(next_node) + + return order + + class BaseNode: + def __init__(self): + ## Change from list to another data sctruct that allos insert and remove in log(n) + self.previous_nodes = [] + self.next_nodes = [] + self.input = 0 + self.reach = 0 + + def add_input(self): + self.input += 1 + + def add_reach(self): + self.reach += 1 + + def get_input(self): + return self.input + + def get_reach(self): + return self.reach + + def reset(self): + self.reach = 0 + + class Node(BaseNode): + def __init__(self, operation): + self.operation = operation + self.agg_node = None + super().__init__() + + def __del__(self): + self.decrease_input(len(self.previous_nodes)) + if self.agg_node is not None: + self.agg_node.nodes.remove(self) + + def get_operations(self): + return [self.operation] + + def add_input(self): + if self.agg_node is not None: + self.agg_node.input += 1 + else: + self.input += 1 + + def decrease_input(self, amount=1): + if self.agg_node is not None: + self.agg_node.input -= amount + else: + self.input -= amount + + def add_reach(self): + if self.agg_node is not None: + self.agg_node.reach += 1 + else: + self.reach += 1 + + def get_input(self): + if self.agg_node is not None: + return self.agg_node.input + else: + return self.input + + def get_reach(self): + if self.agg_node is not None: + return self.agg_node.reach + else: + return self.reach + + def reset(self): + if self.agg_node is not None: + self.agg_node.reset() + else: + self.reach = 0 + + def print(self): + return f"rank {self.operation.rank} tb {self.operation.threadblock} {self.operation.name}" + + class AggregateNode(BaseNode): + def __init__(self): + self.nodes = [] + super().__init__() + + def add_node(self, node): + self.nodes.append(node) + node.agg_node = self + + def get_operations(self): + operations = [] + for node in self.nodes: + operations.extend(node.get_operations()) + return operations + + def print(self): + return f"rank {self.operations[0].rank} tb {self.operations[0].threadblock} {self.operations[0].name}" diff --git a/python/mscclpp/language/internal/operations.py b/python/mscclpp/language/internal/operations.py index 127f4a03c..62f627294 100644 --- a/python/mscclpp/language/internal/operations.py +++ b/python/mscclpp/language/internal/operations.py @@ -11,6 +11,8 @@ DataAccess, DataAccessType, ) +from mscclpp.language.thread_block_group import ThreadBlockGroup +from mscclpp.language.loop import LoopIterationContext from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import List @@ -32,7 +34,11 @@ class BaseOperation(ABC): """ id: uuid.UUID = field(default_factory=uuid.uuid4, init=False) + rank: int + threadblock: int name: str + # TODO: Only fuse operation with the same pipeline_context + pipeline_context: LoopIterationContext = field(default=None) def local_data_access(self, sync_purpose=True): """Get list of local data accesses performed by this operation. @@ -81,6 +87,16 @@ def shift_ids(self, instance, num_instances, replication_function): """ return + def set_pipeline_context(self, pipeline_context): + self.pipeline_context = pipeline_context + + def basic_fusion_check(self, other_op): + return ( + self.rank == other_op.rank + and self.threadblock == other_op.threadblock + and self.pipeline_context is other_op.pipeline_context + ) + def __add__(self, other): """Attempt to fuse this operation with another operation. @@ -112,29 +128,21 @@ def to_dict(self): return {"buffer_id": self.buffer_id, "index": self.index, "size": self.size} -@dataclass -class ThreadBlockGroupInfo: - tb_id: int - tbg_size: int - - def to_dict(self): - return {"tb_id": self.tb_id, "tbg_size": self.tbg_size} - - class SyncOperation(BaseOperation): - def __init__(self): - super().__init__(Instruction.nop) + def __init__(self, rank: int, threadblock: int): + super().__init__(rank, threadblock, Instruction.nop) def __add__(self, other): fused_operation = None - if isinstance(other, SyncOperation): - fused_operation = SyncOperation() - elif isinstance(other, BarrierOperation): - fused_operation = other - elif isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before: - fused_operation = other - elif check_data_sync_op(other): - other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) + if self.basic_fusion_check(other): + if isinstance(other, SyncOperation): + fused_operation = SyncOperation(self.rank, self.threadblock) + elif isinstance(other, BarrierOperation): + fused_operation = other + elif isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before: + fused_operation = other + elif check_data_sync_op(other): + other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) return fused_operation @@ -146,36 +154,76 @@ def to_dict(self): class CopyOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, src_buff: List[LocalChunk], dst_buff: List[LocalChunk], - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, from_packet: bool = False, to_packet: bool = False, ): if from_packet and to_packet: raise RuntimeError(f"Copy Operation from Packet to Packet is not Supported.") elif from_packet: - super().__init__(Instruction.unpack_packet) + super().__init__(rank, threadblock, Instruction.unpack_packet) elif to_packet: - super().__init__(Instruction.copy_packet) + super().__init__(rank, threadblock, Instruction.copy_packet) else: - super().__init__(Instruction.copy) + super().__init__(rank, threadblock, Instruction.copy) self.src_buff = src_buff self.dst_buff = dst_buff - self.tbg_info = tbg_info + self.tbg = tbg - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] if self.name != Instruction.unpack_packet or not sync_purpose: for chunk in self.src_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.read) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + ( + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) + if self.tbg is not None + else 0 + ), + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), + chunk.type, + DataAccessType.read, + self.tbg, + self.pipeline_context, + ) ) if self.name != Instruction.copy_packet or not sync_purpose: for chunk in self.dst_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.write) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + ( + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) + if self.tbg is not None + else 0 + ), + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), + chunk.type, + DataAccessType.write, + self.tbg, + self.pipeline_context, + ) ) return data_access @@ -193,16 +241,20 @@ def to_dict(self): result["dst_buff"] = [] for chunk in self.dst_buff: result["dst_buff"].append(chunk.to_dict()) - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result class SemaphoreAcquireOperation(BaseOperation): - def __init__(self, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): - super().__init__(Instruction.sem_acquire) + def __init__(self, rank: int, threadblock: int, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): + super().__init__(rank, threadblock, Instruction.sem_acquire) self.semaphore_ids = semaphore_ids self.data_sync = data_sync + self.tb_sync = set() + + def add_tb_sync(self, tb): + self.tb_sync.add(tb) def shift_ids(self, instance, num_instances, replication_function): for i in range(len(self.semaphore_ids)): @@ -210,18 +262,24 @@ def shift_ids(self, instance, num_instances, replication_function): def __add__(self, other): fused_operation = None - if isinstance(other, SemaphoreAcquireOperation): - fused_operation = SemaphoreAcquireOperation( - semaphore_ids=self.semaphore_ids + other.semaphore_ids, - data_sync=self.data_sync | other.data_sync, - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if isinstance(other, SemaphoreAcquireOperation): + fused_operation = SemaphoreAcquireOperation( + self.rank, + self.threadblock, + semaphore_ids=self.semaphore_ids + other.semaphore_ids, + data_sync=self.data_sync | other.data_sync, + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -232,8 +290,8 @@ def to_dict(self): class SemaphoreReleaseOperation(BaseOperation): - def __init__(self, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): - super().__init__(Instruction.sem_release) + def __init__(self, rank: int, threadblock: int, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): + super().__init__(rank, threadblock, Instruction.sem_release) self.semaphore_ids = semaphore_ids self.data_sync = data_sync @@ -243,18 +301,24 @@ def shift_ids(self, instance, num_instances, replication_function): def __add__(self, other): fused_operation = None - if isinstance(other, SemaphoreReleaseOperation): - fused_operation = SemaphoreReleaseOperation( - semaphore_ids=self.semaphore_ids + other.semaphore_ids, - data_sync=self.data_sync | other.data_sync, - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if isinstance(other, SemaphoreReleaseOperation): + fused_operation = SemaphoreReleaseOperation( + self.rank, + self.threadblock, + semaphore_ids=self.semaphore_ids + other.semaphore_ids, + data_sync=self.data_sync | other.data_sync, + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -267,40 +331,48 @@ def to_dict(self): class SignalOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, channels_ids: List[int], channel_type: ChannelType, data_sync: SyncType = SyncType.none, relaxed: bool = False, ): if relaxed: - super().__init__(Instruction.relaxed_signal) + super().__init__(rank, threadblock, Instruction.relaxed_signal) else: - super().__init__(Instruction.signal) + super().__init__(rank, threadblock, Instruction.signal) self.channel_ids = set(channels_ids) self.channel_type = channel_type self.data_sync = data_sync def __add__(self, other): fused_operation = None - if ( - isinstance(other, SignalOperation) - and self.channel_type == other.channel_type - and self.name == other.name - and not self.channel_ids & other.channel_ids - ): - fused_operation = SignalOperation( - channels_ids=self.channel_ids | other.channel_ids, - channel_type=self.channel_type, - data_sync=self.data_sync | other.data_sync, - relaxed=(self.name == Instruction.relaxed_signal), - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if ( + isinstance(other, SignalOperation) + and self.channel_type == other.channel_type + and self.name == other.name + and not self.channel_ids & other.channel_ids + ): + fused_operation = SignalOperation( + self.rank, + self.threadblock, + channels_ids=self.channel_ids | other.channel_ids, + channel_type=self.channel_type, + data_sync=self.data_sync | other.data_sync, + relaxed=(self.name == Instruction.relaxed_signal), + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -314,40 +386,48 @@ def to_dict(self): class WaitOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, channels_ids: List[int], channel_type: ChannelType, data_sync: SyncType = SyncType.none, relaxed: bool = False, ): if relaxed: - super().__init__(Instruction.relaxed_wait) + super().__init__(rank, threadblock, Instruction.relaxed_wait) else: - super().__init__(Instruction.wait) + super().__init__(rank, threadblock, Instruction.wait) self.channel_ids = set(channels_ids) self.channel_type = channel_type self.data_sync = data_sync def __add__(self, other): fused_operation = None - if ( - isinstance(other, WaitOperation) - and self.name == other.name - and not self.channel_ids & other.channel_ids - and self.channel_type == other.channel_type - ): - fused_operation = WaitOperation( - channels_ids=self.channel_ids | other.channel_ids, - channel_type=self.channel_type, - data_sync=self.data_sync | other.data_sync, - relaxed=(self.name == Instruction.relaxed_wait), - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if ( + isinstance(other, WaitOperation) + and self.name == other.name + and not self.channel_ids & other.channel_ids + and self.channel_type == other.channel_type + ): + fused_operation = WaitOperation( + self.rank, + self.threadblock, + channels_ids=self.channel_ids | other.channel_ids, + channel_type=self.channel_type, + data_sync=self.data_sync | other.data_sync, + relaxed=(self.name == Instruction.relaxed_wait), + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -361,7 +441,7 @@ def to_dict(self): class BarrierOperation(BaseOperation): __current_barriers = [] - def __init__(self, rank: int, tb_list: List[int]): + def __init__(self, rank: int, threadblock: int, tb_list: List[int]): for _ in range(len(BarrierOperation.__current_barriers), rank + 1): BarrierOperation.__current_barriers.append({}) barrier_info = BarrierOperation.BarrierInfo(tb_list) @@ -372,7 +452,7 @@ def __init__(self, rank: int, tb_list: List[int]): else: self.barrier_id = BarrierOperation.__current_barriers[rank][barrier_info] - super().__init__(Instruction.barrier) + super().__init__(rank, threadblock, Instruction.barrier) self.barrier_info = barrier_info def shift_ids(self, instance, num_instances, replication_function): @@ -380,10 +460,11 @@ def shift_ids(self, instance, num_instances, replication_function): def __add__(self, other): fused_operation = None - if check_data_sync_op(other): - other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) - elif isinstance(other, SyncOperation): - fused_operation = self + if self.basic_fusion_check(other): + if check_data_sync_op(other): + other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) + elif isinstance(other, SyncOperation): + fused_operation = self return fused_operation @@ -406,27 +487,40 @@ def __hash__(self): class FlushOperation(BaseOperation): - def __init__(self, channels_ids: List[int], channel_type: ChannelType, data_sync: SyncType = SyncType.none): - super().__init__(Instruction.flush) + def __init__( + self, + rank: int, + threadblock: int, + channels_ids: List[int], + channel_type: ChannelType, + data_sync: SyncType = SyncType.none, + ): + super().__init__(rank, threadblock, Instruction.flush) self.channel_ids = set(channels_ids) self.channel_type = channel_type self.data_sync = data_sync def __add__(self, other): fused_operation = None - if isinstance(other, FlushOperation) and self.channel_type == other.channel_type: - fused_operation = FlushOperation( - channels_ids=self.channel_ids | other.channel_ids, - channel_type=self.channel_type, - data_sync=self.data_sync | other.data_sync, - ) - elif ( - (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) - or (isinstance(other, PipelineOperation) and (other.get_data_sync() & SyncType.before) == SyncType.before) - or isinstance(other, SyncOperation) - or isinstance(other, BarrierOperation) - ): - self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) + if self.basic_fusion_check(other): + if isinstance(other, FlushOperation) and self.channel_type == other.channel_type: + fused_operation = FlushOperation( + self.rank, + self.threadblock, + channels_ids=self.channel_ids | other.channel_ids, + channel_type=self.channel_type, + data_sync=self.data_sync | other.data_sync, + ) + elif ( + (check_data_sync_op(other) and (other.data_sync & SyncType.before) == SyncType.before) + or ( + isinstance(other, PipelineOperation) + and (other.get_data_sync() & SyncType.before) == SyncType.before + ) + or isinstance(other, SyncOperation) + or isinstance(other, BarrierOperation) + ): + self.data_sync = self.data_sync ^ (SyncType.after & self.data_sync) return fused_operation @@ -440,24 +534,41 @@ def to_dict(self): class GetOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, src_buff: List[RemoteChunk], dst_buff: List[LocalChunk], channel_ids: List[int], channel_type: ChannelType, - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, ): - super().__init__(Instruction.get) + super().__init__(rank, threadblock, Instruction.get) self.src_buff = src_buff self.dst_buff = dst_buff self.channel_ids = channel_ids self.channel_type = channel_type - self.tbg_info = tbg_info + self.tbg = tbg - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] for chunk in self.dst_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.write) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), + chunk.type, + DataAccessType.write, + self.tbg, + self.pipeline_context, + ) ) return data_access @@ -469,19 +580,22 @@ def shift_buffers(self, instance, num_instances, replication_function): def __add__(self, other): fused_operation = None - if ( - isinstance(other, GetOperation) - and self.src_buff[0].size == other.src_buff[0].size - and self.channel_type == other.channel_type - and self.tbg_info == other.tbg_info - ): - fused_operation = GetOperation( - src_buff=self.src_buff + other.src_buff, - dst_buff=self.dst_buff + other.dst_buff, - channel_ids=self.channel_ids + other.channel_ids, - channel_type=self.channel_type, - tbg_info=self.tbg_info, - ) + if self.basic_fusion_check(other): + if ( + isinstance(other, GetOperation) + and self.src_buff[0].size == other.src_buff[0].size + and self.channel_type == other.channel_type + and self.tbg == other.tbg + ): + fused_operation = GetOperation( + self.rank, + self.threadblock, + src_buff=self.src_buff + other.src_buff, + dst_buff=self.dst_buff + other.dst_buff, + channel_ids=self.channel_ids + other.channel_ids, + channel_type=self.channel_type, + tbg=self.tbg, + ) return fused_operation @@ -495,40 +609,42 @@ def to_dict(self): result["dst_buff"].append(chunk.to_dict()) result["channel_ids"] = self.channel_ids result["channel_type"] = self.channel_type.value - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result class PutOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, src_buff: List[LocalChunk], dst_buff: List[RemoteChunk], channel_ids: List[int], channel_type: ChannelType, - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, from_packet: bool = False, to_packet: bool = False, with_signal: bool = False, with_signal_and_flush: bool = False, ): if from_packet and to_packet: - super().__init__(Instruction.read_put_packet) + super().__init__(rank, threadblock, Instruction.read_put_packet) elif to_packet: - super().__init__(Instruction.put_packet) + super().__init__(rank, threadblock, Instruction.put_packet) elif from_packet: raise RuntimeError(f"Put Operation from Packet is not Supported.") else: if with_signal: if with_signal_and_flush: - super().__init__(Instruction.put_with_signal_and_flush) + super().__init__(rank, threadblock, Instruction.put_with_signal_and_flush) else: - super().__init__(Instruction.put_with_signal) + super().__init__(rank, threadblock, Instruction.put_with_signal) elif with_signal_and_flush: - super().__init__(Instruction.put_with_signal_and_flush) + super().__init__(rank, threadblock, Instruction.put_with_signal_and_flush) else: - super().__init__(Instruction.put) + super().__init__(rank, threadblock, Instruction.put) self.src_buff = src_buff self.dst_buff = dst_buff @@ -537,14 +653,33 @@ def __init__( self.to_packet = to_packet self.with_signal = with_signal self.with_signal_and_flush = with_signal_and_flush - self.tbg_info = tbg_info + self.tbg = tbg - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] if self.name != Instruction.read_put_packet or not sync_purpose: for chunk in self.src_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.read) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + ( + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) + if self.tbg is not None + else 0 + ), + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), + chunk.type, + DataAccessType.read, + self.tbg, + self.pipeline_context, + ) ) return data_access @@ -556,29 +691,32 @@ def shift_buffers(self, instance, num_instances, replication_function): def __add__(self, other): fused_operation = None - if ( - isinstance(other, PutOperation) - and ( - self.name == Instruction.put - or self.name == Instruction.put_packet - or self.name == Instruction.put_with_signal - or self.name == Instruction.put_with_signal_and_flush - ) - and self.name == other.name - and self.src_buff[0].size == other.src_buff[0].size - and self.channel_type == other.channel_type - and self.tbg_info == other.tbg_info - ): - fused_operation = PutOperation( - src_buff=self.src_buff + other.src_buff, - dst_buff=self.dst_buff + other.dst_buff, - channel_ids=self.channel_ids + other.channel_ids, - channel_type=self.channel_type, - tbg_info=self.tbg_info, - to_packet=self.to_packet, - with_signal=self.with_signal, - with_signal_and_flush=self.with_signal_and_flush, - ) + if self.basic_fusion_check(other): + if ( + isinstance(other, PutOperation) + and ( + self.name == Instruction.put + or self.name == Instruction.put_packet + or self.name == Instruction.put_with_signal + or self.name == Instruction.put_with_signal_and_flush + ) + and self.name == other.name + and self.src_buff[0].size == other.src_buff[0].size + and self.channel_type == other.channel_type + and self.tbg == other.tbg + ): + fused_operation = PutOperation( + self.rank, + self.threadblock, + src_buff=self.src_buff + other.src_buff, + dst_buff=self.dst_buff + other.dst_buff, + channel_ids=self.channel_ids + other.channel_ids, + channel_type=self.channel_type, + tbg=self.tbg, + to_packet=self.to_packet, + with_signal=self.with_signal, + with_signal_and_flush=self.with_signal_and_flush, + ) return fused_operation @@ -593,8 +731,8 @@ def to_dict(self): if self.channel_type == ChannelType.port: result["channel_ids"] = self.channel_ids result["channel_type"] = self.channel_type.value - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result @@ -602,6 +740,8 @@ def to_dict(self): class ReduceOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, local_src_buff: List[LocalChunk], local_dst_buff: List[LocalChunk], local_pkt_dst_buff: List[LocalChunk] = None, @@ -611,7 +751,7 @@ def __init__( put_channel_ids: List[int] = None, channel_type: ChannelType = ChannelType.none, reduce_operation: ReduceOperationType = ReduceOperationType.sum, - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, packet: bool = False, ): local_pkt_dst_buff = local_pkt_dst_buff if local_pkt_dst_buff is not None else [] @@ -623,23 +763,23 @@ def __init__( if len(remote_src_buff) == 0 and len(remote_dst_buff) == 0: if packet: if len(local_pkt_dst_buff) == 0: - super().__init__(Instruction.reduce_packet) + super().__init__(rank, threadblock, Instruction.reduce_packet) else: - super().__init__(Instruction.reduce_copy_packet) + super().__init__(rank, threadblock, Instruction.reduce_copy_packet) else: - super().__init__(Instruction.reduce) + super().__init__(rank, threadblock, Instruction.reduce) elif len(remote_src_buff) == 0: if packet: if len(local_pkt_dst_buff) == 0: - super().__init__(Instruction.reduce_send_packet) + super().__init__(rank, threadblock, Instruction.reduce_send_packet) else: - super().__init__(Instruction.reduce_copy_send_packet) + super().__init__(rank, threadblock, Instruction.reduce_copy_send_packet) else: - super().__init__(Instruction.reduce_send) + super().__init__(rank, threadblock, Instruction.reduce_send) elif len(remote_dst_buff) == 0 and not packet: - super().__init__(Instruction.read_reduce) + super().__init__(rank, threadblock, Instruction.read_reduce) elif not packet: - super().__init__(Instruction.read_reduce_send) + super().__init__(rank, threadblock, Instruction.read_reduce_send) else: raise RuntimeError(f"Reduce Operation invalid parameters.") @@ -652,20 +792,54 @@ def __init__( self.put_channel_ids = put_channel_ids self.channel_type = channel_type self.reduce_operation = reduce_operation - self.tbg_info = tbg_info + self.tbg = tbg self.packet = packet - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] for i in range(len(self.local_src_buff)): chunk = self.local_src_buff[i] if not self.packet or i != 0 or not sync_purpose: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.read) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + ( + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) + if self.tbg is not None + else 0 + ), + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), + chunk.type, + DataAccessType.read, + self.tbg, + self.pipeline_context, + ) ) for chunk in self.local_dst_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.write) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + ( + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) + if self.tbg is not None + else chunk.size + ), + chunk.type, + DataAccessType.write, + self.tbg, + self.pipeline_context, + ) ) return data_access @@ -681,118 +855,129 @@ def shift_buffers(self, instance, num_instances, replication_function): def __add__(self, other): fused_operation = None - if ( - isinstance(other, ReduceOperation) - and ( - self.name == Instruction.reduce - or self.name == Instruction.reduce_packet - or self.name == Instruction.read_reduce - ) - and self.name == other.name - and self.local_src_buff[0] == other.local_src_buff[0] - and self.local_dst_buff == other.local_dst_buff - and self.channel_type == other.channel_type - and self.reduce_operation == other.reduce_operation - and self.tbg_info == other.tbg_info - ): - fused_operation = ReduceOperation( - self.local_src_buff + other.local_src_buff[1:], - self.local_dst_buff, - remote_src_buff=self.remote_src_buff + other.remote_src_buff, - channel_ids=self.channel_ids + other.channel_ids, - channel_type=self.channel_type, - reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, - packet=self.packet, - ) - if ( - isinstance(other, PutOperation) - and ( - self.name == Instruction.reduce - or self.name == Instruction.reduce_send - or self.name == Instruction.read_reduce - or self.name == Instruction.read_reduce_send - ) - and other.name == Instruction.put - and self.local_dst_buff[0] == other.src_buff[0] - and other.channel_type == ChannelType.memory - and self.tbg_info == other.tbg_info - ): - fused_operation = ReduceOperation( - self.local_src_buff, - self.local_dst_buff, - remote_src_buff=self.remote_src_buff, - remote_dst_buff=self.remote_dst_buff + other.dst_buff, - channel_ids=self.channel_ids, - put_channel_ids=self.put_channel_ids + other.channel_ids, - channel_type=self.channel_type, - reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, - packet=self.packet, - ) - if ( - isinstance(other, PutOperation) - and (self.name == Instruction.reduce_packet or self.name == Instruction.reduce_send_packet) - and other.name == Instruction.put_packet - and self.local_dst_buff[0] == other.src_buff[0] - and other.channel_type == ChannelType.memory - and self.tbg_info == other.tbg_info - ): - fused_operation = ReduceOperation( - self.local_src_buff, - self.local_dst_buff, - remote_src_buff=self.remote_src_buff, - remote_dst_buff=self.remote_dst_buff + other.dst_buff, - channel_ids=self.channel_ids, - put_channel_ids=self.put_channel_ids + other.channel_ids, - channel_type=other.channel_type, - reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, - packet=self.packet, - ) - if ( - isinstance(other, CopyOperation) - and self.name == Instruction.reduce_packet - and other.name == Instruction.copy_packet - and self.local_dst_buff[0] == other.src_buff[0] - and self.tbg_info == other.tbg_info - ): - fused_operation = ReduceOperation( - self.local_src_buff, - self.local_dst_buff, - local_pkt_dst_buff=other.dst_buff, - remote_src_buff=self.remote_src_buff, - remote_dst_buff=self.remote_dst_buff, - channel_ids=self.channel_ids, - put_channel_ids=self.put_channel_ids, - channel_type=self.channel_type, - reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, - packet=self.packet, - ) - if ( - isinstance(other, PutOperation) - and (self.name == Instruction.reduce_copy_packet or self.name == Instruction.reduce_copy_send_packet) - and ( - (other.name == Instruction.put_packet and self.local_dst_buff[0] == other.src_buff[0]) - or (other.name == Instruction.read_put_packet and self.local_pkt_dst_buff[0] == other.src_buff[0]) - ) - and other.channel_type == ChannelType.memory - and self.tbg_info == other.tbg_info - ): - fused_operation = ReduceOperation( - self.local_src_buff, - self.local_dst_buff, - local_pkt_dst_buff=self.local_pkt_dst_buff, - remote_src_buff=self.remote_src_buff, - remote_dst_buff=self.remote_dst_buff + other.dst_buff, - channel_ids=self.channel_ids, - put_channel_ids=self.put_channel_ids + other.channel_ids, - channel_type=other.channel_type, - reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, - packet=self.packet, - ) + if self.basic_fusion_check(other): + if ( + isinstance(other, ReduceOperation) + and ( + self.name == Instruction.reduce + or self.name == Instruction.reduce_packet + or self.name == Instruction.read_reduce + ) + and self.name == other.name + and self.local_src_buff[0] == other.local_src_buff[0] + and self.local_dst_buff == other.local_dst_buff + and self.channel_type == other.channel_type + and self.reduce_operation == other.reduce_operation + and self.tbg == other.tbg + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff + other.local_src_buff[1:], + self.local_dst_buff, + remote_src_buff=self.remote_src_buff + other.remote_src_buff, + channel_ids=self.channel_ids + other.channel_ids, + channel_type=self.channel_type, + reduce_operation=self.reduce_operation, + tbg=self.tbg, + packet=self.packet, + ) + if ( + isinstance(other, PutOperation) + and ( + self.name == Instruction.reduce + or self.name == Instruction.reduce_send + or self.name == Instruction.read_reduce + or self.name == Instruction.read_reduce_send + ) + and other.name == Instruction.put + and self.local_dst_buff[0] == other.src_buff[0] + and other.channel_type == ChannelType.memory + and self.tbg == other.tbg + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff, + self.local_dst_buff, + remote_src_buff=self.remote_src_buff, + remote_dst_buff=self.remote_dst_buff + other.dst_buff, + channel_ids=self.channel_ids, + put_channel_ids=self.put_channel_ids + other.channel_ids, + channel_type=self.channel_type, + reduce_operation=self.reduce_operation, + tbg=self.tbg, + packet=self.packet, + ) + if ( + isinstance(other, PutOperation) + and (self.name == Instruction.reduce_packet or self.name == Instruction.reduce_send_packet) + and other.name == Instruction.put_packet + and self.local_dst_buff[0] == other.src_buff[0] + and other.channel_type == ChannelType.memory + and self.tbg == other.tbg + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff, + self.local_dst_buff, + remote_src_buff=self.remote_src_buff, + remote_dst_buff=self.remote_dst_buff + other.dst_buff, + channel_ids=self.channel_ids, + put_channel_ids=self.put_channel_ids + other.channel_ids, + channel_type=other.channel_type, + reduce_operation=self.reduce_operation, + tbg=self.tbg, + packet=self.packet, + ) + if ( + isinstance(other, CopyOperation) + and self.name == Instruction.reduce_packet + and other.name == Instruction.copy_packet + and self.local_dst_buff[0] == other.src_buff[0] + and self.tbg_info == other.tbg_info + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff, + self.local_dst_buff, + local_pkt_dst_buff=other.dst_buff, + remote_src_buff=self.remote_src_buff, + remote_dst_buff=self.remote_dst_buff, + channel_ids=self.channel_ids, + put_channel_ids=self.put_channel_ids, + channel_type=self.channel_type, + reduce_operation=self.reduce_operation, + tbg_info=self.tbg_info, + packet=self.packet, + ) + if ( + isinstance(other, PutOperation) + and (self.name == Instruction.reduce_copy_packet or self.name == Instruction.reduce_copy_send_packet) + and ( + (other.name == Instruction.put_packet and self.local_dst_buff[0] == other.src_buff[0]) + or (other.name == Instruction.read_put_packet and self.local_pkt_dst_buff[0] == other.src_buff[0]) + ) + and other.channel_type == ChannelType.memory + and self.tbg_info == other.tbg_info + ): + fused_operation = ReduceOperation( + self.rank, + self.threadblock, + self.local_src_buff, + self.local_dst_buff, + local_pkt_dst_buff=self.local_pkt_dst_buff, + remote_src_buff=self.remote_src_buff, + remote_dst_buff=self.remote_dst_buff + other.dst_buff, + channel_ids=self.channel_ids, + put_channel_ids=self.put_channel_ids + other.channel_ids, + channel_type=other.channel_type, + reduce_operation=self.reduce_operation, + tbg_info=self.tbg_info, + packet=self.packet, + ) return fused_operation @@ -817,8 +1002,8 @@ def to_dict(self): if self.channel_type != ChannelType.none: result["channel_type"] = self.channel_type.value result["reduce_op"] = self.reduce_operation.value - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result @@ -826,6 +1011,8 @@ def to_dict(self): class GroupLoadReduce(BaseOperation): def __init__( self, + rank, + threadblock: int, buffer_type: BufferType, buffer_offset: int, size: int, @@ -834,7 +1021,7 @@ def __init__( channel_type: ChannelType = ChannelType.switch, reduce_operation: ReduceOperationType = ReduceOperationType.sum, ): - super().__init__(Instruction.group_load_reduce) + super().__init__(rank, threadblock, Instruction.group_load_reduce) self.buffer_type = buffer_type self.buffer_offset = buffer_offset self.size = size @@ -849,23 +1036,26 @@ def shift_buffers(self, instance, num_instances, replication_function): def __add__(self, other): fused_operation = None - if ( - isinstance(other, GroupStore) - and self.buffer_type == other.buffer_type - and self.size == other.size - and self.dst_chunk == other.src_chunk - and self.channel_ids == other.channel_ids - and self.channel_type == other.channel_type - ): - fused_operation = GroupLoadReduceStore( - buffer_type=self.buffer_type, - size=self.size, - src_index=[self.buffer_offset], - dst_index=[other.buffer_offset], - channel_ids=self.channel_ids, - channel_type=self.channel_type, - reduce_operation=self.reduce_operation, - ) + if self.basic_fusion_check(other): + if ( + isinstance(other, GroupStore) + and self.buffer_type == other.buffer_type + and self.size == other.size + and self.dst_chunk == other.src_chunk + and self.channel_ids == other.channel_ids + and self.channel_type == other.channel_type + ): + fused_operation = GroupLoadReduceStore( + self.rank, + self.threadblock, + buffer_type=self.buffer_type, + size=self.size, + src_index=[self.buffer_offset], + dst_index=[other.buffer_offset], + channel_ids=self.channel_ids, + channel_type=self.channel_type, + reduce_operation=self.reduce_operation, + ) return fused_operation @@ -885,6 +1075,8 @@ def to_dict(self): class GroupStore(BaseOperation): def __init__( self, + rank, + threadblock: int, src_chunk: Chunk, buffer_type: BufferType, buffer_offset: int, @@ -892,7 +1084,7 @@ def __init__( channel_ids: List[int], channel_type: ChannelType = ChannelType.switch, ): - super().__init__(Instruction.group_store) + super().__init__(rank, threadblock, Instruction.group_store) self.src_chunk = src_chunk self.buffer_type = buffer_type self.buffer_offset = buffer_offset @@ -919,6 +1111,8 @@ def to_dict(self): class GroupLoadReduceStore(BaseOperation): def __init__( self, + rank, + threadblock: int, buffer_type: BufferType, size: int, src_index: List[int], @@ -927,7 +1121,7 @@ def __init__( channel_type: ChannelType = ChannelType.switch, reduce_operation: ReduceOperationType = ReduceOperationType.sum, ): - super().__init__(Instruction.group_load_reduce_store) + super().__init__(rank, threadblock, Instruction.group_load_reduce_store) self.buffer_type = buffer_type self.size = size self.src_index = src_index @@ -961,8 +1155,8 @@ def to_dict(self): @dataclass class PipelineOperation(BaseOperation): - def __init__(self, unit_size: int, num_chunks: int, operations=None): - super().__init__(Instruction.pipeline) + def __init__(self, rank: int, threadblock: int, unit_size: int, num_chunks: int, operations=None): + super().__init__(rank, threadblock, Instruction.pipeline) self.unit_size = unit_size self.num_chunks = num_chunks self.operations = operations if operations is not None else [] @@ -1006,10 +1200,11 @@ def shift_ids(self, instance, num_instances, replication_function): def __add__(self, other): fused_operation = None - if (self.get_data_sync() & SyncType.after) == SyncType.after and check_data_sync_op(other): - other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) - elif isinstance(other, SyncOperation) and (self.get_data_sync() & SyncType.after) == SyncType.after: - fused_operation = self + if self.basic_fusion_check(other): + if (self.get_data_sync() & SyncType.after) == SyncType.after and check_data_sync_op(other): + other.data_sync = other.data_sync ^ (SyncType.before & other.data_sync) + elif isinstance(other, SyncOperation) and (self.get_data_sync() & SyncType.after) == SyncType.after: + fused_operation = self return fused_operation @@ -1052,11 +1247,11 @@ def add_data_sync(operations): if operation.name in data_sync_operations and ( operation.data_sync == SyncType.before or operation.data_sync == SyncType.both ): - result_operations.append(SyncOperation()) + result_operations.append(SyncOperation(operation.rank, operation.threadblock)) result_operations.append(operation) if operation.name in data_sync_operations and ( operation.data_sync == SyncType.after or operation.data_sync == SyncType.both ): - result_operations.append(SyncOperation()) + result_operations.append(SyncOperation(operation.rank, operation.threadblock)) return result_operations diff --git a/python/mscclpp/language/internal/register.py b/python/mscclpp/language/internal/register.py new file mode 100644 index 000000000..833415615 --- /dev/null +++ b/python/mscclpp/language/internal/register.py @@ -0,0 +1,22 @@ +class ChannelRegister: + channels = {} + + @staticmethod + def add_channel(rank, tb, tb_channel_id, channel): + ChannelRegister.channels[(rank, tb, tb_channel_id)] = channel + + @staticmethod + def get_channel(rank: int, threadblock: int, tb_channel_id: int): + return ChannelRegister.channels.get((rank, threadblock, tb_channel_id)) + + +class SemaphoreRegister: + semaphores = {} + + @staticmethod + def add_semaphore(semaphore): + SemaphoreRegister.semaphores[(semaphore.rank, semaphore.id)] = semaphore + + @staticmethod + def get_semaphore(rank: int, semaphore_id: int): + return SemaphoreRegister.semaphores.get((rank, semaphore_id)) diff --git a/python/mscclpp/language/internal/types.py b/python/mscclpp/language/internal/types.py index 9bfe1c76f..99e71e57a 100644 --- a/python/mscclpp/language/internal/types.py +++ b/python/mscclpp/language/internal/types.py @@ -1,9 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from dataclasses import dataclass +from mscclpp.language.loop import LoopIterationContext +from mscclpp.language.thread_block_group import ThreadBlockGroup +from dataclasses import dataclass, field from enum import Enum from typing import List, Set +from collections import defaultdict +import uuid class SyncType(Enum): @@ -173,31 +177,113 @@ def __str__(self): @dataclass class DataAccess: - operation_id: int - start: int - end: int + rank: int + threadblock: int + operation_global_id: uuid.UUID + operation_order_id: int + start: float + end: float buffer_type: BufferType data_access_type: DataAccessType + tb_group: ThreadBlockGroup + pipeline_context: LoopIterationContext def __lt__(self, other): if self.start != other.start: return self.start < other.start return self.end < other.end - def __eq__(self, other): - return self.start == other.start and self.end == other.end + def __eq__(self, other, tolerance=1e-5): + return abs(self.start - other.start) < tolerance and abs(self.end - other.end) < tolerance def __hash__(self): return hash((self.start, self.end)) - def overlaps(self, other) -> bool: - return self.start <= other.end and other.start <= self.end + def lower_overlaps(self, other, tolerance=1e-5) -> bool: + return self.start + tolerance < other.end + + def overlaps(self, other, tolerance=1e-5) -> bool: + return (self.start + tolerance < other.end) and (other.start + tolerance < self.end) def check_conflict(self, other) -> bool: - return ( + if ( self.overlaps(other) - and self.operation_id != other.operation_id + and self.operation_global_id != other.operation_global_id and (self.data_access_type != DataAccessType.read or other.data_access_type != DataAccessType.read) + ): + if self.threadblock == other.threadblock: + return DataAccessConflict( + self.rank, + {(other.threadblock, other.operation_order_id, True)}, + DataAccessConflictType.intra_threadblock, + ) + else: + is_order_defined = ( + ( + self.tb_group is not None + and other.tb_group is not None + and self.tb_group.tbg_overlap(other.tb_group) + ) + or ( + self.tb_group is not None + and other.tb_group is None + and self.tb_group.tb_overlap(other.threadblock) + ) + or ( + self.tb_group is None + and other.tb_group is not None + and other.tb_group.tb_overlap(self.threadblock) + ) + ) + return DataAccessConflict( + self.rank, + { + (self.threadblock, other.operation_order_id, True), + (other.threadblock, other.operation_order_id, is_order_defined), + }, + DataAccessConflictType.inter_threadblock, + ) + else: + return DataAccessConflict(self.rank) + + +class DataAccessConflictType(Enum): + inter_threadblock = "inter_tb" + intra_threadblock = "intra_tb" + none = "none" + + def __add__(self, other): + if not isinstance(other, DataAccessConflictType): + return NotImplemented + + map_to_num = { + DataAccessConflictType.none: 0, + DataAccessConflictType.intra_threadblock: 1, + DataAccessConflictType.inter_threadblock: 3, + } + map_to_dact = { + 0: DataAccessConflictType.none, + 1: DataAccessConflictType.intra_threadblock, + 3: DataAccessConflictType.inter_threadblock, + } + return map_to_dact[map_to_num[self] | map_to_num[other]] + + def __str__(self): + return self.value + + +@dataclass +class DataAccessConflict: + rank: int + threadblocks: Set[int] = field(default_factory=set) + conflict_type: DataAccessConflictType = DataAccessConflictType.none + + def __add__(self, other): + if not isinstance(other, DataAccessConflict): + return NotImplemented + + return DataAccessConflict( + self.rank, self.threadblocks | other.threadblocks, self.conflict_type + other.conflict_type ) diff --git a/python/mscclpp/language/loop.py b/python/mscclpp/language/loop.py index 06ca90f40..ad70ecedb 100644 --- a/python/mscclpp/language/loop.py +++ b/python/mscclpp/language/loop.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from mscclpp.language.internal.globals import * -from mscclpp.language.internal.operations import PipelineOperation +from typing import Dict class LoopIterationContext: @@ -35,7 +35,7 @@ def __init__(self, unit, num_chunks): """ self.unit = unit self.num_chunks = num_chunks - self.operations = [] + self.pre_operations = dict() def __enter__(self): """Enter the context and set this as the active loop context. @@ -54,17 +54,7 @@ def __exit__(self, exc_type, exc_value, traceback): """ get_program().set_loop_context(None) - pipeline_operation = {} - for rank, tb, operation in self.operations: - key = (rank, tb) - if key not in pipeline_operation: - pipeline_operation[key] = PipelineOperation(self.unit, self.num_chunks) - pipeline_operation[key].add_operation(operation) - - for (rank, tb), pipeline in pipeline_operation.items(): - get_program().add_operation(rank, tb, pipeline) - - def add_operation(self, rank, tb, operation): + def process_operation(self, operations): """Add an operation to be included in the pipeline. This method is called internally to collect operations that should be @@ -76,4 +66,5 @@ def add_operation(self, rank, tb, operation): tb (int): The thread block ID that will execute this operation. operation: The operation object to be added to the pipeline. """ - self.operations.append((rank, tb, operation)) + for operation in operations: + operation.set_pipeline_context(self) diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index c29e9ab75..61554a592 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -5,10 +5,14 @@ from mscclpp.language.internal.globals import set_program from mscclpp.language.internal.types import BufferType, RemoteBuffer, ChannelType from mscclpp.language.internal.gpu import Gpu +from mscclpp.language.internal.register import ChannelRegister, SemaphoreRegister +from mscclpp.language.internal.op_dep_graph import OperationDependencyGraph +from mscclpp.language.internal.buffer_access import BuffersAccess from mscclpp.language.channel import * from mscclpp.language.rank import Semaphore from mscclpp.language.collectives import * from mscclpp.language.utils import AlgoSpec, ReplicationPolicy +from mscclpp.language.internal.operations import add_data_sync from typing import List import json @@ -49,6 +53,7 @@ def __init__( protocol: str = "Simple", instr_fusion: bool = True, auto_sync: bool = True, + intra_rank_sync: bool = True, replication_policy: ReplicationPolicy = ReplicationPolicy.interleaved, reuse_resources: bool = False, num_threads_per_block: int = 1024, @@ -103,6 +108,8 @@ def __init__( self.min_message_size = min_message_size self.max_message_size = max_message_size assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" + self.op_dep_dag = OperationDependencyGraph() + self.buffers_access = BuffersAccess(num_ranks, intra_rank_sync) self.buffers = collective.init_buffers() self.gpus: List[Gpu] = [] for rank in range(self.num_ranks): @@ -182,6 +189,9 @@ def __exit__(self, exc_type, exc_value, traceback): Semaphore.reset() set_program(None) + def disable_inter_tb_sync(self): + self.buffers_access.intra_rank_sync = False + def add_channel(self, channel): if channel.channel_type == ChannelType.switch: for gpu in channel.rank_group.ranks: @@ -192,22 +202,37 @@ def add_channel(self, channel): def setup_channel(self, tb, channel): tb_channel_ids = [] tb_channel_ids.append(self.gpus[channel.src_rank].setup_channel(tb, channel)) + for tb_channel_id in tb_channel_ids: + ChannelRegister.add_channel(channel.src_rank, tb, tb_channel_id, channel) return tb_channel_ids def setup_remote_chunk(self, rank, tb, remote_chunk: RemoteBuffer, channel_access: ChannelType): return self.gpus[rank].add_remote_buffer(tb, remote_chunk, channel_access) def add_semaphore(self, semaphore): + SemaphoreRegister.add_semaphore(semaphore) self.gpus[semaphore.rank].add_semaphore(semaphore) def add_operation(self, rank, tb, operation): if self.loop_context != None: - self.loop_context.add_operation(rank, tb, operation) - else: - self.gpus[rank].add_operation(tb, operation) + self.loop_context.process_operation([operation]) + self.op_dep_dag.add_operation(operation) + + def add_tbg_operation(self, operations): + if self.loop_context != None: + self.loop_context.process_operation(operations) + self.op_dep_dag.add_tbg_operation(operations) def post_process_operations(self): - for gpu in self.gpus: + self.op_dep_dag.add_semaphore_dependency() + self.op_dep_dag.fusion_operations() + list_op = self.op_dep_dag.get_execution_order() + list_op = add_data_sync(list_op) + list_op = self.buffers_access.process_operations(list_op) + for op in list_op: + self.gpus[op.rank].add_operation(op.threadblock, op) + + """ for gpu in self.gpus: if self.instr_fusion: gpu.optimize_operations() gpu.adding_data_sync() @@ -217,7 +242,7 @@ def post_process_operations(self): self.instances, self.get_default_replication_policy_function(), self.get_buffer_replication_policy_function(), - ) + ) """ def get_default_replication_policy_function(self): return lambda value, instance, num_instances: value * num_instances + instance diff --git a/python/mscclpp/language/rank.py b/python/mscclpp/language/rank.py index e5b7aab89..b3fbdf88f 100644 --- a/python/mscclpp/language/rank.py +++ b/python/mscclpp/language/rank.py @@ -110,20 +110,23 @@ def _copy( "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: op = CopyOperation( + rank=self.rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), from_packet=from_packet, to_packet=to_packet, ) + operations.append(op) - get_program().add_operation(self.rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def copy(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Copy data from source chunk to destination chunk. @@ -240,21 +243,24 @@ def reduce( "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: op = ReduceOperation( - [LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)] + rank=self.rank, + threadblock=tb_id, + local_src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)] + [LocalChunk(chunk.buffer, chunk.index, chunk.size) for chunk in other_chunks], - [LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], + local_dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], reduce_operation=reduce_op, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) - if tb_group is not None - else None - ), + tbg=(tb_group if tb_group is not None else None), packet=packet, ) + operations.append(op) - get_program().add_operation(self.rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def barrier(self, tb_list: List[int]): """Create a synchronization barrier between thread blocks. @@ -275,11 +281,11 @@ def barrier(self, tb_list: List[int]): if len(tb_list) == 0: raise RuntimeError("Barrier requires at least thread block.") elif len(tb_list) == 1: - op = SyncOperation() + op = SyncOperation(self.rank, tb_list[0]) get_program().add_operation(self.rank, tb_list[0], op) else: - op = BarrierOperation(self.rank, tb_list) for tb in tb_list: + op = BarrierOperation(self.rank, tb, tb_list) get_program().add_operation(self.rank, tb, op) @@ -413,7 +419,7 @@ def acquire(self, tb: int, data_sync: SyncType = SyncType.both): Example: >>> sem.acquire(tb=0, data_sync=SyncType.before) """ - op = SemaphoreAcquireOperation([self.id], data_sync) + op = SemaphoreAcquireOperation(self.rank, tb, [self.id], data_sync) get_program().add_operation(self.rank, tb, op) def release(self, tb: int, data_sync: SyncType = SyncType.both): @@ -431,5 +437,5 @@ def release(self, tb: int, data_sync: SyncType = SyncType.both): Example: >>> sem.release(tb=0, data_sync=SyncType.after) """ - op = SemaphoreReleaseOperation([self.id], data_sync) + op = SemaphoreReleaseOperation(self.rank, tb, [self.id], data_sync) get_program().add_operation(self.rank, tb, op) diff --git a/python/mscclpp/language/thread_block_group.py b/python/mscclpp/language/thread_block_group.py index 9ffe4e2dd..811d30f88 100644 --- a/python/mscclpp/language/thread_block_group.py +++ b/python/mscclpp/language/thread_block_group.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import List, Dict +from typing import List, Dict, Set class ThreadBlockGroup: @@ -20,10 +20,9 @@ def __init__(self, tb_list: List[int]): tb_list: List of thread block objects """ - self.tb_list: List[int] = tb_list + self.tb_list: Set[int] = set(tb_list) self._tb_id: Dict[int, int] = {} - # Check for duplicates and build ID mapping seen = set() for i, tb in enumerate(self.tb_list): if tb in seen: @@ -51,3 +50,23 @@ def get_internal_id(self, tb: int) -> int: def numtb(self) -> int: """Return the number of thread blocks in the group.""" return len(self.tb_list) + + def tbg_overlap(self, other): + for tb in self.tb_list: + if tb in other.tb_list: + return True + return False + + def tb_overlap(self, tb_id): + return tb_id in self.tb_list + + def to_dict(self, tb): + return {"tb_id": self.get_internal_id(tb), "tbg_size": self.numtb()} + + def start_offset(self, tb, size): + tb_id = self.get_internal_id(tb) + return (size / self.numtb()) * tb_id + + def end_offset(self, tb, size): + tb_id = self.get_internal_id(tb) + return (size / self.numtb()) * (tb_id + 1)