Hey! Thanks for releasing your work!
I was wondering if you looked at k-means for 1D data. As I understand it you can find the globally optimal centroids, so I thought it might be interesting.
Ran some tests with Llama 3.1 8B Instruct (models are on huggingface):
| Model |
Bits |
ArcC |
ArcE |
STEMMMLU |
HumanMMLU |
SocialMMLU |
OtherMMLU |
Avg |
| float16 |
16 |
51.62 |
81.86 |
58.61 |
64.42 |
75.88 |
74.31 |
67.78 |
| bfloat16 |
16 |
51.79 |
81.86 |
58.64 |
64.25 |
76.86 |
74.18 |
67.93 |
| sklearn |
4.05 |
51.02 |
81.52 |
56.58 |
60.31 |
75.88 |
72.96 |
66.37 |
| kmeans1d |
4.05 |
52.04 |
82.36 |
57.24 |
62.23 |
75.33 |
73.09 |
67.04 |
| sklearn |
3.02 |
46.67 |
78.74 |
52.39 |
54.79 |
71.01 |
70.09 |
62.28 |
| kmeans1d |
3.02 |
42.49 |
74.53 |
53.82 |
55.70 |
69.48 |
68.13 |
60.69 |
Curious what you make of it.
The change is minimal:
pip install git+https://github.com/smpanaro/kmeans1d@master (credit to apple/coremltools)
and apply this diff:
diff
diff --git a/lean_quantizer.py b/lean_quantizer.py
index a860f64..250beb1 100644
--- a/lean_quantizer.py
+++ b/lean_quantizer.py
@@ -6,6 +6,7 @@ import numpy as np
from sklearn.cluster import KMeans
from multiprocessing import Pool
from tqdm import tqdm
+import kmeans1d
import torch
import torch.nn as nn
@@ -14,13 +15,18 @@ import transformers
from quant import *
-DEBUG = False
+DEBUG = False
+USE_KMEANS1D = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def kmeans_fit(row_data):
weights_np, sample_weight, n_cluster, random_seed = row_data
+ if USE_KMEANS1D:
+ _, centroids = kmeans1d.cluster(weights_np, n_cluster, weights=sample_weight)
+ return np.array(centroids, dtype=np.float32)
+
kmeans = KMeans(
n_clusters=n_cluster,
init=np.linspace(weights_np.min(), weights_np.max(), num=n_cluster)[:, None] if n_cluster <= 8 else 'k-means++',
Hey! Thanks for releasing your work!
I was wondering if you looked at k-means for 1D data. As I understand it you can find the globally optimal centroids, so I thought it might be interesting.
Ran some tests with Llama 3.1 8B Instruct (models are on huggingface):
Curious what you make of it.
The change is minimal:
pip install git+https://github.com/smpanaro/kmeans1d@master(credit to apple/coremltools)and apply this diff:
diff