From b743c1380ac0023832e46a990b80a42a0107c03d Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Mon, 5 May 2025 16:28:15 -0700 Subject: [PATCH 1/2] Add input size check and unit test --- extension/module/module.cpp | 6 ++++++ extension/module/test/module_test.cpp | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index ec01323edc7..2b213e78cfc 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -240,6 +240,12 @@ runtime::Result> Module::execute( auto& method = methods_.at(method_name).method; auto& inputs = methods_.at(method_name).inputs; + ET_CHECK_OR_RETURN_ERROR( + input_values.size() <= inputs.size(), + InvalidArgument, + "input size: %zu does not match method input size: %zu", + input_values.size(), + inputs.size()); for (size_t i = 0; i < input_values.size(); ++i) { if (!input_values[i].isNone()) { inputs[i] = input_values[i]; diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 38d739767a9..18f22a69ee6 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -216,6 +216,16 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) { EXPECT_NE(result.error(), Error::Ok); } +TEST_F(ModuleTest, TestExecuteWithTooManyInputs) { + Module module(model_path_); + + auto tensor = make_tensor_ptr({2, 2}, {1.f, 2.f, 3.f, 4.f}); + + const auto result = module.execute("forward", {tensor, tensor, 1.0, 1.0}); + + EXPECT_NE(result.error(), Error::Ok); +} + TEST_F(ModuleTest, TestGet) { Module module(model_path_); From a7941df95560c92a69f5b565f2f7c63334b10999 Mon Sep 17 00:00:00 2001 From: Zuby Afzal Date: Wed, 14 May 2025 19:12:21 -0700 Subject: [PATCH 2/2] Apply lintrunner formatting --- extension/module/module.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 2b213e78cfc..015d59d2462 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -241,11 +241,11 @@ runtime::Result> Module::execute( auto& inputs = methods_.at(method_name).inputs; ET_CHECK_OR_RETURN_ERROR( - input_values.size() <= inputs.size(), - InvalidArgument, - "input size: %zu does not match method input size: %zu", - input_values.size(), - inputs.size()); + input_values.size() <= inputs.size(), + InvalidArgument, + "input size: %zu does not match method input size: %zu", + input_values.size(), + inputs.size()); for (size_t i = 0; i < input_values.size(); ++i) { if (!input_values[i].isNone()) { inputs[i] = input_values[i];