Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions colossalai/fx/graph_module.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
import os
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, Union

import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
from typing import Type, Dict, List, Any, Union, Optional, Set
from pathlib import Path

try:
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall

from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
COLOGM = True
except:
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
COLOGM = False

if COLOGM:

class ColoGraphModule(GraphModule):

def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
graph.set_codegen(ActivationCheckpointCodeGen())
Comment thread
super-dainiu marked this conversation as resolved.
super().__init__(root, graph, class_name)

def bind(self, ckpt_def, globals):
Expand Down
3 changes: 2 additions & 1 deletion colossalai/fx/tracer/_symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt=False,
) -> ColoGraphModule:
"""
Symbolic tracing API
Expand Down Expand Up @@ -49,6 +50,6 @@ def symbolic_trace(
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.

"""
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
42 changes: 25 additions & 17 deletions colossalai/fx/tracer/experimental.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
import functools
import operator
import inspect
import operator
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -286,7 +286,6 @@ def _check_arg_name_valid(names):
self.graph.lint()
return self.graph


@contextmanager
def trace_activation_checkpoint(self, enabled: bool):
if enabled:
Expand Down Expand Up @@ -316,7 +315,6 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func


def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
Expand Down Expand Up @@ -385,18 +383,23 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt=False,
) -> ColoGraphModule:
if is_compatible_with_meta():
if meta_args is not None:
root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args))
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
root.cpu()
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
else:
from .tracer import ColoTracer as OrigColoTracer
graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
concrete_args=concrete_args,
meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)

Expand Down Expand Up @@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule,
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
node.kwargs)


def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
if kind == 'placeholder':
meta_out = meta_args[target] if target in meta_args else concrete_args.get(
_truncate_suffix(target), None)
meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
elif kind == 'get_attr':
attr_itr = root
atoms = target.split(".")
Expand All @@ -490,14 +493,15 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
else:
if target not in _TensorPropertyMethod:
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
**tree_map(unwrap_fn, kwargs))
elif kind == 'call_module':
mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
else:
meta_out = None
return meta_out


def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
meta_out = meta_args[target]
Expand Down Expand Up @@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
return meta_out


def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None):
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None):
result_graph = Graph()
value_remap = {}
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
Expand Down Expand Up @@ -601,20 +605,24 @@ def wrap_fn(n):
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
else:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)

elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)

elif kind == "call_module":
# if not hasattr(self, "orig_forward"):
Expand All @@ -623,20 +631,20 @@ def wrap_fn(n):
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)

if handle is not None:
handle.generate()
for node_inserted in tracer.graph.nodes:
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n])
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n])
last_node = value_remap[node_inserted]
value_remap[orig_node] = last_node
else:
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n])
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n])
Comment thread
super-dainiu marked this conversation as resolved.

del tracer

gm.graph = result_graph
gm.recompile()
meta_prop_pass(gm, root_model, meta_args)