Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion training/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface
LDFLAGS = $(FRAMEWORKS) -ldl

HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h
HEADERS_OPT = $(HEADERS_LARGE) stories_cpu_ops_opt.h

HEADERS_ANE = $(HEADERS_LARGE) ane_rmsnorm_bwd.h ane_classifier.h

Expand All @@ -16,6 +17,9 @@ train_large: train_large.m $(HEADERS_LARGE)
train_large_ane: train_large_ane.m $(HEADERS_ANE)
$(CC) $(CFLAGS) -o $@ train_large_ane.m $(LDFLAGS) -framework Accelerate

train_opt: train_opt.m $(HEADERS_OPT)
$(CC) $(CFLAGS) -o $@ train_opt.m $(LDFLAGS) -framework Accelerate -framework Metal -framework MetalPerformanceShaders

PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced

test_rmsnorm_bwd: test_rmsnorm_bwd.m $(HEADERS_ANE)
Expand All @@ -42,7 +46,7 @@ tokenize:
python3 tokenize.py

clean:
rm -f train train_large train_large_ane $(PROBES) test_rmsnorm_bwd test_classifier
rm -f train train_large train_large_ane train_opt $(PROBES) test_rmsnorm_bwd test_classifier

.PHONY: clean tokenize probes

110 changes: 110 additions & 0 deletions training/stories_cpu_ops_opt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// stories_cpu_ops_opt.h — Optimized CPU operations: NEON Adam, vectorized embedding
#pragma once
#include "stories_cpu_ops.h"
#include <arm_neon.h>

// ===== NEON-vectorized Adam optimizer =====
// ~3-3.5x faster than scalar version for large param counts
// Uses vrsqrteq_f32 + one Newton-Raphson step for fast reciprocal sqrt
static void adam_update_opt(float *w, const float *g, AdamState *s, int t,
float lr, float b1, float b2, float eps) {
float bc1 = 1.0f - powf(b1, t);
float bc2 = 1.0f - powf(b2, t);
float inv_bc1 = 1.0f / bc1;
float inv_bc2 = 1.0f / bc2;
float one_minus_b1 = 1.0f - b1;
float one_minus_b2 = 1.0f - b2;

float32x4_t vb1 = vdupq_n_f32(b1);
float32x4_t vb2 = vdupq_n_f32(b2);
float32x4_t v1mb1 = vdupq_n_f32(one_minus_b1);
float32x4_t v1mb2 = vdupq_n_f32(one_minus_b2);
float32x4_t vinv_bc1 = vdupq_n_f32(inv_bc1);
float32x4_t vinv_bc2 = vdupq_n_f32(inv_bc2);
float32x4_t vneg_lr = vdupq_n_f32(-lr);
float32x4_t veps = vdupq_n_f32(eps);

size_t n = s->n;
size_t i = 0;

// Process 4 elements at a time
for (; i + 3 < n; i += 4) {
// Load
float32x4_t vm = vld1q_f32(s->m + i);
float32x4_t vv = vld1q_f32(s->v + i);
float32x4_t vg = vld1q_f32(g + i);
float32x4_t vw = vld1q_f32(w + i);

// m = b1*m + (1-b1)*g
vm = vmlaq_f32(vmulq_f32(vb1, vm), v1mb1, vg);
// v = b2*v + (1-b2)*g*g
float32x4_t g2 = vmulq_f32(vg, vg);
vv = vmlaq_f32(vmulq_f32(vb2, vv), v1mb2, g2);

// Store updated m, v
vst1q_f32(s->m + i, vm);
vst1q_f32(s->v + i, vv);

// mhat = m / bc1, vhat = v / bc2
float32x4_t mhat = vmulq_f32(vm, vinv_bc1);
float32x4_t vhat = vmulq_f32(vv, vinv_bc2);

// Fast reciprocal sqrt: vrsqrteq + one Newton-Raphson iteration
// rsqrt_est ≈ 1/sqrt(vhat)
float32x4_t rsqrt_est = vrsqrteq_f32(vhat);
// Newton-Raphson: rsqrt *= (3 - vhat * rsqrt^2) / 2
float32x4_t rsqrt_sq = vmulq_f32(rsqrt_est, rsqrt_est);
float32x4_t nr_step = vrsqrtsq_f32(vhat, rsqrt_sq);
rsqrt_est = vmulq_f32(rsqrt_est, nr_step);

// w -= lr * mhat / (sqrt(vhat) + eps)
// = w + (-lr) * mhat * (1/(sqrt(vhat) + eps))
// Compute sqrt(vhat) from rsqrt: sqrt = vhat * rsqrt(vhat) (avoids division)
float32x4_t sqrt_vhat = vmulq_f32(vhat, rsqrt_est);
float32x4_t denom = vaddq_f32(sqrt_vhat, veps);

// Use vdivq_f32 for the final division (accurate, eps-adjusted)
float32x4_t update = vmulq_f32(vneg_lr, vdivq_f32(mhat, denom));
vw = vaddq_f32(vw, update);

vst1q_f32(w + i, vw);
}

// Scalar tail
for (; i < n; i++) {
s->m[i] = b1 * s->m[i] + one_minus_b1 * g[i];
s->v[i] = b2 * s->v[i] + one_minus_b2 * g[i] * g[i];
float mh = s->m[i] * inv_bc1;
float vh = s->v[i] * inv_bc2;
w[i] -= lr * mh / (sqrtf(vh) + eps);
}
}

// ===== Vectorized embedding lookup =====
// Gather rows from [VOCAB, DIM] row-major embed table → x [DIM, SEQ] channel-first
// Strategy: gather token rows into temp buffer [SEQ, DIM], then transpose via vDSP_mtrans
static void embed_lookup_opt(float *x, const float *embed, const uint16_t *tokens,
int dim, int seq, float *tmp) {
// Gather: tmp[t*dim + d] = embed[tokens[t]*dim + d]
for (int t = 0; t < seq; t++) {
memcpy(tmp + t * dim, embed + tokens[t] * dim, dim * sizeof(float));
}
// Transpose [SEQ, DIM] → [DIM, SEQ]: x[d*seq + t] = tmp[t*dim + d]
vDSP_mtrans(tmp, 1, x, 1, (vDSP_Length)dim, (vDSP_Length)seq);
}

// ===== Vectorized embedding backward =====
// Accumulate dE[tok] += dx[:,t] for each position
// Strategy: transpose dx [DIM, SEQ] → tmp [SEQ, DIM], then accumulate rows
static void embed_backward_opt(float *d_embed, const float *dx, const uint16_t *tokens,
int dim, int seq, float *tmp) {
// Transpose [DIM, SEQ] → [SEQ, DIM]: tmp[t*dim + d] = dx[d*seq + t]
vDSP_mtrans(dx, 1, tmp, 1, (vDSP_Length)seq, (vDSP_Length)dim);
// Scatter-add: d_embed[tok*dim .. (tok+1)*dim] += tmp[t*dim .. (t+1)*dim]
for (int t = 0; t < seq; t++) {
vDSP_vadd(tmp + t * dim, 1,
d_embed + tokens[t] * dim, 1,
d_embed + tokens[t] * dim, 1,
(vDSP_Length)dim);
}
}
6 changes: 6 additions & 0 deletions training/stories_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
// Read raw fp16 from IOSurface without conversion (for fp16 activation cache)
static void io_read_raw_fp16(IOSurfaceRef s, _Float16 *data, int ch_off, int channels, int sp) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
memcpy(data, (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, channels * sp * sizeof(_Float16));
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}

// Kernel compile/eval
static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_bytes, int oc_bytes) {
Expand Down
Loading