From 2cf906719158ffea9f2e89a8469a827d38f2343b Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 23 Apr 2024 21:02:17 +0000 Subject: [PATCH 01/12] initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 90da9e06b6..fcc83047ad 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2718,8 +2718,14 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - self.tp_size = tp_size - self.tp_group = tp_group + if tp_group is None: + self.tp_size = tp_size + if tp_size == 1: + self.set_tensor_parallel_group(tp_group) + else: + self.tp_size = get_distributed_world_size(tp_group) + self.set_tensor_parallel_group(tp_group) + self.set_nccl_overlap_warning_if_tp() def get_fp8_weights_scratchpad( self, From df6fea0a8e228e99a7f56a1182ea918f5438804c Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 23 Apr 2024 21:02:44 +0000 Subject: [PATCH 02/12] fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/fused_attn/test_fused_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 40cfdd34b7..caba385d46 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -70,7 +70,8 @@ def reset_global_fp8_state(): def _cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" encoded_version = ext.get_cudnn_version() - major, encoded_version = divmod(encoded_version, 1000) + major_version_magnitude = 1000 if encoded_version < 90000 else 10000 + major, encoded_version = divmod(encoded_version, major_version_magnitude) minor, patch = divmod(encoded_version, 100) return (major, minor, patch) From ae3de4217ffb05e9de6ebf0a5df1a01551d41b3e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 24 Apr 2024 00:41:57 +0000 Subject: [PATCH 03/12] add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 30 ++++++++++++++++++++----- transformer_engine/pytorch/attention.py | 19 ++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 90cfce8a6f..3b49826ade 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -60,11 +60,11 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq module_inference = ["TransformerLayer", "MultiheadAttention"] input_formats_inference = ["sbhd", "bshd"] -param_types = [torch.float32, torch.float16] +param_types = []#torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) -batch_sizes = [1, 2] +batch_sizes = [2]#1, 2] all_boolean = [True, False] @@ -627,7 +627,7 @@ def _test_e2e_checkpointing_get_model(config, dtype): ) -def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"): +def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path="checkpoint.pt"): reset_rng_states() te_inp_hidden_states = torch.randn( @@ -639,13 +639,16 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= te_inp_hidden_states.retain_grad() block = _test_e2e_checkpointing_get_model(config, dtype) + print('block._modules', block._modules.keys()) - for _ in range(steps // 2): + for i in range(steps // 2): + print(f'>>>>>>>>>>>> iter {i} fwd >>>>>>>>>>>>>') te_out = block( te_inp_hidden_states, None, ) loss = te_out.sum() + print(f'>>>>>>>>>>>> iter {i} bwd >>>>>>>>>>>>>') loss.backward() if checkpoint: @@ -654,7 +657,16 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= # loading from a checkpoint gives bitwise identical results. # Since gradients are being accumulated, it is important to # restore them post loading the checkpoint. - torch.save(block.state_dict(), path) + #print('dir(block) ',dir(block)) + sd = block.state_dict() + #del block.state_dict()['self_attention.core_attention.fused_attention._extra_state'] + del sd['self_attention.core_attention.fused_attention._extra_state'] + #torch.save(block.state_dict(), path) + torch.save(sd, path) + print('block.state_dict(): ') + #for k,v in block.state_dict().items(): + for k,v in sd.items(): + print(k) param_grads = [] for p in block.parameters(): @@ -667,7 +679,11 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= del block block = _test_e2e_checkpointing_get_model(config, dtype) + print('>>>>>>>>>>>> loading >>>>>>>>>>>>>') block.load_state_dict(torch.load(path)) + print('block.state_dict(): after') + for k,v in block.state_dict().items(): + print(k) reset_rng_states() for p in block.parameters(): @@ -701,7 +717,11 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", model_configs.keys()) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] + print() + print('=================== checkpoint=False ====================') outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) + print() + print('=================== checkpoint=True ====================') outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) # Check that results match diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index fcc83047ad..e6e123d188 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2727,6 +2727,25 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + def remove_extra_states_check(self, incompatible_keys): + for key in incompatible_keys.missing_keys: + if 'fused_attention._extra_state' in key: + incompatible_keys.missing_keys.remove(key) + warnings.warn( + f"""Ignoring missing key "{key}" in checkpoints. """ + "The likely cause is that checkpoints were collected using " + "pre-v1.6 TransformerEngine. While no functionality impact, " + "please use v1.6+ TransformerEngine for checkpointing " + "next time.") + + self.register_load_state_dict_post_hook(remove_extra_states_check) + + #def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + # missing_keys, unexpected_keys, error_msgs): + # """Overrides FusedAttention from loading _extra_states very strictly""" + # super()._load_from_state_dict(state_dict, prefix, local_metadata, False, + # missing_keys, unexpected_keys, error_msgs) + def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], From 753277368e452e20bc3225c2e463989bf02b943a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 24 Apr 2024 00:45:02 +0000 Subject: [PATCH 04/12] remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 30 +++++-------------------- transformer_engine/pytorch/attention.py | 7 ------ 2 files changed, 5 insertions(+), 32 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 3b49826ade..90cfce8a6f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -60,11 +60,11 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq module_inference = ["TransformerLayer", "MultiheadAttention"] input_formats_inference = ["sbhd", "bshd"] -param_types = []#torch.float32, torch.float16] +param_types = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) -batch_sizes = [2]#1, 2] +batch_sizes = [1, 2] all_boolean = [True, False] @@ -627,7 +627,7 @@ def _test_e2e_checkpointing_get_model(config, dtype): ) -def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path="checkpoint.pt"): +def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"): reset_rng_states() te_inp_hidden_states = torch.randn( @@ -639,16 +639,13 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" te_inp_hidden_states.retain_grad() block = _test_e2e_checkpointing_get_model(config, dtype) - print('block._modules', block._modules.keys()) - for i in range(steps // 2): - print(f'>>>>>>>>>>>> iter {i} fwd >>>>>>>>>>>>>') + for _ in range(steps // 2): te_out = block( te_inp_hidden_states, None, ) loss = te_out.sum() - print(f'>>>>>>>>>>>> iter {i} bwd >>>>>>>>>>>>>') loss.backward() if checkpoint: @@ -657,16 +654,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" # loading from a checkpoint gives bitwise identical results. # Since gradients are being accumulated, it is important to # restore them post loading the checkpoint. - #print('dir(block) ',dir(block)) - sd = block.state_dict() - #del block.state_dict()['self_attention.core_attention.fused_attention._extra_state'] - del sd['self_attention.core_attention.fused_attention._extra_state'] - #torch.save(block.state_dict(), path) - torch.save(sd, path) - print('block.state_dict(): ') - #for k,v in block.state_dict().items(): - for k,v in sd.items(): - print(k) + torch.save(block.state_dict(), path) param_grads = [] for p in block.parameters(): @@ -679,11 +667,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" del block block = _test_e2e_checkpointing_get_model(config, dtype) - print('>>>>>>>>>>>> loading >>>>>>>>>>>>>') block.load_state_dict(torch.load(path)) - print('block.state_dict(): after') - for k,v in block.state_dict().items(): - print(k) reset_rng_states() for p in block.parameters(): @@ -717,11 +701,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" @pytest.mark.parametrize("model", model_configs.keys()) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] - print() - print('=================== checkpoint=False ====================') outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) - print() - print('=================== checkpoint=True ====================') outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) # Check that results match diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e6e123d188..e8a06aaf38 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2737,15 +2737,8 @@ def remove_extra_states_check(self, incompatible_keys): "pre-v1.6 TransformerEngine. While no functionality impact, " "please use v1.6+ TransformerEngine for checkpointing " "next time.") - self.register_load_state_dict_post_hook(remove_extra_states_check) - #def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - # missing_keys, unexpected_keys, error_msgs): - # """Overrides FusedAttention from loading _extra_states very strictly""" - # super()._load_from_state_dict(state_dict, prefix, local_metadata, False, - # missing_keys, unexpected_keys, error_msgs) - def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], From befc86d6558e1f8f2c29533a3430f964809909fe Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 24 Apr 2024 21:28:28 +0000 Subject: [PATCH 05/12] remove warning message and replace with docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index c194fb6cd1..135ea72762 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2729,15 +2729,14 @@ def __init__( self.set_nccl_overlap_warning_if_tp() def remove_extra_states_check(self, incompatible_keys): + """ + Temporarily remove fused_attention._extra_state as a missing key + when loading older TransformerEngine checkpoints. Will phase out + this hook in TransformerEngine 2.0. + """ for key in incompatible_keys.missing_keys: if 'fused_attention._extra_state' in key: incompatible_keys.missing_keys.remove(key) - warnings.warn( - f"""Ignoring missing key "{key}" in checkpoints. """ - "The likely cause is that checkpoints were collected using " - "pre-v1.6 TransformerEngine. While no functionality impact, " - "please use v1.6+ TransformerEngine for checkpointing " - "next time.") self.register_load_state_dict_post_hook(remove_extra_states_check) def get_fp8_weights_scratchpad( From 273cd4e973ba7e3d6056c764decac8ccbaeb13c9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 24 Apr 2024 21:46:17 +0000 Subject: [PATCH 06/12] remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 30 +++++-------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 135ea72762..3678bea5b2 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1937,7 +1937,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, use_FAv2_bwd, - fp8, fp8_meta, tp_size, tp_group): + fp8, fp8_meta): if fp8: if _NVTE_DEBUG: print('[DotProductAttention]: using FP8 forward') @@ -2011,8 +2011,6 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors) ctx.fp8_meta = fp8_meta - ctx.tp_size = tp_size - ctx.tp_group = tp_group ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen = max_seqlen ctx.qkv_dtype = qkv_dtype @@ -2133,7 +2131,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, - use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): + use_FAv2_bwd, fp8, fp8_meta): if fp8: if _NVTE_DEBUG: print('[DotProductAttention]: using FP8 forward') @@ -2214,8 +2212,6 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) ctx.fp8_meta = fp8_meta - ctx.tp_size = tp_size - ctx.tp_group = tp_group ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2350,7 +2346,7 @@ class FusedAttnFunc(torch.autograd.Function): def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, - use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): + use_FAv2_bwd, fp8, fp8_meta): if fp8: if _NVTE_DEBUG: print('[DotProductAttention]: using FP8 forward') @@ -2488,8 +2484,6 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) ctx.fp8_meta = fp8_meta - ctx.tp_size = tp_size - ctx.tp_group = tp_group ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2691,8 +2685,6 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, - tp_size: int = 1, - tp_group: Optional[dist_group_type] = None, ) -> None: super().__init__() @@ -2719,15 +2711,6 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - if tp_group is None: - self.tp_size = tp_size - if tp_size == 1: - self.set_tensor_parallel_group(tp_group) - else: - self.tp_size = get_distributed_world_size(tp_group) - self.set_tensor_parallel_group(tp_group) - self.set_nccl_overlap_warning_if_tp() - def remove_extra_states_check(self, incompatible_keys): """ Temporarily remove fused_attention._extra_state as a missing key @@ -2892,8 +2875,6 @@ def forward( use_FAv2_bwd, self.fp8 and self.fp8_meta["recipe"].fp8_dpa, self.fp8_meta, - self.tp_size, - self.tp_group, ) # ...hd -> ...(hd) @@ -3092,9 +3073,8 @@ def __init__( attention_type=attention_type, layer_number=layer_number, deterministic=self.deterministic, - **attn_kwargs, - tp_size=self.tp_size, - tp_group=self.tp_group) + **attn_kwargs) + self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) From b94a1ee4df8a9d0688f93e7e8989ceeb531fa8af Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 26 Apr 2024 21:56:53 +0000 Subject: [PATCH 07/12] move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 40 +++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3678bea5b2..7bf9cc93bf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2713,15 +2713,30 @@ def __init__( def remove_extra_states_check(self, incompatible_keys): """ - Temporarily remove fused_attention._extra_state as a missing key + Temporarily remove core_attention._extra_state as a missing key when loading older TransformerEngine checkpoints. Will phase out this hook in TransformerEngine 2.0. """ for key in incompatible_keys.missing_keys: - if 'fused_attention._extra_state' in key: + if 'core_attention._extra_state' in key: incompatible_keys.missing_keys.remove(key) self.register_load_state_dict_post_hook(remove_extra_states_check) + def _save_to_state_dict(self, destination, prefix, keep_vars): + """ + Override to save to core_attention._extra_state. + """ + super()._save_to_state_dict(destination, prefix.replace('fused_attention.',''), keep_vars) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """ + Override to load from core_attention._extra_state. + """ + super()._load_from_state_dict(state_dict, prefix.replace('fused_attention.',''), + local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], @@ -3078,6 +3093,17 @@ def __init__( self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) + def remove_extra_states_check(self, incompatible_keys): + """ + Temporarily remove fused_attention._extra_state as a missing key + when loading older TransformerEngine checkpoints. Will phase out + this hook in TransformerEngine 2.0. + """ + for key in incompatible_keys.missing_keys: + if 'core_attention._extra_state' in key: + incompatible_keys.missing_keys.remove(key) + self.register_load_state_dict_post_hook(remove_extra_states_check) + def _checkpointed_attention_forward( self, attention_func: Callable, @@ -3123,6 +3149,16 @@ def set_context_parallel_group( self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream + def get_extra_state(self) -> torch.Tensor: + """ + Override to add core_attention._extra_state to state_dict when _save_to_state_dict(). + """ + + def set_extra_state(self, state: torch.Tensor) -> None: + """ + Override to load core_attention._extra_state when _load_from_state_dict(). + """ + @no_torch_dynamo(recursive=False) def forward( self, From 84b3d78c5a75333819f3ae05d7f4383086cf7ed5 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 30 Apr 2024 02:11:35 +0000 Subject: [PATCH 08/12] simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6f0ced7563..926666dbbc 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2927,17 +2927,6 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - def remove_extra_states_check(self, incompatible_keys): - """ - Temporarily remove core_attention._extra_state as a missing key - when loading older TransformerEngine checkpoints. Will phase out - this hook in TransformerEngine 2.0. - """ - for key in incompatible_keys.missing_keys: - if 'core_attention._extra_state' in key: - incompatible_keys.missing_keys.remove(key) - self.register_load_state_dict_post_hook(remove_extra_states_check) - def _save_to_state_dict(self, destination, prefix, keep_vars): """ Override to save to core_attention._extra_state. @@ -3316,9 +3305,14 @@ def remove_extra_states_check(self, incompatible_keys): when loading older TransformerEngine checkpoints. Will phase out this hook in TransformerEngine 2.0. """ + num = 0 + keys = [] for key in incompatible_keys.missing_keys: if 'core_attention._extra_state' in key: - incompatible_keys.missing_keys.remove(key) + num = num + 1 + keys.append(key) + for i in range(num): + incompatible_keys.missing_keys.remove(keys[i]) self.register_load_state_dict_post_hook(remove_extra_states_check) def _checkpointed_attention_forward( From 8ed65cd9468220a98ed54c3790080e78f607979a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 30 Apr 2024 02:19:45 +0000 Subject: [PATCH 09/12] add temporary test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 85 +++++++++++++++++------ transformer_engine/pytorch/module/base.py | 3 + 2 files changed, 68 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 90cfce8a6f..6dbaa870cb 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -12,6 +12,7 @@ import torch.nn as nn from torch.nn import Parameter +from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init from transformer_engine.pytorch.utils import ( init_method_normal, @@ -610,7 +611,21 @@ def _test_e2e_checkpointing_get_model(config, dtype): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - return TransformerLayer( + fp8 = False + fp8_recipe = None + if dtype == 'fp8': + fp8 = True + fp8_recipe = recipe.DelayedScaling( + margin=0, + interval=1, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + ) + + dtype = torch.bfloat16 if dtype == 'fp8' else dtype + #with fp8_model_init(enabled=fp8): + block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, @@ -625,11 +640,15 @@ def _test_e2e_checkpointing_get_model(config, dtype): params_dtype=dtype, device="cuda", ) + return block, fp8_recipe -def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"): +def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path="checkpoint.pt"): reset_rng_states() + fp8 = True if dtype == 'fp8' else False + orig_dtype = dtype + dtype = torch.bfloat16 if dtype == 'fp8' else dtype te_inp_hidden_states = torch.randn( (config.seq_len, bs, config.hidden_size), dtype=dtype, @@ -638,15 +657,16 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ) te_inp_hidden_states.retain_grad() - block = _test_e2e_checkpointing_get_model(config, dtype) + block, fp8_recipe = _test_e2e_checkpointing_get_model(config, orig_dtype) for _ in range(steps // 2): - te_out = block( - te_inp_hidden_states, - None, - ) - loss = te_out.sum() - loss.backward() + with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + te_out = block( + te_inp_hidden_states, + None, + ) + loss = te_out.sum() + loss.backward() if checkpoint: # This process is necessary so that we can start afresh with @@ -654,7 +674,19 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= # loading from a checkpoint gives bitwise identical results. # Since gradients are being accumulated, it is important to # restore them post loading the checkpoint. - torch.save(block.state_dict(), path) + #torch.save(block.state_dict(), path) + sd = block.state_dict() + for k,v in sd.items(): + if 'extra_state' in k: + print(k) + + # simulate old checkpoints where _extra_state didn't exist for fused attn + #del sd['self_attention.core_attention._extra_state'] + #for k,v in sd.items(): + # if 'extra_state' in k: + # print(k) + + torch.save(sd, path) param_grads = [] for p in block.parameters(): @@ -666,8 +698,19 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= _cuda_rng_state = torch.cuda.get_rng_state() del block - block = _test_e2e_checkpointing_get_model(config, dtype) - block.load_state_dict(torch.load(path)) + block, fp8_recipe = _test_e2e_checkpointing_get_model(config, orig_dtype) + with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + print('------- loading ') + block.load_state_dict(torch.load(path)) + print('------- state_dict()') + sd = block.state_dict() + for k,v in sd.items(): + if 'extra_state' in k: + print(k) + state=sd['self_attention.core_attention._extra_state'] + state.seek(0) + state = torch.load(state, map_location='cuda') + print('------ state ',state) reset_rng_states() for p in block.parameters(): @@ -677,12 +720,13 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= assert not param_grads, "Oops!" for _ in range(steps // 2): - te_out = block( - te_inp_hidden_states, - None, - ) - loss = te_out.sum() - loss.backward() + with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + te_out = block( + te_inp_hidden_states, + None, + ) + loss = te_out.sum() + loss.backward() torch.cuda.synchronize() @@ -696,8 +740,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= return outputs -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("dtype", ['fp8'])#torch.bfloat16])#param_types) +@pytest.mark.parametrize("bs", [2])#batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] @@ -705,6 +749,7 @@ def test_gpt_checkpointing(dtype, bs, model): outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) # Check that results match + dtype = torch.bfloat16 if dtype == 'fp8' else dtype tols = dtype_tols(dtype) if dtype in (torch.float16, torch.bfloat16): tols.update(dict(rtol=2e-2, atol=2e-3)) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0803b474f6..ba7eaf9907 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -370,6 +370,7 @@ def get_extra_state(self) -> torch.Tensor: state = None fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration + print('get_extra_state: ',self.__class__,fp8_checkpoint, self.fp8_meta["fp8_checkpoint"], self.fp8, self.fp8_calibration) if fp8_checkpoint: state = {} @@ -379,6 +380,7 @@ def get_extra_state(self) -> torch.Tensor: state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history + print('>>>>saving.... ',state["scale_fwd"]) # Store other pickelable values. extra = {} @@ -408,6 +410,7 @@ def set_extra_state(self, state: torch.Tensor) -> None: else: raise RuntimeError("Unsupported checkpoint format.") + print('>>>>loaded.... ',state["scale_fwd"] if state is not None else None) if state is None: return From 4635fdc2ff90b87032db9b924926ae513d089475 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:42:03 +0000 Subject: [PATCH 10/12] remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 2 +- transformer_engine/pytorch/attention.py | 50 +++++------------------ transformer_engine/pytorch/module/base.py | 3 -- 3 files changed, 11 insertions(+), 44 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 6dbaa870cb..a7be21e4b7 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -707,7 +707,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" for k,v in sd.items(): if 'extra_state' in k: print(k) - state=sd['self_attention.core_attention._extra_state'] + state=sd['self_attention.core_attention.fused_attention._extra_state'] state.seek(0) state = torch.load(state, map_location='cuda') print('------ state ',state) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d5577d1fc0..cf5a9d84ac 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2929,20 +2929,16 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - def _save_to_state_dict(self, destination, prefix, keep_vars): - """ - Override to save to core_attention._extra_state. - """ - super()._save_to_state_dict(destination, prefix.replace('fused_attention.',''), keep_vars) - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - """ - Override to load from core_attention._extra_state. - """ - super()._load_from_state_dict(state_dict, prefix.replace('fused_attention.',''), - local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + def remove_extra_states_check(self, incompatible_keys): + """ + Temporarily remove fused_attention._extra_state as a missing key + when loading older TransformerEngine checkpoints. Will phase out + this hook in TransformerEngine 2.0. + """ + for key in incompatible_keys.missing_keys: + if 'fused_attention._extra_state' in key: + incompatible_keys.missing_keys.remove(key) + self.register_load_state_dict_post_hook(remove_extra_states_check) def get_fp8_weights_scratchpad( self, @@ -3301,22 +3297,6 @@ def __init__( self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) - def remove_extra_states_check(self, incompatible_keys): - """ - Temporarily remove fused_attention._extra_state as a missing key - when loading older TransformerEngine checkpoints. Will phase out - this hook in TransformerEngine 2.0. - """ - num = 0 - keys = [] - for key in incompatible_keys.missing_keys: - if 'core_attention._extra_state' in key: - num = num + 1 - keys.append(key) - for i in range(num): - incompatible_keys.missing_keys.remove(keys[i]) - self.register_load_state_dict_post_hook(remove_extra_states_check) - def _checkpointed_attention_forward( self, attention_func: Callable, @@ -3362,16 +3342,6 @@ def set_context_parallel_group( self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream - def get_extra_state(self) -> torch.Tensor: - """ - Override to add core_attention._extra_state to state_dict when _save_to_state_dict(). - """ - - def set_extra_state(self, state: torch.Tensor) -> None: - """ - Override to load core_attention._extra_state when _load_from_state_dict(). - """ - @no_torch_dynamo(recursive=False) def forward( self, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ba7eaf9907..0803b474f6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -370,7 +370,6 @@ def get_extra_state(self) -> torch.Tensor: state = None fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration - print('get_extra_state: ',self.__class__,fp8_checkpoint, self.fp8_meta["fp8_checkpoint"], self.fp8, self.fp8_calibration) if fp8_checkpoint: state = {} @@ -380,7 +379,6 @@ def get_extra_state(self) -> torch.Tensor: state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history - print('>>>>saving.... ',state["scale_fwd"]) # Store other pickelable values. extra = {} @@ -410,7 +408,6 @@ def set_extra_state(self, state: torch.Tensor) -> None: else: raise RuntimeError("Unsupported checkpoint format.") - print('>>>>loaded.... ',state["scale_fwd"] if state is not None else None) if state is None: return From cd9777bbbe2e97fbaac65b32585e10dec804fd1e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:42:39 +0000 Subject: [PATCH 11/12] remove the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 85 ++++++++-------------------------- 1 file changed, 20 insertions(+), 65 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a7be21e4b7..90cfce8a6f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -12,7 +12,6 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init from transformer_engine.pytorch.utils import ( init_method_normal, @@ -611,21 +610,7 @@ def _test_e2e_checkpointing_get_model(config, dtype): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - fp8 = False - fp8_recipe = None - if dtype == 'fp8': - fp8 = True - fp8_recipe = recipe.DelayedScaling( - margin=0, - interval=1, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - ) - - dtype = torch.bfloat16 if dtype == 'fp8' else dtype - #with fp8_model_init(enabled=fp8): - block = TransformerLayer( + return TransformerLayer( config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, @@ -640,15 +625,11 @@ def _test_e2e_checkpointing_get_model(config, dtype): params_dtype=dtype, device="cuda", ) - return block, fp8_recipe -def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path="checkpoint.pt"): +def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path="checkpoint.pt"): reset_rng_states() - fp8 = True if dtype == 'fp8' else False - orig_dtype = dtype - dtype = torch.bfloat16 if dtype == 'fp8' else dtype te_inp_hidden_states = torch.randn( (config.seq_len, bs, config.hidden_size), dtype=dtype, @@ -657,16 +638,15 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" ) te_inp_hidden_states.retain_grad() - block, fp8_recipe = _test_e2e_checkpointing_get_model(config, orig_dtype) + block = _test_e2e_checkpointing_get_model(config, dtype) for _ in range(steps // 2): - with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): - te_out = block( - te_inp_hidden_states, - None, - ) - loss = te_out.sum() - loss.backward() + te_out = block( + te_inp_hidden_states, + None, + ) + loss = te_out.sum() + loss.backward() if checkpoint: # This process is necessary so that we can start afresh with @@ -674,19 +654,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" # loading from a checkpoint gives bitwise identical results. # Since gradients are being accumulated, it is important to # restore them post loading the checkpoint. - #torch.save(block.state_dict(), path) - sd = block.state_dict() - for k,v in sd.items(): - if 'extra_state' in k: - print(k) - - # simulate old checkpoints where _extra_state didn't exist for fused attn - #del sd['self_attention.core_attention._extra_state'] - #for k,v in sd.items(): - # if 'extra_state' in k: - # print(k) - - torch.save(sd, path) + torch.save(block.state_dict(), path) param_grads = [] for p in block.parameters(): @@ -698,19 +666,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" _cuda_rng_state = torch.cuda.get_rng_state() del block - block, fp8_recipe = _test_e2e_checkpointing_get_model(config, orig_dtype) - with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): - print('------- loading ') - block.load_state_dict(torch.load(path)) - print('------- state_dict()') - sd = block.state_dict() - for k,v in sd.items(): - if 'extra_state' in k: - print(k) - state=sd['self_attention.core_attention.fused_attention._extra_state'] - state.seek(0) - state = torch.load(state, map_location='cuda') - print('------ state ',state) + block = _test_e2e_checkpointing_get_model(config, dtype) + block.load_state_dict(torch.load(path)) reset_rng_states() for p in block.parameters(): @@ -720,13 +677,12 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" assert not param_grads, "Oops!" for _ in range(steps // 2): - with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): - te_out = block( - te_inp_hidden_states, - None, - ) - loss = te_out.sum() - loss.backward() + te_out = block( + te_inp_hidden_states, + None, + ) + loss = te_out.sum() + loss.backward() torch.cuda.synchronize() @@ -740,8 +696,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=2, path=" return outputs -@pytest.mark.parametrize("dtype", ['fp8'])#torch.bfloat16])#param_types) -@pytest.mark.parametrize("bs", [2])#batch_sizes) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] @@ -749,7 +705,6 @@ def test_gpt_checkpointing(dtype, bs, model): outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) # Check that results match - dtype = torch.bfloat16 if dtype == 'fp8' else dtype tols = dtype_tols(dtype) if dtype in (torch.float16, torch.bfloat16): tols.update(dict(rtol=2e-2, atol=2e-3)) From ab8a7d3788b385b0ab787edfc979a81c6c8f7ae8 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 1 May 2024 17:57:12 +0000 Subject: [PATCH 12/12] disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index cf5a9d84ac..2f5a6aa671 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2929,7 +2929,7 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - def remove_extra_states_check(self, incompatible_keys): + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ Temporarily remove fused_attention._extra_state as a missing key when loading older TransformerEngine checkpoints. Will phase out