From 42091c224771d25bd4a8e8922bff7f507020dea1 Mon Sep 17 00:00:00 2001 From: oahzxl <43881818+oahzxl@users.noreply.github.com> Date: Wed, 18 Jan 2023 15:49:12 +0800 Subject: [PATCH 1/2] [fx] allow control of ckpt_codegen init Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so. So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__. --- colossalai/fx/graph_module.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index 2d6a71f19e16..761c8841e93b 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -22,8 +22,9 @@ class ColoGraphModule(GraphModule): - def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): - graph.set_codegen(ActivationCheckpointCodeGen()) + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule', ckpt_codegen: bool = True): + if ckpt_codegen: + graph.set_codegen(ActivationCheckpointCodeGen()) super().__init__(root, graph, class_name) def bind(self, ckpt_def, globals): From 9c7818f07a4596675075de154b5df8cb526fe167 Mon Sep 17 00:00:00 2001 From: oahzxl <43881818+oahzxl@users.noreply.github.com> Date: Wed, 18 Jan 2023 15:58:55 +0800 Subject: [PATCH 2/2] code style --- colossalai/fx/graph_module.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index 761c8841e93b..ebb9975f27db 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -22,7 +22,11 @@ class ColoGraphModule(GraphModule): - def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule', ckpt_codegen: bool = True): + def __init__(self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = 'GraphModule', + ckpt_codegen: bool = True): if ckpt_codegen: graph.set_codegen(ActivationCheckpointCodeGen()) super().__init__(root, graph, class_name)