Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 148 additions & 12 deletions examples/apple/coreml/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def find_multiple(n: int, k: int) -> int:
return n
return n + k - (n % k)

def silu_approx(x):
x = x.clamp(-3, 3)
x2 = x * x
x4 = x2 * x2
x6 = x4 * x2
res = 0.0017 + 0.5 * x + 0.2423 * x2 -0.0153 * x4 + 0.00057 * x6
return res

@dataclass
class ModelArgs:
Expand Down Expand Up @@ -108,6 +115,15 @@ def __post_init__(self):
if self.head_dim is None:
self.head_dim = self.dim // self.n_heads

def rms_norm_fp16_stable(x, eps=1e-5, min_scale=1e-3):
amax = x.abs().amax(dim=-1, keepdim=True)
scale = amax.clamp(min=min_scale)
x_scaled = x / scale

var = torch.square(x_scaled).mean(dim=-1, keepdim=True)
rms = torch.sqrt(var + eps)
y = x_scaled / rms
return y

class CoreMLRMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
Expand Down Expand Up @@ -146,10 +162,12 @@ def _norm(self, x):
# In future, we want to add CoreML support for the functional RMSNorm op
# We have yet to do large scale evaluations on the numeric stability of this solution, but note that
# it appears better than what exists currently (removing FP32 casts and using FP16)

norm = torch.linalg.vector_norm(x, dim=-1, keepdim=True)
rms_norm_eps0 = (
x
* torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
* torch.reciprocal(torch.linalg.vector_norm(x, dim=-1, keepdim=True))
* (torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) / norm)
# * torch.reciprocal(torch.linalg.vector_norm(x, dim=-1, keepdim=True))
)
return rms_norm_eps0

Expand All @@ -167,6 +185,10 @@ def forward(self, x):
output = self._norm(x)
return output * self.weight

_RMS_NORM = CoreMLRMSNorm
_DECOMPOSE_SDPA = True
_USE_SOFTMAX = True
_USE_SILU_APPROX = False

class Rope(torch.nn.Module):
def __init__(self, params: ModelArgs):
Expand Down Expand Up @@ -249,7 +271,15 @@ def __init__(self, args: ModelArgs):
self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
t1 = self.w1(x)
if _USE_SILU_APPROX:
t1 = silu_approx(t1)
else:
t1 = F.silu(t1)
t2 = self.w3(x)
out = t1 * t2
out = self.w2(out)
return out


class ConditionalFeedForward(nn.Module):
Expand Down Expand Up @@ -327,8 +357,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
if self.use_qk_norm:
q_norm_dim = self.head_dim
k_norm_dim = self.head_dim
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
self.q_norm_fn = _RMS_NORM(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = _RMS_NORM(k_norm_dim, eps=args.norm_eps)

def forward(
self,
Expand Down Expand Up @@ -369,9 +399,40 @@ def forward(
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)

output = torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=attn_mask
)
if not _DECOMPOSE_SDPA:
output = torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=attn_mask
)
else:

# ------------------------------
# Manual SDPA: matmuls + softmax
# q: (B, H, T_q, D)
# k: (B, H, T_k, D)
# v: (B, H, T_k, D)
# attn_mask: broadcastable to (B, H, T_q, T_k)
# ------------------------------
d = q.size(-1)
# (B, H, T_q, T_k)
scores = torch.matmul(q, k.transpose(-2, -1)) / (d ** 0.5)

if attn_mask is not None:
# attn_mask is already used this way with SDPA, keep same semantics:
# 0.0 for allowed, -inf for disallowed, added to scores.
scores = scores + attn_mask

if _USE_SOFTMAX:
# (B, H, T_q, T_k)
attn_weights = torch.softmax(scores, dim=-1)
else:
scores = scores.clamp(min=-60.0, max=60.0)
scores_max, _ = scores.max(dim=-1, keepdim=True) # (B, H, T_q, 1)
scores_exp = torch.exp(scores - scores_max)
attn_weights = scores_exp / scores_exp.sum(dim=-1, keepdim=True)

# (B, H, T_q, D)
output = torch.matmul(attn_weights, v)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output = self.wo(output)
return output, new_k, new_v
Expand All @@ -388,8 +449,8 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
self.block_sparse_moe = MOEFeedForward(args)
else:
self.feed_forward = FeedForward(args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.attention_norm = _RMS_NORM(args.dim, eps=args.norm_eps)
self.ffn_norm = _RMS_NORM(args.dim, eps=args.norm_eps)

def forward(
self,
Expand All @@ -406,9 +467,84 @@ def forward(
)

h = x + h
out = h + self.feed_forward(self.ffn_norm(h))
tmp = self.feed_forward(self.ffn_norm(h))
out = h + tmp
return out, new_k, new_v

class AttentionBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.head_dim
self.attention = Attention(args, layer_id, rope)
self.attention_norm = _RMS_NORM(args.dim, eps=args.norm_eps)

def forward(
self,
x,
freqs_cos,
freqs_sin,
k_cache,
v_cache,
attn_mask,
): # x: 1xN
norm_emb = self.attention_norm(x)
h, new_k, new_v = self.attention.forward(
norm_emb, freqs_cos, freqs_sin, k_cache, v_cache, attn_mask
)
h = x + h
return h, new_k, new_v


class FeedForwardBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.head_dim
if args.moe:
self.block_sparse_moe = MOEFeedForward(args)
else:
self.feed_forward = FeedForward(args)
self.ffn_norm = _RMS_NORM(args.dim, eps=args.norm_eps)

def forward(
self,
h,
): # x: 1xN
tmp = self.feed_forward(self.ffn_norm(h))
out = h + tmp
return out



class InputBlock(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.rope = Rope(params)
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

def forward(self, tokens: torch.LongTensor, input_pos: torch.LongTensor):
h = self.tok_embeddings(tokens)
seqlen = h.shape[1]
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
return h, freqs_cos, freqs_sin

class OutputBlock(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.generate_full_logits = params.generate_full_logits
self.norm = _RMS_NORM(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

def forward(self, h, input_length: torch.LongTensor):
if not self.generate_full_logits:
# Only the last logit is used for the new generated token
h = h[:, input_length - 1, :].squeeze(1)
h = self.norm(h)
logits = self.output(h)
return logits

class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
Expand All @@ -422,7 +558,7 @@ def __init__(self, params: ModelArgs):
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params, self.rope))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.norm = _RMS_NORM(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
Expand Down
8 changes: 6 additions & 2 deletions examples/apple/coreml/llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
out_max_splits=1,
in_target_split_size=1,
in_max_splits=1,
fqn_filer=None
):
super(SplitLinearModule, self).__init__()
self.out_split_sizes = self._get_split_sizes(
Expand Down Expand Up @@ -91,10 +92,11 @@ def forward(self, x):


def replace_linear_with_split_linear(
model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1
model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1,fqn_filer=None,
):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
should_split = isinstance(module, torch.nn.Linear) and fqn_filer(name)
if should_split:
assert module.bias is None, "SplitLinearModule does not support bias"
new_module = SplitLinearModule(
module.in_features,
Expand All @@ -103,6 +105,7 @@ def replace_linear_with_split_linear(
out_max_splits,
in_target_split_size,
in_max_splits,
fqn_filer,
)
new_module.set_params(module.weight)
setattr(model, name, new_module)
Expand All @@ -113,4 +116,5 @@ def replace_linear_with_split_linear(
out_max_splits,
in_target_split_size,
in_max_splits,
fqn_filer,
)
7 changes: 6 additions & 1 deletion examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,12 @@ def _forward_mha(
if masks:
cache_len = k.size(-2) - seq_len
mask = masks[cache_len]
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
# y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
attn = q @ k.transpose(-2, -1)
attn = attn * self.inv_scale
attn = attn + mask
attn = F.softmax(attn, dim=-1)
y = attn @ v

return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state

Expand Down
Loading