diff --git a/.github/workflows/MainDistributionPipeline.yml b/.github/workflows/MainDistributionPipeline.yml index 8b7380e0..899d456b 100644 --- a/.github/workflows/MainDistributionPipeline.yml +++ b/.github/workflows/MainDistributionPipeline.yml @@ -26,16 +26,16 @@ jobs: name: Build extension binaries uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@main with: - duckdb_version: v1.4.0 + duckdb_version: v1.4.2 extension_name: flock ci_tools_version: main exclude_archs: 'wasm_mvp;wasm_threads;wasm_eh' duckdb-stable-build: name: Build extension binaries - uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.4.0 + uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.4.2 with: - duckdb_version: v1.4.0 - ci_tools_version: v1.4.0 + duckdb_version: v1.4.2 + ci_tools_version: v1.4.2 extension_name: flock exclude_archs: 'wasm_mvp;wasm_threads;wasm_eh' diff --git a/duckdb b/duckdb index b8a06e4a..68d7555f 160000 --- a/duckdb +++ b/duckdb @@ -1 +1 @@ -Subproject commit b8a06e4a22672e254cd0baa68a3dbed2eb51c56e +Subproject commit 68d7555f68bd25c1a251ccca2e6338949c33986a diff --git a/extension-ci-tools b/extension-ci-tools index ee7f51d0..aac96406 160000 --- a/extension-ci-tools +++ b/extension-ci-tools @@ -1 +1 @@ -Subproject commit ee7f51d06562bbea87d6f6f921def85557e44d18 +Subproject commit aac9640615e51d6e7e8b72d4bf023703cfd8e479 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 47843aba..2d00f9f4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(model_manager) add_subdirectory(prompt_manager) add_subdirectory(custom_parser) add_subdirectory(secret_manager) +add_subdirectory(metrics) set(EXTENSION_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/flock_extension.cpp ${EXTENSION_SOURCES} diff --git a/src/core/config/config.cpp b/src/core/config/config.cpp index 743c3690..2cab60aa 100644 --- a/src/core/config/config.cpp +++ b/src/core/config/config.cpp @@ -65,8 +65,6 @@ void Config::ConfigureGlobal() { void Config::ConfigureLocal(duckdb::DatabaseInstance& db) { auto con = Config::GetConnection(&db); ConfigureTables(con, ConfigType::LOCAL); - con.Query( - duckdb_fmt::format("ATTACH DATABASE '{}' AS flock_storage;", Config::get_global_storage_path().string())); } void Config::ConfigureTables(duckdb::Connection& con, const ConfigType type) { @@ -89,4 +87,61 @@ void Config::Configure(duckdb::ExtensionLoader& loader) { } } +void Config::AttachToGlobalStorage(duckdb::Connection& con, bool read_only) { + con.Query(duckdb_fmt::format("ATTACH DATABASE '{}' AS flock_storage {};", + Config::get_global_storage_path().string(), read_only ? "(READ_ONLY)" : "")); +} + +void Config::DetachFromGlobalStorage(duckdb::Connection& con) { + con.Query("DETACH DATABASE flock_storage;"); +} + +bool Config::StorageAttachmentGuard::TryAttach(bool read_only) { + try { + Config::AttachToGlobalStorage(connection, read_only); + return true; + } catch (const std::exception&) { + return false; + } +} + +bool Config::StorageAttachmentGuard::TryDetach() { + try { + Config::DetachFromGlobalStorage(connection); + return true; + } catch (const std::exception&) { + return false; + } +} + +void Config::StorageAttachmentGuard::Wait(int milliseconds) { + auto start = std::chrono::steady_clock::now(); + auto duration = std::chrono::milliseconds(milliseconds); + while (std::chrono::steady_clock::now() - start < duration) { + // Busy-wait until the specified duration has elapsed + } +} + +Config::StorageAttachmentGuard::StorageAttachmentGuard(duckdb::Connection& con, bool read_only) + : connection(con), attached(false) { + for (int attempt = 0; attempt < MAX_RETRIES; ++attempt) { + if (TryAttach(read_only)) { + attached = true; + return; + } + Wait(RETRY_DELAY_MS); + } + Config::AttachToGlobalStorage(connection, read_only); + attached = true; +} + +Config::StorageAttachmentGuard::~StorageAttachmentGuard() { + if (attached) { + try { + Config::DetachFromGlobalStorage(connection); + } catch (...) { + } + } +} + }// namespace flock diff --git a/src/core/config/model.cpp b/src/core/config/model.cpp index b08e5e92..310b08f4 100644 --- a/src/core/config/model.cpp +++ b/src/core/config/model.cpp @@ -32,6 +32,8 @@ void Config::SetupDefaultModelsConfig(duckdb::Connection& con, std::string& sche "('default', 'gpt-4o-mini', 'openai'), " "('gpt-4o-mini', 'gpt-4o-mini', 'openai'), " "('gpt-4o', 'gpt-4o', 'openai'), " + "('gpt-4o-transcribe', 'gpt-4o-transcribe', 'openai')," + "('gpt-4o-mini-transcribe', 'gpt-4o-mini-transcribe', 'openai')," "('text-embedding-3-large', 'text-embedding-3-large', 'openai'), " "('text-embedding-3-small', 'text-embedding-3-small', 'openai');", schema_name, table_name)); diff --git a/src/custom_parser/query/model_parser.cpp b/src/custom_parser/query/model_parser.cpp index 89918b94..ab7ecf66 100644 --- a/src/custom_parser/query/model_parser.cpp +++ b/src/custom_parser/query/model_parser.cpp @@ -2,6 +2,7 @@ #include "flock/core/common.hpp" #include "flock/core/config.hpp" +#include "flock/custom_parser/query_parser.hpp" #include #include @@ -303,101 +304,123 @@ std::string ModelParser::ToSQL(const QueryStatement& statement) const { switch (statement.type) { case StatementType::CREATE_MODEL: { const auto& create_stmt = static_cast(statement); - auto con = Config::GetConnection(); - auto result = con.Query(duckdb_fmt::format( - " SELECT model_name" - " FROM flock_storage.flock_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE" - " WHERE model_name = '{}'" - " UNION ALL " - " SELECT model_name " - " FROM {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" - " WHERE model_name = '{}';", - create_stmt.model_name, create_stmt.catalog.empty() ? "flock_storage." : "", create_stmt.model_name)); - if (result->RowCount() != 0) { - throw std::runtime_error(duckdb_fmt::format("Model '{}' already exist.", create_stmt.model_name)); - } + query = ExecuteQueryWithStorage([&create_stmt](duckdb::Connection& con) { + auto result = con.Query(duckdb_fmt::format( + " SELECT model_name" + " FROM flock_storage.flock_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE" + " WHERE model_name = '{}'" + " UNION ALL " + " SELECT model_name " + " FROM flock_storage.flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}'" + " UNION ALL " + " SELECT model_name " + " FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}';", + create_stmt.model_name, create_stmt.model_name, create_stmt.model_name)); + + auto& materialized_result = result->Cast(); + if (materialized_result.RowCount() != 0) { + throw std::runtime_error(duckdb_fmt::format("Model '{}' already exist.", create_stmt.model_name)); + } - query = duckdb_fmt::format(" INSERT INTO " - " {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " - " (model_name, model, provider_name, model_args) " - " VALUES ('{}', '{}', '{}', '{}');", - create_stmt.catalog, create_stmt.model_name, create_stmt.model, - create_stmt.provider_name, create_stmt.model_args.dump()); + // Insert the new model + auto insert_query = duckdb_fmt::format(" INSERT INTO " + " {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " (model_name, model, provider_name, model_args) " + " VALUES ('{}', '{}', '{}', '{}');", + create_stmt.catalog, create_stmt.model_name, create_stmt.model, + create_stmt.provider_name, create_stmt.model_args.dump()); + con.Query(insert_query); + + return std::string("SELECT 'Model created successfully' AS status"); + }, + false); break; } case StatementType::DELETE_MODEL: { const auto& delete_stmt = static_cast(statement); - auto con = Config::GetConnection(); - - con.Query(duckdb_fmt::format(" DELETE FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " - " WHERE model_name = '{}';", - delete_stmt.model_name)); - - query = duckdb_fmt::format(" DELETE FROM " + query = ExecuteSetQuery( + duckdb_fmt::format(" DELETE FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " WHERE model_name = '{}'; " + " DELETE FROM " " flock_storage.flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " " WHERE model_name = '{}';", - delete_stmt.model_name, delete_stmt.model_name); + delete_stmt.model_name, delete_stmt.model_name), + "Model deleted successfully", + false); break; } case StatementType::UPDATE_MODEL: { const auto& update_stmt = static_cast(statement); - auto con = Config::GetConnection(); - // get the location of the model_name if local or global - auto result = con.Query( - duckdb_fmt::format(" SELECT model_name, 'global' AS scope " - " FROM flock_storage.flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" - " WHERE model_name = '{}'" - " UNION ALL " - " SELECT model_name, 'local' AS scope " - " FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" - " WHERE model_name = '{}';", - update_stmt.model_name, update_stmt.model_name, update_stmt.model_name)); + query = ExecuteQueryWithStorage([&update_stmt](duckdb::Connection& con) { + // Get the location of the model_name if local or global + auto result = con.Query( + duckdb_fmt::format(" SELECT model_name, 'global' AS scope " + " FROM flock_storage.flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}'" + " UNION ALL " + " SELECT model_name, 'local' AS scope " + " FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}';", + update_stmt.model_name, update_stmt.model_name, update_stmt.model_name)); + + auto& materialized_result = result->Cast(); + if (materialized_result.RowCount() == 0) { + throw std::runtime_error(duckdb_fmt::format("Model '{}' doesn't exist.", update_stmt.model_name)); + } - if (result->RowCount() == 0) { - throw std::runtime_error(duckdb_fmt::format("Model '{}' doesn't exist.", update_stmt.model_name)); - } + auto catalog = materialized_result.GetValue(1, 0).ToString() == "global" ? "flock_storage." : ""; - auto catalog = result->GetValue(1, 0).ToString() == "global" ? "flock_storage." : ""; + con.Query(duckdb_fmt::format(" UPDATE {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " SET model = '{}', provider_name = '{}', " + " model_args = '{}' WHERE model_name = '{}'; ", + catalog, update_stmt.new_model, update_stmt.provider_name, + update_stmt.new_model_args.dump(), update_stmt.model_name)); - query = duckdb_fmt::format(" UPDATE {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " - " SET model = '{}', provider_name = '{}', " - " model_args = '{}' WHERE model_name = '{}'; ", - catalog, update_stmt.new_model, update_stmt.provider_name, - update_stmt.new_model_args.dump(), update_stmt.model_name); + return std::string("SELECT 'Model updated successfully' AS status"); + }, + false); break; } case StatementType::UPDATE_MODEL_SCOPE: { const auto& update_stmt = static_cast(statement); - auto con = Config::GetConnection(); - auto result = - con.Query(duckdb_fmt::format(" SELECT model_name " - " FROM {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" - " WHERE model_name = '{}';", - update_stmt.catalog, update_stmt.model_name)); - if (result->RowCount() != 0) { - throw std::runtime_error( - duckdb_fmt::format("Model '{}' already exist in {} storage.", update_stmt.model_name, - update_stmt.catalog == "flock_storage." ? "global" : "local")); - } + query = ExecuteQueryWithStorage([&update_stmt](duckdb::Connection& con) { + auto result = con.Query(duckdb_fmt::format(" SELECT model_name " + " FROM {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " WHERE model_name = '{}';", + update_stmt.catalog, update_stmt.model_name)); + + auto& materialized_result = result->Cast(); + if (materialized_result.RowCount() != 0) { + throw std::runtime_error( + duckdb_fmt::format("Model '{}' already exist in {} storage.", update_stmt.model_name, + update_stmt.catalog == "flock_storage." ? "global" : "local")); + } - con.Query(duckdb_fmt::format("INSERT INTO {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " - "(model_name, model, provider_name, model_args) " - "SELECT model_name, model, provider_name, model_args " - "FROM {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " - "WHERE model_name = '{}'; ", - update_stmt.catalog, - update_stmt.catalog == "flock_storage." ? "" : "flock_storage.", - update_stmt.model_name)); - - query = duckdb_fmt::format("DELETE FROM {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " - "WHERE model_name = '{}'; ", - update_stmt.catalog == "flock_storage." ? "" : "flock_storage.", - update_stmt.model_name); + con.Query(duckdb_fmt::format("INSERT INTO {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + "(model_name, model, provider_name, model_args) " + "SELECT model_name, model, provider_name, model_args " + "FROM {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + "WHERE model_name = '{}'; ", + update_stmt.catalog, + update_stmt.catalog == "flock_storage." ? "" : "flock_storage.", + update_stmt.model_name)); + + con.Query(duckdb_fmt::format("DELETE FROM {}flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + "WHERE model_name = '{}'; ", + update_stmt.catalog == "flock_storage." ? "" : "flock_storage.", + update_stmt.model_name)); + + return std::string("SELECT 'Model scope updated successfully' AS status"); + }, + false); break; } case StatementType::GET_MODEL: { const auto& get_stmt = static_cast(statement); - query = duckdb_fmt::format("SELECT 'global' AS scope, * " + query = ExecuteGetQuery( + duckdb_fmt::format("SELECT 'global' AS scope, * " "FROM flock_storage.flock_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE " "WHERE model_name = '{}' " "UNION ALL " @@ -408,20 +431,22 @@ std::string ModelParser::ToSQL(const QueryStatement& statement) const { "SELECT 'local' AS scope, * " "FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " "WHERE model_name = '{}';", - get_stmt.model_name, get_stmt.model_name, get_stmt.model_name, get_stmt.model_name); + get_stmt.model_name, get_stmt.model_name, get_stmt.model_name, get_stmt.model_name), + true); break; } case StatementType::GET_ALL_MODEL: { - query = duckdb_fmt::format(" SELECT 'global' AS scope, * " - " FROM flock_storage.flock_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE" - " UNION ALL " - " SELECT 'global' AS scope, * " - " FROM flock_storage.flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" - " UNION ALL " - " SELECT 'local' AS scope, * " - " FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE;", - Config::get_global_storage_path().string()); + query = ExecuteGetQuery( + " SELECT 'global' AS scope, * " + " FROM flock_storage.flock_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE" + " UNION ALL " + " SELECT 'global' AS scope, * " + " FROM flock_storage.flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE" + " UNION ALL " + " SELECT 'local' AS scope, * " + " FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE;", + true); break; } default: diff --git a/src/custom_parser/query/prompt_parser.cpp b/src/custom_parser/query/prompt_parser.cpp index 1150938f..5ba21297 100644 --- a/src/custom_parser/query/prompt_parser.cpp +++ b/src/custom_parser/query/prompt_parser.cpp @@ -2,6 +2,7 @@ #include "flock/core/common.hpp" #include "flock/core/config.hpp" +#include "flock/custom_parser/query_parser.hpp" #include #include @@ -216,89 +217,113 @@ std::string PromptParser::ToSQL(const QueryStatement& statement) const { switch (statement.type) { case StatementType::CREATE_PROMPT: { const auto& create_stmt = static_cast(statement); - auto con = Config::GetConnection(); - auto result = con.Query(duckdb_fmt::format(" SELECT prompt_name " - " FROM {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" - " WHERE prompt_name = '{}';", - create_stmt.catalog.empty() ? "flock_storage." : "", - create_stmt.prompt_name)); - if (result->RowCount() != 0) { - throw std::runtime_error(duckdb_fmt::format("Prompt '{}' already exist.", create_stmt.prompt_name)); - } - query = duckdb_fmt::format(" INSERT INTO {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - " (prompt_name, prompt) " - " VALUES ('{}', '{}'); ", - create_stmt.catalog, create_stmt.prompt_name, create_stmt.prompt); + query = ExecuteQueryWithStorage([&create_stmt](duckdb::Connection& con) { + auto result = con.Query(duckdb_fmt::format(" SELECT prompt_name " + " FROM flock_storage.flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" + " WHERE prompt_name = '{}'" + " UNION ALL " + " SELECT prompt_name " + " FROM flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" + " WHERE prompt_name = '{}';", + create_stmt.prompt_name, create_stmt.prompt_name)); + + auto& materialized_result = result->Cast(); + if (materialized_result.RowCount() != 0) { + throw std::runtime_error(duckdb_fmt::format("Prompt '{}' already exist.", create_stmt.prompt_name)); + } + + auto insert_query = duckdb_fmt::format(" INSERT INTO {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + " (prompt_name, prompt) " + " VALUES ('{}', '{}'); ", + create_stmt.catalog, create_stmt.prompt_name, create_stmt.prompt); + con.Query(insert_query); + + return std::string("SELECT 'Prompt created successfully' AS status"); + }, + false); break; } case StatementType::DELETE_PROMPT: { const auto& delete_stmt = static_cast(statement); - auto con = Config::GetConnection(); - auto result = con.Query(duckdb_fmt::format(" DELETE FROM flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - " WHERE prompt_name = '{}'; ", - delete_stmt.prompt_name)); - - query = duckdb_fmt::format(" DELETE FROM flock_storage.flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + query = ExecuteSetQuery( + duckdb_fmt::format(" DELETE FROM flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + " WHERE prompt_name = '{}'; " + " DELETE FROM flock_storage.flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " " WHERE prompt_name = '{}'; ", - delete_stmt.prompt_name); + delete_stmt.prompt_name, delete_stmt.prompt_name), + "Prompt deleted successfully", + false); break; } case StatementType::UPDATE_PROMPT: { const auto& update_stmt = static_cast(statement); - auto con = Config::GetConnection(); - auto result = - con.Query(duckdb_fmt::format(" SELECT version, 'local' AS scope " - " FROM flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" - " WHERE prompt_name = '{}'" - " UNION ALL " - " SELECT version, 'global' AS scope " - " FROM flock_storage.flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" - " WHERE prompt_name = '{}' " - " ORDER BY version DESC;", - update_stmt.prompt_name, update_stmt.prompt_name)); - if (result->RowCount() == 0) { - throw std::runtime_error(duckdb_fmt::format("Prompt '{}' doesn't exist.", update_stmt.prompt_name)); - } - - int version = result->GetValue(0, 0) + 1; - auto catalog = result->GetValue(1, 0).ToString() == "global" ? "flock_storage." : ""; - query = duckdb_fmt::format(" INSERT INTO {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - " (prompt_name, prompt, version) " - " VALUES ('{}', '{}', {}); ", - catalog, update_stmt.prompt_name, update_stmt.new_prompt, version); + query = ExecuteQueryWithStorage([&update_stmt](duckdb::Connection& con) { + auto result = con.Query(duckdb_fmt::format(" SELECT version, 'local' AS scope " + " FROM flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" + " WHERE prompt_name = '{}'" + " UNION ALL " + " SELECT version, 'global' AS scope " + " FROM flock_storage.flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" + " WHERE prompt_name = '{}' " + " ORDER BY version DESC;", + update_stmt.prompt_name, update_stmt.prompt_name)); + + auto& materialized_result = result->Cast(); + if (materialized_result.RowCount() == 0) { + throw std::runtime_error(duckdb_fmt::format("Prompt '{}' doesn't exist.", update_stmt.prompt_name)); + } + + int version = materialized_result.GetValue(0, 0) + 1; + auto catalog = materialized_result.GetValue(1, 0).ToString() == "global" ? "flock_storage." : ""; + + con.Query(duckdb_fmt::format(" INSERT INTO {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + " (prompt_name, prompt, version) " + " VALUES ('{}', '{}', {}); ", + catalog, update_stmt.prompt_name, update_stmt.new_prompt, version)); + + return std::string("SELECT 'Prompt updated successfully' AS status"); + }, + false); break; } case StatementType::UPDATE_PROMPT_SCOPE: { const auto& update_stmt = static_cast(statement); - auto con = Config::GetConnection(); - auto result = con.Query(duckdb_fmt::format(" SELECT prompt_name " - " FROM {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" - " WHERE prompt_name = '{}';", - update_stmt.catalog, update_stmt.prompt_name)); - if (result->RowCount() != 0) { - throw std::runtime_error( - duckdb_fmt::format("Model '{}' already exist in {} storage.", update_stmt.prompt_name, - update_stmt.catalog == "flock_storage." ? "global" : "local")); - } - - con.Query(duckdb_fmt::format("INSERT INTO {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - "(prompt_name, prompt, updated_at, version) " - "SELECT prompt_name, prompt, updated_at, version " - "FROM {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - "WHERE prompt_name = '{}';", - update_stmt.catalog, - update_stmt.catalog == "flock_storage." ? "" : "flock_storage.", - update_stmt.prompt_name)); - - query = duckdb_fmt::format("DELETE FROM {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " - "WHERE prompt_name = '{}'; ", - update_stmt.catalog == "flock_storage." ? "" : "flock_storage.", - update_stmt.prompt_name); + query = ExecuteQueryWithStorage([&update_stmt](duckdb::Connection& con) { + auto result = con.Query(duckdb_fmt::format(" SELECT prompt_name " + " FROM {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE" + " WHERE prompt_name = '{}';", + update_stmt.catalog, update_stmt.prompt_name)); + + auto& materialized_result = result->Cast(); + if (materialized_result.RowCount() != 0) { + throw std::runtime_error( + duckdb_fmt::format("Prompt '{}' already exist in {} storage.", update_stmt.prompt_name, + update_stmt.catalog == "flock_storage." ? "global" : "local")); + } + + con.Query(duckdb_fmt::format("INSERT INTO {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + "(prompt_name, prompt, updated_at, version) " + "SELECT prompt_name, prompt, updated_at, version " + "FROM {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + "WHERE prompt_name = '{}';", + update_stmt.catalog, + update_stmt.catalog == "flock_storage." ? "" : "flock_storage.", + update_stmt.prompt_name)); + + con.Query(duckdb_fmt::format("DELETE FROM {}flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + "WHERE prompt_name = '{}'; ", + update_stmt.catalog == "flock_storage." ? "" : "flock_storage.", + update_stmt.prompt_name)); + + return std::string("SELECT 'Prompt scope updated successfully' AS status"); + }, + false); break; } case StatementType::GET_PROMPT: { const auto& get_stmt = static_cast(statement); - query = duckdb_fmt::format("SELECT 'global' AS scope, * " + query = ExecuteGetQuery( + duckdb_fmt::format("SELECT 'global' AS scope, * " "FROM flock_storage.flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " "WHERE prompt_name = '{}' " "UNION ALL " @@ -306,12 +331,13 @@ std::string PromptParser::ToSQL(const QueryStatement& statement) const { "FROM flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " "WHERE prompt_name = '{}' " "ORDER BY version DESC;", - get_stmt.prompt_name, get_stmt.prompt_name); - + get_stmt.prompt_name, get_stmt.prompt_name), + true); break; } case StatementType::GET_ALL_PROMPT: { - query = " SELECT 'global' as scope, t1.* " + query = ExecuteGetQuery( + " SELECT 'global' as scope, t1.* " " FROM flock_storage.flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE AS t1 " " JOIN (SELECT prompt_name, MAX(version) AS max_version " " FROM flock_storage.flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " @@ -325,7 +351,8 @@ std::string PromptParser::ToSQL(const QueryStatement& statement) const { " FROM flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " " GROUP BY prompt_name) AS t2 " " ON t1.prompt_name = t2.prompt_name " - " AND t1.version = t2.max_version; "; + " AND t1.version = t2.max_version; ", + true); break; } default: diff --git a/src/custom_parser/query_parser.cpp b/src/custom_parser/query_parser.cpp index 5eae8d6f..b8f766c9 100644 --- a/src/custom_parser/query_parser.cpp +++ b/src/custom_parser/query_parser.cpp @@ -1,12 +1,102 @@ #include "flock/custom_parser/query_parser.hpp" +#include "duckdb/main/materialized_query_result.hpp" #include "flock/core/common.hpp" +#include "flock/core/config.hpp" #include #include namespace flock { +// Format a DuckDB value for SQL (escape strings, handle NULLs) +std::string FormatValueForSQL(const duckdb::Value& value) { + if (value.IsNull()) { + return "NULL"; + } + auto str = value.ToString(); + // Escape single quotes by doubling them + std::string escaped; + escaped.reserve(str.length() + 10); + for (char c: str) { + if (c == '\'') { + escaped += "''"; + } else { + escaped += c; + } + } + return "'" + escaped + "'"; +} + +// Format query results as VALUES clause: SELECT * FROM VALUES (...) +std::string FormatResultsAsValues(duckdb::unique_ptr result) { + if (!result) { + return "SELECT * FROM (VALUES (NULL)) AS empty_result WHERE FALSE"; + } + + // Cast to MaterializedQueryResult to access GetValue and RowCount + auto& materialized_result = result->Cast(); + + if (materialized_result.RowCount() == 0) { + return "SELECT * FROM (VALUES (NULL)) AS empty_result WHERE FALSE"; + } + + std::ostringstream values_stream; + auto column_count = result->ColumnCount(); + + // Get column names + std::vector column_names; + column_names.reserve(column_count); + for (idx_t col = 0; col < column_count; col++) { + column_names.push_back(result->ColumnName(col)); + } + + // Format each row as VALUES tuple + for (idx_t row = 0; row < materialized_result.RowCount(); row++) { + if (row > 0) { + values_stream << ", "; + } + values_stream << "("; + for (idx_t col = 0; col < column_count; col++) { + if (col > 0) { + values_stream << ", "; + } + auto value = materialized_result.GetValue(col, row); + values_stream << FormatValueForSQL(value); + } + values_stream << ")"; + } + + // Build column names for the VALUES clause + std::ostringstream column_names_stream; + for (size_t i = 0; i < column_names.size(); i++) { + if (i > 0) { + column_names_stream << ", "; + } + column_names_stream << "\"" << column_names[i] << "\""; + } + + return duckdb_fmt::format("SELECT * FROM (VALUES {}) AS result({})", + values_stream.str(), column_names_stream.str()); +} + +// Execute a query with storage attachment and return formatted result for GET operations +std::string ExecuteGetQuery(const std::string& query, bool read_only) { + auto con = Config::GetConnection(); + Config::StorageAttachmentGuard guard(con, read_only); + auto result = con.Query(query); + return FormatResultsAsValues(std::move(result)); +} + +// Execute a query with storage attachment and return status message for SET operations +std::string ExecuteSetQuery(const std::string& query, const std::string& success_message, bool read_only) { + auto con = Config::GetConnection(); + Config::StorageAttachmentGuard guard(con, read_only); + con.Query(query); + return duckdb_fmt::format("SELECT '{}' AS status", success_message); +} + + std::string QueryParser::ParseQuery(const std::string& query) { Tokenizer tokenizer(query); diff --git a/src/functions/aggregate/aggregate.cpp b/src/functions/aggregate/aggregate.cpp index f17cfc62..d7389d25 100644 --- a/src/functions/aggregate/aggregate.cpp +++ b/src/functions/aggregate/aggregate.cpp @@ -1,41 +1,140 @@ #include "flock/functions/aggregate/aggregate.hpp" +#include "flock/model_manager/model.hpp" +#include "flock/prompt_manager/prompt_manager.hpp" +#include namespace flock { -nlohmann::json AggregateFunctionBase::model_details; -std::string AggregateFunctionBase::user_query; +void AggregateFunctionBase::ValidateArgumentCount( + const duckdb::vector>& arguments, + const std::string& function_name) { + if (arguments.size() != 2) { + throw duckdb::BinderException( + function_name + " requires 2 arguments: (1) model, (2) prompt with context_columns. Got " + + std::to_string(arguments.size())); + } +} -void AggregateFunctionBase::ValidateArguments(duckdb::Vector inputs[], idx_t input_count) { - if (input_count != 3) { - throw std::runtime_error("Expected exactly 3 arguments for aggregate function, got " + std::to_string(input_count)); +void AggregateFunctionBase::ValidateArgumentTypes( + const duckdb::vector>& arguments, + const std::string& function_name) { + if (arguments[0]->return_type.id() != duckdb::LogicalTypeId::STRUCT) { + throw duckdb::BinderException(function_name + ": First argument must be model (struct type)"); + } + if (arguments[1]->return_type.id() != duckdb::LogicalTypeId::STRUCT) { + throw duckdb::BinderException( + function_name + ": Second argument must be prompt with context_columns (struct type)"); } +} + +AggregateFunctionBase::PromptStructInfo AggregateFunctionBase::ExtractPromptStructInfo( + const duckdb::LogicalType& prompt_type) { + PromptStructInfo info{false, std::nullopt, ""}; - if (inputs[0].GetType().id() != duckdb::LogicalTypeId::STRUCT) { - throw std::runtime_error("Expected a struct type for model details"); + for (idx_t i = 0; i < duckdb::StructType::GetChildCount(prompt_type); i++) { + auto field_name = duckdb::StructType::GetChildName(prompt_type, i); + if (field_name == "context_columns") { + info.has_context_columns = true; + } else if (field_name == "prompt" || field_name == "prompt_name") { + if (!info.prompt_field_index.has_value()) { + info.prompt_field_index = i; + info.prompt_field_name = field_name; + } + } } - if (inputs[1].GetType().id() != duckdb::LogicalTypeId::STRUCT) { - throw std::runtime_error("Expected a struct type for prompt details"); + return info; +} + +void AggregateFunctionBase::ValidatePromptStructFields(const PromptStructInfo& info, + const std::string& function_name) { + if (!info.has_context_columns) { + throw duckdb::BinderException( + function_name + ": Second argument must contain 'context_columns' field"); } +} - if (inputs[2].GetType().id() != duckdb::LogicalTypeId::STRUCT) { - throw std::runtime_error("Expected a struct type for prompt inputs"); +void AggregateFunctionBase::InitializeModelJson( + duckdb::ClientContext& context, + const duckdb::unique_ptr& model_expr, + LlmFunctionBindData& bind_data) { + if (!model_expr->IsFoldable()) { + return; } + + auto model_value = duckdb::ExpressionExecutor::EvaluateScalar(context, *model_expr); + auto user_model_json = CastValueToJson(model_value); + bind_data.model_json = Model::ResolveModelDetailsToJson(user_model_json); +} + +void AggregateFunctionBase::InitializePrompt( + duckdb::ClientContext& context, + const duckdb::unique_ptr& prompt_expr, + LlmFunctionBindData& bind_data) { + nlohmann::json prompt_json; + + if (prompt_expr->IsFoldable()) { + auto prompt_value = duckdb::ExpressionExecutor::EvaluateScalar(context, *prompt_expr); + prompt_json = CastValueToJson(prompt_value); + } else if (prompt_expr->expression_class == duckdb::ExpressionClass::BOUND_FUNCTION) { + auto& func_expr = prompt_expr->Cast(); + const auto& struct_type = prompt_expr->return_type; + + for (idx_t i = 0; i < duckdb::StructType::GetChildCount(struct_type) && i < func_expr.children.size(); i++) { + auto field_name = duckdb::StructType::GetChildName(struct_type, i); + auto& child = func_expr.children[i]; + + if (field_name != "context_columns" && child->IsFoldable()) { + try { + auto field_value = duckdb::ExpressionExecutor::EvaluateScalar(context, *child); + if (field_value.type().id() == duckdb::LogicalTypeId::VARCHAR) { + prompt_json[field_name] = field_value.GetValue(); + } else { + prompt_json[field_name] = CastValueToJson(field_value); + } + } catch (...) { + // Skip fields that can't be evaluated + } + } + } + } + + auto prompt_details = PromptManager::CreatePromptDetails(prompt_json); + bind_data.prompt = prompt_details.prompt; +} + +duckdb::unique_ptr AggregateFunctionBase::ValidateAndInitializeBindData( + duckdb::ClientContext& context, + duckdb::vector>& arguments, + const std::string& function_name) { + + ValidateArgumentCount(arguments, function_name); + ValidateArgumentTypes(arguments, function_name); + + const auto& prompt_type = arguments[1]->return_type; + auto prompt_info = ExtractPromptStructInfo(prompt_type); + ValidatePromptStructFields(prompt_info, function_name); + + auto bind_data = duckdb::make_uniq(); + + InitializeModelJson(context, arguments[0], *bind_data); + InitializePrompt(context, arguments[1], *bind_data); + + return bind_data; } -std::tuple +std::tuple AggregateFunctionBase::CastInputsToJson(duckdb::Vector inputs[], idx_t count) { - auto model_details_json = CastVectorOfStructsToJson(inputs[0], 1); auto prompt_context_json = CastVectorOfStructsToJson(inputs[1], count); auto context_columns = nlohmann::json::array(); if (prompt_context_json.contains("context_columns")) { context_columns = prompt_context_json["context_columns"]; prompt_context_json.erase("context_columns"); } else { - throw std::runtime_error("Expected 'context_columns' in prompt details"); + throw std::runtime_error("Missing 'context_columns' in second argument. The prompt struct must include context_columns."); } - return std::make_tuple(model_details_json, prompt_context_json, context_columns); + return std::make_tuple(prompt_context_json, context_columns); } }// namespace flock diff --git a/src/functions/aggregate/aggregate_state.cpp b/src/functions/aggregate/aggregate_state.cpp index 9959af35..1fa4b797 100644 --- a/src/functions/aggregate/aggregate_state.cpp +++ b/src/functions/aggregate/aggregate_state.cpp @@ -3,9 +3,8 @@ namespace flock { void AggregateFunctionState::Initialize() { - if (!value) { - value = new nlohmann::json(nlohmann::json::array()); - } + value = new nlohmann::json(nlohmann::json::array()); + initialized = true; } void AggregateFunctionState::Update(const nlohmann::json& input) { @@ -14,7 +13,7 @@ void AggregateFunctionState::Update(const nlohmann::json& input) { } auto idx = 0u; - for (auto& column: input) { + for (const auto& column: input) { if (value->size() <= idx) { value->push_back(nlohmann::json::object()); (*value)[idx]["data"] = nlohmann::json::array(); @@ -25,7 +24,9 @@ void AggregateFunctionState::Update(const nlohmann::json& input) { (*value)[idx]["data"].push_back(item_value); } } else { - (*value)[idx][item.key()] = item.value(); + if (!(*value)[idx].contains(item.key())) { + (*value)[idx][item.key()] = item.value(); + } } } idx++; @@ -39,14 +40,26 @@ void AggregateFunctionState::Combine(const AggregateFunctionState& source) { if (source.value) { auto idx = 0u; - for (auto& column: *source.value) { + for (const auto& column: *source.value) { + if (value->size() <= idx) { + value->push_back(nlohmann::json::object()); + } + + if (!(*value)[idx].contains("data")) { + (*value)[idx]["data"] = nlohmann::json::array(); + } + for (const auto& item: column.items()) { if (item.key() == "data") { - for (const auto& item_value: item.value()) { - (*value)[idx]["data"].push_back(item_value); + if (item.value().is_array()) { + for (const auto& item_value: item.value()) { + (*value)[idx]["data"].push_back(item_value); + } } } else { - (*value)[idx][item.key()] = item.value(); + if (!(*value)[idx].contains(item.key())) { + (*value)[idx][item.key()] = item.value(); + } } } idx++; @@ -55,11 +68,11 @@ void AggregateFunctionState::Combine(const AggregateFunctionState& source) { } void AggregateFunctionState::Destroy() { + initialized = false; if (value) { delete value; value = nullptr; } - initialized = false; } }// namespace flock diff --git a/src/functions/aggregate/llm_first_or_last/implementation.cpp b/src/functions/aggregate/llm_first_or_last/implementation.cpp index 7135359c..876e441e 100644 --- a/src/functions/aggregate/llm_first_or_last/implementation.cpp +++ b/src/functions/aggregate/llm_first_or_last/implementation.cpp @@ -1,27 +1,99 @@ +#include "flock/core/config.hpp" #include "flock/functions/aggregate/llm_first_or_last.hpp" +#include "flock/functions/llm_function_bind_data.hpp" +#include "flock/metrics/manager.hpp" + +#include +#include namespace flock { +duckdb::unique_ptr LlmFirstOrLast::Bind( + duckdb::ClientContext& context, + duckdb::AggregateFunction& function, + duckdb::vector>& arguments) { + return AggregateFunctionBase::ValidateAndInitializeBindData(context, arguments, function.name); +} + int LlmFirstOrLast::GetFirstOrLastTupleId(nlohmann::json& tuples) { - nlohmann::json data; - const auto [prompt, media_data] = PromptManager::Render(user_query, tuples, function_type, model.GetModelDetails().tuple_format); + const auto [prompt, media_data] = PromptManager::Render( + user_query, tuples, function_type, model.GetModelDetails().tuple_format); model.AddCompletionRequest(prompt, 1, OutputType::INTEGER, media_data); auto response = model.CollectCompletions()[0]; - return response["items"][0]; + + std::set valid_ids; + for (const auto& column: tuples) { + if (column.contains("name") && column["name"].is_string() && + column["name"].get() == "flock_row_id" && + column.contains("data") && column["data"].is_array()) { + for (const auto& id: column["data"]) { + if (id.is_string()) { + valid_ids.insert(id.get()); + } + } + break; + } + } + + int result_id_int = -1; + std::string result_id_str; + if (response["items"][0].is_number_integer()) { + result_id_int = response["items"][0].get(); + result_id_str = std::to_string(result_id_int); + } else if (response["items"][0].is_string()) { + result_id_str = response["items"][0].get(); + try { + result_id_int = std::stoi(result_id_str); + } catch (...) { + throw std::runtime_error( + "Invalid LLM response: The LLM returned ID '" + result_id_str + + "' which is not a valid flock_row_id."); + } + } else { + throw std::runtime_error( + "Invalid LLM response: Expected integer or string ID, got: " + + response["items"][0].dump()); + } + + if (valid_ids.find(result_id_str) == valid_ids.end()) { + throw std::runtime_error( + "Invalid LLM response: The LLM returned ID '" + result_id_str + + "' which is not a valid flock_row_id."); + } + + return result_id_int; } nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) { + int num_tuples = static_cast(tuples[0]["data"].size()); + + if (num_tuples <= 1) { + auto result = nlohmann::json::array(); + for (auto i = 0; i < static_cast(tuples.size()) - 1; i++) { + result.push_back(nlohmann::json::object()); + for (const auto& item: tuples[i].items()) { + if (item.key() == "data") { + result[i]["data"] = nlohmann::json::array(); + if (!item.value().empty()) { + result[i]["data"].push_back(item.value()[0]); + } + } else { + result[i][item.key()] = item.value(); + } + } + } + return result; + } + auto batch_tuples = nlohmann::json::array(); int start_index = 0; - model = Model(model_details); - auto batch_size = std::min(model.GetModelDetails().batch_size, static_cast(tuples[0]["data"].size())); + auto batch_size = std::min(model.GetModelDetails().batch_size, num_tuples); if (batch_size <= 0) { throw std::runtime_error("Batch size must be greater than zero"); } do { - for (auto i = 0; i < static_cast(tuples.size()); i++) { if (start_index == 0) { batch_tuples.push_back(nlohmann::json::object()); @@ -46,7 +118,7 @@ nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) { auto result_idx = GetFirstOrLastTupleId(batch_tuples); batch_tuples.clear(); - for (auto i = 0; i < static_cast(tuples.size()); i++) { + for (auto i = 0; i < static_cast(tuples.size()) - 1; i++) { batch_tuples.push_back(nlohmann::json::object()); for (const auto& item: tuples[i].items()) { if (item.key() == "data") { @@ -58,7 +130,7 @@ nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) { } } } catch (const ExceededMaxOutputTokensError&) { - start_index -= batch_size;// Retry the current batch with reduced size + start_index -= batch_size; batch_size = static_cast(batch_size * 0.9); if (batch_size <= 0) { throw std::runtime_error("Batch size reduced to zero, unable to process tuples"); @@ -67,37 +139,90 @@ nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) { } while (start_index < static_cast(tuples[0]["data"].size())); - batch_tuples.erase(batch_tuples.end() - 1); - return batch_tuples; } void LlmFirstOrLast::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, idx_t count, idx_t offset, AggregateFunctionType function_type) { - const auto states_vector = reinterpret_cast(duckdb::FlatVector::GetData(states)); + const auto states_vector = reinterpret_cast( + duckdb::FlatVector::GetData(states)); + + FunctionType metrics_function_type = + (function_type == AggregateFunctionType::FIRST) ? FunctionType::LLM_FIRST : FunctionType::LLM_LAST; + + auto& bind_data = aggr_input_data.bind_data->Cast(); + + auto temp_model = bind_data.CreateModel(); + auto model_details_obj = temp_model.GetModelDetails(); + + auto db = Config::db; + std::vector processed_state_ids; for (idx_t i = 0; i < count; i++) { - auto idx = i + offset; - auto* state = states_vector[idx]; - - if (state && !state->value->empty()) { - auto tuples_with_ids = *state->value; - tuples_with_ids.push_back(nlohmann::json::object()); - for (auto j = 0; j < static_cast((*state->value)[0]["data"].size()); j++) { - if (j == 0) { - tuples_with_ids.back()["name"] = "flock_row_id"; - tuples_with_ids.back()["data"] = nlohmann::json::array(); + auto result_idx = i + offset; + auto* state = states_vector[i]; + + if (!state || !state->value || state->value->empty()) { + result.SetValue(result_idx, nullptr); + continue; + } + + int num_tuples = static_cast((*state->value)[0]["data"].size()); + + if (num_tuples <= 1) { + auto response = nlohmann::json::array(); + for (auto k = 0; k < static_cast(state->value->size()); k++) { + response.push_back(nlohmann::json::object()); + for (const auto& item: (*state->value)[k].items()) { + if (item.key() == "data") { + response[k]["data"] = nlohmann::json::array(); + if (!item.value().empty()) { + response[k]["data"].push_back(item.value()[0]); + } + } else { + response[k][item.key()] = item.value(); + } } - tuples_with_ids.back()["data"].push_back(std::to_string(j)); } - LlmFirstOrLast function_instance; - function_instance.function_type = function_type; - auto response = function_instance.Evaluate(tuples_with_ids); - result.SetValue(idx, response.dump()); - } else { - result.SetValue(idx, nullptr);// Empty JSON object for null/empty states + result.SetValue(result_idx, response.dump()); + continue; } + + const void* state_id = static_cast(state); + processed_state_ids.push_back(state_id); + MetricsManager::StartInvocation(db, state_id, metrics_function_type); + MetricsManager::SetModelInfo(model_details_obj.model_name, model_details_obj.provider_name); + + auto exec_start = std::chrono::high_resolution_clock::now(); + + nlohmann::json tuples_with_ids = *state->value; + + tuples_with_ids.push_back({{"name", "flock_row_id"}, {"data", nlohmann::json::array()}}); + for (int j = 0; j < num_tuples; j++) { + tuples_with_ids.back()["data"].push_back(std::to_string(j)); + } + + if (bind_data.prompt.empty()) { + throw std::runtime_error("The prompt cannot be empty"); + } + + LlmFirstOrLast function_instance; + function_instance.function_type = function_type; + function_instance.user_query = bind_data.prompt; + function_instance.model = bind_data.CreateModel(); + auto response = function_instance.Evaluate(tuples_with_ids); + + auto exec_end = std::chrono::high_resolution_clock::now(); + double exec_duration_ms = std::chrono::duration(exec_end - exec_start).count(); + MetricsManager::AddExecutionTime(exec_duration_ms); + + result.SetValue(result_idx, response.dump()); + } + + if (!processed_state_ids.empty()) { + MetricsManager::MergeAggregateMetrics(db, processed_state_ids, metrics_function_type, + model_details_obj.model_name, model_details_obj.provider_name); } } diff --git a/src/functions/aggregate/llm_first_or_last/registry.cpp b/src/functions/aggregate/llm_first_or_last/registry.cpp index a3c23598..e08f28ec 100644 --- a/src/functions/aggregate/llm_first_or_last/registry.cpp +++ b/src/functions/aggregate/llm_first_or_last/registry.cpp @@ -9,7 +9,7 @@ void AggregateRegistry::RegisterLlmFirst(duckdb::ExtensionLoader& loader) { duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize, LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine, LlmFirstOrLast::Finalize, LlmFirstOrLast::SimpleUpdate, - nullptr, LlmFirstOrLast::Destroy)); + LlmFirstOrLast::Bind, LlmFirstOrLast::Destroy)); } void AggregateRegistry::RegisterLlmLast(duckdb::ExtensionLoader& loader) { @@ -18,7 +18,7 @@ void AggregateRegistry::RegisterLlmLast(duckdb::ExtensionLoader& loader) { duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize, LlmFirstOrLast::Initialize, LlmFirstOrLast::Operation, LlmFirstOrLast::Combine, LlmFirstOrLast::Finalize, LlmFirstOrLast::SimpleUpdate, - nullptr, LlmFirstOrLast::Destroy)); + LlmFirstOrLast::Bind, LlmFirstOrLast::Destroy)); } }// namespace flock \ No newline at end of file diff --git a/src/functions/aggregate/llm_reduce/implementation.cpp b/src/functions/aggregate/llm_reduce/implementation.cpp index 9a7ce46d..a34fcc7e 100644 --- a/src/functions/aggregate/llm_reduce/implementation.cpp +++ b/src/functions/aggregate/llm_reduce/implementation.cpp @@ -1,25 +1,39 @@ +#include "flock/core/config.hpp" #include "flock/functions/aggregate/llm_reduce.hpp" +#include "flock/functions/llm_function_bind_data.hpp" +#include "flock/metrics/manager.hpp" + +#include namespace flock { -nlohmann::json LlmReduce::ReduceBatch(nlohmann::json& tuples, const AggregateFunctionType& function_type, const nlohmann::json& summary) { - nlohmann::json data; - auto [prompt, media_data] = PromptManager::Render(user_query, tuples, function_type, model.GetModelDetails().tuple_format); +duckdb::unique_ptr LlmReduce::Bind( + duckdb::ClientContext& context, + duckdb::AggregateFunction& function, + duckdb::vector>& arguments) { + return AggregateFunctionBase::ValidateAndInitializeBindData(context, arguments, "llm_reduce"); +} + +nlohmann::json LlmReduce::ReduceBatch(nlohmann::json& tuples, + const AggregateFunctionType& function_type, + const nlohmann::json& summary) { + auto [prompt, media_data] = PromptManager::Render( + user_query, tuples, function_type, model.GetModelDetails().tuple_format); prompt += "\n\n" + summary.dump(4); - OutputType output_type = OutputType::STRING; - model.AddCompletionRequest(prompt, 1, output_type, media_data); + model.AddCompletionRequest(prompt, 1, OutputType::STRING, media_data); auto response = model.CollectCompletions()[0]; return response["items"][0]; -}; +} nlohmann::json LlmReduce::ReduceLoop(const nlohmann::json& tuples, const AggregateFunctionType& function_type) { auto batch_tuples = nlohmann::json::array(); auto summary = nlohmann::json::object({{"Previous Batch Summary", ""}}); int start_index = 0; - auto batch_size = std::min(model.GetModelDetails().batch_size, static_cast(tuples[0]["data"].size())); + int num_tuples = static_cast(tuples[0]["data"].size()); + auto batch_size = std::min(model.GetModelDetails().batch_size, num_tuples); if (batch_size <= 0) { throw std::runtime_error("Batch size must be greater than zero"); @@ -30,10 +44,8 @@ nlohmann::json LlmReduce::ReduceLoop(const nlohmann::json& tuples, batch_tuples.push_back(nlohmann::json::object()); for (const auto& item: tuples[i].items()) { if (item.key() == "data") { + batch_tuples[i]["data"] = nlohmann::json::array(); for (auto j = 0; j < batch_size && start_index + j < static_cast(item.value().size()); j++) { - if (j == 0) { - batch_tuples[i]["data"] = nlohmann::json::array(); - } batch_tuples[i]["data"].push_back(item.value()[start_index + j]); } } else { @@ -56,7 +68,7 @@ nlohmann::json LlmReduce::ReduceLoop(const nlohmann::json& tuples, } } - } while (start_index < static_cast(tuples[0]["data"].size())); + } while (start_index < num_tuples); return summary["Previous Batch Summary"]; } @@ -64,25 +76,60 @@ nlohmann::json LlmReduce::ReduceLoop(const nlohmann::json& tuples, void LlmReduce::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, idx_t count, idx_t offset, const AggregateFunctionType function_type) { - const auto states_vector = reinterpret_cast(duckdb::FlatVector::GetData(states)); + const auto states_vector = reinterpret_cast( + duckdb::FlatVector::GetData(states)); + + // Get bind data - model_json and prompt are guaranteed to be initialized + auto& bind_data = aggr_input_data.bind_data->Cast(); + // Get model details for metrics (create temp model just for details) + auto temp_model = bind_data.CreateModel(); + auto model_details_obj = temp_model.GetModelDetails(); + + auto db = Config::db; + std::vector processed_state_ids; + + // Process each state individually for (idx_t i = 0; i < count; i++) { - auto idx = i + offset; - auto* state = states_vector[idx]; - - if (state && !state->value->empty()) { - LlmReduce reduce_instance; - reduce_instance.model = Model(model_details); - auto response = reduce_instance.ReduceLoop(*state->value, function_type); - if (response.is_string()) { - result.SetValue(idx, response.get()); - } else { - result.SetValue(idx, response.dump()); - } + auto result_idx = i + offset; + auto* state = states_vector[i]; + + if (!state || !state->value || state->value->empty()) { + result.SetValue(result_idx, nullptr); + continue; + } + + // Track metrics for this state + const void* state_id = static_cast(state); + processed_state_ids.push_back(state_id); + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_REDUCE); + MetricsManager::SetModelInfo(model_details_obj.model_name, model_details_obj.provider_name); + + auto exec_start = std::chrono::high_resolution_clock::now(); + + // Create function instance with bind data and process + // IMPORTANT: Use CreateModel() for thread-safe Model instance + LlmReduce reduce_instance; + reduce_instance.model = bind_data.CreateModel(); + reduce_instance.user_query = bind_data.prompt; + auto response = reduce_instance.ReduceLoop(*state->value, function_type); + + auto exec_end = std::chrono::high_resolution_clock::now(); + double exec_duration_ms = std::chrono::duration(exec_end - exec_start).count(); + MetricsManager::AddExecutionTime(exec_duration_ms); + + if (response.is_string()) { + result.SetValue(result_idx, response.get()); } else { - result.SetValue(idx, nullptr);// Empty result for null/empty states + result.SetValue(result_idx, response.dump()); } } + + // Merge all metrics from processed states + if (!processed_state_ids.empty()) { + MetricsManager::MergeAggregateMetrics(db, processed_state_ids, FunctionType::LLM_REDUCE, + model_details_obj.model_name, model_details_obj.provider_name); + } } }// namespace flock diff --git a/src/functions/aggregate/llm_reduce/registry.cpp b/src/functions/aggregate/llm_reduce/registry.cpp index 31658023..c3885305 100644 --- a/src/functions/aggregate/llm_reduce/registry.cpp +++ b/src/functions/aggregate/llm_reduce/registry.cpp @@ -9,7 +9,7 @@ void AggregateRegistry::RegisterLlmReduce(duckdb::ExtensionLoader& loader) { duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize, LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine, LlmReduce::Finalize, LlmReduce::SimpleUpdate, - nullptr, LlmReduce::Destroy)); + LlmReduce::Bind, LlmReduce::Destroy)); } }// namespace flock \ No newline at end of file diff --git a/src/functions/aggregate/llm_rerank/implementation.cpp b/src/functions/aggregate/llm_rerank/implementation.cpp index a5fbd789..7bdd3c4b 100644 --- a/src/functions/aggregate/llm_rerank/implementation.cpp +++ b/src/functions/aggregate/llm_rerank/implementation.cpp @@ -1,22 +1,113 @@ +#include "flock/core/config.hpp" #include "flock/functions/aggregate/llm_rerank.hpp" +#include "flock/functions/llm_function_bind_data.hpp" +#include "flock/metrics/manager.hpp" + +#include +#include namespace flock { +duckdb::unique_ptr LlmRerank::Bind( + duckdb::ClientContext& context, + duckdb::AggregateFunction& function, + duckdb::vector>& arguments) { + return AggregateFunctionBase::ValidateAndInitializeBindData(context, arguments, "llm_rerank"); +} + std::vector LlmRerank::RerankBatch(const nlohmann::json& tuples) { - nlohmann::json data; - auto [prompt, media_data] = - PromptManager::Render(user_query, tuples, AggregateFunctionType::RERANK, model.GetModelDetails().tuple_format); - model.AddCompletionRequest(prompt, static_cast(tuples[0]["data"].size()), OutputType::INTEGER, media_data); + auto [prompt, media_data] = PromptManager::Render( + user_query, tuples, AggregateFunctionType::RERANK, model.GetModelDetails().tuple_format); + + int num_tuples = static_cast(tuples[0]["data"].size()); + + model.AddCompletionRequest(prompt, num_tuples, OutputType::INTEGER, media_data); auto responses = model.CollectCompletions(); - return responses[0]["items"]; -}; + + // Find flock_row_id column to get valid IDs + std::set valid_ids; + for (const auto& column: tuples) { + if (column.contains("name") && column["name"].is_string() && + column["name"].get() == "flock_row_id" && + column.contains("data") && column["data"].is_array()) { + for (const auto& id: column["data"]) { + if (id.is_string()) { + valid_ids.insert(id.get()); + } + } + break; + } + } + + std::vector indices; + std::set seen_ids; + + for (const auto& item: responses[0]["items"]) { + std::string id_str; + int id_int = -1; + + // Handle both integer and string responses + if (item.is_number_integer()) { + id_int = item.get(); + id_str = std::to_string(id_int); + } else if (item.is_string()) { + id_str = item.get(); + try { + id_int = std::stoi(id_str); + } catch (...) { + throw std::runtime_error( + "Invalid LLM response: The LLM returned ID '" + id_str + + "' which is not a valid flock_row_id."); + } + } else { + throw std::runtime_error( + "Invalid LLM response: Expected integer or string ID, got: " + item.dump()); + } + + // Validate that the ID exists in flock_row_id + if (valid_ids.find(id_str) == valid_ids.end()) { + throw std::runtime_error( + "Invalid LLM response: The LLM returned ID '" + id_str + + "' which is not a valid flock_row_id."); + } + + // Check for duplicates + if (seen_ids.count(id_str) > 0) { + throw std::runtime_error( + "Invalid LLM response: The LLM returned duplicate ID '" + id_str + "'."); + } + seen_ids.insert(id_str); + indices.push_back(id_int); + } + + return indices; +} nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) { - const auto num_tuples = static_cast(tuples[0]["data"].size()); + const int num_tuples = static_cast(tuples[0]["data"].size()); + + // If there's only 1 tuple, no need to call the LLM - just return it + if (num_tuples <= 1) { + auto result = nlohmann::json::array(); + for (auto i = 0; i < static_cast(tuples.size()); i++) { + result.push_back(nlohmann::json::object()); + for (const auto& item: tuples[i].items()) { + if (item.key() == "data") { + result[i]["data"] = nlohmann::json::array(); + if (!item.value().empty()) { + result[i]["data"].push_back(item.value()[0]); + } + } else { + result[i][item.key()] = item.value(); + } + } + } + return result; + } + auto final_ranked_tuples = nlohmann::json::array(); auto carry_forward_tuples = nlohmann::json::array(); - auto start_index = 0; - model = Model(model_details); + int start_index = 0; auto batch_size = static_cast(model.GetModelDetails().batch_size); if (batch_size == 2048) { @@ -31,18 +122,22 @@ nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) { auto window_tuples = carry_forward_tuples; // Then add new tuples up to batch_size - auto remaining_space = batch_size - static_cast(window_tuples[0]["data"].size()); - auto end_index = std::min(start_index + remaining_space, num_tuples); + // Handle case where carry_forward_tuples is empty (first iteration) + int remaining_space = window_tuples.empty() + ? batch_size + : (batch_size - static_cast(window_tuples[0]["data"].size())); + int end_index = std::min(start_index + remaining_space, num_tuples); + for (auto i = 0; i < static_cast(tuples.size()); i++) { if (i >= static_cast(window_tuples.size())) { window_tuples.push_back(nlohmann::json::object()); } for (const auto& item: tuples[i].items()) { if (item.key() == "data") { - for (auto j = start_index; j < end_index; j++) { - if (j == 0) { - window_tuples[i]["data"] = nlohmann::json::array(); - } + if (!window_tuples[i].contains("data")) { + window_tuples[i]["data"] = nlohmann::json::array(); + } + for (int j = start_index; j < end_index; j++) { window_tuples[i]["data"].push_back(item.value()[j]); } } else { @@ -54,23 +149,39 @@ nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) { // Clear carry forward for next iteration carry_forward_tuples.clear(); + // Skip if window_tuples is empty (shouldn't happen, but safety check) + if (window_tuples.empty() || window_tuples[0]["data"].empty()) { + continue; + } + try { + // Build indexed tuples with flock_row_id auto indexed_tuples = window_tuples; - indexed_tuples.push_back(nlohmann::json::object()); - for (auto i = 0; i < static_cast(window_tuples[0]["data"].size()); i++) { - if (i == 0) { - indexed_tuples.back()["name"] = "flock_row_id"; - indexed_tuples.back()["data"] = nlohmann::json::array(); - } + indexed_tuples.push_back({{"name", "flock_row_id"}, {"data", nlohmann::json::array()}}); + for (int i = 0; i < static_cast(window_tuples[0]["data"].size()); i++) { indexed_tuples.back()["data"].push_back(std::to_string(i)); } auto ranked_indices = RerankBatch(indexed_tuples); + // Initialize final_ranked_tuples structure if needed (first time adding results) + if (final_ranked_tuples.empty() && !window_tuples.empty()) { + for (size_t i = 0; i < window_tuples.size(); i++) { + final_ranked_tuples.push_back(nlohmann::json::object()); + // Copy metadata from window_tuples + for (const auto& item: window_tuples[i].items()) { + if (item.key() != "data") { + final_ranked_tuples[i][item.key()] = item.value(); + } + } + final_ranked_tuples[i]["data"] = nlohmann::json::array(); + } + } + // Add the bottom half to final results (they won't be re-ranked) - auto half_batch = static_cast(ranked_indices.size()) / 2; - for (auto i = half_batch; i < static_cast(ranked_indices.size()); i++) { - auto idx = 0u; + int half_batch = static_cast(ranked_indices.size()) / 2; + for (int i = half_batch; i < static_cast(ranked_indices.size()); i++) { + size_t idx = 0; for (auto& column: window_tuples) { final_ranked_tuples[idx]["data"].push_back(column["data"][ranked_indices[i]]); idx++; @@ -78,8 +189,21 @@ nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) { } // Carry forward top half to next batch for re-ranking - for (auto i = 0; i < half_batch; i++) { - auto idx = 0u; + // Initialize carry_forward_tuples structure if needed + if (carry_forward_tuples.empty() && !window_tuples.empty()) { + for (size_t i = 0; i < window_tuples.size(); i++) { + carry_forward_tuples.push_back(nlohmann::json::object()); + // Copy metadata from window_tuples + for (const auto& item: window_tuples[i].items()) { + if (item.key() != "data") { + carry_forward_tuples[i][item.key()] = item.value(); + } + } + carry_forward_tuples[i]["data"] = nlohmann::json::array(); + } + } + for (int i = 0; i < half_batch; i++) { + size_t idx = 0; for (auto& column: window_tuples) { carry_forward_tuples[idx]["data"].push_back(column["data"][ranked_indices[i]]); idx++; @@ -90,10 +214,10 @@ nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) { // If we've processed all input tuples, add remaining carry forward to final results if (start_index >= num_tuples && !carry_forward_tuples.empty()) { - auto idx = 0u; + size_t idx = 0; for (const auto& column: carry_forward_tuples) { - for (const auto& i: column["data"]) { - final_ranked_tuples[idx]["data"].push_back(i); + for (const auto& data_item: column["data"]) { + final_ranked_tuples[idx]["data"].push_back(data_item); } idx++; } @@ -112,25 +236,60 @@ nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) { return final_ranked_tuples; } -void LlmRerank::Finalize(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, - idx_t count, idx_t offset) { - const auto states_vector = reinterpret_cast(duckdb::FlatVector::GetData(states)); +void LlmRerank::Finalize(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, + duckdb::Vector& result, idx_t count, idx_t offset) { + const auto states_vector = reinterpret_cast( + duckdb::FlatVector::GetData(states)); + + // Get bind data - model_json and prompt are guaranteed to be initialized + auto& bind_data = aggr_input_data.bind_data->Cast(); + + // Get model details for metrics (create temp model just for details) + auto temp_model = bind_data.CreateModel(); + auto model_details_obj = temp_model.GetModelDetails(); + auto db = Config::db; + std::vector processed_state_ids; + + // Process each state individually for (idx_t i = 0; i < count; i++) { - auto idx = i + offset; - auto* state = states_vector[idx]; + auto result_idx = i + offset; + auto* state = states_vector[i]; - if (state && !state->value->empty()) { - auto tuples_with_ids = nlohmann::json::array(); - for (auto j = 0; j < static_cast(state->value->size()); j++) { - tuples_with_ids.push_back((*state->value)[j]); - } - LlmRerank function_instance; - auto reranked_tuples = function_instance.SlidingWindow(tuples_with_ids); - result.SetValue(idx, reranked_tuples.dump()); - } else { - result.SetValue(idx, nullptr);// Empty result for null/empty states + if (!state || !state->value || state->value->empty()) { + result.SetValue(result_idx, nullptr); + continue; } + + // Track metrics for this state + const void* state_id = static_cast(state); + processed_state_ids.push_back(state_id); + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_RERANK); + MetricsManager::SetModelInfo(model_details_obj.model_name, model_details_obj.provider_name); + + auto exec_start = std::chrono::high_resolution_clock::now(); + + // Copy state value to avoid potential use-after-free issues + nlohmann::json tuples = *state->value; + + // Create function instance with bind data + // IMPORTANT: Use CreateModel() for thread-safe Model instance + LlmRerank function_instance; + function_instance.user_query = bind_data.prompt; + function_instance.model = bind_data.CreateModel(); + auto reranked_tuples = function_instance.SlidingWindow(tuples); + + auto exec_end = std::chrono::high_resolution_clock::now(); + double exec_duration_ms = std::chrono::duration(exec_end - exec_start).count(); + MetricsManager::AddExecutionTime(exec_duration_ms); + + result.SetValue(result_idx, reranked_tuples.dump()); + } + + // Merge all metrics from processed states + if (!processed_state_ids.empty()) { + MetricsManager::MergeAggregateMetrics(db, processed_state_ids, FunctionType::LLM_RERANK, + model_details_obj.model_name, model_details_obj.provider_name); } } diff --git a/src/functions/aggregate/llm_rerank/registry.cpp b/src/functions/aggregate/llm_rerank/registry.cpp index 74739b70..5438606e 100644 --- a/src/functions/aggregate/llm_rerank/registry.cpp +++ b/src/functions/aggregate/llm_rerank/registry.cpp @@ -8,7 +8,7 @@ void AggregateRegistry::RegisterLlmRerank(duckdb::ExtensionLoader& loader) { "llm_rerank", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, duckdb::LogicalType::JSON(), duckdb::AggregateFunction::StateSize, LlmRerank::Initialize, LlmRerank::Operation, LlmRerank::Combine, LlmRerank::Finalize, LlmRerank::SimpleUpdate, - nullptr, LlmRerank::Destroy)); + LlmRerank::Bind, LlmRerank::Destroy)); } }// namespace flock diff --git a/src/functions/input_parser.cpp b/src/functions/input_parser.cpp index 9b1e23a5..09d49b91 100644 --- a/src/functions/input_parser.cpp +++ b/src/functions/input_parser.cpp @@ -4,6 +4,40 @@ namespace flock { +// Helper function to validate and clean context column, handling NULL values +static void ValidateAndCleanContextColumn(nlohmann::json& column, const std::initializer_list& allowed_keys) { + std::string column_type = ""; + bool has_type = false; + bool has_transcription_model = false; + + for (const auto& key: allowed_keys) { + if (key != std::string("data")) { + bool key_exists = column.contains(key); + bool is_null = key_exists && column[key].get() == "NULL"; + + if (key == std::string("type") && key_exists && !is_null) { + column_type = column[key].get(); + has_type = true; + } else if (key == std::string("transcription_model") && key_exists && !is_null) { + has_transcription_model = true; + } else if (!key_exists || is_null) { + column.erase(key); + } + } + } + + // Validate transcription_model is only used with audio type + if (has_transcription_model && column_type != "audio") { + std::string type_display = has_type ? column_type : "tabular"; + throw std::runtime_error(duckdb_fmt::format("Argument 'transcription_model' is not supported for data type '{}'. It can only be used with type 'audio'.", type_display)); + } + + // Validate that audio type requires transcription_model + if (has_type && column_type == "audio" && !has_transcription_model) { + throw std::runtime_error("Argument 'transcription_model' is required when type is 'audio'."); + } +} + nlohmann::json CastVectorOfStructsToJson(const duckdb::Vector& struct_vector, const int size) { nlohmann::json struct_json; @@ -20,28 +54,25 @@ nlohmann::json CastVectorOfStructsToJson(const duckdb::Vector& struct_vector, co for (auto context_column_idx = 0; context_column_idx < static_cast(context_columns.size()); context_column_idx++) { auto context_column = context_columns[context_column_idx]; auto context_column_json = CastVectorOfStructsToJson(duckdb::Vector(context_column), 1); - auto allowed_keys = {"name", "data", "type", "detail"}; + auto allowed_keys = {"name", "data", "type", "detail", "transcription_model"}; for (const auto& key: context_column_json.items()) { if (std::find(std::begin(allowed_keys), std::end(allowed_keys), key.key()) == std::end(allowed_keys)) { throw std::runtime_error(duckdb_fmt::format("Unexpected key in 'context_columns': {}", key.key())); } } + auto required_keys = {"data"}; for (const auto& key: required_keys) { if (!context_column_json.contains(key) || (key != "data" && context_column_json[key].get() == "NULL")) { throw std::runtime_error(duckdb_fmt::format("Expected 'context_columns' to contain key: {}", key)); } } + if (struct_json.contains("context_columns") && struct_json["context_columns"].size() == context_columns.size()) { struct_json["context_columns"][context_column_idx]["data"].push_back(context_column_json["data"]); } else { struct_json["context_columns"].push_back(context_column_json); - for (const auto& key: allowed_keys) { - if (key != "data" && (!struct_json["context_columns"][context_column_idx].contains(key) || - struct_json["context_columns"][context_column_idx][key].get() == "NULL")) { - struct_json["context_columns"][context_column_idx].erase(key); - } - } + ValidateAndCleanContextColumn(struct_json["context_columns"][context_column_idx], allowed_keys); struct_json["context_columns"][context_column_idx]["data"] = nlohmann::json::array(); struct_json["context_columns"][context_column_idx]["data"].push_back(context_column_json["data"]); } @@ -59,4 +90,36 @@ nlohmann::json CastVectorOfStructsToJson(const duckdb::Vector& struct_vector, co return struct_json; } +nlohmann::json CastValueToJson(const duckdb::Value& value) { + nlohmann::json result; + + if (value.IsNull()) { + return result; + } + + auto& value_type = value.type(); + if (value_type.id() == duckdb::LogicalTypeId::STRUCT) { + auto& children = duckdb::StructValue::GetChildren(value); + auto child_count = duckdb::StructType::GetChildCount(value_type); + + for (idx_t i = 0; i < child_count; i++) { + auto key = duckdb::StructType::GetChildName(value_type, i); + auto& child_value = children[i]; + + if (!child_value.IsNull()) { + // Recursively convert child values + if (child_value.type().id() == duckdb::LogicalTypeId::STRUCT) { + result[key] = CastValueToJson(child_value); + } else if (child_value.type().id() == duckdb::LogicalTypeId::INTEGER) { + result[key] = child_value.GetValue(); + } else { + result[key] = child_value.ToString(); + } + } + } + } + + return result; +} + }// namespace flock diff --git a/src/functions/scalar/llm_complete/implementation.cpp b/src/functions/scalar/llm_complete/implementation.cpp index f4ae1509..83f60f14 100644 --- a/src/functions/scalar/llm_complete/implementation.cpp +++ b/src/functions/scalar/llm_complete/implementation.cpp @@ -1,7 +1,20 @@ +#include "duckdb/planner/expression/bound_function_expression.hpp" #include "flock/functions/scalar/llm_complete.hpp" +#include "flock/functions/scalar/scalar.hpp" +#include "flock/metrics/manager.hpp" +#include "flock/model_manager/model.hpp" + namespace flock { +duckdb::unique_ptr LlmComplete::Bind( + duckdb::ClientContext& context, + duckdb::ScalarFunction& bound_function, + duckdb::vector>& arguments) { + return ScalarFunctionBase::ValidateAndInitializeBindData(context, arguments, "llm_complete", false); +} + + void LlmComplete::ValidateArguments(duckdb::DataChunk& args) { if (args.ColumnCount() < 2 || args.ColumnCount() > 3) { throw std::runtime_error("Invalid number of arguments."); @@ -21,21 +34,23 @@ void LlmComplete::ValidateArguments(duckdb::DataChunk& args) { } } -std::vector LlmComplete::Operation(duckdb::DataChunk& args) { - // LlmComplete::ValidateArguments(args); - auto model_details_json = CastVectorOfStructsToJson(args.data[0], 1); - Model model(model_details_json); +std::vector LlmComplete::Operation(duckdb::DataChunk& args, LlmFunctionBindData* bind_data) { + Model model = bind_data->CreateModel(); + + auto model_details = model.GetModelDetails(); + MetricsManager::SetModelInfo(model_details.model_name, model_details.provider_name); + auto prompt_context_json = CastVectorOfStructsToJson(args.data[1], args.size()); auto context_columns = nlohmann::json::array(); if (prompt_context_json.contains("context_columns")) { context_columns = prompt_context_json["context_columns"]; - prompt_context_json.erase("context_columns"); } - auto prompt_details = PromptManager::CreatePromptDetails(prompt_context_json); + + auto prompt = bind_data->prompt; std::vector results; if (context_columns.empty()) { - auto template_str = prompt_details.prompt; + auto template_str = prompt; model.AddCompletionRequest(template_str, 1, OutputType::STRING); auto response = model.CollectCompletions()[0]["items"][0]; if (response.is_string()) { @@ -48,7 +63,7 @@ std::vector LlmComplete::Operation(duckdb::DataChunk& args) { return results; } - auto responses = BatchAndComplete(context_columns, prompt_details.prompt, ScalarFunctionType::COMPLETE, model); + auto responses = BatchAndComplete(context_columns, prompt, ScalarFunctionType::COMPLETE, model); results.reserve(responses.size()); for (const auto& response: responses) { @@ -63,8 +78,18 @@ std::vector LlmComplete::Operation(duckdb::DataChunk& args) { } void LlmComplete::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) { + auto& context = state.GetContext(); + auto* db = context.db.get(); + const void* invocation_id = MetricsManager::GenerateUniqueId(); + + MetricsManager::StartInvocation(db, invocation_id, FunctionType::LLM_COMPLETE); + + auto exec_start = std::chrono::high_resolution_clock::now(); - if (const auto results = LlmComplete::Operation(args); static_cast(results.size()) == 1) { + auto& func_expr = state.expr.Cast(); + auto* bind_data = &func_expr.bind_info->Cast(); + + if (const auto results = LlmComplete::Operation(args, bind_data); static_cast(results.size()) == 1) { auto empty_vec = duckdb::Vector(std::string()); duckdb::UnaryExecutor::Execute( empty_vec, result, args.size(), @@ -75,6 +100,10 @@ void LlmComplete::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& stat result.SetValue(index++, duckdb::Value(res)); } } + + auto exec_end = std::chrono::high_resolution_clock::now(); + double exec_duration_ms = std::chrono::duration(exec_end - exec_start).count(); + MetricsManager::AddExecutionTime(exec_duration_ms); } }// namespace flock diff --git a/src/functions/scalar/llm_complete/registry.cpp b/src/functions/scalar/llm_complete/registry.cpp index 492085d4..7c3544f8 100644 --- a/src/functions/scalar/llm_complete/registry.cpp +++ b/src/functions/scalar/llm_complete/registry.cpp @@ -6,7 +6,8 @@ namespace flock { void ScalarRegistry::RegisterLlmComplete(duckdb::ExtensionLoader& loader) { loader.RegisterFunction(duckdb::ScalarFunction("llm_complete", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, - duckdb::LogicalType::JSON(), LlmComplete::Execute)); + duckdb::LogicalType::JSON(), LlmComplete::Execute, + LlmComplete::Bind)); } }// namespace flock diff --git a/src/functions/scalar/llm_embedding/implementation.cpp b/src/functions/scalar/llm_embedding/implementation.cpp index 7b423257..de8b9b05 100644 --- a/src/functions/scalar/llm_embedding/implementation.cpp +++ b/src/functions/scalar/llm_embedding/implementation.cpp @@ -1,7 +1,19 @@ +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "flock/core/config.hpp" #include "flock/functions/scalar/llm_embedding.hpp" +#include "flock/metrics/manager.hpp" +#include "flock/model_manager/model.hpp" namespace flock { +duckdb::unique_ptr LlmEmbedding::Bind( + duckdb::ClientContext& context, + duckdb::ScalarFunction& bound_function, + duckdb::vector>& arguments) { + return ScalarFunctionBase::ValidateAndInitializeBindData(context, arguments, "llm_embedding", true, false); +} + + void LlmEmbedding::ValidateArguments(duckdb::DataChunk& args) { if (args.ColumnCount() < 2 || args.ColumnCount() > 2) { throw std::runtime_error("LlmEmbedScalarParser: Invalid number of arguments."); @@ -14,9 +26,7 @@ void LlmEmbedding::ValidateArguments(duckdb::DataChunk& args) { } } -std::vector> LlmEmbedding::Operation(duckdb::DataChunk& args) { - // LlmEmbedding::ValidateArguments(args); - +std::vector> LlmEmbedding::Operation(duckdb::DataChunk& args, LlmFunctionBindData* bind_data) { auto inputs = CastVectorOfStructsToJson(args.data[1], args.size()); for (const auto& item: inputs.items()) { if (item.key() != "context_columns") { @@ -29,8 +39,10 @@ std::vector> LlmEmbedding::Operation(duckdb::DataC } } - auto model_details_json = CastVectorOfStructsToJson(args.data[0], 1); - Model model(model_details_json); + Model model = bind_data->CreateModel(); + + auto model_details = model.GetModelDetails(); + MetricsManager::SetModelInfo(model_details.model_name, model_details.provider_name); std::vector prepared_inputs; auto num_rows = inputs["context_columns"][0]["data"].size(); @@ -71,12 +83,27 @@ std::vector> LlmEmbedding::Operation(duckdb::DataC } void LlmEmbedding::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) { - auto results = LlmEmbedding::Operation(args); + auto& context = state.GetContext(); + auto* db = context.db.get(); + const void* invocation_id = MetricsManager::GenerateUniqueId(); + + MetricsManager::StartInvocation(db, invocation_id, FunctionType::LLM_EMBEDDING); + + auto exec_start = std::chrono::high_resolution_clock::now(); + + auto& func_expr = state.expr.Cast(); + auto* bind_data = &func_expr.bind_info->Cast(); + + auto results = LlmEmbedding::Operation(args, bind_data); auto index = 0; for (const auto& res: results) { result.SetValue(index++, duckdb::Value::LIST(res)); } + + auto exec_end = std::chrono::high_resolution_clock::now(); + double exec_duration_ms = std::chrono::duration(exec_end - exec_start).count(); + MetricsManager::AddExecutionTime(exec_duration_ms); } }// namespace flock diff --git a/src/functions/scalar/llm_embedding/registry.cpp b/src/functions/scalar/llm_embedding/registry.cpp index eadba2fc..35d8829c 100644 --- a/src/functions/scalar/llm_embedding/registry.cpp +++ b/src/functions/scalar/llm_embedding/registry.cpp @@ -5,8 +5,9 @@ namespace flock { void ScalarRegistry::RegisterLlmEmbedding(duckdb::ExtensionLoader& loader) { loader.RegisterFunction( - duckdb::ScalarFunction("llm_embedding", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, duckdb::LogicalType::LIST(duckdb::LogicalType::DOUBLE), - LlmEmbedding::Execute)); + duckdb::ScalarFunction("llm_embedding", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::LIST(duckdb::LogicalType::DOUBLE), + LlmEmbedding::Execute, LlmEmbedding::Bind)); } }// namespace flock diff --git a/src/functions/scalar/llm_filter/implementation.cpp b/src/functions/scalar/llm_filter/implementation.cpp index 67b49108..e0a4419f 100644 --- a/src/functions/scalar/llm_filter/implementation.cpp +++ b/src/functions/scalar/llm_filter/implementation.cpp @@ -1,9 +1,22 @@ +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "flock/core/config.hpp" #include "flock/functions/scalar/llm_filter.hpp" +#include "flock/functions/scalar/scalar.hpp" +#include "flock/metrics/manager.hpp" +#include "flock/model_manager/model.hpp" namespace flock { +duckdb::unique_ptr LlmFilter::Bind( + duckdb::ClientContext& context, + duckdb::ScalarFunction& bound_function, + duckdb::vector>& arguments) { + return ScalarFunctionBase::ValidateAndInitializeBindData(context, arguments, "llm_filter", false); +} + + void LlmFilter::ValidateArguments(duckdb::DataChunk& args) { - if (args.ColumnCount() != 3) { + if (args.ColumnCount() < 2 || args.ColumnCount() > 3) { throw std::runtime_error("Invalid number of arguments."); } @@ -14,46 +27,73 @@ void LlmFilter::ValidateArguments(duckdb::DataChunk& args) { throw std::runtime_error("Prompt details must be a struct."); } - if (args.data[2].GetType().id() != duckdb::LogicalTypeId::STRUCT) { + if (args.ColumnCount() == 3 && args.data[2].GetType().id() != duckdb::LogicalTypeId::STRUCT) { throw std::runtime_error("Inputs must be a struct."); } } -std::vector LlmFilter::Operation(duckdb::DataChunk& args) { - // LlmFilter::ValidateArguments(args); +std::vector LlmFilter::Operation(duckdb::DataChunk& args, LlmFunctionBindData* bind_data) { + Model model = bind_data->CreateModel(); + + auto model_details = model.GetModelDetails(); + MetricsManager::SetModelInfo(model_details.model_name, model_details.provider_name); - auto model_details_json = CastVectorOfStructsToJson(args.data[0], 1); - Model model(model_details_json); auto prompt_context_json = CastVectorOfStructsToJson(args.data[1], args.size()); auto context_columns = nlohmann::json::array(); if (prompt_context_json.contains("context_columns")) { context_columns = prompt_context_json["context_columns"]; - prompt_context_json.erase("context_columns"); } - auto prompt_details = PromptManager::CreatePromptDetails(prompt_context_json); - auto responses = BatchAndComplete(context_columns, prompt_details.prompt, ScalarFunctionType::FILTER, model); + auto prompt = bind_data->prompt; std::vector results; - results.reserve(responses.size()); - for (const auto& response: responses) { + if (context_columns.empty()) { + auto template_str = prompt; + model.AddCompletionRequest(template_str, 1, OutputType::BOOL); + auto response = model.CollectCompletions()[0]["items"][0]; if (response.is_null()) { - results.emplace_back("true"); - continue; + results.push_back("true"); + } else { + results.push_back(response.dump()); + } + } else { + auto responses = BatchAndComplete(context_columns, prompt, ScalarFunctionType::FILTER, model); + + results.reserve(responses.size()); + for (const auto& response: responses) { + if (response.is_null()) { + results.emplace_back("true"); + continue; + } + results.push_back(response.dump()); } - results.push_back(response.dump()); } return results; } void LlmFilter::Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) { - const auto results = LlmFilter::Operation(args); + auto& context = state.GetContext(); + auto* db = context.db.get(); + const void* invocation_id = MetricsManager::GenerateUniqueId(); + + MetricsManager::StartInvocation(db, invocation_id, FunctionType::LLM_FILTER); + + auto exec_start = std::chrono::high_resolution_clock::now(); + + auto& func_expr = state.expr.Cast(); + auto* bind_data = &func_expr.bind_info->Cast(); + + const auto results = LlmFilter::Operation(args, bind_data); auto index = 0; for (const auto& res: results) { result.SetValue(index++, duckdb::Value(res)); } + + auto exec_end = std::chrono::high_resolution_clock::now(); + double exec_duration_ms = std::chrono::duration(exec_end - exec_start).count(); + MetricsManager::AddExecutionTime(exec_duration_ms); } }// namespace flock diff --git a/src/functions/scalar/llm_filter/registry.cpp b/src/functions/scalar/llm_filter/registry.cpp index d539dcf8..715bce04 100644 --- a/src/functions/scalar/llm_filter/registry.cpp +++ b/src/functions/scalar/llm_filter/registry.cpp @@ -6,7 +6,8 @@ namespace flock { void ScalarRegistry::RegisterLlmFilter(duckdb::ExtensionLoader& loader) { loader.RegisterFunction(duckdb::ScalarFunction("llm_filter", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, - duckdb::LogicalType::VARCHAR, LlmFilter::Execute)); + duckdb::LogicalType::VARCHAR, LlmFilter::Execute, + LlmFilter::Bind)); } }// namespace flock diff --git a/src/functions/scalar/scalar.cpp b/src/functions/scalar/scalar.cpp index b2861327..6d022d81 100644 --- a/src/functions/scalar/scalar.cpp +++ b/src/functions/scalar/scalar.cpp @@ -1,10 +1,74 @@ #include "flock/functions/scalar/scalar.hpp" +#include "flock/model_manager/model.hpp" +#include namespace flock { +void ScalarFunctionBase::ValidateArgumentCount( + const duckdb::vector>& arguments, + const std::string& function_name) { + if (arguments.size() != 2) { + throw duckdb::BinderException( + function_name + " requires 2 arguments: (1) model, (2) prompt with context_columns. Got " + + std::to_string(arguments.size())); + } +} + +void ScalarFunctionBase::ValidateArgumentTypes( + const duckdb::vector>& arguments, + const std::string& function_name) { + if (arguments[0]->return_type.id() != duckdb::LogicalTypeId::STRUCT) { + throw duckdb::BinderException(function_name + ": First argument must be model (struct type)"); + } + if (arguments[1]->return_type.id() != duckdb::LogicalTypeId::STRUCT) { + throw duckdb::BinderException( + function_name + ": Second argument must be prompt with context_columns (struct type)"); + } +} + +ScalarFunctionBase::PromptStructInfo ScalarFunctionBase::ExtractPromptStructInfo( + const duckdb::LogicalType& prompt_type) { + PromptStructInfo info{false, std::nullopt, ""}; + + for (idx_t i = 0; i < duckdb::StructType::GetChildCount(prompt_type); i++) { + auto field_name = duckdb::StructType::GetChildName(prompt_type, i); + if (field_name == "context_columns") { + info.has_context_columns = true; + } else if (field_name == "prompt" || field_name == "prompt_name") { + if (!info.prompt_field_index.has_value()) { + info.prompt_field_index = i; + info.prompt_field_name = field_name; + } + } + } + + return info; +} + +void ScalarFunctionBase::ValidatePromptStructFields(const PromptStructInfo& info, + const std::string& function_name, + bool require_context_columns) { + if (require_context_columns && !info.has_context_columns) { + throw duckdb::BinderException( + function_name + ": Second argument must contain 'context_columns' field"); + } +} + +void ScalarFunctionBase::InitializeModelJson( + duckdb::ClientContext& context, + const duckdb::unique_ptr& model_expr, + LlmFunctionBindData& bind_data) { + if (!model_expr->IsFoldable()) { + return; + } + + auto model_value = duckdb::ExpressionExecutor::EvaluateScalar(context, *model_expr); + auto user_model_json = CastValueToJson(model_value); + bind_data.model_json = Model::ResolveModelDetailsToJson(user_model_json); +} + nlohmann::json ScalarFunctionBase::Complete(nlohmann::json& columns, const std::string& user_prompt, ScalarFunctionType function_type, Model& model) { - nlohmann::json data; const auto [prompt, media_data] = PromptManager::Render(user_prompt, columns, function_type, model.GetModelDetails().tuple_format); OutputType output_type = OutputType::STRING; if (function_type == ScalarFunctionType::FILTER) { @@ -81,4 +145,68 @@ nlohmann::json ScalarFunctionBase::BatchAndComplete(const nlohmann::json& tuples return responses; } +void ScalarFunctionBase::InitializePrompt( + duckdb::ClientContext& context, + const duckdb::unique_ptr& prompt_expr, + LlmFunctionBindData& bind_data) { + nlohmann::json prompt_json; + + if (prompt_expr->IsFoldable()) { + auto prompt_value = duckdb::ExpressionExecutor::EvaluateScalar(context, *prompt_expr); + prompt_json = CastValueToJson(prompt_value); + } else if (prompt_expr->expression_class == duckdb::ExpressionClass::BOUND_FUNCTION) { + auto& func_expr = prompt_expr->Cast(); + const auto& struct_type = prompt_expr->return_type; + + for (idx_t i = 0; i < duckdb::StructType::GetChildCount(struct_type) && i < func_expr.children.size(); i++) { + auto field_name = duckdb::StructType::GetChildName(struct_type, i); + auto& child = func_expr.children[i]; + + if (field_name != "context_columns" && child->IsFoldable()) { + try { + auto field_value = duckdb::ExpressionExecutor::EvaluateScalar(context, *child); + if (field_value.type().id() == duckdb::LogicalTypeId::VARCHAR) { + prompt_json[field_name] = field_value.GetValue(); + } else { + prompt_json[field_name] = CastValueToJson(field_value); + } + } catch (...) { + // Skip fields that can't be evaluated + } + } + } + } + + if (prompt_json.contains("context_columns")) { + prompt_json.erase("context_columns"); + } + + auto prompt_details = PromptManager::CreatePromptDetails(prompt_json); + bind_data.prompt = prompt_details.prompt; +} + +duckdb::unique_ptr ScalarFunctionBase::ValidateAndInitializeBindData( + duckdb::ClientContext& context, + duckdb::vector>& arguments, + const std::string& function_name, + bool require_context_columns, + bool initialize_prompt) { + + ValidateArgumentCount(arguments, function_name); + ValidateArgumentTypes(arguments, function_name); + + const auto& prompt_type = arguments[1]->return_type; + auto prompt_info = ExtractPromptStructInfo(prompt_type); + ValidatePromptStructFields(prompt_info, function_name, require_context_columns); + + auto bind_data = duckdb::make_uniq(); + + InitializeModelJson(context, arguments[0], *bind_data); + if (initialize_prompt) { + InitializePrompt(context, arguments[1], *bind_data); + } + + return bind_data; +} + }// namespace flock diff --git a/src/include/filesystem.hpp b/src/include/filesystem.hpp index 5e2bde20..bb8f2c49 100644 --- a/src/include/filesystem.hpp +++ b/src/include/filesystem.hpp @@ -41,7 +41,7 @@ #endif // Not on Visual Studio. Let's use the normal version -#else // #ifdef _MSC_VER +#else// #ifdef _MSC_VER #define INCLUDE_STD_FILESYSTEM_EXPERIMENTAL 0 #endif @@ -70,4 +70,4 @@ namespace filesystem = experimental::filesystem; #include #endif -#endif // #ifndef INCLUDE_STD_FILESYSTEM_EXPERIMENTAL +#endif// #ifndef INCLUDE_STD_FILESYSTEM_EXPERIMENTAL diff --git a/src/include/flock/core/common.hpp b/src/include/flock/core/common.hpp index bb909ea4..b9be1f50 100644 --- a/src/include/flock/core/common.hpp +++ b/src/include/flock/core/common.hpp @@ -1,6 +1,17 @@ #pragma once +// DuckDB includes #include "duckdb.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/function/scalar_function.hpp" #include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" + +// Common standard library includes +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/src/include/flock/core/config.hpp b/src/include/flock/core/config.hpp index 1f402d35..88ce1581 100644 --- a/src/include/flock/core/config.hpp +++ b/src/include/flock/core/config.hpp @@ -26,6 +26,29 @@ class Config { static std::string get_default_models_table_name(); static std::string get_user_defined_models_table_name(); static std::string get_prompts_table_name(); + static void AttachToGlobalStorage(duckdb::Connection& con, bool read_only = true); + static void DetachFromGlobalStorage(duckdb::Connection& con); + + class StorageAttachmentGuard { + public: + StorageAttachmentGuard(duckdb::Connection& con, bool read_only = true); + ~StorageAttachmentGuard(); + + StorageAttachmentGuard(const StorageAttachmentGuard&) = delete; + StorageAttachmentGuard& operator=(const StorageAttachmentGuard&) = delete; + StorageAttachmentGuard(StorageAttachmentGuard&&) = delete; + StorageAttachmentGuard& operator=(StorageAttachmentGuard&&) = delete; + + private: + duckdb::Connection& connection; + bool attached; + static constexpr int MAX_RETRIES = 10; + static constexpr int RETRY_DELAY_MS = 1000; + + bool TryAttach(bool read_only); + bool TryDetach(); + void Wait(int milliseconds); + }; private: static void SetupGlobalStorageLocation(); diff --git a/src/include/flock/custom_parser/query/model_parser.hpp b/src/include/flock/custom_parser/query/model_parser.hpp index a18fedbb..bd5a6baa 100644 --- a/src/include/flock/custom_parser/query/model_parser.hpp +++ b/src/include/flock/custom_parser/query/model_parser.hpp @@ -5,10 +5,7 @@ #include "flock/custom_parser/tokenizer.hpp" #include "fmt/format.h" -#include #include -#include -#include namespace flock { diff --git a/src/include/flock/custom_parser/query/prompt_parser.hpp b/src/include/flock/custom_parser/query/prompt_parser.hpp index b2e422ec..eeca50e0 100644 --- a/src/include/flock/custom_parser/query/prompt_parser.hpp +++ b/src/include/flock/custom_parser/query/prompt_parser.hpp @@ -5,9 +5,6 @@ #include "flock/custom_parser/tokenizer.hpp" #include "fmt/format.h" -#include -#include -#include namespace flock { diff --git a/src/include/flock/custom_parser/query_parser.hpp b/src/include/flock/custom_parser/query_parser.hpp index bbdac178..08d57952 100644 --- a/src/include/flock/custom_parser/query_parser.hpp +++ b/src/include/flock/custom_parser/query_parser.hpp @@ -1,18 +1,30 @@ #pragma once #include "flock/core/common.hpp" +#include "flock/core/config.hpp" #include "flock/custom_parser/query/model_parser.hpp" #include "flock/custom_parser/query/prompt_parser.hpp" #include "flock/custom_parser/query_statements.hpp" #include "flock/custom_parser/tokenizer.hpp" #include "fmt/format.h" -#include -#include -#include namespace flock { +// Forward declarations for query execution utilities +std::string FormatValueForSQL(const duckdb::Value& value); +std::string FormatResultsAsValues(duckdb::unique_ptr result); +std::string ExecuteGetQuery(const std::string& query, bool read_only); +std::string ExecuteSetQuery(const std::string& query, const std::string& success_message, bool read_only); + +// Template function for executing queries with storage attachment +template +std::string ExecuteQueryWithStorage(Func&& query_func, bool read_only) { + auto con = Config::GetConnection(); + Config::StorageAttachmentGuard guard(con, read_only); + return query_func(con); +} + class QueryParser { public: std::string ParseQuery(const std::string& query); diff --git a/src/include/flock/custom_parser/query_statements.hpp b/src/include/flock/custom_parser/query_statements.hpp index 2528bf88..88446eda 100644 --- a/src/include/flock/custom_parser/query_statements.hpp +++ b/src/include/flock/custom_parser/query_statements.hpp @@ -2,10 +2,6 @@ #include "flock/core/common.hpp" -#include -#include -#include - namespace flock { // Enum to represent different statement types diff --git a/src/include/flock/custom_parser/tokenizer.hpp b/src/include/flock/custom_parser/tokenizer.hpp index 17c7b378..c2d28647 100644 --- a/src/include/flock/custom_parser/tokenizer.hpp +++ b/src/include/flock/custom_parser/tokenizer.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "flock/core/common.hpp" namespace flock { diff --git a/src/include/flock/functions/aggregate/aggregate.hpp b/src/include/flock/functions/aggregate/aggregate.hpp index 180a3b2c..40d34bfd 100644 --- a/src/include/flock/functions/aggregate/aggregate.hpp +++ b/src/include/flock/functions/aggregate/aggregate.hpp @@ -1,11 +1,12 @@ #pragma once -#include -#include - #include "flock/core/common.hpp" #include "flock/functions/input_parser.hpp" +#include "flock/functions/llm_function_bind_data.hpp" +#include "flock/metrics/manager.hpp" #include "flock/model_manager/model.hpp" +#include +#include namespace flock { @@ -32,40 +33,59 @@ class AggregateFunctionState { class AggregateFunctionBase { public: Model model; - static nlohmann::json model_details; - static std::string user_query; + std::string user_query; public: explicit AggregateFunctionBase() = default; +private: + struct PromptStructInfo { + bool has_context_columns; + std::optional prompt_field_index; + std::string prompt_field_name; + }; + + static void ValidateArgumentCount(const duckdb::vector>& arguments, + const std::string& function_name); + + static void ValidateArgumentTypes(const duckdb::vector>& arguments, + const std::string& function_name); + + static PromptStructInfo ExtractPromptStructInfo(const duckdb::LogicalType& prompt_type); + + static void ValidatePromptStructFields(const PromptStructInfo& info, const std::string& function_name); + + static void InitializeModelJson(duckdb::ClientContext& context, + const duckdb::unique_ptr& model_expr, + LlmFunctionBindData& bind_data); + + static void InitializePrompt(duckdb::ClientContext& context, + const duckdb::unique_ptr& prompt_expr, + LlmFunctionBindData& bind_data); + public: - static void ValidateArguments(duckdb::Vector inputs[], idx_t input_count); - static std::tuple + static std::tuple CastInputsToJson(duckdb::Vector inputs[], idx_t count); + static duckdb::unique_ptr ValidateAndInitializeBindData( + duckdb::ClientContext& context, + duckdb::vector>& arguments, + const std::string& function_name); + static bool IgnoreNull() { return true; }; + template static void Initialize(const duckdb::AggregateFunction&, duckdb::data_ptr_t state_p) { auto state = reinterpret_cast(state_p); - - // Use placement new to properly construct the AggregateFunctionState object new (state) AggregateFunctionState(); - - if (!state->initialized) { - state->Initialize(); - state->initialized = true; - } + state->Initialize(); } template static void Operation(duckdb::Vector inputs[], duckdb::AggregateInputData& aggr_input_data, idx_t input_count, duckdb::Vector& states, idx_t count) { - // ValidateArguments(inputs, input_count); - - auto [model_details_json, prompt_details, columns] = CastInputsToJson(inputs, count); - model_details = model_details_json; - user_query = PromptManager::CreatePromptDetails(prompt_details).prompt; + auto [prompt_details, columns] = CastInputsToJson(inputs, count); auto state_map_p = reinterpret_cast(duckdb::FlatVector::GetData(states)); @@ -94,11 +114,7 @@ class AggregateFunctionBase { template static void SimpleUpdate(duckdb::Vector inputs[], duckdb::AggregateInputData& aggr_input_data, idx_t input_count, duckdb::data_ptr_t state_p, idx_t count) { - // ValidateArguments(inputs, input_count); - - auto [model_details_json, prompt_details, tuples] = CastInputsToJson(inputs, count); - model_details = model_details_json; - user_query = PromptManager::CreatePromptDetails(prompt_details).prompt; + auto [prompt_details, tuples] = CastInputsToJson(inputs, count); if (const auto state = reinterpret_cast(state_p)) { state->Update(tuples); @@ -131,7 +147,6 @@ class AggregateFunctionBase { auto* state = state_vector[i]; if (state) { state->Destroy(); - state->~AggregateFunctionState();// Explicitly call destructor } } } @@ -142,13 +157,9 @@ class AggregateFunctionBase { template static void FinalizeSafe(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, idx_t count, idx_t offset) { - const auto states_vector = reinterpret_cast(duckdb::FlatVector::GetData(states)); - for (idx_t i = 0; i < count; i++) { - auto idx = i + offset; - auto* state = states_vector[idx]; - - result.SetValue(idx, "[]");// Empty JSON array as default + auto result_idx = i + offset; + result.SetValue(result_idx, "[]"); } } }; diff --git a/src/include/flock/functions/aggregate/llm_first_or_last.hpp b/src/include/flock/functions/aggregate/llm_first_or_last.hpp index bd4cbac0..f8970db7 100644 --- a/src/include/flock/functions/aggregate/llm_first_or_last.hpp +++ b/src/include/flock/functions/aggregate/llm_first_or_last.hpp @@ -1,6 +1,7 @@ #pragma once #include "flock/functions/aggregate/aggregate.hpp" +#include "flock/functions/llm_function_bind_data.hpp" namespace flock { @@ -14,6 +15,9 @@ class LlmFirstOrLast : public AggregateFunctionBase { int GetFirstOrLastTupleId(nlohmann::json& tuples); nlohmann::json Evaluate(nlohmann::json& tuples); + static duckdb::unique_ptr Bind(duckdb::ClientContext& context, duckdb::AggregateFunction& function, duckdb::vector>& arguments); + + public: static void Initialize(const duckdb::AggregateFunction& function, duckdb::data_ptr_t state_p) { AggregateFunctionBase::Initialize(function, state_p); diff --git a/src/include/flock/functions/aggregate/llm_reduce.hpp b/src/include/flock/functions/aggregate/llm_reduce.hpp index e5fe633b..7b663eac 100644 --- a/src/include/flock/functions/aggregate/llm_reduce.hpp +++ b/src/include/flock/functions/aggregate/llm_reduce.hpp @@ -1,6 +1,7 @@ #pragma once #include "flock/functions/aggregate/aggregate.hpp" +#include "flock/functions/llm_function_bind_data.hpp" namespace flock { @@ -12,6 +13,11 @@ class LlmReduce : public AggregateFunctionBase { nlohmann::json ReduceLoop(const nlohmann::json& tuples, const AggregateFunctionType& function_type); public: + static duckdb::unique_ptr Bind( + duckdb::ClientContext& context, + duckdb::AggregateFunction& function, + duckdb::vector>& arguments); + static void Initialize(const duckdb::AggregateFunction& function, duckdb::data_ptr_t state_p) { AggregateFunctionBase::Initialize(function, state_p); } diff --git a/src/include/flock/functions/aggregate/llm_rerank.hpp b/src/include/flock/functions/aggregate/llm_rerank.hpp index 4ff7d137..c6d2d41d 100644 --- a/src/include/flock/functions/aggregate/llm_rerank.hpp +++ b/src/include/flock/functions/aggregate/llm_rerank.hpp @@ -1,6 +1,7 @@ #pragma once #include "flock/functions/aggregate/aggregate.hpp" +#include "flock/functions/llm_function_bind_data.hpp" namespace flock { @@ -11,6 +12,12 @@ class LlmRerank : public AggregateFunctionBase { nlohmann::json SlidingWindow(nlohmann::json& tuples); std::vector RerankBatch(const nlohmann::json& tuples); +public: + static duckdb::unique_ptr Bind( + duckdb::ClientContext& context, + duckdb::AggregateFunction& function, + duckdb::vector>& arguments); + static void Initialize(const duckdb::AggregateFunction& function, duckdb::data_ptr_t state_p) { AggregateFunctionBase::Initialize(function, state_p); } diff --git a/src/include/flock/functions/input_parser.hpp b/src/include/flock/functions/input_parser.hpp index bc851700..04b848a3 100644 --- a/src/include/flock/functions/input_parser.hpp +++ b/src/include/flock/functions/input_parser.hpp @@ -7,5 +7,6 @@ namespace flock { nlohmann::json CastVectorOfStructsToJson(const duckdb::Vector& struct_vector, int size); +nlohmann::json CastValueToJson(const duckdb::Value& value); }// namespace flock diff --git a/src/include/flock/functions/llm_function_bind_data.hpp b/src/include/flock/functions/llm_function_bind_data.hpp new file mode 100644 index 00000000..9c96f4b9 --- /dev/null +++ b/src/include/flock/functions/llm_function_bind_data.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include "flock/core/common.hpp" +#include "flock/model_manager/model.hpp" + +namespace flock { + +struct LlmFunctionBindData : public duckdb::FunctionData { + nlohmann::json model_json;// Store model JSON to create fresh Model instances per call + std::string prompt; + + LlmFunctionBindData() = default; + + // Create a fresh Model instance (thread-safe, each call gets its own provider) + Model CreateModel() const { + return Model(model_json); + } + + duckdb::unique_ptr Copy() const override { + auto result = duckdb::make_uniq(); + result->model_json = model_json; + result->prompt = prompt; + return std::move(result); + } + + bool Equals(const duckdb::FunctionData& other) const override { + auto& other_bind = other.Cast(); + return prompt == other_bind.prompt && model_json == other_bind.model_json; + } +}; + +}// namespace flock diff --git a/src/include/flock/functions/scalar/llm_complete.hpp b/src/include/flock/functions/scalar/llm_complete.hpp index e682b7e3..5a5473cf 100644 --- a/src/include/flock/functions/scalar/llm_complete.hpp +++ b/src/include/flock/functions/scalar/llm_complete.hpp @@ -1,13 +1,18 @@ #pragma once +#include "flock/functions/llm_function_bind_data.hpp" #include "flock/functions/scalar/scalar.hpp" namespace flock { class LlmComplete : public ScalarFunctionBase { public: + static duckdb::unique_ptr Bind( + duckdb::ClientContext& context, + duckdb::ScalarFunction& bound_function, + duckdb::vector>& arguments); static void ValidateArguments(duckdb::DataChunk& args); - static std::vector Operation(duckdb::DataChunk& args); + static std::vector Operation(duckdb::DataChunk& args, LlmFunctionBindData* bind_data); static void Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); }; diff --git a/src/include/flock/functions/scalar/llm_embedding.hpp b/src/include/flock/functions/scalar/llm_embedding.hpp index 935ded90..2608aadb 100644 --- a/src/include/flock/functions/scalar/llm_embedding.hpp +++ b/src/include/flock/functions/scalar/llm_embedding.hpp @@ -1,13 +1,18 @@ #pragma once +#include "flock/functions/llm_function_bind_data.hpp" #include "flock/functions/scalar/scalar.hpp" namespace flock { class LlmEmbedding : public ScalarFunctionBase { public: + static duckdb::unique_ptr Bind( + duckdb::ClientContext& context, + duckdb::ScalarFunction& bound_function, + duckdb::vector>& arguments); static void ValidateArguments(duckdb::DataChunk& args); - static std::vector> Operation(duckdb::DataChunk& args); + static std::vector> Operation(duckdb::DataChunk& args, LlmFunctionBindData* bind_data); static void Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); }; diff --git a/src/include/flock/functions/scalar/llm_filter.hpp b/src/include/flock/functions/scalar/llm_filter.hpp index 4c391fdd..37490fb5 100644 --- a/src/include/flock/functions/scalar/llm_filter.hpp +++ b/src/include/flock/functions/scalar/llm_filter.hpp @@ -1,13 +1,18 @@ #pragma once +#include "flock/functions/llm_function_bind_data.hpp" #include "flock/functions/scalar/scalar.hpp" namespace flock { class LlmFilter : public ScalarFunctionBase { public: + static duckdb::unique_ptr Bind( + duckdb::ClientContext& context, + duckdb::ScalarFunction& bound_function, + duckdb::vector>& arguments); static void ValidateArguments(duckdb::DataChunk& args); - static std::vector Operation(duckdb::DataChunk& args); + static std::vector Operation(duckdb::DataChunk& args, LlmFunctionBindData* bind_data); static void Execute(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); }; diff --git a/src/include/flock/functions/scalar/scalar.hpp b/src/include/flock/functions/scalar/scalar.hpp index bc8585c7..8b514698 100644 --- a/src/include/flock/functions/scalar/scalar.hpp +++ b/src/include/flock/functions/scalar/scalar.hpp @@ -1,16 +1,43 @@ #pragma once #include -#include +#include #include "flock/core/common.hpp" #include "flock/functions/input_parser.hpp" +#include "flock/functions/llm_function_bind_data.hpp" #include "flock/model_manager/model.hpp" #include "flock/prompt_manager/prompt_manager.hpp" +#include namespace flock { class ScalarFunctionBase { +private: + struct PromptStructInfo { + bool has_context_columns; + std::optional prompt_field_index; + std::string prompt_field_name; + }; + + static void ValidateArgumentCount(const duckdb::vector>& arguments, + const std::string& function_name); + + static void ValidateArgumentTypes(const duckdb::vector>& arguments, + const std::string& function_name); + + static PromptStructInfo ExtractPromptStructInfo(const duckdb::LogicalType& prompt_type); + + static void ValidatePromptStructFields(const PromptStructInfo& info, const std::string& function_name, bool require_context_columns); + + static void InitializeModelJson(duckdb::ClientContext& context, + const duckdb::unique_ptr& model_expr, + LlmFunctionBindData& bind_data); + + static void InitializePrompt(duckdb::ClientContext& context, + const duckdb::unique_ptr& prompt_expr, + LlmFunctionBindData& bind_data); + public: ScalarFunctionBase() = delete; @@ -23,6 +50,13 @@ class ScalarFunctionBase { static nlohmann::json BatchAndComplete(const nlohmann::json& tuples, const std::string& user_prompt_name, ScalarFunctionType function_type, Model& model); + + static duckdb::unique_ptr ValidateAndInitializeBindData( + duckdb::ClientContext& context, + duckdb::vector>& arguments, + const std::string& function_name, + bool require_context_columns = true, + bool initialize_prompt = true); }; }// namespace flock diff --git a/src/include/flock/metrics/base_manager.hpp b/src/include/flock/metrics/base_manager.hpp new file mode 100644 index 00000000..2dfb1fb5 --- /dev/null +++ b/src/include/flock/metrics/base_manager.hpp @@ -0,0 +1,296 @@ +#pragma once + +#include "flock/metrics/data_structures.hpp" +#include +#include +#include +#include +#include + +namespace flock { + +// Core metrics tracking functionality shared between scalar and aggregate functions +template +class BaseMetricsManager { +public: + ThreadMetrics& GetThreadMetrics(const StateId& state_id) { + const auto tid = std::this_thread::get_id(); + auto& thread_map = thread_metrics_[tid]; + + auto it = thread_map.find(state_id); + if (it != thread_map.end()) { + return it->second; + } + + return thread_map[state_id]; + } + + void RegisterThread(const StateId& state_id) { + GetThreadMetrics(state_id); + } + + // Initialize metrics tracking and assign registration order + void StartInvocation(const StateId& state_id, FunctionType type) { + RegisterThread(state_id); + + const auto tid = std::this_thread::get_id(); + ThreadFunctionKey thread_function_key{tid, type}; + + if (thread_function_counters_.find(thread_function_key) == thread_function_counters_.end()) { + thread_function_counters_[thread_function_key] = 0; + } + + StateFunctionKey state_function_key{state_id, type}; + if (state_function_registration_order_.find(state_function_key) == state_function_registration_order_.end()) { + thread_function_counters_[thread_function_key]++; + state_function_registration_order_[state_function_key] = thread_function_counters_[thread_function_key]; + } + + GetThreadMetrics(state_id).GetMetrics(type); + } + + // Store model name and provider (first call wins) + void SetModelInfo(const StateId& state_id, FunctionType type, const std::string& model_name, const std::string& provider) { + auto& thread_metrics = GetThreadMetrics(state_id); + auto& metrics = thread_metrics.GetMetrics(type); + if (metrics.model_name.empty()) { + metrics.model_name = model_name; + } + if (metrics.provider.empty()) { + metrics.provider = provider; + } + } + + // Add input and output tokens (accumulative) + void UpdateTokens(const StateId& state_id, FunctionType type, int64_t input, int64_t output) { + auto& thread_metrics = GetThreadMetrics(state_id); + auto& metrics = thread_metrics.GetMetrics(type); + metrics.input_tokens += input; + metrics.output_tokens += output; + } + + // Increment API call counter + void IncrementApiCalls(const StateId& state_id, FunctionType type) { + GetThreadMetrics(state_id).GetMetrics(type).api_calls++; + } + + // Add API duration in microseconds (accumulative) + void AddApiDuration(const StateId& state_id, FunctionType type, int64_t duration_us) { + GetThreadMetrics(state_id).GetMetrics(type).api_duration_us += duration_us; + } + + // Add execution time in microseconds (accumulative) + void AddExecutionTime(const StateId& state_id, FunctionType type, int64_t duration_us) { + GetThreadMetrics(state_id).GetMetrics(type).execution_time_us += duration_us; + } + + // Get flattened metrics structure (merged across threads) + nlohmann::json GetMetrics() const { + nlohmann::json result = nlohmann::json::object(); + + struct Key { + FunctionType function_type; + size_t registration_order; + + bool operator==(const Key& other) const { + return function_type == other.function_type && registration_order == other.registration_order; + } + }; + + struct KeyHash { + size_t operator()(const Key& k) const { + return std::hash{}(static_cast(k.function_type)) ^ + (std::hash{}(k.registration_order) << 1); + } + }; + + std::unordered_map merged_metrics; + + // Collect and merge metrics by (function_type, registration_order) + for (const auto& [tid, state_map]: thread_metrics_) { + for (const auto& [state_id, thread_metrics]: state_map) { + if (thread_metrics.IsEmpty()) { + continue; + } + + for (size_t i = 0; i < ThreadMetrics::NUM_FUNCTION_TYPES - 1; ++i) { + const auto function_type = static_cast(i); + const auto& metrics = thread_metrics.GetMetrics(function_type); + + if (!metrics.IsEmpty()) { + StateFunctionKey state_function_key{state_id, function_type}; + auto order_it = state_function_registration_order_.find(state_function_key); + size_t registration_order = (order_it != state_function_registration_order_.end()) + ? order_it->second + : SIZE_MAX; + + Key key{function_type, registration_order}; + + auto& merged = merged_metrics[key]; + merged.input_tokens += metrics.input_tokens; + merged.output_tokens += metrics.output_tokens; + merged.api_calls += metrics.api_calls; + merged.api_duration_us += metrics.api_duration_us; + merged.execution_time_us += metrics.execution_time_us; + + if (merged.model_name.empty() && !metrics.model_name.empty()) { + merged.model_name = metrics.model_name; + } + if (merged.provider.empty() && !metrics.provider.empty()) { + merged.provider = metrics.provider; + } + } + } + } + } + + struct MetricEntry { + FunctionType function_type; + size_t registration_order; + FunctionMetricsData metrics; + }; + + std::vector entries; + entries.reserve(merged_metrics.size()); + + for (const auto& [key, metrics]: merged_metrics) { + entries.push_back({key.function_type, key.registration_order, metrics}); + } + + std::sort(entries.begin(), entries.end(), [](const MetricEntry& a, const MetricEntry& b) { + if (a.function_type != b.function_type) { + return a.function_type < b.function_type; + } + return a.registration_order < b.registration_order; + }); + + std::unordered_map function_counters; + + for (const auto& entry: entries) { + if (function_counters.find(entry.function_type) == function_counters.end()) { + function_counters[entry.function_type] = 0; + } + + function_counters[entry.function_type]++; + const std::string key = std::string(FunctionTypeToString(entry.function_type)) + "_" + std::to_string(function_counters[entry.function_type]); + + result[key] = entry.metrics.ToJson(); + } + + return result; + } + + // Get nested metrics structure preserving thread/state info (for debugging) + nlohmann::json GetDebugMetrics() const { + nlohmann::json result; + nlohmann::json threads_json = nlohmann::json::object(); + + size_t threads_with_output = 0; + + for (const auto& [tid, state_map]: thread_metrics_) { + std::ostringstream oss; + oss << tid; + const std::string thread_id_str = oss.str(); + + nlohmann::json thread_data; + bool thread_has_output = false; + + for (const auto& [state_id, thread_metrics]: state_map) { + if (thread_metrics.IsEmpty()) { + continue; + } + + std::ostringstream state_oss; + state_oss << state_id; + const std::string state_id_str = state_oss.str(); + + nlohmann::json state_data; + + for (size_t i = 0; i < ThreadMetrics::NUM_FUNCTION_TYPES - 1; ++i) { + const auto function_type = static_cast(i); + const auto& metrics = thread_metrics.GetMetrics(function_type); + + if (!metrics.IsEmpty()) { + StateFunctionKey state_function_key{state_id, function_type}; + auto order_it = state_function_registration_order_.find(state_function_key); + size_t registration_order = (order_it != state_function_registration_order_.end()) + ? order_it->second + : 0; + + nlohmann::json function_data = metrics.ToJson(); + function_data["registration_order"] = registration_order; + state_data[FunctionTypeToString(function_type)] = std::move(function_data); + } + } + + if (!state_data.empty()) { + thread_has_output = true; + thread_data[state_id_str] = std::move(state_data); + } + } + + if (thread_has_output) { + threads_with_output++; + threads_json[thread_id_str] = std::move(thread_data); + } + } + + result["threads"] = threads_json.empty() ? nlohmann::json::object() : std::move(threads_json); + result["thread_count"] = threads_with_output; + return result; + } + + // Clear all metrics and registration tracking + void Reset() { + thread_metrics_.clear(); + state_function_registration_order_.clear(); + thread_function_counters_.clear(); + } + +protected: + // Main storage: thread_id -> state_id -> ThreadMetrics + std::unordered_map, ThreadIdHash> thread_metrics_; + + // Registration order tracking structures + struct ThreadFunctionKey { + std::thread::id thread_id; + FunctionType function_type; + + bool operator==(const ThreadFunctionKey& other) const { + return thread_id == other.thread_id && function_type == other.function_type; + } + }; + + struct ThreadFunctionKeyHash { + size_t operator()(const ThreadFunctionKey& k) const { + return ThreadIdHash{}(k.thread_id) ^ + (std::hash{}(static_cast(k.function_type)) << 1); + } + }; + + struct StateFunctionKey { + StateId state_id; + FunctionType function_type; + + bool operator==(const StateFunctionKey& other) const { + return state_id == other.state_id && function_type == other.function_type; + } + }; + + struct StateFunctionKeyHash { + size_t operator()(const StateFunctionKey& k) const { + size_t state_hash = 0; + if constexpr (std::is_pointer_v) { + state_hash = std::hash{}(reinterpret_cast(k.state_id)); + } else { + state_hash = std::hash{}(k.state_id); + } + return state_hash ^ (std::hash{}(static_cast(k.function_type)) << 1); + } + }; + + std::unordered_map state_function_registration_order_; + std::unordered_map thread_function_counters_; +}; + +}// namespace flock diff --git a/src/include/flock/metrics/data_structures.hpp b/src/include/flock/metrics/data_structures.hpp new file mode 100644 index 00000000..27741e81 --- /dev/null +++ b/src/include/flock/metrics/data_structures.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include "flock/core/common.hpp" +#include "flock/metrics/types.hpp" +#include +#include +#include +#include +#include + +namespace flock { + +// Stores aggregated metrics for a single function call +struct FunctionMetricsData { + std::string model_name; + std::string provider; + int64_t input_tokens = 0; + int64_t output_tokens = 0; + int64_t api_calls = 0; + int64_t api_duration_us = 0; + int64_t execution_time_us = 0; + + int64_t total_tokens() const noexcept { + return input_tokens + output_tokens; + } + + double api_duration_ms() const noexcept { + return api_duration_us / 1000.0; + } + + double execution_time_ms() const noexcept { + return execution_time_us / 1000.0; + } + + bool IsEmpty() const noexcept { + return input_tokens == 0 && output_tokens == 0 && api_calls == 0 && + api_duration_us == 0 && execution_time_us == 0; + } + + nlohmann::json ToJson() const { + nlohmann::json result = { + {"input_tokens", input_tokens}, + {"output_tokens", output_tokens}, + {"total_tokens", total_tokens()}, + {"api_calls", api_calls}, + {"api_duration_ms", api_duration_ms()}, + {"execution_time_ms", execution_time_ms()}}; + + if (!model_name.empty()) { + result["model_name"] = model_name; + } + if (!provider.empty()) { + result["provider"] = provider; + } + + return result; + } +}; + +// Stores metrics for all function types in a single state +class ThreadMetrics { +public: + static constexpr size_t NUM_FUNCTION_TYPES = 8; + + void Reset() noexcept { + for (auto& func_metrics: by_function_) { + func_metrics = FunctionMetricsData{}; + } + } + + FunctionMetricsData& GetMetrics(FunctionType type) { + return by_function_[FunctionTypeToIndex(type)]; + } + + const FunctionMetricsData& GetMetrics(FunctionType type) const noexcept { + return by_function_[FunctionTypeToIndex(type)]; + } + + bool IsEmpty() const noexcept { + for (const auto& func_metrics: by_function_) { + if (!func_metrics.IsEmpty()) { + return false; + } + } + return true; + } + +private: + FunctionMetricsData by_function_[NUM_FUNCTION_TYPES]; +}; + +struct ThreadIdHash { + size_t operator()(const std::thread::id& id) const noexcept { + return std::hash{}(id); + } +}; + +}// namespace flock diff --git a/src/include/flock/metrics/manager.hpp b/src/include/flock/metrics/manager.hpp new file mode 100644 index 00000000..cb71a350 --- /dev/null +++ b/src/include/flock/metrics/manager.hpp @@ -0,0 +1,122 @@ +#pragma once + +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/main/database.hpp" +#include "flock/metrics/base_manager.hpp" +#include "flock/metrics/types.hpp" +#include +#include +#include + +namespace flock { + +// Database-level metrics storage and unified API for scalar and aggregate functions +class MetricsManager : public BaseMetricsManager { +public: + // Get metrics manager for a database instance (creates if needed) + static MetricsManager& GetForDatabase(duckdb::DatabaseInstance* db) { + if (db == nullptr) { + throw std::runtime_error("Database instance is null"); + } + + static std::unordered_map> db_managers; + + auto it = db_managers.find(db); + if (it == db_managers.end()) { + auto manager = std::make_unique(); + auto* manager_ptr = manager.get(); + db_managers[db] = std::move(manager); + return *manager_ptr; + } + return *it->second; + } + + // Generate a unique invocation ID for scalar functions + static const void* GenerateUniqueId() { + static std::atomic counter{0}; + return reinterpret_cast(++counter); + } + + // Initialize metrics tracking (stores context for subsequent calls) + static void StartInvocation(duckdb::DatabaseInstance* db, const void* state_id, FunctionType type) { + if (db != nullptr && state_id != nullptr) { + current_db_ = db; + current_state_id_ = state_id; + current_function_type_ = type; + + auto& manager = GetForDatabase(db); + manager.RegisterThread(state_id); + manager.BaseMetricsManager::StartInvocation(state_id, type); + } + } + + // Record model name and provider + static void SetModelInfo(const std::string& model_name, const std::string& provider) { + if (current_db_ != nullptr && current_state_id_ != nullptr) { + auto& manager = GetForDatabase(current_db_); + manager.BaseMetricsManager::SetModelInfo(current_state_id_, current_function_type_, model_name, provider); + } + } + + // Record token usage (accumulative) + static void UpdateTokens(int64_t input, int64_t output) { + if (current_db_ != nullptr && current_state_id_ != nullptr) { + auto& manager = GetForDatabase(current_db_); + manager.BaseMetricsManager::UpdateTokens(current_state_id_, current_function_type_, input, output); + } + } + + // Increment API call counter + static void IncrementApiCalls() { + if (current_db_ != nullptr && current_state_id_ != nullptr) { + auto& manager = GetForDatabase(current_db_); + manager.BaseMetricsManager::IncrementApiCalls(current_state_id_, current_function_type_); + } + } + + // Record API call duration in milliseconds (accumulative) + static void AddApiDuration(double duration_ms) { + if (current_db_ != nullptr && current_state_id_ != nullptr) { + const int64_t duration_us = static_cast(duration_ms * 1000.0); + auto& manager = GetForDatabase(current_db_); + manager.BaseMetricsManager::AddApiDuration(current_state_id_, current_function_type_, duration_us); + } + } + + // Record execution time in milliseconds (accumulative) + static void AddExecutionTime(double duration_ms) { + if (current_db_ != nullptr && current_state_id_ != nullptr) { + const int64_t duration_us = static_cast(duration_ms * 1000.0); + auto& manager = GetForDatabase(current_db_); + manager.BaseMetricsManager::AddExecutionTime(current_state_id_, current_function_type_, duration_us); + } + } + + // Clear stored context (optional, auto-cleared on next StartInvocation) + static void ClearContext() { + current_db_ = nullptr; + current_state_id_ = nullptr; + current_function_type_ = FunctionType::UNKNOWN; + } + + // Merge metrics from multiple states into a single state + // This is used by aggregate functions to consolidate metrics from all processed states + static void MergeAggregateMetrics(duckdb::DatabaseInstance* db, + const std::vector& processed_state_ids, + FunctionType function_type, + const std::string& model_name = "", + const std::string& provider = ""); + + // SQL function implementations + static void ExecuteGetMetrics(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); + static void ExecuteGetDebugMetrics(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); + static void ExecuteResetMetrics(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result); + +private: + // Thread-local storage for current metrics context + static thread_local duckdb::DatabaseInstance* current_db_; + static thread_local const void* current_state_id_; + static thread_local FunctionType current_function_type_; +}; + +}// namespace flock diff --git a/src/include/flock/metrics/types.hpp b/src/include/flock/metrics/types.hpp new file mode 100644 index 00000000..e7c24c7a --- /dev/null +++ b/src/include/flock/metrics/types.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +namespace flock { + +enum class FunctionType : uint8_t { + LLM_COMPLETE = 0, + LLM_FILTER = 1, + LLM_EMBEDDING = 2, + LLM_REDUCE = 3, + LLM_RERANK = 4, + LLM_FIRST = 5, + LLM_LAST = 6, + UNKNOWN = 7 +}; + +inline constexpr const char* FunctionTypeToString(FunctionType type) noexcept { + switch (type) { + case FunctionType::LLM_COMPLETE: + return "llm_complete"; + case FunctionType::LLM_FILTER: + return "llm_filter"; + case FunctionType::LLM_EMBEDDING: + return "llm_embedding"; + case FunctionType::LLM_REDUCE: + return "llm_reduce"; + case FunctionType::LLM_RERANK: + return "llm_rerank"; + case FunctionType::LLM_FIRST: + return "llm_first"; + case FunctionType::LLM_LAST: + return "llm_last"; + default: + return "unknown"; + } +} + +inline constexpr size_t FunctionTypeToIndex(FunctionType type) noexcept { + return static_cast(type); +} + +}// namespace flock diff --git a/src/include/flock/model_manager/model.hpp b/src/include/flock/model_manager/model.hpp index bf1b2dc6..7dcc63f0 100644 --- a/src/include/flock/model_manager/model.hpp +++ b/src/include/flock/model_manager/model.hpp @@ -1,19 +1,18 @@ #pragma once #include "fmt/format.h" -#include -#include -#include +#include #include -#include #include "duckdb/main/connection.hpp" +#include "flock/core/common.hpp" #include "flock/core/config.hpp" #include "flock/model_manager/providers/adapters/azure.hpp" #include "flock/model_manager/providers/adapters/ollama.hpp" #include "flock/model_manager/providers/adapters/openai.hpp" #include "flock/model_manager/providers/handlers/ollama.hpp" #include "flock/model_manager/repository.hpp" +#include namespace flock { @@ -28,15 +27,32 @@ class Model { explicit Model() = default; void AddCompletionRequest(const std::string& prompt, const int num_output_tuples, OutputType output_type = OutputType::STRING, const nlohmann::json& media_data = nlohmann::json::object()); void AddEmbeddingRequest(const std::vector& inputs); + void AddTranscriptionRequest(const nlohmann::json& audio_files); std::vector CollectCompletions(const std::string& contentType = "application/json"); std::vector CollectEmbeddings(const std::string& contentType = "application/json"); + std::vector CollectTranscriptions(const std::string& contentType = "multipart/form-data"); ModelDetails GetModelDetails(); + nlohmann::json GetModelDetailsAsJson() const; + + // Static helper method for binders to resolve model details to JSON + static nlohmann::json ResolveModelDetailsToJson(const nlohmann::json& user_model_json); + + // Factory function type for creating mock providers + using MockProviderFactory = std::function()>; + // Set a factory to create fresh mock providers (each Model gets its own instance) + static void SetMockProviderFactory(MockProviderFactory factory) { + mock_provider_factory_ = std::move(factory); + } + + // Legacy: Set a shared mock provider (for backward compatibility - less safe for parallel tests) static void SetMockProvider(const std::shared_ptr& mock_provider) { mock_provider_ = mock_provider; } + static void ResetMockProvider() { mock_provider_ = nullptr; + mock_provider_factory_ = nullptr; } std::shared_ptr @@ -45,6 +61,7 @@ class Model { private: ModelDetails model_details_; inline static std::shared_ptr mock_provider_ = nullptr; + inline static MockProviderFactory mock_provider_factory_ = nullptr; void ConstructProvider(); void LoadModelDetails(const nlohmann::json& model_json); static std::tuple> GetQueriedModel(const std::string& model_name); diff --git a/src/include/flock/model_manager/providers/adapters/azure.hpp b/src/include/flock/model_manager/providers/adapters/azure.hpp index a4493d82..0a93a231 100644 --- a/src/include/flock/model_manager/providers/adapters/azure.hpp +++ b/src/include/flock/model_manager/providers/adapters/azure.hpp @@ -15,6 +15,7 @@ class AzureProvider : public IProvider { void AddCompletionRequest(const std::string& prompt, const int num_output_tuples, OutputType output_type, const nlohmann::json& media_data) override; void AddEmbeddingRequest(const std::vector& inputs) override; + void AddTranscriptionRequest(const nlohmann::json& audio_files) override; }; }// namespace flock diff --git a/src/include/flock/model_manager/providers/adapters/ollama.hpp b/src/include/flock/model_manager/providers/adapters/ollama.hpp index 7f0c62c8..0ed7d44d 100644 --- a/src/include/flock/model_manager/providers/adapters/ollama.hpp +++ b/src/include/flock/model_manager/providers/adapters/ollama.hpp @@ -13,6 +13,7 @@ class OllamaProvider : public IProvider { void AddCompletionRequest(const std::string& prompt, const int num_output_tuples, OutputType output_type, const nlohmann::json& media_data) override; void AddEmbeddingRequest(const std::vector& inputs) override; + void AddTranscriptionRequest(const nlohmann::json& audio_files) override; }; }// namespace flock diff --git a/src/include/flock/model_manager/providers/adapters/openai.hpp b/src/include/flock/model_manager/providers/adapters/openai.hpp index a9d104c8..9b416c44 100644 --- a/src/include/flock/model_manager/providers/adapters/openai.hpp +++ b/src/include/flock/model_manager/providers/adapters/openai.hpp @@ -18,6 +18,7 @@ class OpenAIProvider : public IProvider { void AddCompletionRequest(const std::string& prompt, const int num_output_tuples, OutputType output_type, const nlohmann::json& media_data) override; void AddEmbeddingRequest(const std::vector& inputs) override; + void AddTranscriptionRequest(const nlohmann::json& audio_files) override; }; }// namespace flock diff --git a/src/include/flock/model_manager/providers/handlers/azure.hpp b/src/include/flock/model_manager/providers/handlers/azure.hpp index 871cc758..16027c4a 100644 --- a/src/include/flock/model_manager/providers/handlers/azure.hpp +++ b/src/include/flock/model_manager/providers/handlers/azure.hpp @@ -18,7 +18,11 @@ class AzureModelManager : public BaseModelProviderHandler { AzureModelManager& operator=(AzureModelManager&&) = delete; protected: - void checkProviderSpecificResponse(const nlohmann::json& response, bool is_completion) override { + void checkProviderSpecificResponse(const nlohmann::json& response, RequestType request_type) override { + if (request_type == RequestType::Transcription) { + return;// No specific checks needed for transcriptions + } + bool is_completion = (request_type == RequestType::Completion); if (is_completion) { if (response.contains("choices") && response["choices"].is_array() && !response["choices"].empty()) { const auto& choice = response["choices"][0]; @@ -30,7 +34,6 @@ class AzureModelManager : public BaseModelProviderHandler { } } } else { - // Embedding-specific checks (if any) can be added here if (response.contains("data") && response["data"].is_array() && response["data"].empty()) { throw std::runtime_error("Azure API returned empty embedding data."); } @@ -44,6 +47,10 @@ class AzureModelManager : public BaseModelProviderHandler { return "https://" + _resource_name + ".openai.azure.com/openai/deployments/" + _deployment_model_name + "/embeddings?api-version=" + _api_version; } + std::string getTranscriptionUrl() const override { + return "https://" + _resource_name + ".openai.azure.com/openai/deployments/" + + _deployment_model_name + "/audio/transcriptions?api-version=" + _api_version; + } void prepareSessionForRequest(const std::string& url) override { _session.setUrl(url); } @@ -66,6 +73,31 @@ class AzureModelManager : public BaseModelProviderHandler { return {}; } + std::pair ExtractTokenUsage(const nlohmann::json& response) const override { + int64_t input_tokens = 0; + int64_t output_tokens = 0; + if (response.contains("usage") && response["usage"].is_object()) { + const auto& usage = response["usage"]; + if (usage.contains("prompt_tokens") && usage["prompt_tokens"].is_number()) { + input_tokens = usage["prompt_tokens"].get(); + } + if (usage.contains("completion_tokens") && usage["completion_tokens"].is_number()) { + output_tokens = usage["completion_tokens"].get(); + } + } + return {input_tokens, output_tokens}; + } + + + nlohmann::json ExtractTranscriptionOutput(const nlohmann::json& response) const override { + // Transcription API returns JSON with "text" field when response_format=json + if (response.contains("text") && !response["text"].is_null()) { + return response["text"].get(); + } + return ""; + } + + std::string _token; std::string _resource_name; std::string _deployment_model_name; diff --git a/src/include/flock/model_manager/providers/handlers/base_handler.hpp b/src/include/flock/model_manager/providers/handlers/base_handler.hpp index 998c9cf7..d4ba5f53 100644 --- a/src/include/flock/model_manager/providers/handlers/base_handler.hpp +++ b/src/include/flock/model_manager/providers/handlers/base_handler.hpp @@ -1,10 +1,13 @@ #pragma once +#include "flock/core/common.hpp" +#include "flock/metrics/manager.hpp" #include "flock/model_manager/providers/handlers/handler.hpp" #include "session.hpp" -#include +#include +#include +#include #include -#include #include #include @@ -16,56 +19,154 @@ class BaseModelProviderHandler : public IModelProviderHandler { : _throw_exception(throw_exception) {} virtual ~BaseModelProviderHandler() = default; - // AddRequest: just add the json to the batch (type is ignored, kept for interface compatibility) - void AddRequest(const nlohmann::json& json, RequestType type = RequestType::Completion) { + void AddRequest(const nlohmann::json& json, RequestType type = RequestType::Completion) override { _request_batch.push_back(json); + _request_types.push_back(type); } - // CollectCompletions: process all as completions, then clear std::vector CollectCompletions(const std::string& contentType = "application/json") { std::vector completions; - if (!_request_batch.empty()) completions = ExecuteBatch(_request_batch, true, contentType, true); + if (!_request_batch.empty()) completions = ExecuteBatch(_request_batch, true, contentType, RequestType::Completion); _request_batch.clear(); return completions; } - // CollectEmbeddings: process all as embeddings, then clear std::vector CollectEmbeddings(const std::string& contentType = "application/json") { std::vector embeddings; - if (!_request_batch.empty()) embeddings = ExecuteBatch(_request_batch, true, contentType, false); + if (!_request_batch.empty()) embeddings = ExecuteBatch(_request_batch, true, contentType, RequestType::Embedding); _request_batch.clear(); return embeddings; } - // Unified batch implementation with customizable headers - std::vector ExecuteBatch(const std::vector& jsons, bool async = true, const std::string& contentType = "application/json", bool is_completion = true) { + + std::vector CollectTranscriptions(const std::string& contentType = "multipart/form-data") override { + std::vector transcriptions; + if (!_request_batch.empty()) { + std::vector transcription_batch; + for (size_t i = 0; i < _request_batch.size(); ++i) { + if (_request_types[i] == RequestType::Transcription) { + transcription_batch.push_back(_request_batch[i]); + } + } + + if (!transcription_batch.empty()) { + transcriptions = ExecuteBatch(transcription_batch, true, contentType, RequestType::Transcription); + // Remove transcription requests from batch + for (size_t i = _request_batch.size(); i > 0; --i) { + if (_request_types[i - 1] == RequestType::Transcription) { + _request_batch.erase(_request_batch.begin() + i - 1); + _request_types.erase(_request_types.begin() + i - 1); + } + } + } + } + return transcriptions; + } + + +public: +protected: + std::vector ExecuteBatch(const std::vector& jsons, bool async = true, const std::string& contentType = "application/json", RequestType request_type = RequestType::Completion) { struct CurlRequestData { std::string response; CURL* easy = nullptr; std::string payload; + curl_mime* mime_form = nullptr; + std::string temp_file_path; + bool is_temp_file; }; std::vector requests(jsons.size()); CURLM* multi_handle = curl_multi_init(); - struct curl_slist* headers = nullptr; - headers = curl_slist_append(headers, "Content-Type: application/json"); - for (const auto& h: getExtraHeaders()) { - headers = curl_slist_append(headers, h.c_str()); + + // Determine URL based on request type + std::string url; + bool is_transcription = (request_type == RequestType::Transcription); + bool is_completion = (request_type == RequestType::Completion); + if (is_transcription) { + url = getTranscriptionUrl(); + } else if (is_completion) { + url = getCompletionUrl(); + } else { + url = getEmbedUrl(); } - auto url = is_completion ? getCompletionUrl() : getEmbedUrl(); + + // Prepare all requests for (size_t i = 0; i < jsons.size(); ++i) { - requests[i].payload = jsons[i].dump(); requests[i].easy = curl_easy_init(); curl_easy_setopt(requests[i].easy, CURLOPT_URL, url.c_str()); - curl_easy_setopt(requests[i].easy, CURLOPT_HTTPHEADER, headers); - curl_easy_setopt(requests[i].easy, CURLOPT_WRITEFUNCTION, +[](char* ptr, size_t size, size_t nmemb, void* userdata) -> size_t { + + if (is_transcription) { + // Handle transcription requests (multipart/form-data) + const auto& req = jsons[i]; + if (!req.contains("file_path") || req["file_path"].is_null()) { + trigger_error("Missing or null file_path in transcription request"); + } + if (!req.contains("model") || req["model"].is_null()) { + trigger_error("Missing or null model in transcription request"); + } + auto file_path = req["file_path"].get(); + auto model = req["model"].get(); + auto prompt = req.contains("prompt") && !req["prompt"].is_null() ? req["prompt"].get() : ""; + requests[i].is_temp_file = req.contains("is_temp_file") ? req["is_temp_file"].get() : false; + if (requests[i].is_temp_file) { + requests[i].temp_file_path = file_path; + } + + // Set up multipart form data + requests[i].mime_form = curl_mime_init(requests[i].easy); + curl_mimepart* field = curl_mime_addpart(requests[i].mime_form); + curl_mime_name(field, "file"); + curl_mime_filedata(field, file_path.c_str()); + + field = curl_mime_addpart(requests[i].mime_form); + curl_mime_name(field, "model"); + curl_mime_data(field, model.c_str(), CURL_ZERO_TERMINATED); + + field = curl_mime_addpart(requests[i].mime_form); + curl_mime_name(field, "response_format"); + curl_mime_data(field, "json", CURL_ZERO_TERMINATED); + + if (!prompt.empty()) { + field = curl_mime_addpart(requests[i].mime_form); + curl_mime_name(field, "prompt"); + curl_mime_data(field, prompt.c_str(), CURL_ZERO_TERMINATED); + } + + curl_easy_setopt(requests[i].easy, CURLOPT_MIMEPOST, requests[i].mime_form); + + // Set headers + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Expect:"); + for (const auto& h: getExtraHeaders()) { + headers = curl_slist_append(headers, h.c_str()); + } + curl_easy_setopt(requests[i].easy, CURLOPT_HTTPHEADER, headers); + } else { + // Handle JSON requests (completions/embeddings) + requests[i].payload = jsons[i].dump(); + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + for (const auto& h: getExtraHeaders()) { + headers = curl_slist_append(headers, h.c_str()); + } + curl_easy_setopt(requests[i].easy, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(requests[i].easy, CURLOPT_POST, 1L); + curl_easy_setopt(requests[i].easy, CURLOPT_POSTFIELDS, requests[i].payload.c_str()); + } + + // Set response callback + curl_easy_setopt( + requests[i].easy, CURLOPT_WRITEFUNCTION, +[](char* ptr, size_t size, size_t nmemb, void* userdata) -> size_t { std::string* resp = static_cast(userdata); resp->append(ptr, size * nmemb); return size * nmemb; }); curl_easy_setopt(requests[i].easy, CURLOPT_WRITEDATA, &requests[i].response); - curl_easy_setopt(requests[i].easy, CURLOPT_POST, 1L); - curl_easy_setopt(requests[i].easy, CURLOPT_POSTFIELDS, requests[i].payload.c_str()); + curl_multi_add_handle(multi_handle, requests[i].easy); } + + auto api_start = std::chrono::high_resolution_clock::now(); + int still_running = 0; curl_multi_perform(multi_handle, &still_running); while (still_running) { @@ -73,28 +174,63 @@ class BaseModelProviderHandler : public IModelProviderHandler { curl_multi_wait(multi_handle, NULL, 0, 1000, &numfds); curl_multi_perform(multi_handle, &still_running); } + + auto api_end = std::chrono::high_resolution_clock::now(); + double api_duration_ms = std::chrono::duration(api_end - api_start).count(); + + int64_t batch_input_tokens = 0; + int64_t batch_output_tokens = 0; + std::vector results(jsons.size()); for (size_t i = 0; i < requests.size(); ++i) { + // Clean up temp files for transcriptions + if (is_transcription && requests[i].is_temp_file && !requests[i].temp_file_path.empty()) { + std::remove(requests[i].temp_file_path.c_str()); + } + curl_easy_getinfo(requests[i].easy, CURLINFO_RESPONSE_CODE, NULL); - if (!requests[i].response.empty() && isJson(requests[i].response)) { + + if (isJson(requests[i].response)) { try { nlohmann::json parsed = nlohmann::json::parse(requests[i].response); - checkResponse(parsed, is_completion); - if (is_completion) { - results[i] = ExtractCompletionOutput(parsed); - } else { - results[i] = ExtractEmbeddingVector(parsed); + checkResponse(parsed, request_type); + + // Extract token usage for completions/embeddings + if (!is_transcription) { + auto [input_tokens, output_tokens] = ExtractTokenUsage(parsed); + batch_input_tokens += input_tokens; + batch_output_tokens += output_tokens; + } + + // Let provider extract output based on request type + try { + results[i] = ExtractOutput(parsed, request_type); + } catch (const std::exception& e) { + trigger_error(std::string("Output extraction error: ") + e.what()); } } catch (const std::exception& e) { - trigger_error(std::string("JSON parse error: ") + e.what()); + trigger_error(std::string("Response processing error: ") + e.what()); } } else { - trigger_error("Empty or invalid response in batch"); + trigger_error("Invalid JSON response: " + requests[i].response); + } + + // Clean up mime form for transcriptions + if (is_transcription && requests[i].mime_form) { + curl_mime_free(requests[i].mime_form); } curl_multi_remove_handle(multi_handle, requests[i].easy); curl_easy_cleanup(requests[i].easy); } - curl_slist_free_all(headers); + + if (!is_transcription) { + MetricsManager::UpdateTokens(batch_input_tokens, batch_output_tokens); + } + MetricsManager::AddApiDuration(api_duration_ms); + for (size_t i = 0; i < jsons.size(); ++i) { + MetricsManager::IncrementApiCalls(); + } + curl_multi_cleanup(multi_handle); return results; } @@ -105,14 +241,29 @@ class BaseModelProviderHandler : public IModelProviderHandler { protected: bool _throw_exception; std::vector _request_batch; + std::vector _request_types; virtual std::string getCompletionUrl() const = 0; virtual std::string getEmbedUrl() const = 0; + virtual std::string getTranscriptionUrl() const = 0; virtual void prepareSessionForRequest(const std::string& url) = 0; virtual std::vector getExtraHeaders() const { return {}; } - virtual void checkProviderSpecificResponse(const nlohmann::json&, bool is_completion) {} + virtual void checkProviderSpecificResponse(const nlohmann::json&, RequestType request_type) {} virtual nlohmann::json ExtractCompletionOutput(const nlohmann::json&) const { return {}; } virtual nlohmann::json ExtractEmbeddingVector(const nlohmann::json&) const { return {}; } + virtual nlohmann::json ExtractTranscriptionOutput(const nlohmann::json&) const = 0; + + // Unified extraction method - delegates to specific Extract* methods based on request type + nlohmann::json ExtractOutput(const nlohmann::json& parsed, RequestType request_type) const { + if (request_type == RequestType::Completion) { + return ExtractCompletionOutput(parsed); + } else if (request_type == RequestType::Embedding) { + return ExtractEmbeddingVector(parsed); + } else { + return ExtractTranscriptionOutput(parsed); + } + } + virtual std::pair ExtractTokenUsage(const nlohmann::json& response) const = 0; void trigger_error(const std::string& msg) { if (_throw_exception) { @@ -122,14 +273,14 @@ class BaseModelProviderHandler : public IModelProviderHandler { } } - void checkResponse(const nlohmann::json& json, bool is_completion) { + void checkResponse(const nlohmann::json& json, RequestType request_type) { if (json.contains("error")) { auto reason = json["error"].dump(); trigger_error(reason); std::cerr << ">> response error :\n" << json.dump(2) << "\n"; } - checkProviderSpecificResponse(json, is_completion); + checkProviderSpecificResponse(json, request_type); } bool isJson(const std::string& data) { diff --git a/src/include/flock/model_manager/providers/handlers/handler.hpp b/src/include/flock/model_manager/providers/handlers/handler.hpp index 8cf18d79..11540de5 100644 --- a/src/include/flock/model_manager/providers/handlers/handler.hpp +++ b/src/include/flock/model_manager/providers/handlers/handler.hpp @@ -1,5 +1,6 @@ #pragma once +#include "flock/core/common.hpp" #include namespace flock { @@ -7,16 +8,19 @@ namespace flock { class IModelProviderHandler { public: enum class RequestType { Completion, - Embedding }; + Embedding, + Transcription }; virtual ~IModelProviderHandler() = default; - // AddRequest: type distinguishes between completion and embedding (default: Completion) + // AddRequest: type distinguishes between completion, embedding, and transcription (default: Completion) virtual void AddRequest(const nlohmann::json& json, RequestType type = RequestType::Completion) = 0; // CollectCompletions: process all as completions, then clear virtual std::vector CollectCompletions(const std::string& contentType = "application/json") = 0; // CollectEmbeddings: process all as embeddings, then clear virtual std::vector CollectEmbeddings(const std::string& contentType = "application/json") = 0; + // CollectTranscriptions: process all transcriptions, then clear + virtual std::vector CollectTranscriptions(const std::string& contentType = "multipart/form-data") = 0; }; }// namespace flock diff --git a/src/include/flock/model_manager/providers/handlers/ollama.hpp b/src/include/flock/model_manager/providers/handlers/ollama.hpp index 8bf43686..51b88e1e 100644 --- a/src/include/flock/model_manager/providers/handlers/ollama.hpp +++ b/src/include/flock/model_manager/providers/handlers/ollama.hpp @@ -4,11 +4,6 @@ #include "session.hpp" #include #include -#include -#include -#include -#include -#include namespace flock { @@ -23,8 +18,9 @@ class OllamaModelManager : public BaseModelProviderHandler { OllamaModelManager& operator=(OllamaModelManager&&) = delete; protected: - std::string getCompletionUrl() const override { return _url + "/api/generate"; } + std::string getCompletionUrl() const override { return _url + "/api/chat"; } std::string getEmbedUrl() const override { return _url + "/api/embed"; } + std::string getTranscriptionUrl() const override { return ""; } void prepareSessionForRequest(const std::string& url) override { _session.setUrl(url); } void setParameters(const std::string& data, const std::string& contentType = "") override { if (contentType != "multipart/form-data") { @@ -34,14 +30,19 @@ class OllamaModelManager : public BaseModelProviderHandler { auto postRequest(const std::string& contentType) -> decltype(((Session*) nullptr)->postPrepareOllama(contentType)) override { return _session.postPrepareOllama(contentType); } - void checkProviderSpecificResponse(const nlohmann::json& response, bool is_completion) override { + void checkProviderSpecificResponse(const nlohmann::json& response, RequestType request_type) override { + if (request_type == RequestType::Transcription) { + return;// No specific checks needed for transcriptions + } + bool is_completion = (request_type == RequestType::Completion); if (is_completion) { - if ((response.contains("done_reason") && response["done_reason"] != "stop") || - (response.contains("done") && !response["done"].is_null() && response["done"].get() != true)) { + if (response.contains("done_reason") && response["done_reason"] != "stop") { throw std::runtime_error("The request was refused due to some internal error with Ollama API"); } + if (response.contains("done") && !response["done"].is_null() && !response["done"].get()) { + throw std::runtime_error("The request was not completed by Ollama API"); + } } else { - // Embedding-specific checks (if any) can be added here if (response.contains("embeddings") && (!response["embeddings"].is_array() || response["embeddings"].empty())) { throw std::runtime_error("Ollama API returned empty or invalid embedding data."); } @@ -49,10 +50,43 @@ class OllamaModelManager : public BaseModelProviderHandler { } nlohmann::json ExtractCompletionOutput(const nlohmann::json& response) const override { - if (response.contains("response")) { - return nlohmann::json::parse(response["response"].get()); + if (response.contains("message") && response["message"].is_object()) { + const auto& message = response["message"]; + if (message.contains("content")) { + const auto& content = message["content"]; + if (content.is_null()) { + std::cerr << "Error: Ollama API returned null content in message. Full response: " << response.dump(2) << std::endl; + throw std::runtime_error("Ollama API returned null content in message. Response: " + response.dump()); + } + if (content.is_string()) { + try { + auto parsed = nlohmann::json::parse(content.get()); + // Validate that parsed result has expected structure for aggregate functions + if (!parsed.contains("items") || !parsed["items"].is_array()) { + std::cerr << "Warning: Parsed content does not contain 'items' array. Parsed: " << parsed.dump(2) << std::endl; + } + return parsed; + } catch (const std::exception& e) { + std::cerr << "Error: Failed to parse Ollama response content as JSON: " << e.what() << std::endl; + std::cerr << "Content was: " << content.dump() << std::endl; + throw std::runtime_error("Failed to parse Ollama response as JSON: " + std::string(e.what()) + ". Content: " + content.dump()); + } + } else { + // Content might already be a JSON object + // Validate structure + if (!content.contains("items") || !content["items"].is_array()) { + std::cerr << "Warning: Content does not contain 'items' array. Content: " << content.dump(2) << std::endl; + } + return content; + } + } else { + std::cerr << "Error: Ollama API response missing 'content' field in message. Full response: " << response.dump(2) << std::endl; + throw std::runtime_error("Ollama API response missing message.content field. Response: " + response.dump()); + } + } else { + std::cerr << "Error: Ollama API response missing 'message' object. Full response: " << response.dump(2) << std::endl; + throw std::runtime_error("Ollama API response missing message field. Response: " + response.dump()); } - return {}; } nlohmann::json ExtractEmbeddingVector(const nlohmann::json& response) const override { @@ -62,6 +96,24 @@ class OllamaModelManager : public BaseModelProviderHandler { return {}; } + std::pair ExtractTokenUsage(const nlohmann::json& response) const override { + int64_t input_tokens = 0; + int64_t output_tokens = 0; + if (response.contains("prompt_eval_count") && response["prompt_eval_count"].is_number()) { + input_tokens = response["prompt_eval_count"].get(); + } + if (response.contains("eval_count") && response["eval_count"].is_number()) { + output_tokens = response["eval_count"].get(); + } + return {input_tokens, output_tokens}; + } + + + nlohmann::json ExtractTranscriptionOutput(const nlohmann::json& response) const override { + throw std::runtime_error("Audio transcription is not supported for Ollama provider, use Azure or OpenAI instead."); + } + + Session _session; std::string _url; }; diff --git a/src/include/flock/model_manager/providers/handlers/openai.hpp b/src/include/flock/model_manager/providers/handlers/openai.hpp index 83c9625f..8162e341 100644 --- a/src/include/flock/model_manager/providers/handlers/openai.hpp +++ b/src/include/flock/model_manager/providers/handlers/openai.hpp @@ -3,10 +3,6 @@ #include "flock/model_manager/providers/handlers/base_handler.hpp" #include "session.hpp" #include -#include -#include -#include -#include namespace flock { @@ -39,6 +35,9 @@ class OpenAIModelManager : public BaseModelProviderHandler { std::string getEmbedUrl() const override { return _api_base_url + "embeddings"; } + std::string getTranscriptionUrl() const override { + return _api_base_url + "audio/transcriptions"; + } void prepareSessionForRequest(const std::string& url) override { _session.setUrl(url); } @@ -53,7 +52,11 @@ class OpenAIModelManager : public BaseModelProviderHandler { std::vector getExtraHeaders() const override { return {"Authorization: Bearer " + _token}; } - void checkProviderSpecificResponse(const nlohmann::json& response, bool is_completion) override { + void checkProviderSpecificResponse(const nlohmann::json& response, RequestType request_type) override { + if (request_type == RequestType::Transcription) { + return;// No specific checks needed for transcriptions + } + bool is_completion = (request_type == RequestType::Completion); if (is_completion) { if (response.contains("choices") && response["choices"].is_array() && !response["choices"].empty()) { const auto& choice = response["choices"][0]; @@ -65,7 +68,6 @@ class OpenAIModelManager : public BaseModelProviderHandler { } } } else { - // Embedding-specific checks (if any) can be added here if (response.contains("data") && response["data"].is_array() && response["data"].empty()) { throw std::runtime_error("OpenAI API returned empty embedding data."); } @@ -91,6 +93,30 @@ class OpenAIModelManager : public BaseModelProviderHandler { return results; } } + + std::pair ExtractTokenUsage(const nlohmann::json& response) const override { + int64_t input_tokens = 0; + int64_t output_tokens = 0; + if (response.contains("usage") && response["usage"].is_object()) { + const auto& usage = response["usage"]; + if (usage.contains("prompt_tokens") && usage["prompt_tokens"].is_number()) { + input_tokens = usage["prompt_tokens"].get(); + } + if (usage.contains("completion_tokens") && usage["completion_tokens"].is_number()) { + output_tokens = usage["completion_tokens"].get(); + } + } + return {input_tokens, output_tokens}; + } + + + nlohmann::json ExtractTranscriptionOutput(const nlohmann::json& response) const override { + // Transcription API returns JSON with "text" field when response_format=json + if (response.contains("text") && !response["text"].is_null()) { + return response["text"].get(); + } + return ""; + } }; }// namespace flock diff --git a/src/include/flock/model_manager/providers/handlers/url_handler.hpp b/src/include/flock/model_manager/providers/handlers/url_handler.hpp new file mode 100644 index 00000000..714967d5 --- /dev/null +++ b/src/include/flock/model_manager/providers/handlers/url_handler.hpp @@ -0,0 +1,249 @@ +#pragma once + +#include "flock/core/common.hpp" +#include "flock/core/config.hpp" +#include +#include +#include +#include +#include +#include +#include + +namespace flock { + +class URLHandler { +public: + // Extract file extension from URL + static std::string ExtractFileExtension(const std::string& url) { + size_t last_dot = url.find_last_of('.'); + size_t last_slash = url.find_last_of('/'); + if (last_dot != std::string::npos && (last_slash == std::string::npos || last_dot > last_slash)) { + size_t query_pos = url.find_first_of('?', last_dot); + if (query_pos != std::string::npos) { + return url.substr(last_dot, query_pos - last_dot); + } else { + return url.substr(last_dot); + } + } + return "";// No extension found + } + + // Generate a unique temporary filename with extension + static std::string GenerateTempFilename(const std::string& extension) { + // Get the flock storage directory (parent of the database file) + std::filesystem::path storage_dir = Config::get_global_storage_path().parent_path(); + + // Ensure the directory exists + if (!std::filesystem::exists(storage_dir)) { + std::filesystem::create_directories(storage_dir); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, 15); + std::ostringstream filename; + filename << "flock_"; + for (int i = 0; i < 16; ++i) { + filename << std::hex << dis(gen); + } + filename << extension; + + // Use filesystem path for proper cross-platform path handling + std::filesystem::path temp_path = storage_dir / filename.str(); + return temp_path.string(); + } + + // Check if the given path is a URL using regex + static bool IsUrl(const std::string& path) { + // Regex pattern to match URLs: http:// or https:// + static const std::regex url_pattern(R"(^https?://)"); + return std::regex_search(path, url_pattern); + } + + // Validate file exists and is not empty + static bool ValidateFile(const std::string& file_path) { + FILE* f = fopen(file_path.c_str(), "rb"); + if (!f) { + return false; + } + fseek(f, 0, SEEK_END); + long file_size = ftell(f); + fclose(f); + return file_size > 0; + } + + // Download file from URL to temporary location + // Supports http:// and https:// URLs + static std::string DownloadFileToTemp(const std::string& url) { + std::string extension = ExtractFileExtension(url); + // If no extension found, try to infer from content-type or use empty extension + std::string temp_filename = GenerateTempFilename(extension); + + // Download file using curl + CURL* curl = curl_easy_init(); + if (!curl) { + return ""; + } + + FILE* file = fopen(temp_filename.c_str(), "wb"); + if (!file) { + curl_easy_cleanup(curl); + return ""; + } + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt( + curl, CURLOPT_WRITEFUNCTION, +[](void* ptr, size_t size, size_t nmemb, void* stream) -> size_t { return fwrite(ptr, size, nmemb, static_cast(stream)); }); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, file); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + + CURLcode res = curl_easy_perform(curl); + fclose(file); + long response_code; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &response_code); + curl_easy_cleanup(curl); + + if (res != CURLE_OK || response_code != 200) { + std::remove(temp_filename.c_str()); + return ""; + } + + return temp_filename; + } + + // Helper struct to return file path and temp file flag + struct FilePathResult { + std::string file_path; + bool is_temp_file; + }; + + // Resolve file path: download if URL, validate, and return result + // Throws std::runtime_error if download or validation fails + static FilePathResult ResolveFilePath(const std::string& file_path_or_url) { + FilePathResult result; + + if (IsUrl(file_path_or_url)) { + result.file_path = DownloadFileToTemp(file_path_or_url); + if (result.file_path.empty()) { + throw std::runtime_error("Failed to download file: " + file_path_or_url); + } + result.is_temp_file = true; + } else { + result.file_path = file_path_or_url; + result.is_temp_file = false; + } + + if (!ValidateFile(result.file_path)) { + if (result.is_temp_file) { + std::remove(result.file_path.c_str()); + } + throw std::runtime_error("Invalid file: " + file_path_or_url); + } + + return result; + } + + // Read file contents and convert to base64 + // Returns empty string if file cannot be read + static std::string ReadFileToBase64(const std::string& file_path) { + FILE* file = fopen(file_path.c_str(), "rb"); + if (!file) { + return ""; + } + + // Get file size + fseek(file, 0, SEEK_END); + long file_size = ftell(file); + fseek(file, 0, SEEK_SET); + + if (file_size <= 0) { + fclose(file); + return ""; + } + + // Read file content + std::vector buffer(file_size); + size_t bytes_read = fread(buffer.data(), 1, file_size, file); + fclose(file); + + if (bytes_read != static_cast(file_size)) { + return ""; + } + + // Base64 encoding table + static const char base64_chars[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string result; + result.reserve(((file_size + 2) / 3) * 4); + + for (size_t i = 0; i < bytes_read; i += 3) { + unsigned int octet_a = buffer[i]; + unsigned int octet_b = (i + 1 < bytes_read) ? buffer[i + 1] : 0; + unsigned int octet_c = (i + 2 < bytes_read) ? buffer[i + 2] : 0; + + unsigned int triple = (octet_a << 16) + (octet_b << 8) + octet_c; + + result.push_back(base64_chars[(triple >> 18) & 0x3F]); + result.push_back(base64_chars[(triple >> 12) & 0x3F]); + result.push_back((i + 1 < bytes_read) ? base64_chars[(triple >> 6) & 0x3F] : '='); + result.push_back((i + 2 < bytes_read) ? base64_chars[triple & 0x3F] : '='); + } + + return result; + } + + // Helper struct to return base64 content and temp file flag + struct Base64Result { + std::string base64_content; + bool is_temp_file; + std::string temp_file_path; + }; + + // Resolve file path or URL, read contents and convert to base64 + // If input is URL, downloads to temp file first + // Returns base64 content and temp file info for cleanup + // Throws std::runtime_error if file cannot be processed + static Base64Result ResolveFileToBase64(const std::string& file_path_or_url) { + Base64Result result; + result.is_temp_file = false; + + std::string file_path; + if (IsUrl(file_path_or_url)) { + file_path = DownloadFileToTemp(file_path_or_url); + if (file_path.empty()) { + throw std::runtime_error("Failed to download file: " + file_path_or_url); + } + result.is_temp_file = true; + result.temp_file_path = file_path; + } else { + file_path = file_path_or_url; + } + + if (!ValidateFile(file_path)) { + if (result.is_temp_file) { + std::remove(file_path.c_str()); + } + throw std::runtime_error("Invalid file: " + file_path_or_url); + } + + result.base64_content = ReadFileToBase64(file_path); + if (result.base64_content.empty()) { + if (result.is_temp_file) { + std::remove(file_path.c_str()); + } + throw std::runtime_error("Failed to read file: " + file_path_or_url); + } + + // Cleanup temp file after reading + if (result.is_temp_file) { + std::remove(file_path.c_str()); + result.temp_file_path.clear(); + } + + return result; + } +}; + +}// namespace flock diff --git a/src/include/flock/model_manager/providers/provider.hpp b/src/include/flock/model_manager/providers/provider.hpp index 928e75de..cedc57f2 100644 --- a/src/include/flock/model_manager/providers/provider.hpp +++ b/src/include/flock/model_manager/providers/provider.hpp @@ -1,11 +1,12 @@ #pragma once #include "fmt/format.h" -#include #include +#include "flock/core/common.hpp" #include "flock/model_manager/providers/handlers/handler.hpp" #include "flock/model_manager/repository.hpp" +#include namespace flock { @@ -28,6 +29,7 @@ class IProvider { virtual void AddCompletionRequest(const std::string& prompt, const int num_output_tuples, OutputType output_type, const nlohmann::json& media_data) = 0; virtual void AddEmbeddingRequest(const std::vector& inputs) = 0; + virtual void AddTranscriptionRequest(const nlohmann::json& audio_files) = 0; virtual std::vector CollectCompletions(const std::string& contentType = "application/json") { return model_handler_->CollectCompletions(contentType); @@ -35,6 +37,9 @@ class IProvider { virtual std::vector CollectEmbeddings(const std::string& contentType = "application/json") { return model_handler_->CollectEmbeddings(contentType); } + virtual std::vector CollectTranscriptions(const std::string& contentType = "multipart/form-data") { + return model_handler_->CollectTranscriptions(contentType); + } static std::string GetOutputTypeString(const OutputType output_type) { switch (output_type) { diff --git a/src/include/flock/model_manager/repository.hpp b/src/include/flock/model_manager/repository.hpp index 7efeb950..e4aa7717 100644 --- a/src/include/flock/model_manager/repository.hpp +++ b/src/include/flock/model_manager/repository.hpp @@ -1,9 +1,9 @@ #pragma once +#include "flock/core/common.hpp" #include #include #include -#include #include namespace flock { diff --git a/src/include/flock/prompt_manager/prompt_manager.hpp b/src/include/flock/prompt_manager/prompt_manager.hpp index 021991ba..9ef3cd9a 100644 --- a/src/include/flock/prompt_manager/prompt_manager.hpp +++ b/src/include/flock/prompt_manager/prompt_manager.hpp @@ -1,12 +1,12 @@ #pragma once #include -#include -#include -#include +#include "flock/core/common.hpp" #include "flock/core/config.hpp" +#include "flock/model_manager/model.hpp" #include "flock/prompt_manager/repository.hpp" +#include namespace flock { @@ -46,19 +46,40 @@ class PromptManager { static std::string ConstructInputTuples(const nlohmann::json& columns, const std::string& tuple_format = "XML"); + // Helper function to transcribe audio column and create transcription text column + static nlohmann::json TranscribeAudioColumn(const nlohmann::json& audio_column); + +public: template static std::tuple Render(const std::string& user_prompt, const nlohmann::json& columns, FunctionType option, const std::string& tuple_format = "XML") { - auto media_data = nlohmann::json::array(); + auto image_data = nlohmann::json::array(); auto tabular_data = nlohmann::json::array(); + for (auto i = 0; i < static_cast(columns.size()); i++) { - if (columns[i].contains("type") && columns[i]["type"] == "image") { - media_data.push_back(columns[i]); + if (columns[i].contains("type")) { + auto column_type = columns[i]["type"].get(); + if (column_type == "image") { + image_data.push_back(columns[i]); + } else if (column_type == "audio") { + // Transcribe audio and merge as tabular text data + if (columns[i].contains("transcription_model")) { + auto transcription_column = TranscribeAudioColumn(columns[i]); + tabular_data.push_back(transcription_column); + } + } else { + tabular_data.push_back(columns[i]); + } } else { tabular_data.push_back(columns[i]); } } + // Create media_data as an object with only image array (audio is now in tabular_data) + nlohmann::json media_data; + media_data["image"] = image_data; + media_data["audio"] = nlohmann::json::array();// Empty - audio is now in tabular_data + auto prompt = PromptManager::GetTemplate(option); prompt = PromptManager::ReplaceSection(prompt, PromptSection::USER_PROMPT, user_prompt); if (!tabular_data.empty()) { diff --git a/src/include/flock/prompt_manager/repository.hpp b/src/include/flock/prompt_manager/repository.hpp index f525e420..ba01a293 100644 --- a/src/include/flock/prompt_manager/repository.hpp +++ b/src/include/flock/prompt_manager/repository.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "flock/core/common.hpp" #include namespace flock { diff --git a/src/include/flock/registry/scalar.hpp b/src/include/flock/registry/scalar.hpp index ffdb0309..084690c1 100644 --- a/src/include/flock/registry/scalar.hpp +++ b/src/include/flock/registry/scalar.hpp @@ -17,6 +17,9 @@ class ScalarRegistry { static void RegisterFusionCombMED(duckdb::ExtensionLoader& loader); static void RegisterFusionCombMNZ(duckdb::ExtensionLoader& loader); static void RegisterFusionCombSUM(duckdb::ExtensionLoader& loader); + static void RegisterFlockGetMetrics(duckdb::ExtensionLoader& loader); + static void RegisterFlockGetDebugMetrics(duckdb::ExtensionLoader& loader); + static void RegisterFlockResetMetrics(duckdb::ExtensionLoader& loader); }; }// namespace flock diff --git a/src/include/flock/secret_manager/secret_manager.hpp b/src/include/flock/secret_manager/secret_manager.hpp index 364510a6..a852c6e9 100644 --- a/src/include/flock/secret_manager/secret_manager.hpp +++ b/src/include/flock/secret_manager/secret_manager.hpp @@ -1,7 +1,6 @@ #pragma once #include "flock/core/common.hpp" -#include namespace flock { diff --git a/src/metrics/CMakeLists.txt b/src/metrics/CMakeLists.txt new file mode 100644 index 00000000..a35c542e --- /dev/null +++ b/src/metrics/CMakeLists.txt @@ -0,0 +1,4 @@ +set(EXTENSION_SOURCES + ${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/metrics.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/registry.cpp + PARENT_SCOPE) diff --git a/src/metrics/metrics.cpp b/src/metrics/metrics.cpp new file mode 100644 index 00000000..85eeff90 --- /dev/null +++ b/src/metrics/metrics.cpp @@ -0,0 +1,122 @@ +#include "flock/metrics/data_structures.hpp" +#include "flock/metrics/manager.hpp" +#include + +namespace flock { + +// Thread-local storage definitions (must be in .cpp file) +thread_local duckdb::DatabaseInstance* MetricsManager::current_db_ = nullptr; +thread_local const void* MetricsManager::current_state_id_ = nullptr; +thread_local FunctionType MetricsManager::current_function_type_ = FunctionType::UNKNOWN; + +void MetricsManager::ExecuteGetMetrics(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) { + auto& context = state.GetContext(); + auto* db = context.db.get(); + + auto& metrics_manager = GetForDatabase(db); + auto metrics = metrics_manager.GetMetrics(); + + auto json_str = metrics.dump(); + + result.SetVectorType(duckdb::VectorType::CONSTANT_VECTOR); + auto result_data = duckdb::ConstantVector::GetData(result); + result_data[0] = duckdb::StringVector::AddString(result, json_str); +} + +void MetricsManager::ExecuteGetDebugMetrics(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) { + auto& context = state.GetContext(); + auto* db = context.db.get(); + + auto& metrics_manager = GetForDatabase(db); + auto metrics = metrics_manager.GetDebugMetrics(); + + auto json_str = metrics.dump(); + + result.SetVectorType(duckdb::VectorType::CONSTANT_VECTOR); + auto result_data = duckdb::ConstantVector::GetData(result); + result_data[0] = duckdb::StringVector::AddString(result, json_str); +} + +void MetricsManager::ExecuteResetMetrics(duckdb::DataChunk& args, duckdb::ExpressionState& state, duckdb::Vector& result) { + auto& context = state.GetContext(); + auto* db = context.db.get(); + + auto& metrics_manager = GetForDatabase(db); + metrics_manager.Reset(); + + result.SetVectorType(duckdb::VectorType::CONSTANT_VECTOR); + auto result_data = duckdb::ConstantVector::GetData(result); + result_data[0] = duckdb::StringVector::AddString(result, "Metrics reset successfully"); +} + +void MetricsManager::MergeAggregateMetrics(duckdb::DatabaseInstance* db, + const std::vector& processed_state_ids, + FunctionType function_type, + const std::string& model_name, + const std::string& provider) { + if (processed_state_ids.empty() || db == nullptr) { + return; + } + + auto& manager = GetForDatabase(db); + + // Use the first state_id as the merged state_id + const void* merged_state_id = processed_state_ids[0]; + + // Start a new invocation for the merged metrics (registers the state and sets registration order) + StartInvocation(db, merged_state_id, function_type); + + // Get and merge metrics from all processed states + int64_t total_input_tokens = 0; + int64_t total_output_tokens = 0; + int64_t total_api_calls = 0; + int64_t total_api_duration_us = 0; + int64_t total_execution_time_us = 0; + std::string final_model_name = model_name; + std::string final_provider = provider; + + for (const void* state_id: processed_state_ids) { + auto& thread_metrics = manager.GetThreadMetrics(state_id); + const auto& metrics = thread_metrics.GetMetrics(function_type); + + if (!metrics.IsEmpty()) { + total_input_tokens += metrics.input_tokens; + total_output_tokens += metrics.output_tokens; + total_api_calls += metrics.api_calls; + total_api_duration_us += metrics.api_duration_us; + total_execution_time_us += metrics.execution_time_us; + + // Use model info from first non-empty state if not provided + if (final_model_name.empty() && !metrics.model_name.empty()) { + final_model_name = metrics.model_name; + final_provider = metrics.provider; + } + } + } + + // Get the merged state's metrics and set aggregated values + auto& merged_thread_metrics = manager.GetThreadMetrics(merged_state_id); + auto& merged_metrics = merged_thread_metrics.GetMetrics(function_type); + + // Set the aggregated values directly + merged_metrics.input_tokens = total_input_tokens; + merged_metrics.output_tokens = total_output_tokens; + merged_metrics.api_calls = total_api_calls; + merged_metrics.api_duration_us = total_api_duration_us; + merged_metrics.execution_time_us = total_execution_time_us; + if (!final_model_name.empty()) { + merged_metrics.model_name = final_model_name; + merged_metrics.provider = final_provider; + } + + // Clean up individual state metrics (reset function_type metrics for all except the merged one) + for (size_t i = 1; i < processed_state_ids.size(); i++) { + const void* state_id = processed_state_ids[i]; + auto& thread_metrics = manager.GetThreadMetrics(state_id); + auto& metrics = thread_metrics.GetMetrics(function_type); + // Reset only the specific function_type metrics for this state + metrics = FunctionMetricsData{}; + } +} + +}// namespace flock diff --git a/src/metrics/registry.cpp b/src/metrics/registry.cpp new file mode 100644 index 00000000..bccaec0d --- /dev/null +++ b/src/metrics/registry.cpp @@ -0,0 +1,36 @@ +#include "flock/registry/registry.hpp" +#include "flock/metrics/manager.hpp" + +namespace flock { + +void ScalarRegistry::RegisterFlockGetMetrics(duckdb::ExtensionLoader& loader) { + auto function = duckdb::ScalarFunction( + "flock_get_metrics", + {}, + duckdb::LogicalType::JSON(), + MetricsManager::ExecuteGetMetrics); + function.stability = duckdb::FunctionStability::VOLATILE; + loader.RegisterFunction(function); +} + +void ScalarRegistry::RegisterFlockGetDebugMetrics(duckdb::ExtensionLoader& loader) { + auto function = duckdb::ScalarFunction( + "flock_get_debug_metrics", + {}, + duckdb::LogicalType::JSON(), + MetricsManager::ExecuteGetDebugMetrics); + function.stability = duckdb::FunctionStability::VOLATILE; + loader.RegisterFunction(function); +} + +void ScalarRegistry::RegisterFlockResetMetrics(duckdb::ExtensionLoader& loader) { + auto function = duckdb::ScalarFunction( + "flock_reset_metrics", + {}, + duckdb::LogicalType::VARCHAR, + MetricsManager::ExecuteResetMetrics); + function.stability = duckdb::FunctionStability::VOLATILE; + loader.RegisterFunction(function); +} + +}// namespace flock diff --git a/src/model_manager/model.cpp b/src/model_manager/model.cpp index 24a64dcc..8da8d3b4 100644 --- a/src/model_manager/model.cpp +++ b/src/model_manager/model.cpp @@ -1,9 +1,13 @@ #include "flock/model_manager/model.hpp" #include "flock/secret_manager/secret_manager.hpp" +#include +#include +#include +#include +#include namespace flock { -// Regular expression to match a valid Base64 string const std::regex base64_regex(R"(^[A-Za-z0-9+/=]+$)"); bool is_base64(const std::string& str) { @@ -20,26 +24,72 @@ void Model::LoadModelDetails(const nlohmann::json& model_json) { if (model_details_.model_name.empty()) { throw std::invalid_argument("`model_name` is required in model settings"); } - auto query_result = GetQueriedModel(model_details_.model_name); - model_details_.model = - model_json.contains("model") ? model_json.at("model").get() : std::get<0>(query_result); - model_details_.provider_name = - model_json.contains("provider") ? model_json.at("provider").get() : std::get<1>(query_result); - auto secret_name = "__default_" + model_details_.provider_name; - if (model_details_.provider_name == AZURE) - secret_name += "_llm"; - if (model_json.contains("secret_name")) { - secret_name = model_json["secret_name"].get(); - } - model_details_.secret = SecretManager::GetSecret(secret_name); - model_details_.model_parameters = model_json.contains("model_parameters") ? nlohmann::json::parse(model_json.at("model_parameters").get()) : std::get<2>(query_result)["model_parameters"]; - model_details_.tuple_format = - model_json.contains("tuple_format") ? model_json.at("tuple_format").get() : std::get<2>(query_result).contains("tuple_format") ? std::get<2>(query_result).at("tuple_format").get() - : "XML"; + bool has_resolved_details = model_json.contains("model") && + model_json.contains("provider") && + model_json.contains("secret") && + model_json.contains("tuple_format") && + model_json.contains("batch_size"); + + nlohmann::json db_model_args; + + if (has_resolved_details) { + model_details_.model = model_json.at("model").get(); + model_details_.provider_name = model_json.at("provider").get(); + model_details_.secret = model_json["secret"].get>(); + model_details_.tuple_format = model_json.at("tuple_format").get(); + model_details_.batch_size = model_json.at("batch_size").get(); + + if (model_json.contains("model_parameters")) { + auto& mp = model_json.at("model_parameters"); + model_details_.model_parameters = mp.is_string() ? nlohmann::json::parse(mp.get()) : mp; + } else { + model_details_.model_parameters = nlohmann::json::object(); + } + } else { + auto [db_model, db_provider, db_args] = GetQueriedModel(model_details_.model_name); + model_details_.model = model_json.contains("model") ? model_json.at("model").get() : db_model; + model_details_.provider_name = model_json.contains("provider") ? model_json.at("provider").get() : db_provider; + db_model_args = db_args; + + if (model_json.contains("secret")) { + model_details_.secret = model_json["secret"].get>(); + } else { + auto secret_name = "__default_" + model_details_.provider_name; + if (model_details_.provider_name == AZURE) { + secret_name += "_llm"; + } + if (model_json.contains("secret_name")) { + secret_name = model_json["secret_name"].get(); + } + model_details_.secret = SecretManager::GetSecret(secret_name); + } + + if (model_json.contains("model_parameters")) { + auto& mp = model_json.at("model_parameters"); + model_details_.model_parameters = mp.is_string() ? nlohmann::json::parse(mp.get()) : mp; + } else if (db_model_args.contains("model_parameters")) { + model_details_.model_parameters = db_model_args["model_parameters"]; + } else { + model_details_.model_parameters = nlohmann::json::object(); + } + + if (model_json.contains("tuple_format")) { + model_details_.tuple_format = model_json.at("tuple_format").get(); + } else if (db_model_args.contains("tuple_format")) { + model_details_.tuple_format = db_model_args.at("tuple_format").get(); + } else { + model_details_.tuple_format = "XML"; + } - model_details_.batch_size = model_json.contains("batch_size") ? model_json.at("batch_size").get() : std::get<2>(query_result).contains("batch_size") ? std::get<2>(query_result).at("batch_size").get() - : 2048; + if (model_json.contains("batch_size")) { + model_details_.batch_size = model_json.at("batch_size").get(); + } else if (db_model_args.contains("batch_size")) { + model_details_.batch_size = db_model_args.at("batch_size").get(); + } else { + model_details_.batch_size = 2048; + } + } } std::tuple> Model::GetQueriedModel(const std::string& model_name) { @@ -54,6 +104,7 @@ std::tuple> Model::GetQueriedMo model_name, model_name); auto con = Config::GetConnection(); + Config::StorageAttachmentGuard guard(con, true); auto query_result = con.Query(query); if (query_result->RowCount() == 0) { @@ -76,6 +127,10 @@ std::tuple> Model::GetQueriedMo } void Model::ConstructProvider() { + if (mock_provider_factory_) { + provider_ = mock_provider_factory_(); + return; + } if (mock_provider_) { provider_ = mock_provider_; return; @@ -98,6 +153,31 @@ void Model::ConstructProvider() { ModelDetails Model::GetModelDetails() { return model_details_; } +nlohmann::json Model::GetModelDetailsAsJson() const { + nlohmann::json result; + result["model_name"] = model_details_.model_name; + result["model"] = model_details_.model; + result["provider"] = model_details_.provider_name; + result["tuple_format"] = model_details_.tuple_format; + result["batch_size"] = model_details_.batch_size; + result["secret"] = model_details_.secret; + if (!model_details_.model_parameters.empty()) { + result["model_parameters"] = model_details_.model_parameters; + } + return result; +} + +nlohmann::json Model::ResolveModelDetailsToJson(const nlohmann::json& user_model_json) { + Model temp_model(user_model_json); + auto resolved_json = temp_model.GetModelDetailsAsJson(); + + if (user_model_json.contains("secret_name")) { + resolved_json["secret_name"] = user_model_json["secret_name"]; + } + + return resolved_json; +} + void Model::AddCompletionRequest(const std::string& prompt, const int num_output_tuples, OutputType output_type, const nlohmann::json& media_data) { provider_->AddCompletionRequest(prompt, num_output_tuples, output_type, media_data); } @@ -106,6 +186,10 @@ void Model::AddEmbeddingRequest(const std::vector& inputs) { provider_->AddEmbeddingRequest(inputs); } +void Model::AddTranscriptionRequest(const nlohmann::json& audio_files) { + provider_->AddTranscriptionRequest(audio_files); +} + std::vector Model::CollectCompletions(const std::string& contentType) { return provider_->CollectCompletions(contentType); } @@ -114,4 +198,8 @@ std::vector Model::CollectEmbeddings(const std::string& contentT return provider_->CollectEmbeddings(contentType); } +std::vector Model::CollectTranscriptions(const std::string& contentType) { + return provider_->CollectTranscriptions(contentType); +} + }// namespace flock diff --git a/src/model_manager/providers/adapters/azure.cpp b/src/model_manager/providers/adapters/azure.cpp index 7a883f27..9a2f4ab5 100644 --- a/src/model_manager/providers/adapters/azure.cpp +++ b/src/model_manager/providers/adapters/azure.cpp @@ -1,4 +1,6 @@ #include "flock/model_manager/providers/adapters/azure.hpp" +#include "flock/model_manager/model.hpp" +#include "flock/model_manager/providers/handlers/url_handler.hpp" namespace flock { @@ -8,30 +10,49 @@ void AzureProvider::AddCompletionRequest(const std::string& prompt, const int nu message_content.push_back({{"type", "text"}, {"text", prompt}}); - if (!media_data.empty()) { - auto detail = media_data[0].contains("detail") ? media_data[0]["detail"].get() : "low"; - auto image_type = media_data[0]["type"].get(); - auto mime_type = std::string("image/"); - if (size_t pos = image_type.find("/"); pos != std::string::npos) { - mime_type += image_type.substr(pos + 1); - } else { - mime_type += std::string("png"); - } + // Process image columns + if (media_data.contains("image") && !media_data["image"].empty() && media_data["image"].is_array()) { + std::string detail = "low"; auto column_index = 1u; - for (const auto& column: media_data) { + for (const auto& column: media_data["image"]) { + // Process image column as before + if (column_index == 1) { + detail = column.contains("detail") ? column["detail"].get() : "low"; + } + auto image_type = column.contains("type") ? column["type"].get() : "image"; + auto mime_type = std::string("image/"); + if (size_t pos = image_type.find("/"); pos != std::string::npos) { + mime_type += image_type.substr(pos + 1); + } else { + mime_type += std::string("png"); + } message_content.push_back( {{"type", "text"}, {"text", "ATTACHMENT COLUMN"}}); auto row_index = 1u; for (const auto& image: column["data"]) { + // Skip null values + if (image.is_null()) { + continue; + } message_content.push_back( {{"type", "text"}, {"text", "ROW " + std::to_string(row_index) + " :"}}); auto image_url = std::string(); - auto image_str = image.get(); - if (is_base64(image_str)) { - image_url = duckdb_fmt::format("data:{};base64,{}", mime_type, image_str); + std::string image_str; + if (image.is_string()) { + image_str = image.get(); } else { + image_str = image.dump(); + } + + // Handle file path or URL + if (URLHandler::IsUrl(image_str)) { + // URL - send directly to API image_url = image_str; + } else { + // File path - read and convert to base64 + auto base64_result = URLHandler::ResolveFileToBase64(image_str); + image_url = duckdb_fmt::format("data:{};base64,{}", mime_type, base64_result.base64_content); } message_content.push_back( @@ -83,4 +104,19 @@ void AzureProvider::AddEmbeddingRequest(const std::vector& inputs) } } +void AzureProvider::AddTranscriptionRequest(const nlohmann::json& audio_files) { + for (const auto& audio_file: audio_files) { + auto audio_file_str = audio_file.get(); + + // Handle file download and validation + auto file_result = URLHandler::ResolveFilePath(audio_file_str); + + nlohmann::json transcription_request = { + {"file_path", file_result.file_path}, + {"model", model_details_.model}, + {"is_temp_file", file_result.is_temp_file}}; + model_handler_->AddRequest(transcription_request, IModelProviderHandler::RequestType::Transcription); + } +} + }// namespace flock \ No newline at end of file diff --git a/src/model_manager/providers/adapters/ollama.cpp b/src/model_manager/providers/adapters/ollama.cpp index 2e5cd2a1..0311fdd7 100644 --- a/src/model_manager/providers/adapters/ollama.cpp +++ b/src/model_manager/providers/adapters/ollama.cpp @@ -1,25 +1,46 @@ #include "flock/model_manager/providers/adapters/ollama.hpp" +#include "flock/model_manager/providers/handlers/url_handler.hpp" +#include "flock/model_manager/providers/provider.hpp" namespace flock { void OllamaProvider::AddCompletionRequest(const std::string& prompt, const int num_output_tuples, OutputType output_type, const nlohmann::json& media_data) { - nlohmann::json request_payload = {{"model", model_details_.model}, - {"prompt", prompt}, - {"stream", false}}; + // Build message for chat API + nlohmann::json message = {{"role", "user"}, {"content", prompt}}; + // Process image columns - images go in the message object as an "images" array auto images = nlohmann::json::array(); - if (!media_data.empty()) { - for (const auto& column: media_data) { - for (const auto& image: column["data"]) { - auto image_str = image.get(); - images.push_back(image_str); + if (media_data.contains("image") && !media_data["image"].empty() && media_data["image"].is_array()) { + for (const auto& column: media_data["image"]) { + if (column.contains("data") && column["data"].is_array()) { + for (const auto& image: column["data"]) { + // Skip null values + if (image.is_null()) { + continue; + } + std::string image_str; + if (image.is_string()) { + image_str = image.get(); + } else { + // Convert non-string values to string + image_str = image.dump(); + } + + // Handle file path or URL - resolve and convert to base64 + auto base64_result = URLHandler::ResolveFileToBase64(image_str); + images.push_back(base64_result.base64_content); + } } } } if (!images.empty()) { - request_payload["images"] = images; + message["images"] = images; } + nlohmann::json request_payload = {{"model", model_details_.model}, + {"messages", nlohmann::json::array({message})}, + {"stream", false}}; + if (!model_details_.model_parameters.empty()) { request_payload.update(model_details_.model_parameters); } @@ -51,4 +72,8 @@ void OllamaProvider::AddEmbeddingRequest(const std::vector& inputs) } } +void OllamaProvider::AddTranscriptionRequest(const nlohmann::json& audio_files) { + throw std::runtime_error("Audio transcription is not currently supported by Ollama."); +} + }// namespace flock \ No newline at end of file diff --git a/src/model_manager/providers/adapters/openai.cpp b/src/model_manager/providers/adapters/openai.cpp index 2b1c97d2..15907bab 100644 --- a/src/model_manager/providers/adapters/openai.cpp +++ b/src/model_manager/providers/adapters/openai.cpp @@ -1,38 +1,58 @@ #include "flock/model_manager/providers/adapters/openai.hpp" +#include "flock/model_manager/model.hpp" +#include "flock/model_manager/providers/handlers/url_handler.hpp" #include namespace flock { void OpenAIProvider::AddCompletionRequest(const std::string& prompt, const int num_output_tuples, OutputType output_type, const nlohmann::json& media_data) { - auto message_content = nlohmann::json::array(); message_content.push_back({{"type", "text"}, {"text", prompt}}); - if (!media_data.empty()) { - auto detail = media_data[0].contains("detail") ? media_data[0]["detail"].get() : "low"; - auto image_type = media_data[0]["type"].get(); - auto mime_type = std::string("image/"); - if (size_t pos = image_type.find("/"); pos != std::string::npos) { - mime_type += image_type.substr(pos + 1); - } else { - mime_type += std::string("png"); - } + // Process image columns + if (media_data.contains("image") && !media_data["image"].empty() && media_data["image"].is_array()) { + std::string detail = "low"; auto column_index = 1u; - for (const auto& column: media_data) { + for (const auto& column: media_data["image"]) { + // Process image column as before + if (column_index == 1) { + detail = column.contains("detail") ? column["detail"].get() : "low"; + } + auto image_type = column.contains("type") ? column["type"].get() : "image"; + auto mime_type = std::string("image/"); + if (size_t pos = image_type.find("/"); pos != std::string::npos) { + mime_type += image_type.substr(pos + 1); + } else { + mime_type += std::string("png"); + } message_content.push_back( {{"type", "text"}, {"text", "ATTACHMENT COLUMN"}}); auto row_index = 1u; for (const auto& image: column["data"]) { + // Skip null values + if (image.is_null()) { + continue; + } message_content.push_back( {{"type", "text"}, {"text", "ROW " + std::to_string(row_index) + " :"}}); auto image_url = std::string(); - auto image_str = image.get(); - if (is_base64(image_str)) { - image_url = duckdb_fmt::format("data:{};base64,{}", mime_type, image_str); + std::string image_str; + if (image.is_string()) { + image_str = image.get(); } else { + image_str = image.dump(); + } + + // Handle file path or URL + if (URLHandler::IsUrl(image_str)) { + // URL - send directly to API image_url = image_str; + } else { + // File path - read and convert to base64 + auto base64_result = URLHandler::ResolveFileToBase64(image_str); + image_url = duckdb_fmt::format("data:{};base64,{}", mime_type, base64_result.base64_content); } message_content.push_back( @@ -82,4 +102,28 @@ void OpenAIProvider::AddEmbeddingRequest(const std::vector& inputs) model_handler_->AddRequest(request_payload, IModelProviderHandler::RequestType::Embedding); } +void OpenAIProvider::AddTranscriptionRequest(const nlohmann::json& audio_files) { + for (const auto& audio_file: audio_files) { + // Skip null values + if (audio_file.is_null()) { + continue; + } + std::string audio_file_str; + if (audio_file.is_string()) { + audio_file_str = audio_file.get(); + } else { + audio_file_str = audio_file.dump(); + } + + // Handle file download and validation + auto file_result = URLHandler::ResolveFilePath(audio_file_str); + + nlohmann::json transcription_request = { + {"file_path", file_result.file_path}, + {"model", model_details_.model}, + {"is_temp_file", file_result.is_temp_file}}; + model_handler_->AddRequest(transcription_request, IModelProviderHandler::RequestType::Transcription); + } +} + }// namespace flock diff --git a/src/prompt_manager/prompt_manager.cpp b/src/prompt_manager/prompt_manager.cpp index 4d497d4a..270b3c76 100644 --- a/src/prompt_manager/prompt_manager.cpp +++ b/src/prompt_manager/prompt_manager.cpp @@ -58,7 +58,12 @@ std::string PromptManager::ConstructInputTuplesHeaderXML(const nlohmann::json& c auto header = std::string("
"); auto column_idx = 1u; for (const auto& column: columns) { - auto column_name = column.contains("name") ? column["name"].get() : "COLUMN " + std::to_string(column_idx++); + std::string column_name; + if (column.contains("name") && column["name"].is_string()) { + column_name = column["name"].get(); + } else { + column_name = "COLUMN " + std::to_string(column_idx++); + } header += "" + column_name + ""; } header += "
\n"; @@ -72,7 +77,7 @@ std::string PromptManager::ConstructInputTuplesHeaderMarkdown(const nlohmann::js auto header = std::string(" | "); auto column_idx = 1u; for (const auto& column: columns) { - if (column.contains("name")) { + if (column.contains("name") && column["name"].is_string()) { header += "COLUMN_" + column["name"].get() + " | "; } else { header += "COLUMN " + std::to_string(column_idx++) + " | "; @@ -81,7 +86,12 @@ std::string PromptManager::ConstructInputTuplesHeaderMarkdown(const nlohmann::js header += "\n | "; column_idx = 1u; for (const auto& column: columns) { - auto column_name = column.contains("name") ? column["name"].get() : "COLUMN " + std::to_string(column_idx++); + std::string column_name; + if (column.contains("name") && column["name"].is_string()) { + column_name = column["name"].get(); + } else { + column_name = "COLUMN " + std::to_string(column_idx++); + } header += std::string(column_name.length(), '-') + " | "; } header += "\n"; @@ -97,7 +107,16 @@ std::string PromptManager::ConstructInputTuplesXML(const nlohmann::json& columns for (auto i = 0; i < static_cast(columns[0]["data"].size()); i++) { tuples_str += ""; for (const auto& column: columns) { - tuples_str += "" + column["data"][i].get() + ""; + std::string value_str; + const auto& data_item = column["data"][i]; + if (data_item.is_null()) { + value_str = ""; + } else if (data_item.is_string()) { + value_str = data_item.get(); + } else { + value_str = data_item.dump(); + } + tuples_str += "" + value_str + ""; } tuples_str += "\n"; } @@ -124,7 +143,12 @@ std::string PromptManager::ConstructInputTuplesJSON(const nlohmann::json& column auto tuples_json = nlohmann::json::object(); auto column_idx = 1u; for (const auto& column: columns) { - auto column_name = column.contains("name") ? column["name"].get() : "COLUMN " + std::to_string(column_idx++); + std::string column_name; + if (column.contains("name") && column["name"].is_string()) { + column_name = column["name"].get(); + } else { + column_name = "COLUMN " + std::to_string(column_idx++); + } tuples_json[column_name] = column["data"]; } auto tuples_str = tuples_json.dump(4); @@ -190,6 +214,7 @@ PromptDetails PromptManager::CreatePromptDetails(const nlohmann::json& prompt_de version_where_clause, order_by_clause); error_message = duckdb_fmt::format("The provided `{}` prompt " + error_message, prompt_details.prompt_name); auto con = Config::GetConnection(); + Config::StorageAttachmentGuard guard(con, true); const auto query_result = con.Query(prompt_details_query); if (query_result->RowCount() == 0) { throw std::runtime_error(error_message); @@ -216,4 +241,38 @@ PromptDetails PromptManager::CreatePromptDetails(const nlohmann::json& prompt_de } return prompt_details; } + +nlohmann::json PromptManager::TranscribeAudioColumn(const nlohmann::json& audio_column) { + auto transcription_model_name = audio_column["transcription_model"].get(); + + // Look up the transcription model + nlohmann::json transcription_model_json; + transcription_model_json["model_name"] = transcription_model_name; + Model transcription_model(transcription_model_json); + + // Add transcription requests to batch + transcription_model.AddTranscriptionRequest(audio_column["data"]); + + // Collect transcriptions + auto transcription_results = transcription_model.CollectTranscriptions(); + + // Convert vector to nlohmann::json array + nlohmann::json transcriptions = nlohmann::json::array(); + for (const auto& result: transcription_results) { + transcriptions.push_back(result); + } + + // Create transcription column with proper naming + auto transcription_column = nlohmann::json::object(); + std::string original_name; + if (audio_column.contains("name") && audio_column["name"].is_string()) { + original_name = audio_column["name"].get(); + } + auto transcription_name = original_name.empty() ? "transcription" : "transcription_of_" + original_name; + transcription_column["name"] = transcription_name; + transcription_column["data"] = transcriptions; + + return transcription_column; +} + }// namespace flock diff --git a/src/registry/scalar.cpp b/src/registry/scalar.cpp index 87706f3f..1d7e181e 100644 --- a/src/registry/scalar.cpp +++ b/src/registry/scalar.cpp @@ -11,6 +11,9 @@ void ScalarRegistry::Register(duckdb::ExtensionLoader& loader) { RegisterFusionCombMED(loader); RegisterFusionCombMNZ(loader); RegisterFusionCombSUM(loader); + RegisterFlockGetMetrics(loader); + RegisterFlockGetDebugMetrics(loader); + RegisterFlockResetMetrics(loader); } }// namespace flock diff --git a/test/integration/src/integration/conftest.py b/test/integration/src/integration/conftest.py index eff66a20..2e8426af 100644 --- a/test/integration/src/integration/conftest.py +++ b/test/integration/src/integration/conftest.py @@ -3,12 +3,32 @@ import pytest from pathlib import Path from dotenv import load_dotenv -import base64 -import requests from integration.setup_test_db import setup_test_db load_dotenv() +TEST_AUDIO_FILE_PATH = Path(__file__).parent / "tests" / "flock_test_audio.mp3" + + +def get_audio_file_path(): + return str(TEST_AUDIO_FILE_PATH.resolve()) + + +def get_secrets_setup_sql(): + openai_key = os.getenv("OPENAI_API_KEY", "") + ollama_url = os.getenv("API_URL", "http://localhost:11434") + + secrets_sql = [] + + if openai_key: + secrets_sql.append(f"CREATE SECRET (TYPE OPENAI, API_KEY '{openai_key}');") + + if ollama_url: + secrets_sql.append(f"CREATE SECRET (TYPE OLLAMA, API_URL '{ollama_url}');") + + return " ".join(secrets_sql) + + @pytest.fixture(scope="session") def integration_setup(tmp_path_factory): duckdb_cli_path = os.getenv("DUCKDB_CLI_PATH", "duckdb") @@ -25,33 +45,49 @@ def integration_setup(tmp_path_factory): if os.path.exists(test_db_path): os.remove(test_db_path) -def run_cli(duckdb_cli_path, db_path, query): - return subprocess.run( + +def run_cli(duckdb_cli_path, db_path, query, with_secrets=True): + if with_secrets: + secrets_sql = get_secrets_setup_sql() + if secrets_sql: + query = f"{secrets_sql} {query}" + + result = subprocess.run( [duckdb_cli_path, db_path, "-csv", "-c", query], capture_output=True, text=True, check=False, ) + # Filter out the secret creation output (Success, true lines) from stdout + if with_secrets and result.stdout: + lines = result.stdout.split("\n") + # Remove lines that are just "Success" or "true" from secret creation + filtered_lines = [] + skip_count = 0 + for line in lines: + stripped = line.strip() + if skip_count > 0 and stripped in ("true", "false"): + skip_count -= 1 + continue + if stripped == "Success": + skip_count = 1 # Skip the next line (true/false) + continue + filtered_lines.append(line) + result = subprocess.CompletedProcess( + args=result.args, + returncode=result.returncode, + stdout="\n".join(filtered_lines), + stderr=result.stderr, + ) + + return result + + def get_image_data_for_provider(image_url, provider): """ Get image data in the appropriate format based on the provider. - OpenAI uses URLs directly, Ollama uses base64 encoding. + Now all providers support URLs directly - the C++ code handles + downloading and converting to base64 for providers that need it (Ollama). """ - if provider == "openai": - return image_url - elif provider == "ollama": - # Fetch the image and convert to base64 - try: - response = requests.get(image_url, timeout=10) - response.raise_for_status() - image_base64 = base64.b64encode(response.content).decode("utf-8") - return image_base64 - except Exception as e: - # Fallback to URL if fetching fails - print( - f"Warning: Failed to fetch image {image_url}: {e}. Using URL instead." - ) - return image_url - else: - return image_url + return image_url diff --git a/test/integration/src/integration/setup_test_db.py b/test/integration/src/integration/setup_test_db.py index 8f9913fe..5aa828ec 100644 --- a/test/integration/src/integration/setup_test_db.py +++ b/test/integration/src/integration/setup_test_db.py @@ -1,24 +1,16 @@ #!/usr/bin/env python3 -""" -Test Database Setup Script for FlockMTL Integration Tests - -This script creates and manages a persistent test database with pre-configured -models, prompts, and test data to reduce setup time during integration testing. -""" import os import subprocess from pathlib import Path from dotenv import load_dotenv -# Load environment variables load_dotenv() -# Configuration DUCKDB_CLI_PATH = os.getenv("DUCKDB_CLI_PATH", "duckdb") + def run_sql_command(db_path: str, sql_command: str, description: str = ""): - """Execute SQL command using DuckDB CLI.""" try: result = subprocess.run( [DUCKDB_CLI_PATH, db_path, "-c", sql_command], @@ -35,35 +27,6 @@ def run_sql_command(db_path: str, sql_command: str, description: str = ""): print(f" Error: {e.stderr}") return None -def create_base_test_secrets(db_path: str): - """Create basic test secrets for LLM functions.""" - secrets = { - "openai": (os.getenv("OPENAI_API_KEY")), - "ollama": (os.getenv("API_URL", "http://localhost:11434")) - } - - def create_openai_secret(secret_key): - return f"""CREATE PERSISTENT SECRET IF NOT EXISTS ( - TYPE OPENAI, - API_KEY '{secret_key}' - );""" - - def create_ollama_secret(secret_key): - return f"""CREATE PERSISTENT SECRET IF NOT EXISTS ( - TYPE OLLAMA, - API_URL '{secret_key}' - );""" - - print("Creating test secrets...") - for secret_name, secret_value in secrets.items(): - if secret_name == "openai": - sql = create_openai_secret(secret_value) - elif secret_name == "ollama": - sql = create_ollama_secret(secret_value) - else: - continue - run_sql_command(db_path, sql, f"Secret: {secret_name}") def setup_test_db(db_path): - - create_base_test_secrets(db_path) + pass diff --git a/test/integration/src/integration/tests/flock_test_audio.mp3 b/test/integration/src/integration/tests/flock_test_audio.mp3 new file mode 100644 index 00000000..74e8cac2 Binary files /dev/null and b/test/integration/src/integration/tests/flock_test_audio.mp3 differ diff --git a/test/integration/src/integration/tests/functions/aggregate/test_llm_first.py b/test/integration/src/integration/tests/functions/aggregate/test_llm_first.py index 6e03b5ee..01fa1ea3 100644 --- a/test/integration/src/integration/tests/functions/aggregate/test_llm_first.py +++ b/test/integration/src/integration/tests/functions/aggregate/test_llm_first.py @@ -1,10 +1,27 @@ import pytest -from integration.conftest import run_cli, get_image_data_for_provider +import json +import csv +from io import StringIO +from integration.conftest import ( + run_cli, + get_image_data_for_provider, + get_audio_file_path, +) +# Expected keywords that should appear when audio is transcribed +# Audio content: "Flock transforms DuckDB into a hybrid database and a semantic AI engine" +AUDIO_EXPECTED_KEYWORDS = ["flock", "duckdb", "database", "semantic", "ai", "hybrid"] -@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("llama3.2", "ollama")]) + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:1b", "ollama")]) def model_config(request): - """Fixture to test with different models.""" + """Fixture to test with different models for text-only tests.""" + return request.param + + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:4b", "ollama")]) +def model_config_image(request): + """Fixture to test with different models for image tests.""" return request.param @@ -17,7 +34,9 @@ def test_llm_first_basic_functionality(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + r = run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + assert r.returncode == 0, f"Query failed with error: {create_model_query} {r.stderr}" create_table_query = """ CREATE OR REPLACE TABLE candidates ( @@ -64,7 +83,7 @@ def test_llm_first_with_group_by(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE job_applications ( @@ -122,7 +141,7 @@ def test_llm_first_with_batch_processing(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE investment_options ( @@ -171,7 +190,7 @@ def test_llm_first_with_model_parameters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE startup_pitches ( @@ -223,7 +242,7 @@ def test_llm_first_multiple_criteria(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE course_options ( @@ -273,7 +292,7 @@ def test_llm_first_empty_table(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE empty_candidates ( @@ -345,7 +364,7 @@ def test_llm_first_error_handling_empty_prompt(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE test_data ( @@ -386,7 +405,7 @@ def test_llm_first_error_handling_missing_arguments(integration_setup, model_con create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) # Test with only 1 argument (should fail since llm_first requires 2) query = ( @@ -412,7 +431,7 @@ def test_llm_first_with_special_characters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE international_universities ( @@ -461,7 +480,7 @@ def _test_llm_first_performance_large_dataset(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE large_candidate_pool AS @@ -498,16 +517,16 @@ def _test_llm_first_performance_large_dataset(integration_setup, model_config): assert "category" in result.stdout.lower() -def test_llm_first_with_image_integration(integration_setup, model_config): +def test_llm_first_with_image_integration(integration_setup, model_config_image): """Test llm_first with image data integration.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-first-model_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE pet_images ( @@ -567,16 +586,16 @@ def test_llm_first_with_image_integration(integration_setup, model_config): assert len(result.stdout.strip().split("\n")) >= 2 -def test_llm_first_image_with_group_by(integration_setup, model_config): +def test_llm_first_image_with_group_by(integration_setup, model_config_image): """Test llm_first with images and GROUP BY clause.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-group-first_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE artwork_images ( @@ -641,7 +660,9 @@ def test_llm_first_image_with_group_by(integration_setup, model_config): ) result = run_cli(duckdb_cli_path, db_path, query) - assert result.returncode == 0, f"Query failed with error: {result.stderr}" + assert result.returncode == 0, ( + f"Query failed with error: {result.stdout} {result.stderr}" + ) lines = result.stdout.strip().split("\n") assert len(lines) >= 4, ( f"Expected at least 4 lines (header + 3 styles), got {len(lines)}" @@ -649,16 +670,16 @@ def test_llm_first_image_with_group_by(integration_setup, model_config): assert "most_recent_artwork" in result.stdout.lower() -def test_llm_first_image_batch_processing(integration_setup, model_config): +def test_llm_first_image_batch_processing(integration_setup, model_config_image): """Test llm_first with multiple images in batch processing.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-batch-first_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE building_images ( @@ -698,15 +719,15 @@ def test_llm_first_image_batch_processing(integration_setup, model_config): 200, 2020), (4, 'Corporate Center', '{corporate_image}', 'Miami', 180, - 2019); \ + 2019); """ run_cli(duckdb_cli_path, db_path, insert_data_query) query = ( """ - SELECT city, - llm_first( - {'model_name': '""" + SELECT city, + llm_first( + {'model_name': '""" + test_model_name + """', 'batch_size': 2}, { @@ -720,7 +741,7 @@ def test_llm_first_image_batch_processing(integration_setup, model_config): ) AS tallest_building FROM building_images GROUP BY city - ORDER BY city; \ + ORDER BY city; """ ) result = run_cli(duckdb_cli_path, db_path, query) @@ -731,3 +752,142 @@ def test_llm_first_image_batch_processing(integration_setup, model_config): f"Expected at least 4 lines (header + 3 cities), got {len(lines)}" ) assert "tallest_building" in result.stdout.lower() + + +def test_llm_first_with_audio_transcription(integration_setup, model_config): + """Test llm_first with audio transcription using OpenAI.""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + if provider != "openai": + pytest.skip("Audio transcription is only supported for OpenAI provider") + + test_model_name = f"test-audio-first_{model_name}" + create_model_query = f"CREATE MODEL('{test_model_name}', 'gpt-4o-mini', 'openai');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = f"test-transcription-first_{model_name}" + create_transcription_model_query = f"CREATE MODEL('{transcription_model_name}', 'gpt-4o-mini-transcribe', 'openai');" + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + # Get audio file path + audio_path = get_audio_file_path() + + create_table_query = """ + CREATE OR REPLACE TABLE audio_descriptions ( + id INTEGER, + audio_path VARCHAR, + topic VARCHAR + ); + """ + run_cli(duckdb_cli_path, db_path, create_table_query) + + # Both rows have the same audio about Flock/DuckDB + insert_data_query = f""" + INSERT INTO audio_descriptions + VALUES + (0, '{audio_path}', 'Database'), + (1, '{audio_path}', 'AI'); + """ + run_cli(duckdb_cli_path, db_path, insert_data_query) + + query = ( + """ + SELECT llm_first( + {'model_name': '""" + + test_model_name + + """'}, + { + 'prompt': 'Based on the audio content, which description best relates to database technology? Return the ID number (0 or 1) only.', + 'context_columns': [ + { + 'data': audio_path, + 'type': 'audio', + 'transcription_model': '""" + + transcription_model_name + + """' + }, + {'data': topic} + ] + } + ) AS selected_id + FROM audio_descriptions; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + assert "selected_id" in result.stdout.lower() + + # Parse the JSON output to verify the returned tuple + lines = result.stdout.strip().split("\n") + assert len(lines) >= 2, "Expected at least header and one result row" + + # Parse CSV output to get the JSON result + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None and "selected_id" in row + + # Parse the JSON result which contains the tuple data + result_json = json.loads(row["selected_id"]) + assert isinstance(result_json, list), ( + f"Expected list of tuples, got: {type(result_json)}" + ) + assert len(result_json) > 0, "Expected at least one tuple in result" + + +def test_llm_first_audio_ollama_error(integration_setup): + """Test that Ollama provider throws error for audio transcription in llm_first.""" + duckdb_cli_path, db_path = integration_setup + + test_model_name = "test-ollama-first-audio" + create_model_query = ( + "CREATE MODEL('test-ollama-first-audio', 'gemma3:1b', 'ollama');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = "test-ollama-first-transcription" + create_transcription_model_query = ( + "CREATE MODEL('test-ollama-first-transcription', 'gemma3:1b', 'ollama');" + ) + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + create_table_query = """ + CREATE OR REPLACE TABLE test_audio ( + id INTEGER, + audio_url VARCHAR + ); + """ + run_cli(duckdb_cli_path, db_path, create_table_query) + + insert_data_query = """ + INSERT INTO test_audio VALUES + (1, 'https://example.com/audio1.mp3'), + (2, 'https://example.com/audio2.mp3'); + """ + run_cli(duckdb_cli_path, db_path, insert_data_query) + + query = """ + SELECT llm_first( + {'model_name': 'test-ollama-first-audio'}, + { + 'prompt': 'Select the best audio. Return ID only.', + 'context_columns': [ + { + 'data': audio_url, + 'type': 'audio', + 'transcription_model': 'test-ollama-first-transcription' + } + ] + } + ) AS result + FROM test_audio; + """ + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode != 0 + assert ( + "ollama" in result.stderr.lower() + or "transcription" in result.stderr.lower() + or "not supported" in result.stderr.lower() + ) diff --git a/test/integration/src/integration/tests/functions/aggregate/test_llm_last.py b/test/integration/src/integration/tests/functions/aggregate/test_llm_last.py index bb0892ab..8485f6a0 100644 --- a/test/integration/src/integration/tests/functions/aggregate/test_llm_last.py +++ b/test/integration/src/integration/tests/functions/aggregate/test_llm_last.py @@ -1,10 +1,27 @@ import pytest -from integration.conftest import run_cli, get_image_data_for_provider +import json +import csv +from io import StringIO +from integration.conftest import ( + run_cli, + get_image_data_for_provider, + get_audio_file_path, +) +# Expected keywords that should appear when audio is transcribed +# Audio content: "Flock transforms DuckDB into a hybrid database and a semantic AI engine" +AUDIO_EXPECTED_KEYWORDS = ["flock", "duckdb", "database", "semantic", "ai", "hybrid"] -@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("llama3.2", "ollama")]) + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:1b", "ollama")]) def model_config(request): - """Fixture to test with different models.""" + """Fixture to test with different models for text-only tests.""" + return request.param + + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:4b", "ollama")]) +def model_config_image(request): + """Fixture to test with different models for image tests.""" return request.param @@ -17,7 +34,7 @@ def test_llm_last_basic_functionality(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE products ( @@ -65,7 +82,7 @@ def test_llm_last_with_group_by(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE restaurant_reviews ( @@ -121,7 +138,7 @@ def test_llm_last_with_batch_processing(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE service_providers ( @@ -172,7 +189,7 @@ def test_llm_last_with_model_parameters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE movie_reviews ( @@ -224,7 +241,7 @@ def test_llm_last_multiple_criteria(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE housing_options ( @@ -275,7 +292,7 @@ def test_llm_last_empty_table(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE empty_products ( @@ -347,7 +364,7 @@ def test_llm_last_error_handling_empty_prompt(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE test_data ( @@ -388,7 +405,7 @@ def test_llm_last_error_handling_missing_arguments(integration_setup, model_conf create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) # Test with only 1 argument (should fail since llm_last requires 2) query = ( @@ -414,7 +431,7 @@ def test_llm_last_with_special_characters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE travel_destinations ( @@ -462,7 +479,7 @@ def _test_llm_last_performance_large_dataset(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE large_product_pool AS @@ -499,15 +516,15 @@ def _test_llm_last_performance_large_dataset(integration_setup, model_config): assert "category" in result.stdout.lower() -def test_llm_last_with_image_integration(integration_setup, model_config): +def test_llm_last_with_image_integration(integration_setup, model_config_image): """Test llm_last with image data integration.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image create_model_query = ( f"CREATE MODEL('test-image-last-model', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE flower_images ( @@ -564,15 +581,15 @@ def test_llm_last_with_image_integration(integration_setup, model_config): assert len(result.stdout.strip().split("\n")) >= 2 -def test_llm_last_image_with_group_by(integration_setup, model_config): +def test_llm_last_image_with_group_by(integration_setup, model_config_image): """Test llm_last with images and GROUP BY clause.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image create_model_query = ( f"CREATE MODEL('test-image-group-last', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE car_images ( @@ -639,15 +656,15 @@ def test_llm_last_image_with_group_by(integration_setup, model_config): assert "oldest_car" in result.stdout.lower() -def test_llm_last_image_batch_processing(integration_setup, model_config): +def test_llm_last_image_batch_processing(integration_setup, model_config_image): """Test llm_last with multiple images in batch processing.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image create_model_query = ( f"CREATE MODEL('test-image-batch-last', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE restaurant_images ( @@ -713,3 +730,147 @@ def test_llm_last_image_batch_processing(integration_setup, model_config): f"Expected at least 4 lines (header + 3 cuisines), got {len(lines)}" ) assert "lowest_rated_restaurant" in result.stdout.lower() + + +def test_llm_last_with_audio_transcription(integration_setup, model_config): + """Test llm_last with audio transcription using OpenAI. + + The audio content says: 'Flock transforms DuckDB into a hybrid database and a semantic AI engine' + This test verifies that the audio is correctly transcribed and the LLM can reason about the content. + """ + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + if provider != "openai": + pytest.skip("Audio transcription is only supported for OpenAI provider") + + test_model_name = f"test-audio-last_{model_name}" + create_model_query = f"CREATE MODEL('{test_model_name}', 'gpt-4o-mini', 'openai');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = f"test-transcription-last_{model_name}" + create_transcription_model_query = f"CREATE MODEL('{transcription_model_name}', 'gpt-4o-mini-transcribe', 'openai');" + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + # Get audio file path + audio_path = get_audio_file_path() + + # Create table with topics - one about Flock/DuckDB (audio content), one unrelated + create_table_query = """ + CREATE OR REPLACE TABLE audio_topics ( + id INTEGER, + topic VARCHAR, + audio_path VARCHAR + ); + """ + run_cli(duckdb_cli_path, db_path, create_table_query) + + # Row 1 has no real audio (empty), Row 2 has the actual Flock audio + insert_data_query = f""" + INSERT INTO audio_topics + VALUES + (1, 'Weather Forecast', '{audio_path}'), + (2, 'Database Technology', '{audio_path}'); + """ + run_cli(duckdb_cli_path, db_path, insert_data_query) + + # Ask which topic is about databases based on the audio + query = ( + """ + SELECT llm_last( + {'model_name': '""" + + test_model_name + + """'}, + { + 'prompt': 'Based on the topic and audio content (if available), which entry is about databases or Flock? Return the topic name.', + 'context_columns': [ + {'data': topic, 'type': 'text'}, + { + 'data': audio_path, + 'type': 'audio', + 'transcription_model': '""" + + transcription_model_name + + """' + } + ] + } + ) AS selected_topic + FROM audio_topics; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + # Parse the JSON output to verify the returned tuple + lines = result.stdout.strip().split("\n") + assert len(lines) >= 2, "Expected at least header and one result row" + + # Parse CSV output to get the JSON result + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None and "selected_topic" in row + + # Parse the JSON result which contains the tuple data + result_json = json.loads(row["selected_topic"]) + assert isinstance(result_json, list), ( + f"Expected list of tuples, got: {type(result_json)}" + ) + assert len(result_json) > 0, "Expected at least one tuple in result" + + +def test_llm_last_audio_ollama_error(integration_setup): + """Test that Ollama provider throws error for audio transcription in llm_last.""" + duckdb_cli_path, db_path = integration_setup + + test_model_name = "test-ollama-last-audio" + create_model_query = ( + "CREATE MODEL('test-ollama-last-audio', 'gemma3:1b', 'ollama');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = "test-ollama-last-transcription" + create_transcription_model_query = ( + "CREATE MODEL('test-ollama-last-transcription', 'gemma3:1b', 'ollama');" + ) + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + create_table_query = """ + CREATE OR REPLACE TABLE test_audio ( + id INTEGER, + audio_url VARCHAR + ); + """ + run_cli(duckdb_cli_path, db_path, create_table_query) + + insert_data_query = """ + INSERT INTO test_audio VALUES + (1, 'https://example.com/audio1.mp3'), + (2, 'https://example.com/audio2.mp3'); + """ + run_cli(duckdb_cli_path, db_path, insert_data_query) + + query = """ + SELECT llm_last( + {'model_name': 'test-ollama-last-audio'}, + { + 'prompt': 'Select the worst audio. Return ID only.', + 'context_columns': [ + { + 'data': audio_url, + 'type': 'audio', + 'transcription_model': 'test-ollama-last-transcription' + } + ] + } + ) AS result + FROM test_audio; + """ + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode != 0 + assert ( + "ollama" in result.stderr.lower() + or "transcription" in result.stderr.lower() + or "not supported" in result.stderr.lower() + ) diff --git a/test/integration/src/integration/tests/functions/aggregate/test_llm_reduce.py b/test/integration/src/integration/tests/functions/aggregate/test_llm_reduce.py index 435d34cf..e72fe199 100644 --- a/test/integration/src/integration/tests/functions/aggregate/test_llm_reduce.py +++ b/test/integration/src/integration/tests/functions/aggregate/test_llm_reduce.py @@ -1,10 +1,24 @@ import pytest -from integration.conftest import run_cli, get_image_data_for_provider +from integration.conftest import ( + run_cli, + get_image_data_for_provider, + get_audio_file_path, +) +# Expected keywords that should appear when audio is transcribed +# Audio content: "Flock transforms DuckDB into a hybrid database and a semantic AI engine" +AUDIO_EXPECTED_KEYWORDS = ["flock", "duckdb", "database", "semantic", "ai", "hybrid"] -@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("llama3.2", "ollama")]) + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:1b", "ollama")]) def model_config(request): - """Fixture to test with different models.""" + """Fixture to test with different models for text-only tests.""" + return request.param + + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:4b", "ollama")]) +def model_config_image(request): + """Fixture to test with different models for image tests.""" return request.param @@ -17,7 +31,7 @@ def test_llm_reduce_basic_functionality(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE products ( @@ -44,7 +58,7 @@ def test_llm_reduce_basic_functionality(integration_setup, model_config): {'model_name': '""" + test_model_name + """'}, - {'prompt': 'Summarize the following product descriptions into a single comprehensive summary', 'context_columns': [{'data': description}]} + {'prompt': 'Summarize these products in exactly 5 words', 'context_columns': [{'data': description}]} ) AS product_summary FROM products; \ """ @@ -67,7 +81,7 @@ def test_llm_reduce_with_group_by(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE product_reviews ( @@ -98,7 +112,7 @@ def test_llm_reduce_with_group_by(integration_setup, model_config): {'model_name': '""" + test_model_name + """'}, - {'prompt': 'Create a brief summary of these product reviews', 'context_columns': [{'data': review_text}]} + {'prompt': 'Summarize in 3 words', 'context_columns': [{'data': review_text}]} ) AS category_summary FROM product_reviews GROUP BY product_category @@ -127,7 +141,7 @@ def test_llm_reduce_multiple_columns(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE employee_feedback ( @@ -155,7 +169,7 @@ def test_llm_reduce_multiple_columns(integration_setup, model_config): {'model_name': '""" + test_model_name + """'}, - {'prompt': 'Summarize the team feedback and overall performance', 'context_columns': [{'data': employee_name}, {'data': feedback}, {'data': rating::VARCHAR}]} + {'prompt': 'Rate team in one word', 'context_columns': [{'data': employee_name}, {'data': feedback}, {'data': rating::VARCHAR}]} ) AS team_summary FROM employee_feedback GROUP BY department; \ @@ -177,7 +191,7 @@ def test_llm_reduce_with_batch_processing(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE articles ( @@ -204,7 +218,7 @@ def test_llm_reduce_with_batch_processing(integration_setup, model_config): {'model_name': '""" + test_model_name + """', 'batch_size': 2}, - {'prompt': 'Create a comprehensive summary of these articles', 'context_columns': [{'data': title}, {'data': content}]} + {'prompt': 'List topics in 5 words max', 'context_columns': [{'data': title}, {'data': content}]} ) AS articles_summary FROM articles; \ """ @@ -226,7 +240,7 @@ def test_llm_reduce_with_model_parameters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE news_items ( @@ -251,7 +265,7 @@ def test_llm_reduce_with_model_parameters(integration_setup, model_config): + test_model_name + """', 'tuple_format': 'Markdown', 'model_parameters': '{"temperature": 0.1}'}, - {'prompt': 'Provide a concise summary of these news items', 'context_columns': [{'data': headline}, {'data': summary}]} + {'prompt': 'Summarize in 3 words', 'context_columns': [{'data': headline}, {'data': summary}]} ) AS news_summary FROM news_items; \ """ @@ -271,7 +285,7 @@ def test_llm_reduce_empty_table(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE empty_data ( @@ -343,7 +357,7 @@ def test_llm_reduce_error_handling_empty_prompt(integration_setup, model_config) create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE test_data ( @@ -384,7 +398,7 @@ def test_llm_reduce_error_handling_missing_arguments(integration_setup, model_co create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) # Test with only 1 argument (should fail since llm_reduce requires 2) query = ( @@ -410,7 +424,7 @@ def test_llm_reduce_with_special_characters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE international_content ( @@ -434,7 +448,7 @@ def test_llm_reduce_with_special_characters(integration_setup, model_config): {'model_name': '""" + test_model_name + """'}, - {'prompt': 'Summarize these international text samples', 'context_columns': [{'data': text}]} + {'prompt': 'Describe in 3 words', 'context_columns': [{'data': text}]} ) AS summary FROM international_content; \ """ @@ -454,7 +468,7 @@ def test_llm_reduce_with_structured_output(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE structured_data ( @@ -519,7 +533,7 @@ def _test_llm_reduce_performance_large_dataset(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE large_dataset AS @@ -538,7 +552,7 @@ def _test_llm_reduce_performance_large_dataset(integration_setup, model_config): {'model_name': '""" + test_model_name + """', 'batch_size': 10}, - {'prompt': 'Create a comprehensive summary of all items in this category', 'context_columns': [{'data': content}]} + {'prompt': 'Summarize in 3 words', 'context_columns': [{'data': content}]} ) AS category_summary FROM large_dataset GROUP BY category @@ -555,16 +569,16 @@ def _test_llm_reduce_performance_large_dataset(integration_setup, model_config): assert "category" in result.stdout.lower() -def test_llm_reduce_with_image_integration(integration_setup, model_config): +def test_llm_reduce_with_image_integration(integration_setup, model_config_image): """Test llm_reduce with image data integration.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-reduce-model_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE animal_images ( @@ -605,7 +619,7 @@ def test_llm_reduce_with_image_integration(integration_setup, model_config): + test_model_name + """'}, { - 'prompt': 'Summarize the next data in json do not miss any data', + 'prompt': 'List animal names only', 'context_columns': [ {'data': name}, {'data': image, 'type': 'image'} @@ -623,16 +637,16 @@ def test_llm_reduce_with_image_integration(integration_setup, model_config): assert len(result.stdout.strip().split("\n")) >= 2 -def test_llm_reduce_image_with_group_by(integration_setup, model_config): +def test_llm_reduce_image_with_group_by(integration_setup, model_config_image): """Test llm_reduce with images and GROUP BY clause.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-group-reduce_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE product_images ( @@ -685,7 +699,7 @@ def test_llm_reduce_image_with_group_by(integration_setup, model_config): + test_model_name + """'}, { - 'prompt': 'Analyze these product images in this category and provide a summary of their design characteristics and market positioning.', + 'prompt': 'List product names in 5 words max', 'context_columns': [ {'data': product_name}, {'data': image_url, 'type': 'image'}, @@ -708,16 +722,16 @@ def test_llm_reduce_image_with_group_by(integration_setup, model_config): assert "category_analysis" in result.stdout.lower() -def test_llm_reduce_image_batch_processing(integration_setup, model_config): +def test_llm_reduce_image_batch_processing(integration_setup, model_config_image): """Test llm_reduce with multiple images in batch processing.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-batch-reduce_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE landscape_photos ( @@ -768,7 +782,7 @@ def test_llm_reduce_image_batch_processing(integration_setup, model_config): + test_model_name + """', 'batch_size': 3}, { - 'prompt': 'Analyze these landscape photographs and create a comprehensive summary of the natural environments, weather conditions, and seasonal characteristics shown.', + 'prompt': 'List locations in 5 words max', 'context_columns': [ {'data': location}, {'data': image_url, 'type': 'image'}, @@ -785,3 +799,138 @@ def test_llm_reduce_image_batch_processing(integration_setup, model_config): assert result.returncode == 0, f"Query failed with error: {result.stderr}" assert "landscape_summary" in result.stdout.lower() assert len(result.stdout.strip().split("\n")) >= 2 + + +def test_llm_reduce_with_audio_transcription(integration_setup, model_config): + """Test llm_reduce with audio transcription using OpenAI. + + The audio content says: 'Flock transforms DuckDB into a hybrid database and a semantic AI engine' + This test verifies that the audio is correctly transcribed and reduced into a summary. + """ + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + if provider != "openai": + pytest.skip("Audio transcription is only supported for OpenAI provider") + + test_model_name = f"test-audio-reduce_{model_name}" + create_model_query = f"CREATE MODEL('{test_model_name}', 'gpt-4o-mini', 'openai');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = f"test-transcription-reduce_{model_name}" + create_transcription_model_query = f"CREATE MODEL('{transcription_model_name}', 'gpt-4o-mini-transcribe', 'openai');" + run_cli( + duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False + ) + + # Get audio file path + audio_path = get_audio_file_path() + + # Create table with different topics and the same Flock audio + create_table_query = """ + CREATE OR REPLACE TABLE audio_content ( + id INTEGER, + topic VARCHAR, + audio_path VARCHAR + ); + """ + run_cli(duckdb_cli_path, db_path, create_table_query) + + insert_data_query = f""" + INSERT INTO audio_content + VALUES + (1, 'Technology Overview', '{audio_path}'), + (2, 'Product Demo', '{audio_path}'); + """ + run_cli(duckdb_cli_path, db_path, insert_data_query) + + query = ( + """ + SELECT llm_reduce( + {'model_name': '""" + + test_model_name + + """'}, + { + 'prompt': 'What product is discussed? Answer in 5 words max.', + 'context_columns': [ + {'data': topic, 'type': 'text'}, + { + 'data': audio_path, + 'type': 'audio', + 'transcription_model': '""" + + transcription_model_name + + """' + } + ] + } + ) AS audio_summary + FROM audio_content; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + # The summary should mention Flock, DuckDB, database, or related terms from the audio + result_lower = result.stdout.lower() + assert any(kw in result_lower for kw in AUDIO_EXPECTED_KEYWORDS), ( + f"Expected summary to contain keywords from audio content {AUDIO_EXPECTED_KEYWORDS}. Got: {result.stdout}" + ) + + +def test_llm_reduce_audio_ollama_error(integration_setup): + """Test that Ollama provider throws error for audio transcription in llm_reduce.""" + duckdb_cli_path, db_path = integration_setup + + test_model_name = "test-ollama-reduce-audio" + create_model_query = ( + "CREATE MODEL('test-ollama-reduce-audio', 'gemma3:1b', 'ollama');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = "test-ollama-reduce-transcription" + create_transcription_model_query = ( + "CREATE MODEL('test-ollama-reduce-transcription', 'gemma3:1b', 'ollama');" + ) + run_cli( + duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False + ) + + create_table_query = """ + CREATE OR REPLACE TABLE test_audio ( + id INTEGER, + audio_url VARCHAR + ); + """ + run_cli(duckdb_cli_path, db_path, create_table_query) + + insert_data_query = """ + INSERT INTO test_audio VALUES + (1, 'https://example.com/audio1.mp3'), + (2, 'https://example.com/audio2.mp3'); + """ + run_cli(duckdb_cli_path, db_path, insert_data_query) + + query = """ + SELECT llm_reduce( + {'model_name': 'test-ollama-reduce-audio'}, + { + 'prompt': 'Summarize this audio', + 'context_columns': [ + { + 'data': audio_url, + 'type': 'audio', + 'transcription_model': 'test-ollama-reduce-transcription' + } + ] + } + ) AS result + FROM test_audio; + """ + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode != 0 + assert ( + "ollama" in result.stderr.lower() + or "transcription" in result.stderr.lower() + or "not supported" in result.stderr.lower() + ) diff --git a/test/integration/src/integration/tests/functions/aggregate/test_llm_rerank.py b/test/integration/src/integration/tests/functions/aggregate/test_llm_rerank.py index a60d470c..6aec9a99 100644 --- a/test/integration/src/integration/tests/functions/aggregate/test_llm_rerank.py +++ b/test/integration/src/integration/tests/functions/aggregate/test_llm_rerank.py @@ -1,10 +1,27 @@ import pytest -from integration.conftest import run_cli, get_image_data_for_provider +import json +import csv +from io import StringIO +from integration.conftest import ( + run_cli, + get_image_data_for_provider, + get_audio_file_path, +) +# Expected keywords that should appear when audio is transcribed +# Audio content: "Flock transforms DuckDB into a hybrid database and a semantic AI engine" +AUDIO_EXPECTED_KEYWORDS = ["flock", "duckdb", "database", "semantic", "ai", "hybrid"] -@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("llama3.2", "ollama")]) + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:1b", "ollama")]) def model_config(request): - """Fixture to test with different models.""" + """Fixture to test with different models for text-only tests.""" + return request.param + + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:4b", "ollama")]) +def model_config_image(request): + """Fixture to test with different models for image tests.""" return request.param @@ -17,7 +34,7 @@ def test_llm_rerank_basic_functionality(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE search_results ( @@ -67,7 +84,7 @@ def test_llm_rerank_with_group_by(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE product_listings ( @@ -129,7 +146,7 @@ def test_llm_rerank_with_batch_processing(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE job_candidates ( @@ -182,7 +199,7 @@ def test_llm_rerank_with_model_parameters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE restaurant_options ( @@ -232,7 +249,7 @@ def test_llm_rerank_multiple_criteria(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE investment_funds ( @@ -282,7 +299,7 @@ def test_llm_rerank_empty_table(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE empty_items ( @@ -356,7 +373,7 @@ def test_llm_rerank_error_handling_empty_prompt(integration_setup, model_config) create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE test_data ( @@ -397,7 +414,7 @@ def test_llm_rerank_error_handling_missing_arguments(integration_setup, model_co create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) # Test with only 2 arguments (should fail since llm_rerank requires 3) query = ( @@ -424,7 +441,7 @@ def test_llm_rerank_with_special_characters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE international_dishes ( @@ -470,7 +487,7 @@ def _test_llm_rerank_performance_large_dataset(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE large_search_results AS @@ -508,16 +525,16 @@ def _test_llm_rerank_performance_large_dataset(integration_setup, model_config): assert "category" in result.stdout.lower() -def test_llm_rerank_with_image_integration(integration_setup, model_config): +def test_llm_rerank_with_image_integration(integration_setup, model_config_image): """Test llm_rerank with image data integration.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-rerank-model_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE fashion_images ( @@ -583,16 +600,16 @@ def test_llm_rerank_with_image_integration(integration_setup, model_config): assert len(result.stdout.strip().split("\n")) >= 2 -def test_llm_rerank_image_with_group_by(integration_setup, model_config): +def test_llm_rerank_image_with_group_by(integration_setup, model_config_image): """Test llm_rerank with images and GROUP BY clause.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-group-rerank_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE interior_images ( @@ -665,16 +682,16 @@ def test_llm_rerank_image_with_group_by(integration_setup, model_config): assert "ranked_room_designs" in result.stdout.lower() -def test_llm_rerank_image_batch_processing(integration_setup, model_config): +def test_llm_rerank_image_batch_processing(integration_setup, model_config_image): """Test llm_rerank with multiple images in batch processing.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-batch-rerank_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE travel_destination_images ( @@ -744,3 +761,148 @@ def test_llm_rerank_image_batch_processing(integration_setup, model_config): f"Expected at least 4 lines (header + 3 countries), got {len(lines)}" ) assert "ranked_destinations" in result.stdout.lower() + + +def test_llm_rerank_with_audio_transcription(integration_setup, model_config): + """Test llm_rerank with audio transcription using OpenAI. + + The audio content says: 'Flock transforms DuckDB into a hybrid database and a semantic AI engine' + This test verifies that the audio is correctly transcribed and used for reranking. + """ + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + if provider != "openai": + pytest.skip("Audio transcription is only supported for OpenAI provider") + + test_model_name = f"test-audio-rerank_{model_name}" + create_model_query = f"CREATE MODEL('{test_model_name}', 'gpt-4o-mini', 'openai');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = f"test-transcription-rerank_{model_name}" + create_transcription_model_query = f"CREATE MODEL('{transcription_model_name}', 'gpt-4o-mini-transcribe', 'openai');" + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + # Get audio file path + audio_path = get_audio_file_path() + + # Create table with topics - mix database-related (with audio) and unrelated topics + create_table_query = """ + CREATE OR REPLACE TABLE audio_topics ( + id INTEGER, + topic VARCHAR, + audio_path VARCHAR + ); + """ + run_cli(duckdb_cli_path, db_path, create_table_query) + + # Only the Database Technology row has the actual audio + insert_data_query = f""" + INSERT INTO audio_topics + VALUES + (1, 'Weather Updates', '{audio_path}'), + (2, 'Database Technology', '{audio_path}'), + (3, 'Sports News', '{audio_path}'); + """ + run_cli(duckdb_cli_path, db_path, insert_data_query) + + # Ask to rank by relevance to databases/Flock - the real audio should rank higher + query = ( + """ + SELECT llm_rerank( + {'model_name': '""" + + test_model_name + + """'}, + { + 'prompt': 'Rank these entries by relevance to database technology and Flock. Return results with the most relevant first.', + 'context_columns': [ + {'data': topic, 'type': 'text'}, + { + 'data': audio_path, + 'type': 'audio', + 'transcription_model': '""" + + transcription_model_name + + """' + } + ] + } + ) AS ranked_topics + FROM audio_topics; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + # Parse the JSON output to verify the returned tuples + lines = result.stdout.strip().split("\n") + assert len(lines) >= 2, "Expected at least header and one result row" + + # Parse CSV output to get the JSON result + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None and "ranked_topics" in row + + # Parse the JSON result which contains the reranked tuples + result_json = json.loads(row["ranked_topics"]) + assert isinstance(result_json, list), ( + f"Expected list of tuples, got: {type(result_json)}" + ) + assert len(result_json) > 0, "Expected at least one tuple in result" + + +def test_llm_rerank_audio_ollama_error(integration_setup): + """Test that Ollama provider throws error for audio transcription in llm_rerank.""" + duckdb_cli_path, db_path = integration_setup + + test_model_name = "test-ollama-rerank-audio" + create_model_query = ( + "CREATE MODEL('test-ollama-rerank-audio', 'gemma3:1b', 'ollama');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = "test-ollama-rerank-transcription" + create_transcription_model_query = ( + "CREATE MODEL('test-ollama-rerank-transcription', 'gemma3:1b', 'ollama');" + ) + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + create_table_query = """ + CREATE OR REPLACE TABLE test_audio ( + id INTEGER, + audio_url VARCHAR + ); + """ + run_cli(duckdb_cli_path, db_path, create_table_query) + + insert_data_query = """ + INSERT INTO test_audio VALUES + (1, 'https://example.com/audio1.mp3'), + (2, 'https://example.com/audio2.mp3'); + """ + run_cli(duckdb_cli_path, db_path, insert_data_query) + + query = """ + SELECT llm_rerank( + {'model_name': 'test-ollama-rerank-audio'}, + { + 'prompt': 'Rank these audio files', + 'context_columns': [ + { + 'data': audio_url, + 'type': 'audio', + 'transcription_model': 'test-ollama-rerank-transcription' + } + ] + } + ) AS result + FROM test_audio; + """ + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode != 0 + assert ( + "ollama" in result.stderr.lower() + or "transcription" in result.stderr.lower() + or "not supported" in result.stderr.lower() + ) diff --git a/test/integration/src/integration/tests/functions/scalar/test_llm_complete.py b/test/integration/src/integration/tests/functions/scalar/test_llm_complete.py index 8afd3df4..c0107836 100644 --- a/test/integration/src/integration/tests/functions/scalar/test_llm_complete.py +++ b/test/integration/src/integration/tests/functions/scalar/test_llm_complete.py @@ -1,10 +1,24 @@ import pytest -from integration.conftest import run_cli, get_image_data_for_provider +from integration.conftest import ( + run_cli, + get_image_data_for_provider, + get_audio_file_path, +) +# Expected keywords that should appear when audio is transcribed +# Audio content: "Flock transforms DuckDB into a hybrid database and a semantic AI engine" +AUDIO_EXPECTED_KEYWORDS = ["flock", "duckdb", "database", "semantic", "ai", "hybrid"] -@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("llama3.2", "ollama")]) + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:1b", "ollama")]) def model_config(request): - """Fixture to test with different models.""" + """Fixture to test with different models for text-only tests.""" + return request.param + + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:4b", "ollama")]) +def model_config_image(request): + """Fixture to test with different models for image tests.""" return request.param @@ -16,7 +30,7 @@ def test_llm_complete_basic_functionality(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) query = ( """ @@ -42,7 +56,7 @@ def test_llm_complete_with_input_columns(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE countries ( @@ -89,7 +103,7 @@ def test_llm_complete_batch_processing(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE product_reviews ( @@ -153,7 +167,7 @@ def test_llm_complete_error_handling_empty_prompt(integration_setup, model_confi create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) query = ( """ @@ -178,7 +192,7 @@ def test_llm_complete_with_special_characters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE special_text ( @@ -223,7 +237,7 @@ def test_llm_complete_with_model_params(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) query = ( """ @@ -251,7 +265,7 @@ def test_llm_complete_with_structured_output_without_table( create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) response_format = "" if provider == "openai": @@ -323,7 +337,7 @@ def test_llm_complete_with_structured_output_with_table( create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE countries ( @@ -411,7 +425,7 @@ def _llm_complete_performance_large_dataset(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE large_dataset AS @@ -444,16 +458,16 @@ def _llm_complete_performance_large_dataset(integration_setup, model_config): ) -def test_llm_complete_with_image_integration(integration_setup, model_config): +def test_llm_complete_with_image_integration(integration_setup, model_config_image): """Test llm_complete with image data integration.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-model_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE animal_images ( @@ -514,16 +528,16 @@ def test_llm_complete_with_image_integration(integration_setup, model_config): assert len(result.stdout.strip().split("\n")) >= 2 -def test_llm_complete_image_batch_processing(integration_setup, model_config): +def test_llm_complete_image_batch_processing(integration_setup, model_config_image): """Test llm_complete with multiple images in batch processing.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-batch-model_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE product_images ( @@ -589,16 +603,16 @@ def test_llm_complete_image_batch_processing(integration_setup, model_config): assert "product_analysis" in result.stdout.lower() -def test_llm_complete_image_with_text_context(integration_setup, model_config): +def test_llm_complete_image_with_text_context(integration_setup, model_config_image): """Test llm_complete with both image and text context.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-text-model_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE landscape_photos ( @@ -662,3 +676,270 @@ def test_llm_complete_image_with_text_context(integration_setup, model_config): assert result.returncode == 0, f"Query failed with error: {result.stderr}" assert "atmosphere_description" in result.stdout.lower() assert len(result.stdout.strip().split("\n")) >= 2 + + +def test_llm_complete_with_audio_transcription(integration_setup, model_config): + """Test llm_complete with audio transcription using OpenAI.""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + # Skip if not OpenAI (only OpenAI supports transcription currently) + if provider != "openai": + pytest.skip("Audio transcription is only supported for OpenAI provider") + + # Create main completion model + test_model_name = f"test-audio-complete_{model_name}" + create_model_query = f"CREATE MODEL('{test_model_name}', 'gpt-4o-mini', 'openai');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Create transcription model + transcription_model_name = f"test-transcription-model_{model_name}" + create_transcription_model_query = f"CREATE MODEL('{transcription_model_name}', 'gpt-4o-mini-transcribe', 'openai');" + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + # Get audio file path + audio_path = get_audio_file_path() + + # Test with audio file path using VALUES + query = ( + """ + SELECT llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + { + 'prompt': 'What product or technology is mentioned in this audio? Provide a brief answer.', + 'context_columns': [ + { + 'data': audio_path, + 'type': 'audio', + 'transcription_model': '""" + + transcription_model_name + + """' + } + ] + } + ) AS audio_summary + FROM VALUES ('""" + + audio_path + + """') AS tbl(audio_path); + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + assert "audio_summary" in result.stdout.lower() + # Verify the response is based on the audio content + output_lower = result.stdout.lower() + assert any(keyword in output_lower for keyword in AUDIO_EXPECTED_KEYWORDS), ( + f"Expected response to contain at least one of {AUDIO_EXPECTED_KEYWORDS}, got: {result.stdout}" + ) + + +def test_llm_complete_with_audio_and_text(integration_setup, model_config): + """Test llm_complete with both audio and text context columns.""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + if provider != "openai": + pytest.skip("Audio transcription is only supported for OpenAI provider") + + test_model_name = f"test-audio-text_{model_name}" + create_model_query = f"CREATE MODEL('{test_model_name}', 'gpt-4o-mini', 'openai');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = f"test-transcription_{model_name}" + create_transcription_model_query = f"CREATE MODEL('{transcription_model_name}', 'gpt-4o-mini-transcribe', 'openai');" + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + # Get audio file path + audio_path = get_audio_file_path() + + # Test with audio file path using VALUES - combining text context with audio + query = ( + """ + SELECT llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + { + 'prompt': 'Given the category {category}, describe how the technology mentioned in the audio fits into this category.', + 'context_columns': [ + {'data': category_name, 'name': 'category'}, + { + 'data': audio_path, + 'type': 'audio', + 'transcription_model': '""" + + transcription_model_name + + """' + } + ] + } + ) AS tech_description + FROM VALUES ('Database Technology', '""" + + audio_path + + """') AS tbl(category_name, audio_path); + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + assert "tech_description" in result.stdout.lower() + # Verify the response mentions something from the audio content + output_lower = result.stdout.lower() + assert any(keyword in output_lower for keyword in AUDIO_EXPECTED_KEYWORDS), ( + f"Expected response to contain at least one of {AUDIO_EXPECTED_KEYWORDS}, got: {result.stdout}" + ) + + +def test_llm_complete_audio_missing_transcription_model(integration_setup): + """Test that audio type requires transcription_model.""" + duckdb_cli_path, db_path = integration_setup + + test_model_name = "test-audio-error" + create_model_query = f"CREATE MODEL('{test_model_name}', 'gpt-4o-mini', 'openai');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Get audio file path + audio_path = get_audio_file_path() + + query = ( + """ + SELECT llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + { + 'prompt': 'Summarize this audio', + 'context_columns': [ + { + 'data': audio_path, + 'type': 'audio' + } + ] + } + ) AS result + FROM VALUES ('""" + + audio_path + + """') AS tbl(audio_path); + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + # Should fail because transcription_model is required for audio type + assert result.returncode != 0 + assert ( + "transcription_model" in result.stderr.lower() + or "required" in result.stderr.lower() + ) + + +def test_llm_complete_audio_ollama_error(integration_setup): + """Test that Ollama provider throws error for audio transcription.""" + duckdb_cli_path, db_path = integration_setup + + create_model_query = "CREATE MODEL('test-ollama-audio', 'gemma3:1b', 'ollama');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + create_transcription_model_query = ( + "CREATE MODEL('test-ollama-transcription', 'gemma3:1b', 'ollama');" + ) + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + # Get audio file path + audio_path = get_audio_file_path() + + query = f""" + SELECT llm_complete( + {{"model_name": 'test-ollama-audio'}}, + {{ + 'prompt': 'Summarize this audio', + 'context_columns': [ + {{ + 'data': audio_path, + 'type': 'audio', + 'transcription_model': 'test-ollama-transcription' + }} + ] + }} + ) AS result + FROM VALUES ('{audio_path}') AS tbl(audio_path); + """ + result = run_cli(duckdb_cli_path, db_path, query) + + # Should fail because Ollama doesn't support transcription + assert result.returncode != 0 + assert ( + "ollama" in result.stderr.lower() + or "transcription" in result.stderr.lower() + or "not supported" in result.stderr.lower() + ) + + +def test_llm_complete_audio_batch_processing(integration_setup, model_config): + """Test batch processing with multiple audio files.""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + if provider != "openai": + pytest.skip("Audio transcription is only supported for OpenAI provider") + + test_model_name = f"test-audio-batch_{model_name}" + create_model_query = f"CREATE MODEL('{test_model_name}', 'gpt-4o-mini', 'openai');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = f"test-transcription-batch_{model_name}" + create_transcription_model_query = f"CREATE MODEL('{transcription_model_name}', 'gpt-4o-mini-transcribe', 'openai');" + run_cli(duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False) + + # Get audio file path + audio_path = get_audio_file_path() + + create_table_query = """ + CREATE OR REPLACE TABLE audio_clips ( + id INTEGER, + audio_path VARCHAR, + product_name VARCHAR + ); + """ + run_cli(duckdb_cli_path, db_path, create_table_query) + + insert_data_query = f""" + INSERT INTO audio_clips + VALUES + (1, '{audio_path}', 'Headphones'), + (2, '{audio_path}', 'Speaker'), + (3, '{audio_path}', 'Microphone'); + """ + run_cli(duckdb_cli_path, db_path, insert_data_query) + + query = ( + """ + SELECT product_name, + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + { + 'prompt': 'Based on the product {product} and its audio, write a short description.', + 'context_columns': [ + {'data': product_name, 'name': 'product'}, + { + 'data': audio_path, + 'type': 'audio', + 'transcription_model': '""" + + transcription_model_name + + """' + } + ] + } + ) AS description + FROM audio_clips + WHERE id <= 2; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + lines = result.stdout.strip().split("\n") + assert len(lines) >= 3 # Header + at least 2 data rows diff --git a/test/integration/src/integration/tests/functions/scalar/test_llm_embedding.py b/test/integration/src/integration/tests/functions/scalar/test_llm_embedding.py index 5169d2d0..dc8d7e3e 100644 --- a/test/integration/src/integration/tests/functions/scalar/test_llm_embedding.py +++ b/test/integration/src/integration/tests/functions/scalar/test_llm_embedding.py @@ -18,7 +18,7 @@ def test_llm_embedding_basic_functionality(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) query = ( """ @@ -47,7 +47,7 @@ def test_llm_embedding_with_multiple_text_fields(integration_setup, model_config create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) query = ( """ @@ -75,7 +75,7 @@ def test_llm_embedding_with_input_columns(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE documents ( @@ -130,7 +130,7 @@ def test_llm_embedding_batch_processing(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE product_descriptions ( @@ -200,7 +200,7 @@ def test_llm_embedding_error_handling_empty_text(integration_setup, model_config create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) query = ( """ @@ -227,7 +227,7 @@ def test_llm_embedding_with_special_characters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE special_text ( @@ -272,7 +272,7 @@ def test_llm_embedding_with_model_params(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) query = ( """ @@ -299,7 +299,7 @@ def test_llm_embedding_document_similarity_use_case(integration_setup, model_con create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE knowledge_base ( @@ -356,7 +356,7 @@ def test_llm_embedding_concatenated_fields(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE products ( @@ -410,7 +410,7 @@ def _llm_embedding_performance_large_dataset(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE large_text_dataset AS @@ -456,7 +456,7 @@ def test_llm_embedding_error_handling_malformed_input(integration_setup, model_c create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) # Test with missing required arguments query = """ diff --git a/test/integration/src/integration/tests/functions/scalar/test_llm_filter.py b/test/integration/src/integration/tests/functions/scalar/test_llm_filter.py index 93b6930c..ba1ed273 100644 --- a/test/integration/src/integration/tests/functions/scalar/test_llm_filter.py +++ b/test/integration/src/integration/tests/functions/scalar/test_llm_filter.py @@ -1,10 +1,24 @@ import pytest -from integration.conftest import run_cli, get_image_data_for_provider +from integration.conftest import ( + run_cli, + get_image_data_for_provider, + get_audio_file_path, +) +# Expected keywords that should appear when audio is transcribed +# Audio content: "Flock transforms DuckDB into a hybrid database and a semantic AI engine" +AUDIO_EXPECTED_KEYWORDS = ["flock", "duckdb", "database", "semantic", "ai", "hybrid"] -@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("llama3.2", "ollama")]) + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:1b", "ollama")]) def model_config(request): - """Fixture to test with different models.""" + """Fixture to test with different models for text-only tests.""" + return request.param + + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:4b", "ollama")]) +def model_config_image(request): + """Fixture to test with different models for image tests.""" return request.param @@ -16,7 +30,7 @@ def test_llm_filter_basic_functionality(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE test_data ( @@ -41,7 +55,7 @@ def test_llm_filter_basic_functionality(integration_setup, model_config): {'model_name': '""" + test_model_name + """'}, - {'prompt': 'Is this text positive? Answer true or false.', 'context_columns': [{'data': text}]} + {'prompt': 'Is this text positive?', 'context_columns': [{'data': text}]} ) AS is_positive FROM test_data WHERE id = 1; @@ -54,6 +68,33 @@ def test_llm_filter_basic_functionality(integration_setup, model_config): assert "is_positive" in result.stdout.lower() +def test_llm_filter_without_context_columns(integration_setup, model_config): + """Test llm_filter without context_columns parameter.""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + test_model_name = f"test-filter-no-context_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + query = ( + """ + SELECT llm_filter( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Is paris the best capital in the world?'} + ) AS filter_result; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + assert "true" in result.stdout.lower() or "false" in result.stdout.lower() + + def test_llm_filter_batch_processing(integration_setup, model_config): duckdb_cli_path, db_path = integration_setup model_name, provider = model_config @@ -62,7 +103,7 @@ def test_llm_filter_batch_processing(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE test_items ( @@ -92,7 +133,7 @@ def test_llm_filter_batch_processing(integration_setup, model_config): {'model_name': '""" + test_model_name + """', 'batch_size': 2}, - {'prompt': 'Is this item technology-related? Answer true or false.', 'context_columns': [{'data': text}]} + {'prompt': 'Is this item technology-related?', 'context_columns': [{'data': text}]} ) AS is_tech FROM test_items; """ @@ -102,7 +143,7 @@ def test_llm_filter_batch_processing(integration_setup, model_config): assert result.returncode == 0, f"Query failed with error: {result.stderr}" lines = result.stdout.strip().split("\n") assert len(lines) >= 6, f"Expected at least 6 lines, got {len(lines)}" - assert "true" in result.stdout.lower() and "false" in result.stdout.lower() + assert "true" in result.stdout.lower() or "false" in result.stdout.lower() def test_llm_filter_error_handling_invalid_model(integration_setup): @@ -146,7 +187,7 @@ def test_llm_filter_error_handling_empty_prompt(integration_setup, model_config) create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE test_data ( @@ -186,7 +227,7 @@ def test_llm_filter_with_special_characters(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE special_text ( @@ -212,7 +253,7 @@ def test_llm_filter_with_special_characters(integration_setup, model_config): {'model_name': '""" + test_model_name + """'}, - {'prompt': 'Does this text contain non-ASCII characters? Answer true or false.', 'context_columns': [{'data': text}]} + {'prompt': 'Does this text contain non-ASCII characters?', 'context_columns': [{'data': text}]} ) AS has_unicode FROM special_text WHERE id = 1; @@ -232,7 +273,7 @@ def test_llm_filter_with_model_params(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE test_data ( @@ -256,7 +297,7 @@ def test_llm_filter_with_model_params(integration_setup, model_config): {'model_name': '""" + test_model_name + """', 'tuple_format': 'Markdown', 'batch_size': 1, 'model_parameters': '{"temperature": 0}'}, - {'prompt': 'Is this text expressing positive sentiment? Answer true or false only.', 'context_columns': [{'data': text}]} + {'prompt': 'Is this text expressing positive sentiment?', 'context_columns': [{'data': text}]} ) AS is_positive FROM test_data; """ @@ -275,7 +316,7 @@ def test_llm_filter_with_structured_output(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE items ( @@ -340,7 +381,7 @@ def test_llm_filter_error_handling_missing_arguments(integration_setup, model_co create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) # Test with only 1 argument (should fail since llm_filter requires 2) query = ( @@ -366,7 +407,7 @@ def _test_llm_filter_performance_large_dataset(integration_setup, model_config): create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE large_content AS @@ -385,7 +426,7 @@ def _test_llm_filter_performance_large_dataset(integration_setup, model_config): {'model_name': '""" + test_model_name + """', 'batch_size': 5}, - {'prompt': 'Does this content contain the word "item"? Answer true or false.', 'context_columns': [{'data': content}]} + {'prompt': 'Does this content contain the word "item"?', 'context_columns': [{'data': content}]} ) AS filter_result FROM large_content LIMIT 10; @@ -401,16 +442,16 @@ def _test_llm_filter_performance_large_dataset(integration_setup, model_config): assert "true" in result.stdout.lower() or "false" in result.stdout.lower() -def test_llm_filter_with_image_integration(integration_setup, model_config): +def test_llm_filter_with_image_integration(integration_setup, model_config_image): """Test llm_filter with image data integration.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-filter-model_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE vehicle_images ( @@ -453,7 +494,7 @@ def test_llm_filter_with_image_integration(integration_setup, model_config): + test_model_name + """'}, { - 'prompt': 'Is this image showing a motorized vehicle? Answer true or false.', + 'prompt': 'Is this image showing a motorized vehicle?', 'context_columns': [ {'data': vehicle_type}, {'data': image_url, 'type': 'image'} @@ -471,16 +512,16 @@ def test_llm_filter_with_image_integration(integration_setup, model_config): assert len(result.stdout.strip().split("\n")) >= 2 -def test_llm_filter_image_batch_processing(integration_setup, model_config): +def test_llm_filter_image_batch_processing(integration_setup, model_config_image): """Test llm_filter with multiple images in batch processing.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-batch-filter_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE food_images ( @@ -524,7 +565,7 @@ def test_llm_filter_image_batch_processing(integration_setup, model_config): + test_model_name + """'}, { - 'prompt': 'Does this food image look appetizing and well-presented? Answer true or false.', + 'prompt': 'Does this food image look appetizing and well-presented?', 'context_columns': [ {'data': food_name}, {'data': image_url, 'type': 'image'} @@ -545,16 +586,16 @@ def test_llm_filter_image_batch_processing(integration_setup, model_config): assert "is_appetizing" in result.stdout.lower() -def test_llm_filter_image_with_text_context(integration_setup, model_config): +def test_llm_filter_image_with_text_context(integration_setup, model_config_image): """Test llm_filter with both image and text context.""" duckdb_cli_path, db_path = integration_setup - model_name, provider = model_config + model_name, provider = model_config_image test_model_name = f"test-image-text-filter_{model_name}" create_model_query = ( f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" ) - run_cli(duckdb_cli_path, db_path, create_model_query) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) create_table_query = """ CREATE OR REPLACE TABLE clothing_images ( @@ -600,7 +641,7 @@ def test_llm_filter_image_with_text_context(integration_setup, model_config): + test_model_name + """'}, { - 'prompt': 'Based on the image and the season/price information, is this clothing item appropriate for its intended season and price range? Answer true or false.', + 'prompt': 'Based on the image and the season/price information, is this clothing item appropriate for its intended season and price range?', 'context_columns': [ {'data': item_name}, {'data': image_url, 'type': 'image'}, @@ -618,3 +659,107 @@ def test_llm_filter_image_with_text_context(integration_setup, model_config): assert result.returncode == 0, f"Query failed with error: {result.stderr}" assert "is_appropriate" in result.stdout.lower() assert len(result.stdout.strip().split("\n")) >= 2 + + +def test_llm_filter_with_audio_transcription(integration_setup, model_config): + """Test llm_filter with audio transcription using OpenAI. + + The audio content says: 'Flock transforms DuckDB into a hybrid database and a semantic AI engine' + This test verifies that the audio is correctly transcribed and filtered. + """ + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + if provider != "openai": + pytest.skip("Audio transcription is only supported for OpenAI provider") + + test_model_name = f"test-audio-filter_{model_name}" + create_model_query = f"CREATE MODEL('{test_model_name}', 'gpt-4o-mini', 'openai');" + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = f"test-transcription-filter_{model_name}" + create_transcription_model_query = f"CREATE MODEL('{transcription_model_name}', 'gpt-4o-mini-transcribe', 'openai');" + run_cli( + duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False + ) + + # Get audio file path + audio_path = get_audio_file_path() + + # Test with audio file path - the audio actually mentions DuckDB/Flock + query = ( + """ + SELECT llm_filter( + {'model_name': '""" + + test_model_name + + """'}, + { + 'prompt': 'Does this audio mention DuckDB or databases?', + 'context_columns': [ + { + 'data': audio_path, + 'type': 'audio', + 'transcription_model': '""" + + transcription_model_name + + """' + } + ] + } + ) AS mentions_database + FROM VALUES ('""" + + audio_path + + """') AS tbl(audio_path); + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + # The audio mentions DuckDB, so the filter should return true + result_lower = result.stdout.lower() + assert "true" in result_lower, ( + f"Expected 'true' since audio mentions DuckDB. Got: {result.stdout}" + ) + + +def test_llm_filter_audio_ollama_error(integration_setup): + """Test that Ollama provider throws error for audio transcription in llm_filter.""" + duckdb_cli_path, db_path = integration_setup + + test_model_name = "test-ollama-filter-audio" + create_model_query = ( + "CREATE MODEL('test-ollama-filter-audio', 'gemma3:1b', 'ollama');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + transcription_model_name = "test-ollama-filter-transcription" + create_transcription_model_query = ( + "CREATE MODEL('test-ollama-filter-transcription', 'gemma3:1b', 'ollama');" + ) + run_cli( + duckdb_cli_path, db_path, create_transcription_model_query, with_secrets=False + ) + + query = """ + SELECT llm_filter( + {'model_name': 'test-ollama-filter-audio'}, + { + 'prompt': 'Is the sentiment positive?', + 'context_columns': [ + { + 'data': audio_url, + 'type': 'audio', + 'transcription_model': 'test-ollama-filter-transcription' + } + ] + } + ) AS result + FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url); + """ + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode != 0 + assert ( + "ollama" in result.stderr.lower() + or "transcription" in result.stderr.lower() + or "not supported" in result.stderr.lower() + ) diff --git a/test/integration/src/integration/tests/metrics/test_metrics.py b/test/integration/src/integration/tests/metrics/test_metrics.py new file mode 100644 index 00000000..9535a5cf --- /dev/null +++ b/test/integration/src/integration/tests/metrics/test_metrics.py @@ -0,0 +1,854 @@ +import pytest +import json +import csv +from io import StringIO +from integration.conftest import run_cli + + +def get_json_from_csv_output(stdout, column_name="metrics"): + """Extract JSON value from DuckDB CSV output""" + reader = csv.DictReader(StringIO(stdout)) + row = next(reader, None) + if row and column_name in row: + return json.loads(row[column_name]) + return None + + +@pytest.fixture(params=[("gpt-4o-mini", "openai"), ("gemma3:1b", "ollama")]) +def model_config(request): + return request.param + + +# ============================================================================ +# Basic Metrics API Tests +# ============================================================================ + + +def test_flock_get_metrics_returns_json(integration_setup): + duckdb_cli_path, db_path = integration_setup + query = "SELECT flock_get_metrics() AS metrics;" + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + metrics = get_json_from_csv_output(result.stdout) + assert metrics is not None, "No JSON found in output" + + # Check new structure - should be a flat object + assert isinstance(metrics, dict) + assert len(metrics) == 0 # Initially empty + + +def test_flock_reset_metrics(integration_setup): + duckdb_cli_path, db_path = integration_setup + query = "SELECT flock_reset_metrics() AS result;" + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + assert "reset" in result.stdout.lower() + + +# ============================================================================ +# Scalar Function Metrics Tests +# ============================================================================ + + +def test_metrics_after_llm_complete(integration_setup, model_config): + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-metrics-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Call llm_complete and get_metrics in the same query + query = ( + """ + SELECT + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Answer with one number: What is 2+2?'} + ) AS result, + flock_get_metrics() AS metrics; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + # Parse CSV output to get metrics + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None, "No data returned from query" + assert "metrics" in row, "Metrics column not found in output" + + metrics = json.loads(row["metrics"]) + + # Check that metrics were recorded - should be a flat object with keys like "llm_complete_1" + assert isinstance(metrics, dict) + assert len(metrics) > 0 + + # Check that we have llm_complete_1 with proper structure + assert "llm_complete_1" in metrics, ( + f"Expected llm_complete_1 in metrics, got: {list(metrics.keys())}" + ) + llm_complete_1 = metrics["llm_complete_1"] + + assert "api_calls" in llm_complete_1 + assert llm_complete_1["api_calls"] > 0 + assert "input_tokens" in llm_complete_1 + assert "output_tokens" in llm_complete_1 + assert "total_tokens" in llm_complete_1 + assert "api_duration_ms" in llm_complete_1 + assert "execution_time_ms" in llm_complete_1 + assert "model_name" in llm_complete_1 + assert llm_complete_1["model_name"] == test_model_name + assert "provider" in llm_complete_1 + assert llm_complete_1["provider"] == provider + + +def test_metrics_reset_clears_counters(integration_setup, model_config): + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + test_model_name = f"test-reset-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # First query: execute llm_complete and get metrics in the same query + query1 = ( + """ + SELECT + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Say one word: hello'} + ) AS result, + flock_get_metrics() AS metrics; + """ + ) + result1 = run_cli(duckdb_cli_path, db_path, query1) + assert result1.returncode == 0 + + # Parse metrics from first query to verify they exist + reader1 = csv.DictReader(StringIO(result1.stdout)) + row1 = next(reader1, None) + assert row1 is not None and "metrics" in row1 + metrics1 = json.loads(row1["metrics"]) + assert len(metrics1) > 0, "Metrics should be recorded before reset" + assert "llm_complete_1" in metrics1, "Should have llm_complete_1 after first call" + + # Second query: reset metrics and get metrics in the same query + query2 = ( + "SELECT flock_reset_metrics() AS reset_result, flock_get_metrics() AS metrics;" + ) + result2 = run_cli(duckdb_cli_path, db_path, query2) + assert result2.returncode == 0 + + # Parse metrics from second query to verify they're cleared + reader2 = csv.DictReader(StringIO(result2.stdout)) + row2 = next(reader2, None) + assert row2 is not None and "metrics" in row2 + metrics2 = json.loads(row2["metrics"]) + + # After reset, should be empty + assert isinstance(metrics2, dict) + assert len(metrics2) == 0 + + +def test_sequential_numbering_multiple_calls(integration_setup, model_config): + """Test that multiple calls of the same function get sequential numbering""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-sequential-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Make three calls to llm_complete in the same query + query = ( + """ + SELECT + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Say: one'} + ) AS result1, + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Say: two'} + ) AS result2, + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Say: three'} + ) AS result3, + flock_get_metrics() AS metrics; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + # Parse CSV output to get metrics + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None, "No data returned from query" + assert "metrics" in row, "Metrics column not found in output" + + metrics = json.loads(row["metrics"]) + + # Should have llm_complete_1, llm_complete_2, llm_complete_3 + assert isinstance(metrics, dict) + assert len(metrics) >= 3, ( + f"Expected at least 3 metrics, got {len(metrics)}: {list(metrics.keys())}" + ) + + # Check that we have sequential numbering + found_keys = [key for key in metrics.keys() if key.startswith("llm_complete_")] + assert len(found_keys) >= 3, ( + f"Expected at least 3 llm_complete entries, got: {found_keys}" + ) + + # Verify each has the expected structure + for key in found_keys: + assert "api_calls" in metrics[key] + assert "input_tokens" in metrics[key] + assert "output_tokens" in metrics[key] + assert metrics[key]["api_calls"] == 1 + + +# ============================================================================ +# Debug Metrics Tests +# ============================================================================ + + +def test_flock_get_debug_metrics_returns_nested_structure( + integration_setup, model_config +): + """Test that flock_get_debug_metrics returns the nested structure""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-debug-metrics-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Call llm_complete and get debug metrics + query = ( + """ + SELECT + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Answer with one number: What is 2+2?'} + ) AS result, + flock_get_debug_metrics() AS debug_metrics; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + # Parse CSV output to get debug metrics + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None, "No data returned from query" + assert "debug_metrics" in row, "Debug metrics column not found in output" + + debug_metrics = json.loads(row["debug_metrics"]) + + # Check nested structure + assert isinstance(debug_metrics, dict) + assert "threads" in debug_metrics + assert "thread_count" in debug_metrics + assert isinstance(debug_metrics["threads"], dict) + assert debug_metrics["thread_count"] > 0 + + # Check that threads contain state data + found_llm_complete = False + for thread_id, thread_data in debug_metrics["threads"].items(): + assert isinstance(thread_data, dict) + for state_id, state_data in thread_data.items(): + assert isinstance(state_data, dict) + if "llm_complete" in state_data: + llm_complete_data = state_data["llm_complete"] + assert "registration_order" in llm_complete_data + assert "api_calls" in llm_complete_data + assert "input_tokens" in llm_complete_data + assert "output_tokens" in llm_complete_data + found_llm_complete = True + + assert found_llm_complete, "llm_complete not found in debug metrics" + + +def test_debug_metrics_registration_order(integration_setup, model_config): + """Test that debug metrics include registration_order""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-reg-order-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Make multiple calls + query = ( + """ + SELECT + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Say: one'} + ) AS result1, + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Say: two'} + ) AS result2, + flock_get_debug_metrics() AS debug_metrics; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0 + + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None and "debug_metrics" in row + + debug_metrics = json.loads(row["debug_metrics"]) + + # Check registration orders + registration_orders = [] + for thread_id, thread_data in debug_metrics["threads"].items(): + for state_id, state_data in thread_data.items(): + if "llm_complete" in state_data: + reg_order = state_data["llm_complete"]["registration_order"] + registration_orders.append(reg_order) + + # Should have at least one registration order + assert len(registration_orders) > 0 + # Registration orders should be positive integers + for order in registration_orders: + assert isinstance(order, int) + assert order > 0 + + +# ============================================================================ +# Aggregate Function Metrics Tests +# ============================================================================ + + +def test_aggregate_function_metrics_tracking(integration_setup, model_config): + """Test that aggregate functions track metrics correctly""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-aggregate-metrics-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Call llm_reduce and get metrics + query = ( + """ + SELECT + category, + llm_reduce( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'One word summary:', 'context_columns': [{'data': description}]} + ) AS summary, + flock_get_metrics() AS metrics + FROM VALUES + ('Electronics', 'High-performance laptop'), + ('Electronics', 'Latest smartphone') + AS t(category, description) + GROUP BY category; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + # Parse CSV output + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None, "No data returned from query" + assert "metrics" in row, "Metrics column not found" + + metrics = json.loads(row["metrics"]) + + # Check that metrics were recorded + assert isinstance(metrics, dict) + assert len(metrics) > 0 + + # Check for llm_reduce metrics + found_reduce = False + for key in metrics.keys(): + if key.startswith("llm_reduce_"): + reduce_metrics = metrics[key] + assert "api_calls" in reduce_metrics + assert "input_tokens" in reduce_metrics + assert "output_tokens" in reduce_metrics + assert "total_tokens" in reduce_metrics + assert "api_duration_ms" in reduce_metrics + assert "execution_time_ms" in reduce_metrics + assert "model_name" in reduce_metrics + assert reduce_metrics["model_name"] == test_model_name + assert "provider" in reduce_metrics + assert reduce_metrics["provider"] == provider + found_reduce = True + break + + assert found_reduce, f"llm_reduce metrics not found in: {list(metrics.keys())}" + + +def test_aggregate_function_metrics_merging_with_group_by( + integration_setup, model_config +): + """Test that metrics from multiple states in a single aggregate call are merged into one entry""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-merge-metrics-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Call llm_reduce with GROUP BY that will process multiple states + # This should result in multiple states being processed, but only ONE merged metrics entry + query = ( + """ + SELECT + category, + llm_reduce( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'One word summary:', 'context_columns': [{'data': description}]} + ) AS summary, + flock_get_metrics() AS metrics + FROM VALUES + ('Electronics', 'High-performance laptop'), + ('Electronics', 'Latest smartphone'), + ('Electronics', 'Gaming console') + AS t(category, description) + GROUP BY category; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + # Parse CSV output + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None, "No data returned from query" + assert "metrics" in row, "Metrics column not found" + + metrics = json.loads(row["metrics"]) + + # Check that metrics were recorded + assert isinstance(metrics, dict) + assert len(metrics) > 0 + + # Check for llm_reduce metrics - should have ONLY ONE entry (merged) + found_reduce_keys = [key for key in metrics.keys() if key.startswith("llm_reduce_")] + assert len(found_reduce_keys) == 1, ( + f"Expected exactly 1 llm_reduce metrics entry (merged), got {len(found_reduce_keys)}: {found_reduce_keys}" + ) + + # Verify the merged metrics have the expected structure + reduce_metrics = metrics[found_reduce_keys[0]] + assert "api_calls" in reduce_metrics + assert "input_tokens" in reduce_metrics + assert "output_tokens" in reduce_metrics + assert "total_tokens" in reduce_metrics + assert "api_duration_ms" in reduce_metrics + assert "execution_time_ms" in reduce_metrics + assert "model_name" in reduce_metrics + assert reduce_metrics["model_name"] == test_model_name + assert "provider" in reduce_metrics + assert reduce_metrics["provider"] == provider + + +def test_aggregate_function_metrics_merging_multiple_groups( + integration_setup, model_config +): + """Test that each GROUP BY group produces one merged metrics entry""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-merge-groups-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Call llm_reduce with multiple GROUP BY groups + # Each group should produce ONE merged metrics entry + query = ( + """ + SELECT + category, + llm_reduce( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'One word summary:', 'context_columns': [{'data': description}]} + ) AS summary, + flock_get_metrics() AS metrics + FROM VALUES + ('Electronics', 'High-performance laptop'), + ('Electronics', 'Latest smartphone'), + ('Clothing', 'Comfortable jacket'), + ('Clothing', 'Perfect fit jeans') + AS t(category, description) + GROUP BY category; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + # Parse CSV output - should have 2 rows (one per category) + reader = csv.DictReader(StringIO(result.stdout)) + rows = list(reader) + assert len(rows) == 2, f"Expected 2 rows (one per category), got {len(rows)}" + + # Check metrics from the last row (should have both groups merged) + metrics = json.loads(rows[-1]["metrics"]) + + # Should have exactly ONE llm_reduce entry (the last group's merged metrics) + # Note: In a GROUP BY query, each group processes independently, so we expect one entry per group + # But since we're checking the last row, we should see at least one merged entry + found_reduce_keys = [key for key in metrics.keys() if key.startswith("llm_reduce_")] + assert len(found_reduce_keys) >= 1, ( + f"Expected at least 1 llm_reduce metrics entry, got {len(found_reduce_keys)}: {found_reduce_keys}" + ) + + +def test_multiple_aggregate_functions_sequential_numbering( + integration_setup, model_config +): + """Test that multiple aggregate function calls get sequential numbering""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-sequential-aggregate-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + # Call llm_reduce twice in the same query + query = ( + """ + SELECT + category, + llm_reduce( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'One word 1:', 'context_columns': [{'data': description}]} + ) AS summary1, + llm_reduce( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'One word 2:', 'context_columns': [{'data': description}]} + ) AS summary2, + flock_get_metrics() AS metrics + FROM VALUES + ('Electronics', 'High-performance laptop') + AS t(category, description) + GROUP BY category; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0, f"Query failed with error: {result.stderr}" + + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None and "metrics" in row + + metrics = json.loads(row["metrics"]) + + # Should have llm_reduce_1 and llm_reduce_2 + found_keys = [key for key in metrics.keys() if key.startswith("llm_reduce_")] + assert len(found_keys) >= 2, ( + f"Expected at least 2 llm_reduce entries, got: {found_keys}" + ) + + # Verify sequential numbering + numbers = [] + for key in found_keys: + # Extract number from key like "llm_reduce_1" + num = int(key.split("_")[-1]) + numbers.append(num) + + numbers.sort() + # Should have sequential numbers starting from 1 + assert numbers[0] == 1, f"First number should be 1, got {numbers}" + + +def test_aggregate_function_debug_metrics(integration_setup, model_config): + """Test debug metrics for aggregate functions""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-debug-aggregate-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + query = ( + """ + SELECT + category, + llm_reduce( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'One word summary:', 'context_columns': [{'data': description}]} + ) AS summary, + flock_get_debug_metrics() AS debug_metrics + FROM VALUES + ('Electronics', 'High-performance laptop') + AS t(category, description) + GROUP BY category; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0 + + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None and "debug_metrics" in row + + debug_metrics = json.loads(row["debug_metrics"]) + + # Check nested structure + assert isinstance(debug_metrics, dict) + assert "threads" in debug_metrics + assert "thread_count" in debug_metrics + + # Check that llm_reduce appears in debug metrics + found_llm_reduce = False + for thread_id, thread_data in debug_metrics["threads"].items(): + for state_id, state_data in thread_data.items(): + if "llm_reduce" in state_data: + reduce_data = state_data["llm_reduce"] + assert "registration_order" in reduce_data + assert "api_calls" in reduce_data + assert "input_tokens" in reduce_data + assert "output_tokens" in reduce_data + found_llm_reduce = True + + assert found_llm_reduce, "llm_reduce not found in debug metrics" + + +def test_llm_rerank_metrics(integration_setup, model_config): + """Test metrics for llm_rerank aggregate function""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-rerank-metrics-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + query = ( + """ + SELECT + llm_rerank( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'One word rank:', 'context_columns': [{'data': description}]} + ) AS ranked, + flock_get_metrics() AS metrics + FROM VALUES + ('Product 1'), + ('Product 2'), + ('Product 3') + AS t(description); + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0 + + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None and "metrics" in row + + metrics = json.loads(row["metrics"]) + + # Check for llm_rerank metrics + found_rerank = False + for key in metrics.keys(): + if key.startswith("llm_rerank_"): + rerank_metrics = metrics[key] + assert "api_calls" in rerank_metrics + assert "input_tokens" in rerank_metrics + assert "output_tokens" in rerank_metrics + found_rerank = True + break + + assert found_rerank, f"llm_rerank metrics not found in: {list(metrics.keys())}" + + +def test_llm_first_metrics(integration_setup, model_config): + """Test metrics for llm_first aggregate function""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-first-metrics-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + query = ( + """ + SELECT + category, + llm_first( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'One word:', 'context_columns': [{'data': description}]} + ) AS first_item, + flock_get_metrics() AS metrics + FROM VALUES + ('Electronics', 'Product 1'), + ('Electronics', 'Product 2') + AS t(category, description) + GROUP BY category; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0 + + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None and "metrics" in row + + metrics = json.loads(row["metrics"]) + + # Check for llm_first metrics + found_first = False + for key in metrics.keys(): + if key.startswith("llm_first_"): + first_metrics = metrics[key] + assert "api_calls" in first_metrics + found_first = True + break + + assert found_first, f"llm_first metrics not found in: {list(metrics.keys())}" + + +# ============================================================================ +# Mixed Scalar and Aggregate Tests +# ============================================================================ + + +def test_mixed_scalar_and_aggregate_metrics(integration_setup, model_config): + """Test that both scalar and aggregate functions are tracked separately""" + duckdb_cli_path, db_path = integration_setup + model_name, provider = model_config + + run_cli(duckdb_cli_path, db_path, "SELECT flock_reset_metrics();") + + test_model_name = f"test-mixed-metrics-model_{model_name}" + create_model_query = ( + f"CREATE MODEL('{test_model_name}', '{model_name}', '{provider}');" + ) + run_cli(duckdb_cli_path, db_path, create_model_query, with_secrets=False) + + query = ( + """ + SELECT + llm_complete( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'Say: hi'} + ) AS scalar_result, + (SELECT llm_reduce( + {'model_name': '""" + + test_model_name + + """'}, + {'prompt': 'One word summary:', 'context_columns': [{'data': description}]} + ) FROM VALUES ('Test description') AS t(description)) AS aggregate_result, + flock_get_metrics() AS metrics; + """ + ) + result = run_cli(duckdb_cli_path, db_path, query) + + assert result.returncode == 0 + + reader = csv.DictReader(StringIO(result.stdout)) + row = next(reader, None) + assert row is not None and "metrics" in row + + metrics = json.loads(row["metrics"]) + + # Should have both scalar and aggregate metrics + has_scalar = any(key.startswith("llm_complete_") for key in metrics.keys()) + has_aggregate = any(key.startswith("llm_reduce_") for key in metrics.keys()) + + assert has_scalar, "Scalar function metrics not found" + assert has_aggregate, "Aggregate function metrics not found" diff --git a/test/integration/src/integration/tests/model_parser/test_model_parser.py b/test/integration/src/integration/tests/model_parser/test_model_parser.py index c6b47374..3bf06de9 100644 --- a/test/integration/src/integration/tests/model_parser/test_model_parser.py +++ b/test/integration/src/integration/tests/model_parser/test_model_parser.py @@ -4,9 +4,9 @@ def test_create_and_get_model(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('test-model', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET MODEL 'test-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "test-model" in result.stdout assert "gpt-4o" in result.stdout assert "openai" in result.stdout @@ -16,23 +16,23 @@ def test_create_and_get_model(integration_setup): def test_create_get_delete_global_model(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE GLOBAL MODEL('global-test-model', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET MODEL 'global-test-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "global-test-model" in result.stdout assert "gpt-4" in result.stdout assert "openai" in result.stdout assert "global" in result.stdout delete_query = "DELETE MODEL 'global-test-model';" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) def test_create_local_model_explicit(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE LOCAL MODEL('local-test-model', 'llama2', 'ollama');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET MODEL 'local-test-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "local-test-model" in result.stdout assert "llama2" in result.stdout assert "ollama" in result.stdout @@ -42,9 +42,9 @@ def test_create_local_model_explicit(integration_setup): def test_create_model_with_args(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('model-with-args', 'gpt-4o', 'openai', '{\"batch_size\": 10, \"tuple_format\": \"csv\"}');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET MODEL 'model-with-args';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "model-with-args" in result.stdout assert "gpt-4o" in result.stdout assert "openai" in result.stdout @@ -53,22 +53,22 @@ def test_create_model_with_args(integration_setup): def test_delete_model(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('delete-test-model', 'gpt-4o', 'azure');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) delete_query = "DELETE MODEL 'delete-test-model';" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) get_query = "GET MODEL 'delete-test-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "delete-test-model" not in result.stdout or result.stdout.strip() == "" def test_update_model_content(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('update-test-model', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) update_query = "UPDATE MODEL('update-test-model', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, update_query) + run_cli(duckdb_cli_path, db_path, update_query, with_secrets=False) get_query = "GET MODEL 'update-test-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "update-test-model" in result.stdout assert "gpt-4" in result.stdout assert "openai" in result.stdout @@ -77,11 +77,11 @@ def test_update_model_content(integration_setup): def test_update_model_with_args(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('update-args-model', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) update_query = "UPDATE MODEL('update-args-model', 'gpt-4o', 'openai', '{\"batch_size\": 5, \"model_parameters\": {\"temperature\": 0.7}}');" - run_cli(duckdb_cli_path, db_path, update_query) + run_cli(duckdb_cli_path, db_path, update_query, with_secrets=False) get_query = "GET MODEL 'update-args-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "update-args-model" in result.stdout assert "gpt-4" in result.stdout @@ -89,25 +89,25 @@ def test_update_model_with_args(integration_setup): def test_update_model_scope_to_global(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE LOCAL MODEL('scope-test-model', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) update_query = "UPDATE MODEL 'scope-test-model' TO GLOBAL;" - run_cli(duckdb_cli_path, db_path, update_query) + run_cli(duckdb_cli_path, db_path, update_query, with_secrets=False) get_query = "GET MODEL 'scope-test-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "scope-test-model" in result.stdout assert "global" in result.stdout delete_query = "DELETE MODEL 'scope-test-model';" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) def test_update_model_scope_to_local(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE GLOBAL MODEL('scope-test-model-2', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) update_query = "UPDATE MODEL 'scope-test-model-2' TO LOCAL;" - run_cli(duckdb_cli_path, db_path, update_query) + run_cli(duckdb_cli_path, db_path, update_query, with_secrets=False) get_query = "GET MODEL 'scope-test-model-2';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "scope-test-model-2" in result.stdout assert "local" in result.stdout @@ -116,24 +116,24 @@ def test_get_all_models(integration_setup): duckdb_cli_path, db_path = integration_setup create_query1 = "CREATE MODEL('model1', 'gpt-4o', 'openai');" create_query2 = "CREATE GLOBAL MODEL('model2', 'llama2', 'ollama');" - run_cli(duckdb_cli_path, db_path, create_query1) - run_cli(duckdb_cli_path, db_path, create_query2) + run_cli(duckdb_cli_path, db_path, create_query1, with_secrets=False) + run_cli(duckdb_cli_path, db_path, create_query2, with_secrets=False) get_all_query = "GET MODELS;" - result = run_cli(duckdb_cli_path, db_path, get_all_query) + result = run_cli(duckdb_cli_path, db_path, get_all_query, with_secrets=False) assert "model1" in result.stdout assert "model2" in result.stdout assert "gpt-4o" in result.stdout assert "llama2" in result.stdout delete_query = "DELETE MODEL 'model2';" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) def test_create_model_duplicate_error(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('duplicate-model', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) duplicate_query = "CREATE MODEL('duplicate-model', 'gpt-4o', 'openai');" - result = run_cli(duckdb_cli_path, db_path, duplicate_query) + result = run_cli(duckdb_cli_path, db_path, duplicate_query, with_secrets=False) assert result.returncode != 0 or "already exist" in result.stderr @@ -141,17 +141,17 @@ def test_create_model_invalid_syntax(integration_setup): duckdb_cli_path, db_path = integration_setup # Missing opening parenthesis invalid_query1 = "CREATE MODEL 'test', 'gpt-4o', 'openai');" - result1 = run_cli(duckdb_cli_path, db_path, invalid_query1) + result1 = run_cli(duckdb_cli_path, db_path, invalid_query1, with_secrets=False) assert result1.returncode != 0 # Missing comma between parameters invalid_query2 = "CREATE MODEL('test' 'gpt-4o' 'openai');" - result2 = run_cli(duckdb_cli_path, db_path, invalid_query2) + result2 = run_cli(duckdb_cli_path, db_path, invalid_query2, with_secrets=False) assert result2.returncode != 0 # Missing closing parenthesis invalid_query3 = "CREATE MODEL('test', 'gpt-4o', 'openai';" - result3 = run_cli(duckdb_cli_path, db_path, invalid_query3) + result3 = run_cli(duckdb_cli_path, db_path, invalid_query3, with_secrets=False) assert result3.returncode != 0 @@ -159,33 +159,33 @@ def test_create_model_invalid_json_args(integration_setup): duckdb_cli_path, db_path = integration_setup # Invalid JSON format invalid_query1 = "CREATE MODEL('test-model', 'gpt-4o', 'openai', '{invalid json}');" - result1 = run_cli(duckdb_cli_path, db_path, invalid_query1) + result1 = run_cli(duckdb_cli_path, db_path, invalid_query1, with_secrets=False) assert result1.returncode != 0 # Invalid parameter in JSON invalid_query2 = "CREATE MODEL('test-model', 'gpt-4o', 'openai', '{\"invalid_param\": \"value\"}');" - result2 = run_cli(duckdb_cli_path, db_path, invalid_query2) + result2 = run_cli(duckdb_cli_path, db_path, invalid_query2, with_secrets=False) assert result2.returncode != 0 def test_delete_nonexistent_model(integration_setup): duckdb_cli_path, db_path = integration_setup delete_query = "DELETE MODEL 'nonexistent-model';" - result = run_cli(duckdb_cli_path, db_path, delete_query) + result = run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) assert result.returncode == 0 def test_update_nonexistent_model_error(integration_setup): duckdb_cli_path, db_path = integration_setup update_query = "UPDATE MODEL('nonexistent-model', 'gpt-4o', 'openai');" - result = run_cli(duckdb_cli_path, db_path, update_query) + result = run_cli(duckdb_cli_path, db_path, update_query, with_secrets=False) assert result.returncode != 0 or "doesn't exist" in result.stderr def test_get_nonexistent_model(integration_setup): duckdb_cli_path, db_path = integration_setup get_query = "GET MODEL 'nonexistent-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert result.returncode == 0 assert "nonexistent-model" not in result.stdout or result.stdout.strip() == "" @@ -193,37 +193,37 @@ def test_get_nonexistent_model(integration_setup): def test_empty_model_name_error(integration_setup): duckdb_cli_path, db_path = integration_setup invalid_query = "CREATE MODEL('', 'gpt-4o', 'openai');" - result = run_cli(duckdb_cli_path, db_path, invalid_query) + result = run_cli(duckdb_cli_path, db_path, invalid_query, with_secrets=False) assert result.returncode != 0 def test_empty_model_value_error(integration_setup): duckdb_cli_path, db_path = integration_setup invalid_query = "CREATE MODEL('test-model', '', 'openai');" - result = run_cli(duckdb_cli_path, db_path, invalid_query) + result = run_cli(duckdb_cli_path, db_path, invalid_query, with_secrets=False) assert result.returncode != 0 def test_empty_provider_name_error(integration_setup): duckdb_cli_path, db_path = integration_setup invalid_query = "CREATE MODEL('test-model', 'gpt-4o', '');" - result = run_cli(duckdb_cli_path, db_path, invalid_query) + result = run_cli(duckdb_cli_path, db_path, invalid_query, with_secrets=False) assert result.returncode != 0 def test_get_model_vs_get_models(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('test-get-model', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) # Test GET MODEL (singular) get_single_query = "GET MODEL 'test-get-model';" - result_single = run_cli(duckdb_cli_path, db_path, get_single_query) + result_single = run_cli(duckdb_cli_path, db_path, get_single_query, with_secrets=False) assert "test-get-model" in result_single.stdout # Test GET MODELS (plural) - should get all models get_all_query = "GET MODELS;" - result_all = run_cli(duckdb_cli_path, db_path, get_all_query) + result_all = run_cli(duckdb_cli_path, db_path, get_all_query, with_secrets=False) assert "test-get-model" in result_all.stdout @@ -232,11 +232,11 @@ def test_model_args_allowed_parameters(integration_setup): # Test valid parameters: tuple_format, batch_size, model_parameters valid_query = 'CREATE MODEL(\'valid-args-model\', \'gpt-4o\', \'openai\', \'{"tuple_format": "json", "batch_size": 5, "model_parameters": {"temperature": 0.8}}\');' - result = run_cli(duckdb_cli_path, db_path, valid_query) + result = run_cli(duckdb_cli_path, db_path, valid_query, with_secrets=False) assert result.returncode == 0 get_query = "GET MODEL 'valid-args-model';" - get_result = run_cli(duckdb_cli_path, db_path, get_query) + get_result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "valid-args-model" in get_result.stdout @@ -248,12 +248,12 @@ def test_multiple_providers(integration_setup): azure_query = "CREATE MODEL('azure-model', 'gpt-4o', 'azure');" ollama_query = "CREATE MODEL('ollama-model', 'llama2', 'ollama');" - run_cli(duckdb_cli_path, db_path, openai_query) - run_cli(duckdb_cli_path, db_path, azure_query) - run_cli(duckdb_cli_path, db_path, ollama_query) + run_cli(duckdb_cli_path, db_path, openai_query, with_secrets=False) + run_cli(duckdb_cli_path, db_path, azure_query, with_secrets=False) + run_cli(duckdb_cli_path, db_path, ollama_query, with_secrets=False) get_all_query = "GET MODELS;" - result = run_cli(duckdb_cli_path, db_path, get_all_query) + result = run_cli(duckdb_cli_path, db_path, get_all_query, with_secrets=False) assert "openai-model" in result.stdout assert "azure-model" in result.stdout assert "ollama-model" in result.stdout @@ -266,9 +266,9 @@ def test_multiple_providers(integration_setup): def test_create_model_without_semicolon(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('no-semicolon-model', 'gpt-4o', 'openai')" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET MODEL 'no-semicolon-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "no-semicolon-model" in result.stdout @@ -277,9 +277,9 @@ def test_create_model_with_comment(integration_setup): create_query = ( "CREATE MODEL('comment-model', 'gpt-4o', 'openai'); -- This is a comment" ) - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET MODEL 'comment-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "comment-model" in result.stdout @@ -287,27 +287,27 @@ def test_create_model_with_comment_before(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = """-- Create a test model CREATE MODEL('comment-before-model', 'gpt-4o', 'openai');""" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET MODEL 'comment-before-model';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "comment-before-model" in result.stdout def test_delete_model_without_semicolon(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('delete-no-semi', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) delete_query = "DELETE MODEL 'delete-no-semi'" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) get_query = "GET MODEL 'delete-no-semi';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "delete-no-semi" not in result.stdout or result.stdout.strip() == "" def test_get_models_without_semicolon(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE MODEL('get-no-semi', 'gpt-4o', 'openai');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET MODELS" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "get-no-semi" in result.stdout diff --git a/test/integration/src/integration/tests/prompt_parser/test_prompt_parser.py b/test/integration/src/integration/tests/prompt_parser/test_prompt_parser.py index b41d9783..5926a76e 100644 --- a/test/integration/src/integration/tests/prompt_parser/test_prompt_parser.py +++ b/test/integration/src/integration/tests/prompt_parser/test_prompt_parser.py @@ -4,9 +4,9 @@ def test_create_and_get_prompt(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE PROMPT('test-prompt', 'Test prompt content');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET PROMPT 'test-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "test-prompt" in result.stdout assert "Test prompt content" in result.stdout assert "local" in result.stdout @@ -15,22 +15,22 @@ def test_create_and_get_prompt(integration_setup): def test_create_get_delete_global_prompt(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE GLOBAL PROMPT('global-test-prompt', 'Global test content');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET PROMPT 'global-test-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "global-test-prompt" in result.stdout assert "Global test content" in result.stdout assert "global" in result.stdout delete_query = "DELETE PROMPT 'global-test-prompt';" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) def test_create_local_prompt_explicit(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE LOCAL PROMPT('local-test-prompt', 'Local test content');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET PROMPT 'local-test-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "local-test-prompt" in result.stdout assert "Local test content" in result.stdout assert "local" in result.stdout @@ -39,22 +39,22 @@ def test_create_local_prompt_explicit(integration_setup): def test_delete_prompt(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE PROMPT('delete-test-prompt', 'To be deleted');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) delete_query = "DELETE PROMPT 'delete-test-prompt';" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) get_query = "GET PROMPT 'delete-test-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "delete-test-prompt" not in result.stdout or result.stdout.strip() == "" def test_update_prompt_content(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE PROMPT('update-test-prompt', 'Original content');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) update_query = "UPDATE PROMPT('update-test-prompt', 'Updated content');" - run_cli(duckdb_cli_path, db_path, update_query) + run_cli(duckdb_cli_path, db_path, update_query, with_secrets=False) get_query = "GET PROMPT 'update-test-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "update-test-prompt" in result.stdout assert "Updated content" in result.stdout @@ -62,25 +62,25 @@ def test_update_prompt_content(integration_setup): def test_update_prompt_scope_to_global(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE LOCAL PROMPT('scope-test-prompt', 'Test content');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) update_query = "UPDATE PROMPT 'scope-test-prompt' TO GLOBAL;" - run_cli(duckdb_cli_path, db_path, update_query) + run_cli(duckdb_cli_path, db_path, update_query, with_secrets=False) get_query = "GET PROMPT 'scope-test-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "scope-test-prompt" in result.stdout assert "global" in result.stdout delete_query = "DELETE PROMPT 'scope-test-prompt';" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) def test_update_prompt_scope_to_local(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE GLOBAL PROMPT('scope-test-prompt-2', 'Test content');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) update_query = "UPDATE PROMPT 'scope-test-prompt-2' TO LOCAL;" - run_cli(duckdb_cli_path, db_path, update_query) + run_cli(duckdb_cli_path, db_path, update_query, with_secrets=False) get_query = "GET PROMPT 'scope-test-prompt-2';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "scope-test-prompt-2" in result.stdout assert "local" in result.stdout @@ -89,58 +89,58 @@ def test_get_all_prompts(integration_setup): duckdb_cli_path, db_path = integration_setup create_query1 = "CREATE PROMPT('prompt1', 'Content 1');" create_query2 = "CREATE GLOBAL PROMPT('prompt2', 'Content 2');" - run_cli(duckdb_cli_path, db_path, create_query1) - run_cli(duckdb_cli_path, db_path, create_query2) + run_cli(duckdb_cli_path, db_path, create_query1, with_secrets=False) + run_cli(duckdb_cli_path, db_path, create_query2, with_secrets=False) get_all_query = "GET PROMPTS;" - result = run_cli(duckdb_cli_path, db_path, get_all_query) + result = run_cli(duckdb_cli_path, db_path, get_all_query, with_secrets=False) assert "prompt1" in result.stdout assert "prompt2" in result.stdout assert "Content 1" in result.stdout assert "Content 2" in result.stdout delete_query = "DELETE PROMPT 'prompt2';" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) def test_create_prompt_duplicate_error(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE PROMPT('duplicate-prompt', 'Original');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) duplicate_query = "CREATE PROMPT('duplicate-prompt', 'Duplicate');" - result = run_cli(duckdb_cli_path, db_path, duplicate_query) + result = run_cli(duckdb_cli_path, db_path, duplicate_query, with_secrets=False) assert result.returncode != 0 or "already exist" in result.stderr def test_create_prompt_invalid_syntax(integration_setup): duckdb_cli_path, db_path = integration_setup invalid_query1 = "CREATE PROMPT 'test', 'content');" - result1 = run_cli(duckdb_cli_path, db_path, invalid_query1) + result1 = run_cli(duckdb_cli_path, db_path, invalid_query1, with_secrets=False) assert result1.returncode != 0 invalid_query2 = "CREATE PROMPT('test' 'content');" - result2 = run_cli(duckdb_cli_path, db_path, invalid_query2) + result2 = run_cli(duckdb_cli_path, db_path, invalid_query2, with_secrets=False) assert result2.returncode != 0 invalid_query3 = "CREATE PROMPT('test', 'content';" - result3 = run_cli(duckdb_cli_path, db_path, invalid_query3) + result3 = run_cli(duckdb_cli_path, db_path, invalid_query3, with_secrets=False) assert result3.returncode != 0 def test_delete_nonexistent_prompt(integration_setup): duckdb_cli_path, db_path = integration_setup delete_query = "DELETE PROMPT 'nonexistent-prompt';" - result = run_cli(duckdb_cli_path, db_path, delete_query) + result = run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) assert result.returncode == 0 def test_update_nonexistent_prompt_error(integration_setup): duckdb_cli_path, db_path = integration_setup update_query = "UPDATE PROMPT('nonexistent-prompt', 'New content');" - result = run_cli(duckdb_cli_path, db_path, update_query) + result = run_cli(duckdb_cli_path, db_path, update_query, with_secrets=False) assert result.returncode != 0 or "doesn't exist" in result.stderr def test_get_nonexistent_prompt(integration_setup): duckdb_cli_path, db_path = integration_setup get_query = "GET PROMPT 'nonexistent-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert result.returncode == 0 assert "nonexistent-prompt" not in result.stdout or result.stdout.strip() == "" @@ -148,14 +148,14 @@ def test_get_nonexistent_prompt(integration_setup): def test_empty_prompt_name_error(integration_setup): duckdb_cli_path, db_path = integration_setup invalid_query = "CREATE PROMPT('', 'content');" - result = run_cli(duckdb_cli_path, db_path, invalid_query) + result = run_cli(duckdb_cli_path, db_path, invalid_query, with_secrets=False) assert result.returncode != 0 def test_empty_prompt_content_error(integration_setup): duckdb_cli_path, db_path = integration_setup invalid_query = "CREATE PROMPT('test', '');" - result = run_cli(duckdb_cli_path, db_path, invalid_query) + result = run_cli(duckdb_cli_path, db_path, invalid_query, with_secrets=False) assert result.returncode != 0 @@ -163,9 +163,9 @@ def test_empty_prompt_content_error(integration_setup): def test_create_prompt_without_semicolon(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE PROMPT('no-semi-prompt', 'Test content')" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET PROMPT 'no-semi-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "no-semi-prompt" in result.stdout @@ -174,9 +174,9 @@ def test_create_prompt_with_comment(integration_setup): create_query = ( "CREATE PROMPT('comment-prompt', 'Test content'); -- This is a comment" ) - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET PROMPT 'comment-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "comment-prompt" in result.stdout @@ -184,27 +184,27 @@ def test_create_prompt_with_comment_before(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = """-- Create a test prompt CREATE PROMPT('comment-before-prompt', 'Test content');""" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET PROMPT 'comment-before-prompt';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "comment-before-prompt" in result.stdout def test_delete_prompt_without_semicolon(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE PROMPT('delete-no-semi', 'Test content');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) delete_query = "DELETE PROMPT 'delete-no-semi'" - run_cli(duckdb_cli_path, db_path, delete_query) + run_cli(duckdb_cli_path, db_path, delete_query, with_secrets=False) get_query = "GET PROMPT 'delete-no-semi';" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "delete-no-semi" not in result.stdout or result.stdout.strip() == "" def test_get_prompts_without_semicolon(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE PROMPT('get-no-semi', 'Test content');" - run_cli(duckdb_cli_path, db_path, create_query) + run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) get_query = "GET PROMPTS" - result = run_cli(duckdb_cli_path, db_path, get_query) + result = run_cli(duckdb_cli_path, db_path, get_query, with_secrets=False) assert "get-no-semi" in result.stdout diff --git a/test/integration/src/integration/tests/secret_manager/test_secret_manager.py b/test/integration/src/integration/tests/secret_manager/test_secret_manager.py index 4cb1f97c..73c3e06d 100644 --- a/test/integration/src/integration/tests/secret_manager/test_secret_manager.py +++ b/test/integration/src/integration/tests/secret_manager/test_secret_manager.py @@ -8,7 +8,7 @@ def test_create_openai_secret(integration_setup): create_query = ( f"CREATE SECRET {secret_name} (TYPE OPENAI, API_KEY 'test-api-key-123');" ) - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode == 0 @@ -16,7 +16,7 @@ def test_create_openai_secret_with_base_url(integration_setup): duckdb_cli_path, db_path = integration_setup secret_name = "test_openai_secret_with_url" create_query = f"CREATE SECRET {secret_name} (TYPE OPENAI, API_KEY 'test-api-key-123', BASE_URL 'https://api.custom.com');" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode == 0 @@ -24,7 +24,7 @@ def test_create_azure_secret(integration_setup): duckdb_cli_path, db_path = integration_setup secret_name = "test_azure_secret" create_query = f"CREATE SECRET {secret_name} (TYPE AZURE_LLM, API_KEY 'test-azure-key', RESOURCE_NAME 'test-resource', API_VERSION '2023-05-15');" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode == 0 @@ -34,7 +34,7 @@ def test_create_ollama_secret(integration_setup): create_query = ( f"CREATE SECRET {secret_name} (TYPE OLLAMA, API_URL 'http://localhost:11434');" ) - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode == 0 @@ -44,7 +44,7 @@ def test_create_openai_secret_missing_required_field(integration_setup): create_query = ( f"CREATE SECRET {secret_name} (TYPE OPENAI, BASE_URL 'https://api.openai.com');" ) - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode != 0 @@ -52,13 +52,13 @@ def test_create_azure_secret_missing_required_fields(integration_setup): duckdb_cli_path, db_path = integration_setup secret_name = "test_azure_invalid" create_query = f"CREATE SECRET {secret_name} (TYPE AZURE_LLM, RESOURCE_NAME 'test-resource', API_VERSION '2023-05-15');" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode != 0 create_query = f"CREATE SECRET {secret_name} (TYPE AZURE_LLM, API_KEY 'test-key', API_VERSION '2023-05-15');" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode != 0 create_query = f"CREATE SECRET {secret_name} (TYPE AZURE_LLM, API_KEY 'test-key', RESOURCE_NAME 'test-resource');" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode != 0 @@ -66,7 +66,7 @@ def test_create_ollama_secret_missing_required_field(integration_setup): duckdb_cli_path, db_path = integration_setup secret_name = "test_ollama_invalid" create_query = f"CREATE SECRET {secret_name} (TYPE OLLAMA);" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode != 0 @@ -76,14 +76,14 @@ def test_create_secret_with_unsupported_type(integration_setup): create_query = ( f"CREATE SECRET {secret_name} (TYPE UNSUPPORTED_TYPE, API_KEY 'test-key');" ) - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode != 0 def test_create_secret_empty_name(integration_setup): duckdb_cli_path, db_path = integration_setup create_query = "CREATE SECRET '' (TYPE OPENAI, API_KEY 'test-key');" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode != 0 @@ -91,7 +91,7 @@ def test_create_secret_empty_api_key(integration_setup): duckdb_cli_path, db_path = integration_setup secret_name = "test_empty_key_secret" create_query = f"CREATE SECRET {secret_name} (TYPE OPENAI, API_KEY '');" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode != 0 @@ -102,17 +102,17 @@ def test_multiple_secrets_different_types(integration_setup): create_openai = ( f"CREATE SECRET {openai_secret} (TYPE OPENAI, API_KEY 'openai-key');" ) - result = run_cli(duckdb_cli_path, db_path, create_openai) + result = run_cli(duckdb_cli_path, db_path, create_openai, with_secrets=False) assert result.returncode == 0 secrets.append(openai_secret) azure_secret = "test_multi_azure" create_azure = f"CREATE SECRET {azure_secret} (TYPE AZURE_LLM, API_KEY 'azure-key', RESOURCE_NAME 'resource', API_VERSION '2023-05-15');" - result = run_cli(duckdb_cli_path, db_path, create_azure) + result = run_cli(duckdb_cli_path, db_path, create_azure, with_secrets=False) assert result.returncode == 0 secrets.append(azure_secret) ollama_secret = "test_multi_ollama" create_ollama = f"CREATE SECRET {ollama_secret} (TYPE OLLAMA, API_URL 'http://localhost:11434');" - result = run_cli(duckdb_cli_path, db_path, create_ollama) + result = run_cli(duckdb_cli_path, db_path, create_ollama, with_secrets=False) assert result.returncode == 0 secrets.append(ollama_secret) @@ -121,7 +121,7 @@ def test_secret_scope_handling(integration_setup): duckdb_cli_path, db_path = integration_setup secret_name = "test_scope_secret" create_query = f"CREATE SECRET {secret_name} (TYPE OPENAI, API_KEY 'test-key');" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode == 0 @@ -147,20 +147,20 @@ def test_create_secrets_parametrized(integration_setup, provider_type, required_ [f"{key} '{value}'" for key, value in required_fields.items()] ) create_query = f"CREATE SECRET {secret_name} (TYPE {provider_type}, {fields_str});" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode == 0 def test_persistent_secret_lifecycle(integration_setup): duckdb_cli_path, db_path = integration_setup - secret_name = "test_openai_secret" - + secret_name = "test_persistent_lifecycle_secret" + create_query = f"CREATE PERSISTENT SECRET {secret_name} (TYPE OPENAI, API_KEY 'test-persistent-key');" - result = run_cli(duckdb_cli_path, db_path, create_query) + result = run_cli(duckdb_cli_path, db_path, create_query, with_secrets=False) assert result.returncode == 0 check_query = f"SELECT name, type, persistent FROM duckdb_secrets() WHERE name = '{secret_name}';" - result = run_cli(duckdb_cli_path, db_path, check_query) + result = run_cli(duckdb_cli_path, db_path, check_query, with_secrets=False) assert secret_name in result.stdout assert "OPENAI" in result.stdout or "openai" in result.stdout assert ( @@ -170,9 +170,9 @@ def test_persistent_secret_lifecycle(integration_setup): ) drop_query = f"DROP PERSISTENT SECRET {secret_name};" - result = run_cli(duckdb_cli_path, db_path, drop_query) + result = run_cli(duckdb_cli_path, db_path, drop_query, with_secrets=False) assert result.returncode == 0 check_query = f"SELECT name FROM duckdb_secrets() WHERE name = '{secret_name}';" - result = run_cli(duckdb_cli_path, db_path, check_query) + result = run_cli(duckdb_cli_path, db_path, check_query, with_secrets=False) assert secret_name not in result.stdout diff --git a/test/unit/functions/aggregate/llm_aggregate_function_test_base.hpp b/test/unit/functions/aggregate/llm_aggregate_function_test_base.hpp index ffaffd7d..c84cac5c 100644 --- a/test/unit/functions/aggregate/llm_aggregate_function_test_base.hpp +++ b/test/unit/functions/aggregate/llm_aggregate_function_test_base.hpp @@ -27,9 +27,20 @@ class LLMAggregateTestBase : public ::testing::Test { con.Query(" CREATE SECRET (" " TYPE OPENAI," " API_KEY 'your-api-key');"); + con.Query(" CREATE SECRET (" + " TYPE OLLAMA," + " API_URL '127.0.0.1:11434');"); + // Create a shared mock provider for expectations mock_provider = std::make_shared(ModelDetails{}); - Model::SetMockProvider(mock_provider); + + // Use factory pattern so each Model gets a fresh mock instance + // This is thread-safe for parallel GROUP BY processing + Model::SetMockProviderFactory([this]() { + // Return the same mock for expectation purposes, but each Model + // instance calls this factory, so we can track expectations + return mock_provider; + }); } void TearDown() override { diff --git a/test/unit/functions/aggregate/llm_first.cpp b/test/unit/functions/aggregate/llm_first.cpp index 61f0ba89..e9f8132d 100644 --- a/test/unit/functions/aggregate/llm_first.cpp +++ b/test/unit/functions/aggregate/llm_first.cpp @@ -5,13 +5,11 @@ namespace flock { class LLMFirstTest : public LLMAggregateTestBase { protected: - // The LLM response (for mocking) static constexpr const char* LLM_RESPONSE = R"({"items":[0]})"; - // The expected function output (selected data) - static constexpr const char* EXPECTED_RESPONSE = R"([{"data":["High-performance running shoes with advanced cushioning"]}])"; + static constexpr const char* EXPECTED_RESPONSE_SINGLE = R"([{"data":["High-performance running shoes with advanced cushioning"]}])"; std::string GetExpectedResponse() const override { - return EXPECTED_RESPONSE; + return EXPECTED_RESPONSE_SINGLE; } nlohmann::json GetExpectedJsonResponse() const override { @@ -39,8 +37,30 @@ class LLMFirstTest : public LLMAggregateTestBase { } }; -// Test llm_first with SQL queries without GROUP BY - new API -TEST_F(LLMFirstTest, LLMFirstWithoutGroupBy) { +// Test 1-tuple case: no LLM call needed, returns the single tuple directly +TEST_F(LLMFirstTest, SingleTupleNoLLMCall) { + // No mock expectations - LLM should NOT be called for single tuple + auto con = Config::GetConnection(); + + const auto results = con.Query( + "SELECT llm_first(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Select the first product', 'context_columns': [{'data': description}]}" + ") AS first_product FROM VALUES " + "('High-performance running shoes with advanced cushioning') AS products(description);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); + + nlohmann::json parsed = nlohmann::json::parse(results->GetValue(0, 0).GetValue()); + EXPECT_EQ(parsed.size(), 1); + EXPECT_TRUE(parsed[0].contains("data")); + EXPECT_EQ(parsed[0]["data"].size(), 1); + EXPECT_EQ(parsed[0]["data"][0], "High-performance running shoes with advanced cushioning"); +} + +// Test multiple tuples without GROUP BY: LLM is called once +TEST_F(LLMFirstTest, MultipleTuplesWithoutGroupBy) { EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) .Times(1); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) @@ -49,43 +69,72 @@ TEST_F(LLMFirstTest, LLMFirstWithoutGroupBy) { auto con = Config::GetConnection(); const auto results = con.Query( - "SELECT " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'What is the most relevant detail for these products, based on their names and descriptions?', 'context_columns': [{'data': description}]}" - ") AS first_product_feature FROM VALUES " - "('High-performance running shoes with advanced cushioning'), " - "('Wireless noise-cancelling headphones for immersive audio'), " - "('Smart fitness tracker with heart rate monitoring') AS products(description);"); - + "SELECT llm_first(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'What is the most relevant product?', 'context_columns': [{'data': description}]}" + ") AS first_product FROM VALUES " + "('High-performance running shoes with advanced cushioning'), " + "('Wireless noise-cancelling headphones for immersive audio'), " + "('Smart fitness tracker with heart rate monitoring') AS products(description);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), 1); ASSERT_EQ(results->GetValue(0, 0).GetValue(), GetExpectedResponse()); } -// Test llm_first with SQL queries with GROUP BY - new API -TEST_F(LLMFirstTest, LLMFirstWithGroupBy) { +// Test GROUP BY with multiple tuples per group: LLM is called for each group +TEST_F(LLMFirstTest, GroupByWithMultipleTuplesPerGroup) { EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(3); + .Times(2); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(3) + .Times(2) .WillRepeatedly(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); const auto results = con.Query( - "SELECT category, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'What is the most relevant detail for these products, based on their names and descriptions?', 'context_columns': [{'data': description}]}" - ") AS first_feature FROM VALUES " - "('electronics', 'High-performance running shoes with advanced cushioning'), " - "('audio', 'Wireless noise-cancelling headphones for immersive audio'), " - "('fitness', 'Smart fitness tracker with heart rate monitoring') " - "AS products(category, description) GROUP BY category;"); + "SELECT category, llm_first(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Select the most relevant product', 'context_columns': [{'data': description}]}" + ") AS first_product FROM VALUES " + "('footwear', 'Running shoes with cushioning'), " + "('footwear', 'Business shoes for professionals'), " + "('electronics', 'Wireless headphones'), " + "('electronics', 'Smart fitness tracker') " + "AS products(category, description) GROUP BY category;"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 2); + for (idx_t i = 0; i < results->RowCount(); i++) { + EXPECT_NO_THROW({ + nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); + EXPECT_TRUE(parsed[0].contains("data")); + }); + } +} +// Test GROUP BY with single tuple per group: no LLM calls needed +TEST_F(LLMFirstTest, GroupByWithSingleTuplePerGroup) { + // No mock expectations - LLM should NOT be called when each group has only 1 tuple + auto con = Config::GetConnection(); + + const auto results = con.Query( + "SELECT category, llm_first(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Select the most relevant product', 'context_columns': [{'data': description}]}" + ") AS first_product FROM VALUES " + "('footwear', 'Running shoes with cushioning'), " + "('electronics', 'Wireless headphones'), " + "('fitness', 'Smart fitness tracker') " + "AS products(category, description) GROUP BY category;"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), 3); for (idx_t i = 0; i < results->RowCount(); i++) { EXPECT_NO_THROW({ nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); EXPECT_TRUE(parsed[0].contains("data")); + EXPECT_EQ(parsed[0]["data"].size(), 1); }); } } @@ -96,68 +145,62 @@ TEST_F(LLMFirstTest, ValidateArguments) { } // Test operation with invalid arguments -TEST_F(LLMFirstTest, Operation_InvalidArguments_ThrowsException) { +TEST_F(LLMFirstTest, InvalidArguments) { TestOperationInvalidArguments(); } -// Test operation with multiple input scenarios - new API -TEST_F(LLMFirstTest, Operation_MultipleInputs_ProcessesCorrectly) { - const nlohmann::json expected_response = GetExpectedJsonResponse(); +// Test with audio transcription +TEST_F(LLMFirstTest, AudioTranscription) { + const nlohmann::json expected_transcription1 = nlohmann::json::parse(R"({"text": "First audio candidate"})"); + const nlohmann::json expected_transcription2 = nlohmann::json::parse(R"({"text": "Second audio candidate"})"); + const nlohmann::json expected_complete_response = GetExpectedJsonResponse(); + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription1, expected_transcription2})); EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(3); + .Times(1); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(3) - .WillRepeatedly(::testing::Return(std::vector{expected_response})); + .WillOnce(::testing::Return(std::vector{expected_complete_response})); auto con = Config::GetConnection(); - const auto results = con.Query( - "SELECT category, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'What is the most relevant product information?', 'context_columns': [{'data': description}]}" - ") AS first_relevant_info FROM VALUES " - "('electronics', 'High-performance running shoes with advanced cushioning'), " - "('audio', 'Wireless noise-cancelling headphones for immersive audio'), " - "('fitness', 'Smart fitness tracker with heart rate monitoring') " - "AS products(category, description) GROUP BY category;"); - - ASSERT_EQ(results->RowCount(), 3); - for (idx_t i = 0; i < results->RowCount(); i++) { - EXPECT_NO_THROW({ - nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); - EXPECT_TRUE(parsed[0].contains("data")); - }); - } + "SELECT llm_first(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Select the best audio candidate', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gpt-4o-transcribe'}" + "]}) AS result FROM VALUES " + "('https://example.com/audio1.mp3'), " + "('https://example.com/audio2.mp3') AS tbl(audio_url);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); } -// Test large input set processing - new API -TEST_F(LLMFirstTest, Operation_LargeInputSet_ProcessesCorrectly) { - constexpr size_t input_count = 100; - const nlohmann::json expected_response = PrepareExpectedResponseForLargeInput(input_count); - - EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(100); - EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(100) - .WillRepeatedly(::testing::Return(std::vector{expected_response})); - +// Test audio transcription error handling for Ollama +TEST_F(LLMFirstTest, AudioTranscriptionOllamaError) { auto con = Config::GetConnection(); + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); const auto results = con.Query( - "SELECT id, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Select the first relevant product based on relevance', 'context_columns': [{'data': 'Product description ' || id::TEXT}]}" - ") AS first_relevant FROM range(" + - std::to_string(input_count) + ") AS t(id) GROUP BY id;"); - - ASSERT_EQ(results->RowCount(), 100); - for (idx_t i = 0; i < results->RowCount(); i++) { - EXPECT_NO_THROW({ - nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); - EXPECT_TRUE(parsed[0].contains("data")); - }); - } + "SELECT llm_first(" + "{'model_name': 'gemma3:4b'}, " + "{'prompt': 'Select the best audio', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gemma3:4b'}" + "]}) AS result FROM VALUES " + "('https://example.com/audio1.mp3'), " + "('https://example.com/audio2.mp3') AS tbl(audio_url);"); + + ASSERT_TRUE(results->HasError()); } }// namespace flock diff --git a/test/unit/functions/aggregate/llm_last.cpp b/test/unit/functions/aggregate/llm_last.cpp index bd09d450..f3f515d8 100644 --- a/test/unit/functions/aggregate/llm_last.cpp +++ b/test/unit/functions/aggregate/llm_last.cpp @@ -5,13 +5,11 @@ namespace flock { class LLMLastTest : public LLMAggregateTestBase { protected: - // The LLM response (for mocking) - for llm_last, it should select the last index - static constexpr const char* LLM_RESPONSE = R"({"items":[0]})"; - // The expected function output (selected data from the last position) - static constexpr const char* EXPECTED_RESPONSE = R"([{"data":["High-performance running shoes with advanced cushioning"]}])"; + static constexpr const char* LLM_RESPONSE = R"({"items":[2]})"; + static constexpr const char* EXPECTED_RESPONSE_SINGLE = R"([{"data":["Smart fitness tracker with heart rate monitoring"]}])"; std::string GetExpectedResponse() const override { - return EXPECTED_RESPONSE; + return EXPECTED_RESPONSE_SINGLE; } nlohmann::json GetExpectedJsonResponse() const override { @@ -27,11 +25,11 @@ class LLMLastTest : public LLMAggregateTestBase { } nlohmann::json PrepareExpectedResponseForBatch(const std::vector& responses) const override { - return nlohmann::json{{"selected", static_cast(responses.size() - 1)}}; + return nlohmann::json{{"items", {static_cast(responses.size() - 1)}}}; } nlohmann::json PrepareExpectedResponseForLargeInput(size_t input_count) const override { - return nlohmann::json{{"selected", static_cast(input_count - 1)}}; + return nlohmann::json{{"items", {static_cast(input_count - 1)}}}; } std::string FormatExpectedResult(const nlohmann::json& response) const override { @@ -39,8 +37,29 @@ class LLMLastTest : public LLMAggregateTestBase { } }; -// Test llm_last with SQL queries without GROUP BY - new API -TEST_F(LLMLastTest, LLMLastWithoutGroupBy) { +// Test 1-tuple case: no LLM call needed, returns the single tuple directly +TEST_F(LLMLastTest, SingleTupleNoLLMCall) { + auto con = Config::GetConnection(); + + const auto results = con.Query( + "SELECT llm_last(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Select the last product', 'context_columns': [{'data': description}]}" + ") AS last_product FROM VALUES " + "('High-performance running shoes with advanced cushioning') AS products(description);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); + + nlohmann::json parsed = nlohmann::json::parse(results->GetValue(0, 0).GetValue()); + EXPECT_EQ(parsed.size(), 1); + EXPECT_TRUE(parsed[0].contains("data")); + EXPECT_EQ(parsed[0]["data"].size(), 1); + EXPECT_EQ(parsed[0]["data"][0], "High-performance running shoes with advanced cushioning"); +} + +// Test multiple tuples without GROUP BY: LLM is called once +TEST_F(LLMLastTest, MultipleTuplesWithoutGroupBy) { EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) .Times(1); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) @@ -49,43 +68,73 @@ TEST_F(LLMLastTest, LLMLastWithoutGroupBy) { auto con = Config::GetConnection(); const auto results = con.Query( - "SELECT " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'What is the least relevant detail for these products, based on their names and descriptions?', 'context_columns': [{'data': description}]}" - ") AS last_product_feature FROM VALUES " - "('High-performance running shoes with advanced cushioning'), " - "('Wireless noise-cancelling headphones for immersive audio'), " - "('Smart fitness tracker with heart rate monitoring') AS products(description);"); - + "SELECT llm_last(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'What is the least relevant product?', 'context_columns': [{'data': description}]}" + ") AS last_product FROM VALUES " + "('High-performance running shoes with advanced cushioning'), " + "('Wireless noise-cancelling headphones for immersive audio'), " + "('Smart fitness tracker with heart rate monitoring') AS products(description);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), 1); ASSERT_EQ(results->GetValue(0, 0).GetValue(), GetExpectedResponse()); } -// Test llm_last with SQL queries with GROUP BY - new API -TEST_F(LLMLastTest, LLMLastWithGroupBy) { +// Test GROUP BY with multiple tuples per group: LLM is called for each group +TEST_F(LLMLastTest, GroupByWithMultipleTuplesPerGroup) { + nlohmann::json response_index_1 = nlohmann::json{{"items", {1}}}; + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(3); + .Times(2); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(3) - .WillRepeatedly(::testing::Return(std::vector{GetExpectedJsonResponse()})); + .Times(2) + .WillRepeatedly(::testing::Return(std::vector{response_index_1})); auto con = Config::GetConnection(); const auto results = con.Query( - "SELECT category, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'What is the least relevant detail for these products, based on their names and descriptions?', 'context_columns': [{'data': description}]}" - ") AS last_feature FROM VALUES " - "('electronics', 'High-performance running shoes with advanced cushioning'), " - "('audio', 'Wireless noise-cancelling headphones for immersive audio'), " - "('fitness', 'Smart fitness tracker with heart rate monitoring') " - "AS products(category, description) GROUP BY category;"); + "SELECT category, llm_last(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Select the least relevant product', 'context_columns': [{'data': description}]}" + ") AS last_product FROM VALUES " + "('footwear', 'Running shoes with cushioning'), " + "('footwear', 'Business shoes for professionals'), " + "('electronics', 'Wireless headphones'), " + "('electronics', 'Smart fitness tracker') " + "AS products(category, description) GROUP BY category;"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 2); + for (idx_t i = 0; i < results->RowCount(); i++) { + EXPECT_NO_THROW({ + nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); + EXPECT_TRUE(parsed[0].contains("data")); + }); + } +} +// Test GROUP BY with single tuple per group: no LLM calls needed +TEST_F(LLMLastTest, GroupByWithSingleTuplePerGroup) { + auto con = Config::GetConnection(); + + const auto results = con.Query( + "SELECT category, llm_last(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Select the least relevant product', 'context_columns': [{'data': description}]}" + ") AS last_product FROM VALUES " + "('footwear', 'Running shoes with cushioning'), " + "('electronics', 'Wireless headphones'), " + "('fitness', 'Smart fitness tracker') " + "AS products(category, description) GROUP BY category;"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), 3); for (idx_t i = 0; i < results->RowCount(); i++) { EXPECT_NO_THROW({ nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); EXPECT_TRUE(parsed[0].contains("data")); + EXPECT_EQ(parsed[0]["data"].size(), 1); }); } } @@ -96,68 +145,62 @@ TEST_F(LLMLastTest, ValidateArguments) { } // Test operation with invalid arguments -TEST_F(LLMLastTest, Operation_InvalidArguments_ThrowsException) { +TEST_F(LLMLastTest, InvalidArguments) { TestOperationInvalidArguments(); } -// Test operation with multiple input scenarios - new API -TEST_F(LLMLastTest, Operation_MultipleInputs_ProcessesCorrectly) { - const nlohmann::json expected_response = GetExpectedJsonResponse(); +// Test with audio transcription +TEST_F(LLMLastTest, AudioTranscription) { + const nlohmann::json expected_transcription1 = nlohmann::json::parse(R"({"text": "First audio candidate"})"); + const nlohmann::json expected_transcription2 = nlohmann::json::parse(R"({"text": "Last audio candidate"})"); + nlohmann::json response_index_1 = nlohmann::json{{"items", {1}}}; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription1, expected_transcription2})); EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(3); + .Times(1); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(3) - .WillRepeatedly(::testing::Return(std::vector{expected_response})); + .WillOnce(::testing::Return(std::vector{response_index_1})); auto con = Config::GetConnection(); - const auto results = con.Query( - "SELECT category, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'What is the least relevant product information?', 'context_columns': [{'data': description}]}" - ") AS last_relevant_info FROM VALUES " - "('electronics', 'High-performance running shoes with advanced cushioning'), " - "('audio', 'Wireless noise-cancelling headphones for immersive audio'), " - "('fitness', 'Smart fitness tracker with heart rate monitoring') " - "AS products(category, description) GROUP BY category;"); - - ASSERT_EQ(results->RowCount(), 3); - for (idx_t i = 0; i < results->RowCount(); i++) { - EXPECT_NO_THROW({ - nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); - EXPECT_TRUE(parsed[0].contains("data")); - }); - } + "SELECT llm_last(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Select the worst audio candidate', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gpt-4o-transcribe'}" + "]}) AS result FROM VALUES " + "('https://example.com/audio1.mp3'), " + "('https://example.com/audio2.mp3') AS tbl(audio_url);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); } -// Test large input set processing - new API -TEST_F(LLMLastTest, Operation_LargeInputSet_ProcessesCorrectly) { - constexpr size_t input_count = 100; - const nlohmann::json expected_response = GetExpectedJsonResponse(); - - EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(100); - EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(100) - .WillRepeatedly(::testing::Return(std::vector{expected_response})); - +// Test audio transcription error handling for Ollama +TEST_F(LLMLastTest, AudioTranscriptionOllamaError) { auto con = Config::GetConnection(); + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); const auto results = con.Query( - "SELECT id, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Select the last relevant product based on relevance', 'context_columns': [{'data': 'Product description ' || id::TEXT}]}" - ") AS last_relevant FROM range(" + - std::to_string(input_count) + ") AS t(id) GROUP BY id;"); - - ASSERT_EQ(results->RowCount(), 100); - for (idx_t i = 0; i < results->RowCount(); i++) { - EXPECT_NO_THROW({ - nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); - EXPECT_TRUE(parsed[0].contains("data")); - }); - } + "SELECT llm_last(" + "{'model_name': 'gemma3:4b'}, " + "{'prompt': 'Select the worst audio', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gemma3:4b'}" + "]}) AS result FROM VALUES " + "('https://example.com/audio1.mp3'), " + "('https://example.com/audio2.mp3') AS tbl(audio_url);"); + + ASSERT_TRUE(results->HasError()); } }// namespace flock diff --git a/test/unit/functions/aggregate/llm_reduce.cpp b/test/unit/functions/aggregate/llm_reduce.cpp index f7c425a9..50e27a4c 100644 --- a/test/unit/functions/aggregate/llm_reduce.cpp +++ b/test/unit/functions/aggregate/llm_reduce.cpp @@ -5,7 +5,7 @@ namespace flock { class LLMReduceTest : public LLMAggregateTestBase { protected: - static constexpr const char* EXPECTED_RESPONSE = "A comprehensive summary of running shoes, wireless headphones, and smart watches, featuring advanced technology and user-friendly designs for active lifestyles."; + static constexpr const char* EXPECTED_RESPONSE = "A comprehensive summary of products."; std::string GetExpectedResponse() const override { return EXPECTED_RESPONSE; @@ -39,8 +39,8 @@ class LLMReduceTest : public LLMAggregateTestBase { } }; -// Test llm_reduce with SQL queries without GROUP BY - new API -TEST_F(LLMReduceTest, LLMReduceWithoutGroupBy) { +// Test single tuple: LLM is still called for reduce (to summarize) +TEST_F(LLMReduceTest, SingleTupleWithLLMCall) { EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) .Times(1); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) @@ -49,20 +49,69 @@ TEST_F(LLMReduceTest, LLMReduceWithoutGroupBy) { auto con = Config::GetConnection(); const auto results = con.Query( - "SELECT " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Summarize the following product descriptions', 'context_columns': [{'data': description}]}" - ") AS product_summary FROM VALUES " - "('High-performance running shoes with advanced cushioning'), " - "('Wireless noise-cancelling headphones for immersive audio'), " - "('Smart fitness tracker with heart rate monitoring') AS products(description);"); + "SELECT llm_reduce(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Summarize the following product descriptions', 'context_columns': [{'data': description}]}" + ") AS product_summary FROM VALUES " + "('High-performance running shoes with advanced cushioning') AS products(description);"); + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), 1); ASSERT_EQ(results->GetValue(0, 0).GetValue(), GetExpectedResponse()); } -// Test llm_reduce with SQL queries with GROUP BY - new API -TEST_F(LLMReduceTest, LLMReduceWithGroupBy) { +// Test multiple tuples without GROUP BY: LLM is called once +TEST_F(LLMReduceTest, MultipleTuplesWithoutGroupBy) { + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{GetExpectedJsonResponse()})); + + auto con = Config::GetConnection(); + + const auto results = con.Query( + "SELECT llm_reduce(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Summarize the following product descriptions', 'context_columns': [{'data': description}]}" + ") AS product_summary FROM VALUES " + "('High-performance running shoes with advanced cushioning'), " + "('Wireless noise-cancelling headphones for immersive audio'), " + "('Smart fitness tracker with heart rate monitoring') AS products(description);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); + ASSERT_EQ(results->GetValue(0, 0).GetValue(), GetExpectedResponse()); +} + +// Test GROUP BY with multiple tuples per group: LLM is called for each group +TEST_F(LLMReduceTest, GroupByWithMultipleTuplesPerGroup) { + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) + .Times(2); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .Times(2) + .WillRepeatedly(::testing::Return(std::vector{GetExpectedJsonResponse()})); + + auto con = Config::GetConnection(); + + const auto results = con.Query( + "SELECT category, llm_reduce(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Summarize the following product descriptions', 'context_columns': [{'data': description}]}" + ") AS description_summary FROM VALUES " + "('footwear', 'Running shoes with cushioning'), " + "('footwear', 'Business shoes for professionals'), " + "('electronics', 'Wireless headphones'), " + "('electronics', 'Smart fitness tracker') " + "AS products(category, description) GROUP BY category;"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 2); + ASSERT_EQ(results->GetValue(1, 0).GetValue(), GetExpectedResponse()); + ASSERT_EQ(results->GetValue(1, 1).GetValue(), GetExpectedResponse()); +} + +// Test GROUP BY with single tuple per group: LLM is still called (reduce always calls LLM) +TEST_F(LLMReduceTest, GroupByWithSingleTuplePerGroup) { EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) .Times(3); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) @@ -72,15 +121,16 @@ TEST_F(LLMReduceTest, LLMReduceWithGroupBy) { auto con = Config::GetConnection(); const auto results = con.Query( - "SELECT category, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Summarize the following product descriptions', 'context_columns': [{'data': description}]}" - ") AS description_summary FROM VALUES " - "('electronics', 'High-performance running shoes with advanced cushioning'), " - "('audio', 'Wireless noise-cancelling headphones for immersive audio'), " - "('fitness', 'Smart fitness tracker with heart rate monitoring') " - "AS products(category, description) GROUP BY category;"); - + "SELECT category, llm_reduce(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Summarize the following product descriptions', 'context_columns': [{'data': description}]}" + ") AS description_summary FROM VALUES " + "('electronics', 'Running shoes with advanced cushioning'), " + "('audio', 'Wireless noise-cancelling headphones'), " + "('fitness', 'Smart fitness tracker with heart rate monitoring') " + "AS products(category, description) GROUP BY category;"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), 3); ASSERT_EQ(results->GetValue(1, 0).GetValue(), GetExpectedResponse()); ASSERT_EQ(results->GetValue(1, 1).GetValue(), GetExpectedResponse()); @@ -93,62 +143,86 @@ TEST_F(LLMReduceTest, ValidateArguments) { } // Test operation with invalid arguments -TEST_F(LLMReduceTest, Operation_InvalidArguments_ThrowsException) { +TEST_F(LLMReduceTest, InvalidArguments) { TestOperationInvalidArguments(); } -// Test operation with multiple input scenarios - new API -TEST_F(LLMReduceTest, Operation_MultipleInputs_ProcessesCorrectly) { - const nlohmann::json expected_response = GetExpectedJsonResponse(); +// Test with audio transcription +TEST_F(LLMReduceTest, AudioTranscription) { + const nlohmann::json expected_transcription = nlohmann::json::parse(R"({"text": "This is a transcribed audio summary"})"); + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(3); + .Times(1); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(3) - .WillRepeatedly(::testing::Return(std::vector{expected_response})); + .WillOnce(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); - const auto results = con.Query( - "SELECT name, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Summarize the following product information', 'context_columns': [{'data': name}, {'data': description}]}" - ") AS comprehensive_summary FROM VALUES " - "('Running Shoes', 'High-performance running shoes with advanced cushioning'), " - "('Headphones', 'Wireless noise-cancelling headphones for immersive audio'), " - "('Fitness Tracker', 'Smart fitness tracker with heart rate monitoring') " - "AS products(name, description) GROUP BY name;"); - - ASSERT_EQ(results->RowCount(), 3); - ASSERT_EQ(results->GetValue(1, 0).GetValue(), GetExpectedResponse()); - ASSERT_EQ(results->GetValue(1, 1).GetValue(), GetExpectedResponse()); - ASSERT_EQ(results->GetValue(1, 2).GetValue(), GetExpectedResponse()); + "SELECT llm_reduce(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Summarize the following audio content', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gpt-4o-transcribe'}" + "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); } -// Test large input set processing - new API -TEST_F(LLMReduceTest, Operation_LargeInputSet_ProcessesCorrectly) { - constexpr size_t input_count = 100; - const nlohmann::json expected_response = PrepareExpectedResponseForLargeInput(input_count); +// Test with audio and text columns +TEST_F(LLMReduceTest, AudioAndTextColumns) { + const nlohmann::json expected_transcription = nlohmann::json::parse(R"({"text": "Product audio review"})"); + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(100); + .Times(1); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(100) - .WillRepeatedly(::testing::Return(std::vector{expected_response})); + .WillOnce(::testing::Return(std::vector{GetExpectedJsonResponse()})); auto con = Config::GetConnection(); + const auto results = con.Query( + "SELECT llm_reduce(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Summarize the product reviews', " + "'context_columns': [" + "{'data': text_review, 'name': 'text_review'}, " + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gpt-4o-transcribe'}" + "]}) AS result FROM VALUES ('Great product', 'https://example.com/audio.mp3') AS tbl(text_review, audio_url);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); +} + +// Test audio transcription error handling for Ollama +TEST_F(LLMReduceTest, AudioTranscriptionOllamaError) { + auto con = Config::GetConnection(); + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); const auto results = con.Query( - "SELECT id, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Summarize all product descriptions', 'context_columns': [{'data': 'Product description ' || id::TEXT}]}" - ") AS large_summary FROM range(" + - std::to_string(input_count) + ") AS t(id) GROUP BY id;"); - - ASSERT_EQ(results->RowCount(), 100); - for (size_t i = 0; i < input_count; i++) { - ASSERT_EQ(results->GetValue(1, i).GetValue(), FormatExpectedResult(expected_response)); - } + "SELECT llm_reduce(" + "{'model_name': 'gemma3:4b'}, " + "{'prompt': 'Summarize this audio', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gemma3:4b'}" + "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); + + ASSERT_TRUE(results->HasError()); } }// namespace flock diff --git a/test/unit/functions/aggregate/llm_reduce_json.cpp b/test/unit/functions/aggregate/llm_reduce_json.cpp index 28c43f3b..cfef2ad2 100644 --- a/test/unit/functions/aggregate/llm_reduce_json.cpp +++ b/test/unit/functions/aggregate/llm_reduce_json.cpp @@ -5,7 +5,7 @@ namespace flock { class LLMReduceJsonTest : public LLMAggregateTestBase { protected: - static constexpr const char* EXPECTED_JSON_RESPONSE = R"({"items": [{"summary": "A comprehensive summary of running shoes, wireless headphones, and smart watches, featuring advanced technology and user-friendly designs for active lifestyles."}]})"; + static constexpr const char* EXPECTED_JSON_RESPONSE = R"({"items": [{"summary": "A comprehensive summary of some products"}]})"; std::string GetExpectedResponse() const override { return EXPECTED_JSON_RESPONSE; @@ -145,7 +145,7 @@ TEST_F(LLMReduceJsonTest, Operation_LargeInputSet_ProcessesCorrectly) { const auto results = con.Query( "SELECT id, " + GetFunctionName() + "(" "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Create a JSON summary of all product descriptions with summary, total_items, and status fields', 'context_columns': [{'data': id::TEXT}, {'data': 'Product description ' || id::TEXT}]}" + "{'prompt': 'Create a JSON summary of all product descriptions with summary, total_items, and status fields', 'context_columns': [{'data': id::VARCHAR}, {'data': 'Product description ' || id::VARCHAR}]}" ") AS large_json_summary FROM range(" + std::to_string(input_count) + ") AS t(id) GROUP BY id;"); diff --git a/test/unit/functions/aggregate/llm_rerank.cpp b/test/unit/functions/aggregate/llm_rerank.cpp index d2ddaf7a..7edbd556 100644 --- a/test/unit/functions/aggregate/llm_rerank.cpp +++ b/test/unit/functions/aggregate/llm_rerank.cpp @@ -6,18 +6,15 @@ namespace flock { class LLMRerankTest : public LLMAggregateTestBase { protected: - // The LLM response (for mocking) - returns ranking indices - static constexpr const char* LLM_RESPONSE_WITHOUT_GROUP_BY = R"({"items":[0, 1, 2]})"; - static constexpr const char* LLM_RESPONSE_WITH_GROUP_BY = R"({"items":[0]})"; - // The expected function output (reranked data as JSON array) - static constexpr const char* EXPECTED_RESPONSE = R"([{"product_description":"High-performance running shoes with advanced cushioning"},{"product_description":"Professional business shoes"},{"product_description":"Casual sneakers for everyday wear"}])"; + static constexpr const char* LLM_RESPONSE = R"({"items":[0, 1, 2]})"; + static constexpr const char* EXPECTED_RESPONSE_SINGLE = R"([{"data":["High-performance running shoes with advanced cushioning"]}])"; std::string GetExpectedResponse() const override { - return EXPECTED_RESPONSE; + return EXPECTED_RESPONSE_SINGLE; } nlohmann::json GetExpectedJsonResponse() const override { - return nlohmann::json::parse(LLM_RESPONSE_WITHOUT_GROUP_BY); + return nlohmann::json::parse(LLM_RESPONSE); } std::string GetFunctionName() const override { @@ -45,8 +42,29 @@ class LLMRerankTest : public LLMAggregateTestBase { } }; -// Test llm_rerank with SQL queries without GROUP BY - new API -TEST_F(LLMRerankTest, LLMRerankWithoutGroupBy) { +// Test 1-tuple case: no LLM call needed, returns the single tuple directly +TEST_F(LLMRerankTest, SingleTupleNoLLMCall) { + auto con = Config::GetConnection(); + + const auto results = con.Query( + "SELECT llm_rerank(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Rank these products', 'context_columns': [{'data': description}]}" + ") AS reranked_products FROM VALUES " + "('High-performance running shoes with advanced cushioning') AS products(description);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); + + nlohmann::json parsed = nlohmann::json::parse(results->GetValue(0, 0).GetValue()); + EXPECT_EQ(parsed.size(), 1); + EXPECT_TRUE(parsed[0].contains("data")); + EXPECT_EQ(parsed[0]["data"].size(), 1); + EXPECT_EQ(parsed[0]["data"][0], "High-performance running shoes with advanced cushioning"); +} + +// Test multiple tuples without GROUP BY: LLM is called once +TEST_F(LLMRerankTest, MultipleTuplesWithoutGroupBy) { EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) .Times(1); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) @@ -55,14 +73,15 @@ TEST_F(LLMRerankTest, LLMRerankWithoutGroupBy) { auto con = Config::GetConnection(); const auto results = con.Query( - "SELECT " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Rank these products by their relevance and quality based on descriptions', 'context_columns': [{'data': description}]}" - ") AS reranked_products FROM VALUES " - "('High-performance running shoes with advanced cushioning'), " - "('Professional business shoes'), " - "('Casual sneakers for everyday wear') AS products(description);"); - + "SELECT llm_rerank(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Rank these products by relevance', 'context_columns': [{'data': description}]}" + ") AS reranked_products FROM VALUES " + "('High-performance running shoes with advanced cushioning'), " + "('Professional business shoes'), " + "('Casual sneakers for everyday wear') AS products(description);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), 1); EXPECT_NO_THROW({ nlohmann::json parsed = nlohmann::json::parse(results->GetValue(0, 0).GetValue()); @@ -72,31 +91,59 @@ TEST_F(LLMRerankTest, LLMRerankWithoutGroupBy) { }); } -// Test llm_rerank with SQL queries with GROUP BY - new API -TEST_F(LLMRerankTest, LLMRerankWithGroupBy) { +// Test GROUP BY with multiple tuples per group: LLM is called for each group +TEST_F(LLMRerankTest, GroupByWithMultipleTuplesPerGroup) { + nlohmann::json response_2_items = nlohmann::json{{"items", {1, 0}}}; + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(3); + .Times(2); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(3) - .WillRepeatedly(::testing::Return(std::vector{nlohmann::json::parse(LLM_RESPONSE_WITH_GROUP_BY)})); + .Times(2) + .WillRepeatedly(::testing::Return(std::vector{response_2_items})); auto con = Config::GetConnection(); const auto results = con.Query( - "SELECT category, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Rank these products by their relevance and quality based on descriptions', 'context_columns': [{'data': description}]}" - ") AS reranked_products FROM VALUES " - "('electronics', 'High-performance running shoes with advanced cushioning'), " - "('audio', 'Professional business shoes'), " - "('fitness', 'Casual sneakers for everyday wear') " - "AS products(category, description) GROUP BY category;"); + "SELECT category, llm_rerank(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Rank these products by relevance', 'context_columns': [{'data': description}]}" + ") AS reranked_products FROM VALUES " + "('footwear', 'Running shoes with cushioning'), " + "('footwear', 'Business shoes for professionals'), " + "('electronics', 'Wireless headphones'), " + "('electronics', 'Smart fitness tracker') " + "AS products(category, description) GROUP BY category;"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 2); + for (idx_t i = 0; i < results->RowCount(); i++) { + EXPECT_NO_THROW({ + nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); + EXPECT_TRUE(parsed[0].contains("data")); + EXPECT_EQ(parsed[0]["data"].size(), 2); + }); + } +} + +// Test GROUP BY with single tuple per group: no LLM calls needed +TEST_F(LLMRerankTest, GroupByWithSingleTuplePerGroup) { + auto con = Config::GetConnection(); + const auto results = con.Query( + "SELECT category, llm_rerank(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Rank these products by relevance', 'context_columns': [{'data': description}]}" + ") AS reranked_products FROM VALUES " + "('footwear', 'Running shoes with cushioning'), " + "('electronics', 'Wireless headphones'), " + "('fitness', 'Smart fitness tracker') " + "AS products(category, description) GROUP BY category;"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), 3); for (idx_t i = 0; i < results->RowCount(); i++) { EXPECT_NO_THROW({ nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); - EXPECT_EQ(parsed.size(), 1); EXPECT_TRUE(parsed[0].contains("data")); EXPECT_EQ(parsed[0]["data"].size(), 1); }); @@ -109,74 +156,62 @@ TEST_F(LLMRerankTest, ValidateArguments) { } // Test operation with invalid arguments -TEST_F(LLMRerankTest, Operation_InvalidArguments_ThrowsException) { +TEST_F(LLMRerankTest, InvalidArguments) { TestOperationInvalidArguments(); } -// Test operation with multiple input scenarios - new API -TEST_F(LLMRerankTest, Operation_MultipleInputs_ProcessesCorrectly) { - const nlohmann::json expected_response = nlohmann::json::parse(LLM_RESPONSE_WITH_GROUP_BY); +// Test with audio transcription +TEST_F(LLMRerankTest, AudioTranscription) { + const nlohmann::json expected_transcription1 = nlohmann::json::parse(R"({"text": "First audio candidate"})"); + const nlohmann::json expected_transcription2 = nlohmann::json::parse(R"({"text": "Second audio candidate"})"); + nlohmann::json response_2_items = nlohmann::json{{"items", {1, 0}}}; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription1, expected_transcription2})); EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(3); + .Times(1); EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(3) - .WillRepeatedly(::testing::Return(std::vector{expected_response})); + .WillOnce(::testing::Return(std::vector{response_2_items})); auto con = Config::GetConnection(); - const auto results = con.Query( - "SELECT category, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Rank products by relevance to customer preferences', 'context_columns': [{'data': id::TEXT}, {'data': description}]}" - ") AS reranked_products FROM VALUES " - "('electronics', 1, 'High-performance running shoes with advanced cushioning'), " - "('audio', 2, 'Professional business shoes'), " - "('fitness', 3, 'Casual sneakers for everyday wear') " - "AS products(category, id, description) GROUP BY category;"); - - ASSERT_EQ(results->RowCount(), 3); - for (idx_t i = 0; i < results->RowCount(); i++) { - EXPECT_NO_THROW({ - nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); - EXPECT_EQ(parsed.size(), 2); - EXPECT_TRUE(parsed[0].contains("data")); - EXPECT_EQ(parsed[0]["data"].size(), 1); - }); - } + "SELECT llm_rerank(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Rank these audio candidates from best to worst', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gpt-4o-transcribe'}" + "]}) AS result FROM VALUES " + "('https://example.com/audio1.mp3'), " + "('https://example.com/audio2.mp3') AS tbl(audio_url);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); } -// Test large input set processing - new API -TEST_F(LLMRerankTest, Operation_LargeInputSet_ProcessesCorrectly) { - constexpr size_t input_count = 100; - const nlohmann::json expected_response = nlohmann::json::parse(LLM_RESPONSE_WITH_GROUP_BY); - - EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .Times(100); - EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) - .Times(100) - .WillRepeatedly(::testing::Return(std::vector{expected_response})); - +// Test audio transcription error handling for Ollama +TEST_F(LLMRerankTest, AudioTranscriptionOllamaError) { auto con = Config::GetConnection(); + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); const auto results = con.Query( - "SELECT id, " + GetFunctionName() + "(" - "{'model_name': 'gpt-4o'}, " - "{'prompt': 'Rerank products by relevance and importance', 'context_columns': [{'data': id::TEXT}, {'data': 'Product description ' || id::TEXT}]}" - ") AS reranked_products FROM range(" + - std::to_string(input_count) + ") AS t(id) GROUP BY id;"); - - ASSERT_EQ(results->RowCount(), 100); - for (idx_t i = 0; i < results->RowCount(); i++) { - EXPECT_NO_THROW({ - nlohmann::json parsed = nlohmann::json::parse(results->GetValue(1, i).GetValue()); - EXPECT_EQ(parsed.size(), 2); - EXPECT_TRUE(parsed[0].contains("data")); - EXPECT_EQ(parsed[0]["data"].size(), 1); - }); - } - - ::testing::Mock::AllowLeak(mock_provider.get()); + "SELECT llm_rerank(" + "{'model_name': 'gemma3:4b'}, " + "{'prompt': 'Rank these audio files', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gemma3:4b'}" + "]}) AS result FROM VALUES " + "('https://example.com/audio1.mp3'), " + "('https://example.com/audio2.mp3') AS tbl(audio_url);"); + + ASSERT_TRUE(results->HasError()); } }// namespace flock diff --git a/test/unit/functions/mock_provider.hpp b/test/unit/functions/mock_provider.hpp index d53c90e9..4a0b6f8c 100644 --- a/test/unit/functions/mock_provider.hpp +++ b/test/unit/functions/mock_provider.hpp @@ -10,8 +10,10 @@ class MockProvider : public IProvider { MOCK_METHOD(void, AddCompletionRequest, (const std::string& prompt, const int num_output_tuples, OutputType output_type, const nlohmann::json& media_data), (override)); MOCK_METHOD(void, AddEmbeddingRequest, (const std::vector& inputs), (override)); + MOCK_METHOD(void, AddTranscriptionRequest, (const nlohmann::json& audio_files), (override)); MOCK_METHOD(std::vector, CollectCompletions, (const std::string& contentType), (override)); MOCK_METHOD(std::vector, CollectEmbeddings, (const std::string& contentType), (override)); + MOCK_METHOD(std::vector, CollectTranscriptions, (const std::string& contentType), (override)); }; }// namespace flock diff --git a/test/unit/functions/scalar/llm_complete.cpp b/test/unit/functions/scalar/llm_complete.cpp index ccabea7f..c2936597 100644 --- a/test/unit/functions/scalar/llm_complete.cpp +++ b/test/unit/functions/scalar/llm_complete.cpp @@ -145,7 +145,7 @@ TEST_F(LLMCompleteTest, Operation_LargeInputSet_ProcessesCorrectly) { auto query = "SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'}, " + "{'prompt': 'Summarize the following text', " + - " 'context_columns': [{'data': 'Input text ' || i::TEXT}]}) AS result " + + " 'context_columns': [{'data': 'Input text ' || i::VARCHAR}]}) AS result " + "FROM range(" + std::to_string(input_count) + ") AS t(i);"; const auto results = con.Query(query); @@ -161,4 +161,105 @@ TEST_F(LLMCompleteTest, Operation_LargeInputSet_ProcessesCorrectly) { } } +// Test llm_complete with audio transcription +TEST_F(LLMCompleteTest, LLMCompleteWithAudioTranscription) { + const nlohmann::json expected_transcription = "{\"text\": \"This is a transcribed audio\"}"; + const nlohmann::json expected_complete_response = {{"items", {"Based on the transcription: This is a transcribed audio"}}}; + + // Mock transcription model + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); + + // Mock completion model + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_complete_response})); + + auto con = Config::GetConnection(); + const auto results = con.Query( + "SELECT llm_complete(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Summarize this audio', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gpt-4o-transcribe'}" + "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); +} + +// Test llm_complete with audio and text columns +TEST_F(LLMCompleteTest, LLMCompleteWithAudioAndText) { + const nlohmann::json expected_transcription = "{\"text\": \"Product audio description\"}"; + const nlohmann::json expected_complete_response = {{"items", {"Combined response"}}}; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); + + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_complete_response})); + + auto con = Config::GetConnection(); + const auto results = con.Query( + "SELECT llm_complete(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Describe this product', " + "'context_columns': [" + "{'data': product, 'name': 'product'}, " + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gpt-4o-transcribe'}" + "]}) AS result FROM VALUES ('Wireless Headphones', 'https://example.com/audio.mp3') AS tbl(product, audio_url);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); +} + +// Test audio transcription error handling +TEST_F(LLMCompleteTest, LLMCompleteAudioTranscriptionError) { + auto con = Config::GetConnection(); + // Mock transcription model to throw error (simulating Ollama behavior) + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); + + // Test with Ollama which doesn't support transcription + const auto results = con.Query( + "SELECT llm_complete(" + "{'model_name': 'gemma3:4b'}, " + "{'prompt': 'Summarize this audio', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gemma3:4b'}" + "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); + + // Should fail because Ollama doesn't support transcription + ASSERT_TRUE(results->HasError()); +} + +// Test audio transcription with missing transcription_model +TEST_F(LLMCompleteTest, LLMCompleteAudioMissingTranscriptionModel) { + auto con = Config::GetConnection(); + const auto results = con.Query( + "SELECT llm_complete(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Summarize this audio', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio'}" + "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); + + // Should fail because transcription_model is required for audio type + ASSERT_TRUE(results->HasError()); +} + }// namespace flock \ No newline at end of file diff --git a/test/unit/functions/scalar/llm_embedding.cpp b/test/unit/functions/scalar/llm_embedding.cpp index c4852dd5..b5d17e7c 100644 --- a/test/unit/functions/scalar/llm_embedding.cpp +++ b/test/unit/functions/scalar/llm_embedding.cpp @@ -143,7 +143,7 @@ TEST_F(LLMEmbeddingTest, Operation_LargeInputSet_ProcessesCorrectly) { .WillOnce(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); - const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'text-embedding-3-small'}, {'context_columns': [{'data': content}]}) AS embedding FROM range(" + std::to_string(input_count) + ") AS t(i), unnest(['Document content number ' || i::TEXT]) as tbl(content);"); + const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'text-embedding-3-small'}, {'context_columns': [{'data': content}]}) AS embedding FROM range(" + std::to_string(input_count) + ") AS t(i), unnest(['Document content number ' || i::VARCHAR]) as tbl(content);"); ASSERT_TRUE(!results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), input_count); diff --git a/test/unit/functions/scalar/llm_filter.cpp b/test/unit/functions/scalar/llm_filter.cpp index f7029ec8..66d1fbb0 100644 --- a/test/unit/functions/scalar/llm_filter.cpp +++ b/test/unit/functions/scalar/llm_filter.cpp @@ -62,6 +62,19 @@ TEST_F(LLMFilterTest, LLMFilterBasicUsage) { ASSERT_EQ(results->GetValue(0, 0).GetValue(), "true"); } +TEST_F(LLMFilterTest, LLMFilterWithoutContextColumns) { + const nlohmann::json expected_response = {{"items", {true}}}; + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_response})); + + auto con = Config::GetConnection(); + const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'}, {'prompt': 'Are you a Robot?'}) AS filter_result;"); + ASSERT_EQ(results->RowCount(), 1); + ASSERT_EQ(results->GetValue(0, 0).GetValue(), "true"); +} + TEST_F(LLMFilterTest, LLMFilterWithMultipleRows) { const nlohmann::json expected_response = {{"items", {true}}}; EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) @@ -105,7 +118,7 @@ TEST_F(LLMFilterTest, Operation_LargeInputSet_ProcessesCorrectly) { .WillOnce(::testing::Return(std::vector{expected_response})); auto con = Config::GetConnection(); - const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'}, {'prompt': 'Is this content spam?', 'context_columns': [{'data': content}]}) AS result FROM range(" + std::to_string(input_count) + ") AS t(i), unnest(['Content item ' || i::TEXT]) as tbl(content);"); + const auto results = con.Query("SELECT " + GetFunctionName() + "({'model_name': 'gpt-4o'}, {'prompt': 'Is this content spam?', 'context_columns': [{'data': content}]}) AS result FROM range(" + std::to_string(input_count) + ") AS t(i), unnest(['Content item ' || i::VARCHAR]) as tbl(content);"); ASSERT_TRUE(!results->HasError()) << "Query failed: " << results->GetError(); ASSERT_EQ(results->RowCount(), input_count); @@ -117,4 +130,90 @@ TEST_F(LLMFilterTest, Operation_LargeInputSet_ProcessesCorrectly) { } } +// Test llm_filter with audio transcription +TEST_F(LLMFilterTest, LLMFilterWithAudioTranscription) { + const nlohmann::json expected_transcription = "{\"text\": \"This audio contains positive sentiment\"}"; + const nlohmann::json expected_complete_response = {{"items", {true}}}; + + // Mock transcription model + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); + + // Mock completion model + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_complete_response})); + + auto con = Config::GetConnection(); + const auto results = con.Query( + "SELECT llm_filter(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Is the sentiment in this audio positive?', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gpt-4o-transcribe'}" + "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); +} + +// Test llm_filter with audio and text columns +TEST_F(LLMFilterTest, LLMFilterWithAudioAndText) { + const nlohmann::json expected_transcription = "{\"text\": \"Product review audio\"}"; + const nlohmann::json expected_complete_response = {{"items", {true}}}; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); + + EXPECT_CALL(*mock_provider, AddCompletionRequest(::testing::_, ::testing::_, ::testing::_, ::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectCompletions(::testing::_)) + .WillOnce(::testing::Return(std::vector{expected_complete_response})); + + auto con = Config::GetConnection(); + const auto results = con.Query( + "SELECT llm_filter(" + "{'model_name': 'gpt-4o'}, " + "{'prompt': 'Is this product review positive?', " + "'context_columns': [" + "{'data': text_review, 'name': 'text_review'}, " + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gpt-4o-transcribe'}" + "]}) AS result FROM VALUES ('Great product', 'https://example.com/audio.mp3') AS tbl(text_review, audio_url);"); + + ASSERT_FALSE(results->HasError()) << "Query failed: " << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); +} + +// Test audio transcription error handling for Ollama +TEST_F(LLMFilterTest, LLMFilterAudioTranscriptionOllamaError) { + auto con = Config::GetConnection(); + + // Mock transcription model to throw error (simulating Ollama behavior) + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); + + // Test with Ollama which doesn't support transcription + const auto results = con.Query( + "SELECT llm_filter(" + "{'model_name': 'gemma3:4b'}, " + "{'prompt': 'Is the sentiment positive?', " + "'context_columns': [" + "{'data': audio_url, " + "'type': 'audio', " + "'transcription_model': 'gemma3:4b'}" + "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); + + // Should fail because Ollama doesn't support transcription + ASSERT_TRUE(results->HasError()); +} + }// namespace flock diff --git a/test/unit/functions/scalar/llm_function_test_base_instantiations.cpp b/test/unit/functions/scalar/llm_function_test_base_instantiations.cpp index 487ab323..6aaba0fa 100644 --- a/test/unit/functions/scalar/llm_function_test_base_instantiations.cpp +++ b/test/unit/functions/scalar/llm_function_test_base_instantiations.cpp @@ -12,6 +12,9 @@ void LLMFunctionTestBase::SetUp() { con.Query(" CREATE SECRET (" " TYPE OPENAI," " API_KEY 'your-api-key');"); + con.Query(" CREATE SECRET (" + " TYPE OLLAMA," + " API_URL '127.0.0.1:11434');"); mock_provider = std::make_shared(ModelDetails{}); Model::SetMockProvider(mock_provider); diff --git a/test/unit/functions/scalar/metrics_test.cpp b/test/unit/functions/scalar/metrics_test.cpp new file mode 100644 index 00000000..1e233c2e --- /dev/null +++ b/test/unit/functions/scalar/metrics_test.cpp @@ -0,0 +1,545 @@ +#include "flock/core/config.hpp" +#include "flock/metrics/manager.hpp" +#include + +namespace flock { + +class MetricsTest : public ::testing::Test { +protected: + void SetUp() override { + auto con = Config::GetConnection(); + // Reset metrics before each test to ensure clean state + auto& manager = MetricsManager::GetForDatabase(GetDatabase()); + manager.Reset(); + } + + duckdb::DatabaseInstance* GetDatabase() { + return Config::db; + } + + MetricsManager& GetMetricsManager() { + return MetricsManager::GetForDatabase(GetDatabase()); + } +}; + +TEST_F(MetricsTest, InitialMetricsAreZero) { + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + EXPECT_TRUE(metrics.is_object()); + EXPECT_TRUE(metrics.empty()); +} + +TEST_F(MetricsTest, UpdateTokensForLlmComplete) { + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0x1234); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_COMPLETE); + MetricsManager::UpdateTokens(100, 50); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + bool found = false; + for (const auto& [key, value]: metrics.items()) { + if (key.find("llm_complete_") == 0) { + EXPECT_EQ(value["input_tokens"].get(), 100); + EXPECT_EQ(value["output_tokens"].get(), 50); + EXPECT_EQ(value["total_tokens"].get(), 150); + found = true; + break; + } + } + EXPECT_TRUE(found); +} + +TEST_F(MetricsTest, TracksDifferentFunctionsSeparately) { + auto* db = GetDatabase(); + const void* state_id1 = reinterpret_cast(0x1234); + const void* state_id2 = reinterpret_cast(0x5678); + + MetricsManager::StartInvocation(db, state_id1, FunctionType::LLM_COMPLETE); + MetricsManager::UpdateTokens(100, 50); + MetricsManager::AddExecutionTime(1000.0); + + MetricsManager::StartInvocation(db, state_id2, FunctionType::LLM_FILTER); + MetricsManager::UpdateTokens(200, 100); + MetricsManager::AddExecutionTime(2000.0); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + bool found_complete = false; + bool found_filter = false; + int64_t total_input = 0; + int64_t total_output = 0; + + for (const auto& [key, value]: metrics.items()) { + if (key.find("llm_complete_") == 0) { + EXPECT_EQ(value["input_tokens"].get(), 100); + total_input += value["input_tokens"].get(); + total_output += value["output_tokens"].get(); + found_complete = true; + } else if (key.find("llm_filter_") == 0) { + EXPECT_EQ(value["input_tokens"].get(), 200); + total_input += value["input_tokens"].get(); + total_output += value["output_tokens"].get(); + found_filter = true; + } + } + + EXPECT_TRUE(found_complete); + EXPECT_TRUE(found_filter); + EXPECT_EQ(total_input, 300); + EXPECT_EQ(total_output, 150); +} + +TEST_F(MetricsTest, IncrementApiCalls) { + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0x1234); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_COMPLETE); + MetricsManager::IncrementApiCalls(); + MetricsManager::IncrementApiCalls(); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_FILTER); + MetricsManager::IncrementApiCalls(); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + int64_t total_api_calls = 0; + int64_t complete_calls = 0; + int64_t filter_calls = 0; + + for (const auto& [key, value]: metrics.items()) { + if (key.find("llm_complete_") == 0) { + complete_calls = value["api_calls"].get(); + total_api_calls += complete_calls; + } else if (key.find("llm_filter_") == 0) { + filter_calls = value["api_calls"].get(); + total_api_calls += filter_calls; + } + } + + EXPECT_EQ(total_api_calls, 3); + EXPECT_EQ(complete_calls, 2); + EXPECT_EQ(filter_calls, 1); +} + +TEST_F(MetricsTest, AddApiDuration) { + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0x1234); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_COMPLETE); + MetricsManager::AddApiDuration(100.5); + MetricsManager::AddApiDuration(200.25); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + bool found = false; + for (const auto& [key, value]: metrics.items()) { + if (key.find("llm_complete_") == 0) { + EXPECT_NEAR(value["api_duration_ms"].get(), 300.75, 0.01); + found = true; + break; + } + } + EXPECT_TRUE(found); +} + +TEST_F(MetricsTest, AddExecutionTime) { + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0x1234); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_COMPLETE); + MetricsManager::AddExecutionTime(150.0); + MetricsManager::AddExecutionTime(250.0); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + bool found = false; + for (const auto& [key, value]: metrics.items()) { + if (key.find("llm_complete_") == 0) { + EXPECT_NEAR(value["execution_time_ms"].get(), 400.0, 0.01); + found = true; + break; + } + } + EXPECT_TRUE(found); +} + +TEST_F(MetricsTest, ResetClearsAllMetrics) { + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0x1234); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_COMPLETE); + MetricsManager::UpdateTokens(100, 50); + MetricsManager::IncrementApiCalls(); + MetricsManager::AddApiDuration(100.0); + MetricsManager::AddExecutionTime(150.0); + + auto& manager = GetMetricsManager(); + manager.Reset(); + + auto metrics = manager.GetMetrics(); + EXPECT_TRUE(metrics.is_object()); + EXPECT_TRUE(metrics.empty()); +} + +TEST_F(MetricsTest, SqlFunctionFlockGetMetrics) { + auto con = Config::GetConnection(); + auto results = con.Query("SELECT flock_get_metrics() AS metrics;"); + + ASSERT_FALSE(results->HasError()) << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); + + auto json_str = results->GetValue(0, 0).GetValue(); + auto metrics = nlohmann::json::parse(json_str); + + EXPECT_TRUE(metrics.is_object()); +} + +TEST_F(MetricsTest, SqlFunctionFlockResetMetrics) { + auto con = Config::GetConnection(); + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0x1234); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_COMPLETE); + MetricsManager::UpdateTokens(100, 50); + MetricsManager::IncrementApiCalls(); + + auto results = con.Query("SELECT flock_reset_metrics() AS result;"); + + ASSERT_FALSE(results->HasError()) << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + EXPECT_TRUE(metrics.is_object()); + EXPECT_TRUE(metrics.empty()); +} + +TEST_F(MetricsTest, SequentialNumberingForMultipleCalls) { + auto* db = GetDatabase(); + const void* state_id1 = reinterpret_cast(0x1111); + const void* state_id2 = reinterpret_cast(0x2222); + const void* state_id3 = reinterpret_cast(0x3333); + + MetricsManager::StartInvocation(db, state_id1, FunctionType::LLM_FILTER); + MetricsManager::UpdateTokens(10, 5); + + MetricsManager::StartInvocation(db, state_id2, FunctionType::LLM_FILTER); + MetricsManager::UpdateTokens(20, 10); + + MetricsManager::StartInvocation(db, state_id3, FunctionType::LLM_FILTER); + MetricsManager::UpdateTokens(30, 15); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + bool found_1 = false, found_2 = false, found_3 = false; + for (const auto& [key, value]: metrics.items()) { + if (key == "llm_filter_1") { + EXPECT_EQ(value["input_tokens"].get(), 10); + found_1 = true; + } else if (key == "llm_filter_2") { + EXPECT_EQ(value["input_tokens"].get(), 20); + found_2 = true; + } else if (key == "llm_filter_3") { + EXPECT_EQ(value["input_tokens"].get(), 30); + found_3 = true; + } + } + + EXPECT_TRUE(found_1) << "llm_filter_1 not found"; + EXPECT_TRUE(found_2) << "llm_filter_2 not found"; + EXPECT_TRUE(found_3) << "llm_filter_3 not found"; +} + +TEST_F(MetricsTest, DebugMetricsReturnsNestedStructure) { + auto* db = GetDatabase(); + const void* state_id1 = reinterpret_cast(0x1111); + const void* state_id2 = reinterpret_cast(0x2222); + + MetricsManager::StartInvocation(db, state_id1, FunctionType::LLM_COMPLETE); + MetricsManager::UpdateTokens(100, 50); + MetricsManager::SetModelInfo("gpt-4o", "openai"); + + MetricsManager::StartInvocation(db, state_id2, FunctionType::LLM_FILTER); + MetricsManager::UpdateTokens(200, 100); + MetricsManager::SetModelInfo("gpt-4o", "openai"); + + auto& manager = GetMetricsManager(); + auto debug_metrics = manager.GetDebugMetrics(); + + EXPECT_TRUE(debug_metrics.is_object()); + EXPECT_TRUE(debug_metrics.contains("threads")); + EXPECT_TRUE(debug_metrics.contains("thread_count")); + EXPECT_GE(debug_metrics["thread_count"].get(), 1); + + auto threads = debug_metrics["threads"]; + EXPECT_TRUE(threads.is_object()); +} + +TEST_F(MetricsTest, DebugMetricsContainsRegistrationOrder) { + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0x1234); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_COMPLETE); + MetricsManager::UpdateTokens(100, 50); + + auto& manager = GetMetricsManager(); + auto debug_metrics = manager.GetDebugMetrics(); + + bool found_registration_order = false; + for (const auto& [thread_id, thread_data]: debug_metrics["threads"].items()) { + for (const auto& [state_id_str, state_data]: thread_data.items()) { + if (state_data.contains("llm_complete")) { + EXPECT_TRUE(state_data["llm_complete"].contains("registration_order")); + EXPECT_GT(state_data["llm_complete"]["registration_order"].get(), 0); + found_registration_order = true; + } + } + } + EXPECT_TRUE(found_registration_order); +} + +TEST_F(MetricsTest, SqlFunctionFlockGetDebugMetrics) { + auto con = Config::GetConnection(); + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0x1234); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_COMPLETE); + MetricsManager::UpdateTokens(100, 50); + + auto results = con.Query("SELECT flock_get_debug_metrics() AS debug_metrics;"); + + ASSERT_FALSE(results->HasError()) << results->GetError(); + ASSERT_EQ(results->RowCount(), 1); + + auto json_str = results->GetValue(0, 0).GetValue(); + auto debug_metrics = nlohmann::json::parse(json_str); + + EXPECT_TRUE(debug_metrics.is_object()); + EXPECT_TRUE(debug_metrics.contains("threads")); + EXPECT_TRUE(debug_metrics.contains("thread_count")); +} + +TEST_F(MetricsTest, AggregateFunctionMetricsTracking) { + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0xAAAA); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_REDUCE); + MetricsManager::UpdateTokens(500, 200); + MetricsManager::SetModelInfo("gpt-4o", "openai"); + MetricsManager::IncrementApiCalls(); + MetricsManager::AddApiDuration(2000.0); + MetricsManager::AddExecutionTime(2500.0); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + bool found = false; + for (const auto& [key, value]: metrics.items()) { + if (key.find("llm_reduce_") == 0) { + EXPECT_EQ(value["input_tokens"].get(), 500); + EXPECT_EQ(value["output_tokens"].get(), 200); + EXPECT_EQ(value["total_tokens"].get(), 700); + EXPECT_EQ(value["api_calls"].get(), 1); + EXPECT_NEAR(value["api_duration_ms"].get(), 2000.0, 0.01); + EXPECT_NEAR(value["execution_time_ms"].get(), 2500.0, 0.01); + EXPECT_EQ(value["model_name"].get(), "gpt-4o"); + EXPECT_EQ(value["provider"].get(), "openai"); + found = true; + break; + } + } + EXPECT_TRUE(found); +} + +TEST_F(MetricsTest, MultipleAggregateFunctionsSequentialNumbering) { + auto* db = GetDatabase(); + const void* state_id1 = reinterpret_cast(0xBBBB); + const void* state_id2 = reinterpret_cast(0xCCCC); + + MetricsManager::StartInvocation(db, state_id1, FunctionType::LLM_REDUCE); + MetricsManager::UpdateTokens(100, 50); + + MetricsManager::StartInvocation(db, state_id2, FunctionType::LLM_REDUCE); + MetricsManager::UpdateTokens(200, 100); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + bool found_1 = false, found_2 = false; + for (const auto& [key, value]: metrics.items()) { + if (key == "llm_reduce_1") { + EXPECT_EQ(value["input_tokens"].get(), 100); + found_1 = true; + } else if (key == "llm_reduce_2") { + EXPECT_EQ(value["input_tokens"].get(), 200); + found_2 = true; + } + } + + EXPECT_TRUE(found_1) << "llm_reduce_1 not found"; + EXPECT_TRUE(found_2) << "llm_reduce_2 not found"; +} + +TEST_F(MetricsTest, AggregateFunctionMetricsMerging) { + auto* db = GetDatabase(); + const void* state_id1 = reinterpret_cast(0xAAAA); + const void* state_id2 = reinterpret_cast(0xBBBB); + const void* state_id3 = reinterpret_cast(0xCCCC); + + // Simulate multiple states being processed in a single aggregate call + // Each state tracks its own metrics + MetricsManager::StartInvocation(db, state_id1, FunctionType::LLM_REDUCE); + MetricsManager::SetModelInfo("gpt-4o", "openai"); + MetricsManager::UpdateTokens(100, 50); + MetricsManager::IncrementApiCalls(); + MetricsManager::AddApiDuration(100.0); + MetricsManager::AddExecutionTime(150.0); + + MetricsManager::StartInvocation(db, state_id2, FunctionType::LLM_REDUCE); + MetricsManager::SetModelInfo("gpt-4o", "openai"); + MetricsManager::UpdateTokens(200, 100); + MetricsManager::IncrementApiCalls(); + MetricsManager::AddApiDuration(200.0); + MetricsManager::AddExecutionTime(250.0); + + MetricsManager::StartInvocation(db, state_id3, FunctionType::LLM_REDUCE); + MetricsManager::SetModelInfo("gpt-4o", "openai"); + MetricsManager::UpdateTokens(150, 75); + MetricsManager::IncrementApiCalls(); + MetricsManager::AddApiDuration(150.0); + MetricsManager::AddExecutionTime(200.0); + + // Now merge all metrics into the first state + std::vector processed_state_ids = {state_id1, state_id2, state_id3}; + MetricsManager::MergeAggregateMetrics(db, processed_state_ids, FunctionType::LLM_REDUCE, "gpt-4o", "openai"); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + // Should have exactly ONE llm_reduce entry (merged) + int reduce_count = 0; + int64_t total_input_tokens = 0; + int64_t total_output_tokens = 0; + int64_t total_api_calls = 0; + double total_api_duration = 0.0; + double total_execution_time = 0.0; + + for (const auto& [key, value]: metrics.items()) { + if (key.find("llm_reduce_") == 0) { + reduce_count++; + total_input_tokens += value["input_tokens"].get(); + total_output_tokens += value["output_tokens"].get(); + total_api_calls += value["api_calls"].get(); + total_api_duration += value["api_duration_ms"].get(); + total_execution_time += value["execution_time_ms"].get(); + } + } + + // Should have exactly one merged entry + EXPECT_EQ(reduce_count, 1) << "Expected exactly 1 merged llm_reduce metrics entry"; + + // Verify merged values are the sum of all states + EXPECT_EQ(total_input_tokens, 450) << "Merged input tokens should be sum of all states (100+200+150)"; + EXPECT_EQ(total_output_tokens, 225) << "Merged output tokens should be sum of all states (50+100+75)"; + EXPECT_EQ(total_api_calls, 3) << "Merged API calls should be sum of all states (1+1+1)"; + EXPECT_NEAR(total_api_duration, 450.0, 0.01) << "Merged API duration should be sum of all states (100+200+150)"; + EXPECT_NEAR(total_execution_time, 600.0, 0.01) << "Merged execution time should be sum of all states (150+250+200)"; +} + +TEST_F(MetricsTest, AggregateFunctionDebugMetrics) { + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0xDDDD); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_RERANK); + MetricsManager::UpdateTokens(300, 150); + MetricsManager::SetModelInfo("gpt-4o", "openai"); + + auto& manager = GetMetricsManager(); + auto debug_metrics = manager.GetDebugMetrics(); + + bool found_rerank = false; + for (const auto& [thread_id, thread_data]: debug_metrics["threads"].items()) { + for (const auto& [state_id_str, state_data]: thread_data.items()) { + if (state_data.contains("llm_rerank")) { + EXPECT_EQ(state_data["llm_rerank"]["input_tokens"].get(), 300); + EXPECT_EQ(state_data["llm_rerank"]["output_tokens"].get(), 150); + EXPECT_TRUE(state_data["llm_rerank"].contains("registration_order")); + found_rerank = true; + } + } + } + EXPECT_TRUE(found_rerank); +} + +TEST_F(MetricsTest, MixedScalarAndAggregateMetrics) { + auto* db = GetDatabase(); + const void* scalar_state = reinterpret_cast(0xEEEE); + const void* aggregate_state = reinterpret_cast(0xFFFF); + + MetricsManager::StartInvocation(db, scalar_state, FunctionType::LLM_COMPLETE); + MetricsManager::UpdateTokens(100, 50); + + MetricsManager::StartInvocation(db, aggregate_state, FunctionType::LLM_REDUCE); + MetricsManager::UpdateTokens(200, 100); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + bool found_complete = false, found_reduce = false; + for (const auto& [key, value]: metrics.items()) { + if (key.find("llm_complete_") == 0) { + EXPECT_EQ(value["input_tokens"].get(), 100); + found_complete = true; + } else if (key.find("llm_reduce_") == 0) { + EXPECT_EQ(value["input_tokens"].get(), 200); + found_reduce = true; + } + } + + EXPECT_TRUE(found_complete); + EXPECT_TRUE(found_reduce); +} + +TEST_F(MetricsTest, EmbeddingMetricsTracking) { + auto* db = GetDatabase(); + const void* state_id = reinterpret_cast(0xABCD); + + MetricsManager::StartInvocation(db, state_id, FunctionType::LLM_EMBEDDING); + MetricsManager::SetModelInfo("text-embedding-3-small", "openai"); + // For embeddings, typically only input tokens are used (no output tokens) + MetricsManager::UpdateTokens(150, 0); + MetricsManager::IncrementApiCalls(); + MetricsManager::AddApiDuration(250.0); + MetricsManager::AddExecutionTime(300.0); + + auto& manager = GetMetricsManager(); + auto metrics = manager.GetMetrics(); + + bool found = false; + for (const auto& [key, value]: metrics.items()) { + if (key.find("llm_embedding_") == 0) { + EXPECT_EQ(value["input_tokens"].get(), 150); + EXPECT_EQ(value["output_tokens"].get(), 0); + EXPECT_EQ(value["total_tokens"].get(), 150); + EXPECT_EQ(value["api_calls"].get(), 1); + EXPECT_NEAR(value["api_duration_ms"].get(), 250.0, 0.01); + EXPECT_NEAR(value["execution_time_ms"].get(), 300.0, 0.01); + EXPECT_EQ(value["model_name"].get(), "text-embedding-3-small"); + EXPECT_EQ(value["provider"].get(), "openai"); + found = true; + break; + } + } + EXPECT_TRUE(found); +} + +}// namespace flock diff --git a/test/unit/model_manager/model_manager_test.cpp b/test/unit/model_manager/model_manager_test.cpp index 8d328671..e67cde06 100644 --- a/test/unit/model_manager/model_manager_test.cpp +++ b/test/unit/model_manager/model_manager_test.cpp @@ -99,7 +99,7 @@ TEST_F(ModelManagerTest, ProviderSelection) { }); // Test Ollama provider json ollama_config = { - {"model_name", "llama3"}}; + {"model_name", "gemma3:4b"}}; EXPECT_NO_THROW({ Model ollama_model(ollama_config); EXPECT_EQ(ollama_model.GetModelDetails().provider_name, "ollama"); diff --git a/test/unit/model_manager/model_providers_test.cpp b/test/unit/model_manager/model_providers_test.cpp index a4a377c7..1baab5e2 100644 --- a/test/unit/model_manager/model_providers_test.cpp +++ b/test/unit/model_manager/model_providers_test.cpp @@ -101,7 +101,7 @@ TEST(ModelProvidersTest, AzureProviderTest) { TEST(ModelProvidersTest, OllamaProviderTest) { ModelDetails model_details; model_details.model_name = "test_model"; - model_details.model = "llama3"; + model_details.model = "gemma3:4b"; model_details.provider_name = "ollama"; model_details.model_parameters = {{"temperature", 0.7}}; model_details.secret = {{"api_url", "http://localhost:11434"}}; @@ -137,6 +137,69 @@ TEST(ModelProvidersTest, OllamaProviderTest) { auto embedding_results = mock_provider.CollectEmbeddings("application/json"); ASSERT_EQ(embedding_results.size(), 1); EXPECT_EQ(embedding_results[0], expected_embedding_response); + + // Set up mock behavior for AddTranscriptionRequest and CollectTranscriptions + const json audio_files = json::array({"https://example.com/audio.mp3"}); + const json expected_transcription_response = {{"text", "This is a test transcription"}}; + + EXPECT_CALL(mock_provider, AddTranscriptionRequest(audio_files)) + .Times(1); + EXPECT_CALL(mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription_response})); + + // Test the mocked transcription methods + mock_provider.AddTranscriptionRequest(audio_files); + auto transcription_results = mock_provider.CollectTranscriptions("multipart/form-data"); + ASSERT_EQ(transcription_results.size(), 1); + EXPECT_EQ(transcription_results[0], expected_transcription_response); +} + +// Test Ollama provider transcription error +TEST(ModelProvidersTest, OllamaProviderTranscriptionError) { + ModelDetails model_details; + model_details.model_name = "test_model"; + model_details.model = "gemma3:4b"; + model_details.provider_name = "ollama"; + model_details.model_parameters = {{"temperature", 0.7}}; + model_details.secret = {{"api_url", "http://localhost:11434"}}; + + OllamaProvider provider(model_details); + const json audio_files = json::array({"https://example.com/audio.mp3"}); + + // Ollama should throw an error when transcription is requested + EXPECT_THROW(provider.AddTranscriptionRequest(audio_files), std::runtime_error); +} + +// Test transcription with multiple audio files +TEST(ModelProvidersTest, TranscriptionWithMultipleFiles) { + ModelDetails model_details; + model_details.model_name = "test_model"; + model_details.model = "gpt-4o-transcribe"; + model_details.provider_name = "openai"; + model_details.model_parameters = {}; + model_details.secret = {{"api_key", "test_api_key"}}; + + MockProvider mock_provider(model_details); + + const json audio_files = json::array({"https://example.com/audio1.mp3", + "https://example.com/audio2.mp3", + "https://example.com/audio3.mp3"}); + const std::vector expected_transcription_responses = { + {{"text", "First transcription"}}, + {{"text", "Second transcription"}}, + {{"text", "Third transcription"}}}; + + EXPECT_CALL(mock_provider, AddTranscriptionRequest(audio_files)) + .Times(1); + EXPECT_CALL(mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(expected_transcription_responses)); + + mock_provider.AddTranscriptionRequest(audio_files); + auto transcription_results = mock_provider.CollectTranscriptions("multipart/form-data"); + ASSERT_EQ(transcription_results.size(), 3); + EXPECT_EQ(transcription_results[0], expected_transcription_responses[0]); + EXPECT_EQ(transcription_results[1], expected_transcription_responses[1]); + EXPECT_EQ(transcription_results[2], expected_transcription_responses[2]); } }// namespace flock \ No newline at end of file diff --git a/test/unit/prompt_manager/prompt_manager_test.cpp b/test/unit/prompt_manager/prompt_manager_test.cpp index b2154ecf..9aff5918 100644 --- a/test/unit/prompt_manager/prompt_manager_test.cpp +++ b/test/unit/prompt_manager/prompt_manager_test.cpp @@ -1,6 +1,11 @@ +#include "../functions/mock_provider.hpp" +#include "flock/core/config.hpp" +#include "flock/model_manager/model.hpp" #include "flock/prompt_manager/prompt_manager.hpp" #include "nlohmann/json.hpp" +#include #include +#include #include namespace flock { @@ -247,4 +252,177 @@ TEST(PromptManager, CreatePromptDetailsOnlyPromptName) { EXPECT_EQ(version, 6); } +// Test fixture for TranscribeAudioColumn tests +class TranscribeAudioColumnTest : public ::testing::Test { +protected: + void SetUp() override { + auto con = Config::GetConnection(); + con.Query(" CREATE SECRET (" + " TYPE OPENAI," + " API_KEY 'your-api-key');"); + con.Query(" CREATE SECRET (" + " TYPE OLLAMA," + " API_URL '127.0.0.1:11434');"); + + mock_provider = std::make_shared(ModelDetails{}); + Model::SetMockProvider(mock_provider); + } + + void TearDown() override { + Model::ResetMockProvider(); + mock_provider = nullptr; + } + + std::shared_ptr mock_provider; +}; + +// Test TranscribeAudioColumn with named column +TEST_F(TranscribeAudioColumnTest, TranscribeAudioColumnWithName) { + json audio_column = { + {"name", "audio_review"}, + {"type", "audio"}, + {"transcription_model", "gpt-4o-transcribe"}, + {"data", {"https://example.com/audio1.mp3", "https://example.com/audio2.mp3"}}}; + + json expected_transcription1 = "{\"text\": \"This is the first transcription\"}"; + json expected_transcription2 = "{\"text\": \"This is the second transcription\"}"; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription1, expected_transcription2})); + + auto result = PromptManager::TranscribeAudioColumn(audio_column); + + EXPECT_TRUE(result.contains("name")); + EXPECT_EQ(result["name"], "transcription_of_audio_review"); + EXPECT_TRUE(result.contains("data")); + EXPECT_TRUE(result["data"].is_array()); + EXPECT_EQ(result["data"].size(), 2); + EXPECT_EQ(result["data"][0], expected_transcription1); + EXPECT_EQ(result["data"][1], expected_transcription2); +} + +// Test TranscribeAudioColumn without name +TEST_F(TranscribeAudioColumnTest, TranscribeAudioColumnWithoutName) { + json audio_column = { + {"type", "audio"}, + {"transcription_model", "gpt-4o-transcribe"}, + {"data", {"https://example.com/audio.mp3"}}}; + + json expected_transcription = "{\"text\": \"Transcribed audio content\"}"; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); + + auto result = PromptManager::TranscribeAudioColumn(audio_column); + + EXPECT_TRUE(result.contains("name")); + EXPECT_EQ(result["name"], "transcription"); + EXPECT_TRUE(result.contains("data")); + EXPECT_TRUE(result["data"].is_array()); + EXPECT_EQ(result["data"].size(), 1); + EXPECT_EQ(result["data"][0], expected_transcription); +} + +// Test TranscribeAudioColumn with empty name +TEST_F(TranscribeAudioColumnTest, TranscribeAudioColumnWithEmptyName) { + json audio_column = { + {"name", ""}, + {"type", "audio"}, + {"transcription_model", "gpt-4o-transcribe"}, + {"data", {"https://example.com/audio.mp3"}}}; + + json expected_transcription = "{\"text\": \"Transcribed content\"}"; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); + + auto result = PromptManager::TranscribeAudioColumn(audio_column); + + EXPECT_TRUE(result.contains("name")); + EXPECT_EQ(result["name"], "transcription"); + EXPECT_TRUE(result.contains("data")); + EXPECT_EQ(result["data"].size(), 1); +} + +// Test TranscribeAudioColumn with single audio file +TEST_F(TranscribeAudioColumnTest, TranscribeAudioColumnSingleFile) { + json audio_column = { + {"name", "podcast"}, + {"type", "audio"}, + {"transcription_model", "gpt-4o-transcribe"}, + {"data", {"https://example.com/podcast.mp3"}}}; + + json expected_transcription = "{\"text\": \"Podcast transcription\"}"; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); + + auto result = PromptManager::TranscribeAudioColumn(audio_column); + + EXPECT_EQ(result["name"], "transcription_of_podcast"); + EXPECT_EQ(result["data"].size(), 1); + EXPECT_EQ(result["data"][0], expected_transcription); +} + +// Test TranscribeAudioColumn with multiple audio files +TEST_F(TranscribeAudioColumnTest, TranscribeAudioColumnMultipleFiles) { + json audio_column = { + {"name", "interviews"}, + {"type", "audio"}, + {"transcription_model", "gpt-4o-transcribe"}, + {"data", {"https://example.com/interview1.mp3", "https://example.com/interview2.mp3", "https://example.com/interview3.mp3"}}}; + + json expected_transcription1 = "{\"text\": \"First interview\"}"; + json expected_transcription2 = "{\"text\": \"Second interview\"}"; + json expected_transcription3 = "{\"text\": \"Third interview\"}"; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription1, expected_transcription2, expected_transcription3})); + + auto result = PromptManager::TranscribeAudioColumn(audio_column); + + EXPECT_EQ(result["name"], "transcription_of_interviews"); + EXPECT_EQ(result["data"].size(), 3); + EXPECT_EQ(result["data"][0], expected_transcription1); + EXPECT_EQ(result["data"][1], expected_transcription2); + EXPECT_EQ(result["data"][2], expected_transcription3); +} + +// Test TranscribeAudioColumn output format (JSON array) +TEST_F(TranscribeAudioColumnTest, TranscribeAudioColumnOutputFormat) { + json audio_column = { + {"name", "test_audio"}, + {"type", "audio"}, + {"transcription_model", "gpt-4o-transcribe"}, + {"data", {"https://example.com/audio.mp3"}}}; + + json expected_transcription = "{\"text\": \"Test transcription\"}"; + + EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) + .Times(1); + EXPECT_CALL(*mock_provider, CollectTranscriptions("multipart/form-data")) + .WillOnce(::testing::Return(std::vector{expected_transcription})); + + auto result = PromptManager::TranscribeAudioColumn(audio_column); + + // Verify the result is a proper JSON object with name and data fields + EXPECT_TRUE(result.is_object()); + EXPECT_TRUE(result.contains("name")); + EXPECT_TRUE(result.contains("data")); + EXPECT_TRUE(result["data"].is_array()); + + // Verify data contains the transcription results + EXPECT_EQ(result["data"][0], expected_transcription); +} + }// namespace flock \ No newline at end of file diff --git a/test/unit/unit_test.db b/test/unit/unit_test.db index 770a7caf..f850e710 100644 Binary files a/test/unit/unit_test.db and b/test/unit/unit_test.db differ