From 8e6655820c9bc602918450b69eaeff6fa6851b7c Mon Sep 17 00:00:00 2001 From: Min Guo Date: Tue, 25 Nov 2025 18:05:47 -0800 Subject: [PATCH 1/3] add split_linear Differential Revision: D87606892 --- .../apple/coreml/llama/llama_transformer.py | 38 ++++++++++++++++--- examples/apple/coreml/llama/utils.py | 8 +++- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index ae98c327b45..e3c5f5914d6 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -238,6 +238,34 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): return freqs_cos, freqs_sin +class CoreMLRMSNormV2(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The normalized tensor. + """ + + return torch.nn.functional.rms_norm(x, normalized_shape=[self.dim], weight=self.weight, eps=None) + +_RMS_NORM = CoreMLRMSNorm class FeedForward(nn.Module): def __init__(self, args: ModelArgs): @@ -327,8 +355,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): if self.use_qk_norm: q_norm_dim = self.head_dim k_norm_dim = self.head_dim - self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) - self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + self.q_norm_fn = _RMS_NORM(q_norm_dim, eps=args.norm_eps) + self.k_norm_fn = _RMS_NORM(k_norm_dim, eps=args.norm_eps) def forward( self, @@ -388,8 +416,8 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): self.block_sparse_moe = MOEFeedForward(args) else: self.feed_forward = FeedForward(args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = _RMS_NORM(args.dim, eps=args.norm_eps) + self.ffn_norm = _RMS_NORM(args.dim, eps=args.norm_eps) def forward( self, @@ -422,7 +450,7 @@ def __init__(self, params: ModelArgs): self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params, self.rope)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.norm = _RMS_NORM(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.generate_full_logits = params.generate_full_logits self.max_seq_len = params.max_seq_len diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py index 1e5a842fed5..11effbf9678 100644 --- a/examples/apple/coreml/llama/utils.py +++ b/examples/apple/coreml/llama/utils.py @@ -16,6 +16,7 @@ def __init__( out_max_splits=1, in_target_split_size=1, in_max_splits=1, + fqn_filer=None ): super(SplitLinearModule, self).__init__() self.out_split_sizes = self._get_split_sizes( @@ -91,10 +92,11 @@ def forward(self, x): def replace_linear_with_split_linear( - model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1 + model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1,fqn_filer=None, ): for name, module in model.named_children(): - if isinstance(module, torch.nn.Linear): + should_split = isinstance(module, torch.nn.Linear) and fqn_filer(name) + if should_split: assert module.bias is None, "SplitLinearModule does not support bias" new_module = SplitLinearModule( module.in_features, @@ -103,6 +105,7 @@ def replace_linear_with_split_linear( out_max_splits, in_target_split_size, in_max_splits, + fqn_filer, ) new_module.set_params(module.weight) setattr(model, name, new_module) @@ -113,4 +116,5 @@ def replace_linear_with_split_linear( out_max_splits, in_target_split_size, in_max_splits, + fqn_filer, ) From 36e517f332d118ac9af4aec20001f09d3f025acd Mon Sep 17 00:00:00 2001 From: Min Guo Date: Wed, 26 Nov 2025 10:14:21 -0800 Subject: [PATCH 2/3] coreml experiment disable sdpa Differential Revision: D87906465 --- .../apple/coreml/llama/llama_transformer.py | 178 ++++++++++++++---- 1 file changed, 143 insertions(+), 35 deletions(-) diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index e3c5f5914d6..aefa2178670 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -33,6 +33,13 @@ def find_multiple(n: int, k: int) -> int: return n return n + k - (n % k) +def silu_approx(x): + x = x.clamp(-3, 3) + x2 = x * x + x4 = x2 * x2 + x6 = x4 * x2 + res = 0.0017 + 0.5 * x + 0.2423 * x2 -0.0153 * x4 + 0.00057 * x6 + return res @dataclass class ModelArgs: @@ -108,6 +115,15 @@ def __post_init__(self): if self.head_dim is None: self.head_dim = self.dim // self.n_heads +def rms_norm_fp16_stable(x, eps=1e-5, min_scale=1e-3): + amax = x.abs().amax(dim=-1, keepdim=True) + scale = amax.clamp(min=min_scale) + x_scaled = x / scale + + var = torch.square(x_scaled).mean(dim=-1, keepdim=True) + rms = torch.sqrt(var + eps) + y = x_scaled / rms + return y class CoreMLRMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): @@ -146,10 +162,12 @@ def _norm(self, x): # In future, we want to add CoreML support for the functional RMSNorm op # We have yet to do large scale evaluations on the numeric stability of this solution, but note that # it appears better than what exists currently (removing FP32 casts and using FP16) + + norm = torch.linalg.vector_norm(x, dim=-1, keepdim=True) rms_norm_eps0 = ( x - * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) - * torch.reciprocal(torch.linalg.vector_norm(x, dim=-1, keepdim=True)) + * (torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) / norm) + # * torch.reciprocal(torch.linalg.vector_norm(x, dim=-1, keepdim=True)) ) return rms_norm_eps0 @@ -167,6 +185,10 @@ def forward(self, x): output = self._norm(x) return output * self.weight +_RMS_NORM = CoreMLRMSNorm +_DECOMPOSE_SDPA = True +_USE_SOFTMAX = True +_USE_SILU_APPROX = False class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): @@ -238,34 +260,6 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): return freqs_cos, freqs_sin -class CoreMLRMSNormV2(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - """ - Initialize the RMSNorm normalization layer. - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. - """ - super().__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - """ - Apply the RMSNorm normalization to the input tensor. - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The normalized tensor. - """ - - return torch.nn.functional.rms_norm(x, normalized_shape=[self.dim], weight=self.weight, eps=None) - -_RMS_NORM = CoreMLRMSNorm class FeedForward(nn.Module): def __init__(self, args: ModelArgs): @@ -277,7 +271,15 @@ def __init__(self, args: ModelArgs): self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + t1 = self.w1(x) + if _USE_SILU_APPROX: + t1 = silu_approx(t1) + else: + t1 = F.silu(t1) + t2 = self.w3(x) + out = t1 * t2 + out = self.w2(out) + return out class ConditionalFeedForward(nn.Module): @@ -397,9 +399,40 @@ def forward( k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - output = torch.ops.aten.scaled_dot_product_attention.default( - q, k, v, attn_mask=attn_mask - ) + if not _DECOMPOSE_SDPA: + output = torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, attn_mask=attn_mask + ) + else: + + # ------------------------------ + # Manual SDPA: matmuls + softmax + # q: (B, H, T_q, D) + # k: (B, H, T_k, D) + # v: (B, H, T_k, D) + # attn_mask: broadcastable to (B, H, T_q, T_k) + # ------------------------------ + d = q.size(-1) + # (B, H, T_q, T_k) + scores = torch.matmul(q, k.transpose(-2, -1)) / (d ** 0.5) + + if attn_mask is not None: + # attn_mask is already used this way with SDPA, keep same semantics: + # 0.0 for allowed, -inf for disallowed, added to scores. + scores = scores + attn_mask + + if _USE_SOFTMAX: + # (B, H, T_q, T_k) + attn_weights = torch.softmax(scores, dim=-1) + else: + scores = scores.clamp(min=-60.0, max=60.0) + scores_max, _ = scores.max(dim=-1, keepdim=True) # (B, H, T_q, 1) + scores_exp = torch.exp(scores - scores_max) + attn_weights = scores_exp / scores_exp.sum(dim=-1, keepdim=True) + + # (B, H, T_q, D) + output = torch.matmul(attn_weights, v) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output = self.wo(output) return output, new_k, new_v @@ -434,9 +467,84 @@ def forward( ) h = x + h - out = h + self.feed_forward(self.ffn_norm(h)) + tmp = self.feed_forward(self.ffn_norm(h)) + out = h + tmp return out, new_k, new_v +class AttentionBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.head_dim + self.attention = Attention(args, layer_id, rope) + self.attention_norm = _RMS_NORM(args.dim, eps=args.norm_eps) + + def forward( + self, + x, + freqs_cos, + freqs_sin, + k_cache, + v_cache, + attn_mask, + ): # x: 1xN + norm_emb = self.attention_norm(x) + h, new_k, new_v = self.attention.forward( + norm_emb, freqs_cos, freqs_sin, k_cache, v_cache, attn_mask + ) + h = x + h + return h, new_k, new_v + + +class FeedForwardBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.head_dim + if args.moe: + self.block_sparse_moe = MOEFeedForward(args) + else: + self.feed_forward = FeedForward(args) + self.ffn_norm = _RMS_NORM(args.dim, eps=args.norm_eps) + + def forward( + self, + h, + ): # x: 1xN + tmp = self.feed_forward(self.ffn_norm(h)) + out = h + tmp + return out + + + +class InputBlock(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.rope = Rope(params) + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + + def forward(self, tokens: torch.LongTensor, input_pos: torch.LongTensor): + h = self.tok_embeddings(tokens) + seqlen = h.shape[1] + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen) + return h, freqs_cos, freqs_sin + +class OutputBlock(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.generate_full_logits = params.generate_full_logits + self.norm = _RMS_NORM(params.dim, eps=params.norm_eps) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + + def forward(self, h, input_length: torch.LongTensor): + if not self.generate_full_logits: + # Only the last logit is used for the new generated token + h = h[:, input_length - 1, :].squeeze(1) + h = self.norm(h) + logits = self.output(h) + return logits class Transformer(nn.Module): def __init__(self, params: ModelArgs): From 06fd91247fde4b08444768818d1c162baabb90c6 Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Thu, 4 Dec 2025 12:24:28 -0800 Subject: [PATCH 3/3] disable sdpa in static attention mha forward Differential Revision: D87944616 --- examples/models/llama/static_attention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 4d5b9c1da57..04efcd08116 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1036,7 +1036,12 @@ def _forward_mha( if masks: cache_len = k.size(-2) - seq_len mask = masks[cache_len] - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + # y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + attn = q @ k.transpose(-2, -1) + attn = attn * self.inv_scale + attn = attn + mask + attn = F.softmax(attn, dim=-1) + y = attn @ v return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state