From 173d4bb336eb3be9d1ddeb817d4697f0e040e89f Mon Sep 17 00:00:00 2001 From: VJHack Date: Sat, 14 Sep 2024 11:15:51 -0500 Subject: [PATCH 1/9] added cli arg to disable context shift --- .pre-commit-config.yaml | 10 +++++----- common/arg.cpp | 8 +++++++- common/common.h | 1 + examples/main/main.cpp | 34 ++++++++++++++++++++-------------- 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 91d79162850..84a81bb56d0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,8 +9,8 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files -- repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: [flake8-no-print] +# - repo: https://github.com/PyCQA/flake8 +# rev: 7.0.0 +# hooks: +# - id: flake8 +# additional_dependencies: [flake8-no-print] diff --git a/common/arg.cpp b/common/arg.cpp index a1cd5830f93..f5faddf124f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -697,6 +697,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.n_keep = value; } )); + add_opt(llama_arg( + {"--no-context-shift"}, + format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), + [](gpt_params & params) { + params.ctx_shift = false; + } + )); add_opt(llama_arg( {"--chunks"}, "N", format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), @@ -1992,4 +1999,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, return ctx_arg; } - diff --git a/common/common.h b/common/common.h index e8025aeef57..33cb004463b 100644 --- a/common/common.h +++ b/common/common.h @@ -248,6 +248,7 @@ struct gpt_params { bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention bool no_perf = false; // disable performance metrics + bool ctx_shift = true; // context shift on inifinite text generation bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool logits_all = false; // return logits for all tokens in the batch diff --git a/examples/main/main.cpp b/examples/main/main.cpp index f41be53082a..836e7fcb3ee 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -579,29 +579,35 @@ int main(int argc, char ** argv) { // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() >= n_ctx) { - if (params.n_predict == -2) { - LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + if(!params.ctx_shift){ + LOG_TEE("\n\n%s: context full and context shift is disabled => stopping\n", __func__); break; - } + } else { + if (params.n_predict == -2) { + LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + break; + } - const int n_left = n_past - params.n_keep; - const int n_discard = n_left/2; + const int n_left = n_past - params.n_keep; + const int n_discard = n_left/2; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", - n_past, n_left, n_ctx, params.n_keep, n_discard); + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); - n_past -= n_discard; + n_past -= n_discard; - LOG("after swap: n_past = %d\n", n_past); + LOG("after swap: n_past = %d\n", n_past); - LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); - LOG("clear session path\n"); - path_session.clear(); + LOG("clear session path\n"); + path_session.clear(); + } } } else { // context extension via Self-Extend From c52b922d98ab5d23d2ffbcf877433e2f232bf5c8 Mon Sep 17 00:00:00 2001 From: VJHack Date: Sat, 14 Sep 2024 11:16:54 -0500 Subject: [PATCH 2/9] reverted precommit --- .pre-commit-config.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84a81bb56d0..91d79162850 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,8 +9,8 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files -# - repo: https://github.com/PyCQA/flake8 -# rev: 7.0.0 -# hooks: -# - id: flake8 -# additional_dependencies: [flake8-no-print] +- repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + additional_dependencies: [flake8-no-print] From 0680710b06bb886588a48fcf436d2040610d222c Mon Sep 17 00:00:00 2001 From: VJHack Date: Sat, 14 Sep 2024 11:30:10 -0500 Subject: [PATCH 3/9] updated README.md for main --- examples/main/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/main/README.md b/examples/main/README.md index 9396a34fa5a..8b233fd0573 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -161,6 +161,8 @@ A value of -1 will enable infinite text generation, even though we have a finite If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled. +The `--no-context-shift` options allows you to stop the inifinite text generation once the finite context window is full. + It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter. ### Temperature From e244300df53844f3861861f717828807f41c21aa Mon Sep 17 00:00:00 2001 From: VJHack Date: Sat, 14 Sep 2024 11:37:41 -0500 Subject: [PATCH 4/9] white space --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 836e7fcb3ee..86228524c11 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -581,7 +581,7 @@ int main(int argc, char ** argv) { // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches if (n_past + (int) embd.size() >= n_ctx) { - if(!params.ctx_shift){ + if (!params.ctx_shift){ LOG_TEE("\n\n%s: context full and context shift is disabled => stopping\n", __func__); break; } else { From cf77a846c620ac5d673592407e05554def398a73 Mon Sep 17 00:00:00 2001 From: VJHack Date: Sun, 15 Sep 2024 09:12:24 -0500 Subject: [PATCH 5/9] allow disabling context shift in the server --- examples/server/server.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 14c4af3d928..d112712fe54 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1885,6 +1885,13 @@ struct server_context { for (server_slot & slot : slots) { if (slot.ga_n == 1) { if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { + if (!params.ctx_shift){ + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + continue; + } // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; From 63f0fa572d7db84d0792c30844de573c1cb6f660 Mon Sep 17 00:00:00 2001 From: Vinesh Janarthanan <36610342+VJHack@users.noreply.github.com> Date: Sun, 15 Sep 2024 20:35:01 -0500 Subject: [PATCH 6/9] Update common/arg.cpp no-context-shift only works for main example Co-authored-by: Georgi Gerganov --- common/arg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index f5faddf124f..97391ef77f8 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -703,7 +703,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params) { params.ctx_shift = false; } - )); + ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(llama_arg( {"--chunks"}, "N", format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), From f5a23928c7101009e95bd263135a8ec662e51726 Mon Sep 17 00:00:00 2001 From: VJHack Date: Sun, 15 Sep 2024 20:57:57 -0500 Subject: [PATCH 7/9] added server example to --no-context-shift args --- common/arg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 60e37a89a68..117b6a9a7cb 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -691,7 +691,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params) { params.ctx_shift = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( {"--chunks"}, "N", format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), From 2736688af4ffb1dac4f65f439ed4916e1c46defb Mon Sep 17 00:00:00 2001 From: VJHack Date: Sun, 15 Sep 2024 21:26:46 -0500 Subject: [PATCH 8/9] removed server changes --- common/arg.cpp | 2 +- examples/main/README.md | 2 +- examples/server/server.cpp | 9 +-------- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 117b6a9a7cb..60e37a89a68 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -691,7 +691,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params) { params.ctx_shift = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(llama_arg( {"--chunks"}, "N", format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), diff --git a/examples/main/README.md b/examples/main/README.md index 8b233fd0573..6730effdf2d 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -161,7 +161,7 @@ A value of -1 will enable infinite text generation, even though we have a finite If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled. -The `--no-context-shift` options allows you to stop the inifinite text generation once the finite context window is full. +The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full. It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3122b27a093..27623e4a702 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1815,13 +1815,6 @@ struct server_context { for (server_slot & slot : slots) { if (slot.ga_n == 1) { if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { - if (!params.ctx_shift){ - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - continue; - } // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; @@ -3175,4 +3168,4 @@ int main(int argc, char ** argv) { t.join(); return 0; -} +} \ No newline at end of file From 169e8a38754cfafaaff2c9e0c74ae451b5bcf56b Mon Sep 17 00:00:00 2001 From: VJHack Date: Sun, 15 Sep 2024 21:28:16 -0500 Subject: [PATCH 9/9] white space --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 27623e4a702..b5f264ff119 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3168,4 +3168,4 @@ int main(int argc, char ** argv) { t.join(); return 0; -} \ No newline at end of file +}