diff --git a/cortex/model/branch/_conv1d_branch.py b/cortex/model/branch/_conv1d_branch.py index a4b84bb..e70877a 100644 --- a/cortex/model/branch/_conv1d_branch.py +++ b/cortex/model/branch/_conv1d_branch.py @@ -91,7 +91,7 @@ def forward( padding_mask = trunk_outputs.padding_mask branch_features, branch_mask = self.encoder((trunk_features, padding_mask.to(trunk_features))) - pooled_features = self.pooling_op(branch_features, branch_mask) + pooled_features = self.pooling_op((branch_features, branch_mask)) branch_outputs = Conv1dBranchOutput( branch_features=branch_features.contiguous(), diff --git a/cortex/model/branch/_transformer_branch.py b/cortex/model/branch/_transformer_branch.py index 137e3f9..fe29dc0 100644 --- a/cortex/model/branch/_transformer_branch.py +++ b/cortex/model/branch/_transformer_branch.py @@ -76,7 +76,10 @@ def __init__( elif pooling_type == "weighted_mean": self.pooling_op = WeightedMeanPooling(out_dim) elif pooling_type == "attention": - self.pooling_op = PoolingSelfAttention(num_heads=num_heads, embed_dim=out_dim, dropout_p=dropout_prob) + self.pooling_op = nn.Sequential( + Apply(nn.LayerNorm(out_dim, bias=False)), + PoolingSelfAttention(num_heads=num_heads, embed_dim=out_dim, dropout_p=dropout_prob), + ) else: raise NotImplementedError @@ -94,7 +97,7 @@ def forward( padding_mask = trunk_outputs.padding_mask branch_features, branch_mask = self.encoder((trunk_features, padding_mask.to(trunk_features))) - pooled_features = self.pooling_op(branch_features, branch_mask) + pooled_features = self.pooling_op((branch_features, branch_mask)) branch_outputs = TransformerBranchOutput( branch_features=branch_features.contiguous(), diff --git a/cortex/model/elemental/_bidirectional_self_attention.py b/cortex/model/elemental/_bidirectional_self_attention.py index 173f4b2..7694e9a 100644 --- a/cortex/model/elemental/_bidirectional_self_attention.py +++ b/cortex/model/elemental/_bidirectional_self_attention.py @@ -8,6 +8,7 @@ def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0 raise ValueError("num_heads must evenly divide embed_dim") self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.dropout = nn.Dropout(dropout_p) self.dropout_p = dropout_p self.head_dim = embed_dim // num_heads @@ -35,4 +36,5 @@ def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: ) res = res.transpose(-2, -3).contiguous().flatten(start_dim=-2) + res = self.c_proj(res) return self.dropout(res), padding_mask diff --git a/cortex/model/elemental/_causal_self_attention.py b/cortex/model/elemental/_causal_self_attention.py index 0f1b76b..83a7ae6 100644 --- a/cortex/model/elemental/_causal_self_attention.py +++ b/cortex/model/elemental/_causal_self_attention.py @@ -8,6 +8,7 @@ def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0 raise ValueError("num_heads must evenly divide embed_dim") self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.dropout = nn.Dropout(dropout_p) self.dropout_p = dropout_p self.head_dim = embed_dim // num_heads @@ -32,4 +33,5 @@ def forward(self, inputs: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: ) res = res.transpose(-2, -3).flatten(start_dim=-2) + res = self.c_proj(res) return self.dropout(res), padding_mask diff --git a/cortex/model/elemental/_mean_pooling.py b/cortex/model/elemental/_mean_pooling.py index 4f4ceb3..1847820 100644 --- a/cortex/model/elemental/_mean_pooling.py +++ b/cortex/model/elemental/_mean_pooling.py @@ -7,7 +7,8 @@ class MeanPooling(nn.Module): Average pooling over the sequence dimension excluding padding token positions. """ - def forward(self, x, padding_mask): + def forward(self, inputs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + x, padding_mask = inputs weights = torch.where(padding_mask.bool(), 0.0, float("-inf")) weights = weights.softmax(dim=-1).to(x) pooled_x = (x * weights[..., None]).sum(-2) @@ -24,7 +25,8 @@ def __init__(self, in_dim): super().__init__() self.encoder = nn.Linear(in_dim, in_dim) - def forward(self, x, padding_mask): + def forward(self, inputs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + x, padding_mask = inputs weights = self.encoder(x) weights = torch.where(padding_mask.bool().unsqueeze(-1), weights, float("-inf")) weights = weights.softmax(dim=-2).to(x) diff --git a/cortex/model/elemental/_pooling_self_attention.py b/cortex/model/elemental/_pooling_self_attention.py index f039e14..11e2613 100644 --- a/cortex/model/elemental/_pooling_self_attention.py +++ b/cortex/model/elemental/_pooling_self_attention.py @@ -8,12 +8,14 @@ def __init__(self, num_heads: int = 4, embed_dim: int = 32, dropout_p: float = 0 raise ValueError("num_heads must evenly divide embed_dim") self.c_attn = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.dropout = nn.Dropout(dropout_p) self.dropout_p = dropout_p self.head_dim = embed_dim // num_heads self.num_heads = num_heads - def forward(self, x: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: + def forward(self, inputs: tuple[Tensor, Tensor]) -> Tensor: + x, padding_mask = inputs seq_len = x.size(-2) queries, keys, values = self.c_attn(x).chunk(3, dim=-1) @@ -38,5 +40,6 @@ def forward(self, x: Tensor, padding_mask: Tensor) -> tuple[Tensor, Tensor]: ) res = res.transpose(-2, -3).contiguous().flatten(start_dim=-2) + res = self.c_proj(res) res = self.dropout(res)[..., 0, :] # drop 1D query dim return res diff --git a/cortex/model/leaf/_autoregressive_lm_leaf.py b/cortex/model/leaf/_autoregressive_lm_leaf.py index 1fcee02..f713c7f 100644 --- a/cortex/model/leaf/_autoregressive_lm_leaf.py +++ b/cortex/model/leaf/_autoregressive_lm_leaf.py @@ -38,6 +38,7 @@ def __init__( *args, corruption_process: Optional[CorruptionProcess] = None, corruption_rate: float = 0.1, + layernorm: bool = True, **kwargs, ): """ @@ -49,7 +50,7 @@ def __init__( *args: Additional positional arguments to pass to the parent class **kwargs: Additional keyword arguments to pass to the parent class """ - super().__init__(*args, **kwargs) + super().__init__(*args, layernorm=layernorm, **kwargs) self.corruption_process = corruption_process self.corruption_rate = corruption_rate diff --git a/cortex/model/leaf/_classifier_leaf.py b/cortex/model/leaf/_classifier_leaf.py index 2d43ce1..fe84fa6 100644 --- a/cortex/model/leaf/_classifier_leaf.py +++ b/cortex/model/leaf/_classifier_leaf.py @@ -75,6 +75,7 @@ def __init__( last_layer_bias: bool = True, label_smoothing: Union[float, str] = 0.0, root_key: Optional[str] = None, + layernorm: bool = False, ) -> None: super().__init__() self.in_dim = in_dim @@ -83,7 +84,7 @@ def __init__( self.root_key = root_key # testing out normalizing the penultimate activations - encoder_modules = [nn.LayerNorm(in_dim, bias=False)] + encoder_modules = [nn.LayerNorm(in_dim, bias=False)] if layernorm else [] if num_layers >= 1: for _ in range(num_layers): encoder_modules.extend( diff --git a/cortex/model/leaf/_denoising_lm_leaf.py b/cortex/model/leaf/_denoising_lm_leaf.py index e37cf91..ab72666 100644 --- a/cortex/model/leaf/_denoising_lm_leaf.py +++ b/cortex/model/leaf/_denoising_lm_leaf.py @@ -38,6 +38,7 @@ def __init__( *args, corruption_process: Optional[CorruptionProcess] = None, corruption_rate: float = 0.1, + layernorm: bool = True, **kwargs, ): """ @@ -49,7 +50,7 @@ def __init__( *args: Additional positional arguments to pass to the parent class **kwargs: Additional keyword arguments to pass to the parent class """ - super().__init__(*args, **kwargs) + super().__init__(*args, layernorm=layernorm, **kwargs) self.corruption_process = corruption_process self.corruption_rate = corruption_rate