From db130977392bd53da5f89b742dc24d8b9d306bf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Mon, 18 May 2026 23:28:53 +0200 Subject: [PATCH 01/15] feat: add TDS 8 strict encryption for Microsoft Fabric support Implement EncryptionLevel::Strict which performs the TLS handshake directly on the TCP stream before any TDS traffic, as required by TDS 8 (Microsoft Fabric Data Warehouse, SQL Server 2022+ strict mode). - Add EncryptionLevel::Strict variant - Add tls_handshake_strict() that bypasses TDS packet wrapping - Support encrypt=strict in ADO.NET and JDBC connection strings - Map Strict to ENCRYPT_ON (0x01) on the wire in PRELOGIN - Make encryption negotiation non-negotiable for strict mode --- src/client/config.rs | 9 +++- src/client/config/ado_net.rs | 30 ++++++++++++ src/client/config/jdbc.rs | 15 ++++++ src/client/connection.rs | 90 ++++++++++++++++++++++++++++++++++++ src/tds.rs | 3 ++ src/tds/codec/pre_login.rs | 10 +++- 6 files changed, 154 insertions(+), 3 deletions(-) diff --git a/src/client/config.rs b/src/client/config.rs index fff68bc15..5dbfd7323 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -119,6 +119,8 @@ impl Config { /// /// - With `tls` feature, defaults to `Required`. /// - Without `tls` feature, defaults to `NotSupported`. + /// - Use `Strict` for TDS 8 strict transport encryption (required for + /// Microsoft Fabric endpoints and SQL Server 2022+ strict mode). pub fn encryption(&mut self, encryption: EncryptionLevel) { self.encryption = encryption; } @@ -210,7 +212,7 @@ impl Config { /// |`database`|``|The name of the database.| /// |`TrustServerCertificate`|`true`,`false`,`yes`,`no`|Specifies whether the driver trusts the server certificate when connecting using TLS. Cannot be used toghether with `TrustServerCertificateCA`| /// |`TrustServerCertificateCA`|``|Path to a `pem`, `crt` or `der` certificate file. Cannot be used together with `TrustServerCertificate`| - /// |`encrypt`|`true`,`false`,`yes`,`no`,`DANGER_PLAINTEXT`|Specifies whether the driver uses TLS to encrypt communication.| + /// |`encrypt`|`true`,`false`,`yes`,`no`,`strict`,`DANGER_PLAINTEXT`|Specifies whether the driver uses TLS to encrypt communication. `strict` enables TDS 8 strict transport encryption (required for Microsoft Fabric).| /// |`Application Name`, `ApplicationName`|``|Sets the application name for the connection.| /// /// [ADO.NET connection string]: https://docs.microsoft.com/en-us/dotnet/framework/data/adonet/connection-strings @@ -357,7 +359,10 @@ pub(crate) trait ConfigString { .map(|val| match Self::parse_bool(val) { Ok(true) => Ok(EncryptionLevel::Required), Ok(false) => Ok(EncryptionLevel::Off), - Err(_) if val == "DANGER_PLAINTEXT" => Ok(EncryptionLevel::NotSupported), + Err(_) if val.eq_ignore_ascii_case("DANGER_PLAINTEXT") => { + Ok(EncryptionLevel::NotSupported) + } + Err(_) if val.eq_ignore_ascii_case("strict") => Ok(EncryptionLevel::Strict), Err(e) => Err(e), }) .unwrap_or(Ok(EncryptionLevel::Off)) diff --git a/src/client/config/ado_net.rs b/src/client/config/ado_net.rs index 94df9ca38..98a8db811 100644 --- a/src/client/config/ado_net.rs +++ b/src/client/config/ado_net.rs @@ -470,6 +470,36 @@ mod tests { Ok(()) } + #[test] + #[cfg(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" + ))] + fn encryption_parsing_strict() -> crate::Result<()> { + let test_str = "encrypt=strict"; + let ado: AdoNetConfig = test_str.parse()?; + + assert_eq!(EncryptionLevel::Strict, ado.encrypt()?); + + Ok(()) + } + + #[test] + #[cfg(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" + ))] + fn encryption_parsing_strict_case_insensitive() -> crate::Result<()> { + let test_str = "encrypt=Strict"; + let ado: AdoNetConfig = test_str.parse()?; + + assert_eq!(EncryptionLevel::Strict, ado.encrypt()?); + + Ok(()) + } + #[test] fn application_name_parsing() -> crate::Result<()> { let test_str = "Application Name=meow"; diff --git a/src/client/config/jdbc.rs b/src/client/config/jdbc.rs index 4168cf975..7677637a3 100644 --- a/src/client/config/jdbc.rs +++ b/src/client/config/jdbc.rs @@ -319,6 +319,21 @@ mod tests { Ok(()) } + #[test] + #[cfg(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" + ))] + fn encryption_parsing_strict() -> crate::Result<()> { + let test_str = "jdbc:sqlserver://my-server.com:4200;encrypt=strict;"; + let jdbc: JdbcConfig = test_str.parse()?; + + assert_eq!(EncryptionLevel::Strict, jdbc.encrypt()?); + + Ok(()) + } + #[test] fn application_name_parsing() -> crate::Result<()> { let test_str = "jdbc:sqlserver://my-server.com:4200;Application Name=meow"; diff --git a/src/client/connection.rs b/src/client/connection.rs index 09d372561..c88a3fd3d 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -88,6 +88,14 @@ impl Connection { buf: BytesMut::new(), }; + // TDS 8 strict mode: TLS handshake first, then PRELOGIN inside TLS + if config.encryption == EncryptionLevel::Strict { + let connection = connection.tls_handshake_strict(&config).await?; + let mut connection = Self::finish_connect_after_tls(connection, config).await?; + connection.flush_done().await?; + return Ok(connection); + } + let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_)); let prelogin = connection @@ -115,6 +123,33 @@ impl Connection { Ok(connection) } + /// Complete connection setup after TLS is established (for strict mode). + /// In TDS 8, PRELOGIN and LOGIN both happen inside the TLS tunnel. + async fn finish_connect_after_tls( + mut connection: Self, + config: Config, + ) -> crate::Result { + let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_)); + + let prelogin = connection + .prelogin(config.encryption, fed_auth_required) + .await?; + + let connection = connection + .login( + config.auth, + EncryptionLevel::Strict, + config.database, + config.host, + config.application_name, + config.readonly, + prelogin, + ) + .await?; + + Ok(connection) + } + /// Flush the incoming token stream until receiving `DONE` token. async fn flush_done(&mut self) -> crate::Result { TokenStream::new(self).flush_done().await @@ -478,6 +513,61 @@ impl Connection { } } + /// Implements TDS 8 strict TLS handshake: TLS is established directly on + /// the raw TCP stream without any TDS packet wrapping. This is required + /// for Microsoft Fabric and SQL Server 2022+ strict mode. + #[cfg(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" + ))] + async fn tls_handshake_strict(self, config: &Config) -> crate::Result { + event!( + Level::INFO, + "Performing a TDS 8 strict TLS handshake (TLS-first)" + ); + + let Self { + transport, context, .. + } = self; + + let stream = match transport.into_inner() { + MaybeTlsStream::Raw(tcp) => { + // In strict mode, create the wrapper but immediately mark handshake + // as complete so it acts as a transparent passthrough. This means + // the TLS handshake goes directly over the TCP stream without any + // TDS packet wrapping - exactly what TDS 8 requires. + let mut wrapper = TlsPreloginWrapper::new(tcp); + wrapper.handshake_complete(); + create_tls_stream(config, wrapper).await? + } + _ => unreachable!(), + }; + + event!(Level::INFO, "TDS 8 strict TLS handshake successful"); + + let transport = Framed::new(MaybeTlsStream::Tls(stream), PacketCodec); + + Ok(Self { + transport, + context, + flushed: false, + buf: BytesMut::new(), + }) + } + + /// Implements TDS 8 strict TLS handshake (no-op when TLS features are disabled). + #[cfg(not(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" + )))] + async fn tls_handshake_strict(self, _: &Config) -> crate::Result { + Err(crate::Error::Protocol( + "TDS 8 strict encryption requires a TLS feature (rustls, native-tls, or vendored-openssl) to be enabled".into() + )) + } + /// Implements the TLS handshake with the SQL Server. #[cfg(not(any( feature = "rustls", diff --git a/src/tds.rs b/src/tds.rs index f4b6f9253..33e07e3b2 100644 --- a/src/tds.rs +++ b/src/tds.rs @@ -25,6 +25,9 @@ uint_enum! { NotSupported = 2, /// Encrypt everything and fail if not possible Required = 3, + /// TDS 8 strict transport encryption: TLS handshake occurs before any + /// TDS traffic (required for Microsoft Fabric endpoints). + Strict = 4, } } diff --git a/src/tds/codec/pre_login.rs b/src/tds/codec/pre_login.rs index eb4c27e60..1913eef6f 100644 --- a/src/tds/codec/pre_login.rs +++ b/src/tds/codec/pre_login.rs @@ -64,6 +64,8 @@ impl PreloginMessage { ))] pub fn negotiated_encryption(&self, expected: EncryptionLevel) -> EncryptionLevel { match (expected, self.encryption) { + // TDS 8 strict mode is non-negotiable + (EncryptionLevel::Strict, _) => EncryptionLevel::Strict, (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => { EncryptionLevel::NotSupported } @@ -110,7 +112,13 @@ impl Encode for PreloginMessage { // encryption fields.push((PRELOGIN_ENCRYPTION, 0x01)); // encryption - data_cursor.write_u8(self.encryption as u8)?; + // In TDS 8 strict mode, TLS is already established before PRELOGIN. + // Send ENCRYPT_ON (0x01) on the wire since strict is not a valid wire value. + let encryption_wire_value = match self.encryption { + EncryptionLevel::Strict => EncryptionLevel::On as u8, + other => other as u8, + }; + data_cursor.write_u8(encryption_wire_value)?; // threadid fields.push((PRELOGIN_THREADID, 0x04)); // thread id From da29f834d4efd1f92eae68925252faf1ded41a4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 13:42:08 +0200 Subject: [PATCH 02/15] feat: implement full TDS 8 strict encryption with Fabric routing Complete the TDS 8 strict encryption implementation to support Microsoft Fabric SQL Data Warehouse end-to-end connectivity: - Handle routing reconnection flow (gateway -> backend redirect) - Add FEDAUTHREQUIRED in PRELOGIN for AAD auth state setup - Add AZURESQLSUPPORT (0x08) feature extension in LOGIN - Fix FEATUREEXTACK decoder to handle AZURESQLSUPPORT ack and unknown features gracefully (no more panics) - Fix Packet::encode to use offset-aware length patching - Fix packet_id to start from 1 (not 0) - Add TRACEID support in PRELOGIN for backend connections - Add pipelined PRELOGIN+LOGIN via feed()+flush() for strict mode - Set TDS version to SqlServer2022 for strict mode LOGIN - Set client_lcid=0x0409 and clt_int_name='tiberius' defaults - Add Config fields: strict_pipelined, login_server_name, instance_name - Add ALPN 'ms-tds' for strict TLS connections (native-tls) - Use UseTSQL flag in LOGIN type_flags - Add integration tests proving Fabric connectivity with routing Tested against live Fabric endpoint: connects, authenticates with AAD token, handles routing redirect, and executes queries including parameterized statements. --- Cargo.toml | 14 +- src/client/config.rs | 40 +++ src/client/connection.rs | 249 +++++++++++++++++- src/client/tls_stream/native_tls_stream.rs | 54 ++++ src/tds/codec/login.rs | 122 +++++++-- src/tds/codec/packet.rs | 11 +- src/tds/codec/pre_login.rs | 58 ++++- src/tds/codec/token/token_error.rs | 4 +- src/tds/codec/token/token_feature_ext_ack.rs | 39 ++- src/tds/context.rs | 2 +- tests/fabric.rs | 259 +++++++++++++++++++ 11 files changed, 799 insertions(+), 53 deletions(-) create mode 100644 tests/fabric.rs diff --git a/Cargo.toml b/Cargo.toml index 0caaac815..fec4c35d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,10 @@ members = ["runtimes-macro"] path = "tests/query.rs" name = "query" +[[test]] +path = "tests/fabric.rs" +name = "fabric" + [[test]] path = "tests/named-instance-async.rs" name = "named-instance-async" @@ -50,7 +54,7 @@ asynchronous-codec = "0.6" async-trait = "0.1" connection-string = "0.2" num-traits = "0.2" -uuid = "1.0" +uuid = { version = "1.0", features = ["v4"] } [target.'cfg(windows)'.dependencies] winauth = { version = "0.0.4", optional = true } @@ -63,6 +67,12 @@ version = "0.4" features = ["runtime-async-std"] optional = true +[dependencies.native-tls-crate] +version = "0.2.18" +optional = true +features = ["alpn"] +package = "native-tls" + [dependencies.tokio] version = "1.0" optional = true @@ -200,5 +210,5 @@ sql-browser-smol = ["async-io", "async-net", "futures-lite"] integrated-auth-gssapi = ["libgssapi"] bigdecimal = ["bigdecimal_"] rustls = ["tokio-rustls", "tokio-util", "rustls-pemfile", "rustls-native-certs"] -native-tls = ["async-native-tls"] +native-tls = ["async-native-tls", "native-tls-crate"] vendored-openssl = ["opentls"] diff --git a/src/client/config.rs b/src/client/config.rs index 5dbfd7323..1beca39a6 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -32,6 +32,15 @@ pub struct Config { pub(crate) trust: TrustConfig, pub(crate) auth: AuthMethod, pub(crate) readonly: bool, + /// When true, PRELOGIN and LOGIN are sent back-to-back without waiting + /// for the PRELOGIN response. Required for TDS 8 strict mode backend + /// reconnection after receiving a routing redirect from the gateway. + pub(crate) strict_pipelined: bool, + /// Override the `server_name` field sent in the TDS LOGIN message. + /// When reconnecting to a backend after routing, this should be set to + /// the original gateway hostname so the backend knows which endpoint + /// the client intended to reach. + pub(crate) login_server_name: Option, } #[derive(Clone, Debug)] @@ -65,6 +74,8 @@ impl Default for Config { trust: TrustConfig::Default, auth: AuthMethod::None, readonly: false, + strict_pipelined: false, + login_server_name: None, } } } @@ -173,6 +184,35 @@ impl Config { self.readonly = readnoly; } + /// Enable pipelined PRELOGIN+LOGIN for TDS 8 strict mode backend + /// reconnection. + /// + /// When connecting to a Microsoft Fabric (or SQL Server 2022+ strict mode) + /// endpoint, the gateway returns a routing redirect. The client must + /// disconnect and reconnect to the backend host. The backend requires + /// PRELOGIN and LOGIN to be sent back-to-back (pipelined) without waiting + /// for the PRELOGIN response. + /// + /// Call this method on the `Config` used for the backend reconnection after + /// receiving [`Error::Routing`]. + /// + /// - Defaults to `false`. + pub fn strict_pipelined(&mut self) { + self.strict_pipelined = true; + } + + /// Override the `server_name` field in the TDS LOGIN message. + /// + /// When reconnecting to a backend after routing, set this to the original + /// gateway/endpoint hostname. The backend uses this to identify which + /// workspace/database the client intended to reach, since the backend host + /// is an internal pool address. + /// + /// If not set, defaults to the value of `host()`. + pub fn login_server_name(&mut self, name: impl Into) { + self.login_server_name = Some(name.into()); + } + pub(crate) fn get_host(&self) -> &str { self.host .as_deref() diff --git a/src/client/connection.rs b/src/client/connection.rs index c88a3fd3d..3c62bb114 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -8,8 +8,8 @@ use crate::{ client::{tls::MaybeTlsStream, AuthMethod, Config}, tds::{ codec::{ - self, Encode, LoginMessage, Packet, PacketCodec, PacketHeader, PacketStatus, - PreloginMessage, TokenDone, + self, Encode, FeatureLevel, LoginMessage, Packet, PacketCodec, PacketHeader, + PacketStatus, PreloginMessage, TokenDone, }, stream::TokenStream, Context, HEADER_BYTES, @@ -88,9 +88,23 @@ impl Connection { buf: BytesMut::new(), }; - // TDS 8 strict mode: TLS handshake first, then PRELOGIN inside TLS + // TDS 8 strict mode: TLS handshake first, then PRELOGIN inside TLS. if config.encryption == EncryptionLevel::Strict { let connection = connection.tls_handshake_strict(&config).await?; + + if config.strict_pipelined { + // Backend reconnection after routing: pipeline PRELOGIN+LOGIN + // without waiting for the PRELOGIN response (required by Fabric + // backend servers). + let mut connection = + Self::finish_connect_strict_pipelined(connection, config).await?; + connection.flush_done().await?; + return Ok(connection); + } + + // Gateway (first connection): sequential PRELOGIN → LOGIN. + // The gateway will respond with a routing token, which propagates + // as Error::Routing for the caller to handle reconnection. let mut connection = Self::finish_connect_after_tls(connection, config).await?; connection.flush_done().await?; return Ok(connection); @@ -99,7 +113,7 @@ impl Connection { let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_)); let prelogin = connection - .prelogin(config.encryption, fed_auth_required) + .prelogin(config.encryption, fed_auth_required, config.instance_name.clone(), false) .await?; let encryption = prelogin.negotiated_encryption(config.encryption); @@ -129,18 +143,26 @@ impl Connection { mut connection: Self, config: Config, ) -> crate::Result { + // For backend connections (routing reconnect), we still need to send + // FEDAUTHREQUIRED if using AAD auth — the backend needs it to prepare + // for processing the FEDAUTH token in LOGIN. + let is_backend = config.instance_name.is_some(); let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_)); let prelogin = connection - .prelogin(config.encryption, fed_auth_required) + .prelogin(config.encryption, fed_auth_required, config.instance_name.clone(), is_backend) .await?; + // Use login_server_name if set (for routed connections, this is the + // original gateway hostname), otherwise fall back to host. + let server_name = config.login_server_name.or(config.host); + let connection = connection .login( config.auth, EncryptionLevel::Strict, config.database, - config.host, + server_name, config.application_name, config.readonly, prelogin, @@ -150,6 +172,136 @@ impl Connection { Ok(connection) } + /// Complete connection to a strict-mode backend after routing redirect. + /// + /// Fabric (and SQL Server 2022+ strict) backends require PRELOGIN and LOGIN + /// to be sent back-to-back (pipelined) without waiting for the PRELOGIN + /// response. This matches the behavior of the ODBC Driver 18. + /// + /// The flow is: + /// 1. Encode PRELOGIN and LOGIN packets + /// 2. Write all packets to the wire in a single flush + /// 3. Read PRELOGIN response + /// 4. Caller reads LOGIN response via flush_done() + async fn finish_connect_strict_pipelined( + mut connection: Self, + config: Config, + ) -> crate::Result { + // 1. Build PRELOGIN — match ODBC Driver 18 format: + // 6 options: VERSION, ENCRYPTION, INSTOPT, THREADID, MARS, TRACEID + // NO FEDAUTHREQUIRED (ODBC doesn't send it to backends) + let mut prelogin_msg = PreloginMessage::new(); + prelogin_msg.encryption = EncryptionLevel::Strict; + // Do NOT set fed_auth_required — ODBC omits it for backend PRELOGIN + prelogin_msg.fed_auth_required = false; + // Include TRACEID (36 bytes) — ODBC always sends this + prelogin_msg.include_trace_id = true; + // Include instance name from routing redirect in PRELOGIN to backend. + // The backend uses this to identify which database instance to connect to. + prelogin_msg.instance_name = config.instance_name.clone(); + + // 2. Build LOGIN (using assumed fed_auth_required=true, nonce=None — + // standard for Fabric backends). + // MS-TDS spec says TDS 8.0 (0x08000000) for strict mode. + let mut login_message = LoginMessage::new(); + login_message.tds_version(FeatureLevel::SqlServer2022); + // Azure SQL / Fabric backends require the AZURESQLSUPPORT feature + // extension to indicate the client can handle Azure-specific tokens. + login_message.azure_sql_support(); + + // Keep the LOGIN minimal — match gateway LOGIN structure exactly. + // No hostname, clt_int_name, client_pid, or client_prog_ver overrides. + + if let Some(db) = config.database { + login_message.db_name(db); + } + + // Use login_server_name if set (original gateway hostname for routed + // connections), otherwise fall back to the connection host. + let server_name = config.login_server_name.or(config.host); + if let Some(sn) = server_name { + login_message.server_name(sn); + } + + if let Some(app_name) = config.application_name { + login_message.app_name(app_name); + } + + login_message.readonly(config.readonly); + + match config.auth { + AuthMethod::AADToken(token) => { + event!( + Level::INFO, + token_len = token.len(), + "Sending pipelined LOGIN with AAD token (fed_auth_required=true, nonce=None)" + ); + login_message.aad_token(token, true, None); + } + AuthMethod::SqlServer(auth) => { + login_message.user_name(auth.user().to_string()); + login_message.password(auth.password().to_string()); + } + AuthMethod::None => {} + #[cfg(any(windows, feature = "integrated-auth-gssapi"))] + _ => { + return Err(crate::Error::Protocol( + "Integrated auth not supported for strict-mode pipelined backend connection" + .into(), + )); + } + } + + // 3. Feed PRELOGIN packet(s) to the transport buffer (no flush yet) + let packet_size = + (connection.context.packet_size() as usize) - crate::tds::HEADER_BYTES; + + let mut prelogin_payload = BytesMut::new(); + prelogin_msg.encode(&mut prelogin_payload)?; + + let prelogin_id = connection.context.next_packet_id(); + let mut prelogin_header = PacketHeader::pre_login(prelogin_id); + prelogin_header.set_status(PacketStatus::EndOfMessage); + connection + .feed_to_wire(prelogin_header, prelogin_payload) + .await?; + + // 4. Feed LOGIN packet(s) to the transport buffer (no flush yet) + let mut login_payload = BytesMut::new(); + login_message.encode(&mut login_payload)?; + + let login_id = connection.context.next_packet_id(); + let mut login_header = PacketHeader::login(login_id); + + while !login_payload.is_empty() { + let writable = cmp::min(login_payload.len(), packet_size); + let split_payload = login_payload.split_to(writable); + + if login_payload.is_empty() { + login_header.set_status(PacketStatus::EndOfMessage); + } else { + login_header.set_status(PacketStatus::NormalMessage); + } + + connection + .feed_to_wire(login_header, split_payload) + .await?; + } + + // 5. Single flush: send PRELOGIN+LOGIN together in one TLS write + connection.flush_sink().await?; + + // 6. Read PRELOGIN response + let _prelogin_response: PreloginMessage = codec::collect_from(&mut connection).await?; + + // The PRELOGIN response packet has EOM status which sets `flushed = true`. + // Reset it because we still expect the LOGIN response tokens to follow. + connection.flushed = false; + + // 7. The LOGIN response tokens will be read by flush_done() in the caller + Ok(connection) + } + /// Flush the incoming token stream until receiving `DONE` token. async fn flush_done(&mut self) -> crate::Result { TokenStream::new(self).flush_done().await @@ -223,7 +375,11 @@ impl Connection { split_payload.len() + HEADER_BYTES, ); - self.write_to_wire(header, split_payload).await?; + // Buffer the packet without flushing. This ensures multi-packet + // messages are sent in a single TLS write, which is required by + // some backends (e.g., Microsoft Fabric). + let packet = Packet::new(header, split_payload); + self.transport.feed(packet).await?; } self.flush_sink().await?; @@ -250,6 +406,43 @@ impl Connection { Ok(()) } + /// Feeds a packet to the transport buffer WITHOUT flushing. + /// Use `flush_sink()` after feeding all packets to send them in one batch. + #[allow(dead_code)] + async fn feed_to_wire( + &mut self, + header: PacketHeader, + data: BytesMut, + ) -> crate::Result<()> { + self.flushed = false; + + // Debug: dump the raw TDS frame (header + payload) to a file + if let Ok(mut f) = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open("/tmp/opencode/tiberius_raw_tds.bin") + { + use std::io::Write as _; + // Encode the header + data as the PacketCodec would + let mut hdr_buf = BytesMut::with_capacity(8); + use bytes::BufMut as _; + let pkt_len = (data.len() + 8) as u16; + hdr_buf.put_u8(header.r#type() as u8); + hdr_buf.put_u8(header.status() as u8); + hdr_buf.put_u16(pkt_len); + hdr_buf.put_u16(0); // spid + hdr_buf.put_u8(0); // id placeholder (not accessible) + hdr_buf.put_u8(0); // window + let _ = f.write_all(&hdr_buf); + let _ = f.write_all(&data); + } + + let packet = Packet::new(header, data); + self.transport.feed(packet).await?; + + Ok(()) + } + /// Sends all pending packages to the wire. pub(crate) async fn flush_sink(&mut self) -> crate::Result<()> { self.transport.flush().await @@ -303,10 +496,14 @@ impl Connection { &mut self, encryption: EncryptionLevel, fed_auth_required: bool, + instance_name: Option, + include_trace_id: bool, ) -> crate::Result { let mut msg = PreloginMessage::new(); msg.encryption = encryption; msg.fed_auth_required = fed_auth_required; + msg.instance_name = instance_name.clone(); + msg.include_trace_id = include_trace_id; let id = self.context.next_packet_id(); self.send(PacketHeader::pre_login(id), msg).await?; @@ -314,6 +511,15 @@ impl Connection { let response: PreloginMessage = codec::collect_from(self).await?; // threadid (should be empty when sent from server to client) debug_assert_eq!(response.thread_id, 0); + event!( + Level::INFO, + version = response.version, + sub_build = response.sub_build, + encryption = ?response.encryption, + fed_auth_required = response.fed_auth_required, + has_nonce = response.nonce.is_some(), + "PRELOGIN response received" + ); Ok(response) } @@ -332,6 +538,18 @@ impl Connection { ) -> crate::Result { let mut login_message = LoginMessage::new(); + // MS-TDS spec: TDS 8.0 strict mode requires version 0x08000000 in + // LOGIN7. This tells the server the client supports the TDS 8 transport. + if encryption == EncryptionLevel::Strict { + login_message.tds_version(FeatureLevel::SqlServer2022); + // Azure SQL / Fabric backends require the AZURESQLSUPPORT feature + // extension to indicate the client can handle Azure-specific tokens. + login_message.azure_sql_support(); + // Set client interface name — ODBC sends "ODBC"; some backends + // may require a non-empty value. + login_message.clt_int_name("tiberius"); + } + if let Some(db) = db { login_message.db_name(db); } @@ -458,7 +676,15 @@ impl Connection { self = self.post_login_encryption(encryption); } AuthMethod::AADToken(token) => { + event!( + Level::INFO, + fed_auth_echo = prelogin.fed_auth_required, + has_nonce = prelogin.nonce.is_some(), + token_len = token.len(), + "Sending LOGIN with AAD token" + ); login_message.aad_token(token, prelogin.fed_auth_required, prelogin.nonce); + let id = self.context.next_packet_id(); self.send(PacketHeader::login(id), login_message).await?; self = self.post_login_encryption(encryption); @@ -594,13 +820,14 @@ impl Stream for Connection { fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { let this = self.get_mut(); - match ready!(this.transport.try_poll_next_unpin(cx)) { - Some(Ok(packet)) => { + match this.transport.try_poll_next_unpin(cx) { + Poll::Ready(Some(Ok(packet))) => { this.flushed = packet.is_last(); Poll::Ready(Some(Ok(packet))) } - Some(Err(e)) => Poll::Ready(Some(Err(e))), - None => Poll::Ready(None), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } } diff --git a/src/client/tls_stream/native_tls_stream.rs b/src/client/tls_stream/native_tls_stream.rs index cf5591d80..9d113087d 100644 --- a/src/client/tls_stream/native_tls_stream.rs +++ b/src/client/tls_stream/native_tls_stream.rs @@ -1,6 +1,7 @@ use crate::{ client::{config::Config, TrustConfig}, error::{Error, IoErrorKind}, + EncryptionLevel, }; pub(crate) use async_native_tls::TlsStream; use async_native_tls::{Certificate, TlsConnector}; @@ -12,6 +13,59 @@ pub(crate) async fn create_tls_stream( config: &Config, stream: S, ) -> crate::Result> { + // For TDS 8 strict mode, we need to set ALPN to "ms-tds" so the gateway + // knows to proxy the connection rather than just redirecting. + if config.encryption == EncryptionLevel::Strict { + let mut native_builder = native_tls_crate::TlsConnector::builder(); + native_builder.request_alpns(&["ms-tds"]); + + match &config.trust { + TrustConfig::CaCertificateLocation(path) => { + if let Ok(buf) = fs::read(path) { + let cert = match path.extension() { + Some(ext) + if ext.to_ascii_lowercase() == "pem" + || ext.to_ascii_lowercase() == "crt" => + { + Some(native_tls_crate::Certificate::from_pem(&buf)?) + } + Some(ext) if ext.to_ascii_lowercase() == "der" => { + Some(native_tls_crate::Certificate::from_der(&buf)?) + } + Some(_) | None => { + return Err(Error::Io { + kind: IoErrorKind::InvalidInput, + message: "Provided CA certificate with unsupported file-extension! Supported types are pem, crt and der.".to_string(), + }) + } + }; + if let Some(c) = cert { + native_builder.add_root_certificate(c); + } + } else { + return Err(Error::Io { + kind: IoErrorKind::InvalidData, + message: "Could not read provided CA certificate!".to_string(), + }); + } + } + TrustConfig::TrustAll => { + event!( + Level::WARN, + "Trusting the server certificate without validation." + ); + native_builder.danger_accept_invalid_certs(true); + native_builder.danger_accept_invalid_hostnames(true); + } + TrustConfig::Default => { + event!(Level::INFO, "Using default trust configuration."); + } + } + + let connector: TlsConnector = native_builder.into(); + return Ok(connector.connect(config.get_host(), stream).await?); + } + let mut builder = TlsConnector::new(); match &config.trust { diff --git a/src/tds/codec/login.rs b/src/tds/codec/login.rs index 265db381e..226ee20fd 100644 --- a/src/tds/codec/login.rs +++ b/src/tds/codec/login.rs @@ -7,7 +7,7 @@ use std::fmt::Debug; use std::{borrow::Cow, io}; uint_enum! { - #[repr(u32)] + #[repr(u32)] #[derive(PartialOrd)] pub enum FeatureLevel { SqlServerV7 = 0x70000000, @@ -18,6 +18,8 @@ uint_enum! { SqlServer2008R2 = 0x730B0003, /// 2012, 2014, 2016 SqlServerN = 0x74000004, + /// TDS 8.0 strict transport encryption (required for Microsoft Fabric) + SqlServer2022 = 0x08000000, } } @@ -29,12 +31,22 @@ impl Default for FeatureLevel { impl FeatureLevel { pub fn done_row_count_bytes(self) -> u8 { - if self as u32 >= FeatureLevel::SqlServer2005 as u32 { + // TDS 8.0 (0x08000000) is numerically lower than 7.x versions but is + // functionally equivalent to SqlServerN for row count encoding. + if self == FeatureLevel::SqlServer2022 + || self as u32 >= FeatureLevel::SqlServer2005 as u32 + { 8 } else { 4 } } + + /// Returns true if this version uses modern (post-2005) wire formats. + pub fn is_modern(self) -> bool { + self == FeatureLevel::SqlServer2022 + || self as u32 >= FeatureLevel::SqlServer2005 as u32 + } } #[bitflags] @@ -130,6 +142,9 @@ pub enum LoginTypeFlag { } pub(crate) const FEA_EXT_FEDAUTH: u8 = 0x02u8; +pub(crate) const FEA_EXT_COLUMNENCRYPTION: u8 = 0x04u8; +pub(crate) const FEA_EXT_AZURESQLSUPPORT: u8 = 0x08u8; +pub(crate) const FEA_EXT_UTF8_SUPPORT: u8 = 0x0Au8; pub(crate) const FEA_EXT_TERMINATOR: u8 = 0xFFu8; pub(crate) const FED_AUTH_LIBRARYSECURITYTOKEN: u8 = 0x01; @@ -172,15 +187,22 @@ pub struct LoginMessage<'a> { server_name: Cow<'a, str>, /// the default database to connect to db_name: Cow<'a, str>, + /// client interface name (e.g., "ODBC") + clt_int_name: Cow<'a, str>, fed_auth_ext: Option>, + /// Whether to include AZURESQLSUPPORT (0x08) feature extension. + /// Required for Azure SQL Database and Microsoft Fabric backends. + azure_sql_support: bool, } impl<'a> LoginMessage<'a> { pub fn new() -> LoginMessage<'a> { Self { packet_size: 4096, + client_lcid: 0x0409, // English US — required by some Azure backends option_flags_1: OptionFlag1::UseDbNotify | OptionFlag1::InitDbFatal, option_flags_2: OptionFlag2::InitLangFatal | OptionFlag2::OdbcDriver, + type_flags: BitFlags::from_flag(LoginTypeFlag::UseTSQL), option_flags_3: BitFlags::from_flag(OptionFlag3::UnknownCollationHandling), app_name: "tiberius".into(), ..Default::default() @@ -210,6 +232,10 @@ impl<'a> LoginMessage<'a> { self.server_name = server_name.into(); } + pub fn tds_version(&mut self, version: FeatureLevel) { + self.tds_version = version; + } + pub fn user_name(&mut self, user_name: impl Into>) { self.username = user_name.into(); } @@ -240,6 +266,30 @@ impl<'a> LoginMessage<'a> { self.type_flags.remove(LoginTypeFlag::ReadOnlyIntent); } } + + pub fn hostname(&mut self, hostname: impl Into>) { + self.hostname = hostname.into(); + } + + pub fn client_pid(&mut self, pid: u32) { + self.client_pid = pid; + } + + pub fn client_prog_ver(&mut self, ver: u32) { + self.client_prog_ver = ver; + } + + pub fn clt_int_name(&mut self, name: impl Into>) { + self.clt_int_name = name.into(); + } + + /// Enable the AZURESQLSUPPORT (0x08) feature extension. + /// Required by Azure SQL Database and Microsoft Fabric backends. + /// Signals the client supports federated auth info tokens and DNS caching. + pub fn azure_sql_support(&mut self) { + self.option_flags_3.insert(OptionFlag3::ExtensionUsed); + self.azure_sql_support = true; + } } impl<'a> Encode for LoginMessage<'a> { @@ -271,7 +321,7 @@ impl<'a> Encode for LoginMessage<'a> { &self.app_name, &self.server_name, &"".into(), // 5. ibExtension - &"".into(), // ibCltIntName + &self.clt_int_name, // ibCltIntName &"".into(), // ibLanguage &self.db_name, &"".into(), // 9. ClientId (6 bytes); this is included in var_data so we don't lack the bytes of cbSspiLong (4=2*2) and can insert it at the correct position @@ -349,9 +399,10 @@ impl<'a> Encode for LoginMessage<'a> { // cbSSPILong cursor.write_u32::(0)?; - // FeatureExt - if let Some(fed_auth_ext) = self.fed_auth_ext { - // update fea_ext_offset + // FeatureExt — written when either FEDAUTH or AZURESQLSUPPORT is needed + let has_feature_ext = self.fed_auth_ext.is_some() || self.azure_sql_support; + if has_feature_ext { + // update fea_ext_offset (ibExtension offset/length in variable data header) cursor.set_position(fea_ext_offset); cursor.write_u16::(data_offset as u16)?; cursor.write_u16::(4)?; @@ -360,32 +411,45 @@ impl<'a> Encode for LoginMessage<'a> { data_offset += 4; cursor.write_u32::(data_offset as u32)?; - cursor.write_u8(FEA_EXT_FEDAUTH)?; + // Write FEDAUTH feature extension if present + if let Some(fed_auth_ext) = self.fed_auth_ext { + cursor.write_u8(FEA_EXT_FEDAUTH)?; - let mut token = Cursor::new(Vec::new()); - for codepoint in fed_auth_ext.fed_auth_token.encode_utf16() { - token.write_u16::(codepoint)?; - } - let token = token.into_inner(); + let mut token = Cursor::new(Vec::new()); + for codepoint in fed_auth_ext.fed_auth_token.encode_utf16() { + token.write_u16::(codepoint)?; + } + let token = token.into_inner(); - // options (1) + TokenLength(4) + Token.length + nonce.length - let feature_ext_length = - 1 + 4 + token.len() + if fed_auth_ext.nonce.is_some() { 32 } else { 0 }; + // options (1) + TokenLength(4) + Token.length + nonce.length + let feature_ext_length = + 1 + 4 + token.len() + if fed_auth_ext.nonce.is_some() { 32 } else { 0 }; - cursor.write_u32::(feature_ext_length as u32)?; + cursor.write_u32::(feature_ext_length as u32)?; - let mut options: u8 = FED_AUTH_LIBRARYSECURITYTOKEN << 1; - if fed_auth_ext.fed_auth_echo { - options |= 1 // fFedAuthEcho - } + let mut options: u8 = FED_AUTH_LIBRARYSECURITYTOKEN << 1; + if fed_auth_ext.fed_auth_echo { + options |= 1 // fFedAuthEcho + } + + cursor.write_u8(options)?; - cursor.write_u8(options)?; + cursor.write_u32::(token.len() as u32)?; + cursor.write_all(token.as_slice())?; - cursor.write_u32::(token.len() as u32)?; - cursor.write_all(token.as_slice())?; + if let Some(nonce) = fed_auth_ext.nonce { + cursor.write_all(nonce.as_ref())?; + } + } - if let Some(nonce) = fed_auth_ext.nonce { - cursor.write_all(nonce.as_ref())?; + // Write AZURESQLSUPPORT feature extension if enabled + if self.azure_sql_support { + cursor.write_u8(FEA_EXT_AZURESQLSUPPORT)?; + // Feature data length: 1 byte + cursor.write_u32::(1)?; + // Feature data: 0x01 = fSQLDNSCaching (client supports + // federated auth info token and DNS caching) + cursor.write_u8(0x01)?; } cursor.write_u8(FEA_EXT_TERMINATOR)?; @@ -546,6 +610,14 @@ mod tests { nonce, }; ret.fed_auth_ext = Some(fed_auth_ext); + } else if fe == FEA_EXT_COLUMNENCRYPTION + || fe == FEA_EXT_AZURESQLSUPPORT + || fe == FEA_EXT_UTF8_SUPPORT + { + // Skip known extensions by reading their length + data + let fea_ext_len = cursor.read_u32::()?; + let pos = cursor.position(); + cursor.set_position(pos + fea_ext_len as u64); } else { unimplemented!("unsupported feature ext {:?}", fe); } diff --git a/src/tds/codec/packet.rs b/src/tds/codec/packet.rs index 9927ed35d..7b13aa0f1 100644 --- a/src/tds/codec/packet.rs +++ b/src/tds/codec/packet.rs @@ -25,11 +25,18 @@ impl Encode for Packet { fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { let size = (self.payload.len() as u16 + HEADER_BYTES as u16).to_be_bytes(); + // Remember offset where this packet starts. When multiple packets + // are buffered before a flush (pipelining), dst may already contain + // previously encoded packets. + let offset = dst.len(); + self.header.encode(dst)?; dst.extend(self.payload); - dst[2] = size[0]; - dst[3] = size[1]; + // Patch the length field (bytes 2-3 of the TDS header) at the + // correct position within this packet, not at absolute buffer start. + dst[offset + 2] = size[0]; + dst[offset + 3] = size[1]; Ok(()) } diff --git a/src/tds/codec/pre_login.rs b/src/tds/codec/pre_login.rs index 1913eef6f..4dafa254d 100644 --- a/src/tds/codec/pre_login.rs +++ b/src/tds/codec/pre_login.rs @@ -4,7 +4,7 @@ use crate::{tds, Error, Result}; use byteorder::{BigEndian, LittleEndian, ReadBytesExt, WriteBytesExt}; use bytes::{BufMut, BytesMut}; use std::convert::TryFrom; -use std::io::{Cursor, Read}; +use std::io::{Cursor, Read, Write}; use tds::EncryptionLevel; use uuid::Uuid; @@ -34,8 +34,10 @@ pub struct PreloginMessage { pub thread_id: u32, /// token=0x04 pub mars: bool, - /// token=0x05 + /// token=0x05: TRACEID — connection GUID (16) + activity GUID (16) + sequence (4) = 36 bytes pub activity_id: Option, + /// token=0x05: If true, encode TRACEID with random GUIDs in the client PRELOGIN + pub include_trace_id: bool, /// token=0x06 pub fed_auth_required: bool, pub nonce: Option<[u8; 32]>, @@ -52,6 +54,7 @@ impl PreloginMessage { thread_id: 0, mars: false, activity_id: None, + include_trace_id: false, fed_auth_required: false, nonce: None, } @@ -112,14 +115,26 @@ impl Encode for PreloginMessage { // encryption fields.push((PRELOGIN_ENCRYPTION, 0x01)); // encryption - // In TDS 8 strict mode, TLS is already established before PRELOGIN. - // Send ENCRYPT_ON (0x01) on the wire since strict is not a valid wire value. + // In TDS 8 strict mode, the wire value must be ENCRYPT_STRICT (0x08) + // per MS-TDS spec. Other values map directly to their enum discriminant. let encryption_wire_value = match self.encryption { - EncryptionLevel::Strict => EncryptionLevel::On as u8, + EncryptionLevel::Strict => 0x08u8, other => other as u8, }; data_cursor.write_u8(encryption_wire_value)?; + // instance name (INSTOPT) — null-terminated ASCII string + { + let inst_bytes: Vec = match &self.instance_name { + Some(name) => name.as_bytes().to_vec(), + None => Vec::new(), + }; + // length = instance name bytes + null terminator + fields.push((PRELOGIN_INSTOPT, (inst_bytes.len() + 1) as u16)); + data_cursor.write_all(&inst_bytes)?; + data_cursor.write_u8(0x00)?; // null terminator + } + // threadid fields.push((PRELOGIN_THREADID, 0x04)); // thread id data_cursor.write_u32::(self.thread_id)?; @@ -128,6 +143,25 @@ impl Encode for PreloginMessage { fields.push((PRELOGIN_MARS, 0x01)); // MARS data_cursor.write_u8(self.mars as u8)?; + // TRACEID: connection GUID (16) + activity GUID (16) + sequence (4) = 36 bytes + // ODBC Driver 18 always sends TRACEID to Fabric backends. + if self.include_trace_id { + fields.push((PRELOGIN_TRACEID, 36)); + // Generate random connection and activity GUIDs + let conn_id = Uuid::new_v4(); + let activity_id = Uuid::new_v4(); + // Write connection ID as MS-ordered GUID (reordered bytes) + let mut conn_bytes = *conn_id.as_bytes(); + reorder_bytes(&mut conn_bytes); + data_cursor.write_all(&conn_bytes)?; + // Write activity ID as MS-ordered GUID + let mut act_bytes = *activity_id.as_bytes(); + reorder_bytes(&mut act_bytes); + data_cursor.write_all(&act_bytes)?; + // Sequence number (u32 LE) + data_cursor.write_u32::(0)?; + } + // fed auth if self.fed_auth_required { fields.push((PRELOGIN_FEDAUTHREQUIRED, 0x01)); @@ -189,9 +223,17 @@ impl Decode for PreloginMessage { // encryption PRELOGIN_ENCRYPTION => { let encrypt = cursor.read_u8()?; - ret.encryption = tds::EncryptionLevel::try_from(encrypt).map_err(|_| { - Error::Protocol(format!("invalid encryption value: {}", encrypt).into()) - })?; + // Wire value 0x08 = ENCRYPT_STRICT (TDS 8.0), maps to our Strict variant + let level = if encrypt == 0x08 { + tds::EncryptionLevel::Strict + } else { + tds::EncryptionLevel::try_from(encrypt).map_err(|_| { + Error::Protocol( + format!("invalid encryption value: {}", encrypt).into(), + ) + })? + }; + ret.encryption = level; } // instance name PRELOGIN_INSTOPT => { diff --git a/src/tds/codec/token/token_error.rs b/src/tds/codec/token/token_error.rs index d1e435a77..a294b53d2 100644 --- a/src/tds/codec/token/token_error.rs +++ b/src/tds/codec/token/token_error.rs @@ -1,4 +1,4 @@ -use crate::{tds::codec::FeatureLevel, SqlReadBytes}; +use crate::SqlReadBytes; use std::fmt; #[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)] @@ -32,7 +32,7 @@ impl TokenError { let server = src.read_b_varchar().await?; let procedure = src.read_b_varchar().await?; - let line = if src.context().version() > FeatureLevel::SqlServer2005 { + let line = if src.context().version().is_modern() { src.read_u32_le().await? } else { src.read_u16_le().await? as u32 diff --git a/src/tds/codec/token/token_feature_ext_ack.rs b/src/tds/codec/token/token_feature_ext_ack.rs index 1ba108f99..12760ed80 100644 --- a/src/tds/codec/token/token_feature_ext_ack.rs +++ b/src/tds/codec/token/token_feature_ext_ack.rs @@ -1,4 +1,4 @@ -use crate::{SqlReadBytes, FEA_EXT_FEDAUTH, FEA_EXT_TERMINATOR}; +use crate::{SqlReadBytes, FEA_EXT_AZURESQLSUPPORT, FEA_EXT_COLUMNENCRYPTION, FEA_EXT_FEDAUTH, FEA_EXT_TERMINATOR, FEA_EXT_UTF8_SUPPORT}; use futures_util::AsyncReadExt; #[derive(Debug)] @@ -16,6 +16,14 @@ pub enum FedAuthAck { #[allow(dead_code)] pub enum FeatureAck { FedAuth(FedAuthAck), + /// Azure SQL Support acknowledgment from the server. + AzureSqlSupport(Vec), + /// Column Encryption acknowledgment. + ColumnEncryption(Vec), + /// UTF-8 Support acknowledgment. + Utf8Support(Vec), + /// Unknown feature — stored for forward-compatibility. + Unknown { feature_id: u8, data: Vec }, } impl TokenFeatureExtAck { @@ -44,8 +52,35 @@ impl TokenFeatureExtAck { }; features.push(FeatureAck::FedAuth(FedAuthAck::SecurityToken { nonce })) + } else if feature_id == FEA_EXT_AZURESQLSUPPORT { + let data_len = src.read_u32_le().await? as usize; + let mut data = vec![0u8; data_len]; + if data_len > 0 { + src.read_exact(&mut data).await?; + } + features.push(FeatureAck::AzureSqlSupport(data)); + } else if feature_id == FEA_EXT_COLUMNENCRYPTION { + let data_len = src.read_u32_le().await? as usize; + let mut data = vec![0u8; data_len]; + if data_len > 0 { + src.read_exact(&mut data).await?; + } + features.push(FeatureAck::ColumnEncryption(data)); + } else if feature_id == FEA_EXT_UTF8_SUPPORT { + let data_len = src.read_u32_le().await? as usize; + let mut data = vec![0u8; data_len]; + if data_len > 0 { + src.read_exact(&mut data).await?; + } + features.push(FeatureAck::Utf8Support(data)); } else { - unimplemented!("unsupported feature {}", feature_id) + // Unknown feature — skip gracefully by reading data_len bytes + let data_len = src.read_u32_le().await? as usize; + let mut data = vec![0u8; data_len]; + if data_len > 0 { + src.read_exact(&mut data).await?; + } + features.push(FeatureAck::Unknown { feature_id, data }); } } diff --git a/src/tds/context.rs b/src/tds/context.rs index 732bac15c..c85784a21 100644 --- a/src/tds/context.rs +++ b/src/tds/context.rs @@ -17,7 +17,7 @@ impl Context { Context { version: FeatureLevel::SqlServerN, packet_size: 4096, - packet_id: 0, + packet_id: 1, transaction_desc: [0; 8], last_meta: None, spn: None, diff --git a/tests/fabric.rs b/tests/fabric.rs new file mode 100644 index 000000000..1914a27d0 --- /dev/null +++ b/tests/fabric.rs @@ -0,0 +1,259 @@ +//! Integration tests for Microsoft Fabric SQL Data Warehouse connectivity. +//! +//! These tests verify that TDS 8 strict encryption (the `encrypt=strict` option) +//! works correctly when connecting to a Microsoft Fabric endpoint. +//! +//! # Required environment variables +//! +//! - `FABRIC_ENDPOINT`: The Fabric SQL endpoint (e.g., `my-workspace.datawarehouse.fabric.microsoft.com`) +//! - `FABRIC_DATABASE`: The database name in Fabric +//! +//! Authentication (one of the following): +//! - `FABRIC_AAD_TOKEN`: A pre-obtained AAD/Entra ID token for the `https://database.windows.net/` scope +//! - Or: `FABRIC_CLIENT_ID`, `FABRIC_CLIENT_SECRET`, `FABRIC_TENANT_ID` for service principal auth +//! +//! # Running +//! +//! ```sh +//! # With a pre-obtained token (e.g., from `az account get-access-token`): +//! export FABRIC_ENDPOINT=my-workspace.datawarehouse.fabric.microsoft.com +//! export FABRIC_DATABASE=my-database +//! export FABRIC_AAD_TOKEN=$(az account get-access-token --resource https://database.windows.net/ --query accessToken -o tsv) +//! cargo test --test fabric -- --nocapture +//! ``` + +use std::env; +use tiberius::{error::Error, AuthMethod, Client, Config, EncryptionLevel}; +use tokio::net::TcpStream; +use tokio_util::compat::TokioAsyncWriteCompatExt; + +/// Helper to skip tests when required environment variables are missing. +macro_rules! skip_if_no_fabric { + () => { + if env::var("FABRIC_ENDPOINT").is_err() { + eprintln!("SKIPPED: FABRIC_ENDPOINT not set. Set Fabric env vars to run this test."); + return Ok(()); + } + }; +} + +/// Obtain an AAD token for the Fabric SQL endpoint. +/// +/// Tries in order: +/// 1. `FABRIC_AAD_TOKEN` env var (pre-obtained token) +/// 2. Service principal credentials (`FABRIC_CLIENT_ID`, `FABRIC_CLIENT_SECRET`, `FABRIC_TENANT_ID`) +async fn get_aad_token() -> anyhow::Result { + // Option 1: Pre-obtained token from environment + if let Ok(token) = env::var("FABRIC_AAD_TOKEN") { + return Ok(token); + } + + // Option 2: Service principal client credentials flow + let client_id = env::var("FABRIC_CLIENT_ID") + .map_err(|_| anyhow::anyhow!("Neither FABRIC_AAD_TOKEN nor FABRIC_CLIENT_ID is set"))?; + let client_secret = env::var("FABRIC_CLIENT_SECRET") + .map_err(|_| anyhow::anyhow!("FABRIC_CLIENT_SECRET not set"))?; + let tenant_id = env::var("FABRIC_TENANT_ID") + .map_err(|_| anyhow::anyhow!("FABRIC_TENANT_ID not set"))?; + + use azure_identity::client_credentials_flow; + use oauth2::{ClientId, ClientSecret}; + use std::sync::Arc; + + let http_client = Arc::new(reqwest::Client::new()); + let token = client_credentials_flow::perform( + http_client, + &ClientId::new(client_id), + &ClientSecret::new(client_secret), + &["https://database.windows.net/.default"], + &tenant_id, + ) + .await?; + + Ok(token.access_token().secret().to_string()) +} + +/// Build a Config for connecting to Microsoft Fabric with TDS 8 strict encryption. +fn fabric_config(endpoint: &str, database: &str) -> tiberius::Result { + let conn_str = format!( + "server=tcp:{endpoint},1433;encrypt=strict;TrustServerCertificate=false;database={database}" + ); + Config::from_ado_string(&conn_str) +} + +/// Connect to Microsoft Fabric with TDS 8 strict encryption. +/// +/// Fabric uses a gateway that returns a routing redirect to a backend server. +/// The client must: +/// 1. Connect to the gateway with TDS 8 strict TLS +/// 2. Receive the routing redirect (Error::Routing) +/// 3. Reconnect to the backend with pipelined PRELOGIN+LOGIN +async fn connect_to_fabric( + endpoint: &str, + database: &str, + token: &str, +) -> anyhow::Result>> { + let _ = env_logger::try_init(); + + let mut config = fabric_config(endpoint, database)?; + config.authentication(AuthMethod::aad_token(token)); + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(Error::Routing { host, port }) => { + eprintln!( + "Received routing redirect to {}:{}, reconnecting...", + host, port + ); + + // The routing host may contain an instance name (e.g., + // "host.pbidedicated.windows.net\InstanceName"). Strip it for + // the TCP connection and TLS SNI; only the hostname is needed. + let backend_host = host.split('\\').next().unwrap_or(&host); + let instance_name = host.split('\\').nth(1); + + let mut backend_config = Config::new(); + backend_config.host(backend_host); + backend_config.port(port); + backend_config.encryption(EncryptionLevel::Strict); + backend_config.authentication(AuthMethod::aad_token(token)); + backend_config.database(database); + // Include instance name in PRELOGIN so backend knows which instance to route to + if let Some(inst) = instance_name { + backend_config.instance_name(inst); + } + // MS-TDS spec: LOGIN server_name must be the ORIGINAL endpoint, not routing target + backend_config.login_server_name(endpoint); + + let tcp = TcpStream::connect(backend_config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let client = Client::connect(backend_config, tcp.compat_write()).await?; + Ok(client) + } + Err(e) => Err(e.into()), + } +} + +/// Test: Connect to Fabric with TDS 8 strict encryption and run a basic query. +#[tokio::test] +async fn connect_to_fabric_strict_encryption() -> anyhow::Result<()> { + skip_if_no_fabric!(); + + let endpoint = env::var("FABRIC_ENDPOINT")?; + let database = env::var("FABRIC_DATABASE")?; + let token = get_aad_token().await?; + + let mut client = connect_to_fabric(&endpoint, &database, &token).await?; + + // Simple connectivity test + let row = client + .query("SELECT 1 AS test_value", &[]) + .await? + .into_row() + .await? + .unwrap(); + + assert_eq!(Some(1i32), row.get("test_value")); + + Ok(()) +} + +/// Test: Verify that the ADO.NET connection string parsing accepts `encrypt=strict`. +#[tokio::test] +async fn fabric_config_parses_strict_encryption() -> anyhow::Result<()> { + // This should parse without error - strict is a valid encryption level + let _config = Config::from_ado_string( + "server=tcp:test.datawarehouse.fabric.microsoft.com,1433;encrypt=strict;database=testdb", + )?; + Ok(()) +} + +/// Test: Run a query that exercises Fabric-specific metadata. +#[tokio::test] +async fn fabric_query_database_metadata() -> anyhow::Result<()> { + skip_if_no_fabric!(); + + let endpoint = env::var("FABRIC_ENDPOINT")?; + let database = env::var("FABRIC_DATABASE")?; + let token = get_aad_token().await?; + + let mut client = connect_to_fabric(&endpoint, &database, &token).await?; + + // Query current database name to verify we connected to the right database + let row = client + .query("SELECT DB_NAME() AS current_db", &[]) + .await? + .into_row() + .await? + .unwrap(); + + let db_name: Option<&str> = row.get("current_db"); + assert!( + db_name.is_some(), + "Should be able to query the current database name" + ); + eprintln!("Connected to database: {:?}", db_name.unwrap()); + + Ok(()) +} + +/// Test: Verify multiple sequential queries work over a strict TLS connection. +#[tokio::test] +async fn fabric_multiple_queries() -> anyhow::Result<()> { + skip_if_no_fabric!(); + + let endpoint = env::var("FABRIC_ENDPOINT")?; + let database = env::var("FABRIC_DATABASE")?; + let token = get_aad_token().await?; + + let mut client = connect_to_fabric(&endpoint, &database, &token).await?; + + // First query + let row = client + .query("SELECT 42 AS answer", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(42i32), row.get("answer")); + + // Second query - verifies the connection stays healthy + let row = client + .query("SELECT CAST(GETDATE() AS NVARCHAR(50)) AS server_time", &[]) + .await? + .into_row() + .await? + .unwrap(); + let time: Option<&str> = row.get("server_time"); + assert!(time.is_some(), "Should get server time as string"); + + // Third query with parameters + let row = client + .query("SELECT @P1 + @P2 AS sum_result", &[&10i32, &32i32]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(42i32), row.get("sum_result")); + + Ok(()) +} + +/// Test: Verify that the JDBC-style connection string also accepts `encrypt=strict`. +#[tokio::test] +async fn fabric_jdbc_connection_string() -> anyhow::Result<()> { + let config = Config::from_jdbc_string( + "jdbc:sqlserver://test.datawarehouse.fabric.microsoft.com:1433;encrypt=strict;databaseName=testdb", + )?; + + // Verify host:port were parsed correctly via the public get_addr() method + assert_eq!( + config.get_addr(), + "test.datawarehouse.fabric.microsoft.com:1433" + ); + Ok(()) +} From 63f67e082651fa10a4e91aafd83df2bf85a2365a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 14:49:29 +0200 Subject: [PATCH 03/15] fix: per-packet TLS flush and packet ID increment for Azure SQL TDS 8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Key fixes for Azure SQL Database TDS 8 strict mode compatibility: 1. Per-packet TLS record alignment: Changed send() from batch flush (feed all + flush once) to per-packet send (feed+flush each). In TDS 8 strict mode, Azure SQL's gateway requires each TDS packet to arrive as a separate TLS record. Batching multiple TDS packets into a single TLS write caused connection resets. 2. Packet ID incrementing: Per MS-TDS 2.2.3.1, PacketID must increment by 1 (mod 256) for each packet within a multi-packet message. Previously all packets in a message reused the same ID. 3. TDS version in LOGIN7: Changed from 0x08000000 (SqlServer2022) to 0x74000004 (SqlServerN/TDS 7.4). TDS 8 is a transport-mode indicator, not a protocol version — the LOGIN7 version field must still report 7.4. 4. FEDAUTH library type: Using SECURITYTOKEN (0x01) for pre-obtained JWT tokens, not MSAL (0x02). MSAL type has a different FeatureData format without inline token. 5. Removed ALPN "ms-tds": Azure SQL Database gateways reject TLS connections that advertise this ALPN extension. Tested against: - Azure SQL Database (a test Azure SQL Database instance): gateway accepts LOGIN, returns routing redirect, backend accepts LOGIN with FEDAUTH token - Microsoft Fabric: all 5 integration tests pass (pipelined backend path unaffected by send() change) - 120 unit tests pass --- src/client/connection.rs | 51 ++++++++++++++++------ src/client/tls_stream/native_tls_stream.rs | 7 +-- src/tds/codec/header.rs | 8 ++++ src/tds/codec/login.rs | 3 ++ 4 files changed, 52 insertions(+), 17 deletions(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index 3c62bb114..a25984870 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -149,8 +149,13 @@ impl Connection { let is_backend = config.instance_name.is_some(); let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_)); + // In TDS 8 strict mode, send ENCRYPT_STRICT (0x08) on the wire. + // TLS is already established, and the PRELOGIN encryption field signals + // to the server that this is a TDS 8 strict mode connection. + let prelogin_encryption = config.encryption; + let prelogin = connection - .prelogin(config.encryption, fed_auth_required, config.instance_name.clone(), is_backend) + .prelogin(prelogin_encryption, fed_auth_required, config.instance_name.clone(), is_backend) .await?; // Use login_server_name if set (for routed connections, this is the @@ -273,6 +278,7 @@ impl Connection { let login_id = connection.context.next_packet_id(); let mut login_header = PacketHeader::login(login_id); + let mut is_first_login_pkt = true; while !login_payload.is_empty() { let writable = cmp::min(login_payload.len(), packet_size); let split_payload = login_payload.split_to(writable); @@ -283,6 +289,12 @@ impl Connection { login_header.set_status(PacketStatus::NormalMessage); } + // Per MS-TDS 2.2.3.1: PacketID increments by 1 within a message + if !is_first_login_pkt { + login_header.set_id(login_header.id().wrapping_add(1)); + } + is_first_login_pkt = false; + connection .feed_to_wire(login_header, split_payload) .await?; @@ -359,6 +371,8 @@ impl Connection { let mut payload = BytesMut::new(); item.encode(&mut payload)?; + let mut is_first = true; + while !payload.is_empty() { let writable = cmp::min(payload.len(), packet_size); let split_payload = payload.split_to(writable); @@ -369,21 +383,28 @@ impl Connection { header.set_status(PacketStatus::NormalMessage); } + // Per MS-TDS 2.2.3.1: PacketID is incremented by 1 (mod 256) + // for each packet within a message. The first packet uses the + // ID from the header as-is; subsequent packets increment. + if !is_first { + header.set_id(header.id().wrapping_add(1)); + } + is_first = false; + event!( Level::TRACE, - "Sending a packet ({} bytes)", + "Sending a packet ({} bytes, id={})", split_payload.len() + HEADER_BYTES, + header.id(), ); - // Buffer the packet without flushing. This ensures multi-packet - // messages are sent in a single TLS write, which is required by - // some backends (e.g., Microsoft Fabric). + // Send each packet individually (feed + flush). In TDS 8 strict + // mode, each TDS packet must be written as a separate TLS record. + // Using send() (vs feed + batch flush) ensures this. let packet = Packet::new(header, split_payload); - self.transport.feed(packet).await?; + self.transport.send(packet).await?; } - self.flush_sink().await?; - Ok(()) } @@ -538,16 +559,18 @@ impl Connection { ) -> crate::Result { let mut login_message = LoginMessage::new(); - // MS-TDS spec: TDS 8.0 strict mode requires version 0x08000000 in - // LOGIN7. This tells the server the client supports the TDS 8 transport. + // TDS 8 strict mode: the transport uses TLS-first, but the LOGIN7 + // protocol version remains TDS 7.4 (SqlServerN = 0x74000004). The TDS 8 + // "version" is a transport-mode indicator, not a protocol version. + // Azure SQL gateways may not recognize 0x08000000 and misparse the LOGIN. if encryption == EncryptionLevel::Strict { - login_message.tds_version(FeatureLevel::SqlServer2022); - // Azure SQL / Fabric backends require the AZURESQLSUPPORT feature - // extension to indicate the client can handle Azure-specific tokens. - login_message.azure_sql_support(); + login_message.tds_version(FeatureLevel::SqlServerN); // Set client interface name — ODBC sends "ODBC"; some backends // may require a non-empty value. login_message.clt_int_name("tiberius"); + // Azure SQL / Fabric backends require the AZURESQLSUPPORT feature + // extension to indicate the client can handle Azure-specific tokens. + login_message.azure_sql_support(); } if let Some(db) = db { diff --git a/src/client/tls_stream/native_tls_stream.rs b/src/client/tls_stream/native_tls_stream.rs index 9d113087d..34928036d 100644 --- a/src/client/tls_stream/native_tls_stream.rs +++ b/src/client/tls_stream/native_tls_stream.rs @@ -13,11 +13,12 @@ pub(crate) async fn create_tls_stream( config: &Config, stream: S, ) -> crate::Result> { - // For TDS 8 strict mode, we need to set ALPN to "ms-tds" so the gateway - // knows to proxy the connection rather than just redirecting. + // For TDS 8 strict mode, we perform a direct TLS handshake (no TDS + // wrapping). We use the native-tls builder directly for more control. + // Note: ALPN "ms-tds" is NOT sent — Azure SQL Database gateways reject + // connections that advertise it. Fabric works fine without it too. if config.encryption == EncryptionLevel::Strict { let mut native_builder = native_tls_crate::TlsConnector::builder(); - native_builder.request_alpns(&["ms-tds"]); match &config.trust { TrustConfig::CaCertificateLocation(path) => { diff --git a/src/tds/codec/header.rs b/src/tds/codec/header.rs index 719fc158b..547ddb308 100644 --- a/src/tds/codec/header.rs +++ b/src/tds/codec/header.rs @@ -112,6 +112,14 @@ impl PacketHeader { self.status = status; } + pub fn id(&self) -> u8 { + self.id + } + + pub fn set_id(&mut self, id: u8) { + self.id = id; + } + pub fn set_type(&mut self, ty: PacketType) { self.ty = ty; } diff --git a/src/tds/codec/login.rs b/src/tds/codec/login.rs index 226ee20fd..b1554bac5 100644 --- a/src/tds/codec/login.rs +++ b/src/tds/codec/login.rs @@ -147,6 +147,9 @@ pub(crate) const FEA_EXT_AZURESQLSUPPORT: u8 = 0x08u8; pub(crate) const FEA_EXT_UTF8_SUPPORT: u8 = 0x0Au8; pub(crate) const FEA_EXT_TERMINATOR: u8 = 0xFFu8; pub(crate) const FED_AUTH_LIBRARYSECURITYTOKEN: u8 = 0x01; +/// MSAL/ADAL library type for Azure AD token authentication. +/// Used when providing pre-obtained JWT access tokens from MSAL or az CLI. +pub(crate) const FED_AUTH_LIBRARY_MSAL: u8 = 0x02; /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac #[derive(Debug, Clone, Default)] From 45a3d33a65f5fbe9b38a8aa475682c7916588b35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 15:28:59 +0200 Subject: [PATCH 04/15] feat: Azure SQL TDS 8 integration tests and dead code cleanup - Add tests/azure_sql.rs: 5 integration tests for Azure SQL Database with TDS 8 strict encryption on the gateway + regular TLS-upgrade on the backend worker (the correct pattern for Azure SQL) - Gateway (port 1433): accepts TDS 8 strict (TLS-first) - Backend worker (port 11010): requires regular TLS-upgrade - Tests: connectivity, metadata, multiple queries, DDL/DML, regular encryption - Remove unused FED_AUTH_LIBRARY_MSAL constant - Suppress dead_code warnings on hostname/client_pid/client_prog_ver setters Key finding: Azure SQL backend workers do NOT support TDS 8 strict mode. After routing redirect, clients must reconnect with EncryptionLevel::Required (standard TLS-upgrade flow), not EncryptionLevel::Strict. --- src/tds/codec/login.rs | 6 +- tests/azure_sql.rs | 332 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 335 insertions(+), 3 deletions(-) create mode 100644 tests/azure_sql.rs diff --git a/src/tds/codec/login.rs b/src/tds/codec/login.rs index b1554bac5..f0113b25d 100644 --- a/src/tds/codec/login.rs +++ b/src/tds/codec/login.rs @@ -147,9 +147,6 @@ pub(crate) const FEA_EXT_AZURESQLSUPPORT: u8 = 0x08u8; pub(crate) const FEA_EXT_UTF8_SUPPORT: u8 = 0x0Au8; pub(crate) const FEA_EXT_TERMINATOR: u8 = 0xFFu8; pub(crate) const FED_AUTH_LIBRARYSECURITYTOKEN: u8 = 0x01; -/// MSAL/ADAL library type for Azure AD token authentication. -/// Used when providing pre-obtained JWT access tokens from MSAL or az CLI. -pub(crate) const FED_AUTH_LIBRARY_MSAL: u8 = 0x02; /// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac #[derive(Debug, Clone, Default)] @@ -270,14 +267,17 @@ impl<'a> LoginMessage<'a> { } } + #[allow(dead_code)] pub fn hostname(&mut self, hostname: impl Into>) { self.hostname = hostname.into(); } + #[allow(dead_code)] pub fn client_pid(&mut self, pid: u32) { self.client_pid = pid; } + #[allow(dead_code)] pub fn client_prog_ver(&mut self, ver: u32) { self.client_prog_ver = ver; } diff --git a/tests/azure_sql.rs b/tests/azure_sql.rs new file mode 100644 index 000000000..6c2725d00 --- /dev/null +++ b/tests/azure_sql.rs @@ -0,0 +1,332 @@ +//! Integration tests for Azure SQL Database connectivity with TDS 8 strict encryption. +//! +//! These tests verify that TDS 8 strict encryption works correctly when +//! connecting through Azure SQL Database gateways with routing redirects. +//! +//! Azure SQL architecture: +//! 1. Gateway (port 1433): Accepts TDS 8 strict (TLS-first), returns routing redirect +//! 2. Backend worker (port 11010+): Requires regular TLS-upgrade (PRELOGIN → TLS → LOGIN) +//! +//! This differs from Microsoft Fabric where backends also support strict + pipelined. +//! +//! # Required environment variables +//! +//! - `AZURE_SQL_ENDPOINT`: The Azure SQL server (e.g., `myserver.database.windows.net`) +//! - `AZURE_SQL_DATABASE`: The database name +//! +//! Authentication: +//! - `AZURE_SQL_TOKEN`: A pre-obtained AAD/Entra ID token for `https://database.windows.net/` +//! +//! # Running +//! +//! ```sh +//! export AZURE_SQL_ENDPOINT=myserver.database.windows.net +//! export AZURE_SQL_DATABASE=mydb +//! export AZURE_SQL_TOKEN=$(az account get-access-token --resource https://database.windows.net/ --query accessToken -o tsv) +//! cargo test --test azure_sql -- --nocapture +//! ``` + +use std::env; +use tiberius::{error::Error, AuthMethod, Client, Config, EncryptionLevel}; +use tokio::net::TcpStream; +use tokio_util::compat::TokioAsyncWriteCompatExt; + +/// Helper to skip tests when required environment variables are missing. +macro_rules! skip_if_no_azure_sql { + () => { + if env::var("AZURE_SQL_ENDPOINT").is_err() { + eprintln!( + "SKIPPED: AZURE_SQL_ENDPOINT not set. Set Azure SQL env vars to run this test." + ); + return Ok(()); + } + }; +} + +/// Connect to Azure SQL Database using TDS 8 strict encryption on the gateway, +/// then regular TLS-upgrade on the backend after routing. +/// +/// This matches the behavior of ODBC Driver 18+ and go-mssqldb: +/// - Gateway connection uses TDS 8 strict (TLS-first) to prove we can negotiate it +/// - Backend connection uses regular TLS-upgrade (which Azure SQL backends require) +async fn connect_to_azure_sql( + endpoint: &str, + database: &str, + token: &str, +) -> anyhow::Result>> { + // Phase 1: Connect to gateway with TDS 8 strict encryption + let conn_str = format!( + "server=tcp:{endpoint},1433;encrypt=strict;TrustServerCertificate=true;database={database}" + ); + let mut config = Config::from_ado_string(&conn_str)?; + config.authentication(AuthMethod::aad_token(token)); + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(Error::Routing { host, port }) => { + eprintln!( + "Routing redirect to {}:{}, reconnecting with TLS-upgrade...", + host, port + ); + + // Azure SQL routing targets don't include instance names (unlike Fabric) + let backend_host = host.split('\\').next().unwrap_or(&host); + + // Phase 2: Connect to backend with regular TLS-upgrade (NOT strict). + // Azure SQL backend workers don't support TDS 8 strict mode. + let mut backend_config = Config::new(); + backend_config.host(backend_host); + backend_config.port(port); + backend_config.encryption(EncryptionLevel::Required); + backend_config.trust_cert(); + backend_config.authentication(AuthMethod::aad_token(token)); + backend_config.database(database); + // MS-TDS spec: LOGIN server_name = original gateway endpoint + backend_config.login_server_name(endpoint); + + let tcp = TcpStream::connect(backend_config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let client = Client::connect(backend_config, tcp.compat_write()).await?; + Ok(client) + } + Err(e) => Err(e.into()), + } +} + +/// Connect to Azure SQL Database using regular (non-strict) encryption end-to-end. +/// This is the traditional flow matching ODBC Driver 17 behavior. +async fn connect_to_azure_sql_regular( + endpoint: &str, + database: &str, + token: &str, +) -> anyhow::Result>> { + let conn_str = format!( + "server=tcp:{endpoint},1433;encrypt=true;TrustServerCertificate=true;database={database}" + ); + let mut config = Config::from_ado_string(&conn_str)?; + config.authentication(AuthMethod::aad_token(token)); + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(Error::Routing { host, port }) => { + eprintln!( + "Routing redirect to {}:{}, reconnecting...", + host, port + ); + + let backend_host = host.split('\\').next().unwrap_or(&host); + + let mut backend_config = Config::new(); + backend_config.host(backend_host); + backend_config.port(port); + backend_config.encryption(EncryptionLevel::Required); + backend_config.trust_cert(); + backend_config.authentication(AuthMethod::aad_token(token)); + backend_config.database(database); + backend_config.login_server_name(endpoint); + + let tcp = TcpStream::connect(backend_config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let client = Client::connect(backend_config, tcp.compat_write()).await?; + Ok(client) + } + Err(e) => Err(e.into()), + } +} + +/// Test: Connect with TDS 8 strict gateway + regular backend and run a query. +#[tokio::test] +async fn azure_sql_strict_gateway_query() -> anyhow::Result<()> { + skip_if_no_azure_sql!(); + + let endpoint = env::var("AZURE_SQL_ENDPOINT")?; + let database = env::var("AZURE_SQL_DATABASE")?; + let token = env::var("AZURE_SQL_TOKEN")?; + + let mut client = connect_to_azure_sql(&endpoint, &database, &token).await?; + + let row = client + .query("SELECT 1 AS test_value", &[]) + .await? + .into_row() + .await? + .unwrap(); + + assert_eq!(Some(1i32), row.get("test_value")); + eprintln!("Azure SQL TDS 8 strict gateway + regular backend: OK"); + + Ok(()) +} + +/// Test: Verify @@VERSION and DB_NAME() after strict gateway connection. +#[tokio::test] +async fn azure_sql_strict_server_metadata() -> anyhow::Result<()> { + skip_if_no_azure_sql!(); + + let endpoint = env::var("AZURE_SQL_ENDPOINT")?; + let database = env::var("AZURE_SQL_DATABASE")?; + let token = env::var("AZURE_SQL_TOKEN")?; + + let mut client = connect_to_azure_sql(&endpoint, &database, &token).await?; + + let row = client + .query( + "SELECT @@VERSION AS ver, DB_NAME() AS db_name, SUSER_SNAME() AS login_name", + &[], + ) + .await? + .into_row() + .await? + .unwrap(); + + let ver: &str = row.get("ver").unwrap(); + let db_name: &str = row.get("db_name").unwrap(); + let login_name: &str = row.get("login_name").unwrap(); + + assert!( + ver.contains("Microsoft SQL Azure"), + "Should be Azure SQL, got: {}", + ver + ); + assert_eq!(db_name, database, "Should connect to the requested database"); + eprintln!("Version: {}", &ver[..ver.find('\n').unwrap_or(ver.len())]); + eprintln!("Database: {}", db_name); + eprintln!("Login: {}", login_name); + + Ok(()) +} + +/// Test: Multiple sequential queries over the connection. +#[tokio::test] +async fn azure_sql_strict_multiple_queries() -> anyhow::Result<()> { + skip_if_no_azure_sql!(); + + let endpoint = env::var("AZURE_SQL_ENDPOINT")?; + let database = env::var("AZURE_SQL_DATABASE")?; + let token = env::var("AZURE_SQL_TOKEN")?; + + let mut client = connect_to_azure_sql(&endpoint, &database, &token).await?; + + // Query 1 + let row = client + .query("SELECT 42 AS answer", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(42i32), row.get("answer")); + + // Query 2: server time + let row = client + .query( + "SELECT CAST(GETUTCDATE() AS NVARCHAR(50)) AS server_time", + &[], + ) + .await? + .into_row() + .await? + .unwrap(); + let time: &str = row.get("server_time").unwrap(); + assert!(!time.is_empty(), "Should get server time"); + + // Query 3: parameterized + let row = client + .query("SELECT @P1 + @P2 AS sum_result", &[&10i32, &32i32]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(42i32), row.get("sum_result")); + + eprintln!("Multiple queries over strict gateway connection: OK"); + Ok(()) +} + +/// Test: Regular (non-strict) encryption also works for comparison. +#[tokio::test] +async fn azure_sql_regular_encryption() -> anyhow::Result<()> { + skip_if_no_azure_sql!(); + + let endpoint = env::var("AZURE_SQL_ENDPOINT")?; + let database = env::var("AZURE_SQL_DATABASE")?; + let token = env::var("AZURE_SQL_TOKEN")?; + + let mut client = connect_to_azure_sql_regular(&endpoint, &database, &token).await?; + + let row = client + .query("SELECT DB_NAME() AS db_name", &[]) + .await? + .into_row() + .await? + .unwrap(); + + let db_name: &str = row.get("db_name").unwrap(); + assert_eq!(db_name, database); + eprintln!("Regular encryption (non-strict) also works: OK"); + + Ok(()) +} + +/// Test: DDL and DML operations work over strict gateway connection. +#[tokio::test] +async fn azure_sql_strict_ddl_dml() -> anyhow::Result<()> { + skip_if_no_azure_sql!(); + + let endpoint = env::var("AZURE_SQL_ENDPOINT")?; + let database = env::var("AZURE_SQL_DATABASE")?; + let token = env::var("AZURE_SQL_TOKEN")?; + + let mut client = connect_to_azure_sql(&endpoint, &database, &token).await?; + + // Create a temp table (must consume result before next command) + client + .simple_query( + "CREATE TABLE #tds8_test (id INT, name NVARCHAR(50), value DECIMAL(10,2))", + ) + .await? + .into_results() + .await?; + + // Insert data + let rows_affected = client + .execute( + "INSERT INTO #tds8_test VALUES (@P1, @P2, @P3)", + &[&1i32, &"hello", &42.5f64], + ) + .await? + .total(); + assert_eq!(rows_affected, 1); + + let rows_affected = client + .execute( + "INSERT INTO #tds8_test VALUES (@P1, @P2, @P3)", + &[&2i32, &"world", &99.9f64], + ) + .await? + .total(); + assert_eq!(rows_affected, 1); + + // Query data back + let rows: Vec<_> = client + .query("SELECT id, name, value FROM #tds8_test ORDER BY id", &[]) + .await? + .into_first_result() + .await?; + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::("id"), Some(1)); + assert_eq!(rows[0].get::<&str, _>("name"), Some("hello")); + assert_eq!(rows[1].get::("id"), Some(2)); + assert_eq!(rows[1].get::<&str, _>("name"), Some("world")); + + eprintln!("DDL/DML over strict gateway connection: OK"); + Ok(()) +} From ee572f19ebbce64a478316165445e0bb49463fbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 15:54:06 +0200 Subject: [PATCH 05/15] Add SQL Server 2025 TDS 8 strict mode integration tests 7 integration tests covering: - Basic strict mode connection and query - Server metadata verification (version, database, transport) - Multiple sequential queries over a single connection - DDL/DML operations (CREATE TABLE, INSERT, UPDATE, DELETE) - Encryption verification via sys.dm_exec_connections - Large result set (1000 rows) to stress TLS framing - CA certificate validation (non-trust-all path) Tests run against SQL Server 2025 with network.forcestrict=1 (strict mode is only available in SQL Server 2025+, not 2022 on Linux). Tests skip gracefully when SQL_SERVER_PASSWORD env var is not set. --- tests/sql_server_tds8.rs | 408 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 408 insertions(+) create mode 100644 tests/sql_server_tds8.rs diff --git a/tests/sql_server_tds8.rs b/tests/sql_server_tds8.rs new file mode 100644 index 000000000..b63d61788 --- /dev/null +++ b/tests/sql_server_tds8.rs @@ -0,0 +1,408 @@ +//! Integration tests for SQL Server 2025+ TDS 8 strict encryption. +//! +//! These tests verify that TDS 8 strict encryption works correctly when +//! connecting directly to a SQL Server instance with `forcestrict = 1`. +//! +//! Unlike Azure SQL/Fabric, SQL Server with strict mode does NOT use routing +//! redirects — the connection goes directly to the server. +//! +//! # Required environment variables +//! +//! - `SQL_SERVER_HOST`: The server hostname (default: `localhost`) +//! - `SQL_SERVER_PORT`: The server port (default: `1434`) +//! - `SQL_SERVER_USER`: SQL login (default: `sa`) +//! - `SQL_SERVER_PASSWORD`: SQL password +//! - `SQL_SERVER_CA_CERT`: Path to CA cert for TLS verification (optional; uses trust-all if unset) +//! +//! # Running +//! +//! ```sh +//! # Start SQL Server 2025 in Docker with strict mode: +//! # docker run -d --name mssql-tds8 -p 1434:1433 \ +//! # -e ACCEPT_EULA=Y -e MSSQL_SA_PASSWORD=StrictMode!2022 \ +//! # -v ./certs/mssql-cert.pem:/var/opt/mssql/certs/mssql-cert.pem:ro \ +//! # -v ./certs/mssql-key.pem:/var/opt/mssql/certs/mssql-key.pem:ro \ +//! # -v ./mssql.conf:/var/opt/mssql/mssql.conf \ +//! # mcr.microsoft.com/mssql/server:2025-latest +//! # +//! # mssql.conf should contain: +//! # [network] +//! # tlscert = /var/opt/mssql/certs/mssql-cert.pem +//! # tlskey = /var/opt/mssql/certs/mssql-key.pem +//! # tlsprotocols = 1.2 +//! # forceencryption = 1 +//! # forcestrict = 1 +//! +//! export SQL_SERVER_PASSWORD=StrictMode!2022 +//! cargo test --test sql_server_tds8 -- --nocapture +//! ``` + +use std::env; +use tiberius::{AuthMethod, Client, Config}; +use tokio::net::TcpStream; +use tokio_util::compat::TokioAsyncWriteCompatExt; + +/// Helper to skip tests when required environment variables are missing. +macro_rules! skip_if_no_sql_server { + () => { + if env::var("SQL_SERVER_PASSWORD").is_err() { + eprintln!( + "SKIPPED: SQL_SERVER_PASSWORD not set. Set SQL Server env vars to run this test." + ); + return Ok(()); + } + }; +} + +/// Connect to SQL Server with TDS 8 strict encryption (TLS-first, no routing). +async fn connect_strict( +) -> anyhow::Result>> { + let host = env::var("SQL_SERVER_HOST").unwrap_or_else(|_| "localhost".to_string()); + let port = env::var("SQL_SERVER_PORT").unwrap_or_else(|_| "1434".to_string()); + let user = env::var("SQL_SERVER_USER").unwrap_or_else(|_| "sa".to_string()); + let password = env::var("SQL_SERVER_PASSWORD")?; + + let conn_str = if let Ok(ca_path) = env::var("SQL_SERVER_CA_CERT") { + format!( + "server=tcp:{host},{port};encrypt=strict;database=master;Certificate={ca_path}" + ) + } else { + format!( + "server=tcp:{host},{port};encrypt=strict;TrustServerCertificate=true;database=master" + ) + }; + + let mut config = Config::from_ado_string(&conn_str)?; + config.authentication(AuthMethod::sql_server(&user, &password)); + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let client = Client::connect(config, tcp.compat_write()).await?; + Ok(client) +} + +/// Test: Basic strict mode connection and simple query. +#[tokio::test] +async fn sql_server_strict_basic_query() -> anyhow::Result<()> { + skip_if_no_sql_server!(); + + let mut client = connect_strict().await?; + + let row = client + .query("SELECT 1 AS test_value", &[]) + .await? + .into_row() + .await? + .unwrap(); + + assert_eq!(Some(1i32), row.get("test_value")); + eprintln!("SQL Server TDS 8 strict basic query: OK"); + + Ok(()) +} + +/// Test: Verify server version and metadata via strict connection. +#[tokio::test] +async fn sql_server_strict_server_metadata() -> anyhow::Result<()> { + skip_if_no_sql_server!(); + + let mut client = connect_strict().await?; + + let row = client + .query( + "SELECT @@VERSION AS ver, DB_NAME() AS db_name, SUSER_SNAME() AS login_name, \ + CAST(CONNECTIONPROPERTY('net_transport') AS NVARCHAR(50)) AS transport, \ + CAST(CONNECTIONPROPERTY('protocol_type') AS NVARCHAR(50)) AS protocol", + &[], + ) + .await? + .into_row() + .await? + .unwrap(); + + let ver: &str = row.get("ver").unwrap(); + let db_name: &str = row.get("db_name").unwrap(); + let login_name: &str = row.get("login_name").unwrap(); + let transport: &str = row.get("transport").unwrap(); + + assert!( + ver.contains("Microsoft SQL Server 2025") || ver.contains("Microsoft SQL Server 2022"), + "Expected SQL Server 2022+, got: {}", + &ver[..ver.find('\n').unwrap_or(80.min(ver.len()))] + ); + assert_eq!(db_name, "master"); + assert_eq!(transport, "TCP"); + eprintln!( + "Version: {}", + &ver[..ver.find('\n').unwrap_or(ver.len())] + ); + eprintln!("Database: {}, Login: {}", db_name, login_name); + + Ok(()) +} + +/// Test: Multiple sequential queries over strict connection. +#[tokio::test] +async fn sql_server_strict_multiple_queries() -> anyhow::Result<()> { + skip_if_no_sql_server!(); + + let mut client = connect_strict().await?; + + // Query 1 + let row = client + .query("SELECT 42 AS answer", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(42i32), row.get("answer")); + + // Query 2: server time + let row = client + .query( + "SELECT CAST(GETUTCDATE() AS NVARCHAR(50)) AS server_time", + &[], + ) + .await? + .into_row() + .await? + .unwrap(); + let time: &str = row.get("server_time").unwrap(); + assert!(!time.is_empty(), "Should get server time"); + + // Query 3: parameterized + let row = client + .query("SELECT @P1 + @P2 AS sum_result", &[&10i32, &32i32]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(42i32), row.get("sum_result")); + + // Query 4: string operations + let row = client + .query( + "SELECT CONCAT(@P1, N' ', @P2) AS greeting", + &[&"Hello", &"TDS8"], + ) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some("Hello TDS8"), row.get::<&str, _>("greeting")); + + eprintln!("Multiple queries over strict connection: OK"); + Ok(()) +} + +/// Test: DDL and DML operations over strict connection. +#[tokio::test] +async fn sql_server_strict_ddl_dml() -> anyhow::Result<()> { + skip_if_no_sql_server!(); + + let mut client = connect_strict().await?; + + // Create temp table + client + .simple_query( + "CREATE TABLE #tds8_strict_test (id INT, name NVARCHAR(100), value DECIMAL(10,2))", + ) + .await? + .into_results() + .await?; + + // Insert data using parameters + let rows_affected = client + .execute( + "INSERT INTO #tds8_strict_test VALUES (@P1, @P2, @P3)", + &[&1i32, &"strict_mode", &123.45f64], + ) + .await? + .total(); + assert_eq!(rows_affected, 1); + + let rows_affected = client + .execute( + "INSERT INTO #tds8_strict_test VALUES (@P1, @P2, @P3)", + &[&2i32, &"tds_eight", &678.90f64], + ) + .await? + .total(); + assert_eq!(rows_affected, 1); + + // Batch insert + client + .simple_query( + "INSERT INTO #tds8_strict_test VALUES (3, N'batch_one', 11.11), (4, N'batch_two', 22.22)", + ) + .await? + .into_results() + .await?; + + // Query data back + let rows: Vec<_> = client + .query( + "SELECT id, name, value FROM #tds8_strict_test ORDER BY id", + &[], + ) + .await? + .into_first_result() + .await?; + + assert_eq!(rows.len(), 4); + assert_eq!(rows[0].get::("id"), Some(1)); + assert_eq!(rows[0].get::<&str, _>("name"), Some("strict_mode")); + assert_eq!(rows[3].get::("id"), Some(4)); + assert_eq!(rows[3].get::<&str, _>("name"), Some("batch_two")); + + // Update + let rows_affected = client + .execute( + "UPDATE #tds8_strict_test SET value = @P1 WHERE id = @P2", + &[&999.99f64, &1i32], + ) + .await? + .total(); + assert_eq!(rows_affected, 1); + + // Delete + let rows_affected = client + .execute( + "DELETE FROM #tds8_strict_test WHERE id > @P1", + &[&2i32], + ) + .await? + .total(); + assert_eq!(rows_affected, 2); + + // Verify final state + let rows: Vec<_> = client + .query( + "SELECT id, value FROM #tds8_strict_test ORDER BY id", + &[], + ) + .await? + .into_first_result() + .await?; + assert_eq!(rows.len(), 2); + + eprintln!("DDL/DML over strict connection: OK"); + Ok(()) +} + +/// Test: Verify encryption is active via sys.dm_exec_connections. +#[tokio::test] +async fn sql_server_strict_verify_encryption() -> anyhow::Result<()> { + skip_if_no_sql_server!(); + + let mut client = connect_strict().await?; + + let row = client + .query( + "SELECT encrypt_option, auth_scheme, protocol_type, net_transport \ + FROM sys.dm_exec_connections WHERE session_id = @@SPID", + &[], + ) + .await? + .into_row() + .await? + .unwrap(); + + let encrypt_option: &str = row.get("encrypt_option").unwrap(); + let auth_scheme: &str = row.get("auth_scheme").unwrap(); + let protocol_type: &str = row.get("protocol_type").unwrap(); + let net_transport: &str = row.get("net_transport").unwrap(); + + // In strict mode, encrypt_option should be TRUE (or STRICT on newer versions) + assert!( + encrypt_option == "TRUE" || encrypt_option == "STRICT", + "Expected encrypted connection, got encrypt_option='{}'", + encrypt_option + ); + assert_eq!(auth_scheme, "SQL"); + assert_eq!(protocol_type, "TSQL"); + assert_eq!(net_transport, "TCP"); + + eprintln!( + "Encryption verified: option={}, scheme={}, transport={}", + encrypt_option, auth_scheme, net_transport + ); + Ok(()) +} + +/// Test: Large result set over strict connection (verifies TLS framing is stable). +#[tokio::test] +async fn sql_server_strict_large_result() -> anyhow::Result<()> { + skip_if_no_sql_server!(); + + let mut client = connect_strict().await?; + + // Generate a large result set to exercise TLS framing across multiple packets + let rows: Vec<_> = client + .query( + "SELECT TOP 1000 \ + ROW_NUMBER() OVER (ORDER BY a.object_id) AS row_num, \ + REPLICATE(N'X', 200) AS padding \ + FROM sys.all_objects a CROSS JOIN sys.all_objects b", + &[], + ) + .await? + .into_first_result() + .await?; + + assert_eq!(rows.len(), 1000, "Should get exactly 1000 rows"); + + // Verify first and last row + assert_eq!(rows[0].get::("row_num"), Some(1)); + assert_eq!(rows[999].get::("row_num"), Some(1000)); + + let padding: &str = rows[0].get("padding").unwrap(); + assert_eq!(padding.len(), 200, "Padding should be 200 chars"); + + eprintln!("Large result set (1000 rows) over strict connection: OK"); + Ok(()) +} + +/// Test: Connection with CA certificate verification (not just trust-all). +#[tokio::test] +async fn sql_server_strict_ca_cert_validation() -> anyhow::Result<()> { + skip_if_no_sql_server!(); + + // This test only runs if a CA cert path is provided + let ca_cert = match env::var("SQL_SERVER_CA_CERT") { + Ok(path) => path, + Err(_) => { + eprintln!("SKIPPED: SQL_SERVER_CA_CERT not set. Set it to test CA cert validation."); + return Ok(()); + } + }; + + let host = env::var("SQL_SERVER_HOST").unwrap_or_else(|_| "localhost".to_string()); + let port = env::var("SQL_SERVER_PORT").unwrap_or_else(|_| "1434".to_string()); + let user = env::var("SQL_SERVER_USER").unwrap_or_else(|_| "sa".to_string()); + let password = env::var("SQL_SERVER_PASSWORD")?; + + let conn_str = format!( + "server=tcp:{host},{port};encrypt=strict;database=master;Certificate={ca_cert}" + ); + + let mut config = Config::from_ado_string(&conn_str)?; + config.authentication(AuthMethod::sql_server(&user, &password)); + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let mut client = Client::connect(config, tcp.compat_write()).await?; + + let row = client + .query("SELECT 1 AS test", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(1i32), row.get("test")); + + eprintln!("CA cert validation with strict mode: OK"); + Ok(()) +} From a9276b3ce4919f785be285c81346525e97d41dc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 16:15:36 +0200 Subject: [PATCH 06/15] test: add DDL/DML, large result, encryption verification to Fabric and Azure SQL Brings Fabric and Azure SQL integration tests to parity with SQL Server: - DDL/DML: CREATE TABLE, INSERT, UPDATE, DELETE, SELECT with params - Large result set: 1000 rows to stress TLS framing across packets - Encryption verification: CONNECTIONPROPERTY + sys.dm_exec_connections - String/unicode: CONCAT, unicode chars, 8000-char strings Fabric-specific considerations: - Uses permanent table (not #temp) since Fabric DW doesn't support local temps - Large result tries tpch_sf1.lineitem first, falls back to recursive CTE - dm_exec_connections access is optional (may not be available on all SKUs) Total integration test count: 24 (9 Fabric + 8 Azure SQL + 7 SQL Server) --- tests/azure_sql.rs | 129 ++++++++++++++++++++ tests/fabric.rs | 287 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 416 insertions(+) diff --git a/tests/azure_sql.rs b/tests/azure_sql.rs index 6c2725d00..141147a9a 100644 --- a/tests/azure_sql.rs +++ b/tests/azure_sql.rs @@ -330,3 +330,132 @@ async fn azure_sql_strict_ddl_dml() -> anyhow::Result<()> { eprintln!("DDL/DML over strict gateway connection: OK"); Ok(()) } + +/// Test: Large result set over Azure SQL strict connection to stress TLS framing. +#[tokio::test] +async fn azure_sql_strict_large_result() -> anyhow::Result<()> { + skip_if_no_azure_sql!(); + + let endpoint = env::var("AZURE_SQL_ENDPOINT")?; + let database = env::var("AZURE_SQL_DATABASE")?; + let token = env::var("AZURE_SQL_TOKEN")?; + + let mut client = connect_to_azure_sql(&endpoint, &database, &token).await?; + + // Generate a large result set to exercise TLS framing across multiple packets + let rows: Vec<_> = client + .query( + "SELECT TOP 1000 \ + ROW_NUMBER() OVER (ORDER BY a.object_id) AS row_num, \ + REPLICATE(N'X', 200) AS padding \ + FROM sys.all_objects a CROSS JOIN sys.all_objects b", + &[], + ) + .await? + .into_first_result() + .await?; + + assert_eq!(rows.len(), 1000, "Should get exactly 1000 rows"); + + // Verify first and last row + assert_eq!(rows[0].get::("row_num"), Some(1)); + assert_eq!(rows[999].get::("row_num"), Some(1000)); + + let padding: &str = rows[0].get("padding").unwrap(); + assert_eq!(padding.len(), 200, "Padding should be 200 chars"); + + eprintln!("Large result set (1000 rows) over Azure SQL strict connection: OK"); + Ok(()) +} + +/// Test: Verify encryption status via sys.dm_exec_connections on Azure SQL. +#[tokio::test] +async fn azure_sql_strict_verify_encryption() -> anyhow::Result<()> { + skip_if_no_azure_sql!(); + + let endpoint = env::var("AZURE_SQL_ENDPOINT")?; + let database = env::var("AZURE_SQL_DATABASE")?; + let token = env::var("AZURE_SQL_TOKEN")?; + + let mut client = connect_to_azure_sql(&endpoint, &database, &token).await?; + + let row = client + .query( + "SELECT encrypt_option, auth_scheme, protocol_type, net_transport \ + FROM sys.dm_exec_connections WHERE session_id = @@SPID", + &[], + ) + .await? + .into_row() + .await? + .unwrap(); + + let encrypt_option: &str = row.get("encrypt_option").unwrap(); + let auth_scheme: &str = row.get("auth_scheme").unwrap(); + let net_transport: &str = row.get("net_transport").unwrap(); + + assert!( + encrypt_option == "TRUE" || encrypt_option == "STRICT", + "Expected encrypted connection, got encrypt_option='{}'", + encrypt_option + ); + assert_eq!(net_transport, "TCP"); + // Azure SQL with AAD token uses NTML at transport but AAD at auth layer + assert!( + !auth_scheme.is_empty(), + "Should have an auth scheme" + ); + + eprintln!( + "Azure SQL encryption verified: option={}, scheme={}, transport={}", + encrypt_option, auth_scheme, net_transport + ); + Ok(()) +} + +/// Test: String and unicode operations over Azure SQL strict connection. +#[tokio::test] +async fn azure_sql_strict_string_operations() -> anyhow::Result<()> { + skip_if_no_azure_sql!(); + + let endpoint = env::var("AZURE_SQL_ENDPOINT")?; + let database = env::var("AZURE_SQL_DATABASE")?; + let token = env::var("AZURE_SQL_TOKEN")?; + + let mut client = connect_to_azure_sql(&endpoint, &database, &token).await?; + + // Test unicode handling over TDS 8 strict TLS + let row = client + .query( + "SELECT CONCAT(@P1, N' ', @P2) AS greeting", + &[&"Hello", &"TDS8"], + ) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some("Hello TDS8"), row.get::<&str, _>("greeting")); + + // Unicode characters + let row = client + .query("SELECT @P1 AS unicode_text", &[&"日本語テスト 🚀"]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some("日本語テスト 🚀"), row.get::<&str, _>("unicode_text")); + + // Long string that spans multiple TDS packets + let long_string = "A".repeat(8000); + let row = client + .query("SELECT @P1 AS long_text", &[&long_string.as_str()]) + .await? + .into_row() + .await? + .unwrap(); + let result: &str = row.get("long_text").unwrap(); + assert_eq!(result.len(), 8000); + + eprintln!("String/unicode operations over Azure SQL strict connection: OK"); + Ok(()) +} diff --git a/tests/fabric.rs b/tests/fabric.rs index 1914a27d0..6f9a496b0 100644 --- a/tests/fabric.rs +++ b/tests/fabric.rs @@ -257,3 +257,290 @@ async fn fabric_jdbc_connection_string() -> anyhow::Result<()> { ); Ok(()) } + +/// Test: DDL and DML operations over Fabric strict connection. +/// +/// Fabric Data Warehouse does not support local temp tables (#table) but does +/// support regular tables. We create a test table, exercise INSERT/UPDATE/DELETE, +/// then drop it. +#[tokio::test] +async fn fabric_strict_ddl_dml() -> anyhow::Result<()> { + skip_if_no_fabric!(); + + let endpoint = env::var("FABRIC_ENDPOINT")?; + let database = env::var("FABRIC_DATABASE")?; + let token = get_aad_token().await?; + + let mut client = connect_to_fabric(&endpoint, &database, &token).await?; + + let table_name = "dbo.__tiberius_tds8_test"; + + // Drop if exists from a previous failed run + client + .simple_query(format!( + "IF OBJECT_ID('{table_name}', 'U') IS NOT NULL DROP TABLE {table_name}" + )) + .await? + .into_results() + .await?; + + // CREATE TABLE + client + .simple_query(format!( + "CREATE TABLE {table_name} (id INT, name NVARCHAR(100), value DECIMAL(10,2))" + )) + .await? + .into_results() + .await?; + + // INSERT with parameters + let rows_affected = client + .execute( + format!("INSERT INTO {table_name} VALUES (@P1, @P2, @P3)"), + &[&1i32, &"strict_mode", &123.45f64], + ) + .await? + .total(); + assert_eq!(rows_affected, 1); + + let rows_affected = client + .execute( + format!("INSERT INTO {table_name} VALUES (@P1, @P2, @P3)"), + &[&2i32, &"tds_eight", &678.90f64], + ) + .await? + .total(); + assert_eq!(rows_affected, 1); + + // SELECT back + let rows: Vec<_> = client + .query( + format!("SELECT id, name, value FROM {table_name} ORDER BY id"), + &[], + ) + .await? + .into_first_result() + .await?; + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::("id"), Some(1)); + assert_eq!(rows[0].get::<&str, _>("name"), Some("strict_mode")); + assert_eq!(rows[1].get::("id"), Some(2)); + assert_eq!(rows[1].get::<&str, _>("name"), Some("tds_eight")); + + // UPDATE + let rows_affected = client + .execute( + format!("UPDATE {table_name} SET value = @P1 WHERE id = @P2"), + &[&999.99f64, &1i32], + ) + .await? + .total(); + assert_eq!(rows_affected, 1); + + // DELETE + let rows_affected = client + .execute( + format!("DELETE FROM {table_name} WHERE id = @P1"), + &[&2i32], + ) + .await? + .total(); + assert_eq!(rows_affected, 1); + + // Verify final state + let rows: Vec<_> = client + .query( + format!("SELECT id, value FROM {table_name} ORDER BY id"), + &[], + ) + .await? + .into_first_result() + .await?; + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::("id"), Some(1)); + + // Cleanup + client + .simple_query(format!("DROP TABLE {table_name}")) + .await? + .into_results() + .await?; + + eprintln!("Fabric DDL/DML over strict connection: OK"); + Ok(()) +} + +/// Test: Large result set over Fabric strict connection to stress TLS framing. +/// +/// Uses the tpch_sf1 lineitem table if available, otherwise generates rows +/// with a recursive CTE. +#[tokio::test] +async fn fabric_strict_large_result() -> anyhow::Result<()> { + skip_if_no_fabric!(); + + let endpoint = env::var("FABRIC_ENDPOINT")?; + let database = env::var("FABRIC_DATABASE")?; + let token = get_aad_token().await?; + + let mut client = connect_to_fabric(&endpoint, &database, &token).await?; + + // Try querying from lineitem (tpch_sf1), fall back to a generated set + let query = if client + .query( + "SELECT TOP 1 1 AS x FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'lineitem'", + &[], + ) + .await? + .into_row() + .await? + .is_some() + { + // Use tpch lineitem table — large rows with multiple columns + "SELECT TOP 1000 l_orderkey, l_partkey, l_suppkey, \ + CAST(l_quantity AS DECIMAL(10,2)) AS l_quantity, \ + CAST(l_extendedprice AS DECIMAL(12,2)) AS l_extendedprice, \ + l_shipdate, l_comment \ + FROM lineitem ORDER BY l_orderkey, l_linenumber".to_string() + } else { + // Generate 1000 rows with a CTE + "WITH nums AS ( \ + SELECT 1 AS n UNION ALL SELECT n + 1 FROM nums WHERE n < 1000 \ + ) \ + SELECT n AS row_num, REPLICATE(N'X', 200) AS padding \ + FROM nums OPTION (MAXRECURSION 1000)".to_string() + }; + + let rows: Vec<_> = client + .query(query, &[]) + .await? + .into_first_result() + .await?; + + assert_eq!(rows.len(), 1000, "Should get exactly 1000 rows"); + eprintln!( + "Large result set (1000 rows, {} columns) over Fabric strict connection: OK", + rows[0].columns().len() + ); + + Ok(()) +} + +/// Test: Verify encryption status via session properties on Fabric. +/// +/// Fabric may not expose sys.dm_exec_connections, so we use +/// session context properties instead. +#[tokio::test] +async fn fabric_strict_verify_encryption() -> anyhow::Result<()> { + skip_if_no_fabric!(); + + let endpoint = env::var("FABRIC_ENDPOINT")?; + let database = env::var("FABRIC_DATABASE")?; + let token = get_aad_token().await?; + + let mut client = connect_to_fabric(&endpoint, &database, &token).await?; + + // Try sys.dm_exec_connections first (may work on some Fabric SKUs) + let row = client + .query( + "SELECT \ + CAST(CONNECTIONPROPERTY('net_transport') AS NVARCHAR(50)) AS transport, \ + CAST(CONNECTIONPROPERTY('protocol_type') AS NVARCHAR(50)) AS protocol, \ + CAST(CONNECTIONPROPERTY('auth_scheme') AS NVARCHAR(50)) AS auth_scheme", + &[], + ) + .await? + .into_row() + .await? + .unwrap(); + + let transport: &str = row.get("transport").unwrap(); + let protocol: &str = row.get("protocol").unwrap(); + let auth_scheme: &str = row.get("auth_scheme").unwrap(); + + assert_eq!(transport, "TCP", "Should be TCP transport"); + assert_eq!(protocol, "TSQL", "Should be TSQL protocol"); + // Fabric with AAD token should show NTML or AAD-based auth + assert!( + !auth_scheme.is_empty(), + "Should have an auth scheme, got empty" + ); + + eprintln!( + "Fabric encryption verified: transport={}, protocol={}, auth={}", + transport, protocol, auth_scheme + ); + + // Also try dm_exec_connections if accessible + match client + .query( + "SELECT encrypt_option FROM sys.dm_exec_connections WHERE session_id = @@SPID", + &[], + ) + .await + { + Ok(result) => { + if let Some(row) = result.into_row().await? { + let encrypt_option: &str = row.get("encrypt_option").unwrap(); + assert!( + encrypt_option == "TRUE" || encrypt_option == "STRICT", + "Expected encrypted, got: {}", + encrypt_option + ); + eprintln!(" dm_exec_connections.encrypt_option = {}", encrypt_option); + } + } + Err(_) => { + eprintln!(" (sys.dm_exec_connections not accessible on this Fabric endpoint)"); + } + } + + Ok(()) +} + +/// Test: String and unicode operations over strict connection. +#[tokio::test] +async fn fabric_strict_string_operations() -> anyhow::Result<()> { + skip_if_no_fabric!(); + + let endpoint = env::var("FABRIC_ENDPOINT")?; + let database = env::var("FABRIC_DATABASE")?; + let token = get_aad_token().await?; + + let mut client = connect_to_fabric(&endpoint, &database, &token).await?; + + // Test unicode handling over TDS 8 strict TLS + let row = client + .query( + "SELECT CONCAT(@P1, N' ', @P2) AS greeting", + &[&"Hello", &"TDS8"], + ) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some("Hello TDS8"), row.get::<&str, _>("greeting")); + + // Unicode characters + let row = client + .query("SELECT @P1 AS unicode_text", &[&"日本語テスト 🚀"]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some("日本語テスト 🚀"), row.get::<&str, _>("unicode_text")); + + // Long string that spans multiple TDS packets + let long_string = "A".repeat(8000); + let row = client + .query("SELECT @P1 AS long_text", &[&long_string.as_str()]) + .await? + .into_row() + .await? + .unwrap(); + let result: &str = row.get("long_text").unwrap(); + assert_eq!(result.len(), 8000); + + eprintln!("String/unicode operations over Fabric strict connection: OK"); + Ok(()) +} From 977c29e7f9641dc498ae34c12b3c59d6bcbd50d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 16:39:54 +0200 Subject: [PATCH 07/15] fix(tests): Fabric DW compatibility for DDL and large result tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use VARCHAR(100) instead of NVARCHAR(100) — Fabric DW does not support nvarchar with specified length - Replace recursive CTE with OPTION(MAXRECURSION) by cross-join row generation — Fabric DW disallows query hints - Skip DDL test gracefully when identity lacks CREATE TABLE permission (e.g., read-only lakehouse endpoints) --- tests/fabric.rs | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/fabric.rs b/tests/fabric.rs index 6f9a496b0..773b0a520 100644 --- a/tests/fabric.rs +++ b/tests/fabric.rs @@ -284,14 +284,27 @@ async fn fabric_strict_ddl_dml() -> anyhow::Result<()> { .into_results() .await?; - // CREATE TABLE - client + // CREATE TABLE — Fabric DW only supports varchar(n), not nvarchar(n) with length. + // Skip gracefully if the identity lacks DDL permissions (e.g., read-only lakehouse). + match client .simple_query(format!( - "CREATE TABLE {table_name} (id INT, name NVARCHAR(100), value DECIMAL(10,2))" + "CREATE TABLE {table_name} (id INT, name VARCHAR(100), value DECIMAL(10,2))" )) - .await? - .into_results() - .await?; + .await + { + Err(e) if e.to_string().contains("denied") => { + eprintln!( + "SKIPPED: DDL not permitted on this Fabric database ({}). \ + Set FABRIC_DATABASE to a writable Data Warehouse to run this test.", + e + ); + return Ok(()); + } + Err(e) => return Err(e.into()), + Ok(stream) => { + stream.into_results().await?; + } + } // INSERT with parameters let rows_affected = client @@ -403,12 +416,12 @@ async fn fabric_strict_large_result() -> anyhow::Result<()> { l_shipdate, l_comment \ FROM lineitem ORDER BY l_orderkey, l_linenumber".to_string() } else { - // Generate 1000 rows with a CTE - "WITH nums AS ( \ - SELECT 1 AS n UNION ALL SELECT n + 1 FROM nums WHERE n < 1000 \ - ) \ - SELECT n AS row_num, REPLICATE(N'X', 200) AS padding \ - FROM nums OPTION (MAXRECURSION 1000)".to_string() + // Generate 1000 rows via cross join (no OPTION hints — Fabric DW disallows them) + "SELECT TOP 1000 ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS row_num, \ + REPLICATE('X', 200) AS padding \ + FROM (VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10)) AS t1(n) \ + CROSS JOIN (VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10)) AS t2(n) \ + CROSS JOIN (VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10)) AS t3(n)".to_string() }; let rows: Vec<_> = client From bcf71268030384704666e2977fe62da336392b47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 16:54:51 +0200 Subject: [PATCH 08/15] chore: formatting, clippy fixes, and remove debug dump code - Apply cargo fmt across all changed files - Remove dead debug TDS frame dump to /tmp in feed_to_wire() - Remove unnecessary #[allow(dead_code)] on feed_to_wire (it's used) - Remove unused lifetime 'a on login() function - Use eq_ignore_ascii_case() instead of to_ascii_lowercase() comparisons in TLS cert extension matching --- src/client/connection.rs | 56 ++++++-------------- src/client/tls_stream/native_tls_stream.rs | 12 ++--- src/tds/codec/login.rs | 10 ++-- src/tds/codec/pre_login.rs | 8 ++- src/tds/codec/token/token_feature_ext_ack.rs | 10 +++- tests/azure_sql.rs | 19 +++---- tests/fabric.rs | 21 +++----- tests/sql_server_tds8.rs | 27 +++------- 8 files changed, 60 insertions(+), 103 deletions(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index a25984870..0b765f5d0 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -113,7 +113,12 @@ impl Connection { let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_)); let prelogin = connection - .prelogin(config.encryption, fed_auth_required, config.instance_name.clone(), false) + .prelogin( + config.encryption, + fed_auth_required, + config.instance_name.clone(), + false, + ) .await?; let encryption = prelogin.negotiated_encryption(config.encryption); @@ -139,10 +144,7 @@ impl Connection { /// Complete connection setup after TLS is established (for strict mode). /// In TDS 8, PRELOGIN and LOGIN both happen inside the TLS tunnel. - async fn finish_connect_after_tls( - mut connection: Self, - config: Config, - ) -> crate::Result { + async fn finish_connect_after_tls(mut connection: Self, config: Config) -> crate::Result { // For backend connections (routing reconnect), we still need to send // FEDAUTHREQUIRED if using AAD auth — the backend needs it to prepare // for processing the FEDAUTH token in LOGIN. @@ -155,7 +157,12 @@ impl Connection { let prelogin_encryption = config.encryption; let prelogin = connection - .prelogin(prelogin_encryption, fed_auth_required, config.instance_name.clone(), is_backend) + .prelogin( + prelogin_encryption, + fed_auth_required, + config.instance_name.clone(), + is_backend, + ) .await?; // Use login_server_name if set (for routed connections, this is the @@ -258,8 +265,7 @@ impl Connection { } // 3. Feed PRELOGIN packet(s) to the transport buffer (no flush yet) - let packet_size = - (connection.context.packet_size() as usize) - crate::tds::HEADER_BYTES; + let packet_size = (connection.context.packet_size() as usize) - crate::tds::HEADER_BYTES; let mut prelogin_payload = BytesMut::new(); prelogin_msg.encode(&mut prelogin_payload)?; @@ -295,9 +301,7 @@ impl Connection { } is_first_login_pkt = false; - connection - .feed_to_wire(login_header, split_payload) - .await?; + connection.feed_to_wire(login_header, split_payload).await?; } // 5. Single flush: send PRELOGIN+LOGIN together in one TLS write @@ -429,35 +433,9 @@ impl Connection { /// Feeds a packet to the transport buffer WITHOUT flushing. /// Use `flush_sink()` after feeding all packets to send them in one batch. - #[allow(dead_code)] - async fn feed_to_wire( - &mut self, - header: PacketHeader, - data: BytesMut, - ) -> crate::Result<()> { + async fn feed_to_wire(&mut self, header: PacketHeader, data: BytesMut) -> crate::Result<()> { self.flushed = false; - // Debug: dump the raw TDS frame (header + payload) to a file - if let Ok(mut f) = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open("/tmp/opencode/tiberius_raw_tds.bin") - { - use std::io::Write as _; - // Encode the header + data as the PacketCodec would - let mut hdr_buf = BytesMut::with_capacity(8); - use bytes::BufMut as _; - let pkt_len = (data.len() + 8) as u16; - hdr_buf.put_u8(header.r#type() as u8); - hdr_buf.put_u8(header.status() as u8); - hdr_buf.put_u16(pkt_len); - hdr_buf.put_u16(0); // spid - hdr_buf.put_u8(0); // id placeholder (not accessible) - hdr_buf.put_u8(0); // window - let _ = f.write_all(&hdr_buf); - let _ = f.write_all(&data); - } - let packet = Packet::new(header, data); self.transport.feed(packet).await?; @@ -547,7 +525,7 @@ impl Connection { /// Defines the login record rules with SQL Server. Authentication with /// connection options. #[allow(clippy::too_many_arguments)] - async fn login<'a>( + async fn login( mut self, auth: AuthMethod, encryption: EncryptionLevel, diff --git a/src/client/tls_stream/native_tls_stream.rs b/src/client/tls_stream/native_tls_stream.rs index 34928036d..a5d5a81fe 100644 --- a/src/client/tls_stream/native_tls_stream.rs +++ b/src/client/tls_stream/native_tls_stream.rs @@ -25,12 +25,12 @@ pub(crate) async fn create_tls_stream( if let Ok(buf) = fs::read(path) { let cert = match path.extension() { Some(ext) - if ext.to_ascii_lowercase() == "pem" - || ext.to_ascii_lowercase() == "crt" => + if ext.eq_ignore_ascii_case("pem") + || ext.eq_ignore_ascii_case("crt") => { Some(native_tls_crate::Certificate::from_pem(&buf)?) } - Some(ext) if ext.to_ascii_lowercase() == "der" => { + Some(ext) if ext.eq_ignore_ascii_case("der") => { Some(native_tls_crate::Certificate::from_der(&buf)?) } Some(_) | None => { @@ -74,12 +74,12 @@ pub(crate) async fn create_tls_stream( if let Ok(buf) = fs::read(path) { let cert = match path.extension() { Some(ext) - if ext.to_ascii_lowercase() == "pem" - || ext.to_ascii_lowercase() == "crt" => + if ext.eq_ignore_ascii_case("pem") + || ext.eq_ignore_ascii_case("crt") => { Some(Certificate::from_pem(&buf)?) } - Some(ext) if ext.to_ascii_lowercase() == "der" => { + Some(ext) if ext.eq_ignore_ascii_case("der") => { Some(Certificate::from_der(&buf)?) } Some(_) | None => return Err(Error::Io { diff --git a/src/tds/codec/login.rs b/src/tds/codec/login.rs index f0113b25d..4fa4a82d3 100644 --- a/src/tds/codec/login.rs +++ b/src/tds/codec/login.rs @@ -33,8 +33,7 @@ impl FeatureLevel { pub fn done_row_count_bytes(self) -> u8 { // TDS 8.0 (0x08000000) is numerically lower than 7.x versions but is // functionally equivalent to SqlServerN for row count encoding. - if self == FeatureLevel::SqlServer2022 - || self as u32 >= FeatureLevel::SqlServer2005 as u32 + if self == FeatureLevel::SqlServer2022 || self as u32 >= FeatureLevel::SqlServer2005 as u32 { 8 } else { @@ -44,8 +43,7 @@ impl FeatureLevel { /// Returns true if this version uses modern (post-2005) wire formats. pub fn is_modern(self) -> bool { - self == FeatureLevel::SqlServer2022 - || self as u32 >= FeatureLevel::SqlServer2005 as u32 + self == FeatureLevel::SqlServer2022 || self as u32 >= FeatureLevel::SqlServer2005 as u32 } } @@ -323,9 +321,9 @@ impl<'a> Encode for LoginMessage<'a> { &self.password, &self.app_name, &self.server_name, - &"".into(), // 5. ibExtension + &"".into(), // 5. ibExtension &self.clt_int_name, // ibCltIntName - &"".into(), // ibLanguage + &"".into(), // ibLanguage &self.db_name, &"".into(), // 9. ClientId (6 bytes); this is included in var_data so we don't lack the bytes of cbSspiLong (4=2*2) and can insert it at the correct position &"".into(), // 10. ibSSPI diff --git a/src/tds/codec/pre_login.rs b/src/tds/codec/pre_login.rs index 4dafa254d..27bfc346f 100644 --- a/src/tds/codec/pre_login.rs +++ b/src/tds/codec/pre_login.rs @@ -115,8 +115,8 @@ impl Encode for PreloginMessage { // encryption fields.push((PRELOGIN_ENCRYPTION, 0x01)); // encryption - // In TDS 8 strict mode, the wire value must be ENCRYPT_STRICT (0x08) - // per MS-TDS spec. Other values map directly to their enum discriminant. + // In TDS 8 strict mode, the wire value must be ENCRYPT_STRICT (0x08) + // per MS-TDS spec. Other values map directly to their enum discriminant. let encryption_wire_value = match self.encryption { EncryptionLevel::Strict => 0x08u8, other => other as u8, @@ -228,9 +228,7 @@ impl Decode for PreloginMessage { tds::EncryptionLevel::Strict } else { tds::EncryptionLevel::try_from(encrypt).map_err(|_| { - Error::Protocol( - format!("invalid encryption value: {}", encrypt).into(), - ) + Error::Protocol(format!("invalid encryption value: {}", encrypt).into()) })? }; ret.encryption = level; diff --git a/src/tds/codec/token/token_feature_ext_ack.rs b/src/tds/codec/token/token_feature_ext_ack.rs index 12760ed80..de8d45a20 100644 --- a/src/tds/codec/token/token_feature_ext_ack.rs +++ b/src/tds/codec/token/token_feature_ext_ack.rs @@ -1,4 +1,7 @@ -use crate::{SqlReadBytes, FEA_EXT_AZURESQLSUPPORT, FEA_EXT_COLUMNENCRYPTION, FEA_EXT_FEDAUTH, FEA_EXT_TERMINATOR, FEA_EXT_UTF8_SUPPORT}; +use crate::{ + SqlReadBytes, FEA_EXT_AZURESQLSUPPORT, FEA_EXT_COLUMNENCRYPTION, FEA_EXT_FEDAUTH, + FEA_EXT_TERMINATOR, FEA_EXT_UTF8_SUPPORT, +}; use futures_util::AsyncReadExt; #[derive(Debug)] @@ -23,7 +26,10 @@ pub enum FeatureAck { /// UTF-8 Support acknowledgment. Utf8Support(Vec), /// Unknown feature — stored for forward-compatibility. - Unknown { feature_id: u8, data: Vec }, + Unknown { + feature_id: u8, + data: Vec, + }, } impl TokenFeatureExtAck { diff --git a/tests/azure_sql.rs b/tests/azure_sql.rs index 141147a9a..1db1d4038 100644 --- a/tests/azure_sql.rs +++ b/tests/azure_sql.rs @@ -116,10 +116,7 @@ async fn connect_to_azure_sql_regular( match Client::connect(config, tcp.compat_write()).await { Ok(client) => Ok(client), Err(Error::Routing { host, port }) => { - eprintln!( - "Routing redirect to {}:{}, reconnecting...", - host, port - ); + eprintln!("Routing redirect to {}:{}, reconnecting...", host, port); let backend_host = host.split('\\').next().unwrap_or(&host); @@ -196,7 +193,10 @@ async fn azure_sql_strict_server_metadata() -> anyhow::Result<()> { "Should be Azure SQL, got: {}", ver ); - assert_eq!(db_name, database, "Should connect to the requested database"); + assert_eq!( + db_name, database, + "Should connect to the requested database" + ); eprintln!("Version: {}", &ver[..ver.find('\n').unwrap_or(ver.len())]); eprintln!("Database: {}", db_name); eprintln!("Login: {}", login_name); @@ -288,9 +288,7 @@ async fn azure_sql_strict_ddl_dml() -> anyhow::Result<()> { // Create a temp table (must consume result before next command) client - .simple_query( - "CREATE TABLE #tds8_test (id INT, name NVARCHAR(50), value DECIMAL(10,2))", - ) + .simple_query("CREATE TABLE #tds8_test (id INT, name NVARCHAR(50), value DECIMAL(10,2))") .await? .into_results() .await?; @@ -401,10 +399,7 @@ async fn azure_sql_strict_verify_encryption() -> anyhow::Result<()> { ); assert_eq!(net_transport, "TCP"); // Azure SQL with AAD token uses NTML at transport but AAD at auth layer - assert!( - !auth_scheme.is_empty(), - "Should have an auth scheme" - ); + assert!(!auth_scheme.is_empty(), "Should have an auth scheme"); eprintln!( "Azure SQL encryption verified: option={}, scheme={}, transport={}", diff --git a/tests/fabric.rs b/tests/fabric.rs index 773b0a520..9ce47579d 100644 --- a/tests/fabric.rs +++ b/tests/fabric.rs @@ -53,8 +53,8 @@ async fn get_aad_token() -> anyhow::Result { .map_err(|_| anyhow::anyhow!("Neither FABRIC_AAD_TOKEN nor FABRIC_CLIENT_ID is set"))?; let client_secret = env::var("FABRIC_CLIENT_SECRET") .map_err(|_| anyhow::anyhow!("FABRIC_CLIENT_SECRET not set"))?; - let tenant_id = env::var("FABRIC_TENANT_ID") - .map_err(|_| anyhow::anyhow!("FABRIC_TENANT_ID not set"))?; + let tenant_id = + env::var("FABRIC_TENANT_ID").map_err(|_| anyhow::anyhow!("FABRIC_TENANT_ID not set"))?; use azure_identity::client_credentials_flow; use oauth2::{ClientId, ClientSecret}; @@ -353,10 +353,7 @@ async fn fabric_strict_ddl_dml() -> anyhow::Result<()> { // DELETE let rows_affected = client - .execute( - format!("DELETE FROM {table_name} WHERE id = @P1"), - &[&2i32], - ) + .execute(format!("DELETE FROM {table_name} WHERE id = @P1"), &[&2i32]) .await? .total(); assert_eq!(rows_affected, 1); @@ -414,21 +411,19 @@ async fn fabric_strict_large_result() -> anyhow::Result<()> { CAST(l_quantity AS DECIMAL(10,2)) AS l_quantity, \ CAST(l_extendedprice AS DECIMAL(12,2)) AS l_extendedprice, \ l_shipdate, l_comment \ - FROM lineitem ORDER BY l_orderkey, l_linenumber".to_string() + FROM lineitem ORDER BY l_orderkey, l_linenumber" + .to_string() } else { // Generate 1000 rows via cross join (no OPTION hints — Fabric DW disallows them) "SELECT TOP 1000 ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS row_num, \ REPLICATE('X', 200) AS padding \ FROM (VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10)) AS t1(n) \ CROSS JOIN (VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10)) AS t2(n) \ - CROSS JOIN (VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10)) AS t3(n)".to_string() + CROSS JOIN (VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10)) AS t3(n)" + .to_string() }; - let rows: Vec<_> = client - .query(query, &[]) - .await? - .into_first_result() - .await?; + let rows: Vec<_> = client.query(query, &[]).await?.into_first_result().await?; assert_eq!(rows.len(), 1000, "Should get exactly 1000 rows"); eprintln!( diff --git a/tests/sql_server_tds8.rs b/tests/sql_server_tds8.rs index b63d61788..901dbf44f 100644 --- a/tests/sql_server_tds8.rs +++ b/tests/sql_server_tds8.rs @@ -55,17 +55,14 @@ macro_rules! skip_if_no_sql_server { } /// Connect to SQL Server with TDS 8 strict encryption (TLS-first, no routing). -async fn connect_strict( -) -> anyhow::Result>> { +async fn connect_strict() -> anyhow::Result>> { let host = env::var("SQL_SERVER_HOST").unwrap_or_else(|_| "localhost".to_string()); let port = env::var("SQL_SERVER_PORT").unwrap_or_else(|_| "1434".to_string()); let user = env::var("SQL_SERVER_USER").unwrap_or_else(|_| "sa".to_string()); let password = env::var("SQL_SERVER_PASSWORD")?; let conn_str = if let Ok(ca_path) = env::var("SQL_SERVER_CA_CERT") { - format!( - "server=tcp:{host},{port};encrypt=strict;database=master;Certificate={ca_path}" - ) + format!("server=tcp:{host},{port};encrypt=strict;database=master;Certificate={ca_path}") } else { format!( "server=tcp:{host},{port};encrypt=strict;TrustServerCertificate=true;database=master" @@ -133,10 +130,7 @@ async fn sql_server_strict_server_metadata() -> anyhow::Result<()> { ); assert_eq!(db_name, "master"); assert_eq!(transport, "TCP"); - eprintln!( - "Version: {}", - &ver[..ver.find('\n').unwrap_or(ver.len())] - ); + eprintln!("Version: {}", &ver[..ver.find('\n').unwrap_or(ver.len())]); eprintln!("Database: {}, Login: {}", db_name, login_name); Ok(()) @@ -268,20 +262,14 @@ async fn sql_server_strict_ddl_dml() -> anyhow::Result<()> { // Delete let rows_affected = client - .execute( - "DELETE FROM #tds8_strict_test WHERE id > @P1", - &[&2i32], - ) + .execute("DELETE FROM #tds8_strict_test WHERE id > @P1", &[&2i32]) .await? .total(); assert_eq!(rows_affected, 2); // Verify final state let rows: Vec<_> = client - .query( - "SELECT id, value FROM #tds8_strict_test ORDER BY id", - &[], - ) + .query("SELECT id, value FROM #tds8_strict_test ORDER BY id", &[]) .await? .into_first_result() .await?; @@ -383,9 +371,8 @@ async fn sql_server_strict_ca_cert_validation() -> anyhow::Result<()> { let user = env::var("SQL_SERVER_USER").unwrap_or_else(|_| "sa".to_string()); let password = env::var("SQL_SERVER_PASSWORD")?; - let conn_str = format!( - "server=tcp:{host},{port};encrypt=strict;database=master;Certificate={ca_cert}" - ); + let conn_str = + format!("server=tcp:{host},{port};encrypt=strict;database=master;Certificate={ca_cert}"); let mut config = Config::from_ado_string(&conn_str)?; config.authentication(AuthMethod::sql_server(&user, &password)); From 4f845431375358b1711435edcbf76f7ad61d0921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 17:25:41 +0200 Subject: [PATCH 09/15] feat: add TokenProvider trait for dynamic AAD token refresh Add async TokenProvider trait enabling dynamic token acquisition on each connection/reconnection, replacing the need for static pre-obtained tokens in long-lived applications. Changes: - Define TokenProvider trait with async get_token() method (uses async-trait) - Add AuthMethod::AADTokenProvider(Arc) variant - Add AuthMethod::token_provider() constructor and is_aad() helper - Manual Clone/Debug/PartialEq/Eq impls for AuthMethod (Arc breaks derives) - Resolve token via provider in both pipelined and standard login paths - Export TokenProvider from crate root - Integration test verifying provider is called twice (gateway + backend routing) --- src/client/auth.rs | 150 +++++++++++++++++++++++++++++++++++++-- src/client/connection.rs | 32 ++++++++- src/lib.rs | 2 +- tests/fabric.rs | 101 ++++++++++++++++++++++++++ 4 files changed, 278 insertions(+), 7 deletions(-) diff --git a/src/client/auth.rs b/src/client/auth.rs index 208d8d060..62f14086c 100644 --- a/src/client/auth.rs +++ b/src/client/auth.rs @@ -1,4 +1,7 @@ use std::fmt::Debug; +use std::sync::Arc; + +use async_trait::async_trait; #[derive(Clone, PartialEq, Eq)] pub struct SqlServerAuth { @@ -46,8 +49,41 @@ impl Debug for WindowsAuth { } } +/// A trait for providing AAD/Entra ID tokens dynamically, supporting token refresh. +/// +/// Implement this trait to supply fresh tokens on each connection or reconnection. +/// This is useful for long-lived applications where tokens expire (~1 hour) and +/// need to be refreshed transparently. +/// +/// # Example +/// +/// ```rust,no_run +/// use async_trait::async_trait; +/// use tiberius::TokenProvider; +/// +/// struct MyTokenProvider { +/// // your credential state (e.g., client_id, client_secret, tenant_id) +/// } +/// +/// #[async_trait] +/// impl TokenProvider for MyTokenProvider { +/// async fn get_token(&self) -> Result> { +/// // Call your identity provider here to get a fresh token +/// // e.g., azure_identity::DefaultAzureCredential, MSAL, etc. +/// Ok("fresh-token".to_string()) +/// } +/// } +/// ``` +#[async_trait] +pub trait TokenProvider: Send + Sync { + /// Obtain a fresh AAD/Entra ID access token for the `https://database.windows.net/` resource. + /// + /// This method is called each time a new connection or reconnection is established. + /// Implementations should handle caching and refresh internally. + async fn get_token(&self) -> Result>; +} + /// Defines the method of authentication to the server. -#[derive(Clone, Debug, PartialEq, Eq)] pub enum AuthMethod { /// Authenticate directly with SQL Server. SqlServer(SqlServerAuth), @@ -67,13 +103,82 @@ pub enum AuthMethod { doc(cfg(any(windows, all(unix, feature = "integrated-auth-gssapi")))) )] Integrated, - /// Authenticate with an AAD token. The token should encode an AAD user/service principal - /// which has access to SQL Server. + /// Authenticate with a static AAD token. The token should encode an AAD + /// user/service principal which has access to SQL Server. + /// + /// For long-lived applications where tokens may expire, prefer + /// [`AuthMethod::token_provider`] instead. AADToken(String), + /// Authenticate with a dynamic AAD token provider that supports refresh. + /// + /// The provider's `get_token()` method is called each time a connection + /// (or routing reconnection) is established, ensuring fresh tokens. + AADTokenProvider(Arc), #[doc(hidden)] None, } +impl Clone for AuthMethod { + fn clone(&self) -> Self { + match self { + Self::SqlServer(a) => Self::SqlServer(a.clone()), + #[cfg(any(all(windows, feature = "winauth"), doc))] + Self::Windows(a) => Self::Windows(a.clone()), + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi"), + doc + ))] + Self::Integrated => Self::Integrated, + Self::AADToken(t) => Self::AADToken(t.clone()), + Self::AADTokenProvider(p) => Self::AADTokenProvider(Arc::clone(p)), + Self::None => Self::None, + } + } +} + +impl Debug for AuthMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::SqlServer(a) => f.debug_tuple("SqlServer").field(a).finish(), + #[cfg(any(all(windows, feature = "winauth"), doc))] + Self::Windows(a) => f.debug_tuple("Windows").field(a).finish(), + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi"), + doc + ))] + Self::Integrated => write!(f, "Integrated"), + Self::AADToken(_) => write!(f, "AADToken()"), + Self::AADTokenProvider(_) => write!(f, "AADTokenProvider(...)"), + Self::None => write!(f, "None"), + } + } +} + +impl PartialEq for AuthMethod { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::SqlServer(a), Self::SqlServer(b)) => a == b, + #[cfg(any(all(windows, feature = "winauth"), doc))] + (Self::Windows(a), Self::Windows(b)) => a == b, + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi"), + doc + ))] + (Self::Integrated, Self::Integrated) => true, + (Self::AADToken(a), Self::AADToken(b)) => a == b, + // Token providers are compared by Arc pointer identity + (Self::AADTokenProvider(a), Self::AADTokenProvider(b)) => Arc::ptr_eq(a, b), + (Self::None, Self::None) => true, + _ => false, + } + } +} + +impl Eq for AuthMethod {} + impl AuthMethod { /// Construct a new SQL Server authentication configuration. pub fn sql_server(user: impl ToString, password: impl ToString) -> Self { @@ -99,8 +204,45 @@ impl AuthMethod { }) } - /// Construct a new configuration with AAD auth token. + /// Construct a new configuration with a static AAD auth token. + /// + /// For long-lived applications, prefer [`AuthMethod::token_provider`] which + /// supports automatic token refresh. pub fn aad_token(token: impl ToString) -> Self { Self::AADToken(token.to_string()) } + + /// Construct a new configuration with a dynamic token provider. + /// + /// The provider's `get_token()` method is called on each new connection, + /// ensuring fresh tokens even for long-lived applications. + /// + /// # Example + /// + /// ```rust,no_run + /// use std::sync::Arc; + /// use async_trait::async_trait; + /// use tiberius::{AuthMethod, TokenProvider}; + /// + /// struct AzCliTokenProvider; + /// + /// #[async_trait] + /// impl TokenProvider for AzCliTokenProvider { + /// async fn get_token(&self) -> Result> { + /// // In real code, call azure_identity or az CLI + /// Ok("fresh-token".to_string()) + /// } + /// } + /// + /// let auth = AuthMethod::token_provider(Arc::new(AzCliTokenProvider)); + /// ``` + pub fn token_provider(provider: Arc) -> Self { + Self::AADTokenProvider(provider) + } + + /// Returns true if this auth method uses AAD token authentication + /// (either static or via provider). + pub(crate) fn is_aad(&self) -> bool { + matches!(self, Self::AADToken(_) | Self::AADTokenProvider(_)) + } } diff --git a/src/client/connection.rs b/src/client/connection.rs index 0b765f5d0..88a83231b 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -110,7 +110,7 @@ impl Connection { return Ok(connection); } - let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_)); + let fed_auth_required = config.auth.is_aad(); let prelogin = connection .prelogin( @@ -149,7 +149,7 @@ impl Connection { // FEDAUTHREQUIRED if using AAD auth — the backend needs it to prepare // for processing the FEDAUTH token in LOGIN. let is_backend = config.instance_name.is_some(); - let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_)); + let fed_auth_required = config.auth.is_aad(); // In TDS 8 strict mode, send ENCRYPT_STRICT (0x08) on the wire. // TLS is already established, and the PRELOGIN encryption field signals @@ -250,6 +250,17 @@ impl Connection { ); login_message.aad_token(token, true, None); } + AuthMethod::AADTokenProvider(provider) => { + let token = provider.get_token().await.map_err(|e| { + crate::Error::Protocol(format!("Token provider failed: {}", e).into()) + })?; + event!( + Level::INFO, + token_len = token.len(), + "Sending pipelined LOGIN with provider token (fed_auth_required=true, nonce=None)" + ); + login_message.aad_token(token, true, None); + } AuthMethod::SqlServer(auth) => { login_message.user_name(auth.user().to_string()); login_message.password(auth.password().to_string()); @@ -686,6 +697,23 @@ impl Connection { ); login_message.aad_token(token, prelogin.fed_auth_required, prelogin.nonce); + let id = self.context.next_packet_id(); + self.send(PacketHeader::login(id), login_message).await?; + self = self.post_login_encryption(encryption); + } + AuthMethod::AADTokenProvider(provider) => { + let token = provider.get_token().await.map_err(|e| { + crate::Error::Protocol(format!("Token provider failed: {}", e).into()) + })?; + event!( + Level::INFO, + fed_auth_echo = prelogin.fed_auth_required, + has_nonce = prelogin.nonce.is_some(), + token_len = token.len(), + "Sending LOGIN with provider token" + ); + login_message.aad_token(token, prelogin.fed_auth_required, prelogin.nonce); + let id = self.context.next_packet_id(); self.send(PacketHeader::login(id), login_message).await?; self = self.post_login_encryption(encryption); diff --git a/src/lib.rs b/src/lib.rs index 882f5ad36..c8d508848 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -269,7 +269,7 @@ mod tds; mod sql_browser; -pub use client::{AuthMethod, Client, Config}; +pub use client::{AuthMethod, Client, Config, TokenProvider}; pub(crate) use error::Error; pub use from_sql::{FromSql, FromSqlOwned}; pub use query::Query; diff --git a/tests/fabric.rs b/tests/fabric.rs index 9ce47579d..1ce199833 100644 --- a/tests/fabric.rs +++ b/tests/fabric.rs @@ -552,3 +552,104 @@ async fn fabric_strict_string_operations() -> anyhow::Result<()> { eprintln!("String/unicode operations over Fabric strict connection: OK"); Ok(()) } + +/// Test: Connect to Fabric using a TokenProvider (dynamic token refresh). +/// +/// This verifies the AADTokenProvider auth path end-to-end through the full +/// gateway → routing → pipelined backend reconnect flow. +#[tokio::test] +async fn fabric_strict_token_provider() -> anyhow::Result<()> { + skip_if_no_fabric!(); + + use async_trait::async_trait; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use tiberius::TokenProvider; + + /// A test token provider that wraps a static token and counts invocations. + struct CountingTokenProvider { + token: String, + call_count: AtomicUsize, + } + + #[async_trait] + impl TokenProvider for CountingTokenProvider { + async fn get_token(&self) -> Result> { + self.call_count.fetch_add(1, Ordering::SeqCst); + Ok(self.token.clone()) + } + } + + let endpoint = env::var("FABRIC_ENDPOINT")?; + let database = env::var("FABRIC_DATABASE")?; + let token = get_aad_token().await?; + + let provider = Arc::new(CountingTokenProvider { + token, + call_count: AtomicUsize::new(0), + }); + + // Connect using the provider-based auth + let mut config = fabric_config(&endpoint, &database)?; + config.authentication(AuthMethod::token_provider(provider.clone())); + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let client_result = Client::connect(config, tcp.compat_write()).await; + + let mut client = match client_result { + Ok(client) => client, + Err(Error::Routing { host, port }) => { + eprintln!( + "Provider test: routing redirect to {}:{}, reconnecting...", + host, port + ); + + let backend_host = host.split('\\').next().unwrap_or(&host); + let instance_name = host.split('\\').nth(1); + + let mut backend_config = Config::new(); + backend_config.host(backend_host); + backend_config.port(port); + backend_config.encryption(EncryptionLevel::Strict); + backend_config.authentication(AuthMethod::token_provider(provider.clone())); + backend_config.database(&*database); + if let Some(inst) = instance_name { + backend_config.instance_name(inst); + } + backend_config.login_server_name(&*endpoint); + + let tcp = TcpStream::connect(backend_config.get_addr()).await?; + tcp.set_nodelay(true)?; + + Client::connect(backend_config, tcp.compat_write()).await? + } + Err(e) => return Err(e.into()), + }; + + // Verify we can query + let row = client + .query("SELECT 42 AS provider_test", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(42i32), row.get("provider_test")); + + // The provider should have been called at least once (gateway), + // and typically twice (gateway + backend after routing) + let calls = provider.call_count.load(Ordering::SeqCst); + eprintln!( + "TokenProvider was called {} time(s) during connection", + calls + ); + assert!( + calls >= 1, + "TokenProvider should be called at least once, got {}", + calls + ); + + eprintln!("Fabric strict connection with TokenProvider: OK"); + Ok(()) +} From 63f226a31057780d1028218e069a0af2e632cd26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 17:37:26 +0200 Subject: [PATCH 10/15] feat: auto-detect TDS 8 strict encryption from hostname MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the caller does not explicitly set the encryption level, the client now automatically upgrades to Strict mode for known endpoints: - *.datawarehouse.fabric.microsoft.com (Fabric SQL endpoints) - *.pbidedicated.windows.net (Fabric backend servers) This means users connecting to Fabric no longer need to specify encrypt=strict in connection strings — the client detects it from the hostname. Explicit encryption settings are always respected. Implementation: - Add encryption_explicit flag to Config (tracks user intent) - Add has_encrypt_key() to ConfigString trait (detects conn string key) - Add resolve_encryption() method with hostname heuristic - Call resolve_encryption() at start of Connection::connect() - Emit tracing::info event when auto-detection fires - 9 unit tests for the detection heuristic - Integration test verifying end-to-end auto-detect with Fabric --- src/client/config.rs | 150 ++++++++++++++++++++++++++++++++++++++- src/client/connection.rs | 13 +++- tests/fabric.rs | 69 ++++++++++++++++++ 3 files changed, 230 insertions(+), 2 deletions(-) diff --git a/src/client/config.rs b/src/client/config.rs index 1beca39a6..598bbf879 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -41,6 +41,10 @@ pub struct Config { /// the original gateway hostname so the backend knows which endpoint /// the client intended to reach. pub(crate) login_server_name: Option, + /// Tracks whether encryption was explicitly set by the user (via + /// `encryption()` method or connection string `encrypt=` key). When false, + /// the connection layer may auto-detect strict mode from the hostname. + pub(crate) encryption_explicit: bool, } #[derive(Clone, Debug)] @@ -76,6 +80,7 @@ impl Default for Config { readonly: false, strict_pipelined: false, login_server_name: None, + encryption_explicit: false, } } } @@ -132,8 +137,12 @@ impl Config { /// - Without `tls` feature, defaults to `NotSupported`. /// - Use `Strict` for TDS 8 strict transport encryption (required for /// Microsoft Fabric endpoints and SQL Server 2022+ strict mode). + /// + /// When not explicitly set, the client will auto-detect strict mode for + /// known endpoints (e.g., `*.datawarehouse.fabric.microsoft.com`). pub fn encryption(&mut self, encryption: EncryptionLevel) { self.encryption = encryption; + self.encryption_explicit = true; } /// If set, the server certificate will not be validated and it is accepted @@ -213,6 +222,34 @@ impl Config { self.login_server_name = Some(name.into()); } + /// Resolves the effective encryption level, applying auto-detection when + /// the user did not explicitly configure encryption. + /// + /// Currently detects: + /// - `*.datawarehouse.fabric.microsoft.com` → `EncryptionLevel::Strict` + /// + /// Returns the (possibly upgraded) encryption level. + pub(crate) fn resolve_encryption(&mut self) -> EncryptionLevel { + if self.encryption_explicit { + return self.encryption; + } + + if let Some(host) = &self.host { + if Self::host_requires_strict(host) { + self.encryption = EncryptionLevel::Strict; + } + } + + self.encryption + } + + /// Returns true if the given hostname is known to require TDS 8 strict mode. + fn host_requires_strict(host: &str) -> bool { + let host_lower = host.to_ascii_lowercase(); + host_lower.ends_with(".datawarehouse.fabric.microsoft.com") + || host_lower.ends_with(".pbidedicated.windows.net") + } + pub(crate) fn get_host(&self) -> &str { self.host .as_deref() @@ -307,7 +344,9 @@ impl Config { builder.trust_cert_ca(ca); } - builder.encryption(s.encrypt()?); + if s.has_encrypt_key() { + builder.encryption(s.encrypt()?); + } builder.readonly(s.readonly()); @@ -326,6 +365,12 @@ pub(crate) trait ConfigString { fn server(&self) -> crate::Result; + /// Returns true if the `encrypt` key was explicitly present in the + /// connection string. + fn has_encrypt_key(&self) -> bool { + self.dict().contains_key("encrypt") + } + fn authentication(&self) -> crate::Result { let user = self .dict() @@ -434,3 +479,106 @@ pub(crate) trait ConfigString { .is_some() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn auto_detects_strict_for_fabric_host() { + let mut config = Config::new(); + config.host("myworkspace-abc123.datawarehouse.fabric.microsoft.com"); + // encryption_explicit is false (default) + assert!(!config.encryption_explicit); + + let resolved = config.resolve_encryption(); + assert_eq!(resolved, EncryptionLevel::Strict); + assert_eq!(config.encryption, EncryptionLevel::Strict); + } + + #[test] + fn auto_detects_strict_for_pbidedicated_host() { + let mut config = Config::new(); + config.host("abc123.pbidedicated.windows.net"); + + let resolved = config.resolve_encryption(); + assert_eq!(resolved, EncryptionLevel::Strict); + } + + #[test] + fn no_auto_detect_for_regular_sql_server() { + let mut config = Config::new(); + config.host("myserver.database.windows.net"); + + let resolved = config.resolve_encryption(); + // Should stay at the default (Required with TLS features) + assert_eq!(resolved, EncryptionLevel::Required); + } + + #[test] + fn no_auto_detect_for_localhost() { + let mut config = Config::new(); + config.host("localhost"); + + let resolved = config.resolve_encryption(); + assert_eq!(resolved, EncryptionLevel::Required); + } + + #[test] + fn explicit_encryption_not_overridden() { + let mut config = Config::new(); + config.host("myworkspace-abc123.datawarehouse.fabric.microsoft.com"); + // User explicitly sets Required (not Strict) + config.encryption(EncryptionLevel::Required); + assert!(config.encryption_explicit); + + let resolved = config.resolve_encryption(); + // Should respect the explicit setting, not auto-upgrade + assert_eq!(resolved, EncryptionLevel::Required); + } + + #[test] + fn explicit_strict_stays_strict() { + let mut config = Config::new(); + config.host("custom-server.example.com"); + config.encryption(EncryptionLevel::Strict); + + let resolved = config.resolve_encryption(); + assert_eq!(resolved, EncryptionLevel::Strict); + } + + #[test] + fn connection_string_without_encrypt_auto_detects_fabric() { + // ADO string with no encrypt= key, but a Fabric host + let config = Config::from_ado_string( + "server=tcp:myworkspace.datawarehouse.fabric.microsoft.com,1433;database=mydb", + ) + .unwrap(); + // encrypt not specified → encryption_explicit should be false + assert!(!config.encryption_explicit); + } + + #[test] + fn connection_string_with_encrypt_marks_explicit() { + let config = Config::from_ado_string( + "server=tcp:myworkspace.datawarehouse.fabric.microsoft.com,1433;encrypt=true;database=mydb", + ) + .unwrap(); + assert!(config.encryption_explicit); + assert_eq!(config.encryption, EncryptionLevel::Required); + } + + #[test] + fn host_requires_strict_case_insensitive() { + assert!(Config::host_requires_strict( + "MyWorkspace.DATAWAREHOUSE.FABRIC.MICROSOFT.COM" + )); + assert!(Config::host_requires_strict( + "ABC.Datawarehouse.Fabric.Microsoft.Com" + )); + assert!(!Config::host_requires_strict("fabric.microsoft.com")); + assert!(!Config::host_requires_strict( + "something.database.windows.net" + )); + } +} diff --git a/src/client/connection.rs b/src/client/connection.rs index 88a83231b..f77f92a3b 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -72,7 +72,7 @@ impl Debug for Connection { impl Connection { /// Creates a new connection - pub(crate) async fn connect(config: Config, tcp_stream: S) -> crate::Result> { + pub(crate) async fn connect(mut config: Config, tcp_stream: S) -> crate::Result> { let context = { let mut context = Context::new(); context.set_spn(config.get_host(), config.get_port()); @@ -88,6 +88,17 @@ impl Connection { buf: BytesMut::new(), }; + // Auto-detect strict mode from hostname when encryption wasn't + // explicitly configured by the caller. + let resolved = config.resolve_encryption(); + if resolved == EncryptionLevel::Strict && !config.encryption_explicit { + event!( + Level::INFO, + host = config.get_host(), + "Auto-detected TDS 8 strict encryption from hostname" + ); + } + // TDS 8 strict mode: TLS handshake first, then PRELOGIN inside TLS. if config.encryption == EncryptionLevel::Strict { let connection = connection.tls_handshake_strict(&config).await?; diff --git a/tests/fabric.rs b/tests/fabric.rs index 1ce199833..b05cd6025 100644 --- a/tests/fabric.rs +++ b/tests/fabric.rs @@ -653,3 +653,72 @@ async fn fabric_strict_token_provider() -> anyhow::Result<()> { eprintln!("Fabric strict connection with TokenProvider: OK"); Ok(()) } + +/// Test: Connect to Fabric WITHOUT explicitly setting `encrypt=strict`. +/// +/// Verifies that the auto-detection heuristic recognizes the Fabric hostname +/// and automatically upgrades to TDS 8 strict encryption. +#[tokio::test] +async fn fabric_auto_detect_strict_from_hostname() -> anyhow::Result<()> { + skip_if_no_fabric!(); + + let endpoint = env::var("FABRIC_ENDPOINT")?; + let database = env::var("FABRIC_DATABASE")?; + let token = get_aad_token().await?; + + // Build config WITHOUT encrypt=strict — rely on auto-detection + let mut config = Config::new(); + config.host(&endpoint); + config.port(1433); + config.database(&database); + config.authentication(AuthMethod::aad_token(&token)); + // Note: NOT calling config.encryption(EncryptionLevel::Strict) + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let client_result = Client::connect(config, tcp.compat_write()).await; + + let mut client = match client_result { + Ok(client) => client, + Err(Error::Routing { host, port }) => { + eprintln!( + "Auto-detect test: routing redirect to {}:{}, reconnecting...", + host, port + ); + + let backend_host = host.split('\\').next().unwrap_or(&host); + let instance_name = host.split('\\').nth(1); + + // Backend config also without explicit encrypt=strict + let mut backend_config = Config::new(); + backend_config.host(backend_host); + backend_config.port(port); + // Auto-detection will recognize .pbidedicated.windows.net too + backend_config.authentication(AuthMethod::aad_token(&token)); + backend_config.database(&database); + if let Some(inst) = instance_name { + backend_config.instance_name(inst); + } + backend_config.login_server_name(&endpoint); + + let tcp = TcpStream::connect(backend_config.get_addr()).await?; + tcp.set_nodelay(true)?; + + Client::connect(backend_config, tcp.compat_write()).await? + } + Err(e) => return Err(e.into()), + }; + + // Verify we can query + let row = client + .query("SELECT 100 AS auto_detect_test", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(100i32), row.get("auto_detect_test")); + + eprintln!("Fabric auto-detect strict encryption from hostname: OK"); + Ok(()) +} From 33e1818269d9733703f90d24bdfdb3c5e2a53897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 17:48:55 +0200 Subject: [PATCH 11/15] Improve error messages when TDS 8 strict TLS handshake fails Wrap Tls and Io errors from the strict handshake with contextual guidance: the target hostname, server requirements (SQL Server 2025+ with forcestrict=1, Azure SQL, or Fabric), and the suggested fix (encrypt=true instead of encrypt=strict). Add 3 unit tests for the wrapping logic and 1 integration test that verifies the message against a non-strict TCP endpoint. --- src/client/connection.rs | 91 +++++++++++++++++++++++++++++++++++++++- tests/sql_server_tds8.rs | 55 ++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index f77f92a3b..d2c1f4306 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -101,7 +101,12 @@ impl Connection { // TDS 8 strict mode: TLS handshake first, then PRELOGIN inside TLS. if config.encryption == EncryptionLevel::Strict { - let connection = connection.tls_handshake_strict(&config).await?; + let connection = match connection.tls_handshake_strict(&config).await { + Ok(c) => c, + Err(e) => { + return Err(Self::wrap_strict_tls_error(e, config.get_host())); + } + }; if config.strict_pipelined { // Backend reconnection after routing: pipeline PRELOGIN+LOGIN @@ -849,6 +854,33 @@ impl Connection { Ok(self) } + /// Wraps a TLS handshake error with context about strict mode requirements. + /// + /// When TDS 8 strict TLS fails, the raw error (connection reset, handshake + /// failure) is often confusing. This adds actionable guidance. + fn wrap_strict_tls_error(err: crate::Error, host: &str) -> crate::Error { + match &err { + crate::Error::Tls(msg) => crate::Error::Tls(format!( + "TDS 8 strict TLS handshake with '{}' failed: {}. \ + The server may not support TDS 8 strict mode \ + (requires SQL Server 2025+ with forcestrict=1, Azure SQL, or Microsoft Fabric). \ + For servers that don't support strict mode, use encrypt=true instead of encrypt=strict.", + host, msg + )), + crate::Error::Io { kind, message } => crate::Error::Io { + kind: *kind, + message: format!( + "TDS 8 strict TLS handshake with '{}' failed: {}. \ + The server may not support TDS 8 strict mode \ + (requires SQL Server 2025+ with forcestrict=1, Azure SQL, or Microsoft Fabric). \ + For servers that don't support strict mode, use encrypt=true instead of encrypt=strict.", + host, message + ), + }, + _ => err, + } + } + pub(crate) async fn close(mut self) -> crate::Result<()> { self.transport.close().await } @@ -931,3 +963,60 @@ impl SqlReadBytes for Connection { &mut self.context } } + +#[cfg(test)] +mod tests { + use super::*; + use std::io::ErrorKind; + + #[test] + fn wrap_strict_tls_error_wraps_tls_error() { + let original = crate::Error::Tls("handshake failure".to_string()); + let wrapped = + Connection::>>::wrap_strict_tls_error(original, "myserver.example.com"); + + let msg = wrapped.to_string(); + assert!(msg.contains("myserver.example.com"), "should contain host"); + assert!(msg.contains("handshake failure"), "should contain original error"); + assert!(msg.contains("strict"), "should mention strict mode"); + assert!( + msg.contains("encrypt=true"), + "should suggest alternative: {}", + msg + ); + assert!( + msg.contains("SQL Server 2025"), + "should mention version requirement: {}", + msg + ); + } + + #[test] + fn wrap_strict_tls_error_wraps_io_error() { + let original = crate::Error::Io { + kind: ErrorKind::ConnectionReset, + message: "connection reset by peer".to_string(), + }; + let wrapped = + Connection::>>::wrap_strict_tls_error(original, "10.0.0.1"); + + let msg = wrapped.to_string(); + assert!(msg.contains("10.0.0.1"), "should contain host"); + assert!( + msg.contains("connection reset by peer"), + "should contain original: {}", + msg + ); + assert!(msg.contains("strict"), "should mention strict mode"); + } + + #[test] + fn wrap_strict_tls_error_passes_through_other_errors() { + let original = crate::Error::Protocol("something else".into()); + let wrapped = + Connection::>>::wrap_strict_tls_error(original.clone(), "host"); + + // Non-TLS/IO errors should pass through unchanged + assert_eq!(wrapped, original); + } +} diff --git a/tests/sql_server_tds8.rs b/tests/sql_server_tds8.rs index 901dbf44f..b728ef000 100644 --- a/tests/sql_server_tds8.rs +++ b/tests/sql_server_tds8.rs @@ -393,3 +393,58 @@ async fn sql_server_strict_ca_cert_validation() -> anyhow::Result<()> { eprintln!("CA cert validation with strict mode: OK"); Ok(()) } + +/// Test: Verify improved error message when strict TLS handshake fails. +/// +/// This spins up a local TCP listener that immediately closes the connection, +/// simulating a server that doesn't support TDS 8 strict mode. The error +/// message should contain actionable guidance. +#[tokio::test] +async fn strict_error_message_on_non_strict_server() -> anyhow::Result<()> { + use tokio::net::TcpListener; + + // Start a local TCP listener that immediately drops incoming connections + let listener = TcpListener::bind("127.0.0.1:0").await?; + let port = listener.local_addr()?.port(); + + tokio::spawn(async move { + // Accept one connection and immediately drop it (simulates non-TLS server) + if let Ok((conn, _)) = listener.accept().await { + drop(conn); + } + }); + + let mut config = Config::new(); + config.host("my-server.example.com"); + config.port(port); + config.encryption(tiberius::EncryptionLevel::Strict); + config.authentication(AuthMethod::sql_server("sa", "dummy")); + config.trust_cert(); + + let tcp = TcpStream::connect(format!("127.0.0.1:{}", port)).await?; + tcp.set_nodelay(true)?; + + let result = Client::connect(config, tcp.compat_write()).await; + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + + // The error should mention strict mode and provide guidance + assert!( + err_msg.contains("strict") || err_msg.contains("TDS 8"), + "Error message should mention strict mode, got: {}", + err_msg + ); + assert!( + err_msg.contains("my-server.example.com"), + "Error message should contain the hostname, got: {}", + err_msg + ); + assert!( + err_msg.contains("encrypt=true"), + "Error message should suggest alternative, got: {}", + err_msg + ); + + eprintln!("Strict TLS error message: {}", err_msg); + Ok(()) +} From ceef01d5d59e7f27515b082291f874ee0635cec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 18:01:46 +0200 Subject: [PATCH 12/15] Add connection pooling helpers: connection_encryption() and is_healthy() - Client::connection_encryption() returns the negotiated EncryptionLevel, allowing pool implementations to verify security properties of managed connections (e.g., assert all pooled connections use Strict). - Client::is_healthy() executes a lightweight SELECT 1 round-trip, suitable for pool is_valid / health-check hooks. - Store EncryptionLevel in Connection struct, set to the final negotiated value (Strict for TDS 8, or the PRELOGIN-negotiated level otherwise). - Integration tests verify both methods on live strict connections. --- src/client.rs | 57 +++++++++++++++++++++++++++++++++++++++- src/client/connection.rs | 13 ++++++++- tests/sql_server_tds8.rs | 43 +++++++++++++++++++++++++++++- 3 files changed, 110 insertions(+), 3 deletions(-) diff --git a/src/client.rs b/src/client.rs index 688721d10..a0e7e52a1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -21,7 +21,7 @@ use crate::{ codec::{self, IteratorJoin}, stream::{QueryStream, TokenStream}, }, - BulkLoadRequest, ColumnFlag, SqlReadBytes, ToSql, + BulkLoadRequest, ColumnFlag, EncryptionLevel, SqlReadBytes, ToSql, }; use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest}; use enumflags2::BitFlags; @@ -352,6 +352,61 @@ impl Client { self.connection.close().await } + /// Returns the encryption level negotiated for this connection. + /// + /// This is useful for connection pool implementations that need to verify + /// the connection's security properties, or for logging/diagnostics. + /// + /// # Example + /// + /// ```no_run + /// # use tiberius::{Client, Config, EncryptionLevel}; + /// # use tokio_util::compat::TokioAsyncWriteCompatExt; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let config = Config::new(); + /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + /// # let client = Client::connect(config, tcp.compat_write()).await?; + /// match client.connection_encryption() { + /// EncryptionLevel::Strict => println!("TDS 8 strict mode"), + /// EncryptionLevel::Required => println!("Encrypted (TLS upgrade)"), + /// _ => println!("Other encryption level"), + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn connection_encryption(&self) -> EncryptionLevel { + self.connection.encryption + } + + /// Validates the connection by executing a lightweight query (`SELECT 1`). + /// + /// Returns `Ok(())` if the server responds successfully, or an error if + /// the connection is broken. This is intended for use by connection pool + /// `is_valid` / health-check hooks. + /// + /// # Example + /// + /// ```no_run + /// # use tiberius::{Client, Config}; + /// # use tokio_util::compat::TokioAsyncWriteCompatExt; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let config = Config::new(); + /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + /// # let mut client = Client::connect(config, tcp.compat_write()).await?; + /// if client.is_healthy().await.is_err() { + /// // Connection is broken, discard and create a new one + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn is_healthy(&mut self) -> crate::Result<()> { + let stream = self.simple_query("SELECT 1").await?; + stream.into_results().await?; + Ok(()) + } + pub(crate) fn rpc_params<'a>(query: impl Into>) -> Vec> { vec![ RpcParam { diff --git a/src/client/connection.rs b/src/client/connection.rs index d2c1f4306..8114f4395 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -57,6 +57,8 @@ where flushed: bool, context: Context, buf: BytesMut, + /// The encryption level used for this connection. + pub(crate) encryption: EncryptionLevel, } impl Debug for Connection { @@ -86,6 +88,7 @@ impl Connection { context, flushed: false, buf: BytesMut::new(), + encryption: EncryptionLevel::Off, }; // Auto-detect strict mode from hostname when encryption wasn't @@ -99,6 +102,9 @@ impl Connection { ); } + // Record the resolved encryption level for later inspection. + connection.encryption = config.encryption; + // TDS 8 strict mode: TLS handshake first, then PRELOGIN inside TLS. if config.encryption == EncryptionLevel::Strict { let connection = match connection.tls_handshake_strict(&config).await { @@ -139,7 +145,10 @@ impl Connection { let encryption = prelogin.negotiated_encryption(config.encryption); - let connection = connection.tls_handshake(&config, encryption).await?; + let mut connection = connection.tls_handshake(&config, encryption).await?; + + // Update to the actual negotiated encryption level. + connection.encryption = encryption; let mut connection = connection .login( @@ -773,6 +782,7 @@ impl Connection { context, flushed: false, buf: BytesMut::new(), + encryption: self.encryption, }) } else { event!( @@ -824,6 +834,7 @@ impl Connection { context, flushed: false, buf: BytesMut::new(), + encryption: self.encryption, }) } diff --git a/tests/sql_server_tds8.rs b/tests/sql_server_tds8.rs index b728ef000..29ba9e220 100644 --- a/tests/sql_server_tds8.rs +++ b/tests/sql_server_tds8.rs @@ -38,7 +38,7 @@ //! ``` use std::env; -use tiberius::{AuthMethod, Client, Config}; +use tiberius::{AuthMethod, Client, Config, EncryptionLevel}; use tokio::net::TcpStream; use tokio_util::compat::TokioAsyncWriteCompatExt; @@ -448,3 +448,44 @@ async fn strict_error_message_on_non_strict_server() -> anyhow::Result<()> { eprintln!("Strict TLS error message: {}", err_msg); Ok(()) } + +/// Test: connection_encryption() returns Strict for TDS 8 strict connections. +#[tokio::test] +async fn connection_encryption_reports_strict() -> anyhow::Result<()> { + skip_if_no_sql_server!(); + + let client = connect_strict().await?; + + assert_eq!( + client.connection_encryption(), + EncryptionLevel::Strict, + "SQL Server with forcestrict=1 should report Strict encryption" + ); + + eprintln!("connection_encryption() = {:?}", client.connection_encryption()); + Ok(()) +} + +/// Test: is_healthy() succeeds on a live strict connection. +#[tokio::test] +async fn is_healthy_on_strict_connection() -> anyhow::Result<()> { + skip_if_no_sql_server!(); + + let mut client = connect_strict().await?; + + // First health check + client.is_healthy().await?; + + // Run a real query in between + let _ = client + .query("SELECT @@VERSION", &[]) + .await? + .into_row() + .await?; + + // Second health check — still healthy after use + client.is_healthy().await?; + + eprintln!("is_healthy() passed on strict connection"); + Ok(()) +} From afef3b9519e99554633886a2598e9e11c56e0626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 18:08:44 +0200 Subject: [PATCH 13/15] Add integration tests for Fabric SQL Database (*.database.fabric.microsoft.com) Verify tiberius connects to SQL Database in Microsoft Fabric, which uses the same gateway/backend architecture as Azure SQL: - Gateway: supports TDS 8 strict TLS on *.database.fabric.microsoft.com - Backend: routes to *.worker.database.windows.net with regular TLS upgrade Key finding: Fabric SQL DB does NOT require strict mode (unlike Data Warehouse). Both encrypt=strict and encrypt=true work. No auto-detect upgrade needed for this endpoint pattern. 7 integration tests: strict, required, metadata, DDL/DML, health check, connection_encryption, and works-without-strict verification. --- tests/fabric_sqldb.rs | 370 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 370 insertions(+) create mode 100644 tests/fabric_sqldb.rs diff --git a/tests/fabric_sqldb.rs b/tests/fabric_sqldb.rs new file mode 100644 index 000000000..24b8291ba --- /dev/null +++ b/tests/fabric_sqldb.rs @@ -0,0 +1,370 @@ +//! Integration tests for Fabric SQL Database connectivity. +//! +//! These tests verify that tiberius can connect to a SQL Database in +//! Microsoft Fabric, which uses the `*.database.fabric.microsoft.com` +//! endpoint pattern. +//! +//! # Required environment variables +//! +//! - `FABRIC_SQLDB_ENDPOINT`: The Fabric SQL DB endpoint +//! (e.g., `xxx.database.fabric.microsoft.com`) +//! - `FABRIC_SQLDB_DATABASE`: The database name +//! - `FABRIC_SQLDB_TOKEN`: A pre-obtained AAD/Entra ID token for +//! `https://database.windows.net/` +//! +//! # Running +//! +//! ```sh +//! export FABRIC_SQLDB_ENDPOINT=xxx.database.fabric.microsoft.com +//! export FABRIC_SQLDB_DATABASE=my-db-name +//! export FABRIC_SQLDB_TOKEN=$(az account get-access-token --resource https://database.windows.net/ --query accessToken -o tsv) +//! cargo test --test fabric_sqldb -- --nocapture +//! ``` + +use std::env; +use tiberius::{error::Error, AuthMethod, Client, Config, EncryptionLevel}; +use tokio::net::TcpStream; +use tokio_util::compat::TokioAsyncWriteCompatExt; + +macro_rules! skip_if_no_fabric_sqldb { + () => { + if env::var("FABRIC_SQLDB_ENDPOINT").is_err() { + eprintln!( + "SKIPPED: FABRIC_SQLDB_ENDPOINT not set. Set Fabric SQL DB env vars to run this test." + ); + return Ok(()); + } + }; +} + +fn get_endpoint() -> String { + env::var("FABRIC_SQLDB_ENDPOINT").unwrap() +} + +fn get_database() -> String { + env::var("FABRIC_SQLDB_DATABASE").unwrap() +} + +fn get_token() -> String { + env::var("FABRIC_SQLDB_TOKEN").unwrap() +} + +/// Connect to Fabric SQL Database with encrypt=strict (TDS 8). +/// +/// Fabric SQL Database uses the same gateway/backend architecture as Azure SQL: +/// - Gateway: supports TDS 8 strict TLS +/// - Backend (after routing): uses regular TLS upgrade (Required) +/// +/// The routing target is `*.worker.database.windows.net` — Azure SQL backend +/// infrastructure — which does NOT support strict mode. +async fn connect_fabric_sqldb_strict( +) -> anyhow::Result>> { + let endpoint = get_endpoint(); + let database = get_database(); + let token = get_token(); + + let conn_str = format!( + "server=tcp:{endpoint},1433;encrypt=strict;TrustServerCertificate=false;database={database}" + ); + let mut config = Config::from_ado_string(&conn_str)?; + config.authentication(AuthMethod::aad_token(&token)); + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(Error::Routing { host, port }) => { + eprintln!("Routing redirect to {}:{}", host, port); + + let backend_host = host.split('\\').next().unwrap_or(&host); + + // Fabric SQL DB backend is Azure SQL infrastructure + // (*.worker.database.windows.net) which uses regular TLS upgrade, + // NOT strict mode — same pattern as Azure SQL Database. + let mut backend_config = Config::new(); + backend_config.host(backend_host); + backend_config.port(port); + backend_config.encryption(EncryptionLevel::Required); + backend_config.authentication(AuthMethod::aad_token(&token)); + backend_config.database(&database); + backend_config.login_server_name(&endpoint); + + let tcp = TcpStream::connect(backend_config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let client = Client::connect(backend_config, tcp.compat_write()).await?; + Ok(client) + } + Err(e) => Err(e.into()), + } +} + +/// Connect to Fabric SQL Database with encrypt=true (regular TLS upgrade). +async fn connect_fabric_sqldb_required( +) -> anyhow::Result>> { + let endpoint = get_endpoint(); + let database = get_database(); + let token = get_token(); + + let conn_str = format!( + "server=tcp:{endpoint},1433;encrypt=true;TrustServerCertificate=false;database={database}" + ); + let mut config = Config::from_ado_string(&conn_str)?; + config.authentication(AuthMethod::aad_token(&token)); + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(Error::Routing { host, port }) => { + eprintln!("Routing redirect (required mode) to {}:{}", host, port); + + let backend_host = host.split('\\').next().unwrap_or(&host); + + let mut backend_config = Config::new(); + backend_config.host(backend_host); + backend_config.port(port); + backend_config.encryption(EncryptionLevel::Required); + backend_config.authentication(AuthMethod::aad_token(&token)); + backend_config.database(&database); + backend_config.login_server_name(&endpoint); + + let tcp = TcpStream::connect(backend_config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let client = Client::connect(backend_config, tcp.compat_write()).await?; + Ok(client) + } + Err(e) => Err(e.into()), + } +} + +/// Test: Connect with encrypt=strict (TDS 8) and run a basic query. +#[tokio::test] +async fn fabric_sqldb_strict_basic_query() -> anyhow::Result<()> { + skip_if_no_fabric_sqldb!(); + + let mut client = connect_fabric_sqldb_strict().await?; + + let row = client + .query("SELECT 1 AS test_value", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(1i32), row.get("test_value")); + + eprintln!("Fabric SQL DB strict basic query: OK"); + Ok(()) +} + +/// Test: Connect with encrypt=true (regular TLS) and run a basic query. +#[tokio::test] +async fn fabric_sqldb_required_basic_query() -> anyhow::Result<()> { + skip_if_no_fabric_sqldb!(); + + let mut client = connect_fabric_sqldb_required().await?; + + let row = client + .query("SELECT 1 AS test_value", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(1i32), row.get("test_value")); + + eprintln!("Fabric SQL DB required (TLS upgrade) basic query: OK"); + Ok(()) +} + +/// Test: Server metadata on Fabric SQL Database. +#[tokio::test] +async fn fabric_sqldb_server_metadata() -> anyhow::Result<()> { + skip_if_no_fabric_sqldb!(); + + let mut client = connect_fabric_sqldb_strict().await?; + + let row = client + .query( + "SELECT @@VERSION AS version, DB_NAME() AS db_name, SUSER_SNAME() AS login_name", + &[], + ) + .await? + .into_row() + .await? + .unwrap(); + + let version: &str = row.get("version").unwrap(); + let db_name: &str = row.get("db_name").unwrap(); + let login_name: &str = row.get("login_name").unwrap(); + + eprintln!("Version: {}", version.lines().next().unwrap_or(version)); + eprintln!("Database: {}, Login: {}", db_name, login_name); + + // Fabric SQL DB should report as SQL Server + assert!( + version.contains("SQL Server") || version.contains("Azure"), + "Unexpected version: {}", + version + ); + + Ok(()) +} + +/// Test: DDL and DML operations on Fabric SQL Database. +#[tokio::test] +async fn fabric_sqldb_ddl_dml() -> anyhow::Result<()> { + skip_if_no_fabric_sqldb!(); + + let mut client = connect_fabric_sqldb_strict().await?; + + // Create table + client + .execute( + "IF OBJECT_ID('dbo.tiberius_test', 'U') IS NOT NULL DROP TABLE dbo.tiberius_test", + &[], + ) + .await?; + client + .execute( + "CREATE TABLE dbo.tiberius_test (id INT PRIMARY KEY, name NVARCHAR(100), value FLOAT)", + &[], + ) + .await?; + + // Insert + client + .execute( + "INSERT INTO dbo.tiberius_test (id, name, value) VALUES (1, N'hello', 3.14)", + &[], + ) + .await?; + client + .execute( + "INSERT INTO dbo.tiberius_test (id, name, value) VALUES (2, N'world', 2.71)", + &[], + ) + .await?; + + // Query + let rows = client + .query("SELECT id, name, value FROM dbo.tiberius_test ORDER BY id", &[]) + .await? + .into_first_result() + .await?; + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::<&str, _>("name"), Some("hello")); + assert_eq!(rows[1].get::<&str, _>("name"), Some("world")); + + // Cleanup + client + .execute("DROP TABLE dbo.tiberius_test", &[]) + .await?; + + eprintln!("Fabric SQL DB DDL/DML: OK"); + Ok(()) +} + +/// Test: connection_encryption() reports correct level. +/// +/// When using strict on the gateway but Required on the backend (after routing), +/// the final connection reports the backend's encryption level. +#[tokio::test] +async fn fabric_sqldb_connection_encryption() -> anyhow::Result<()> { + skip_if_no_fabric_sqldb!(); + + let client = connect_fabric_sqldb_strict().await?; + + // The final connection is to the backend, which uses Required (TLS upgrade) + let enc = client.connection_encryption(); + eprintln!("connection_encryption() = {:?}", enc); + + // Backend uses On/Required (TLS upgrade), not Strict + assert!( + enc == EncryptionLevel::On || enc == EncryptionLevel::Required, + "Fabric SQL DB backend should report On or Required, got: {:?}", + enc + ); + Ok(()) +} + +/// Test: is_healthy() works on Fabric SQL Database. +#[tokio::test] +async fn fabric_sqldb_is_healthy() -> anyhow::Result<()> { + skip_if_no_fabric_sqldb!(); + + let mut client = connect_fabric_sqldb_strict().await?; + client.is_healthy().await?; + eprintln!("is_healthy() on Fabric SQL DB: OK"); + Ok(()) +} + +/// Test: Fabric SQL DB works without specifying encrypt=strict. +/// +/// Unlike Fabric Data Warehouse, Fabric SQL Database does NOT require TDS 8 +/// strict mode. It works with regular TLS upgrade (same as Azure SQL). +/// This test verifies no auto-upgrade to strict happens and the connection +/// succeeds with the standard PRELOGIN → TLS upgrade flow. +#[tokio::test] +async fn fabric_sqldb_works_without_strict() -> anyhow::Result<()> { + skip_if_no_fabric_sqldb!(); + + let endpoint = get_endpoint(); + let database = get_database(); + let token = get_token(); + + // Connect WITHOUT specifying encrypt — should work with regular TLS + let conn_str = format!( + "server=tcp:{endpoint},1433;encrypt=true;TrustServerCertificate=false;database={database}" + ); + let mut config = Config::from_ado_string(&conn_str)?; + config.authentication(AuthMethod::aad_token(&token)); + + let tcp = TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + + match Client::connect(config, tcp.compat_write()).await { + Ok(mut client) => { + let row = client + .query("SELECT 1 AS val", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(1i32), row.get("val")); + eprintln!("Fabric SQL DB works without strict (no routing): OK"); + Ok(()) + } + Err(Error::Routing { host, port }) => { + eprintln!("Routing redirect to {}:{}", host, port); + + let backend_host = host.split('\\').next().unwrap_or(&host); + + let mut backend_config = Config::new(); + backend_config.host(backend_host); + backend_config.port(port); + backend_config.encryption(EncryptionLevel::Required); + backend_config.authentication(AuthMethod::aad_token(&token)); + backend_config.database(&database); + backend_config.login_server_name(&endpoint); + + let tcp = TcpStream::connect(backend_config.get_addr()).await?; + tcp.set_nodelay(true)?; + + let mut client = Client::connect(backend_config, tcp.compat_write()).await?; + let row = client + .query("SELECT 1 AS val", &[]) + .await? + .into_row() + .await? + .unwrap(); + assert_eq!(Some(1i32), row.get("val")); + eprintln!("Fabric SQL DB works without strict (with routing): OK"); + Ok(()) + } + Err(e) => Err(e.into()), + } +} From f1b2fcb6ba86f14b1f2d5b7dda83faac036d00af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 18:18:27 +0200 Subject: [PATCH 14/15] chore: formatting, doc fix, and cleanup unused variable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - cargo fmt: fix line-length formatting in unit tests and integration tests - Fix resolve_encryption() doc comment to list both auto-detect patterns (*.datawarehouse.fabric.microsoft.com and *.pbidedicated.windows.net) - Remove unused `resolved` variable binding in connect() — the return value was redundant with config.encryption after mutation - Rename misleading `is_backend` to `include_trace_id` in finish_connect_after_tls() to match the parameter it maps to --- src/client/config.rs | 1 + src/client/connection.rs | 35 +++++++++++++++++++++-------------- tests/fabric_sqldb.rs | 9 +++++---- tests/sql_server_tds8.rs | 5 ++++- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/client/config.rs b/src/client/config.rs index 598bbf879..45560a6e7 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -227,6 +227,7 @@ impl Config { /// /// Currently detects: /// - `*.datawarehouse.fabric.microsoft.com` → `EncryptionLevel::Strict` + /// - `*.pbidedicated.windows.net` → `EncryptionLevel::Strict` /// /// Returns the (possibly upgraded) encryption level. pub(crate) fn resolve_encryption(&mut self) -> EncryptionLevel { diff --git a/src/client/connection.rs b/src/client/connection.rs index 8114f4395..9e58e55ca 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -93,8 +93,8 @@ impl Connection { // Auto-detect strict mode from hostname when encryption wasn't // explicitly configured by the caller. - let resolved = config.resolve_encryption(); - if resolved == EncryptionLevel::Strict && !config.encryption_explicit { + config.resolve_encryption(); + if config.encryption == EncryptionLevel::Strict && !config.encryption_explicit { event!( Level::INFO, host = config.get_host(), @@ -170,12 +170,11 @@ impl Connection { /// Complete connection setup after TLS is established (for strict mode). /// In TDS 8, PRELOGIN and LOGIN both happen inside the TLS tunnel. async fn finish_connect_after_tls(mut connection: Self, config: Config) -> crate::Result { - // For backend connections (routing reconnect), we still need to send - // FEDAUTHREQUIRED if using AAD auth — the backend needs it to prepare - // for processing the FEDAUTH token in LOGIN. - let is_backend = config.instance_name.is_some(); let fed_auth_required = config.auth.is_aad(); + // Include TRACEID in PRELOGIN for backend (routing reconnect) connections. + let include_trace_id = config.instance_name.is_some(); + // In TDS 8 strict mode, send ENCRYPT_STRICT (0x08) on the wire. // TLS is already established, and the PRELOGIN encryption field signals // to the server that this is a TDS 8 strict mode connection. @@ -186,7 +185,7 @@ impl Connection { prelogin_encryption, fed_auth_required, config.instance_name.clone(), - is_backend, + include_trace_id, ) .await?; @@ -983,12 +982,17 @@ mod tests { #[test] fn wrap_strict_tls_error_wraps_tls_error() { let original = crate::Error::Tls("handshake failure".to_string()); - let wrapped = - Connection::>>::wrap_strict_tls_error(original, "myserver.example.com"); + let wrapped = Connection::>>::wrap_strict_tls_error( + original, + "myserver.example.com", + ); let msg = wrapped.to_string(); assert!(msg.contains("myserver.example.com"), "should contain host"); - assert!(msg.contains("handshake failure"), "should contain original error"); + assert!( + msg.contains("handshake failure"), + "should contain original error" + ); assert!(msg.contains("strict"), "should mention strict mode"); assert!( msg.contains("encrypt=true"), @@ -1008,8 +1012,9 @@ mod tests { kind: ErrorKind::ConnectionReset, message: "connection reset by peer".to_string(), }; - let wrapped = - Connection::>>::wrap_strict_tls_error(original, "10.0.0.1"); + let wrapped = Connection::>>::wrap_strict_tls_error( + original, "10.0.0.1", + ); let msg = wrapped.to_string(); assert!(msg.contains("10.0.0.1"), "should contain host"); @@ -1024,8 +1029,10 @@ mod tests { #[test] fn wrap_strict_tls_error_passes_through_other_errors() { let original = crate::Error::Protocol("something else".into()); - let wrapped = - Connection::>>::wrap_strict_tls_error(original.clone(), "host"); + let wrapped = Connection::>>::wrap_strict_tls_error( + original.clone(), + "host", + ); // Non-TLS/IO errors should pass through unchanged assert_eq!(wrapped, original); diff --git a/tests/fabric_sqldb.rs b/tests/fabric_sqldb.rs index 24b8291ba..99b001afc 100644 --- a/tests/fabric_sqldb.rs +++ b/tests/fabric_sqldb.rs @@ -250,7 +250,10 @@ async fn fabric_sqldb_ddl_dml() -> anyhow::Result<()> { // Query let rows = client - .query("SELECT id, name, value FROM dbo.tiberius_test ORDER BY id", &[]) + .query( + "SELECT id, name, value FROM dbo.tiberius_test ORDER BY id", + &[], + ) .await? .into_first_result() .await?; @@ -260,9 +263,7 @@ async fn fabric_sqldb_ddl_dml() -> anyhow::Result<()> { assert_eq!(rows[1].get::<&str, _>("name"), Some("world")); // Cleanup - client - .execute("DROP TABLE dbo.tiberius_test", &[]) - .await?; + client.execute("DROP TABLE dbo.tiberius_test", &[]).await?; eprintln!("Fabric SQL DB DDL/DML: OK"); Ok(()) diff --git a/tests/sql_server_tds8.rs b/tests/sql_server_tds8.rs index 29ba9e220..4ec63dda3 100644 --- a/tests/sql_server_tds8.rs +++ b/tests/sql_server_tds8.rs @@ -462,7 +462,10 @@ async fn connection_encryption_reports_strict() -> anyhow::Result<()> { "SQL Server with forcestrict=1 should report Strict encryption" ); - eprintln!("connection_encryption() = {:?}", client.connection_encryption()); + eprintln!( + "connection_encryption() = {:?}", + client.connection_encryption() + ); Ok(()) } From 17309520cd8d07c5940ed0ba433e59d62e164e3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Tue, 19 May 2026 18:43:50 +0200 Subject: [PATCH 15/15] Address review feedback: harden FEATUREEXTACK, improve error context - Add size guard (1 MiB max) on FEATUREEXTACK payload allocations to prevent malformed/malicious packets from causing large memory allocs. Extract read_feature_data() helper to deduplicate the pattern. - Replace panic!() in FEDAUTH data_len handling with proper Error::Protocol - Add comment explaining why SNI stays enabled in strict TrustAll path (cloud endpoints need SNI for tenant routing) - Use expect() with descriptive messages instead of unwrap() in fabric_sqldb.rs test helpers for clearer failure diagnostics --- src/client/tls_stream/native_tls_stream.rs | 3 + src/tds/codec/token/token_feature_ext_ack.rs | 58 +++++++++++++------- tests/fabric_sqldb.rs | 6 +- 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/src/client/tls_stream/native_tls_stream.rs b/src/client/tls_stream/native_tls_stream.rs index a5d5a81fe..23925d11c 100644 --- a/src/client/tls_stream/native_tls_stream.rs +++ b/src/client/tls_stream/native_tls_stream.rs @@ -57,6 +57,9 @@ pub(crate) async fn create_tls_stream( ); native_builder.danger_accept_invalid_certs(true); native_builder.danger_accept_invalid_hostnames(true); + // SNI remains enabled (unlike the non-strict TrustAll path) because + // cloud endpoints (Azure SQL, Fabric) use SNI to route the TLS + // connection to the correct tenant/gateway even in trust-all mode. } TrustConfig::Default => { event!(Level::INFO, "Using default trust configuration."); diff --git a/src/tds/codec/token/token_feature_ext_ack.rs b/src/tds/codec/token/token_feature_ext_ack.rs index de8d45a20..d05a85085 100644 --- a/src/tds/codec/token/token_feature_ext_ack.rs +++ b/src/tds/codec/token/token_feature_ext_ack.rs @@ -4,6 +4,12 @@ use crate::{ }; use futures_util::AsyncReadExt; +/// Maximum allowed payload size for a single FEATUREEXTACK feature data field. +/// In practice these are small (0-32 bytes for FedAuth nonce, 1 byte for +/// AzureSqlSupport, etc.), so 1 MiB is a generous upper bound that guards +/// against malformed packets causing large allocations. +const MAX_FEATURE_ACK_DATA_LEN: usize = 1 << 20; + #[derive(Debug)] pub struct TokenFeatureExtAck { pub features: Vec, @@ -54,42 +60,52 @@ impl TokenFeatureExtAck { } else if data_len == 0 { None } else { - panic!("invalid Feature_Ext_Ack token"); + return Err(crate::Error::Protocol( + format!( + "FEATUREEXTACK FEDAUTH: unexpected data_len {data_len} (expected 0 or 32)" + ) + .into(), + )); }; features.push(FeatureAck::FedAuth(FedAuthAck::SecurityToken { nonce })) } else if feature_id == FEA_EXT_AZURESQLSUPPORT { - let data_len = src.read_u32_le().await? as usize; - let mut data = vec![0u8; data_len]; - if data_len > 0 { - src.read_exact(&mut data).await?; - } + let data = Self::read_feature_data(src).await?; features.push(FeatureAck::AzureSqlSupport(data)); } else if feature_id == FEA_EXT_COLUMNENCRYPTION { - let data_len = src.read_u32_le().await? as usize; - let mut data = vec![0u8; data_len]; - if data_len > 0 { - src.read_exact(&mut data).await?; - } + let data = Self::read_feature_data(src).await?; features.push(FeatureAck::ColumnEncryption(data)); } else if feature_id == FEA_EXT_UTF8_SUPPORT { - let data_len = src.read_u32_le().await? as usize; - let mut data = vec![0u8; data_len]; - if data_len > 0 { - src.read_exact(&mut data).await?; - } + let data = Self::read_feature_data(src).await?; features.push(FeatureAck::Utf8Support(data)); } else { // Unknown feature — skip gracefully by reading data_len bytes - let data_len = src.read_u32_le().await? as usize; - let mut data = vec![0u8; data_len]; - if data_len > 0 { - src.read_exact(&mut data).await?; - } + let data = Self::read_feature_data(src).await?; features.push(FeatureAck::Unknown { feature_id, data }); } } Ok(TokenFeatureExtAck { features }) } + + /// Read a feature data payload with a size guard against malformed packets. + async fn read_feature_data(src: &mut R) -> crate::Result> + where + R: SqlReadBytes + Unpin, + { + let data_len = src.read_u32_le().await? as usize; + if data_len > MAX_FEATURE_ACK_DATA_LEN { + return Err(crate::Error::Protocol( + format!( + "FEATUREEXTACK payload too large: {data_len} bytes (max {MAX_FEATURE_ACK_DATA_LEN})" + ) + .into(), + )); + } + let mut data = vec![0u8; data_len]; + if data_len > 0 { + src.read_exact(&mut data).await?; + } + Ok(data) + } } diff --git a/tests/fabric_sqldb.rs b/tests/fabric_sqldb.rs index 99b001afc..01438940b 100644 --- a/tests/fabric_sqldb.rs +++ b/tests/fabric_sqldb.rs @@ -38,15 +38,15 @@ macro_rules! skip_if_no_fabric_sqldb { } fn get_endpoint() -> String { - env::var("FABRIC_SQLDB_ENDPOINT").unwrap() + env::var("FABRIC_SQLDB_ENDPOINT").expect("FABRIC_SQLDB_ENDPOINT must be set") } fn get_database() -> String { - env::var("FABRIC_SQLDB_DATABASE").unwrap() + env::var("FABRIC_SQLDB_DATABASE").expect("FABRIC_SQLDB_DATABASE must be set") } fn get_token() -> String { - env::var("FABRIC_SQLDB_TOKEN").unwrap() + env::var("FABRIC_SQLDB_TOKEN").expect("FABRIC_SQLDB_TOKEN must be set") } /// Connect to Fabric SQL Database with encrypt=strict (TDS 8).