diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 470c21231..8cd51ea91 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -600,7 +600,7 @@ def calculate_sin_cos_rotary( self, rotary_dim: int, n_ctx: int, - base: int = 10000, + base: Union[float, int] = 10000, dtype: torch.dtype = torch.float32, ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: """ diff --git a/transformer_lens/config/HookedTransformerConfig.py b/transformer_lens/config/HookedTransformerConfig.py index 6b3d58bb9..e8450df1f 100644 --- a/transformer_lens/config/HookedTransformerConfig.py +++ b/transformer_lens/config/HookedTransformerConfig.py @@ -206,7 +206,7 @@ class HookedTransformerConfig(TransformerLensConfig): YARN extension. Defaults to 4096. use_qk_norm (bool): Whether to apply RMSNorm to the query and key projections before computing attention scores. Used by Gemma 3 models. Defaults to False. - rotary_base_local (int, *optional*): The base for rotary positional embeddings in local + rotary_base_local (float, *optional*): The base for rotary positional embeddings in local attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3) which use different RoPE bases for local (10k) and global (1M) attention. Defaults to None, which means the standard rotary_base is used for all layers. @@ -250,9 +250,9 @@ class HookedTransformerConfig(TransformerLensConfig): dtype: torch.dtype = torch.float32 tokenizer_prepends_bos: Optional[bool] = None post_embedding_ln: bool = False - rotary_base: int = 10000 + rotary_base: Union[float, int] = 10000 rotary_base_local: Optional[ - int + Union[float, int] ] = None # For models with different RoPE bases per attention type (e.g., Gemma 3) rotary_scaling_factor: float = ( 1.0 # Linear RoPE scaling factor for global attention (e.g., 8.0 for Gemma 3 4B)