Skip to content

Implement the OLMo architecture#6741

Merged
phymbert merged 8 commits intoggml-org:masterfrom
nopperl:olmo
Apr 19, 2024
Merged

Implement the OLMo architecture#6741
phymbert merged 8 commits intoggml-org:masterfrom
nopperl:olmo

Conversation

@nopperl
Copy link
Copy Markdown
Contributor

@nopperl nopperl commented Apr 18, 2024

Implements the recently released open-source OLMo architecture. Tested with allenai/OLMo-1B-hf and allenai/OLMo-7B-hf, should work with allenai/OLMo-1.7-7B and the future OLMo-70B as well. Fixes #5408.

Implementation differences from Llama:

  • non-parametric layer norm
  • QKV clipping

Test:

./convert-hf-to-gguf.py OLMo-1B-hf --outtype f16
./main -m OLMo-1B-hf/ggml-model-f16.gguf --temp 0.8 -s 1000 -n 50 -p "Language modeling is "

Output:

Language modeling is 
"a process that allows a computer to mimic the way that humans speak 
in order to enable computers to recognize speech in a wide range of natural 
language utterances.  In practice, it involves training a machine to 

Reference (requires transformers>=4.40.0.dev0):

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
set_seed(1000)
olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-1B-hf")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf")
message = ["Language modeling is "]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = olmo.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.8)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])

Output:

Language modeling is 
using the same techniques used for text generation. The goal is to create a semantic representation of the world in terms of a language model. In this lecture we will learn some fundamental modeling techniques and will give examples of semantic representations, both text and images

Note: llm_load_vocab shows mismatch in special tokens definition ( 31/50304 vs 52/50304 ). This is due to special tokens 50254 to 50276, which are sequences of " " of varying length.

@nopperl
Copy link
Copy Markdown
Contributor Author

nopperl commented Apr 18, 2024

Will upload the GGUF conversions here: https://huggingface.co/collections/nopperl/olmo-gguf-66211a0071b6c3d66303fcf1

@phymbert phymbert added the model Model specific label Apr 18, 2024
Comment thread convert-hf-to-gguf.py Outdated
Comment thread gguf-py/gguf/constants.py Outdated
Comment thread gguf-py/gguf/constants.py
Comment thread gguf-py/gguf/constants.py
@github-actions

This comment has been minimized.

@nopperl
Copy link
Copy Markdown
Contributor Author

nopperl commented Apr 18, 2024

@phymbert thanks for checking, I've removed the superfluous code now. The generation results are not affected.

@nopperl
Copy link
Copy Markdown
Contributor Author

nopperl commented Apr 18, 2024

As an aside, I confirm that it works with the new allenai/OLMo-1.7-7B-hf as well.

$ main -m OLMo-1.7-7B-hf/ggml-model-f16.gguf --temp 0.8 -s 1000 -n 50 -p "Language modeling is "
...
Language modeling is 
a technique for modeling natural language that has a long history and has been the focus of much recent research.  It is concerned with representing text as a series of words, phrases and sentences.  The words are the basic units, and there is
...

Comment thread convert-hf-to-gguf.py Outdated
@phymbert phymbert merged commit 9958c81 into ggml-org:master Apr 19, 2024
@reneleonhardt
Copy link
Copy Markdown
Contributor

Will upload the GGUF conversions here: https://huggingface.co/collections/nopperl/olmo-gguf-66211a0071b6c3d66303fcf1

Thank you very much! Do you think that some of them would be useful for coding prompts like in CodeGPT?
https://huggingface.co/nopperl/OLMo-1.7-7B-GGUF/tree/main

@nopperl
Copy link
Copy Markdown
Contributor Author

nopperl commented Apr 19, 2024

Will upload the GGUF conversions here: https://huggingface.co/collections/nopperl/olmo-gguf-66211a0071b6c3d66303fcf1

Thank you very much! Do you think that some of them would be useful for coding prompts like in CodeGPT? https://huggingface.co/nopperl/OLMo-1.7-7B-GGUF/tree/main

If you want to test its performance, I recommend starting from the largest one (f16). However, I did try it on a few coding prompts and I cannot really recommend it for that. For the same weight class (7B), there are better models like deepseek-ai/deepseek-coder-7b-instruct-v1.5 (for QA) or bigcode/starcoder2-7b (for code completion).

Bear in mind that this is still a base model, so it will perform worse than instruction-tuned models on these tasks.

Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
* implement olmo architecture

* remove unused variable

* remove unused moe branch

* remove check for weight

* remove superfluous moe, bias and rope tensors

* clarified comment

* fix clamp_kqv setting

* remove obsolete parameter name filter
phuongncn pushed a commit to phuongncn/llama.cpp-gx10-dgx-sparks-deepseekv4 that referenced this pull request Apr 28, 2026
* implement olmo architecture

* remove unused variable

* remove unused moe branch

* remove check for weight

* remove superfluous moe, bias and rope tensors

* clarified comment

* fix clamp_kqv setting

* remove obsolete parameter name filter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for OLMo

4 participants