Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ private SqlConnection(SqlConnection connection)

internal static bool TryGetSystemColumnEncryptionKeyStoreProvider(string keyStoreName, out SqlColumnEncryptionKeyStoreProvider provider)
{
return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider);
return s_systemColumnEncryptionKeyStoreProviders.TryGetValue(keyStoreName, out provider);
}

/// <summary>
Expand Down Expand Up @@ -1332,7 +1332,7 @@ public override void ChangeDatabase(string database)
SqlStatistics statistics = null;
RepairInnerConnection();
SqlClientEventSource.Log.TryCorrelationTraceEvent("SqlConnection.ChangeDatabase | API | Correlation | Object Id {0}, Activity Id {1}, Database {2}", ObjectID, ActivityCorrelator.Current, database);

try
{
statistics = SqlStatistics.StartTimer(Statistics);
Expand Down Expand Up @@ -1408,7 +1408,7 @@ public override void Close()

SqlStatistics statistics = null;
Exception e = null;

try
{
statistics = SqlStatistics.StartTimer(Statistics);
Expand Down Expand Up @@ -1901,7 +1901,7 @@ internal void Abort(Exception e)
}

/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/OpenAsync/*' />
public override Task OpenAsync(CancellationToken cancellationToken)
public override Task OpenAsync(CancellationToken cancellationToken)
=> OpenAsync(SqlConnectionOverrides.None, cancellationToken);

/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/OpenAsyncWithOverrides/*' />
Expand Down Expand Up @@ -2224,7 +2224,18 @@ private bool TryOpenInner(TaskCompletionSource<DbConnectionInternal> retry)
}
// does not require GC.KeepAlive(this) because of ReRegisterForFinalize below.

var tdsInnerConnection = (SqlConnectionInternal)InnerConnection;
// Capture InnerConnection once into a local to avoid a TOCTOU race: another thread
// concurrently calling Open() on the same SqlConnection instance can change
// _innerConnection to DbConnectionClosedConnecting between the TryOpenConnection()
// call above and the cast below. Without this local capture the second read of
// InnerConnection may return DbConnectionClosedConnecting, which is not assignable
// to SqlConnectionInternal and would produce an opaque InvalidCastException.
// See GitHub issue #3314.
var innerConnection = InnerConnection;
if (innerConnection is not SqlConnectionInternal tdsInnerConnection)
{
throw ADP.ConnectionAlreadyOpen(innerConnection.State);
}

Debug.Assert(tdsInnerConnection.Parser != null, "Where's the parser?");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Data;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.Data.ProviderBase;
using Microsoft.Data.SqlClient.Connection;
using Xunit;

namespace Microsoft.Data.SqlClient.UnitTests.Microsoft.Data.SqlClient
{
/// <summary>
/// Regression tests for GitHub issue #3314.
///
/// Root cause: TryOpenInner() read InnerConnection twice - once for TryOpenConnection() and
/// again for the cast to SqlConnectionInternal. Between those two reads another thread could
/// change _innerConnection to DbConnectionClosedConnecting, which is not assignable to
/// SqlConnectionInternal, causing an opaque InvalidCastException.
///
/// Fix: InnerConnection is now captured into a local variable once; if it is not a
/// SqlConnectionInternal an InvalidOperationException with a descriptive message is thrown
/// instead of an InvalidCastException.
/// </summary>
public class SqlConnectionConcurrentOpenTests
{
private static readonly MethodInfo s_tryOpenInner = typeof(SqlConnection)
.GetMethod("TryOpenInner", BindingFlags.Instance | BindingFlags.NonPublic)!;

private static DbConnectionInternal GetConnectingSingleton()
{
return DbConnectionClosedConnecting.SingletonInstance;
}

private static void ForceInnerConnection(SqlConnection connection, DbConnectionInternal innerConnectionValue)
{
connection.SetInnerConnectionTo(innerConnectionValue);
}

private static bool InvokeTryOpenInner(SqlConnection connection, TaskCompletionSource<DbConnectionInternal> retry)
{
try
{
return (bool)s_tryOpenInner.Invoke(connection, [retry])!;
}
catch (TargetInvocationException tie) when (tie.InnerException != null)
{
throw tie.InnerException;
}
}

[Fact]
public void InnerConnection_DbConnectionClosedConnecting_IsNotAssignableToSqlConnectionInternal()
{
DbConnectionInternal connectingSingleton = GetConnectingSingleton();

Assert.False(
connectingSingleton is SqlConnectionInternal,
"DbConnectionClosedConnecting must NOT be assignable to SqlConnectionInternal. " +
"If it were, the race condition in #3314 would not manifest.");
}

[Fact]
public void InnerConnection_InConnectingState_ReportsConnectingState()
{
DbConnectionInternal connectingSingleton = GetConnectingSingleton();

var connection = new SqlConnection("Data Source=localhost");
ForceInnerConnection(connection, connectingSingleton);

Assert.Equal(ConnectionState.Connecting, connection.State);
}

[Fact]
public void Open_WhenAlreadyConnecting_ThrowsInvalidOperation()
{
DbConnectionInternal connectingSingleton = GetConnectingSingleton();

var connection = new SqlConnection("Data Source=localhost");
ForceInnerConnection(connection, connectingSingleton);

Assert.Throws<InvalidOperationException>(() => connection.Open());
}

[Fact]
public void TryOpenInner_WhenInnerConnectionRacesToNonSqlConnectionInternalState_ThrowsInvalidOperation_NotInvalidCast()
{
DbConnectionInternal initialConnectingState = GetConnectingSingleton();
DbConnectionInternal racedNonSqlConnectionInternalState = DbConnectionOpenBusy.SingletonInstance;

Comment thread
paulmedynski marked this conversation as resolved.
var connection = new SqlConnection("Data Source=localhost");
ForceInnerConnection(connection, initialConnectingState);

TaskCompletionSource<DbConnectionInternal> completedRetry = new();
completedRetry.SetResult(racedNonSqlConnectionInternalState);

Exception ex = Assert.ThrowsAny<Exception>(() =>
{
InvokeTryOpenInner(connection, completedRetry);
});

Assert.True(
ex is InvalidOperationException,
$"Expected InvalidOperationException but got {ex.GetType().Name}: {ex.Message}. " +
"The fix for #3314 must throw InvalidOperationException (not InvalidCastException) " +
"when _innerConnection races to a non-SqlConnectionInternal state inside TryOpenInner.");

Assert.Contains("connection", ex.Message, StringComparison.OrdinalIgnoreCase);
}
}
}
Loading