diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index a2d4ca074f..0840283468 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -2,6 +2,7 @@ import unittest from typing import ( Any, + Optional, Tuple, ) @@ -107,6 +108,15 @@ def data(self) -> dict: "seed": 1145141919810, } + def is_meaningless_zero_attention_layer_tests( + self, + attn_layer: int, + attn_dotr: bool, + normalize: bool, + temperature: Optional[float], + ) -> bool: + return attn_layer == 0 and (attn_dotr or normalize or temperature is not None) + @property def skip_pt(self) -> bool: ( @@ -128,7 +138,12 @@ def skip_pt(self) -> bool: concat_output_tebd, precision, ) = self.param - return CommonTest.skip_pt + return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) @property def skip_dp(self) -> bool: @@ -151,7 +166,12 @@ def skip_dp(self) -> bool: concat_output_tebd, precision, ) = self.param - return CommonTest.skip_pt + return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) @property def skip_tf(self) -> bool: @@ -176,12 +196,21 @@ def skip_tf(self) -> bool: ) = self.param # TODO (excluded_types != [] and attn_layer > 0) need fix return ( - env_protection != 0.0 - or smooth_type_embedding - or not normalize - or temperature != 1.0 - or (excluded_types != [] and attn_layer > 0) - or (type_one_side and tebd_input_mode == "strip") # not consistent yet + CommonTest.skip_tf + or ( + env_protection != 0.0 + or smooth_type_embedding + or not normalize + or temperature != 1.0 + or (excluded_types != [] and attn_layer > 0) + or (type_one_side and tebd_input_mode == "strip") # not consistent yet + ) + or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) ) tf_class = DescrptDPA1TF