Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ members = ["runtimes-macro"]
path = "tests/query.rs"
name = "query"

[[test]]
path = "tests/fabric.rs"
name = "fabric"

[[test]]
path = "tests/named-instance-async.rs"
name = "named-instance-async"
Expand Down Expand Up @@ -50,7 +54,7 @@ asynchronous-codec = "0.6"
async-trait = "0.1"
connection-string = "0.2"
num-traits = "0.2"
uuid = "1.0"
uuid = { version = "1.0", features = ["v4"] }

[target.'cfg(windows)'.dependencies]
winauth = { version = "0.0.4", optional = true }
Expand All @@ -63,6 +67,12 @@ version = "0.4"
features = ["runtime-async-std"]
optional = true

[dependencies.native-tls-crate]
version = "0.2.18"
optional = true
features = ["alpn"]
package = "native-tls"

[dependencies.tokio]
version = "1.0"
optional = true
Expand Down Expand Up @@ -200,5 +210,5 @@ sql-browser-smol = ["async-io", "async-net", "futures-lite"]
integrated-auth-gssapi = ["libgssapi"]
bigdecimal = ["bigdecimal_"]
rustls = ["tokio-rustls", "tokio-util", "rustls-pemfile", "rustls-native-certs"]
native-tls = ["async-native-tls"]
native-tls = ["async-native-tls", "native-tls-crate"]
vendored-openssl = ["opentls"]
57 changes: 56 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
codec::{self, IteratorJoin},
stream::{QueryStream, TokenStream},
},
BulkLoadRequest, ColumnFlag, SqlReadBytes, ToSql,
BulkLoadRequest, ColumnFlag, EncryptionLevel, SqlReadBytes, ToSql,
};
use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest};
use enumflags2::BitFlags;
Expand Down Expand Up @@ -352,6 +352,61 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
self.connection.close().await
}

/// Returns the encryption level negotiated for this connection.
///
/// This is useful for connection pool implementations that need to verify
/// the connection's security properties, or for logging/diagnostics.
///
/// # Example
///
/// ```no_run
/// # use tiberius::{Client, Config, EncryptionLevel};
/// # use tokio_util::compat::TokioAsyncWriteCompatExt;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let config = Config::new();
/// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
/// # let client = Client::connect(config, tcp.compat_write()).await?;
/// match client.connection_encryption() {
/// EncryptionLevel::Strict => println!("TDS 8 strict mode"),
/// EncryptionLevel::Required => println!("Encrypted (TLS upgrade)"),
/// _ => println!("Other encryption level"),
/// }
/// # Ok(())
/// # }
/// ```
pub fn connection_encryption(&self) -> EncryptionLevel {
self.connection.encryption
}

/// Validates the connection by executing a lightweight query (`SELECT 1`).
///
/// Returns `Ok(())` if the server responds successfully, or an error if
/// the connection is broken. This is intended for use by connection pool
/// `is_valid` / health-check hooks.
///
/// # Example
///
/// ```no_run
/// # use tiberius::{Client, Config};
/// # use tokio_util::compat::TokioAsyncWriteCompatExt;
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let config = Config::new();
/// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
/// # let mut client = Client::connect(config, tcp.compat_write()).await?;
/// if client.is_healthy().await.is_err() {
/// // Connection is broken, discard and create a new one
/// }
/// # Ok(())
/// # }
/// ```
pub async fn is_healthy(&mut self) -> crate::Result<()> {
let stream = self.simple_query("SELECT 1").await?;
stream.into_results().await?;
Ok(())
}

pub(crate) fn rpc_params<'a>(query: impl Into<Cow<'a, str>>) -> Vec<RpcParam<'a>> {
vec![
RpcParam {
Expand Down
150 changes: 146 additions & 4 deletions src/client/auth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::fmt::Debug;
use std::sync::Arc;

use async_trait::async_trait;

#[derive(Clone, PartialEq, Eq)]
pub struct SqlServerAuth {
Expand Down Expand Up @@ -46,8 +49,41 @@ impl Debug for WindowsAuth {
}
}

/// A trait for providing AAD/Entra ID tokens dynamically, supporting token refresh.
///
/// Implement this trait to supply fresh tokens on each connection or reconnection.
/// This is useful for long-lived applications where tokens expire (~1 hour) and
/// need to be refreshed transparently.
///
/// # Example
///
/// ```rust,no_run
/// use async_trait::async_trait;
/// use tiberius::TokenProvider;
///
/// struct MyTokenProvider {
/// // your credential state (e.g., client_id, client_secret, tenant_id)
/// }
///
/// #[async_trait]
/// impl TokenProvider for MyTokenProvider {
/// async fn get_token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
/// // Call your identity provider here to get a fresh token
/// // e.g., azure_identity::DefaultAzureCredential, MSAL, etc.
/// Ok("fresh-token".to_string())
/// }
/// }
/// ```
#[async_trait]
pub trait TokenProvider: Send + Sync {
/// Obtain a fresh AAD/Entra ID access token for the `https://database.windows.net/` resource.
///
/// This method is called each time a new connection or reconnection is established.
/// Implementations should handle caching and refresh internally.
async fn get_token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
}

/// Defines the method of authentication to the server.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum AuthMethod {
/// Authenticate directly with SQL Server.
SqlServer(SqlServerAuth),
Expand All @@ -67,13 +103,82 @@ pub enum AuthMethod {
doc(cfg(any(windows, all(unix, feature = "integrated-auth-gssapi"))))
)]
Integrated,
/// Authenticate with an AAD token. The token should encode an AAD user/service principal
/// which has access to SQL Server.
/// Authenticate with a static AAD token. The token should encode an AAD
/// user/service principal which has access to SQL Server.
///
/// For long-lived applications where tokens may expire, prefer
/// [`AuthMethod::token_provider`] instead.
AADToken(String),
/// Authenticate with a dynamic AAD token provider that supports refresh.
///
/// The provider's `get_token()` method is called each time a connection
/// (or routing reconnection) is established, ensuring fresh tokens.
AADTokenProvider(Arc<dyn TokenProvider>),
#[doc(hidden)]
None,
}

impl Clone for AuthMethod {
fn clone(&self) -> Self {
match self {
Self::SqlServer(a) => Self::SqlServer(a.clone()),
#[cfg(any(all(windows, feature = "winauth"), doc))]
Self::Windows(a) => Self::Windows(a.clone()),
#[cfg(any(
all(windows, feature = "winauth"),
all(unix, feature = "integrated-auth-gssapi"),
doc
))]
Self::Integrated => Self::Integrated,
Self::AADToken(t) => Self::AADToken(t.clone()),
Self::AADTokenProvider(p) => Self::AADTokenProvider(Arc::clone(p)),
Self::None => Self::None,
}
}
}

impl Debug for AuthMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SqlServer(a) => f.debug_tuple("SqlServer").field(a).finish(),
#[cfg(any(all(windows, feature = "winauth"), doc))]
Self::Windows(a) => f.debug_tuple("Windows").field(a).finish(),
#[cfg(any(
all(windows, feature = "winauth"),
all(unix, feature = "integrated-auth-gssapi"),
doc
))]
Self::Integrated => write!(f, "Integrated"),
Self::AADToken(_) => write!(f, "AADToken(<HIDDEN>)"),
Self::AADTokenProvider(_) => write!(f, "AADTokenProvider(...)"),
Self::None => write!(f, "None"),
}
}
}

impl PartialEq for AuthMethod {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::SqlServer(a), Self::SqlServer(b)) => a == b,
#[cfg(any(all(windows, feature = "winauth"), doc))]
(Self::Windows(a), Self::Windows(b)) => a == b,
#[cfg(any(
all(windows, feature = "winauth"),
all(unix, feature = "integrated-auth-gssapi"),
doc
))]
(Self::Integrated, Self::Integrated) => true,
(Self::AADToken(a), Self::AADToken(b)) => a == b,
// Token providers are compared by Arc pointer identity
(Self::AADTokenProvider(a), Self::AADTokenProvider(b)) => Arc::ptr_eq(a, b),
(Self::None, Self::None) => true,
_ => false,
}
}
}

impl Eq for AuthMethod {}

impl AuthMethod {
/// Construct a new SQL Server authentication configuration.
pub fn sql_server(user: impl ToString, password: impl ToString) -> Self {
Expand All @@ -99,8 +204,45 @@ impl AuthMethod {
})
}

/// Construct a new configuration with AAD auth token.
/// Construct a new configuration with a static AAD auth token.
///
/// For long-lived applications, prefer [`AuthMethod::token_provider`] which
/// supports automatic token refresh.
pub fn aad_token(token: impl ToString) -> Self {
Self::AADToken(token.to_string())
}

/// Construct a new configuration with a dynamic token provider.
///
/// The provider's `get_token()` method is called on each new connection,
/// ensuring fresh tokens even for long-lived applications.
///
/// # Example
///
/// ```rust,no_run
/// use std::sync::Arc;
/// use async_trait::async_trait;
/// use tiberius::{AuthMethod, TokenProvider};
///
/// struct AzCliTokenProvider;
///
/// #[async_trait]
/// impl TokenProvider for AzCliTokenProvider {
/// async fn get_token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
/// // In real code, call azure_identity or az CLI
/// Ok("fresh-token".to_string())
/// }
/// }
///
/// let auth = AuthMethod::token_provider(Arc::new(AzCliTokenProvider));
/// ```
pub fn token_provider(provider: Arc<dyn TokenProvider>) -> Self {
Self::AADTokenProvider(provider)
}

/// Returns true if this auth method uses AAD token authentication
/// (either static or via provider).
pub(crate) fn is_aad(&self) -> bool {
matches!(self, Self::AADToken(_) | Self::AADTokenProvider(_))
}
}
Loading