diff --git a/apps/llm/app/multimodal_llm/index.tsx b/apps/llm/app/multimodal_llm/index.tsx index 87c2ad6870..b7d6859ede 100644 --- a/apps/llm/app/multimodal_llm/index.tsx +++ b/apps/llm/app/multimodal_llm/index.tsx @@ -14,7 +14,7 @@ import { import { launchImageLibrary } from 'react-native-image-picker'; import { useIsFocused } from '@react-navigation/native'; import { useSafeAreaInsets } from 'react-native-safe-area-context'; -import { useLLM, LFM2_VL_1_6B_QUANTIZED } from 'react-native-executorch'; +import { useLLM, LFM2_5_VL_1_6B_QUANTIZED } from 'react-native-executorch'; import SendIcon from '../../assets/icons/send_icon.svg'; import PauseIcon from '../../assets/icons/pause_icon.svg'; import ColorPalette from '../../colors'; @@ -50,7 +50,7 @@ function MultimodalLLMScreen() { const [error, setError] = useState(null); const vlm = useLLM({ - model: LFM2_VL_1_6B_QUANTIZED, + model: LFM2_5_VL_1_6B_QUANTIZED, }); const tokenCount = vlm.isReady ? vlm.getGeneratedTokenCount() : 0; const { stats, onMessageSend } = useLLMStats( diff --git a/docs/docs/03-hooks/01-natural-language-processing/useLLM.md b/docs/docs/03-hooks/01-natural-language-processing/useLLM.md index a75d2b0b12..7e6e4b17c2 100644 --- a/docs/docs/03-hooks/01-natural-language-processing/useLLM.md +++ b/docs/docs/03-hooks/01-natural-language-processing/useLLM.md @@ -211,7 +211,15 @@ To configure model (i.e. change system prompt, load initial conversation history - [`temperature`](../../06-api-reference/interfaces/GenerationConfig.md#temperature) - Scales output logits by the inverse of temperature. Controls the randomness / creativity of text generation. - - [`topp`](../../06-api-reference/interfaces/GenerationConfig.md#topp) - Only samples from the smallest set of tokens whose cumulative probability exceeds topp. + - [`topP`](../../06-api-reference/interfaces/GenerationConfig.md#topp) - Only samples from the smallest set of tokens whose cumulative probability exceeds topP. Range `[0, 1]`. Values of `0` or `1` disable top-p filtering. + + - [`minP`](../../06-api-reference/interfaces/GenerationConfig.md#minp) - Minimum-probability threshold applied after softmax: tokens whose probability is below `minP * max_prob` are excluded from sampling. Range `[0, 1]`. Default `0` disables the filter. Stacks with `topP` when both are set. + + - [`repetitionPenalty`](../../06-api-reference/interfaces/GenerationConfig.md#repetitionpenalty) - Multiplicative penalty applied to logits of tokens that already appeared in the prompt or the generated text. Values greater than `1` discourage repetition; default `1` disables the penalty. + +:::info[Built-in models ship with sampling defaults] +Model presets expose an optional [`generationConfig`](../../06-api-reference/interfaces/LLMProps.md) on the `model` prop. Whenever the upstream model card publishes recommended values (currently Qwen3 and LFM2-VL) the preset carries them and `useLLM` applies them automatically before `isReady` flips — you don't need to call `configure` just to get sensible defaults. Any fields you then pass to `configure` still override on a per-field basis. +::: ### Model configuration example @@ -282,7 +290,9 @@ useEffect(() => { outputTokenBatchSize: 15, batchTimeInterval: 100, temperature: 0.7, - topp: 0.9, + topP: 0.9, + minP: 0.05, + repetitionPenalty: 1.05, }, }); }, [configure]); @@ -491,9 +501,9 @@ Some models support multimodal input — text and images together. To use them, ### Loading a VLM ```tsx -import { useLLM, LFM2_VL_1_6B_QUANTIZED } from 'react-native-executorch'; +import { useLLM, LFM2_5_VL_1_6B_QUANTIZED } from 'react-native-executorch'; -const llm = useLLM({ model: LFM2_VL_1_6B_QUANTIZED }); +const llm = useLLM({ model: LFM2_5_VL_1_6B_QUANTIZED }); ``` The `capabilities` field is already set on the model constant. You can also construct the model object explicitly: @@ -514,7 +524,7 @@ Passing `capabilities` unlocks the typed `media` argument on `sendMessage`. ### Sending a message with an image ```tsx -const llm = useLLM({ model: LFM2_VL_1_6B_QUANTIZED }); +const llm = useLLM({ model: LFM2_5_VL_1_6B_QUANTIZED }); const send = () => { llm.sendMessage('What is in this image?', { @@ -537,7 +547,7 @@ The `imagePath` should be a local file path on the device. You can also use `generate` directly by setting `mediaPath` on user messages: ```tsx -const llm = useLLM({ model: LFM2_VL_1_6B_QUANTIZED }); +const llm = useLLM({ model: LFM2_5_VL_1_6B_QUANTIZED }); const handleGenerate = async () => { const chat: Message[] = [ diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md index 5dd32551a6..c1cf24a9fc 100644 --- a/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md +++ b/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md @@ -107,17 +107,25 @@ To configure model (i.e. change system prompt, load initial conversation history - [`temperature`](../../06-api-reference/interfaces/GenerationConfig.md#temperature) - Scales output logits by the inverse of temperature. Controls the randomness / creativity of text generation. - - [`topp`](../../06-api-reference/interfaces/GenerationConfig.md#topp) - Only samples from the smallest set of tokens whose cumulative probability exceeds topp. + - [`topP`](../../06-api-reference/interfaces/GenerationConfig.md#topp) - Only samples from the smallest set of tokens whose cumulative probability exceeds topP. Range `[0, 1]`. Values of `0` or `1` disable top-p filtering. + + - [`minP`](../../06-api-reference/interfaces/GenerationConfig.md#minp) - Minimum-probability threshold applied after softmax: tokens whose probability is below `minP * max_prob` are excluded from sampling. Range `[0, 1]`. Default `0` disables the filter. Stacks with `topP` when both are set. + + - [`repetitionPenalty`](../../06-api-reference/interfaces/GenerationConfig.md#repetitionpenalty) - Multiplicative penalty applied to logits of tokens that already appeared in the prompt or the generated text. Values greater than `1` discourage repetition; default `1` disables the penalty. + +:::info[Built-in models ship with sampling defaults] +Model presets expose an optional `generationConfig` that `LLMModule.fromModelName` applies automatically when available — for Qwen3 and LFM2-VL this means the model-card recommended sampling settings are in effect without any explicit `configure` call. Any fields you pass to `configure` still override on a per-field basis. +::: ## Vision-Language Models (VLM) Some models support multimodal input — text and images together. To use them, pass `capabilities` in the model object when calling [`fromModelName`](../../06-api-reference/classes/LLMModule.md#frommodelname): ```typescript -import { LLMModule, LFM2_VL_1_6B_QUANTIZED } from 'react-native-executorch'; +import { LLMModule, LFM2_5_VL_1_6B_QUANTIZED } from 'react-native-executorch'; const llm = await LLMModule.fromModelName( - LFM2_VL_1_6B_QUANTIZED, + LFM2_5_VL_1_6B_QUANTIZED, undefined, (token) => console.log(token) ); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 04b0accd34..d311a3de78 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -140,6 +140,15 @@ template class ModelHostObject : public JsiHostObject { synchronousHostFunction<&Model::setTopp>, "setTopp")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::setMinP>, + "setMinP")); + + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, + synchronousHostFunction<&Model::setRepetitionPenalty>, + "setRepetitionPenalty")); + addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, synchronousHostFunction<&Model::getMaxContextLength>, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 10aff5cc77..7e0fa4b26e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -250,6 +250,30 @@ void LLM::setTopp(float topp) { runner_->set_topp(topp); } +void LLM::setMinP(float minP) { + if (!runner_ || !runner_->is_loaded()) { + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, + "Can't configure a model that's not loaded"); + } + if (minP < 0.0f || minP > 1.0f) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, + "Min-p must be between 0.0 and 1.0"); + } + runner_->set_min_p(minP); +} + +void LLM::setRepetitionPenalty(float repetitionPenalty) { + if (!runner_ || !runner_->is_loaded()) { + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, + "Can't configure a model that's not loaded"); + } + if (repetitionPenalty < 0.0f) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, + "Repetition penalty must be non-negative"); + } + runner_->set_repetition_penalty(repetitionPenalty); +} + int32_t LLM::getMaxContextLength() const { if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError( diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h index 5c9bc258d7..222b5bc62f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h @@ -38,6 +38,8 @@ class LLM : public BaseModel { void setCountInterval(size_t countInterval); void setTemperature(float temperature); void setTopp(float topp); + void setMinP(float minP); + void setRepetitionPenalty(float repetitionPenalty); void setTimeInterval(size_t timeInterval); int32_t getMaxContextLength() const; diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index 46f361753e..8286518217 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -151,6 +151,12 @@ add_rn_test(RunnerTests unit/RunnerTest.cpp integration/stubs/jsi_stubs.cpp LIBS tokenizers_deps ) +add_rn_test(SamplerTests unit/SamplerTest.cpp + SOURCES + ${COMMON_DIR}/runner/sampler.cpp + ${COMMON_DIR}/runner/arange_util.cpp + LIBS +) add_rn_test(LogTests unit/LogTest.cpp) add_rn_test(FileUtilsTest unit/FileUtilsTest.cpp) add_rn_test(ImageProcessingTest unit/ImageProcessingTest.cpp diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp index cad365cc61..ae0a11e777 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp @@ -110,6 +110,31 @@ TEST_F(LLMTest, SetToppInvalidThrows) { EXPECT_THROW(model.setTopp(1.1f), RnExecutorchError); } +TEST_F(LLMTest, SetMinP) { + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); + EXPECT_NO_THROW(model.setMinP(0.0f)); + EXPECT_NO_THROW(model.setMinP(0.15f)); + EXPECT_NO_THROW(model.setMinP(1.0f)); +} + +TEST_F(LLMTest, SetMinPInvalidThrows) { + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); + EXPECT_THROW(model.setMinP(-0.1f), RnExecutorchError); + EXPECT_THROW(model.setMinP(1.1f), RnExecutorchError); +} + +TEST_F(LLMTest, SetRepetitionPenalty) { + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); + EXPECT_NO_THROW(model.setRepetitionPenalty(1.0f)); + EXPECT_NO_THROW(model.setRepetitionPenalty(1.05f)); + EXPECT_NO_THROW(model.setRepetitionPenalty(2.0f)); +} + +TEST_F(LLMTest, SetRepetitionPenaltyInvalidThrows) { + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); + EXPECT_THROW(model.setRepetitionPenalty(-0.1f), RnExecutorchError); +} + TEST_F(LLMTest, SetCountInterval) { LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_NO_THROW(model.setCountInterval(5)); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/StubRunner.h b/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/StubRunner.h index 65f1dde66b..023d6cf080 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/StubRunner.h +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/StubRunner.h @@ -18,10 +18,6 @@ class StubRunner : public ::executorch::extension::llm::BaseLLMRunner { return ::executorch::runtime::Error::Ok; } void stop_impl() override {} - void set_temperature_impl(float t) override { last_temp_ = t; } - void set_topp_impl(float) override {} - void set_count_interval_impl(size_t) override {} - void set_time_interval_impl(size_t) override {} int32_t resolve_max(int32_t prompt, int32_t seq_len, int32_t ctx_len, int32_t max_new = -1) const { @@ -29,5 +25,4 @@ class StubRunner : public ::executorch::extension::llm::BaseLLMRunner { } bool loaded_ = false; - float last_temp_ = -1.f; }; diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/RunnerTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/RunnerTest.cpp index d7bd344049..77d0d4ae8d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/unit/RunnerTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/RunnerTest.cpp @@ -62,11 +62,10 @@ TEST(MultimodalInputTest, EmptyStringIsStillText) { // BaseLLMRunner via StubRunner // ============================================================================ -TEST(BaseLLMRunnerTest, SetTemperatureUpdatesConfigAndCallsImpl) { +TEST(BaseLLMRunnerTest, SetTemperatureUpdatesConfig) { StubRunner runner(nullptr, "dummy"); runner.set_temperature(0.42f); EXPECT_FLOAT_EQ(runner.config_.temperature, 0.42f); - EXPECT_FLOAT_EQ(runner.last_temp_, 0.42f); } TEST(BaseLLMRunnerTest, SetToppUpdatesConfig) { @@ -89,3 +88,15 @@ TEST(BaseLLMRunnerTest, GenerateEmptyStringReturnsError) { auto err = runner.generate("", {}, {}, {}); EXPECT_NE(err, ::executorch::runtime::Error::Ok); } + +TEST(BaseLLMRunnerTest, SetMinPUpdatesConfig) { + StubRunner runner(nullptr, "dummy"); + runner.set_min_p(0.15f); + EXPECT_FLOAT_EQ(runner.config_.min_p, 0.15f); +} + +TEST(BaseLLMRunnerTest, SetRepetitionPenaltyUpdatesConfig) { + StubRunner runner(nullptr, "dummy"); + runner.set_repetition_penalty(1.05f); + EXPECT_FLOAT_EQ(runner.config_.repetition_penalty, 1.05f); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp new file mode 100644 index 0000000000..4295f16232 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +using namespace executorch::extension::llm; + +// Helper: run sampler N times, count how often each index is picked. +template +std::vector sampleMany(Sampler &s, std::vector logits, + const std::vector &recent, int n) { + std::vector counts(logits.size(), 0); + for (int i = 0; i < n; ++i) { + std::vector copy = logits; + counts[s.sample(copy.data(), recent)]++; + } + return counts; +} + +// 1. Repetition penalty on positive logit: token 0 should be sampled less. +TEST(SamplerTest, RepetitionPenaltyReducesPositiveLogit) { + Sampler s(2, 1.0f, 1.0f, 0, 0.0f, 1.3f); + std::vector logits = {1.0f, 1.0f}; + std::vector recent = {0}; + auto counts = sampleMany(s, logits, recent, 2000); + EXPECT_LT(counts[0], 1200); +} + +// 2. Repetition penalty on negative logit: penalised token should appear even +// less. +TEST(SamplerTest, RepetitionPenaltyMultipliesNegativeLogit) { + Sampler s(2, 1.0f, 1.0f, 0, 0.0f, 1.5f); + std::vector logits = {0.0f, -1.0f}; + std::vector recent = {1}; + auto counts = sampleMany(s, logits, recent, 2000); + EXPECT_LT(counts[1], 200); +} + +// 3. No recent tokens — penalty has no effect. +TEST(SamplerTest, RepetitionPenaltyNoRecentTokensHasNoEffect) { + Sampler baseline(2, 1.0f, 1.0f, 0, 0.0f, 1.0f); + Sampler penalised(2, 1.0f, 1.0f, 0, 0.0f, 2.0f); + std::vector logits_b = {1.0f, 1.0f}; + std::vector logits_p = {1.0f, 1.0f}; + std::vector recent = {}; + auto cb = sampleMany(baseline, logits_b, recent, 2000); + auto cp = sampleMany(penalised, logits_p, recent, 2000); + EXPECT_NEAR(cb[0], cp[0], 300); +} + +// 4. Min-p truncation: token with very low probability is excluded. +TEST(SamplerTest, MinPFiltersTailTokens) { + Sampler s(3, 1.0f, 1.0f, 0, 0.1f, 1.0f); + std::vector logits = {5.0f, -5.0f, -5.0f}; + std::vector recent = {}; + auto counts = sampleMany(s, logits, recent, 1000); + EXPECT_EQ(counts[1], 0); + EXPECT_EQ(counts[2], 0); + EXPECT_EQ(counts[0], 1000); +} + +// 5. Min-p = 0 disables filtering. +TEST(SamplerTest, MinPZeroDisablesFiltering) { + Sampler s(3, 0.0f, 1.0f, 0, 0.0f, 1.0f); + std::vector logits = {1.0f, -1000.0f, -1000.0f}; + std::vector recent = {}; + EXPECT_EQ(s.sample(logits.data(), recent), 0); +} + +// 6. Min-p + top-p stacked. +TEST(SamplerTest, MinPAndToppStack) { + Sampler s(4, 1.0f, 0.5f, 0, 0.2f, 1.0f); + std::vector logits = {5.0f, 2.0f, -2.0f, -5.0f}; + std::vector recent = {}; + auto counts = sampleMany(s, logits, recent, 2000); + EXPECT_EQ(counts[2], 0); + EXPECT_EQ(counts[3], 0); +} diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.cpp b/packages/react-native-executorch/common/runner/base_llm_runner.cpp index d38a41b76b..a021040807 100644 --- a/packages/react-native-executorch/common/runner/base_llm_runner.cpp +++ b/packages/react-native-executorch/common/runner/base_llm_runner.cpp @@ -139,20 +139,22 @@ int32_t BaseLLMRunner::get_max_context_length() const { void BaseLLMRunner::set_temperature(float temperature) noexcept { config_.temperature = temperature; - set_temperature_impl(temperature); } -void BaseLLMRunner::set_topp(float topp) noexcept { - config_.topp = topp; - set_topp_impl(topp); +void BaseLLMRunner::set_topp(float topp) noexcept { config_.topp = topp; } + +void BaseLLMRunner::set_min_p(float min_p) noexcept { config_.min_p = min_p; } + +void BaseLLMRunner::set_repetition_penalty(float repetition_penalty) noexcept { + config_.repetition_penalty = repetition_penalty; } void BaseLLMRunner::set_count_interval(size_t count_interval) { - set_count_interval_impl(count_interval); + config_.output_token_batch_size = count_interval; } void BaseLLMRunner::set_time_interval(size_t time_interval) { - set_time_interval_impl(time_interval); + config_.batch_time_interval_ms = time_interval; } int32_t BaseLLMRunner::resolve_max_new_tokens(int32_t num_prompt_tokens, diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.h b/packages/react-native-executorch/common/runner/base_llm_runner.h index 3924aa3d7a..9710f5ae70 100644 --- a/packages/react-native-executorch/common/runner/base_llm_runner.h +++ b/packages/react-native-executorch/common/runner/base_llm_runner.h @@ -53,6 +53,8 @@ class BaseLLMRunner { void set_temperature(float temperature) noexcept; void set_topp(float topp) noexcept; + void set_min_p(float min_p) noexcept; + void set_repetition_penalty(float repetition_penalty) noexcept; void set_count_interval(size_t count_interval); void set_time_interval(size_t time_interval); @@ -65,10 +67,12 @@ class BaseLLMRunner { protected: virtual ::executorch::runtime::Error load_subcomponents() = 0; virtual void stop_impl() = 0; - virtual void set_temperature_impl(float temperature) = 0; - virtual void set_topp_impl(float topp) = 0; - virtual void set_count_interval_impl(size_t count_interval) = 0; - virtual void set_time_interval_impl(size_t time_interval) = 0; + // Sampling values and token-batching intervals live entirely in `config_`. + // The TextDecoderRunner / TextTokenGenerator shared by both TextRunner and + // MultimodalRunner are constructed with a const reference to `config_` + // and read those fields on every iteration, so writes via the public + // set_* methods on BaseLLMRunner take effect immediately with no virtual + // dispatch needed. int32_t resolve_max_new_tokens(int32_t num_prompt_tokens, int32_t max_seq_len, int32_t max_context_len, diff --git a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp index 4f664a173e..de3e196c1f 100644 --- a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp +++ b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp @@ -77,15 +77,23 @@ Result VisionEncoder::getInputShape() const { std::vector VisionEncoder::preprocessImage(const std::string &path, const ImageShape &targetShape) const { - cv::Mat mat = rnexecutorch::image_processing::readImage(path); - cv::resize(mat, mat, cv::Size(targetShape.width, targetShape.height)); - cv::cvtColor(mat, mat, cv::COLOR_BGR2RGB); + // The bundled vision-encoder PTEs (e.g. LFM2.5-VL) bake rescale + normalize + // into the exported graph, so we hand raw 0-255 float pixel values to the + // module. Adding rescale / normalize here would double-apply the transform + // and destroy the input distribution. We reuse `resizePadded` for the + // aspect-ratio-preserving letterbox (it picks the pad colour from the + // source image corners, which blends better than a flat gray), then + // convert BGR->RGB and repack the raw pixels into CHW float. + cv::Mat src = rnexecutorch::image_processing::readImage(path); + cv::Mat canvas = rnexecutorch::image_processing::resizePadded( + src, cv::Size(targetShape.width, targetShape.height)); + cv::cvtColor(canvas, canvas, cv::COLOR_BGR2RGB); const int32_t pixelCount = targetShape.height * targetShape.width; std::vector chw(targetShape.channels * pixelCount); for (int32_t i = 0; i < pixelCount; ++i) { cv::Vec3b px = - mat.at(i / targetShape.width, i % targetShape.width); + canvas.at(i / targetShape.width, i % targetShape.width); for (int32_t c = 0; c < targetShape.channels; ++c) { chw[c * pixelCount + i] = static_cast(px[c]); } diff --git a/packages/react-native-executorch/common/runner/irunner.h b/packages/react-native-executorch/common/runner/irunner.h index 8dedc6687e..54b14c354f 100644 --- a/packages/react-native-executorch/common/runner/irunner.h +++ b/packages/react-native-executorch/common/runner/irunner.h @@ -58,6 +58,21 @@ struct GenerationConfig { // = more deterministic, higher = more diverse generations. float topp = -1.F; + // Minimum probability threshold: tokens with prob < min_p * max_prob are + // excluded. 0.0 disables min_p filtering. + float min_p = 0.0f; + + // Multiplicative penalty applied to logits of recently generated tokens. + // Values > 1.0 discourage repetition. 1.0 disables the penalty. + float repetition_penalty = 1.0f; + + // Token-batching parameters for the streaming token callback. The + // generator flushes a batch when either `output_token_batch_size` tokens + // have accumulated or `batch_time_interval_ms` milliseconds have elapsed + // since the last flush, whichever comes first. + size_t output_token_batch_size = 10; + size_t batch_time_interval_ms = 120; + // Enable dynamic input shapes (if implemented) or not // Impacts the prefill phase and causes TextPrefiller to pass all the tokens // at once if set to true. diff --git a/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h index e6fa4e7eae..071b193539 100644 --- a/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h +++ b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h @@ -16,8 +16,9 @@ namespace executorch::extension::llm { class MultimodalDecoderRunner : public TextDecoderRunner { public: - explicit MultimodalDecoderRunner(Module &module, IOManager *io_manager) - : TextDecoderRunner(module, io_manager) {} + explicit MultimodalDecoderRunner(Module &module, IOManager *io_manager, + const GenerationConfig &config) + : TextDecoderRunner(module, io_manager, config) {} inline ::executorch::runtime::Result<::executorch::aten::Tensor> step(TensorPtr &tokens, int64_t start_pos) override { diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.cpp b/packages/react-native-executorch/common/runner/multimodal_runner.cpp index c81d15760e..767fef9f38 100644 --- a/packages/react-native-executorch/common/runner/multimodal_runner.cpp +++ b/packages/react-native-executorch/common/runner/multimodal_runner.cpp @@ -47,8 +47,8 @@ Error MultimodalRunner::load_subcomponents() { Stats *stats_ptr = &stats_; - mm_decoder_runner_ = - std::make_unique(*module_, io_manager_.get()); + mm_decoder_runner_ = std::make_unique( + *module_, io_manager_.get(), config_); IEncoder *image_encoder = nullptr; auto enc_it = encoders_.find(MultimodalType::Image); if (enc_it != encoders_.end()) { @@ -58,7 +58,7 @@ Error MultimodalRunner::load_subcomponents() { *module_, *mm_decoder_runner_, *tokenizer_, image_encoder); mm_token_generator_ = std::make_unique( tokenizer_.get(), mm_decoder_runner_.get(), /*use_kv_cache=*/true, - std::move(eos_ids_), stats_ptr); + std::move(eos_ids_), stats_ptr, config_); ET_CHECK_OK_OR_RETURN_ERROR(mm_prefiller_->load()); ET_CHECK_OK_OR_RETURN_ERROR(mm_token_generator_->load()); @@ -106,7 +106,7 @@ Error MultimodalRunner::generate_internal( auto generate_result = mm_token_generator_->generate( seed_tokens, pos_, static_cast(std::max(0, resolved_max_new - 1)), - config_.temperature, config_.topp, wrapped_callback); + wrapped_callback); if (!generate_result.ok()) return generate_result.error(); @@ -125,16 +125,4 @@ void MultimodalRunner::stop_impl() { } } -void MultimodalRunner::set_count_interval_impl(size_t count_interval) { - if (mm_token_generator_) { - mm_token_generator_->set_count_interval(count_interval); - } -} - -void MultimodalRunner::set_time_interval_impl(size_t time_interval) { - if (mm_token_generator_) { - mm_token_generator_->set_time_interval(time_interval); - } -} - } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.h b/packages/react-native-executorch/common/runner/multimodal_runner.h index 3c31c0165f..d24e0b40c2 100644 --- a/packages/react-native-executorch/common/runner/multimodal_runner.h +++ b/packages/react-native-executorch/common/runner/multimodal_runner.h @@ -30,10 +30,6 @@ class MultimodalRunner : public BaseLLMRunner { protected: ::executorch::runtime::Error load_subcomponents() override; void stop_impl() override; - void set_temperature_impl(float) override {} - void set_topp_impl(float) override {} - void set_count_interval_impl(size_t count_interval) override; - void set_time_interval_impl(size_t time_interval) override; private: std::map> encoders_; diff --git a/packages/react-native-executorch/common/runner/sampler.cpp b/packages/react-native-executorch/common/runner/sampler.cpp index a5fadef93c..26c75d4dd5 100644 --- a/packages/react-native-executorch/common/runner/sampler.cpp +++ b/packages/react-native-executorch/common/runner/sampler.cpp @@ -35,6 +35,7 @@ #include "sampler.h" #include #include +#include namespace executorch { namespace extension { @@ -119,16 +120,16 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) { return probindex[last_idx].index; // in case of rounding errors } -Sampler::Sampler(int vocab_size, float temperature, float topp, - unsigned long long rng_seed) +Sampler::Sampler(int32_t vocab_size, float temperature, float topp, + unsigned long long rng_seed, float min_p, + float repetition_penalty) : vocab_size_(vocab_size), inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f), - topp_(topp), rng_state_(rng_seed) {} + topp_(topp), min_p_(min_p), repetition_penalty_(repetition_penalty), + rng_state_(rng_seed) {} Sampler::Sampler(int vocab_size, float temperature, float topp) - : vocab_size_(vocab_size), - inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f), - topp_(topp), rng_state_(std::time(nullptr)) {} + : Sampler(vocab_size, temperature, topp, std::time(nullptr), 0.0f, 1.0f) {} template static void softmax(T *x, int size) { // find max value (for numerical stability) @@ -162,22 +163,25 @@ static float random_f32(unsigned long long *state) { // random float32 in [0,1) return (random_u32(state) >> 8) / 16777216.0f; } -template int32_t Sampler::sample(T *logits) { +template +int32_t Sampler::sample(T *logits, const std::vector &recent_tokens) { // sample the token given the logits and some hyperparameters int next; if (inv_temperature_ == 0.0f) { // greedy argmax sampling: take the token with the highest probability next = sample_argmax(logits); } else { - // apply the temperature to the logits - for (int q = 0; q < vocab_size_; q++) { - logits[q] *= inv_temperature_; - } - // apply softmax to the logits to get the probabilities for next token + // 1. apply repetition penalty to raw logits (pre-softmax) + apply_repetition_penalty(logits, vocab_size_, recent_tokens); + // 2. apply the temperature to the logits + apply_temperature(logits, vocab_size_); + // 3. apply softmax to the logits to get the probabilities for next token softmax(logits, vocab_size_); + // 4. apply min_p truncation + apply_min_p(logits, vocab_size_); // flip a (float) coin (this is our source of entropy for sampling) float coin = random_f32(&rng_state_); - // we sample from this distribution to get the next token + // 5. we sample from this distribution to get the next token if (topp_ <= 0 || topp_ >= 1) { // simply sample from the predicted probability distribution next = sample_mult(logits, coin); @@ -189,6 +193,10 @@ template int32_t Sampler::sample(T *logits) { return next; } +template int32_t Sampler::sample(T *logits) { + return sample(logits, {}); +} + template int32_t Sampler::sample(float *logits); template int32_t Sampler::sample(uint16_t *logits); template int32_t @@ -196,6 +204,17 @@ Sampler::sample(executorch::aten::Half *logits); template int32_t Sampler::sample(executorch::aten::BFloat16 *logits); +template int32_t Sampler::sample(float *logits, + const std::vector &); +template int32_t Sampler::sample(uint16_t *logits, + const std::vector &); +template int32_t +Sampler::sample(executorch::aten::Half *logits, + const std::vector &); +template int32_t +Sampler::sample(executorch::aten::BFloat16 *logits, + const std::vector &); + } // namespace llm } // namespace extension } // namespace executorch diff --git a/packages/react-native-executorch/common/runner/sampler.h b/packages/react-native-executorch/common/runner/sampler.h index a46a5ed12c..16811297ef 100644 --- a/packages/react-native-executorch/common/runner/sampler.h +++ b/packages/react-native-executorch/common/runner/sampler.h @@ -8,12 +8,15 @@ #pragma once +#include #include #include #include #include #include #include +#include +#include #ifdef USE_ATEN_LIB #include #endif @@ -36,22 +39,77 @@ template struct ProbIndex { class Sampler { public: Sampler(int32_t vocab_size, float temperature, float topp, - unsigned long long rng_seed); + unsigned long long rng_seed, float min_p = 0.0f, + float repetition_penalty = 1.0f); Sampler(int32_t vocab_size, float temperature, float topp); template int32_t sample(T *logits); + template + int32_t sample(T *logits, const std::vector &recent_tokens); + private: template int32_t sample_topp(T *probabilities, float coin); template int32_t sample_mult(T *probabilities, float coin); template int32_t sample_argmax(T *probabilities); + template + inline void apply_temperature(T *logits, int32_t vocab_size) { + for (std::size_t i = 0; std::cmp_less(i, vocab_size); ++i) { + logits[i] = + static_cast(static_cast(logits[i]) * inv_temperature_); + } + } + + template + inline void + apply_repetition_penalty(T *logits, int32_t vocab_size, + const std::vector &recent_tokens) { + if (repetition_penalty_ == 1.0f || recent_tokens.empty()) + return; + for (uint64_t id : recent_tokens) { + if (!std::cmp_less(id, vocab_size)) { + continue; + } + T &val = logits[id]; + if (val > T(0)) { + val = static_cast(static_cast(val) / repetition_penalty_); + } else { + val = static_cast(static_cast(val) * repetition_penalty_); + } + } + } + + template + inline void apply_min_p(T *probabilities, int32_t vocab_size) { + if (min_p_ <= 0.0f) { + return; + } + T max_prob = *std::max_element(probabilities, probabilities + vocab_size); + T threshold = static_cast(min_p_ * static_cast(max_prob)); + T sum = T(0); + for (std::size_t i = 0; std::cmp_less(i, vocab_size); ++i) { + if (probabilities[i] < threshold) { + probabilities[i] = T(0); + } else { + sum += probabilities[i]; + } + } + if (sum > T(0)) { + for (std::size_t i = 0; std::cmp_less(i, vocab_size); ++i) { + probabilities[i] /= sum; + } + } + } + private: int32_t vocab_size_; // reciprocal of temperature, or 0 if temperature == 0. float inv_temperature_; float topp_; + float min_p_; + float repetition_penalty_; unsigned long long rng_state_; }; diff --git a/packages/react-native-executorch/common/runner/text_decoder_runner.cpp b/packages/react-native-executorch/common/runner/text_decoder_runner.cpp index fdd9e4489c..e67d3e41fb 100644 --- a/packages/react-native-executorch/common/runner/text_decoder_runner.cpp +++ b/packages/react-native-executorch/common/runner/text_decoder_runner.cpp @@ -10,6 +10,7 @@ #include "text_decoder_runner.h" #include "arange_util.h" +#include "irunner.h" #include "stats.h" #include @@ -22,9 +23,8 @@ namespace llm { // and a ~5% improvement on Galaxy S22 by switching to // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors. TextDecoderRunner::TextDecoderRunner(Module &module, IOManager *io_manager, - float temperature, float topp) - : module_(&module), io_manager_(io_manager), temperature_(temperature), - topp_(topp) {} + const GenerationConfig &config) + : module_(&module), io_manager_(io_manager), config_(config) {} // This function is functional, meaning it shouldn't modify any state of the // input. It should be safe to call multiple times with the same inputs. The @@ -82,6 +82,34 @@ TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) { } } +int32_t TextDecoderRunner::logits_to_token( + const executorch::aten::Tensor &logits_tensor, + const std::vector &recent_tokens) { + int32_t result = 0; + + struct { + [[noreturn]] void fail(torch::executor::Error) { + ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token"); + } + } ctx; + + ET_SWITCH_FOUR_TYPES( + Float, Half, BFloat16, UInt16, logits_tensor.scalar_type(), ctx, + "logits_to_token", CTYPE, [&]() { + auto *logits = logits_tensor.mutable_data_ptr(); + ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); + if (logits_tensor.dim() == 3) { + auto num_tokens = logits_tensor.size(1); + logits += (num_tokens - 1) * vocab_size; + } + Sampler sampler(vocab_size, config_.temperature, config_.topp, + static_cast(std::time(nullptr)), + config_.min_p, config_.repetition_penalty); + result = sampler.sample(logits, recent_tokens); + }); + return result; +} + } // namespace llm } // namespace extension } // namespace executorch diff --git a/packages/react-native-executorch/common/runner/text_decoder_runner.h b/packages/react-native-executorch/common/runner/text_decoder_runner.h index bd318da526..bffc254bd6 100644 --- a/packages/react-native-executorch/common/runner/text_decoder_runner.h +++ b/packages/react-native-executorch/common/runner/text_decoder_runner.h @@ -17,10 +17,15 @@ namespace executorch { namespace extension { namespace llm { +// Forward declaration to avoid the include chain +// irunner.h -> stats.h -> util.h -> text_prefiller.h -> text_decoder_runner.h +// which would re-enter this header before TextDecoderRunner is defined. +struct GenerationConfig; + class TextDecoderRunner { public: explicit TextDecoderRunner(Module &module, IOManager *io_manager, - float temperature = 0.8F, float topp = 0.9F); + const GenerationConfig &config); virtual ~TextDecoderRunner() = default; @@ -50,61 +55,23 @@ class TextDecoderRunner { return module_->is_method_loaded("forward"); } - virtual void set_temperature(float temperature) noexcept { - temperature_ = temperature; - } - - virtual void set_topp(float topp) noexcept { topp_ = topp; } - inline void stop() { should_stop_ = true; } /** - * Sample the next token from the logits tensor. - * @param logits_tensor The logits tensor. - * @param temperature The temperature parameter used to control randomness in - * sampling. - * @return The next token. + * Sample the next token from the logits tensor using the sampling + * parameters in `config_`. `recent_tokens` is the prompt-plus-generated + * window used by the repetition penalty. Defined out-of-line in the cpp + * so this header doesn't need a complete `GenerationConfig` type. */ - inline int32_t logits_to_token(const executorch::aten::Tensor &logits_tensor, - float temperature = -1.F, float topp = -1.F) { - int32_t result = 0; - - temperature = temperature < 0.F ? temperature_ : temperature; - topp = topp < 0.F ? topp_ : topp; - - // Create a minimal context for error handling in ET_SWITCH - struct { - [[noreturn]] void fail(torch::executor::Error /* error */) { - ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token"); - } - } ctx; - - ET_SWITCH_FOUR_TYPES( - Float, Half, BFloat16, UInt16, logits_tensor.scalar_type(), ctx, - "logits_to_token", CTYPE, [&]() { - // If the logit_tensor rank is 3, the shape is [batch, seq_length, - // vocab_size], get the last logits, sample and return. Else the model - // outputs the last logit, directly sample and return. - auto *logits = logits_tensor.mutable_data_ptr(); - ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); - if (logits_tensor.dim() == 3) { - auto num_tokens = logits_tensor.size(1); - logits += (num_tokens - 1) * vocab_size; - } - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - Sampler sampler(vocab_size, temperature, topp); - result = sampler.sample(logits); - }); - return result; - } + int32_t logits_to_token(const executorch::aten::Tensor &logits_tensor, + const std::vector &recent_tokens = {}); protected: // Non-owning. The runner (BaseLLMRunner) owns the Module and outlives this. Module *module_; IOManager *io_manager_; bool should_stop_{false}; - float temperature_; - float topp_; + const GenerationConfig &config_; }; } // namespace llm diff --git a/packages/react-native-executorch/common/runner/text_runner.cpp b/packages/react-native-executorch/common/runner/text_runner.cpp index 6a0e202084..5a75e00b4a 100644 --- a/packages/react-native-executorch/common/runner/text_runner.cpp +++ b/packages/react-native-executorch/common/runner/text_runner.cpp @@ -27,13 +27,13 @@ Error TextRunner::load_subcomponents() { Stats *stats_ptr = &stats_; text_decoder_runner_ = std::make_unique( - *module_, io_manager_.get(), config_.temperature, config_.topp); + *module_, io_manager_.get(), config_); text_prefiller_ = std::make_unique( text_decoder_runner_.get(), config_.enable_kv_cache, config_.enable_dynamic_shape, config_.max_seq_len); text_token_generator_ = std::make_unique( tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache, - std::move(eos_ids_), stats_ptr); + std::move(eos_ids_), stats_ptr, config_); return Error::Ok; } @@ -112,8 +112,7 @@ Error TextRunner::generate_internal( prompt_tokens.push_back(cur_token); int64_t num_generated = ET_UNWRAP(text_token_generator_->generate( - prompt_tokens, pos_, max_new_tokens - 1, config_.temperature, - config_.topp, wrapped_callback)); + prompt_tokens, pos_, max_new_tokens - 1, wrapped_callback)); pos_ += num_generated; stats_.inference_end_ms = time_in_ms(); @@ -128,26 +127,4 @@ void TextRunner::stop_impl() { text_token_generator_->stop(); } -void TextRunner::set_temperature_impl(float temperature) { - if (text_decoder_runner_) - text_decoder_runner_->set_temperature(temperature); -} - -void TextRunner::set_topp_impl(float topp) { - if (text_decoder_runner_) - text_decoder_runner_->set_topp(topp); -} - -void TextRunner::set_count_interval_impl(size_t count_interval) { - if (text_token_generator_) { - text_token_generator_->set_count_interval(count_interval); - } -} - -void TextRunner::set_time_interval_impl(size_t time_interval) { - if (text_token_generator_) { - text_token_generator_->set_time_interval(time_interval); - } -} - } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/text_runner.h b/packages/react-native-executorch/common/runner/text_runner.h index 4fce0e8150..096d3498d7 100644 --- a/packages/react-native-executorch/common/runner/text_runner.h +++ b/packages/react-native-executorch/common/runner/text_runner.h @@ -24,10 +24,6 @@ class TextRunner : public BaseLLMRunner { protected: ::executorch::runtime::Error load_subcomponents() override; void stop_impl() override; - void set_temperature_impl(float temperature) override; - void set_topp_impl(float topp) override; - void set_count_interval_impl(size_t count_interval) override; - void set_time_interval_impl(size_t time_interval) override; private: std::unique_ptr text_decoder_runner_; diff --git a/packages/react-native-executorch/common/runner/text_token_generator.h b/packages/react-native-executorch/common/runner/text_token_generator.h index 024cce456e..7ecf6177a9 100644 --- a/packages/react-native-executorch/common/runner/text_token_generator.h +++ b/packages/react-native-executorch/common/runner/text_token_generator.h @@ -9,6 +9,7 @@ // Generate tokens in a loop. #pragma once +#include "irunner.h" #include "stats.h" #include "text_decoder_runner.h" #include "util.h" @@ -22,13 +23,18 @@ namespace llm { class TextTokenGenerator { public: + // Holds a const reference to the owning runner's GenerationConfig and + // reads `output_token_batch_size` / `batch_time_interval_ms` from it on + // every iteration of the generation loop, so external updates take effect + // mid-stream without any sync setter on this class. TextTokenGenerator(tokenizers::HFTokenizer *tokenizer, TextDecoderRunner *text_decoder_runner, bool use_kv_cache, std::unique_ptr> &&eos_ids, - Stats *stats) + Stats *stats, const GenerationConfig &config) : tokenizer_(tokenizer), text_decoder_runner_(text_decoder_runner), eos_ids_(std::move(eos_ids)), use_kv_cache_(use_kv_cache), - timestamp_(std::chrono::high_resolution_clock::now()), stats_(stats) {} + timestamp_(std::chrono::high_resolution_clock::now()), stats_(stats), + config_(config) {} virtual ~TextTokenGenerator() = default; @@ -48,7 +54,6 @@ class TextTokenGenerator { */ inline ::executorch::runtime::Result generate( std::vector tokens, int64_t start_pos, uint64_t max_new_tokens, - float temperature, float topp, const std::function &token_callback = {}) { ET_CHECK_MSG(!tokens.empty(), "Token generation loop shouldn't take empty tokens"); @@ -80,6 +85,7 @@ class TextTokenGenerator { auto tokens_managed = from_blob(token_data.data(), token_shape, executorch::aten::ScalarType::Long); + std::vector generated_tokens(tokens.begin(), tokens.end()); should_stop_ = false; timestamp_ = std::chrono::high_resolution_clock::now(); @@ -94,11 +100,12 @@ class TextTokenGenerator { prev_token = cur_token; stats_->on_sampling_begin(); - cur_token = text_decoder_runner_->logits_to_token(logits_tensor, - temperature, topp); + cur_token = + text_decoder_runner_->logits_to_token(logits_tensor, generated_tokens); stats_->on_sampling_end(); pos++; + generated_tokens.push_back(cur_token); if (use_kv_cache_) { // update the token tensor. token_data will not be empty. @@ -128,9 +135,10 @@ class TextTokenGenerator { std::string cache_decoded = decodeResult.get(); const auto timeIntervalElapsed = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - timestamp_) > - time_interval_; - const auto countIntervalElapsed = token_cache.size() > count_interval_; + std::chrono::high_resolution_clock::now() - timestamp_) + .count() > static_cast(config_.batch_time_interval_ms); + const auto countIntervalElapsed = + token_cache.size() > config_.output_token_batch_size; const auto eos_reached = eos_ids_->contains(cur_token); if (!cache_decoded.ends_with("�") && @@ -177,14 +185,6 @@ class TextTokenGenerator { return text_decoder_runner_->is_method_loaded(); } - void set_count_interval(size_t count_interval) { - count_interval_ = count_interval; - } - - void set_time_interval(size_t time_interval) { - time_interval_ = std::chrono::milliseconds(time_interval); - } - private: /** * Note: TextTokenGenerator does not own the tokenizer_ and @@ -199,12 +199,14 @@ class TextTokenGenerator { // state machine bool should_stop_ = false; - size_t count_interval_{10}; - std::chrono::milliseconds time_interval_{120}; std::chrono::high_resolution_clock::time_point timestamp_; // stats Stats *stats_; + + // Reference to the owning runner's GenerationConfig. Token-batching + // intervals are read fresh on every iteration. + const GenerationConfig &config_; }; } // namespace llm diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index ac00820c5d..7a939e9f55 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -83,6 +83,14 @@ const QWEN3_4B_QUANTIZED_MODEL = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/qwen-3-4B/ const QWEN3_TOKENIZER = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/tokenizer.json`; const QWEN3_TOKENIZER_CONFIG = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/tokenizer_config.json`; +// Qwen3's published generation_config.json recommends temperature=0.6 and +// top_p=0.95. We propagate those to every Qwen3 preset so model quality is +// reasonable out of the box; users can override via `configure()`. +const QWEN3_GENERATION_CONFIG = { + temperature: 0.6, + topP: 0.95, +} as const; + /** * @category Models - LLM */ @@ -91,6 +99,7 @@ export const QWEN3_0_6B = { modelSource: QWEN3_0_6B_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, + generationConfig: QWEN3_GENERATION_CONFIG, } as const; /** @@ -101,6 +110,7 @@ export const QWEN3_0_6B_QUANTIZED = { modelSource: QWEN3_0_6B_QUANTIZED_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, + generationConfig: QWEN3_GENERATION_CONFIG, } as const; /** @@ -111,6 +121,7 @@ export const QWEN3_1_7B = { modelSource: QWEN3_1_7B_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, + generationConfig: QWEN3_GENERATION_CONFIG, } as const; /** @@ -121,6 +132,7 @@ export const QWEN3_1_7B_QUANTIZED = { modelSource: QWEN3_1_7B_QUANTIZED_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, + generationConfig: QWEN3_GENERATION_CONFIG, } as const; /** @@ -131,6 +143,7 @@ export const QWEN3_4B = { modelSource: QWEN3_4B_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, + generationConfig: QWEN3_GENERATION_CONFIG, } as const; /** @@ -141,6 +154,7 @@ export const QWEN3_4B_QUANTIZED = { modelSource: QWEN3_4B_QUANTIZED_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, + generationConfig: QWEN3_GENERATION_CONFIG, } as const; // HAMMER 2.1 @@ -500,25 +514,52 @@ const LFM2_VL_450M_TOKENIZER_CONFIG = `${URL_PREFIX}-lfm-2.5/${VERSION_TAG}/lfm2 /** * @category Models - VLM */ -export const LFM2_VL_1_6B_QUANTIZED = { +// LiquidAI's LFM2-VL model card recommends the following sampling settings. +// Without them the model often produces generic / repetitive responses. +const LFM2_5_VL_GENERATION_CONFIG = { + temperature: 0.1, + minP: 0.15, + repetitionPenalty: 1.05, +} as const; + +/** + * @category Models - VLM + */ +export const LFM2_5_VL_1_6B_QUANTIZED = { modelName: 'lfm2.5-vl-1.6b-quantized', capabilities: ['vision'], modelSource: LFM2_VL_1_6B_QUANTIZED_MODEL, tokenizerSource: LFM2_VL_1_6B_TOKENIZER, tokenizerConfigSource: LFM2_VL_1_6B_TOKENIZER_CONFIG, + generationConfig: LFM2_5_VL_GENERATION_CONFIG, } as const; /** * @category Models - VLM */ -export const LFM2_VL_450M_QUANTIZED = { +export const LFM2_5_VL_450M_QUANTIZED = { modelName: 'lfm2.5-vl-450m-quantized', capabilities: ['vision'], modelSource: LFM2_VL_450M_QUANTIZED_MODEL, tokenizerSource: LFM2_VL_450M_TOKENIZER, tokenizerConfigSource: LFM2_VL_450M_TOKENIZER_CONFIG, + generationConfig: LFM2_5_VL_GENERATION_CONFIG, } as const; +/** + * @deprecated Use `LFM2_5_VL_1_6B_QUANTIZED` instead — the model is from the + * LFM2.5 family. This alias will be removed in a future major release. + * @category Models - VLM + */ +export const LFM2_VL_1_6B_QUANTIZED = LFM2_5_VL_1_6B_QUANTIZED; + +/** + * @deprecated Use `LFM2_5_VL_450M_QUANTIZED` instead — the model is from the + * LFM2.5 family. This alias will be removed in a future major release. + * @category Models - VLM + */ +export const LFM2_VL_450M_QUANTIZED = LFM2_5_VL_450M_QUANTIZED; + // Classification const EFFICIENTNET_V2_S_MODEL = Platform.OS === `ios` @@ -1185,6 +1226,8 @@ export const MODEL_REGISTRY = { LFM2_5_350M_QUANTIZED, LFM2_5_1_2B_INSTRUCT, LFM2_5_1_2B_INSTRUCT_QUANTIZED, + LFM2_5_VL_1_6B_QUANTIZED, + LFM2_5_VL_450M_QUANTIZED, LFM2_VL_1_6B_QUANTIZED, LFM2_VL_450M_QUANTIZED, BIELIK_V3_0_1_5B, diff --git a/packages/react-native-executorch/src/controllers/LLMController.ts b/packages/react-native-executorch/src/controllers/LLMController.ts index d9d5593ae3..696825eee3 100644 --- a/packages/react-native-executorch/src/controllers/LLMController.ts +++ b/packages/react-native-executorch/src/controllers/LLMController.ts @@ -76,12 +76,14 @@ export class LLMController { tokenizerSource, tokenizerConfigSource, capabilities, + defaultGenerationConfig, onDownloadProgressCallback, }: { modelSource: ResourceSource; tokenizerSource: ResourceSource; tokenizerConfigSource: ResourceSource; capabilities?: readonly LLMCapability[]; + defaultGenerationConfig?: GenerationConfig; onDownloadProgressCallback?: (downloadProgress: number) => void; }) { // reset inner state when loading new model @@ -130,6 +132,12 @@ export class LLMController { tokenizerPath, capabilities ?? [] ); + if (defaultGenerationConfig) { + // Apply model-specific recommended sampling defaults before flipping + // isReady so callers that react to it see the right config on first + // send. User-provided `configure()` calls still override these. + this.applyGenerationConfig(defaultGenerationConfig); + } this.isReadyCallback(true); this.onToken = (data: string) => { if (!data) { @@ -166,28 +174,58 @@ export class LLMController { this.chatConfig = { ...DEFAULT_CHAT_CONFIG, ...chatConfig }; this.toolsConfig = toolsConfig; - if (generationConfig?.outputTokenBatchSize) { + if (generationConfig) { + this.applyGenerationConfig(generationConfig); + } + + // reset inner state when loading new configuration + this.messageHistoryCallback(this.chatConfig.initialMessageHistory); + this.isGeneratingCallback(false); + } + + private applyGenerationConfig(generationConfig: GenerationConfig) { + if (generationConfig.outputTokenBatchSize) { this.nativeModule.setCountInterval(generationConfig.outputTokenBatchSize); } - if (generationConfig?.batchTimeInterval) { + if (generationConfig.batchTimeInterval) { this.nativeModule.setTimeInterval(generationConfig.batchTimeInterval); } - if (generationConfig?.temperature) { + if (generationConfig.temperature !== undefined) { this.nativeModule.setTemperature(generationConfig.temperature); } - if (generationConfig?.topp) { - if (generationConfig.topp < 0 || generationConfig.topp > 1) { + // `topp` is the legacy spelling kept for backwards compatibility — `topP` + // wins when both are set so callers migrating to the new name don't get + // surprised by stale values. Reading the deprecated alias is intentional. + const topP = generationConfig.topP ?? generationConfig.topp; + if (topP !== undefined) { + if (topP < 0 || topP > 1) { throw new RnExecutorchError( RnExecutorchErrorCode.InvalidConfig, 'Top P has to be in range [0, 1]' ); } - this.nativeModule.setTopp(generationConfig.topp); + this.nativeModule.setTopp(topP); + } + if (generationConfig.minP !== undefined) { + if (generationConfig.minP < 0 || generationConfig.minP > 1) { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidConfig, + 'Min P has to be in range [0, 1]' + ); + } + this.nativeModule.setMinP(generationConfig.minP); + } + if (generationConfig.repetitionPenalty !== undefined) { + if (generationConfig.repetitionPenalty < 0) { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidConfig, + 'Repetition penalty must be non-negative' + ); + } + this.nativeModule.setRepetitionPenalty( + generationConfig.repetitionPenalty + ); } - - // reset inner state when loading new configuration - this.messageHistoryCallback(this.chatConfig.initialMessageHistory); - this.isGeneratingCallback(false); } private getImageToken(): string { diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index feb7a931e3..027e237997 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -62,6 +62,7 @@ export function useLLM({ tokenizerSource: model.tokenizerSource, tokenizerConfigSource: model.tokenizerConfigSource!, capabilities: model.capabilities, + defaultGenerationConfig: model.generationConfig, onDownloadProgressCallback: setDownloadProgress, }); } catch (e) { diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index 9aff4a4067..6254775c15 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -91,6 +91,13 @@ export interface LLMProps { * Example: `['vision']` enables `sendMessage(text, { imagePath })`. */ capabilities?: readonly LLMCapability[]; + /** + * Recommended default generation settings, typically copied from the + * upstream `generation_config.json` or the model card. Applied automatically + * after the native module loads and before any user `configure()` call, + * so callers only need to override the values they want to change. + */ + generationConfig?: GenerationConfig; }; /** * Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. @@ -252,6 +259,10 @@ export interface LLMConfig { * `temperature` - Scales output logits by the inverse of temperature. Controls the randomness / creativity of text generation. * * `topp` - Only samples from the smallest set of tokens whose cumulative probability exceeds topp. + * + * `minP` - Minimum probability threshold: tokens with prob < minP * max_prob are excluded. 0 disables filtering. + * + * `repetitionPenalty` - Multiplicative penalty applied to logits of recently generated tokens. Values > 1 discourage repetition. 1 disables the penalty. */ generationConfig?: GenerationConfig; } @@ -329,13 +340,20 @@ export interface ToolsConfig { * Object configuring generation settings. * @category Types * @property {number} [temperature] - Scales output logits by the inverse of temperature. Controls the randomness / creativity of text generation. - * @property {number} [topp] - Only samples from the smallest set of tokens whose cumulative probability exceeds topp. + * @property {number} [topP] - Only samples from the smallest set of tokens whose cumulative probability exceeds topP. + * @property {number} [topp] - **Deprecated.** Use `topP` instead. + * @property {number} [minP] - Minimum probability threshold: tokens with prob < minP * max_prob are excluded. 0 disables filtering. + * @property {number} [repetitionPenalty] - Multiplicative penalty applied to logits of recently generated tokens. Values > 1 discourage repetition. 1 disables the penalty. * @property {number} [outputTokenBatchSize] - Soft upper limit on the number of tokens in each token batch (in certain cases there can be more tokens in given batch, i.e. when the batch would end with special emoji join character). * @property {number} [batchTimeInterval] - Upper limit on the time interval between consecutive token batches. */ export interface GenerationConfig { temperature?: number; + topP?: number; + /** @deprecated Use `topP` instead. */ topp?: number; + minP?: number; + repetitionPenalty?: number; outputTokenBatchSize?: number; batchTimeInterval?: number; }