From 12a9836778a1378f1aaba3cb59ffa5b98e315f25 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Wed, 13 Jul 2022 17:55:10 +0800 Subject: [PATCH 1/2] [fx] added apex normalization to patched modules --- .../meta_patch/patched_module/normalization.py | 13 ++++++++++++- .../test_fx/test_pipeline/test_hf_model/test_t5.py | 9 --------- .../test_fx/test_tracer/test_hf_model/test_hf_t5.py | 10 ---------- 3 files changed, 12 insertions(+), 20 deletions(-) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py index 78a3620cc522..206a80663aef 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -1,3 +1,4 @@ +from ast import Import import torch from ..registry import meta_patched_module @@ -17,4 +18,14 @@ def torch_nn_normalize(self, input): assert input.dim() == 5 # normalization maintain the same shape as the input - return input.clone() \ No newline at end of file + return input.clone() + + +try: + import apex + meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) + meta_patched_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize) +except ImportError: + pass diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py index 0b747cef648d..f24dd705cfe6 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -2,15 +2,6 @@ import transformers import torch from hf_utils import split_model_and_compare_output -from colossalai.fx.tracer.meta_patch import meta_patched_module -try: - import apex - - @meta_patched_module.register(apex.normalization.FusedRMSNorm) - def apex_fused_layernorm(self, input): - return torch.empty(input.shape, device='meta') -except ImportError: - pass BATCH_SIZE = 1 SEQ_LENGHT = 16 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 989cc9c12cd4..4e2614056d51 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -1,18 +1,8 @@ import pytest import transformers import torch -from colossalai.fx.tracer.meta_patch import meta_patched_module from utils import trace_model_and_compare_output -try: - import apex - - @meta_patched_module.register(apex.normalization.FusedRMSNorm) - def apex_fused_layernorm(self, input): - return torch.empty(input.shape, device='meta') -except ImportError: - pass - BATCH_SIZE = 1 SEQ_LENGHT = 16 From d51e6b0184001fbc4c3cfdd3fc17c05c68d8a16f Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Wed, 13 Jul 2022 18:00:22 +0800 Subject: [PATCH 2/2] remove unused imports --- colossalai/fx/tracer/meta_patch/patched_module/normalization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py index 206a80663aef..120874e70052 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -1,4 +1,3 @@ -from ast import Import import torch from ..registry import meta_patched_module