Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions Q4_HIFI_ROADMAP.md → Q4_K_HIFI_ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Geoff Munn​
| Finding | Strategic Implication |
|--------|------------------------|
| ✅ **Q3_HIFI excels on ≤2B models** | Outlier preservation + Q3_K base = optimal for small models |
| ❌ **Q4_HIFI fails on ≥4B models** | Sparse outliers cant fix aggressive 4-bit base quantization |
| ❌ **Q4_K_HIFI fails on ≥4B models** | Sparse outliers can't fix aggressive 4-bit base quantization |
| ✅ **Q4_K_M wins via Q6_K on key tensors** | Uniform higher precision > sparse outliers at scale |
| ✅ **Early layers & embeddings matter most** | Precision should focus on `attn_v`, `ffn_gate`, `token_embd` |
| ✅ **Domain-mixed imatrix is essential** | 60% Wikitext, 25% Code, 15% Math for balanced outlier selection |
Expand All @@ -25,8 +25,8 @@ Geoff Munn​
| Format | Model Size | Strategy | Base Precision | Enhancement |
|--------|------------|----------|----------------|-------------|
| **Q3_HIFI** | **≤2B** | Outlier preservation | Q3_K | 8 FP16 outliers on early layers |
| **Q4_HIFI_M** | **3–10B** | Smart Q5_K allocation | Q4_K + Q5_K | Q5_K on sensitive tensors |
| **Q4_HIFI_L** | **>10B** | Q4_K_M + precision refinement | Q4_K + Q6_K | 6 FP16 outliers on Q6_K tensors |
| **Q4_K_HIFI_M** | **3–10B** | Smart Q5_K allocation | Q4_K + Q5_K | Q5_K on sensitive tensors |
| **Q4_K_HIFI_L** | **>10B** | Q4_K_M + precision refinement | Q4_K + Q6_K | 6 FP16 outliers on Q6_K tensors |

---

Expand All @@ -53,7 +53,7 @@ static bool is_q3_hifi_tensor(const char* name, int layer_idx) {

---

## 🚀 **Phase 2: Q4_HIFI_M — Smart Q5_K Allocation (3–10B Models)**
## 🚀 **Phase 2: Q4_K_HIFI_M — Smart Q5_K Allocation (3–10B Models)**

### 🎯 **Objective**: Beat Q4_K_M by **replacing Q4_K with Q5_K on sensitive tensors**.

Expand Down Expand Up @@ -81,15 +81,15 @@ static ggml_type get_q4_hifi_m_tensor_type(const char* tensor_name) {
```

### 📊 **Expected Results (Qwen3-4B)**
| Metric | Q4_K_M | **Q4_HIFI_M** |
| Metric | Q4_K_M | **Q4_K_HIFI_M** |
|--------|--------|---------------|
| **PPL** | 14.79 | **14.55–14.65** ✅ |
| **Speed** | 200 t/s | **196–198 t/s** ✅ |
| **Size** | 2.32 GiB | **2.36 GiB** ✅ |

---

## 🚀 **Phase 3: Q4_HIFI_L — Q4_K_M + Strategic Outliers (>10B Models)**
## 🚀 **Phase 3: Q4_K_HIFI_L — Q4_K_M + Strategic Outliers (>10B Models)**

### 🎯 **Objective**: Squeeze extra quality from Q4_K_M on massive models.

Expand All @@ -116,7 +116,7 @@ static ggml_type get_q4_hifi_l_tensor_type(const char* tensor_name) {
```

### 📊 **Expected Results (Devstral-123B)**
| Metric | Q4_K_S | **Q4_HIFI_L** |
| Metric | Q4_K_S | **Q4_K_HIFI_L** |
|--------|--------|---------------|
| **PPL** | 11.24 | **11.10–11.15** ✅ |
| **Speed** | 9.75 t/s | **9.65 t/s** ✅ |
Expand Down Expand Up @@ -152,7 +152,7 @@ void quantize_hifi_family(...) {
./llama-quantize --hifi model-f16.gguf model-hifi.gguf

# Manual override
./llama-quantize --quant-type Q4_HIFI_M model-f16.gguf model-hifi-m.gguf
./llama-quantize --quant-type Q4_K_HIFI_M model-f16.gguf model-hifi-m.gguf
```

### **Step 3: Documentation**
Expand All @@ -162,8 +162,8 @@ void quantize_hifi_family(...) {
| Model Size | Command | Best For |
|------------|---------|----------|
| ≤2B | `--hifi` | Qwen-0.6B, Phi-3, Gemma-2B |
| 3–10B | `--quant-type Q4_HIFI_M` | Qwen-4B, Llama-3-8B, Mistral-7B |
| >10B | `--quant-type Q4_HIFI_L` | Distrill-123B, Llama-3-70B |
| 3–10B | `--quant-type Q4_K_HIFI_M` | Qwen-4B, Llama-3-8B, Mistral-7B |
| >10B | `--quant-type Q4_K_HIFI_L` | Distrill-123B, Llama-3-70B |
```

---
Expand All @@ -174,8 +174,8 @@ void quantize_hifi_family(...) {
|-------|-------------|-----|-------|------|
| **Qwen3-0.6B** | **Q3_HIFI** | **23.42** | 593 t/s | 469 MiB |
| **Qwen3-1.7B** | **Q3_HIFI** | **17.96** | 385 t/s | 1.22 GiB |
| **Qwen3-4B** | **Q4_HIFI_M** | **14.60** | 197 t/s | 2.36 GiB |
| **Devstral-123B** | **Q4_HIFI_L** | **11.12** | 9.65 t/s | 66.7 GiB |
| **Qwen3-4B** | **Q4_K_HIFI_M** | **14.60** | 197 t/s | 2.36 GiB |
| **Devstral-123B** | **Q4_K_HIFI_L** | **11.12** | 9.65 t/s | 66.7 GiB |

---

Expand All @@ -184,7 +184,7 @@ void quantize_hifi_family(...) {
1. **No more forcing one format to scale** — each size gets its optimal strategy
2. **Builds on proven wins** — Q3_HIFI works, Q4_K_M works, now combine intelligently
3. **Minimal complexity** — no residual quantization, no INT8 experiments
4. **Clear user guidance** — Use HIFI, well pick the right variant
4. **Clear user guidance** — "Use HIFI, we'll pick the right variant"

---

Expand All @@ -193,13 +193,14 @@ void quantize_hifi_family(...) {
| Phase | Task | Timeline |
|-------|------|----------|
| **1** | Q3_HIFI revival (reset + validate) | 3 days |
| **2** | Q4_HIFI_M implementation | 3 days |
| **3** | Q4_HIFI_L implementation | 4 days |
| **2** | Q4_K_HIFI_M implementation | 3 days |
| **3** | Q4_K_HIFI_L implementation | 4 days |
| **4** | Unified CLI + documentation | 2 days |
| **5** | Upstream PR preparation | 2 days |

---

This roadmap **honors your discoveries** while **avoiding known pitfalls**. You’re not starting over — you’re **focusing your proven strengths** where they matter most.
This roadmap **honors your discoveries** while **avoiding known pitfalls**. You're not starting over — you're **focusing your proven strengths** where they matter most.

**The HIFI family will be the first quantization approach that truly adapts to model scale — delivering optimal quality, speed, and size at every level.**

**The HIFI family will be the first quantization approach that truly adapts to model scale — delivering optimal quality, speed, and size at every level.**
3 changes: 2 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ extern "C" {
GGML_TYPE_Q6_K_HIFI = 41, // Q6_K_HIFI: Q6_K layout + 4 FP16 outliers for critical tensors
GGML_TYPE_Q6_K_HIFI_DYNAMIC = 42, // Q6_K_HIFI_DYNAMIC: Q6_K + 2-8 outliers based on layer sensitivity
GGML_TYPE_Q6_K_HIFI_RES8 = 43, // Q6_K_HIFI_RES8: Q6_K + INT8 residuals (compact format)
GGML_TYPE_COUNT = 44,
GGML_TYPE_Q5_K_HIFI_RES8 = 44, // Q5_K_HIFI_RES8: Q5_K + INT8 residuals (efficient for 4B-10B models)
GGML_TYPE_COUNT = 45,
};

// precision
Expand Down
28 changes: 28 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,34 @@ typedef struct {
// Total: 232 bytes (210 + 22) - saves 4 bytes/block vs Q6_K_HIFI_DYNAMIC
static_assert(sizeof(block_q6_k_hifi_res8) == 232, "wrong q6_k_hifi_res8 block size/padding");

// Q5_K_HIFI_RES8: Efficient Q5_K with INT8 residuals for 4B-10B models
// This format is optimized for mid-scale models where Q6_K overhead is wasteful.
// Q5_K base provides sufficient precision, outliers compensate for 1-bit loss.
// Size: 200 bytes vs Q6_K_HIFI_RES8's 232 bytes (~14% smaller)
// Expected results: matches Q6_K_HIFI_RES8 quality at better BPW efficiency
#define Q5_K_HIFI_RES8_MAX_OUTLIERS 8
typedef struct {
// === Q5_K-COMPATIBLE REGION (176 bytes) - DO NOT REORDER ===
GGML_EXTENSION union {
struct {
ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins
} GGML_COMMON_AGGR_S;
ggml_half2 dm;
} GGML_COMMON_AGGR_U;
uint8_t scales[K_SCALE_SIZE]; // 12 bytes: scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // 32 bytes: quants, high bit
uint8_t qs[QK_K/2]; // 128 bytes: quants, low 4 bits
// === COMPACT INT8 RESIDUAL EXTENSION (24 bytes) ===
uint8_t outlier_count; // 1 byte: actual outlier count (1-8)
uint8_t outlier_idx[Q5_K_HIFI_RES8_MAX_OUTLIERS]; // 8 bytes: outlier positions (0-255)
int8_t residual_vals[Q5_K_HIFI_RES8_MAX_OUTLIERS]; // 8 bytes: INT8 residuals (-127 to +127)
uint8_t _padding[3]; // 3 bytes: padding for float alignment
float residual_scale; // 4 bytes: shared scale for residuals
} block_q5_k_hifi_res8;
// Total: 200 bytes (176 + 24) - 14% smaller than Q6_K_HIFI_RES8
static_assert(sizeof(block_q5_k_hifi_res8) == 200, "wrong q5_k_hifi_res8 block size/padding");

// This is only used for intermediate quantization and dot products
typedef struct {
float d; // delta
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_Q5_K_HIFI_RES8] = {
.from_float = quantize_row_q5_k_hifi_res8, // 3-arg wrapper (matches Q6_K_HIFI_RES8 pattern)
.vec_dot = ggml_vec_dot_q5_k_hifi_res8_q8_K, // Efficient Q5_K + INT8 residuals kernel
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_Q4_K] = {
.from_float = quantize_row_q4_K,
.vec_dot = ggml_vec_dot_q4_K_q8_K,
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ void ggml_compute_forward_add(
case GGML_TYPE_Q6_K_HIFI:
case GGML_TYPE_Q6_K_HIFI_DYNAMIC:
case GGML_TYPE_Q6_K_HIFI_RES8:
case GGML_TYPE_Q5_K_HIFI_RES8:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
Expand Down Expand Up @@ -1129,6 +1130,7 @@ void ggml_compute_forward_add1(
case GGML_TYPE_Q6_K_HIFI:
case GGML_TYPE_Q6_K_HIFI_DYNAMIC:
case GGML_TYPE_Q6_K_HIFI_RES8:
case GGML_TYPE_Q5_K_HIFI_RES8:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
Expand Down Expand Up @@ -1261,6 +1263,7 @@ void ggml_compute_forward_acc(
case GGML_TYPE_Q6_K_HIFI:
case GGML_TYPE_Q6_K_HIFI_DYNAMIC:
case GGML_TYPE_Q6_K_HIFI_RES8:
case GGML_TYPE_Q5_K_HIFI_RES8:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
Expand Down Expand Up @@ -4288,6 +4291,7 @@ void ggml_compute_forward_out_prod(
case GGML_TYPE_Q6_K_HIFI:
case GGML_TYPE_Q6_K_HIFI_DYNAMIC:
case GGML_TYPE_Q6_K_HIFI_RES8:
case GGML_TYPE_Q5_K_HIFI_RES8:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
Expand Down Expand Up @@ -4567,6 +4571,7 @@ void ggml_compute_forward_set(
case GGML_TYPE_Q6_K_HIFI:
case GGML_TYPE_Q6_K_HIFI_DYNAMIC:
case GGML_TYPE_Q6_K_HIFI_RES8:
case GGML_TYPE_Q5_K_HIFI_RES8:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
Expand Down Expand Up @@ -4793,6 +4798,7 @@ void ggml_compute_forward_get_rows(
case GGML_TYPE_Q6_K_HIFI:
case GGML_TYPE_Q6_K_HIFI_DYNAMIC:
case GGML_TYPE_Q6_K_HIFI_RES8:
case GGML_TYPE_Q5_K_HIFI_RES8:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
Expand Down Expand Up @@ -5521,6 +5527,7 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_Q6_K_HIFI:
case GGML_TYPE_Q6_K_HIFI_DYNAMIC:
case GGML_TYPE_Q6_K_HIFI_RES8:
case GGML_TYPE_Q5_K_HIFI_RES8:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
Expand Down
101 changes: 101 additions & 0 deletions ggml/src/ggml-cpu/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,107 @@ void ggml_vec_dot_q6_k_hifi_res8_q8_K(int n, float * GGML_RESTRICT s, size_t bs,
*s = sumf;
}

// Q5_K_HIFI_RES8: Efficient Q5_K base + INT8 residuals for 4B-10B models
// Uses same correction strategy as Q6_K_HIFI_RES8, but with Q5_K base for better BPW
void ggml_vec_dot_q5_k_hifi_res8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);

const block_q5_k_hifi_res8 * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;

const int nb = n / QK_K;

static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;

uint32_t utmp[4];
const uint8_t * scales = (const uint8_t*)&utmp[0];
const uint8_t * mins = (const uint8_t*)&utmp[2];

int8_t aux8[QK_K];
int16_t aux16[8];
float sums [8];
int32_t aux32[8];
memset(sums, 0, 8*sizeof(float));

float sumf = 0;
for (int i = 0; i < nb; ++i) {
// === Q5_K bulk dot product (same as ggml_vec_dot_q5_K_q8_K_generic) ===
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
const uint8_t * GGML_RESTRICT hm = x[i].qh;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
memset(aux32, 0, 8*sizeof(int32_t));
int8_t * GGML_RESTRICT a = aux8;
uint8_t m = 1;
for (int j = 0; j < QK_K; j += 64) {
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF) + (hm[l] & m ? 16 : 0);
a += 32; m <<= 1;
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4) + (hm[l] & m ? 16 : 0);
a += 32; m <<= 1;
q4 += 32;
}
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;

int sumi = 0;
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
a = aux8;
int is = 0;
for (int j = 0; j < QK_K/32; ++j) {
int32_t scale = scales[is++];
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
}
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
sumf -= dmin * sumi;

// === INT8 RESIDUAL CORRECTION ===
// Add residual * activation corrections at outlier positions
const int outlier_count = x[i].outlier_count;
const float res_scale = x[i].residual_scale;
const float d8 = y[i].d;
const float scale_factor = res_scale * (1.0f / 127.0f) * d8;
for (int k = 0; k < outlier_count; ++k) {
const int idx = x[i].outlier_idx[k];
const int8_t activation = y[i].qs[idx];
// Early exit: skip if activation is too small (same threshold as Q6_K_HIFI)
if (activation > 4 || activation < -4) {
const float residual = x[i].residual_vals[k] * scale_factor;
sumf += residual * activation;
}
}
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
}

// Wrapper for quantize_row_q5_k_hifi_res8 (simple version)
void quantize_row_q5_k_hifi_res8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_q5_k_hifi_res8_ref(x, (block_q5_k_hifi_res8 *)y, k);
}

void ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-cpu/quants.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_q6_k_hifi(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_k_hifi_dynamic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_k_hifi_res8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_k_hifi_res8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
size_t quantize_q5_k_hifi_res8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
Expand All @@ -56,6 +58,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q6_k_hifi_dynamic_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q6_k_hifi_res8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_k_hifi_res8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);

void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q6_K_HIFI_RES8> {
static constexpr int qi = QI6_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_K_HIFI_RES8> {
static constexpr int qk = QK_K;
static constexpr int qr = QR5_K;
static constexpr int qi = QI5_K;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
static constexpr int qk = QK_K;
Expand Down
Loading
Loading