Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
2cf9067
initialize tp_group for FP8 DPA
cyanguwa Apr 23, 2024
df6fea0
fix cuDNN version in unit tests for cuDNN v9
cyanguwa Apr 23, 2024
ae3de42
add hook to ignore missing fused_attn._extra_states if training from …
cyanguwa Apr 24, 2024
7532773
remove test and redundant implementation from last commit
cyanguwa Apr 24, 2024
b4105f4
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa Apr 24, 2024
4d45bba
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa Apr 24, 2024
befc86d
remove warning message and replace with docstring
cyanguwa Apr 24, 2024
273cd4e
remove tp_size/tp_group in FusedAttention; amax reduction is handled …
cyanguwa Apr 24, 2024
8e168dc
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa Apr 24, 2024
9bfd19e
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa Apr 25, 2024
b94a1ee
move core_attention.fused_attention._extra_state to core_attention._e…
cyanguwa Apr 26, 2024
5d35ff4
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa Apr 26, 2024
3924565
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa Apr 29, 2024
8f52de6
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa Apr 29, 2024
84b3d78
simplify post_state_dict_hooks between FU and DPA
cyanguwa Apr 30, 2024
8ed65cd
add temporary test
cyanguwa Apr 30, 2024
de0f072
Merge branch 'NVIDIA:main' into fp8_dpa/misc_fixes
cyanguwa Apr 30, 2024
4635fdc
remove previous attempts to move core_attention.fused_attention to co…
cyanguwa Apr 30, 2024
cd9777b
remove the test
cyanguwa Apr 30, 2024
2840f19
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa Apr 30, 2024
ab8a7d3
disable pylint self arg for hook which is required by hook
cyanguwa May 1, 2024
9304d26
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa May 1, 2024
b78163d
Merge branch 'main' into fp8_dpa/misc_fixes
cyanguwa May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def reset_global_fp8_state():
def _cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major, encoded_version = divmod(encoded_version, 1000)
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)

Expand Down
12 changes: 12 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2929,6 +2929,17 @@ def __init__(
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"

def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove fused_attention._extra_state as a missing key
when loading older TransformerEngine checkpoints. Will phase out
this hook in TransformerEngine 2.0.
"""
for key in incompatible_keys.missing_keys:
if 'fused_attention._extra_state' in key:
incompatible_keys.missing_keys.remove(key)
self.register_load_state_dict_post_hook(remove_extra_states_check)

def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
Expand Down Expand Up @@ -3282,6 +3293,7 @@ def __init__(
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs)

self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)

Expand Down