From d1940c2b3617ebfe655c7441b4ba649c12d26a7d Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 18 May 2023 16:20:42 +0800 Subject: [PATCH 1/2] [mixed_precison] add naive amp demo --- colossalai/booster/mixed_precision/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py index 3cf0ad28cdbe..0df9d84159f9 100644 --- a/colossalai/booster/mixed_precision/__init__.py +++ b/colossalai/booster/mixed_precision/__init__.py @@ -1,17 +1,19 @@ from .bf16 import BF16MixedPrecision from .fp8 import FP8MixedPrecision from .fp16_apex import FP16ApexMixedPrecision +from .fp16_naive import FP16NaiveMixedPrecision from .fp16_torch import FP16TorchMixedPrecision from .mixed_precision_base import MixedPrecision __all__ = [ 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision', - 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision' + 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision' ] _mixed_precision_mapping = { 'fp16': FP16TorchMixedPrecision, 'fp16_apex': FP16ApexMixedPrecision, + 'fp16_naive': FP16NaiveMixedPrecision, 'bf16': BF16MixedPrecision, 'fp8': FP8MixedPrecision } From e8b7029662a586a774ee7ee578f06c8476f9da3c Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 18 May 2023 16:21:03 +0800 Subject: [PATCH 2/2] [mixed_precison] add naive amp demo --- colossalai/booster/mixed_precision/fp16_naive.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 colossalai/booster/mixed_precision/fp16_naive.py diff --git a/colossalai/booster/mixed_precision/fp16_naive.py b/colossalai/booster/mixed_precision/fp16_naive.py new file mode 100644 index 000000000000..ef1ec1f42d70 --- /dev/null +++ b/colossalai/booster/mixed_precision/fp16_naive.py @@ -0,0 +1,5 @@ +from .mixed_precision_base import MixedPrecision + + +class FP16NaiveMixedPrecision(MixedPrecision): + pass