diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index bb9571042c3..aed25b4748c 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -310,7 +310,7 @@ jobs: ARROW_DATASET: ON ARROW_FLIGHT: ON ARROW_FLIGHT_SQL: ON - ARROW_FLIGHT_SQL_ODBC: ON + ARROW_FLIGHT_SQL_ODBC: OFF ARROW_GANDIVA: ON ARROW_GCS: ON ARROW_HDFS: OFF @@ -389,10 +389,6 @@ jobs: PIPX_BASE_PYTHON: ${{ steps.python-install.outputs.python-path }} run: | ci/scripts/install_gcs_testbench.sh default - - name: Register Flight SQL ODBC Driver - shell: cmd - run: | - call "cpp\src\arrow\flight\sql\odbc\tests\install_odbc.cmd" ${{ github.workspace }}\build\cpp\%ARROW_BUILD_TYPE%\libarrow_flight_sql_odbc.dll - name: Test shell: msys2 {0} run: | diff --git a/cpp/src/arrow/flight/sql/odbc/odbc.def b/cpp/src/arrow/flight/sql/odbc/odbc.def index a8191ff662b..8ba5b3fff78 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc.def +++ b/cpp/src/arrow/flight/sql/odbc/odbc.def @@ -17,8 +17,7 @@ LIBRARY arrow_flight_sql_odbc EXPORTS - ; GH-46574 TODO enable DSN window - ; ConfigDSNW + ConfigDSNW SQLAllocConnect SQLAllocEnv SQLAllocHandle diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index 01780f0efe2..a028e063b34 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -31,6 +31,11 @@ #include "arrow/flight/sql/odbc/odbc_impl/spi/connection.h" #include "arrow/util/logging.h" +#if defined _WIN32 +// For displaying DSN Window +# include "arrow/flight/sql/odbc/odbc_impl/system_dsn.h" +#endif // defined(_WIN32) + namespace arrow::flight::sql::odbc { SQLRETURN SQLAllocHandle(SQLSMALLINT type, SQLHANDLE parent, SQLHANDLE* result) { ARROW_LOG(DEBUG) << "SQLAllocHandle called with type: " << type @@ -718,8 +723,30 @@ SQLRETURN SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER value_ptr, ARROW_LOG(DEBUG) << "SQLSetConnectAttrW called with conn: " << conn << ", attr: " << attr << ", value_ptr: " << value_ptr << ", value_len: " << value_len; - // GH-47708 TODO: Implement SQLSetConnectAttr - return SQL_INVALID_HANDLE; + // GH-47708 TODO: Add tests for SQLSetConnectAttr + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + const bool is_unicode = true; + ODBCConnection* connection = reinterpret_cast(conn); + connection->SetConnectAttr(attr, value_ptr, value_len, is_unicode); + return SQL_SUCCESS; + }); +} + +// Load properties from the given DSN. The properties loaded do _not_ overwrite existing +// entries in the properties. +void LoadPropertiesFromDSN(const std::string& dsn, + Connection::ConnPropertyMap& properties) { + arrow::flight::sql::odbc::config::Configuration config; + config.LoadDsn(dsn); + Connection::ConnPropertyMap dsn_properties = config.GetProperties(); + for (auto& [key, value] : dsn_properties) { + auto prop_iter = properties.find(key); + if (prop_iter == properties.end()) { + properties.emplace(std::make_pair(std::move(key), std::move(value))); + } + } } SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle, @@ -740,13 +767,73 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle, << out_connection_string_buffer_len << ", out_connection_string_len: " << static_cast(out_connection_string_len) << ", driver_completion: " << driver_completion; + // GH-46449 TODO: Implement FILEDSN and SAVEFILE keywords according to the spec // GH-46560 TODO: Copy connection string properly in SQLDriverConnect according to the // spec - // GH-46574 TODO: Implement SQLDriverConnect - return SQL_INVALID_HANDLE; + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + ODBCConnection* connection = reinterpret_cast(conn); + std::string connection_string = + ODBC::SqlWcharToString(in_connection_string, in_connection_string_len); + Connection::ConnPropertyMap properties; + std::string dsn_value = ""; + std::optional dsn = ODBCConnection::GetDsnIfExists(connection_string); + if (dsn.has_value()) { + dsn_value = dsn.value(); + LoadPropertiesFromDSN(dsn_value, properties); + } + ODBCConnection::GetPropertiesFromConnString(connection_string, properties); + + std::vector missing_properties; + + // GH-46448 TODO: Implement SQL_DRIVER_COMPLETE_REQUIRED in SQLDriverConnect according + // to the spec +#if defined _WIN32 + // Load the DSN window according to driver_completion + if (driver_completion == SQL_DRIVER_PROMPT) { + // Load DSN window before first attempt to connect + arrow::flight::sql::odbc::config::Configuration config; + if (!DisplayConnectionWindow(window_handle, config, properties)) { + return static_cast(SQL_NO_DATA); + } + connection->Connect(dsn_value, properties, missing_properties); + } else if (driver_completion == SQL_DRIVER_COMPLETE || + driver_completion == SQL_DRIVER_COMPLETE_REQUIRED) { + try { + connection->Connect(dsn_value, properties, missing_properties); + } catch (const DriverException&) { + // If first connection fails due to missing attributes, load + // the DSN window and try to connect again + if (!missing_properties.empty()) { + arrow::flight::sql::odbc::config::Configuration config; + missing_properties.clear(); + + if (!DisplayConnectionWindow(window_handle, config, properties)) { + return static_cast(SQL_NO_DATA); + } + connection->Connect(dsn_value, properties, missing_properties); + } else { + throw; + } + } + } else { + // Default case: attempt connection without showing DSN window + connection->Connect(dsn_value, properties, missing_properties); + } +#else + // Attempt connection without loading DSN window on macOS/Linux + connection->Connect(dsn, properties, missing_properties); +#endif + // Copy connection string to out_connection_string after connection attempt + return ODBC::GetStringAttribute(true, connection_string, false, out_connection_string, + out_connection_string_buffer_len, + out_connection_string_len, + connection->GetDiagnostics()); + }); } SQLRETURN SQLConnect(SQLHDBC conn, SQLWCHAR* dsn_name, SQLSMALLINT dsn_name_len, @@ -759,14 +846,48 @@ SQLRETURN SQLConnect(SQLHDBC conn, SQLWCHAR* dsn_name, SQLSMALLINT dsn_name_len, << ", user_name_len: " << user_name_len << ", password: " << static_cast(password) << ", password_len: " << password_len; - // GH-46574 TODO: Implement SQLConnect - return SQL_INVALID_HANDLE; + + using ODBC::ODBCConnection; + + using ODBC::SqlWcharToString; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + ODBCConnection* connection = reinterpret_cast(conn); + std::string dsn = SqlWcharToString(dsn_name, dsn_name_len); + + Configuration config; + config.LoadDsn(dsn); + + if (user_name) { + std::string uid = SqlWcharToString(user_name, user_name_len); + config.Emplace(FlightSqlConnection::UID, std::move(uid)); + } + + if (password) { + std::string pwd = SqlWcharToString(password, password_len); + config.Emplace(FlightSqlConnection::PWD, std::move(pwd)); + } + + std::vector missing_properties; + + connection->Connect(dsn, config.GetProperties(), missing_properties); + + return SQL_SUCCESS; + }); } SQLRETURN SQLDisconnect(SQLHDBC conn) { ARROW_LOG(DEBUG) << "SQLDisconnect called with conn: " << conn; - // GH-46574 TODO: Implement SQLDisconnect - return SQL_INVALID_HANDLE; + + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + ODBCConnection* connection = reinterpret_cast(conn); + + connection->Disconnect(); + + return SQL_SUCCESS; + }); } SQLRETURN SQLGetInfo(SQLHDBC conn, SQLUSMALLINT info_type, SQLPOINTER info_value_ptr, @@ -776,8 +897,24 @@ SQLRETURN SQLGetInfo(SQLHDBC conn, SQLUSMALLINT info_type, SQLPOINTER info_value << ", info_value_ptr: " << info_value_ptr << ", buf_len: " << buf_len << ", string_length_ptr: " << static_cast(string_length_ptr); - // GH-47709 TODO: Implement SQLGetInfo - return SQL_INVALID_HANDLE; + + // GH-47709 TODO: Update SQLGetInfo implementation and add tests for SQLGetInfo + using ODBC::ODBCConnection; + + return ODBCConnection::ExecuteWithDiagnostics(conn, SQL_ERROR, [=]() { + ODBCConnection* connection = reinterpret_cast(conn); + + // Set character type to be Unicode by default + const bool is_unicode = true; + + if (!info_value_ptr && !string_length_ptr) { + return static_cast(SQL_ERROR); + } + + connection->GetInfo(info_type, info_value_ptr, buf_len, string_length_ptr, + is_unicode); + return static_cast(SQL_SUCCESS); + }); } SQLRETURN SQLGetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER value_ptr, diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt index b232577ee37..8f09fccd71d 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt @@ -124,7 +124,9 @@ if(WIN32) ui/dsn_configuration_window.h ui/window.cc ui/window.h - system_dsn.cc) + win_system_dsn.cc + system_dsn.cc + system_dsn.h) endif() target_link_libraries(arrow_odbc_spi_impl PUBLIC arrow_flight_sql_shared diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h index 7baea759ede..8c5eae59f7e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h @@ -17,17 +17,17 @@ #pragma once -#include -#include -#include #include #include #include #include #include +#include "arrow/flight/sql/odbc/odbc_impl/diagnostics.h" +#include "arrow/flight/sql/odbc/odbc_impl/encoding_utils.h" +#include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" -#include - +// GH-48083 TODO: replace `namespace ODBC` with `namespace arrow::flight::sql::odbc` namespace ODBC { using arrow::flight::sql::odbc::Diagnostics; @@ -48,12 +48,12 @@ inline void GetAttribute(T attribute_value, SQLPOINTER output, O output_size, } template -inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER output, +inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER output, O output_size, O* output_len_ptr) { if (output) { size_t output_len_before_null = std::min(static_cast(attribute_value.size()), static_cast(output_size - 1)); - memcpy(output, attribute_value.c_str(), output_len_before_null); + std::memcpy(output, attribute_value.data(), output_len_before_null); reinterpret_cast(output)[output_len_before_null] = '\0'; } @@ -68,7 +68,7 @@ inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER } template -inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER output, +inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER output, O output_size, O* output_len_ptr, Diagnostics& diagnostics) { SQLRETURN result = @@ -80,7 +80,7 @@ inline SQLRETURN GetAttributeUTF8(const std::string& attribute_value, SQLPOINTER } template -inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value, +inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value, bool is_length_in_bytes, SQLPOINTER output, O output_size, O* output_len_ptr) { size_t length = ConvertToSqlWChar( @@ -104,7 +104,7 @@ inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value, } template -inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value, +inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value, bool is_length_in_bytes, SQLPOINTER output, O output_size, O* output_len_ptr, Diagnostics& diagnostics) { @@ -117,7 +117,7 @@ inline SQLRETURN GetAttributeSQLWCHAR(const std::string& attribute_value, } template -inline SQLRETURN GetStringAttribute(bool is_unicode, const std::string& attribute_value, +inline SQLRETURN GetStringAttribute(bool is_unicode, std::string_view attribute_value, bool is_length_in_bytes, SQLPOINTER output, O output_size, O* output_len_ptr, Diagnostics& diagnostics) { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.cc index cdb889f0567..df61f1247c7 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.cc @@ -35,7 +35,7 @@ static const char DEFAULT_USE_CERT_STORE[] = TRUE_STR; static const char DEFAULT_DISABLE_CERT_VERIFICATION[] = FALSE_STR; namespace { -std::string ReadDsnString(const std::string& dsn, const std::string_view& key, +std::string ReadDsnString(const std::string& dsn, std::string_view key, const std::string& dflt = "") { CONVERT_WIDE_STR(const std::wstring wdsn, dsn); CONVERT_WIDE_STR(const std::wstring wkey, key); @@ -150,11 +150,11 @@ void Configuration::LoadDsn(const std::string& dsn) { void Configuration::Clear() { this->properties_.clear(); } -bool Configuration::IsSet(const std::string_view& key) const { +bool Configuration::IsSet(std::string_view key) const { return 0 != this->properties_.count(key); } -const std::string& Configuration::Get(const std::string_view& key) const { +const std::string& Configuration::Get(std::string_view key) const { const auto itr = this->properties_.find(key); if (itr == this->properties_.cend()) { static const std::string empty(""); @@ -163,15 +163,22 @@ const std::string& Configuration::Get(const std::string_view& key) const { return itr->second; } -void Configuration::Set(const std::string_view& key, const std::wstring& wvalue) { +void Configuration::Set(std::string_view key, const std::wstring& wvalue) { CONVERT_UTF8_STR(const std::string value, wvalue); Set(key, value); } -void Configuration::Set(const std::string_view& key, const std::string& value) { +void Configuration::Set(std::string_view key, const std::string& value) { const std::string copy = boost::trim_copy(value); if (!copy.empty()) { - this->properties_[key] = value; + this->properties_[std::string(key)] = value; + } +} + +void Configuration::Emplace(std::string_view key, std::string&& value) { + const std::string copy = boost::trim_copy(value); + if (!copy.empty()) { + this->properties_.emplace(std::make_pair(key, std::move(value))); } } @@ -182,7 +189,7 @@ const Connection::ConnPropertyMap& Configuration::GetProperties() const { std::vector Configuration::GetCustomKeys() const { Connection::ConnPropertyMap copy_props(properties_); for (auto& key : FlightSqlConnection::ALL_KEYS) { - copy_props.erase(key); + copy_props.erase(std::string(key)); } std::vector keys; boost::copy(copy_props | boost::adaptors::map_keys, std::back_inserter(keys)); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.h index 77d07b1420a..0390a57e52f 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/config/configuration.h @@ -46,22 +46,15 @@ class Configuration { */ ~Configuration(); - /** - * Convert configure to connect string. - * - * @return Connect string. - */ - std::string ToConnectString() const; - void LoadDefaults(); void LoadDsn(const std::string& dsn); void Clear(); - bool IsSet(const std::string_view& key) const; - const std::string& Get(const std::string_view& key) const; - void Set(const std::string_view& key, const std::wstring& wvalue); - void Set(const std::string_view& key, const std::string& value); - + bool IsSet(std::string_view key) const; + const std::string& Get(std::string_view key) const; + void Set(std::string_view key, const std::wstring& wvalue); + void Set(std::string_view key, const std::string& value); + void Emplace(std::string_view key, std::string&& value); /** * Get properties map. */ diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h index a5cc3a6f4c8..66e5c3bf0d8 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h @@ -16,7 +16,6 @@ // under the License. #pragma once - #include "arrow/flight/sql/odbc/odbc_impl/encoding.h" #include "arrow/flight/sql/odbc/odbc_impl/platform.h" @@ -40,15 +39,15 @@ using arrow::flight::sql::odbc::WcsToUtf8; // Return the number of bytes required for the conversion. template -inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, +inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer, SQLLEN buffer_size_in_bytes) { thread_local std::vector wstr; Utf8ToWcs(str.data(), str.size(), &wstr); SQLLEN value_length_in_bytes = wstr.size(); if (buffer) { - memcpy(buffer, wstr.data(), - std::min(static_cast(wstr.size()), buffer_size_in_bytes)); + std::memcpy(buffer, wstr.data(), + std::min(static_cast(wstr.size()), buffer_size_in_bytes)); // Write a NUL terminator if (buffer_size_in_bytes >= @@ -67,7 +66,7 @@ inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, return value_length_in_bytes; } -inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, +inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer, SQLLEN buffer_size_in_bytes) { switch (GetSqlWCharSize()) { case sizeof(char16_t): @@ -101,4 +100,22 @@ inline std::string SqlWcharToString(SQLWCHAR* wchar_msg, SQLINTEGER msg_len = SQ return std::string(utf8_str.begin(), utf8_str.end()); } +inline std::string SqlStringToString(const unsigned char* sql_str, + int32_t sql_str_len = SQL_NTS) { + std::string res; + + const char* sql_str_c = reinterpret_cast(sql_str); + + if (!sql_str) { + return res; + } + + if (sql_str_len == SQL_NTS) { + res.assign(sql_str_c); + } else if (sql_str_len > 0) { + res.assign(sql_str_c, sql_str_len); + } + + return res; +} } // namespace ODBC diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc index 479a72f3fea..e18a58d069f 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc @@ -99,7 +99,14 @@ inline std::string GetCerts() { return ""; } #endif -const std::set BUILT_IN_PROPERTIES = { +// Case insensitive comparator that takes string_view +struct CaseInsensitiveComparatorStrView { + bool operator()(std::string_view s1, std::string_view s2) const { + return boost::lexicographical_compare(s1, s2, boost::is_iless()); + } +}; + +const std::set BUILT_IN_PROPERTIES = { FlightSqlConnection::HOST, FlightSqlConnection::PORT, FlightSqlConnection::USER, @@ -116,7 +123,7 @@ const std::set BUILT_IN_PROPERTIES FlightSqlConnection::USE_WIDE_CHAR}; Connection::ConnPropertyMap::const_iterator TrackMissingRequiredProperty( - const std::string_view& property, const Connection::ConnPropertyMap& properties, + std::string_view property, const Connection::ConnPropertyMap& properties, std::vector& missing_attr) { auto prop_iter = properties.find(property); if (properties.end() == prop_iter) { @@ -138,7 +145,7 @@ std::shared_ptr LoadFlightSslConfigs( .value_or(SYSTEM_TRUST_STORE_DEFAULT); auto trusted_certs_iterator = - conn_property_map.find(FlightSqlConnection::TRUSTED_CERTS); + conn_property_map.find(std::string(FlightSqlConnection::TRUSTED_CERTS)); auto trusted_certs = trusted_certs_iterator != conn_property_map.end() ? trusted_certs_iterator->second : ""; @@ -161,6 +168,8 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties, std::unique_ptr flight_client; ThrowIfNotOK(FlightClient::Connect(location, client_options).Value(&flight_client)); + PopulateMetadataSettings(properties); + PopulateCallOptions(properties); std::unique_ptr auth_method = FlightSqlAuthMethod::FromProperties(flight_client, properties); @@ -175,9 +184,6 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties, info_.SetProperty(SQL_USER_NAME, auth_method->GetUser()); attribute_[CONNECTION_DEAD] = static_cast(SQL_FALSE); - - PopulateMetadataSettings(properties); - PopulateCallOptions(properties); } catch (...) { attribute_[CONNECTION_DEAD] = static_cast(SQL_TRUE); sql_client_.reset(); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection_test.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection_test.cc index a42d0198527..9c9b0f8f3c1 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection_test.cc @@ -33,7 +33,7 @@ TEST(AttributeTests, SetAndGetAttribute) { EXPECT_TRUE(first_value); - EXPECT_EQ(boost::get(*first_value), static_cast(200)); + EXPECT_EQ(static_cast(200), boost::get(*first_value)); connection.SetAttribute(Connection::CONNECTION_TIMEOUT, static_cast(300)); @@ -41,7 +41,7 @@ TEST(AttributeTests, SetAndGetAttribute) { connection.GetAttribute(Connection::CONNECTION_TIMEOUT); EXPECT_TRUE(change_value); - EXPECT_EQ(boost::get(*change_value), static_cast(300)); + EXPECT_EQ(static_cast(300), boost::get(*change_value)); connection.Close(); } @@ -65,10 +65,10 @@ TEST(MetadataSettingsTest, StringColumnLengthTest) { const int32_t expected_string_column_length = 100000; const Connection::ConnPropertyMap properties = { - {FlightSqlConnection::HOST, std::string("localhost")}, // expect not used - {FlightSqlConnection::PORT, std::string("32010")}, // expect not used - {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, // expect not used - {FlightSqlConnection::STRING_COLUMN_LENGTH, + {std::string(FlightSqlConnection::HOST), "localhost"}, // expect not used + {std::string(FlightSqlConnection::PORT), "32010"}, // expect not used + {std::string(FlightSqlConnection::USE_ENCRYPTION), "false"}, // expect not used + {std::string(FlightSqlConnection::STRING_COLUMN_LENGTH), std::to_string(expected_string_column_length)}, }; @@ -86,10 +86,10 @@ TEST(MetadataSettingsTest, UseWideCharTest) { connection.SetClosed(false); const Connection::ConnPropertyMap properties1 = { - {FlightSqlConnection::USE_WIDE_CHAR, std::string("true")}, + {std::string(FlightSqlConnection::USE_WIDE_CHAR), "true"}, }; const Connection::ConnPropertyMap properties2 = { - {FlightSqlConnection::USE_WIDE_CHAR, std::string("false")}, + {std::string(FlightSqlConnection::USE_WIDE_CHAR), "false"}, }; EXPECT_EQ(true, connection.GetUseWideChar(properties1)); @@ -101,9 +101,9 @@ TEST(MetadataSettingsTest, UseWideCharTest) { TEST(BuildLocationTests, ForTcp) { std::vector missing_attr; Connection::ConnPropertyMap properties = { - {FlightSqlConnection::HOST, std::string("localhost")}, - {FlightSqlConnection::PORT, std::string("32010")}, - {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, + {std::string(FlightSqlConnection::HOST), "localhost"}, + {std::string(FlightSqlConnection::PORT), "32010"}, + {std::string(FlightSqlConnection::USE_ENCRYPTION), "false"}, }; const std::shared_ptr& ssl_config = @@ -113,8 +113,8 @@ TEST(BuildLocationTests, ForTcp) { FlightSqlConnection::BuildLocation(properties, missing_attr, ssl_config); const Location& actual_location2 = FlightSqlConnection::BuildLocation( { - {FlightSqlConnection::HOST, std::string("localhost")}, - {FlightSqlConnection::PORT, std::string("32011")}, + {std::string(FlightSqlConnection::HOST), "localhost"}, + {std::string(FlightSqlConnection::PORT), "32011"}, }, missing_attr, ssl_config); @@ -127,9 +127,9 @@ TEST(BuildLocationTests, ForTcp) { TEST(BuildLocationTests, ForTls) { std::vector missing_attr; Connection::ConnPropertyMap properties = { - {FlightSqlConnection::HOST, std::string("localhost")}, - {FlightSqlConnection::PORT, std::string("32010")}, - {FlightSqlConnection::USE_ENCRYPTION, std::string("1")}, + {std::string(FlightSqlConnection::HOST), "localhost"}, + {std::string(FlightSqlConnection::PORT), "32010"}, + {std::string(FlightSqlConnection::USE_ENCRYPTION), "1"}, }; const std::shared_ptr& ssl_config = @@ -139,9 +139,9 @@ TEST(BuildLocationTests, ForTls) { FlightSqlConnection::BuildLocation(properties, missing_attr, ssl_config); Connection::ConnPropertyMap second_properties = { - {FlightSqlConnection::HOST, std::string("localhost")}, - {FlightSqlConnection::PORT, std::string("32011")}, - {FlightSqlConnection::USE_ENCRYPTION, std::string("1")}, + {std::string(FlightSqlConnection::HOST), "localhost"}, + {std::string(FlightSqlConnection::PORT), "32011"}, + {std::string(FlightSqlConnection::USE_ENCRYPTION), "1"}, }; const std::shared_ptr& second_ssl_config = diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/main.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/main.cc index 8f649311e9d..3336e0160e1 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/main.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/main.cc @@ -43,7 +43,7 @@ using arrow::flight::sql::odbc::Statement; void TestBindColumn(const std::shared_ptr& connection) { const std::shared_ptr& statement = connection->CreateStatement(); - statement->Execute("SELECT IncidntNum, Category FROM \"@dremio\".Test LIMIT 10"); + statement->Execute("SELECT IncidntNum, Category FROM \"@apache\".Test LIMIT 10"); const std::shared_ptr& result_set = statement->GetResultSet(); @@ -105,7 +105,7 @@ void TestBindColumnBigInt(const std::shared_ptr& connection) { " SELECT CONVERT_TO_INTEGER(IncidntNum, 1, 1, 0) AS IncidntNum, " "Category\n" " FROM (\n" - " SELECT IncidntNum, Category FROM \"@dremio\".Test LIMIT 10\n" + " SELECT IncidntNum, Category FROM \"@apache\".Test LIMIT 10\n" " ) nested_0\n" ") nested_0"); @@ -202,11 +202,11 @@ int main() { driver.CreateConnection(arrow::flight::sql::odbc::OdbcVersion::V_3); Connection::ConnPropertyMap properties = { - {FlightSqlConnection::HOST, std::string("automaster.drem.io")}, - {FlightSqlConnection::PORT, std::string("32010")}, - {FlightSqlConnection::USER, std::string("dremio")}, - {FlightSqlConnection::PASSWORD, std::string("dremio123")}, - {FlightSqlConnection::USE_ENCRYPTION, std::string("false")}, + {std::string(FlightSqlConnection::HOST), "automaster.apache"}, + {std::string(FlightSqlConnection::PORT), "32010"}, + {std::string(FlightSqlConnection::USER), "apache"}, + {std::string(FlightSqlConnection::PASSWORD), "apache123"}, + {std::string(FlightSqlConnection::USE_ENCRYPTION), "false"}, }; std::vector missing_attr; connection->Connect(properties, missing_attr); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.cc index c0a55840d56..ead2beada4b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.cc @@ -53,57 +53,7 @@ namespace { // characters such as semi-colons and equals signs. NOTE: This can be optimized to be // built statically. const boost::xpressive::sregex CONNECTION_STR_REGEX( - boost::xpressive::sregex::compile("([^=;]+)=({.+}|[^=;]+|[^;])")); - -// Load properties from the given DSN. The properties loaded do _not_ overwrite existing -// entries in the properties. -void loadPropertiesFromDSN(const std::string& dsn, - Connection::ConnPropertyMap& properties) { - const size_t BUFFER_SIZE = 1024 * 10; - std::vector output_buffer; - output_buffer.resize(BUFFER_SIZE, '\0'); - SQLSetConfigMode(ODBC_BOTH_DSN); - - CONVERT_WIDE_STR(const std::wstring wdsn, dsn); - - SQLGetPrivateProfileString(wdsn.c_str(), NULL, L"", &output_buffer[0], BUFFER_SIZE, - L"odbc.ini"); - - // The output buffer holds the list of keys in a series of NUL-terminated strings. - // The series is terminated with an empty string (eg a NUL-terminator terminating the - // last key followed by a NUL terminator after). - std::vector keys; - size_t pos = 0; - while (pos < BUFFER_SIZE) { - std::wstring wkey(&output_buffer[pos]); - if (wkey.empty()) { - break; - } - size_t len = wkey.size(); - - // Skip over Driver or DSN keys. - if (!boost::iequals(wkey, L"DSN") && !boost::iequals(wkey, L"Driver")) { - keys.emplace_back(std::move(wkey)); - } - pos += len + 1; - } - - for (auto& wkey : keys) { - output_buffer.clear(); - output_buffer.resize(BUFFER_SIZE, '\0'); - SQLGetPrivateProfileString(wdsn.c_str(), wkey.data(), L"", &output_buffer[0], - BUFFER_SIZE, L"odbc.ini"); - - std::wstring wvalue = std::wstring(&output_buffer[0]); - CONVERT_UTF8_STR(const std::string value, wvalue); - CONVERT_UTF8_STR(const std::string key, std::wstring(wkey)); - auto propIter = properties.find(key); - if (propIter == properties.end()) { - properties.emplace(std::make_pair(std::move(key), std::move(value))); - } - } -} - + boost::xpressive::sregex::compile("([^=;]+)=({.+}|[^;]+|[^;])")); } // namespace // Public @@ -734,39 +684,43 @@ void ODBCConnection::DropDescriptor(ODBCDescriptor* desc) { // Public Static // =================================================================================== -std::string ODBCConnection::GetPropertiesFromConnString( +std::optional ODBCConnection::GetDsnIfExists(const std::string& conn_str) { + const int groups[] = {1, 2}; // CONNECTION_STR_REGEX has two groups. key: 1, value: 2 + boost::xpressive::sregex_token_iterator regex_iter(conn_str.begin(), conn_str.end(), + CONNECTION_STR_REGEX, groups), + end; + + // First key in connection string should be either dsn or driver + auto it = regex_iter; + std::string key = *regex_iter; + std::string value = *++regex_iter; + + // Strip wrapping curly braces. + if (value.size() >= 2 && value[0] == '{' && value[value.size() - 1] == '}') { + value = value.substr(1, value.size() - 2); + } + + if (boost::iequals(key, "DSN")) { + return value; + } else if (boost::iequals(key, "Driver")) { + return std::nullopt; + } else { + throw DriverException( + "Connection string is faulty. The first key should be DSN or Driver.", "HY000"); + } +} + +void ODBCConnection::GetPropertiesFromConnString( const std::string& conn_str, Connection::ConnPropertyMap& properties) { const int groups[] = {1, 2}; // CONNECTION_STR_REGEX has two groups. key: 1, value: 2 boost::xpressive::sregex_token_iterator regex_iter(conn_str.begin(), conn_str.end(), CONNECTION_STR_REGEX, groups), end; - bool is_dsn_first = false; - bool is_driver_first = false; - std::string dsn; for (auto it = regex_iter; end != regex_iter; ++regex_iter) { std::string key = *regex_iter; std::string value = *++regex_iter; - // If the DSN shows up before driver key, load settings from the DSN. - // Only load values from the DSN once regardless of how many times the DSN - // key shows up. - if (boost::iequals(key, "DSN")) { - if (!is_driver_first) { - if (!is_dsn_first) { - is_dsn_first = true; - loadPropertiesFromDSN(value, properties); - dsn.swap(value); - } - } - continue; - } else if (boost::iequals(key, "Driver")) { - if (!is_dsn_first) { - is_driver_first = true; - } - continue; - } - // Strip wrapping curly braces. if (value.size() >= 2 && value[0] == '{' && value[value.size() - 1] == '}') { value = value.substr(1, value.size() - 2); @@ -776,5 +730,4 @@ std::string ODBCConnection::GetPropertiesFromConnString( // including over entries in the DSN. properties[key] = std::move(value); } - return dsn; } diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.h index 8157c2f5f94..2e5ab57ad49 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_connection.h @@ -23,6 +23,7 @@ #include #include #include +#include #include namespace ODBC { @@ -75,8 +76,11 @@ class ODBCConnection : public ODBCHandle { inline bool IsOdbc2Connection() const { return is_2x_connection_; } - /// @return the DSN or empty string if Driver was used. - static std::string GetPropertiesFromConnString( + /// \return an optional DSN + static std::optional GetDsnIfExists(const std::string& conn_str); + + /// Read properties from connection string, but does not read values from DSN + static void GetPropertiesFromConnString( const std::string& conn_str, arrow::flight::sql::odbc::Connection::ConnPropertyMap& properties); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/connection.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/connection.h index e24af6c3dd7..7a8243e7859 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/connection.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/connection.h @@ -32,13 +32,15 @@ namespace arrow::flight::sql::odbc { /// \brief Case insensitive comparator struct CaseInsensitiveComparator { - bool operator()(const std::string_view& s1, const std::string_view& s2) const { + using is_transparent = std::true_type; + + bool operator()(std::string_view s1, std::string_view s2) const { return boost::lexicographical_compare(s1, s2, boost::is_iless()); } }; // PropertyMap is case-insensitive for keys. -typedef std::map PropertyMap; +typedef std::map PropertyMap; class Statement; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc index 75501ac8dd4..468f05e4cf4 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc @@ -17,18 +17,11 @@ #include "arrow/flight/sql/odbc/odbc_impl/system_dsn.h" -// platform.h includes windows.h, so it needs to be included -// before winuser.h -#include "arrow/flight/sql/odbc/odbc_impl/platform.h" - -#include -#include #include "arrow/flight/sql/odbc/odbc_impl/config/configuration.h" -#include "arrow/flight/sql/odbc/odbc_impl/config/connection_string_parser.h" -#include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" #include "arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h" #include "arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.h" #include "arrow/flight/sql/odbc/odbc_impl/ui/window.h" +#include "arrow/flight/sql/odbc/odbc_impl/util.h" #include "arrow/result.h" #include "arrow/util/utf8.h" @@ -38,41 +31,6 @@ namespace arrow::flight::sql::odbc { using config::Configuration; -using config::ConnectionStringParser; -using config::DsnConfigurationWindow; -using config::Result; -using config::Window; - -bool DisplayConnectionWindow(void* window_parent, Configuration& config) { - HWND hwnd_parent = (HWND)window_parent; - - if (!hwnd_parent) return true; - - try { - Window parent(hwnd_parent); - DsnConfigurationWindow window(&parent, config); - - window.Create(); - - window.Show(); - window.Update(); - - return ProcessMessages(window) == Result::OK; - } catch (const DriverException& err) { - std::stringstream buf; - buf << "SQL State: " << err.GetSqlState() << ", Message: " << err.GetMessageText() - << ", Code: " << err.GetNativeError(); - std::wstring wmessage = - arrow::util::UTF8ToWideString(buf.str()).ValueOr(L"Error during load DSN"); - MessageBox(NULL, wmessage.c_str(), L"Error!", MB_ICONEXCLAMATION | MB_OK); - - std::wstring wmessage_text = arrow::util::UTF8ToWideString(err.GetMessageText()) - .ValueOr(L"Error during load DSN"); - SQLPostInstallerError(err.GetNativeError(), wmessage_text.c_str()); - } - - return false; -} void PostError(DWORD error_code, LPCWSTR error_msg) { MessageBox(NULL, error_msg, L"Error!", MB_ICONEXCLAMATION | MB_OK); @@ -138,7 +96,7 @@ bool RegisterDsn(const Configuration& config, LPCWSTR driver) { const auto& map = config.GetProperties(); for (auto it = map.begin(); it != map.end(); ++it) { - const std::string_view& key = it->first; + std::string_view key = it->first; if (boost::iequals(FlightSqlConnection::DSN, key) || boost::iequals(FlightSqlConnection::DRIVER, key)) { continue; @@ -167,77 +125,4 @@ bool RegisterDsn(const Configuration& config, LPCWSTR driver) { return true; } - -BOOL INSTAPI ConfigDSNW(HWND hwnd_parent, WORD req, LPCWSTR wdriver, - LPCWSTR wattributes) { - Configuration config; - ConnectionStringParser parser(config); - - auto attributes_result = arrow::util::WideStringToUTF8(std::wstring(wattributes)); - if (!attributes_result.status().ok()) { - PostArrowUtilError(attributes_result.status()); - return FALSE; - } - std::string attributes = attributes_result.ValueOrDie(); - - parser.ParseConfigAttributes(attributes.c_str()); - - switch (req) { - case ODBC_ADD_DSN: { - config.LoadDefaults(); - if (!DisplayConnectionWindow(hwnd_parent, config) || !RegisterDsn(config, wdriver)) - return FALSE; - - break; - } - - case ODBC_CONFIG_DSN: { - const std::string& dsn = config.Get(FlightSqlConnection::DSN); - auto wdsn_result = arrow::util::UTF8ToWideString(dsn); - if (!wdsn_result.status().ok()) { - PostArrowUtilError(wdsn_result.status()); - return FALSE; - } - std::wstring wdsn = wdsn_result.ValueOrDie(); - if (!SQLValidDSN(wdsn.c_str())) return FALSE; - - Configuration loaded(config); - try { - loaded.LoadDsn(dsn); - } catch (const DriverException& err) { - std::string error_msg = err.GetMessageText(); - std::wstring werror_msg = - arrow::util::UTF8ToWideString(error_msg).ValueOr(L"Error during DSN load"); - - PostError(err.GetNativeError(), werror_msg.c_str()); - return FALSE; - } - - if (!DisplayConnectionWindow(hwnd_parent, loaded) || !UnregisterDsn(wdsn.c_str()) || - !RegisterDsn(loaded, wdriver)) - return FALSE; - - break; - } - - case ODBC_REMOVE_DSN: { - const std::string& dsn = config.Get(FlightSqlConnection::DSN); - auto wdsn_result = arrow::util::UTF8ToWideString(dsn); - if (!wdsn_result.status().ok()) { - PostArrowUtilError(wdsn_result.status()); - return FALSE; - } - std::wstring wdsn = wdsn_result.ValueOrDie(); - if (!SQLValidDSN(wdsn.c_str()) || !UnregisterDsn(wdsn)) return FALSE; - - break; - } - - default: - return FALSE; - } - - return TRUE; -} - } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h index 32d17af6753..5d23c3dfcaf 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h @@ -19,6 +19,7 @@ #include "arrow/flight/sql/odbc/odbc_impl/platform.h" #include "arrow/flight/sql/odbc/odbc_impl/config/configuration.h" +#include "arrow/status.h" namespace arrow::flight::sql::odbc { @@ -65,4 +66,7 @@ bool RegisterDsn(const Configuration& config, LPCWSTR driver); */ bool UnregisterDsn(const std::wstring& dsn); +void PostError(DWORD error_code, LPCWSTR error_msg); + +void PostArrowUtilError(arrow::Status error_status); } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.cc index 3f49690daad..0432836a16f 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.cc @@ -18,9 +18,9 @@ #include "arrow/result.h" #include "arrow/util/utf8.h" -#include "arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.h" - #include "arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h" +#include "arrow/flight/sql/odbc/odbc_impl/ui/add_property_window.h" +#include "arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.h" #include "arrow/flight/sql/odbc/odbc_impl/util.h" #include @@ -30,16 +30,13 @@ #include #include -#include "arrow/flight/sql/odbc/odbc_impl/ui/add_property_window.h" - #define COMMON_TAB 0 #define ADVANCED_TAB 1 namespace arrow::flight::sql::odbc { namespace { std::string TestConnection(const config::Configuration& config) { - std::unique_ptr flight_sql_conn( - new FlightSqlConnection(OdbcVersion::V_3)); + std::unique_ptr flight_sql_conn(new FlightSqlConnection(V_3)); std::vector missing_properties; flight_sql_conn->Connect(config.GetProperties(), missing_properties); @@ -250,6 +247,7 @@ int DsnConfigurationWindow::CreateEncryptionSettingsGroup(int pos_x, int pos_y, std::string val = config_.Get(FlightSqlConnection::USE_ENCRYPTION); + // Enable encryption default value is true const bool enable_encryption = util::AsBool(val).value_or(true); labels_.push_back(CreateLabel(label_pos_x, row_pos, LABEL_WIDTH, ROW_HEIGHT, L"Use Encryption:", ChildId::ENABLE_ENCRYPTION_LABEL)); @@ -275,6 +273,7 @@ int DsnConfigurationWindow::CreateEncryptionSettingsGroup(int pos_x, int pos_y, val = config_.Get(FlightSqlConnection::USE_SYSTEM_TRUST_STORE).c_str(); + // System trust store default value is true const bool use_system_cert_store = util::AsBool(val).value_or(true); labels_.push_back(CreateLabel(label_pos_x, row_pos, LABEL_WIDTH, 2 * ROW_HEIGHT, L"Use System Certificate Store:", diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc index df6aff9cfa7..59ee7dda565 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.cc @@ -1108,8 +1108,8 @@ boost::optional AsBool(const std::string& value) { } boost::optional AsBool(const Connection::ConnPropertyMap& conn_property_map, - const std::string_view& property_name) { - auto extracted_property = conn_property_map.find(std::string(property_name)); + std::string_view property_name) { + auto extracted_property = conn_property_map.find(property_name); if (extracted_property != conn_property_map.end()) { return AsBool(extracted_property->second); @@ -1120,8 +1120,8 @@ boost::optional AsBool(const Connection::ConnPropertyMap& conn_property_ma boost::optional AsInt32(int32_t min_value, const Connection::ConnPropertyMap& conn_property_map, - const std::string_view& property_name) { - auto extracted_property = conn_property_map.find(std::string(property_name)); + std::string_view property_name) { + auto extracted_property = conn_property_map.find(property_name); if (extracted_property != conn_property_map.end()) { const int32_t string_column_length = std::stoi(extracted_property->second); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h index 8197f741d1e..c17e77e7de8 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/util.h @@ -144,7 +144,7 @@ boost::optional AsBool(const std::string& value); /// \param property_name the name of the property that will be looked up. /// \return the parsed valued. boost::optional AsBool(const Connection::ConnPropertyMap& conn_property_map, - const std::string_view& property_name); + std::string_view property_name); /// Looks up for a value inside the ConnPropertyMap and then try to parse it. /// In case it does not find or it cannot parse, the default value will be returned. @@ -156,7 +156,7 @@ boost::optional AsBool(const Connection::ConnPropertyMap& conn_property_ma /// std::out_of_range exception from std::stoi boost::optional AsInt32(int32_t min_value, const Connection::ConnPropertyMap& conn_property_map, - const std::string_view& property_name); + std::string_view property_name); } // namespace util } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/win_system_dsn.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/win_system_dsn.cc new file mode 100644 index 00000000000..2ea9a2451c2 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/win_system_dsn.cc @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/odbc/odbc_impl/system_dsn.h" + +// platform.h includes windows.h, so it needs to be included +// before winuser.h +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include + +#include "arrow/result.h" +#include "arrow/util/utf8.h" + +#include "arrow/flight/sql/odbc/odbc_impl/config/configuration.h" +#include "arrow/flight/sql/odbc/odbc_impl/config/connection_string_parser.h" +#include "arrow/flight/sql/odbc/odbc_impl/exceptions.h" +#include "arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h" +#include "arrow/flight/sql/odbc/odbc_impl/ui/dsn_configuration_window.h" +#include "arrow/flight/sql/odbc/odbc_impl/ui/window.h" +#include "arrow/util/logging.h" + +#include +#include +#include +#include + +namespace arrow::flight::sql::odbc { +using config::Configuration; +using config::ConnectionStringParser; +using config::DsnConfigurationWindow; +using config::Result; +using config::Window; +bool DisplayConnectionWindow(void* window_parent, Configuration& config) { + HWND hwnd_parent = (HWND)window_parent; + + if (!hwnd_parent) return true; + + try { + Window parent(hwnd_parent); + DsnConfigurationWindow window(&parent, config); + + window.Create(); + + window.Show(); + window.Update(); + + return ProcessMessages(window) == Result::OK; + } catch (const DriverException& err) { + std::stringstream buf; + buf << "SQL State: " << err.GetSqlState() << ", Message: " << err.GetMessageText() + << ", Code: " << err.GetNativeError(); + std::wstring wmessage = + arrow::util::UTF8ToWideString(buf.str()).ValueOr(L"Error during load DSN"); + MessageBox(NULL, wmessage.c_str(), L"Error!", MB_ICONEXCLAMATION | MB_OK); + + std::wstring wmessage_text = arrow::util::UTF8ToWideString(err.GetMessageText()) + .ValueOr(L"Error during load DSN"); + SQLPostInstallerError(err.GetNativeError(), wmessage_text.c_str()); + } + + return false; +} + +bool DisplayConnectionWindow(void* window_parent, Configuration& config, + Connection::ConnPropertyMap& properties) { + for (const auto& [key, value] : properties) { + config.Set(key, value); + } + + if (DisplayConnectionWindow(window_parent, config)) { + properties = config.GetProperties(); + return true; + } else { + ARROW_LOG(INFO) << "Dialog is cancelled by user"; + return false; + } +} +} // namespace arrow::flight::sql::odbc + +BOOL INSTAPI ConfigDSNW(HWND hwnd_parent, WORD req, LPCWSTR wdriver, + LPCWSTR wattributes) { + using arrow::flight::sql::odbc::DisplayConnectionWindow; + using arrow::flight::sql::odbc::DriverException; + using arrow::flight::sql::odbc::FlightSqlConnection; + using arrow::flight::sql::odbc::PostArrowUtilError; + using arrow::flight::sql::odbc::PostError; + using arrow::flight::sql::odbc::RegisterDsn; + using arrow::flight::sql::odbc::UnregisterDsn; + using arrow::flight::sql::odbc::config::Configuration; + using arrow::flight::sql::odbc::config::ConnectionStringParser; + + Configuration config; + ConnectionStringParser parser(config); + + auto attributes_result = arrow::util::WideStringToUTF8(std::wstring(wattributes)); + if (!attributes_result.status().ok()) { + PostArrowUtilError(attributes_result.status()); + return FALSE; + } + std::string attributes = attributes_result.ValueOrDie(); + + parser.ParseConfigAttributes(attributes.c_str()); + + switch (req) { + case ODBC_ADD_DSN: { + config.LoadDefaults(); + if (!DisplayConnectionWindow(hwnd_parent, config) || !RegisterDsn(config, wdriver)) + return FALSE; + + break; + } + + case ODBC_CONFIG_DSN: { + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + auto wdsn_result = arrow::util::UTF8ToWideString(dsn); + if (!wdsn_result.status().ok()) { + PostArrowUtilError(wdsn_result.status()); + return FALSE; + } + std::wstring wdsn = wdsn_result.ValueOrDie(); + if (!SQLValidDSN(wdsn.c_str())) return FALSE; + + Configuration loaded(config); + try { + loaded.LoadDsn(dsn); + } catch (const DriverException& err) { + std::string error_msg = err.GetMessageText(); + std::wstring werror_msg = + arrow::util::UTF8ToWideString(error_msg).ValueOr(L"Error during DSN load"); + + PostError(err.GetNativeError(), werror_msg.c_str()); + return FALSE; + } + + if (!DisplayConnectionWindow(hwnd_parent, loaded) || !UnregisterDsn(wdsn.c_str()) || + !RegisterDsn(loaded, wdriver)) + return FALSE; + + break; + } + + case ODBC_REMOVE_DSN: { + const std::string& dsn = config.Get(FlightSqlConnection::DSN); + auto wdsn_result = arrow::util::UTF8ToWideString(dsn); + if (!wdsn_result.status().ok()) { + PostArrowUtilError(wdsn_result.status()); + return FALSE; + } + std::wstring wdsn = wdsn_result.ValueOrDie(); + if (!SQLValidDSN(wdsn.c_str()) || !UnregisterDsn(wdsn)) return FALSE; + + break; + } + + default: + return FALSE; + } + + return TRUE; +} diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc index c5646b42bef..531250b69b8 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + #include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" #include "arrow/flight/sql/odbc/odbc_impl/platform.h" @@ -29,34 +30,31 @@ namespace arrow::flight::sql::odbc { template class ConnectionTest : public T {}; -// GH-46574 TODO: add remote server test cases using `ConnectionRemoteTest` -class ConnectionRemoteTest : public FlightSQLODBCRemoteTestBase {}; -using TestTypes = ::testing::Types; +using TestTypes = + ::testing::Types; TYPED_TEST_SUITE(ConnectionTest, TestTypes); -TEST(SQLAllocHandle, TestSQLAllocHandleEnv) { - SQLHENV env; - - // Allocate an environment handle - ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env)); - - ASSERT_NE(env, nullptr); - - // Free an environment handle - ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env)); -} +template +class ConnectionHandleTest : public T {}; -TEST(SQLAllocEnv, TestSQLAllocEnv) { - SQLHENV env; +class ConnectionRemoteTest : public FlightSQLOdbcHandleRemoteTestBase {}; +using TestTypesHandle = + ::testing::Types; +TYPED_TEST_SUITE(ConnectionHandleTest, TestTypesHandle); +TEST(ODBCHandles, TestSQLAllocAndFreeEnv) { // Allocate an environment handle + SQLHENV env; ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env)); - // Free an environment handle + // Check for valid handle + ASSERT_NE(nullptr, env); + + // Free environment handle ASSERT_EQ(SQL_SUCCESS, SQLFreeEnv(env)); } -TEST(SQLAllocHandle, TestSQLAllocHandleConnect) { +TEST(ODBCHandles, TestSQLAllocAndFreeHandleConnect) { SQLHENV env; SQLHDBC conn; @@ -66,14 +64,17 @@ TEST(SQLAllocHandle, TestSQLAllocHandleConnect) { // Allocate a connection using alloc handle ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DBC, env, &conn)); - // Free a connection handle + // Check for valid handle + ASSERT_NE(nullptr, conn); + + // Free the created connection using free handle ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn)); - // Free an environment handle + // Free environment handle ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env)); } -TEST(SQLAllocConnect, TestSQLAllocHandleConnect) { +TEST(ODBCHandles, TestSQLAllocAndFreeConnect) { SQLHENV env; SQLHDBC conn; @@ -83,14 +84,17 @@ TEST(SQLAllocConnect, TestSQLAllocHandleConnect) { // Allocate a connection using alloc handle ASSERT_EQ(SQL_SUCCESS, SQLAllocConnect(env, &conn)); - // Free a connection handle + // Check for valid handle + ASSERT_NE(nullptr, conn); + + // Free the created connection using free connect ASSERT_EQ(SQL_SUCCESS, SQLFreeConnect(conn)); - // Free an environment handle + // Free environment handle ASSERT_EQ(SQL_SUCCESS, SQLFreeEnv(env)); } -TEST(SQLFreeHandle, TestFreeNullHandles) { +TEST(ODBCHandles, TestFreeNullHandles) { SQLHENV env = NULL; SQLHDBC conn = NULL; SQLHSTMT stmt = NULL; @@ -108,7 +112,6 @@ TEST(SQLFreeHandle, TestFreeNullHandles) { TEST(SQLGetEnvAttr, TestSQLGetEnvAttrODBCVersion) { SQLHENV env; - SQLINTEGER version; // Allocate an environment handle @@ -118,43 +121,37 @@ TEST(SQLGetEnvAttr, TestSQLGetEnvAttrODBCVersion) { ASSERT_EQ(SQL_OV_ODBC2, version); + // Free environment handle ASSERT_EQ(SQL_SUCCESS, SQLFreeEnv(env)); } TEST(SQLSetEnvAttr, TestSQLSetEnvAttrODBCVersionValid) { - SQLHENV env; - // Allocate an environment handle + SQLHENV env; ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env)); - // Attempt to set to supported version + // Attempt to set to unsupported version ASSERT_EQ(SQL_SUCCESS, SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast(SQL_OV_ODBC2), 0)); - SQLINTEGER version; - // Check ODBC version is set - ASSERT_EQ(SQL_SUCCESS, SQLGetEnvAttr(env, SQL_ATTR_ODBC_VERSION, &version, 0, 0)); - - ASSERT_EQ(SQL_OV_ODBC2, version); - + // Free environment handle ASSERT_EQ(SQL_SUCCESS, SQLFreeEnv(env)); } TEST(SQLSetEnvAttr, TestSQLSetEnvAttrODBCVersionInvalid) { - SQLHENV env; - // Allocate an environment handle + SQLHENV env; ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env)); // Attempt to set to unsupported version ASSERT_EQ(SQL_ERROR, SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast(1), 0)); + // Free environment handle ASSERT_EQ(SQL_SUCCESS, SQLFreeEnv(env)); } -// GH-46574 TODO: enable TestSQLGetEnvAttrOutputNTS which requires connection support -TYPED_TEST(ConnectionTest, DISABLED_TestSQLGetEnvAttrOutputNTS) { +TYPED_TEST(ConnectionTest, TestSQLGetEnvAttrOutputNTS) { SQLINTEGER output_nts; ASSERT_EQ(SQL_SUCCESS, @@ -183,41 +180,292 @@ TYPED_TEST(ConnectionTest, DISABLED_TestSQLGetEnvAttrNullValuePointer) { } TEST(SQLSetEnvAttr, TestSQLSetEnvAttrOutputNTSValid) { - SQLHENV env; - // Allocate an environment handle + SQLHENV env; ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env)); // Attempt to set to output nts to supported version ASSERT_EQ(SQL_SUCCESS, SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, reinterpret_cast(SQL_TRUE), 0)); + // Free environment handle ASSERT_EQ(SQL_SUCCESS, SQLFreeEnv(env)); } TEST(SQLSetEnvAttr, TestSQLSetEnvAttrOutputNTSInvalid) { - SQLHENV env; - // Allocate an environment handle + SQLHENV env; ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env)); // Attempt to set to output nts to unsupported false ASSERT_EQ(SQL_ERROR, SQLSetEnvAttr(env, SQL_ATTR_OUTPUT_NTS, reinterpret_cast(SQL_FALSE), 0)); + // Free environment handle ASSERT_EQ(SQL_SUCCESS, SQLFreeEnv(env)); } TEST(SQLSetEnvAttr, TestSQLSetEnvAttrNullValuePointer) { - SQLHENV env; - // Allocate an environment handle + SQLHENV env; ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env)); // Attempt to set using bad data pointer ASSERT_EQ(SQL_ERROR, SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, nullptr, 0)); + // Free environment handle ASSERT_EQ(SQL_SUCCESS, SQLFreeEnv(env)); } +TYPED_TEST(ConnectionHandleTest, TestSQLDriverConnect) { + // Connect string + std::string connect_str = this->GetConnectionString(); + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR out_str[kOdbcBufferSize] = L""; + SQLSMALLINT out_str_len; + + // Connecting to ODBC server. + ASSERT_EQ(SQL_SUCCESS, + SQLDriverConnect(this->conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), out_str, + kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); + + // Check that out_str has same content as connect_str + std::string out_connection_string = ODBC::SqlWcharToString(out_str, out_str_len); + Connection::ConnPropertyMap out_properties; + Connection::ConnPropertyMap in_properties; + ODBC::ODBCConnection::GetPropertiesFromConnString(out_connection_string, + out_properties); + ODBC::ODBCConnection::GetPropertiesFromConnString(connect_str, in_properties); + ASSERT_TRUE(CompareConnPropertyMap(out_properties, in_properties)); + + // Disconnect from ODBC + ASSERT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); +} + +#if defined _WIN32 +TYPED_TEST(ConnectionHandleTest, TestSQLDriverConnectDsn) { + // Connect string + std::string connect_str = this->GetConnectionString(); + + // Write connection string content into a DSN, + // must succeed before continuing + ASSERT_TRUE(WriteDSN(connect_str)); + + std::string dsn(kTestDsn); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + + // Update connection string to use DSN to connect + connect_str = std::string("DSN=") + std::string(kTestDsn) + + std::string(";driver={Apache Arrow Flight SQL ODBC Driver};"); + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR out_str[kOdbcBufferSize] = L""; + SQLSMALLINT out_str_len; + + // Connecting to ODBC server. + ASSERT_EQ(SQL_SUCCESS, + SQLDriverConnect(this->conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), out_str, + kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); + + // Remove DSN + ASSERT_TRUE(UnregisterDsn(wdsn)); + + // Disconnect from ODBC + ASSERT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); +} + +TYPED_TEST(ConnectionHandleTest, TestSQLConnect) { + // Connect string + std::string connect_str = this->GetConnectionString(); + + // Write connection string content into a DSN, + // must succeed before continuing + std::string uid(""), pwd(""); + ASSERT_TRUE(WriteDSN(connect_str)); + + std::string dsn(kTestDsn); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + ASSERT_OK_AND_ASSIGN(std::wstring wuid, arrow::util::UTF8ToWideString(uid)); + ASSERT_OK_AND_ASSIGN(std::wstring wpwd, arrow::util::UTF8ToWideString(pwd)); + std::vector dsn0(wdsn.begin(), wdsn.end()); + std::vector uid0(wuid.begin(), wuid.end()); + std::vector pwd0(wpwd.begin(), wpwd.end()); + + // Connecting to ODBC server. Empty uid and pwd should be ignored. + ASSERT_EQ(SQL_SUCCESS, + SQLConnect(this->conn, dsn0.data(), static_cast(dsn0.size()), + uid0.data(), static_cast(uid0.size()), pwd0.data(), + static_cast(pwd0.size()))) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); + + // Remove DSN + ASSERT_TRUE(UnregisterDsn(wdsn)); + + // Disconnect from ODBC + ASSERT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); +} + +TEST_F(ConnectionRemoteTest, TestSQLConnectInputUidPwd) { + // Connect string + std::string connect_str = GetConnectionString(); + + // Retrieve valid uid and pwd, assumes TEST_CONNECT_STR contains uid and pwd + Connection::ConnPropertyMap properties; + ODBC::ODBCConnection::GetPropertiesFromConnString(connect_str, properties); + std::string uid_key("uid"); + std::string pwd_key("pwd"); + std::string uid = properties[uid_key]; + std::string pwd = properties[pwd_key]; + + // Write connection string content without uid and pwd into a DSN, + // must succeed before continuing + properties.erase(uid_key); + properties.erase(pwd_key); + ASSERT_TRUE(WriteDSN(properties)); + + std::string dsn(kTestDsn); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + ASSERT_OK_AND_ASSIGN(std::wstring wuid, arrow::util::UTF8ToWideString(uid)); + ASSERT_OK_AND_ASSIGN(std::wstring wpwd, arrow::util::UTF8ToWideString(pwd)); + std::vector dsn0(wdsn.begin(), wdsn.end()); + std::vector uid0(wuid.begin(), wuid.end()); + std::vector pwd0(wpwd.begin(), wpwd.end()); + + // Connecting to ODBC server. + ASSERT_EQ(SQL_SUCCESS, + SQLConnect(this->conn, dsn0.data(), static_cast(dsn0.size()), + uid0.data(), static_cast(uid0.size()), pwd0.data(), + static_cast(pwd0.size()))) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); + + // Remove DSN + ASSERT_TRUE(UnregisterDsn(wdsn)); + + // Disconnect from ODBC + ASSERT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); +} + +TEST_F(ConnectionRemoteTest, TestSQLConnectInvalidUid) { + // Connect string + std::string connect_str = GetConnectionString(); + + // Retrieve valid uid and pwd, assumes TEST_CONNECT_STR contains uid and pwd + Connection::ConnPropertyMap properties; + ODBC::ODBCConnection::GetPropertiesFromConnString(connect_str, properties); + std::string uid = properties["uid"]; + std::string pwd = properties["pwd"]; + + // Append invalid uid to connection string + connect_str += "uid=non_existent_id;"; + + // Write connection string content into a DSN, + // must succeed before continuing + ASSERT_TRUE(WriteDSN(connect_str)); + + std::string dsn(kTestDsn); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + ASSERT_OK_AND_ASSIGN(std::wstring wuid, arrow::util::UTF8ToWideString(uid)); + ASSERT_OK_AND_ASSIGN(std::wstring wpwd, arrow::util::UTF8ToWideString(pwd)); + std::vector dsn0(wdsn.begin(), wdsn.end()); + std::vector uid0(wuid.begin(), wuid.end()); + std::vector pwd0(wpwd.begin(), wpwd.end()); + + // Connecting to ODBC server. + // UID specified in DSN will take precedence, + // so connection still fails despite passing valid uid in SQLConnect call + ASSERT_EQ(SQL_ERROR, + SQLConnect(this->conn, dsn0.data(), static_cast(dsn0.size()), + uid0.data(), static_cast(uid0.size()), pwd0.data(), + static_cast(pwd0.size()))); + + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorState28000); + + // Remove DSN + ASSERT_TRUE(UnregisterDsn(wdsn)); +} + +TEST_F(ConnectionRemoteTest, TestSQLConnectDSNPrecedence) { + // Connect string + std::string connect_str = GetConnectionString(); + + // Write connection string content into a DSN, + // must succeed before continuing + + // Pass incorrect uid and password to SQLConnect, they will be ignored. + // Assumes TEST_CONNECT_STR contains uid and pwd + std::string uid("non_existent_id"), pwd("non_existent_password"); + ASSERT_TRUE(WriteDSN(connect_str)); + + std::string dsn(kTestDsn); + ASSERT_OK_AND_ASSIGN(std::wstring wdsn, arrow::util::UTF8ToWideString(dsn)); + ASSERT_OK_AND_ASSIGN(std::wstring wuid, arrow::util::UTF8ToWideString(uid)); + ASSERT_OK_AND_ASSIGN(std::wstring wpwd, arrow::util::UTF8ToWideString(pwd)); + std::vector dsn0(wdsn.begin(), wdsn.end()); + std::vector uid0(wuid.begin(), wuid.end()); + std::vector pwd0(wpwd.begin(), wpwd.end()); + + // Connecting to ODBC server. + ASSERT_EQ(SQL_SUCCESS, + SQLConnect(this->conn, dsn0.data(), static_cast(dsn0.size()), + uid0.data(), static_cast(uid0.size()), pwd0.data(), + static_cast(pwd0.size()))) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); + + // Remove DSN + ASSERT_TRUE(UnregisterDsn(wdsn)); + + // Disconnect from ODBC + ASSERT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); +} + +#endif + +TEST_F(ConnectionRemoteTest, TestSQLDriverConnectInvalidUid) { + // Invalid connect string + std::string connect_str = GetInvalidConnectionString(); + + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR out_str[kOdbcBufferSize]; + SQLSMALLINT out_str_len; + + // Connecting to ODBC server. + ASSERT_EQ(SQL_ERROR, + SQLDriverConnect(this->conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), out_str, + kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)); + + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorState28000); + + std::string out_connection_string = ODBC::SqlWcharToString(out_str, out_str_len); + ASSERT_TRUE(out_connection_string.empty()); +} + +TYPED_TEST(ConnectionHandleTest, TestSQLDisconnectWithoutConnection) { + // Attempt to disconnect without a connection, expect to fail + ASSERT_EQ(SQL_ERROR, SQLDisconnect(this->conn)); + + // Expect ODBC driver manager to return error state + VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorState08003); +} + +TYPED_TEST(ConnectionTest, TestConnect) { + // Verifies connect and disconnect works on its own +} } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc index fccb5525759..eb6c60b9762 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc @@ -28,7 +28,7 @@ namespace arrow::flight::sql::odbc { -void FlightSQLODBCRemoteTestBase::AllocEnvConnHandles(SQLINTEGER odbc_ver) { +void ODBCRemoteTestBase::AllocEnvConnHandles(SQLINTEGER odbc_ver) { // Allocate an environment handle ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env)); @@ -41,13 +41,13 @@ void FlightSQLODBCRemoteTestBase::AllocEnvConnHandles(SQLINTEGER odbc_ver) { ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DBC, env, &conn)); } -void FlightSQLODBCRemoteTestBase::Connect(SQLINTEGER odbc_ver) { +void ODBCRemoteTestBase::Connect(SQLINTEGER odbc_ver) { ASSERT_NO_FATAL_FAILURE(AllocEnvConnHandles(odbc_ver)); std::string connect_str = GetConnectionString(); ASSERT_NO_FATAL_FAILURE(ConnectWithString(connect_str)); } -void FlightSQLODBCRemoteTestBase::ConnectWithString(std::string connect_str) { +void ODBCRemoteTestBase::ConnectWithString(std::string connect_str) { // Connect string std::vector connect_str0(connect_str.begin(), connect_str.end()); @@ -61,18 +61,22 @@ void FlightSQLODBCRemoteTestBase::ConnectWithString(std::string connect_str) { kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)) << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); - // Allocate a statement using alloc handle - ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, conn, &stmt)); + // GH-47710: TODO Allocate a statement using alloc handle + // ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, conn, &stmt)); } -void FlightSQLODBCRemoteTestBase::Disconnect() { - // Close statement - EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt)); +void ODBCRemoteTestBase::Disconnect() { + // GH-47710: TODO Close statement + // EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt)); // Disconnect from ODBC EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(conn)) << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); + FreeEnvConnHandles(); +} + +void ODBCRemoteTestBase::FreeEnvConnHandles() { // Free connection handle EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn)); @@ -80,20 +84,20 @@ void FlightSQLODBCRemoteTestBase::Disconnect() { EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env)); } -std::string FlightSQLODBCRemoteTestBase::GetConnectionString() { +std::string ODBCRemoteTestBase::GetConnectionString() { std::string connect_str = arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOrDie(); return connect_str; } -std::string FlightSQLODBCRemoteTestBase::GetInvalidConnectionString() { +std::string ODBCRemoteTestBase::GetInvalidConnectionString() { std::string connect_str = GetConnectionString(); // Append invalid uid to connection string connect_str += std::string("uid=non_existent_id;"); return connect_str; } -std::wstring FlightSQLODBCRemoteTestBase::GetQueryAllDataTypes() { +std::wstring ODBCRemoteTestBase::GetQueryAllDataTypes() { std::wstring wsql = LR"( SELECT -- Numeric types @@ -144,10 +148,18 @@ std::wstring FlightSQLODBCRemoteTestBase::GetQueryAllDataTypes() { return wsql; } -void FlightSQLODBCRemoteTestBase::SetUp() { +void ODBCRemoteTestBase::SetUp() { if (arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOr("").empty()) { + skipping_test_ = true; GTEST_SKIP() << "Skipping test: kTestConnectStr not set"; } +} + +void FlightSQLODBCRemoteTestBase::SetUp() { + ODBCRemoteTestBase::SetUp(); + if (skipping_test_) { + return; + } this->Connect(); connected_ = true; @@ -161,14 +173,32 @@ void FlightSQLODBCRemoteTestBase::TearDown() { } void FlightSQLOdbcV2RemoteTestBase::SetUp() { - if (arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOr("").empty()) { - GTEST_SKIP() << "Skipping test: kTestConnectStr not set"; + ODBCRemoteTestBase::SetUp(); + if (skipping_test_) { + return; } this->Connect(SQL_OV_ODBC2); connected_ = true; } +void FlightSQLOdbcHandleRemoteTestBase::SetUp() { + ODBCRemoteTestBase::SetUp(); + if (skipping_test_) { + return; + } + + this->AllocEnvConnHandles(); + allocated_ = true; +} + +void FlightSQLOdbcHandleRemoteTestBase::TearDown() { + if (allocated_) { + this->FreeEnvConnHandles(); + allocated_ = false; + } +} + std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers) { // Lambda function to compare characters without case sensitivity. auto char_compare = [](const char& char1, const char& char2) { @@ -209,7 +239,7 @@ Status MockServerMiddlewareFactory::StartCall( return Status::OK(); } -std::string FlightSQLODBCMockTestBase::GetConnectionString() { +std::string ODBCMockTestBase::GetConnectionString() { std::string connect_str( "driver={Apache Arrow Flight SQL ODBC Driver};HOST=localhost;port=" + std::to_string(port) + ";token=" + std::string(kTestToken) + @@ -217,14 +247,14 @@ std::string FlightSQLODBCMockTestBase::GetConnectionString() { return connect_str; } -std::string FlightSQLODBCMockTestBase::GetInvalidConnectionString() { +std::string ODBCMockTestBase::GetInvalidConnectionString() { std::string connect_str = GetConnectionString(); // Append invalid token to connection string connect_str += std::string("token=invalid_token;"); return connect_str; } -std::wstring FlightSQLODBCMockTestBase::GetQueryAllDataTypes() { +std::wstring ODBCMockTestBase::GetQueryAllDataTypes() { std::wstring wsql = LR"( SELECT -- Numeric types @@ -273,7 +303,7 @@ std::wstring FlightSQLODBCMockTestBase::GetQueryAllDataTypes() { return wsql; } -void FlightSQLODBCMockTestBase::CreateTestTables() { +void ODBCMockTestBase::CreateTestTables() { ASSERT_OK(server_->ExecuteSql(R"( CREATE TABLE TestTable ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -286,7 +316,7 @@ void FlightSQLODBCMockTestBase::CreateTestTables() { )")); } -void FlightSQLODBCMockTestBase::CreateTableAllDataType() { +void ODBCMockTestBase::CreateTableAllDataType() { // Limitation on mock SQLite server: // Only int64, float64, binary, and utf8 Arrow Types are supported by // SQLiteFlightSqlServer::Impl::DoGetTables @@ -308,7 +338,7 @@ void FlightSQLODBCMockTestBase::CreateTableAllDataType() { )")); } -void FlightSQLODBCMockTestBase::CreateUnicodeTable() { +void ODBCMockTestBase::CreateUnicodeTable() { std::string unicode_sql = arrow::util::WideStringToUTF8( LR"( CREATE TABLE 数据( @@ -322,7 +352,7 @@ void FlightSQLODBCMockTestBase::CreateUnicodeTable() { ASSERT_OK(server_->ExecuteSql(unicode_sql)); } -void FlightSQLODBCMockTestBase::Initialize() { +void ODBCMockTestBase::SetUp() { ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("0.0.0.0", 0)); arrow::flight::FlightServerOptions options(location); options.auth_handler = std::make_unique(); @@ -338,25 +368,40 @@ void FlightSQLODBCMockTestBase::Initialize() { } void FlightSQLODBCMockTestBase::SetUp() { - this->Initialize(); + ODBCMockTestBase::SetUp(); this->Connect(); connected_ = true; } +void ODBCMockTestBase::TearDown() { + ASSERT_OK(server_->Shutdown()); + ASSERT_OK(server_->Wait()); +} + void FlightSQLODBCMockTestBase::TearDown() { if (connected_) { this->Disconnect(); connected_ = false; } - ASSERT_OK(server_->Shutdown()); + ODBCMockTestBase::TearDown(); } void FlightSQLOdbcV2MockTestBase::SetUp() { - this->Initialize(); + ODBCMockTestBase::SetUp(); this->Connect(SQL_OV_ODBC2); connected_ = true; } +void FlightSQLOdbcHandleMockTestBase::SetUp() { + ODBCMockTestBase::SetUp(); + this->AllocEnvConnHandles(); +} + +void FlightSQLOdbcHandleMockTestBase::TearDown() { + this->FreeEnvConnHandles(); + ODBCMockTestBase::TearDown(); +} + bool CompareConnPropertyMap(Connection::ConnPropertyMap map1, Connection::ConnPropertyMap map2) { if (map1.size() != map2.size()) return false; @@ -411,7 +456,7 @@ std::string GetOdbcErrorMessage(SQLSMALLINT handle_type, SQLHANDLE handle) { return res; } -// TODO: once RegisterDsn is implemented in Mac and Linux, the following can be +// GH-47822 TODO: once RegisterDsn is implemented in Mac and Linux, the following can be // re-enabled. #if defined _WIN32 bool WriteDSN(std::string connection_str) { diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h index e35e6c38f85..e043a459f0a 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h @@ -42,14 +42,16 @@ static constexpr std::string_view kTestDsn = "Apache Arrow Flight SQL Test DSN"; namespace arrow::flight::sql::odbc { /// \brief Base test fixture for running tests against a remote server. -/// Each test file running remote server tests should define a -/// fixture inheriting from this base fixture. /// The connection string for connecting to this server is defined /// in the ARROW_FLIGHT_SQL_ODBC_CONN environment variable. -class FlightSQLODBCRemoteTestBase : public ::testing::Test { +/// Note that this fixture does not handle the driver's connection/disconnection +/// during SetUp/Teardown. +class ODBCRemoteTestBase : public ::testing::Test { public: /// \brief Allocate environment and connection handles void AllocEnvConnHandles(SQLINTEGER odbc_ver = SQL_OV_ODBC3); + /// \brief Free environment and connection handles + void FreeEnvConnHandles(); /// \brief Connect to Arrow Flight SQL server using connection string defined in /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN", allocate statement handle. /// Connects using ODBC Ver 3 by default @@ -75,6 +77,18 @@ class FlightSQLODBCRemoteTestBase : public ::testing::Test { /** ODBC Statement. */ SQLHSTMT stmt = 0; + protected: + void SetUp() override; + + bool skipping_test_ = false; +}; + +/// \brief Base test fixture for running tests against a remote server. +/// Each test file running remote server tests should define a +/// fixture inheriting from this base fixture. +/// The connection string for connecting to this server is defined +/// in the ARROW_FLIGHT_SQL_ODBC_CONN environment variable. +class FlightSQLODBCRemoteTestBase : public ODBCRemoteTestBase { protected: void SetUp() override; @@ -91,6 +105,14 @@ class FlightSQLOdbcV2RemoteTestBase : public FlightSQLODBCRemoteTestBase { void SetUp() override; }; +class FlightSQLOdbcHandleRemoteTestBase : public FlightSQLODBCRemoteTestBase { + protected: + void SetUp() override; + void TearDown() override; + + bool allocated_ = false; +}; + static constexpr std::string_view kAuthorizationHeader = "authorization"; static constexpr std::string_view kBearerPrefix = "Bearer "; static constexpr std::string_view kTestToken = "t0k3n"; @@ -129,9 +151,7 @@ class MockServerMiddlewareFactory : public ServerMiddlewareFactory { }; /// \brief Base test fixture for running tests against a mock server. -/// Each test file running mock server tests should define a -/// fixture inheriting from this base fixture. -class FlightSQLODBCMockTestBase : public FlightSQLODBCRemoteTestBase { +class ODBCMockTestBase : public FlightSQLODBCRemoteTestBase { // Sets up a mock server for each test case public: /// \brief Get connection string for mock server @@ -152,16 +172,23 @@ class FlightSQLODBCMockTestBase : public FlightSQLODBCRemoteTestBase { int port; protected: - void Initialize(); - void SetUp() override; void TearDown() override; - private: std::shared_ptr server_; }; +/// \brief Base test fixture for running tests against a mock server. +/// Each test file running mock server tests should define a +/// fixture inheriting from this base fixture. +class FlightSQLODBCMockTestBase : public ODBCMockTestBase { + protected: + void SetUp() override; + + void TearDown() override; +}; + /// \brief Base test fixture for running ODBC V2 tests against a mock server. /// Each test file running mock server ODBC V2 tests should define a /// fixture inheriting from this base fixture. @@ -170,6 +197,12 @@ class FlightSQLOdbcV2MockTestBase : public FlightSQLODBCMockTestBase { void SetUp() override; }; +class FlightSQLOdbcHandleMockTestBase : public FlightSQLODBCMockTestBase { + protected: + void SetUp() override; + void TearDown() override; +}; + /** ODBC read buffer size. */ static constexpr int kOdbcBufferSize = 1024;