Skip to content

Model : Add support for Kimi-K2#14654

Merged
CISC merged 21 commits intoggml-org:masterfrom
gabriellarson:kimi-k2
Jul 15, 2025
Merged

Model : Add support for Kimi-K2#14654
CISC merged 21 commits intoggml-org:masterfrom
gabriellarson:kimi-k2

Conversation

@gabriellarson
Copy link
Copy Markdown
Contributor

@gabriellarson gabriellarson commented Jul 12, 2025

I used the same set_vocab approach as the HunYuanMoE, and attempted to accurately represent the kimi_tokenization.py regex in unicode.cpp .

@github-actions github-actions Bot added the python python script changes label Jul 12, 2025
@nicoboss nicoboss mentioned this pull request Jul 13, 2025
4 tasks
@anikifoss
Copy link
Copy Markdown

Thanks for the patch, I'm excited to try this model! Running convert_hf_to_gguf based on your branch, will report back with the results.

@gabriellarson
Copy link
Copy Markdown
Contributor Author

gabriellarson commented Jul 13, 2025

Uploaded Q2_K and BF16 GGUF to HuggingFace (thanks @danielhanchen for the initial BF16 conversion!) https://huggingface.co/gabriellarson/Kimi-K2-Instruct-GGUF

@danielhanchen
Copy link
Copy Markdown
Contributor

Nice work @gabriellarson ! I was also trying to get a PR going - master...unslothai:llama.cpp:master - but I was primarily stuck on the K2's special regex handling - will try yours out to see if the regex works!

danielhanchen added a commit to unslothai/llama.cpp that referenced this pull request Jul 13, 2025
@thad0ctor
Copy link
Copy Markdown

has anyone been able to get quantization to work, I got Q3 to successfully quantize but have failed with Q2, its really a pain with this model because the GGUF ended up being 2 TB when I converted it and I only have 384GB of RAM (only) and Q3 comes it at about 440GB so loading it is a nightmare

@ubergarm
Copy link
Copy Markdown
Contributor

ubergarm commented Jul 14, 2025

has anyone been able to get quantization to work, I got Q3 to successfully quantize but have failed with Q2, its really a pain with this model because the GGUF ended up being 2 TB when I converted it and I only have 384GB of RAM (only) and Q3 comes it at about 440GB so loading it is a nightmare

i believe @gabriellarson made a Q2_K if you scroll down to the bottom of their huggingface repo files and that is a link to a discussion where some folks are sharing information. going to try this PR shortly and with luck release an imatrix file (or two eventually...) 🤞

@gabriellarson
Copy link
Copy Markdown
Contributor Author

@CISC I think this is ready for review now

@thad0ctor
Copy link
Copy Markdown

has anyone been able to get quantization to work, I got Q3 to successfully quantize but have failed with Q2, its really a pain with this model because the GGUF ended up being 2 TB when I converted it and I only have 384GB of RAM (only) and Q3 comes it at about 440GB so loading it is a nightmare

i believe @gabriellarson made a Q2_K if you scroll down to the bottom of their huggingface repo files and that is a link to a discussion where some folks are sharing information. going to try this PR shortly and with luck release an imatrix file (or two eventually...) 🤞

Thank you! The Q2 I made outputs complete gibberish - not sure if my own tweaks to the SRC caused this issue or if its inherent to the commit as a whole right now. I saw the recent commits so I will start over from a clean slate. I need to test some more to see if FA or KV quant was the issue.

Gibberish aside, I was able to get about 13-15 t/s on 4x RTX 5090s which is about par on what I get with Deepseek.

@csabakecskemeti
Copy link
Copy Markdown
Contributor

I've just started converting -> testing (based on my own dequantized BF16).
Will report back (will take some time)

@anikifoss
Copy link
Copy Markdown

anikifoss commented Jul 14, 2025

I haven't tested any edge-cases involving complex character handling in the patch, but it works for English:

  • convert_hf_to_gguf worked
  • quantized model loads with llama.cpp compiled from this branch and the output looks good (able to one-shot the spinning hexagon with 20 balls)

@ubergarm
Copy link
Copy Markdown
Contributor

Thanks @anikifoss for the confirmation. I too have been able to fp8 to bf16 cast, then run convert_hf_to_gguf to get bf16 GGUF and use that to quantize a "pure" q8_0 which successfully inferenced with llama-server in a few short chats.

I have my methodology details and screenshots on the hf repo discussion and updating as I go along.

Running into issue now with imatrix dropping a lot of experts due to only 99.74% partial data, might need to look at #9400 (comment) to get a better imatrix here on mainline.

@RodriMora
Copy link
Copy Markdown
Contributor

Gibberish aside, I was able to get about 13-15 t/s on 4x RTX 5090s which is about par on what I get with Deepseek.

@thad0ctor How many layers did you have to offload to vram and at what context size did you get that speed?

@usrlocalben
Copy link
Copy Markdown

I tested briefly with the provided Q2_K quant and observed a lot of repetition in the output. Whole paragraphs repeating verbatim 3-4 times under different but similar headings. Bottom Line, In Summary, etc. (temp=0.3 and temp=0.6). I'm building Q8 and will try again.

@danielhanchen
Copy link
Copy Markdown
Contributor

@gabriellarson I can confirm the new regex seems to work well based on tokenization ID matches!

@ubergarm Yes you're correct on experts being zeros - I think I also found this to be the case.

I also made some 245GB, 281GB (IQ1_S) dynamic quants + Q2_K_XL, Q4_K_XL quants at https://huggingface.co/unsloth/Kimi-K2-Instruct-GGUF - it should work fine with this PR or using my fork https://github.com/unslothai/llama.cpp - guide to run them here: https://docs.unsloth.ai/basics/kimi-k2-how-to-run-locally#run-kimi-k2-tutorials

@thad0ctor
Copy link
Copy Markdown

Gibberish aside, I was able to get about 13-15 t/s on 4x RTX 5090s which is about par on what I get with Deepseek.

@thad0ctor How many layers did you have to offload to vram and at what context size did you get that speed?

I am pretty sure it was default context (2k) and 5 layers per card which left about 4-5 GB headroom on each card (Q_4 kv cache). I suspect that the llama.cpp splitting may be inefficiently mapping layers so I suspect that moe enhancements on ik_llama or -ot args may add some perfomance once this model gets flushed out a bit more

It's unfortunately the Kimi team didn't work with the community pre-release to get the ball rolling on compatability with common interface engines for those of us who aren't swimming in VRAM lol

@ubergarm
Copy link
Copy Markdown
Contributor

@danielhanchen

Yes you're correct on experts being zeros - I think I also found this to be the case.

Okay thanks for confirming! I checked your hf repo but didn't see your Kimi-K2-Instruct-GGUF/imatrix_unsloth.dat there, perhaps you don't release them anymore it seems? Also given it is a .dat I assume it is missing data for a lot of exps.

I'm trying @compilade 's new imatrix.gguf which seems to go ahead and save even with partial data.

I also made some 245GB, 281GB (IQ1_S) dynamic quants + Q2_K_XL, Q4_K_XL quants

Oh hey we already discussed this, but looks like your scripts mistakenly named another quant TQ1_0 given it doesn't actually contain that quantization type and is conflating the rough BPW range with an actual ternary model only quantization type.

Its great there are a lot of options in the smaller size ranges these days, but just trying to keep the naming conventions accurate! Thanks and great job getting this beast of a model going!

anikifoss pushed a commit to anikifoss/ik_llama.cpp that referenced this pull request Jul 14, 2025
@danielhanchen
Copy link
Copy Markdown
Contributor

@RodriMora Yes I saw that as well!

  1. tokenizer.encode will now encode special tokens - this does NOT affect this PR see Feature Request: Support Kimi K2 #14642 (comment) which tokenizes special tokens correctly.
  2. The multi turn tool calling chat template does need an update. One has to enable the new one via --chat-template-file PATH_TO_KIMI_K2_CHAT_TEMPLATE.jinja - I will have to re-update the quants to bake the new chat template over the next few days

@gabriellarson
Copy link
Copy Markdown
Contributor Author

@gabriellarson Thanks again for your effort on this. I got slowed down but hope to have some test quants made using your work on ik_llama.cpp. Sorry I am not sure how to cherry-pick the commits so your name shows up right.

Also, it seems that while Kimi-K2 is close enough to deepseek to run, it has a different unique looking chat template as pointed out to me by some folks on the BeaverAI Club discord:

<|im_system|>system<|im_middle|>example system prompt<|im_end|><|im_user|>user<|im_middle|>example user turn 1<|im_end|><|im_assistant|>assistant<|im_middle|>example assistant turn 1<|im_end|><|im_user|>user<|im_middle|>example user turn 2<|im_end|><|im_assistant|>assistant<|im_middle|>

Kimi-K2-Instruct/blob/main/tokenizer_config.json#L154

Not sure if it even needs to be added given the gguf has the template psure, but the recent Hunyuan-A13B did add code in that area fwiw. Just in case you do need something like that, it would go into llama-chat.cpp psure. Here is the draft chat template code I still need to test.

Thanks, will keep you posted how it goes tomorow after I finally have some test quants 🤞

I'm getting stuck on adding the template to llm_chat_detect_template() in llama-chat.cpp
} else if (tmpl_contains("???")) { return LLM_CHAT_TEMPLATE_KIMI_K2; }
I'm not sure what exactly to put here to uniquely identify the kimi template
@ubergarm

@CISC
Copy link
Copy Markdown
Member

CISC commented Jul 15, 2025

I'm getting stuck on adding the template to llm_chat_detect_template() in llama-chat.cpp } else if (tmpl_contains("???")) { return LLM_CHAT_TEMPLATE_KIMI_K2; } I'm not sure what exactly to put here to uniquely identify the kimi template

<|im_assistant|>assistant<|im_middle|> seems like a good one.

@gabriellarson gabriellarson requested a review from CISC July 15, 2025 16:27
Comment thread convert_hf_to_gguf.py Outdated
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Comment thread convert_hf_to_gguf.py Outdated
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Comment thread convert_hf_to_gguf_update.py
@ubergarm
Copy link
Copy Markdown
Contributor

ubergarm commented Jul 15, 2025

@gabriellarson @CISC

I just got it working with this, my first attempt forgot to add_ass, but without it the model would give "empty replies". I've had some folks testing with good success now using this code which would go into llama-chat.cpp as seen in this similar PR: ikawrakow/ik_llama.cpp#612 tested with this model ubergarm/Kimi-K2-Instruct-GGUF IQ2_KL 345.687 GiB (2.892 BPW) Final estimate: PPL = 3.2741 +/- 0.01689 (upload complete in 30 minutes lol)

+    } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
+        // moonshotai/Kimi-K2-Instruct
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "<|im_system|>system<|im_middle|>" << message->content << "<|im_end|>";
+            } else if (role == "assistant") {
+                ss << "<|im_user|>user<|im_middle|>" << message->content << "<|im_end|>";
+            } else {
+                ss << "<|im_assistant|>assistant<|im_middle|>" << message->content << "<|im_end|>";
+            }
+        }
+        if (add_ass) {
+            ss << "<|im_assistant|>assistant<|im_middle|>";
+        }

@gabriellarson
Copy link
Copy Markdown
Contributor Author

@ubergarm I'll add the if (add_ass){}.

i also include the tool role in mine, is the tool role not necessary?

@ubergarm
Copy link
Copy Markdown
Contributor

@gabriellarson

i also include the tool role in mine, is the tool role not necessary?

I wasn't sure myself honestly, but yes yours does look correct to me given my understanding of the official template. Should be fine, but I don't have a quant to test it and honestly don't know how to use proper tool calling 😅

👈 chat template decoder script

output

$ python chat_template_tester.py moonshotai/Kimi-K2-Instruct
>> chat template <<
<|im_system|>system<|im_middle|>example system prompt<|im_end|><|im_user|>user<|im_middle|>example user turn 1<|im_end|><|im_assistant|>assistant<|im_middle|>example assistant turn 1<|im_end|><|im_user|>user<|im_middle|>example user turn 2<|im_end|><|im_assistant|>assistant<|im_middle|>example assistant turn 2<|im_end|><|im_system|>tool<|im_middle|>## Return of \nsome kind of tool call maybe<|im_end|><|im_assistant|>assistant<|im_middle|>

python script

$ cat chat_template_tester.py
# uv pip install transformers jinja2
# (and sometimes also sentencepiece torch statsmodels, looking at you ERNIE4.5)
from transformers import AutoTokenizer
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("model", help="Name of Hugging Face LLM repo (org/model format)")
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code = True)

chat = [
    {"role": "system", "content": "example system prompt"},
    {"role": "user", "content": "example user turn 1"},
    {"role": "assistant", "content": "example assistant turn 1"},
    {"role": "user", "content": "example user turn 2"},
    {"role": "assistant", "content": "example assistant turn 2"},
    {"role": "tool", "content": "some kind of tool call maybe"},
]

print(">> chat template <<")
print(tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False))
print(">> end of chat template <<")

@CISC CISC merged commit 4a4f426 into ggml-org:master Jul 15, 2025
1 check passed
@jukofyork
Copy link
Copy Markdown
Collaborator

I've managed to create a draft model, but unsure yet if I can actually train it:

https://huggingface.co/jukofyork/Kimi-K2-Instruct-DRAFT-0.6B-UNTRAINED

https://huggingface.co/jukofyork/Kimi-K2-Instruct-DRAFT-0.6B-UNTRAINED-GGUF

Even the untrained draft gives some improvements for highly "draftable" refactoring prompts:

prompt eval time =   55366.84 ms /  1832 tokens (   30.22 ms per token,    33.09 tokens per second)
       eval time =  241439.49 ms /  1618 tokens (  149.22 ms per token,     6.70 tokens per second)
      total time =  296806.34 ms /  3450 tokens
prompt eval time =   55209.16 ms /  1832 tokens (   30.14 ms per token,    33.18 tokens per second)
       eval time =  169682.33 ms /  1524 tokens (  111.34 ms per token,     8.98 tokens per second)
      total time =  224891.50 ms /  3356 tokens
draft acceptance rate = 0.59985 (  814 accepted /  1357 generated)

I'll try to see if I can fine-tune it, but it all depends on if I can get transformers to load it and/or figure out a similar TikToken <--> SentencePiece hack...

@jukofyork
Copy link
Copy Markdown
Collaborator

I managed to get them to train:

https://huggingface.co/jukofyork/Kimi-K2-Instruct-DRAFT-0.6B-v2.0
https://huggingface.co/jukofyork/Kimi-K2-Instruct-DRAFT-0.6B-v2.0-GGUF

@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Jul 24, 2025

I made a small fix to the chat template, would appreciate if anyone can test it: #14852

@jukofyork
Copy link
Copy Markdown
Collaborator

jukofyork commented Jul 24, 2025

There's a good chance we can improve Kimi K2 a little by limiting to 32k context and changing the YaRN scaling parameter, eg:

--override-kv deepseek2.context_length=int:32768 --override-kv deepseek2.rope.scaling.factor=float:8.0 --ctx_size 32768

The tech report seems to suggest that there was no actual long-context training beyond 32k:

image.png

unlike DeepSeek V3 which used the YaRN parameters during training:

image.png

It doesn't look like it works that great above 32k anyway:

image.png

so it's probably a good idea to avoid using the 128k YaRN parameters unless you really have to, and only then use the minimum you actually need >32k, eg:

--override-kv deepseek2.context_length=int:65536 --override-kv deepseek2.rope.scaling.factor=float:16.0 --ctx_size 65536

and so on.

@jukofyork
Copy link
Copy Markdown
Collaborator

I raised a question here: MoonshotAI/Kimi-K2#55

@jukofyork
Copy link
Copy Markdown
Collaborator

I raised a question here: MoonshotAI/Kimi-K2#55

They replied:

Starting from 4k -> 32k stage, we set yarn_scale_factor=32, original_max_position_embeddings=4096 and rope_theta=50000, i.e. we use 128k settings even in 32k stages.

This is the same as what deepseek-v3 did, so it's probably not a good idea to alter the YaRN scaling parameter.

@danielhanchen
Copy link
Copy Markdown
Contributor

@jukofyork So during "mid-training" they essentially did long context extension it seems?

@ngxson Nice work :)

blime4 referenced this pull request in blime4/llama.cpp Feb 5, 2026
* Kimi-K2 conversion

* add Kimi_K2  pre type

* Kimi-K2

* Kimi-K2 unicode

* Kimi-K2

* LLAMA_MAX_EXPERTS 384

* fix vocab iteration

* regex space fix

* add kimi-k2 to pre_computed_hashes

* Updated with kimi-k2 get_vocab_base_pre hash

* fix whitespaces

* fix flake errors

* remove more unicode.cpp whitespaces

* change set_vocab() flow

* add moonshotai-Kimi-K2.jinja to /models/templates/

* update moonshotai-Kimi-K2.jinja

* add kimi-k2 chat template

* add kimi-k2

* update NotImplementedError

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* except Exception

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* LLM_CHAT_TEMPLATE_KIMI_K2 if(add_ass){}

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
* Kimi-K2 conversion

* add Kimi_K2  pre type

* Kimi-K2

* Kimi-K2 unicode

* Kimi-K2

* LLAMA_MAX_EXPERTS 384

* fix vocab iteration

* regex space fix

* add kimi-k2 to pre_computed_hashes

* Updated with kimi-k2 get_vocab_base_pre hash

* fix whitespaces

* fix flake errors

* remove more unicode.cpp whitespaces

* change set_vocab() flow

* add moonshotai-Kimi-K2.jinja to /models/templates/

* update moonshotai-Kimi-K2.jinja

* add kimi-k2 chat template

* add kimi-k2

* update NotImplementedError

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* except Exception

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* LLM_CHAT_TEMPLATE_KIMI_K2 if(add_ass){}

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
phuongncn pushed a commit to phuongncn/llama.cpp-gx10-dgx-sparks-deepseekv4 that referenced this pull request Apr 28, 2026
Original patch by @gabriellarson:
ggml-org#14654

Co-authored-by: anikifoss <anikifoss>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.