Skip to content

Add missing inference support for GPTNeoXForCausalLM (Pythia and GPT-NeoX base models)#7461

Merged
fairydreaming merged 7 commits intoggml-org:masterfrom
fairydreaming:gpt-neox
May 23, 2024
Merged

Add missing inference support for GPTNeoXForCausalLM (Pythia and GPT-NeoX base models)#7461
fairydreaming merged 7 commits intoggml-org:masterfrom
fairydreaming:gpt-neox

Conversation

@fairydreaming
Copy link
Copy Markdown
Collaborator

This pull request adds missing pieces to support inference for GPT-NeoX-based models like the GPT-NeoX and the Pythia family. Fixes #742. It also adds model types for all Pythia model sizes.
Added use_par_res hparams field corresponds to the use_parallel_residual parameter from config.json.

@github-actions github-actions Bot added the python python script changes label May 22, 2024
@ggerganov
Copy link
Copy Markdown
Member

Tested with https://huggingface.co/EleutherAI/pythia-1.4b/tree/main

Seems to work. PPL on wiki.test is 12.8692 +/- 0.09260:

./perplexity -m models/pythia-1b/ggml-model-f16.gguf -f build/wikitext-2-raw/wiki.test.raw

I guess it's normal for 1.4B model that is 1 year old. Thanks for implementing this

@cebtenzzre cebtenzzre linked an issue May 22, 2024 that may be closed by this pull request
@fairydreaming
Copy link
Copy Markdown
Collaborator Author

fairydreaming commented May 22, 2024

It seems that the perplexity is a little higher compared to the HF transformers implementation because there are differences in tokenization output between llama.cpp and GPTNeoXTokenizerFast.
Edit: It looks that there are differences in dataset files that I used for measuring perplexity for transformers and llama.cpp, will have to recheck.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 22, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 537 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8705.59ms p(95)=22299.26ms fails=, finish reason: stop=480 truncated=57
  • Prompt processing (pp): avg=103.68tk/s p(95)=455.27tk/s
  • Token generation (tg): avg=31.86tk/s p(95)=47.22tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=gpt-neox commit=7e171de882ca16fbd75f72d7d1dd4afef75c04d6

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716460765 --> 1716461391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 330.75, 330.75, 330.75, 330.75, 330.75, 691.29, 691.29, 691.29, 691.29, 691.29, 690.47, 690.47, 690.47, 690.47, 690.47, 731.07, 731.07, 731.07, 731.07, 731.07, 765.01, 765.01, 765.01, 765.01, 765.01, 774.28, 774.28, 774.28, 774.28, 774.28, 779.03, 779.03, 779.03, 779.03, 779.03, 802.67, 802.67, 802.67, 802.67, 802.67, 783.81, 783.81, 783.81, 783.81, 783.81, 785.81, 785.81, 785.81, 785.81, 785.81, 807.75, 807.75, 807.75, 807.75, 807.75, 849.72, 849.72, 849.72, 849.72, 849.72, 876.13, 876.13, 876.13, 876.13, 876.13, 883.55, 883.55, 883.55, 883.55, 883.55, 884.21, 884.21, 884.21, 884.21, 884.21, 887.5, 887.5, 887.5, 887.5, 887.5, 890.74, 890.74, 890.74, 890.74, 890.74, 887.0, 887.0, 887.0, 887.0, 887.0, 893.05, 893.05, 893.05, 893.05, 893.05, 892.89, 892.89, 892.89, 892.89, 892.89, 899.33, 899.33, 899.33, 899.33, 899.33, 894.13, 894.13, 894.13, 894.13, 894.13, 895.46, 895.46, 895.46, 895.46, 895.46, 912.73, 912.73, 912.73, 912.73, 912.73, 906.99, 906.99, 906.99, 906.99, 906.99, 906.97, 906.97, 906.97, 906.97, 906.97, 908.35, 908.35, 908.35, 908.35, 908.35, 854.22, 854.22, 854.22, 854.22, 854.22, 850.01, 850.01, 850.01, 850.01, 850.01, 850.49, 850.49, 850.49, 850.49, 850.49, 855.41, 855.41, 855.41, 855.41, 855.41, 853.77, 853.77, 853.77, 853.77, 853.77, 857.9, 857.9, 857.9, 857.9, 857.9, 861.87, 861.87, 861.87, 861.87, 861.87, 873.08, 873.08, 873.08, 873.08, 873.08, 880.87, 880.87, 880.87, 880.87, 880.87, 880.4, 880.4, 880.4, 880.4, 880.4, 878.5, 878.5, 878.5, 878.5, 878.5, 873.8, 873.8, 873.8, 873.8, 873.8, 876.65, 876.65, 876.65, 876.65, 876.65, 878.44, 878.44, 878.44, 878.44, 878.44, 878.46, 878.46, 878.46, 878.46, 878.46, 847.66, 847.66, 847.66, 847.66, 847.66, 850.26, 850.26, 850.26, 850.26, 850.26, 849.57, 849.57, 849.57, 849.57, 849.57, 848.81, 848.81, 848.81, 848.81, 848.81, 851.85, 851.85, 851.85, 851.85, 851.85, 843.14, 843.14, 843.14, 843.14, 843.14, 841.89, 841.89, 841.89, 841.89, 841.89, 844.43, 844.43, 844.43, 844.43, 844.43, 843.23, 843.23, 843.23, 843.23, 843.23, 845.33, 845.33, 845.33, 845.33, 845.33, 850.24, 850.24, 850.24, 850.24, 850.24, 850.0, 850.0, 850.0, 850.0, 850.0, 841.26, 841.26, 841.26, 841.26, 841.26, 840.65, 840.65, 840.65, 840.65, 840.65, 840.81, 840.81, 840.81, 840.81, 840.81, 840.85, 840.85, 840.85, 840.85, 840.85, 841.22, 841.22, 841.22, 841.22, 841.22, 840.66, 840.66, 840.66, 840.66, 840.66, 842.45, 842.45, 842.45, 842.45]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716460765 --> 1716461391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.31, 45.31, 45.31, 45.31, 45.31, 43.69, 43.69, 43.69, 43.69, 43.69, 28.62, 28.62, 28.62, 28.62, 28.62, 33.0, 33.0, 33.0, 33.0, 33.0, 34.71, 34.71, 34.71, 34.71, 34.71, 34.05, 34.05, 34.05, 34.05, 34.05, 33.36, 33.36, 33.36, 33.36, 33.36, 34.25, 34.25, 34.25, 34.25, 34.25, 34.69, 34.69, 34.69, 34.69, 34.69, 34.64, 34.64, 34.64, 34.64, 34.64, 34.76, 34.76, 34.76, 34.76, 34.76, 34.17, 34.17, 34.17, 34.17, 34.17, 34.12, 34.12, 34.12, 34.12, 34.12, 33.14, 33.14, 33.14, 33.14, 33.14, 32.5, 32.5, 32.5, 32.5, 32.5, 30.43, 30.43, 30.43, 30.43, 30.43, 29.68, 29.68, 29.68, 29.68, 29.68, 29.75, 29.75, 29.75, 29.75, 29.75, 29.95, 29.95, 29.95, 29.95, 29.95, 29.83, 29.83, 29.83, 29.83, 29.83, 29.81, 29.81, 29.81, 29.81, 29.81, 29.86, 29.86, 29.86, 29.86, 29.86, 30.09, 30.09, 30.09, 30.09, 30.09, 30.14, 30.14, 30.14, 30.14, 30.14, 30.02, 30.02, 30.02, 30.02, 30.02, 30.01, 30.01, 30.01, 30.01, 30.01, 30.2, 30.2, 30.2, 30.2, 30.2, 30.25, 30.25, 30.25, 30.25, 30.25, 30.19, 30.19, 30.19, 30.19, 30.19, 30.43, 30.43, 30.43, 30.43, 30.43, 30.74, 30.74, 30.74, 30.74, 30.74, 30.84, 30.84, 30.84, 30.84, 30.84, 31.07, 31.07, 31.07, 31.07, 31.07, 31.14, 31.14, 31.14, 31.14, 31.14, 31.01, 31.01, 31.01, 31.01, 31.01, 30.95, 30.95, 30.95, 30.95, 30.95, 30.54, 30.54, 30.54, 30.54, 30.54, 30.44, 30.44, 30.44, 30.44, 30.44, 30.53, 30.53, 30.53, 30.53, 30.53, 30.65, 30.65, 30.65, 30.65, 30.65, 30.76, 30.76, 30.76, 30.76, 30.76, 30.89, 30.89, 30.89, 30.89, 30.89, 30.9, 30.9, 30.9, 30.9, 30.9, 30.61, 30.61, 30.61, 30.61, 30.61, 30.15, 30.15, 30.15, 30.15, 30.15, 29.49, 29.49, 29.49, 29.49, 29.49, 29.36, 29.36, 29.36, 29.36, 29.36, 29.31, 29.31, 29.31, 29.31, 29.31, 29.28, 29.28, 29.28, 29.28, 29.28, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.18, 29.25, 29.25, 29.25, 29.25, 29.25, 29.26, 29.26, 29.26, 29.26, 29.26, 29.22, 29.22, 29.22, 29.22, 29.22, 29.29, 29.29, 29.29, 29.29, 29.29, 29.24, 29.24, 29.24, 29.24, 29.24, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.2, 29.24, 29.24, 29.24, 29.24, 29.24, 29.35, 29.35, 29.35, 29.35, 29.35, 29.47, 29.47, 29.47, 29.47]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716460765 --> 1716461391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08, 0.08, 0.08, 0.08, 0.08, 0.35, 0.35, 0.35, 0.35, 0.35, 0.09, 0.09, 0.09, 0.09, 0.09, 0.13, 0.13, 0.13, 0.13, 0.13, 0.22, 0.22, 0.22, 0.22, 0.22, 0.26, 0.26, 0.26, 0.26, 0.26, 0.09, 0.09, 0.09, 0.09, 0.09, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.28, 0.28, 0.28, 0.28, 0.28, 0.39, 0.39, 0.39, 0.39, 0.39, 0.4, 0.4, 0.4, 0.4, 0.4, 0.35, 0.35, 0.35, 0.35, 0.35, 0.26, 0.26, 0.26, 0.26, 0.26, 0.18, 0.18, 0.18, 0.18, 0.18, 0.14, 0.14, 0.14, 0.14, 0.14, 0.26, 0.26, 0.26, 0.26, 0.26, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.35, 0.35, 0.35, 0.35, 0.35, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.25, 0.25, 0.25, 0.25, 0.25, 0.33, 0.33, 0.33, 0.33, 0.33, 0.22, 0.22, 0.22, 0.22, 0.22, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.19, 0.19, 0.19, 0.19, 0.19, 0.45, 0.45, 0.45, 0.45, 0.45, 0.49, 0.49, 0.49, 0.49, 0.49, 0.47, 0.47, 0.47, 0.47, 0.47, 0.31, 0.31, 0.31, 0.31, 0.31, 0.26, 0.26, 0.26, 0.26, 0.26, 0.31, 0.31, 0.31, 0.31, 0.31, 0.28, 0.28, 0.28, 0.28, 0.28, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.22, 0.22, 0.22, 0.22, 0.22, 0.21, 0.21, 0.21, 0.21, 0.21, 0.23, 0.23, 0.23, 0.23, 0.23, 0.25, 0.25, 0.25, 0.25, 0.25, 0.14, 0.14, 0.14, 0.14, 0.14, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 537 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716460765 --> 1716461391
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0]
                    
Loading

@ggerganov
Copy link
Copy Markdown
Member

The tokenization differences on wiki.test are minimal and related to slightly different way that we handle added tokens:

diff ./build/wikitext-2-raw/wiki.test.raw.tok ./build/wikitext-2-raw/wiki.test.raw.tokcpp
245413,245414c245413,245414
< 50276
< 6285
---
> 209
> 20589
245440,245441c245440,245441
< 50276
< 6285
---
> 209
> 20589
246660,246661c246660,246661
< 50276
< 6285
---
> 209
> 20589
246687,246688c246687,246688
< 50276
< 6285
---
> 209
> 20589

Likely the perplexity computation used in the HF transformers differs from llama.cpp (i.e. different context size, strided evaluation, etc.)

For Pythia 2.8b I get PPL 10.9294 +/- 0.07654

@ggerganov
Copy link
Copy Markdown
Member

Edit: It looks that there are differences in dataset files that I used for measuring perplexity for transformers and llama.cpp, will have to recheck.

Feel free to merge this when ready - I think it works

@mofosyne mofosyne added model Model specific Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level labels May 23, 2024
@fairydreaming fairydreaming merged commit 9b82476 into ggml-org:master May 23, 2024
@felladrin
Copy link
Copy Markdown
Contributor

Thank you for this, @fairydreaming! I have wanted it for so long!
And thanks @ggerganov for reviewing it so quickly!

Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
…NeoX base models) (ggml-org#7461)

* convert-hf : add conversion of bloom-style qkv tensor to gpt-style qkv (code borrowed from BloomModel)

* llama : add inference support for LLM_ARCH_GPTNEOX

* llama : add model types for every Pythia variant and GPT-NeoX

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
phuongncn pushed a commit to phuongncn/llama.cpp-gx10-dgx-sparks-deepseekv4 that referenced this pull request Apr 28, 2026
…NeoX base models) (ggml-org#7461)

* convert-hf : add conversion of bloom-style qkv tensor to gpt-style qkv (code borrowed from BloomModel)

* llama : add inference support for LLM_ARCH_GPTNEOX

* llama : add model types for every Pythia variant and GPT-NeoX

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPT-NeoX has only minimal inference support Pythia Support?

5 participants