diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs index b96e6f654d..e0490ebdf5 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -266,7 +266,7 @@ private SqlConnection(SqlConnection connection) internal static bool TryGetSystemColumnEncryptionKeyStoreProvider(string keyStoreName, out SqlColumnEncryptionKeyStoreProvider provider) { - return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider); + return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider); } /// @@ -1332,7 +1332,7 @@ public override void ChangeDatabase(string database) SqlStatistics statistics = null; RepairInnerConnection(); SqlClientEventSource.Log.TryCorrelationTraceEvent("SqlConnection.ChangeDatabase | API | Correlation | Object Id {0}, Activity Id {1}, Database {2}", ObjectID, ActivityCorrelator.Current, database); - + try { statistics = SqlStatistics.StartTimer(Statistics); @@ -1408,7 +1408,7 @@ public override void Close() SqlStatistics statistics = null; Exception e = null; - + try { statistics = SqlStatistics.StartTimer(Statistics); @@ -1901,7 +1901,7 @@ internal void Abort(Exception e) } /// - public override Task OpenAsync(CancellationToken cancellationToken) + public override Task OpenAsync(CancellationToken cancellationToken) => OpenAsync(SqlConnectionOverrides.None, cancellationToken); /// @@ -2224,7 +2224,18 @@ private bool TryOpenInner(TaskCompletionSource retry) } // does not require GC.KeepAlive(this) because of ReRegisterForFinalize below. - var tdsInnerConnection = (SqlConnectionInternal)InnerConnection; + // Capture InnerConnection once into a local to avoid a TOCTOU race: another thread + // concurrently calling Open() on the same SqlConnection instance can change + // _innerConnection to DbConnectionClosedConnecting between the TryOpenConnection() + // call above and the cast below. Without this local capture the second read of + // InnerConnection may return DbConnectionClosedConnecting, which is not assignable + // to SqlConnectionInternal and would produce an opaque InvalidCastException. + // See GitHub issue #3314. + var innerConnection = InnerConnection; + if (innerConnection is not SqlConnectionInternal tdsInnerConnection) + { + throw ADP.ConnectionAlreadyOpen(innerConnection.State); + } Debug.Assert(tdsInnerConnection.Parser != null, "Where's the parser?"); diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/SqlConnectionConcurrentOpenTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/SqlConnectionConcurrentOpenTests.cs new file mode 100644 index 0000000000..c38a30d317 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/SqlConnectionConcurrentOpenTests.cs @@ -0,0 +1,113 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.Data.ProviderBase; +using Microsoft.Data.SqlClient.Connection; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.Microsoft.Data.SqlClient +{ + /// + /// Regression tests for GitHub issue #3314. + /// + /// Root cause: TryOpenInner() read InnerConnection twice - once for TryOpenConnection() and + /// again for the cast to SqlConnectionInternal. Between those two reads another thread could + /// change _innerConnection to DbConnectionClosedConnecting, which is not assignable to + /// SqlConnectionInternal, causing an opaque InvalidCastException. + /// + /// Fix: InnerConnection is now captured into a local variable once; if it is not a + /// SqlConnectionInternal an InvalidOperationException with a descriptive message is thrown + /// instead of an InvalidCastException. + /// + public class SqlConnectionConcurrentOpenTests + { + private static readonly MethodInfo s_tryOpenInner = typeof(SqlConnection) + .GetMethod("TryOpenInner", BindingFlags.Instance | BindingFlags.NonPublic)!; + + private static DbConnectionInternal GetConnectingSingleton() + { + return DbConnectionClosedConnecting.SingletonInstance; + } + + private static void ForceInnerConnection(SqlConnection connection, DbConnectionInternal innerConnectionValue) + { + connection.SetInnerConnectionTo(innerConnectionValue); + } + + private static bool InvokeTryOpenInner(SqlConnection connection, TaskCompletionSource retry) + { + try + { + return (bool)s_tryOpenInner.Invoke(connection, [retry])!; + } + catch (TargetInvocationException tie) when (tie.InnerException != null) + { + throw tie.InnerException; + } + } + + [Fact] + public void InnerConnection_DbConnectionClosedConnecting_IsNotAssignableToSqlConnectionInternal() + { + DbConnectionInternal connectingSingleton = GetConnectingSingleton(); + + Assert.False( + connectingSingleton is SqlConnectionInternal, + "DbConnectionClosedConnecting must NOT be assignable to SqlConnectionInternal. " + + "If it were, the race condition in #3314 would not manifest."); + } + + [Fact] + public void InnerConnection_InConnectingState_ReportsConnectingState() + { + DbConnectionInternal connectingSingleton = GetConnectingSingleton(); + + var connection = new SqlConnection("Data Source=localhost"); + ForceInnerConnection(connection, connectingSingleton); + + Assert.Equal(ConnectionState.Connecting, connection.State); + } + + [Fact] + public void Open_WhenAlreadyConnecting_ThrowsInvalidOperation() + { + DbConnectionInternal connectingSingleton = GetConnectingSingleton(); + + var connection = new SqlConnection("Data Source=localhost"); + ForceInnerConnection(connection, connectingSingleton); + + Assert.Throws(() => connection.Open()); + } + + [Fact] + public void TryOpenInner_WhenInnerConnectionRacesToNonSqlConnectionInternalState_ThrowsInvalidOperation_NotInvalidCast() + { + DbConnectionInternal initialConnectingState = GetConnectingSingleton(); + DbConnectionInternal racedNonSqlConnectionInternalState = DbConnectionOpenBusy.SingletonInstance; + + var connection = new SqlConnection("Data Source=localhost"); + ForceInnerConnection(connection, initialConnectingState); + + TaskCompletionSource completedRetry = new(); + completedRetry.SetResult(racedNonSqlConnectionInternalState); + + Exception ex = Assert.ThrowsAny(() => + { + InvokeTryOpenInner(connection, completedRetry); + }); + + Assert.True( + ex is InvalidOperationException, + $"Expected InvalidOperationException but got {ex.GetType().Name}: {ex.Message}. " + + "The fix for #3314 must throw InvalidOperationException (not InvalidCastException) " + + "when _innerConnection races to a non-SqlConnectionInternal state inside TryOpenInner."); + + Assert.Contains("connection", ex.Message, StringComparison.OrdinalIgnoreCase); + } + } +} \ No newline at end of file