Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
cdaa42a
add alphafold benchmark
oahzxl Feb 1, 2023
752e37a
renae alphafold test
oahzxl Feb 1, 2023
d78a737
rename tests
oahzxl Feb 1, 2023
901126f
rename diffuser
oahzxl Feb 1, 2023
14b32be
renme
oahzxl Feb 1, 2023
e93a865
rename
oahzxl Feb 1, 2023
8d6c176
update transformer
oahzxl Feb 1, 2023
a548157
update benchmark
oahzxl Feb 1, 2023
ef4bf3d
update benchmark
oahzxl Feb 1, 2023
4255f1c
update bench memory
oahzxl Feb 1, 2023
75ed562
update transformer benchmark
oahzxl Feb 2, 2023
a5940dc
rename
oahzxl Feb 2, 2023
d5f39a6
support diffuser
oahzxl Feb 2, 2023
b6992e0
support unet metainfo prop
oahzxl Feb 2, 2023
ff4e9ba
Merge branch 'fx' of https://github.com/oahzxl/ColossalAI into unet
oahzxl Feb 2, 2023
a5a27dd
Merge https://github.com/oahzxl/ColossalAI into unet
oahzxl Feb 2, 2023
71b3bae
Merge branch 'hpcaitech:main' into unet
oahzxl Feb 2, 2023
1fde822
fix bug and simplify code
oahzxl Feb 2, 2023
6c22507
Merge branch 'unet' of https://github.com/oahzxl/ColossalAI into unet
oahzxl Feb 2, 2023
05ca225
update linear and support some op
oahzxl Feb 2, 2023
b4566dd
optimize max region search, support conv
oahzxl Feb 3, 2023
b532e29
update unet test
oahzxl Feb 6, 2023
baf001e
support some op
oahzxl Feb 6, 2023
dba9f78
support groupnorm and interpolate
oahzxl Feb 6, 2023
ed59541
update flow search
oahzxl Feb 6, 2023
4b10d38
add fix dim in node flow
oahzxl Feb 6, 2023
f5a370f
fix utils
oahzxl Feb 6, 2023
b07c0b0
rename
oahzxl Feb 6, 2023
95cf822
support diffusion
oahzxl Feb 6, 2023
c6529ad
update diffuser
oahzxl Feb 6, 2023
b2c4fdd
update chunk search
oahzxl Feb 6, 2023
0f9b0f5
optimize imports
oahzxl Feb 6, 2023
46d4b86
import
oahzxl Feb 6, 2023
04651bd
Merge branch 'hpcaitech:main' into unet
oahzxl Feb 7, 2023
717b27b
Merge branch 'unet' of https://github.com/oahzxl/ColossalAI into unet
oahzxl Feb 7, 2023
24bdbf9
finish autochunk
oahzxl Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions colossalai/autochunk/autochunk_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,7 @@
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()

if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods

from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg

Expand Down Expand Up @@ -143,7 +132,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
return context


def _replace_ones_like(
def _replace_new_tensor_like_shape(
search_chunk: SearchChunk,
chunk_infos: List[Dict],
region_idx: int,
Expand All @@ -154,7 +143,7 @@ def _replace_ones_like(
"""
add chunk slice for new tensor op such as ones like
"""
if "ones_like" in node.name:
if get_node_name(node) in ["ones_like", "zeros_like", "empty_like"]:
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
Expand All @@ -166,6 +155,33 @@ def _replace_ones_like(
return body


def _replace_new_tensor_shape(
search_chunk: SearchChunk,
chunk_infos: List[Dict],
region_idx: int,
node_idx: int,
node: Node,
body: List[str],
) -> List[str]:
"""
add chunk slice for new tensor op such as ones
"""
if get_node_name(node) in ["ones", "zeros", "empty"]:
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if chunk_dim is None:
return
if get_node_shape(meta_node)[chunk_dim] == 1:
return
origin_shape = str(node.args)
new_shape = list(node.args)
new_shape[chunk_dim] = "min(chunk_size, %d - chunk_idx)" % get_node_shape(meta_node)[chunk_dim]
new_shape = str(new_shape)
new_shape = new_shape.replace("'", "")
body[-1] = _replace_name(body[-1], origin_shape[1:-1], new_shape[1:-1])
return body


def _add_node_slice(
chunk_nodes: List[Node],
region_idx: int,
Expand Down Expand Up @@ -265,8 +281,10 @@ def emit_code_with_chunk(
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
# replace output var with chunk var
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
# ones like
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# new tensor like
body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# new tensor
body = _replace_new_tensor_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# reassgin reshape size
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
body[-1] = " " + body[-1]
Expand Down
52 changes: 23 additions & 29 deletions colossalai/autochunk/search_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import (
NodeMgr,
find_chunk_compute_input_and_output_nodes,
get_logger,
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder


class SearchChunk(object):
Expand Down Expand Up @@ -75,8 +68,8 @@ def _init_trace(self) -> None:
max_chunk_region_list = []
while True:
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
cur_node_idx = max_chunk_region[1]
if cur_node_idx == len(active_nodes) - 1:
cur_node_idx = max_chunk_region[1] + 1
if cur_node_idx >= len(active_nodes) - 1:
break
max_chunk_region_list.append(max_chunk_region)

Expand Down Expand Up @@ -135,6 +128,7 @@ def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_
min_active_node_num = min(active_node_num[free_var_num:])
threshold = max(free_var_num, min_active_node_num)

# normal search
# from peak_node to free_var
inside_flag = False
chunk_region_start = free_var_num
Expand All @@ -144,7 +138,6 @@ def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_
if inside_flag and active_node_num[i] > threshold:
chunk_region_start = i + 1
break

# from peak_node to len-2
inside_flag = False
chunk_region_end = len(active_node) - 1
Expand All @@ -155,6 +148,22 @@ def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_
chunk_region_end = i
break

# if normal search fails, use approximate search
if (chunk_region_end - chunk_region_start) > 250:
window_size = 100
# search min for start
min_num = 1e3
for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_start = i
# search min for end
min_num = 1e3
for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_end = i

# avoid chunk regions overlap
if chunk_regions is not None:
for i in chunk_regions:
Expand Down Expand Up @@ -271,12 +280,6 @@ def _step_search(
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region

def _stop_search(self, init_mem_peak, mem_peak):
sorted_init_mem_peak = sorted(init_mem_peak)
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
return True
return False

def search_region(self) -> Dict:
"""
Search all chunk regions:
Expand All @@ -291,11 +294,7 @@ def search_region(self) -> Dict:
get_logger().info("AutoChunk start searching chunk regions")

chunk_infos = []
(
init_mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
mem_peak = init_mem_peak

while True:
Expand All @@ -304,18 +303,13 @@ def search_region(self) -> Dict:
break
chunk_infos.append(chunk_info)

(
mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos)

if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))

if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
Expand Down
51 changes: 20 additions & 31 deletions colossalai/autochunk/trace_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ def _assgin_single_node_flow(
if not (start_idx <= arg_idx < end_idx):
return True

# get fix dim
arg_fix_dim = []
if cur_node_dim is not None:
for i in cur_node_fix_dim:
fix_dim_source = cur_node_source[i]
if arg_idx in fix_dim_source:
arg_fix_dim.append(fix_dim_source[arg_idx][0])
if arg_node in all_node_info:
arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))

# find arg dim
if cur_node_dim is not None:
# dim is computed
Expand All @@ -109,6 +119,9 @@ def _assgin_single_node_flow(
arg_dim = None
else:
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
# chunk dim cannot be in fix dims
if arg_dim in arg_fix_dim:
return False
# chunk dim should be None if shape size is 1
if get_node_shape(arg_node)[arg_dim] == 1:
arg_dim = None
Expand All @@ -120,19 +133,16 @@ def _assgin_single_node_flow(
else:
arg_dim = None

# get fix dim
arg_fix_dim = []
if cur_node_dim is not None:
for i in cur_node_fix_dim:
fix_dim_source = cur_node_source[i]
if arg_idx in fix_dim_source:
arg_fix_dim.append(fix_dim_source[arg_idx][0])
# add arg rest dim as fix dim
arg_fix_dim = list(range(len(get_node_shape(arg_node))))
if arg_dim is not None:
arg_fix_dim.remove(arg_dim)

# if already in node_info, arg dim must be same
if arg_node in all_node_info:
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
return False
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
all_node_info[arg_node]["fix_dim"] = arg_fix_dim
# else add it to list
else:
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
Expand Down Expand Up @@ -164,6 +174,8 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx):
continue
if is_non_compute_node(arg):
continue
if get_node_shape(arg) is None:
continue
arg_list.append(arg)
flow_flag = self._assgin_single_node_flow(
arg,
Expand All @@ -180,29 +192,6 @@ def _get_all_node_info(self, end_dim, start_idx, end_idx):
if flow_flag == False:
return None

if len(arg_list) >= 2:
# need to mark fix dim
if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]):
for arg in arg_list:
if get_node_shape(arg) is None:
continue
if not (start_idx <= self.node_mgr.find_node_idx(arg) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"]
arg_shape = get_node_shape(arg)
# add all dim as fix dim except chunk dim
for i, shape in enumerate(arg_shape):
if shape != 1 and i != cur_node_chunk_dim:
if i == arg_chunk_dim:
return None
if i not in arg_fix_dim:
arg_fix_dim.append(i)
elif any(i == get_node_name(cur_node)
for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]):
pass
else:
raise NotImplementedError()
cur_node_list = next_node_list
return all_node_info

Expand Down
Loading