Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions colossalai/legacy/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .layer_spec import LayerSpec
from .pipelinable import PipelinableContext, PipelinableModel

__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch

from colossalai.utils.model.utils import call_to_str


class LayerSpec:
"""

"""

def __init__(self, typename, *module_args, **module_kwargs):
Expand Down Expand Up @@ -52,4 +54,4 @@ def count_params(self):
return self._param_count

def reset_param_count(self):
self._param_count = 0
self._param_count = 0
3 changes: 3 additions & 0 deletions colossalai/legacy/pipeline/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo

__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .fx import get_topology as get_fx_topology

__all__ = ['get_fx_topology']
__all__ = ['get_fx_topology']
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from torch.fx.graph_module import GraphModule
from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
import torch
from torch.fx.graph_module import GraphModule

from colossalai.legacy.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo


def partition_name_to_id(partition_name, is_input=False, is_output=False):
if is_input:
Expand All @@ -12,6 +14,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
partition_id = int(partition_name.split(prefix)[-1]) + 2
return partition_id


# There are two kinds of def in fx.graph
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
# e.g. submod1 = call_module(...)
Expand All @@ -20,6 +23,8 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
# 2. direct_use & direct_def, which means the output is used by next partition directly.
# e.g. submod1 = call_module(...)
# submod2 = call_module(submod1, ...)


def find_input_in_partition(node, partitions, input_partitions=None):
p_input_val = None
direct_def = not node.name.startswith('getitem')
Expand All @@ -45,9 +50,10 @@ def find_input_in_partition(node, partitions, input_partitions=None):
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
return p_input_val

return p_input_val



def find_output_in_partition(node, partitions, output_partitions=None):
p_output_val = PartitionOutputVal()
for user in node.users:
Expand All @@ -70,7 +76,7 @@ def find_output_in_partition(node, partitions, output_partitions=None):
if arg == user:
p_output_val.add(partition_id=partition_id, offset=i)
break

# user is output
if output_partitions is not None:
output_node = output_partitions[0]
Expand All @@ -84,10 +90,11 @@ def find_output_in_partition(node, partitions, output_partitions=None):
break
return p_output_val


def get_topology(gm: GraphModule):
topo = Topo()
topo_output_partition = Partition()

input_partitions = []
partitions = []
output_partitions = []
Expand All @@ -109,7 +116,7 @@ def get_topology(gm: GraphModule):
topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition_id(partition_id=0)

for i, partition in enumerate(partitions):
topo_mid_partition = Partition()
# set input for submodule
Expand All @@ -131,15 +138,16 @@ def get_topology(gm: GraphModule):
for user in partition.users:
cur_node = user
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=i+2, partition=topo_mid_partition)
topo_mid_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=i + 2, partition=topo_mid_partition)

# set input for output_partition
for partition in output_partitions:
topo_output_partition = Partition()
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
find_input_in_partition(n, partitions, input_partitions)))
torch.fx.graph.map_arg(
partition.args[0],
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition_id(partition_id=1)

return topo
return topo
Original file line number Diff line number Diff line change
@@ -1,77 +1,84 @@
from typing import Dict, List
from dataclasses import dataclass
from typing import Dict, List

# This file includes data structure used by Pipeline Middleware.


@dataclass
class ValPosition:
partition_id: int
offset: int

def __str__(self) -> str:
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
return res

def __repr__(self) -> str:
return self.__str__()


class PartitionInputVal(object):

def __init__(self, partition_id, offset) -> None:
# every input from which partition_id and which offset
val_pos = ValPosition(partition_id, offset)
self._from_partition_and_offset: ValPosition = val_pos

def get(self):
return self._from_partition_and_offset

def __str__(self) -> str:
res = ''
res += f'<-({self._from_partition_and_offset})'
return res

def __repr__(self) -> str:
return self.__str__()



class PartitionOutputVal(object):

def __init__(self) -> None:
# every output to which partition_id and which offset
self._to_partition_and_offset: List[ValPosition] = []

def add(self, partition_id, offset):
val_pos = ValPosition(partition_id, offset)
self._to_partition_and_offset.append(val_pos)

def get(self):
return self._to_partition_and_offset

def __str__(self) -> str:
res = ''
res += '->('
for val_pos in self._to_partition_and_offset:
res += f'{val_pos},'
res += ')'
return res

def __repr__(self) -> str:
return self.__str__()


class Partition(object):

def __init__(self) -> None:
self._input_vals: List[PartitionInputVal] = []
self._output_vals: List[PartitionOutputVal] = []

def add_input_val(self, input_val: PartitionInputVal):
self._input_vals.append(input_val)

def add_output_val(self, output_val: PartitionOutputVal):
self._output_vals.append(output_val)

def get_input_vals(self):
return self._input_vals

def get_output_vals(self):
return self._output_vals

# get the output offsets sent to dst_partition_id
def get_output_offsets(self, dst_partition_id):
res = []
Expand All @@ -80,9 +87,9 @@ def get_output_offsets(self, dst_partition_id):
for val_pos in outputs:
if val_pos.partition_id == dst_partition_id:
res.append(offset)

return res

# get all input dst partition_ids
def get_input_partition_ids(self):
res = []
Expand All @@ -91,7 +98,7 @@ def get_input_partition_ids(self):
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res

# get all output dst partition_ids
def get_output_partition_ids(self):
res = []
Expand All @@ -101,24 +108,25 @@ def get_output_partition_ids(self):
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res

def __str__(self) -> str:
res = ''
res += f' input:\n'
res += f' length:{len(self._input_vals)}\n'
for i, input_val in enumerate(self._input_vals):
res += f' offset={i}:{input_val}\n'

res += f' output:\n'
res += f' length:{len(self._output_vals)}\n'
for i, output_val in enumerate(self._output_vals):
res += f' offset={i}:{output_val}\n'

return res

def __repr__(self) -> str:
return self.__str__()


# This class is a middleware between partition splitter
# and Pipeline Scheduler. It records the graph info about
# partition input/output and provides it to scheduler.
Expand All @@ -132,50 +140,51 @@ def __repr__(self) -> str:
# _input_partition_id: the key represents input_partition
# _output_partition_id: the key represents output_partition
class Topo(object):

def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
self._partitions: Dict[int, Partition] = {}
self._input_partition_id = input_partition_id
self._output_partition_id = output_partition_id

def set_input_partition_id(self, partition_id: int):
self._input_partition_id = partition_id

def set_output_partition_id(self, partition_id: int):
self._output_partition_id = partition_id

def get_input_partition_id(self):
return self._input_partition_id

def get_output_partition_id(self):
return self._output_partition_id

def set_partitions(self, partition_id: int, partition: Partition):
self._partitions[partition_id] = partition

def get_mid_partitions(self):
res = {} #{partition_id: Partition}
res = {} #{partition_id: Partition}
for partition_id, partition in self._partitions.items():
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
continue
res[partition_id] = partition
return res

def get_mid_partition_ids(self):
return list(self.get_mid_partitions().keys())

def get_input_partition(self):
if self._input_partition_id is not None:
return self._partitions[self._input_partition_id]
return None

def get_output_partition(self):
if self._output_partition_id is not None:
return self._partitions[self._output_partition_id]
return None

def get_partition_by_id(self, partition_id):
return self._partitions[partition_id]

def __str__(self) -> str:
res = ''
if len(self._partitions) == 0:
Expand All @@ -186,21 +195,20 @@ def __str__(self) -> str:
res += '{\n'
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
res += '}\n'

mid_parts = self.get_mid_partitions()
for i, (partition_id, part) in enumerate(mid_parts.items()):
res += '{\n'
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
res += '}\n'

output_part = self.get_output_partition()
if output_part is not None:
res += '{\n'
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
res += '}\n'

return res

def __repr__(self) -> str:
return self.__str__()

Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Dict, Tuple
import os
import threading
from typing import Dict, List, Tuple

from torch.distributed import rpc
import torch.distributed as dist
from torch.distributed import rpc

from colossalai.tensor import ProcessGroup

Expand Down
4 changes: 4 additions & 0 deletions colossalai/legacy/pipeline/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine
from .utils import pytree_map

__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']
Loading