diff --git a/src/libraries/System.Data.Common/ref/System.Data.Common.cs b/src/libraries/System.Data.Common/ref/System.Data.Common.cs index eb588309501c3b..c5a8f8dbbb19bd 100644 --- a/src/libraries/System.Data.Common/ref/System.Data.Common.cs +++ b/src/libraries/System.Data.Common/ref/System.Data.Common.cs @@ -1933,6 +1933,9 @@ public virtual void EnlistTransaction(System.Transactions.Transaction? transacti public virtual System.Data.DataTable GetSchema() { throw null; } public virtual System.Data.DataTable GetSchema(string collectionName) { throw null; } public virtual System.Data.DataTable GetSchema(string collectionName, string?[] restrictionValues) { throw null; } + public virtual System.Threading.Tasks.Task GetSchemaAsync(System.Threading.CancellationToken cancellationToken = default) { throw null; } + public virtual System.Threading.Tasks.Task GetSchemaAsync(string collectionName, System.Threading.CancellationToken cancellationToken = default) { throw null; } + public virtual System.Threading.Tasks.Task GetSchemaAsync(string collectionName, string?[] restrictionValues, System.Threading.CancellationToken cancellationToken = default) { throw null; } protected virtual void OnStateChange(System.Data.StateChangeEventArgs stateChange) { } public abstract void Open(); public System.Threading.Tasks.Task OpenAsync() { throw null; } @@ -2111,6 +2114,8 @@ protected virtual void Dispose(bool disposing) { } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] public virtual int GetProviderSpecificValues(object[] values) { throw null; } public virtual System.Data.DataTable GetSchemaTable() { throw null; } + public virtual System.Threading.Tasks.Task GetSchemaTableAsync(System.Threading.CancellationToken cancellationToken = default) { throw null; } + public virtual System.Threading.Tasks.Task> GetColumnSchemaAsync(System.Threading.CancellationToken cancellationToken = default) { throw null; } public virtual System.IO.Stream GetStream(int ordinal) { throw null; } public abstract string GetString(int ordinal); public virtual System.IO.TextReader GetTextReader(int ordinal) { throw null; } diff --git a/src/libraries/System.Data.Common/src/System/Data/Common/DbConnection.cs b/src/libraries/System.Data.Common/src/System/Data/Common/DbConnection.cs index 130611be629a1b..7b99a8dfbff029 100644 --- a/src/libraries/System.Data.Common/src/System/Data/Common/DbConnection.cs +++ b/src/libraries/System.Data.Common/src/System/Data/Common/DbConnection.cs @@ -143,21 +143,166 @@ public virtual void EnlistTransaction(System.Transactions.Transaction? transacti // these need to be here so that GetSchema is visible when programming to a dbConnection object. // they are overridden by the real implementations in DbConnectionBase + + /// + /// Returns schema information for the data source of this . + /// + /// A that contains schema information. + /// + /// If the connection is associated with a transaction, executing calls may cause + /// some providers to throw an exception. + /// public virtual DataTable GetSchema() { throw ADP.NotSupported(); } + /// + /// Returns schema information for the data source of this using the specified + /// string for the schema name. + /// + /// Specifies the name of the schema to return. + /// A that contains schema information. + /// + /// is specified as . + /// + /// + /// If the connection is associated with a transaction, executing calls may cause + /// some providers to throw an exception. + /// public virtual DataTable GetSchema(string collectionName) { throw ADP.NotSupported(); } + /// + /// Returns schema information for the data source of this using the specified + /// string for the schema name and the specified string array for the restriction values. + /// + /// Specifies the name of the schema to return. + /// Specifies a set of restriction values for the requested schema. + /// A that contains schema information. + /// + /// is specified as . + /// + /// + /// + /// The parameter can supply n depth of values, which are specified by the + /// restrictions collection for a specific collection. In order to set values on a given restriction, and not + /// set the values of other restrictions, you need to set the preceding restrictions to null and then put the + /// appropriate value in for the restriction that you would like to specify a value for. + /// + /// + /// An example of this is the "Tables" collection. If the "Tables" collection has three restrictions (database, + /// owner, and table name) and you want to get back only the tables associated with the owner "Carl", you must + /// pass in the following values at least: null, "Carl". If a restriction value is not passed in, the default + /// values are used for that restriction. This is the same mapping as passing in null, which is different from + /// passing in an empty string for the parameter value. In that case, the empty string ("") is considered to be + /// the value for the specified parameter. + /// + /// + /// If the connection is associated with a transaction, executing + /// calls may cause some providers to throw an exception. + /// + /// public virtual DataTable GetSchema(string collectionName, string?[] restrictionValues) { throw ADP.NotSupported(); } + /// + /// This is the asynchronous version of . + /// Providers should override with an appropriate implementation. + /// The cancellation token can optionally be honored. + /// The default implementation invokes the synchronous call and returns a completed + /// task. + /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken. + /// Exceptions thrown by will be communicated via the returned Task Exception + /// property. + /// + /// The cancellation instruction. + /// A task representing the asynchronous operation. + public virtual Task GetSchemaAsync(CancellationToken cancellationToken = default) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + try + { + return Task.FromResult(GetSchema()); + } + catch (Exception e) + { + return Task.FromException(e); + } + } + + /// + /// This is the asynchronous version of . + /// Providers should override with an appropriate implementation. + /// The cancellation token can optionally be honored. + /// The default implementation invokes the synchronous call and returns a + /// completed task. + /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken. + /// Exceptions thrown by will be communicated via the returned Task Exception + /// property. + /// + /// Specifies the name of the schema to return. + /// The cancellation instruction. + /// A task representing the asynchronous operation. + public virtual Task GetSchemaAsync( + string collectionName, + CancellationToken cancellationToken = default) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + try + { + return Task.FromResult(GetSchema(collectionName)); + } + catch (Exception e) + { + return Task.FromException(e); + } + } + + /// + /// This is the asynchronous version of . + /// Providers should override with an appropriate implementation. + /// The cancellation token can optionally be honored. + /// The default implementation invokes the synchronous call and + /// returns a completed task. + /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken. + /// Exceptions thrown by will be communicated via the returned Task + /// Exception property. + /// + /// Specifies the name of the schema to return. + /// Specifies a set of restriction values for the requested schema. + /// The cancellation instruction. + /// A task representing the asynchronous operation. + public virtual Task GetSchemaAsync(string collectionName, string?[] restrictionValues, + CancellationToken cancellationToken = default) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + try + { + return Task.FromResult(GetSchema(collectionName, restrictionValues)); + } + catch (Exception e) + { + return Task.FromException(e); + } + } + protected virtual void OnStateChange(StateChangeEventArgs stateChange) { if (_suppressStateChangeForReconnection) diff --git a/src/libraries/System.Data.Common/src/System/Data/Common/DbDataReader.cs b/src/libraries/System.Data.Common/src/System/Data/Common/DbDataReader.cs index 2c1e02878151a8..4753058f841d9f 100644 --- a/src/libraries/System.Data.Common/src/System/Data/Common/DbDataReader.cs +++ b/src/libraries/System.Data.Common/src/System/Data/Common/DbDataReader.cs @@ -3,6 +3,7 @@ #nullable enable using System.Collections; +using System.Collections.ObjectModel; using System.ComponentModel; using System.IO; using System.Threading.Tasks; @@ -73,11 +74,77 @@ public virtual ValueTask DisposeAsync() public abstract int GetOrdinal(string name); + /// + /// Returns a that describes the column metadata of the >. + /// + /// A that describes the column metadata. + /// The is closed. + /// The column index is out of range. + /// .NET Core only: This member is not supported. public virtual DataTable GetSchemaTable() { throw new NotSupportedException(); } + /// + /// This is the asynchronous version of . + /// Providers should override with an appropriate implementation. + /// The cancellation token can optionally be honored. + /// The default implementation invokes the synchronous call and + /// returns a completed task. + /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken. + /// Exceptions thrown by will be communicated via the returned Task + /// Exception property. + /// + /// The cancellation instruction. + /// A task representing the asynchronous operation. + public virtual Task GetSchemaTableAsync(CancellationToken cancellationToken = default) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + try + { + return Task.FromResult(GetSchemaTable()); + } + catch (Exception e) + { + return Task.FromException(e); + } + } + + /// + /// This is the asynchronous version of . + /// Providers should override with an appropriate implementation. + /// The cancellation token can optionally be honored. + /// The default implementation invokes the synchronous + /// call and returns a completed task. + /// The default implementation will return a cancelled task if passed an already cancelled cancellationToken. + /// Exceptions thrown by will be + /// communicated via the returned Task Exception property. + /// + /// The cancellation instruction. + /// A task representing the asynchronous operation. + public virtual Task> GetColumnSchemaAsync( + CancellationToken cancellationToken = default) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled>(cancellationToken); + } + + try + { + return Task.FromResult(this.GetColumnSchema()); + } + catch (Exception e) + { + return Task.FromException>(e); + } + } + public abstract bool GetBoolean(int ordinal); public abstract byte GetByte(int ordinal); diff --git a/src/libraries/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs b/src/libraries/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs index 89840aceb8b4da..6a6fd2929a1f24 100644 --- a/src/libraries/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs +++ b/src/libraries/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs @@ -1,7 +1,12 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +#nullable enable + +using System.Diagnostics.CodeAnalysis; using System.Reflection; +using System.Threading; +using System.Threading.Tasks; using Xunit; namespace System.Data.Common.Tests @@ -10,124 +15,70 @@ public class DbConnectionTests { private static volatile bool _wasFinalized; - private class FinalizingConnection : DbConnection + private class MockDbConnection : DbConnection { - public static void CreateAndRelease() + [AllowNull] + public override string ConnectionString { - new FinalizingConnection(); + get => throw new NotImplementedException(); + set => throw new NotImplementedException(); } + public override string Database => throw new NotImplementedException(); + public override string DataSource => throw new NotImplementedException(); + public override string ServerVersion => throw new NotImplementedException(); + public override ConnectionState State => throw new NotImplementedException(); + public override void ChangeDatabase(string databaseName) => throw new NotImplementedException(); + public override void Close() => throw new NotImplementedException(); + public override void Open() => throw new NotImplementedException(); + protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw new NotImplementedException(); + protected override DbCommand CreateDbCommand() => throw new NotImplementedException(); + } + + private class FinalizingConnection : MockDbConnection + { + public static void CreateAndRelease() => new FinalizingConnection(); + protected override void Dispose(bool disposing) { if (!disposing) _wasFinalized = true; base.Dispose(disposing); } - - public override string ConnectionString - { - get - { - throw new NotImplementedException(); - } - - set - { - throw new NotImplementedException(); - } - } - - public override string Database - { - get - { - throw new NotImplementedException(); - } - } - - public override string DataSource - { - get - { - throw new NotImplementedException(); - } - } - - public override string ServerVersion - { - get - { - throw new NotImplementedException(); - } - } - - public override ConnectionState State - { - get - { - throw new NotImplementedException(); - } - } - - public override void ChangeDatabase(string databaseName) - { - throw new NotImplementedException(); - } - - public override void Close() - { - throw new NotImplementedException(); - } - - public override void Open() - { - throw new NotImplementedException(); - } - - protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) - { - throw new NotImplementedException(); - } - - protected override DbCommand CreateDbCommand() - { - throw new NotImplementedException(); - } } - private class DbProviderFactoryConnection : DbConnection + private class GetSchemaConnection : MockDbConnection { - protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) + public override DataTable GetSchema() { - throw new NotImplementedException(); + var table = new DataTable(); + table.Columns.Add(new DataColumn("CollectionName", typeof(string))); + table.Columns.Add(new DataColumn("WithRestrictions", typeof(bool))); + table.Rows.Add("Default", false); + return table; } - public override void ChangeDatabase(string databaseName) + public override DataTable GetSchema(string collectionName) { - throw new NotImplementedException(); + var table = new DataTable(); + table.Columns.Add(new DataColumn("CollectionName", typeof(string))); + table.Columns.Add(new DataColumn("WithRestrictions", typeof(bool))); + table.Rows.Add(collectionName, false); + return table; } - public override void Close() + public override DataTable GetSchema(string collectionName, string?[] restrictionValues) { - throw new NotImplementedException(); - } - - public override void Open() - { - throw new NotImplementedException(); - } - - public override string ConnectionString { get; set; } - public override string Database { get; } - public override ConnectionState State { get; } - public override string DataSource { get; } - public override string ServerVersion { get; } - - protected override DbCommand CreateDbCommand() - { - throw new NotImplementedException(); + var table = new DataTable(); + table.Columns.Add(new DataColumn("CollectionName", typeof(string))); + table.Columns.Add(new DataColumn("WithRestrictions", typeof(bool))); + table.Rows.Add(collectionName, true); + return table; } + } + private class DbProviderFactoryConnection : MockDbConnection + { protected override DbProviderFactory DbProviderFactory => TestDbProviderFactory.Instance; } @@ -150,12 +101,48 @@ public void CanBeFinalized() public void ProviderFactoryTest() { DbProviderFactoryConnection con = new DbProviderFactoryConnection(); - PropertyInfo providerFactoryProperty = con.GetType().GetProperty("ProviderFactory", BindingFlags.NonPublic | BindingFlags.Instance); + PropertyInfo providerFactoryProperty = con.GetType().GetProperty("ProviderFactory", BindingFlags.NonPublic | BindingFlags.Instance)!; Assert.NotNull(providerFactoryProperty); - DbProviderFactory factory = providerFactoryProperty.GetValue(con) as DbProviderFactory; + DbProviderFactory? factory = providerFactoryProperty.GetValue(con) as DbProviderFactory; Assert.NotNull(factory); - Assert.Same(typeof(TestDbProviderFactory), factory.GetType()); + Assert.Same(typeof(TestDbProviderFactory), factory!.GetType()); Assert.Same(TestDbProviderFactory.Instance, factory); } + + [Fact] + public void GetSchemaAsync_with_cancelled_token() + { + var conn = new MockDbConnection(); + Assert.ThrowsAsync(async () => await conn.GetSchemaAsync(new CancellationToken(true))); + Assert.ThrowsAsync(async () => await conn.GetSchemaAsync("MetaDataCollections", new CancellationToken(true))); + Assert.ThrowsAsync(async () => await conn.GetSchemaAsync("MetaDataCollections", new string[0], new CancellationToken(true))); + } + + [Fact] + public void GetSchemaAsync_with_exception() + { + var conn = new MockDbConnection(); + Assert.ThrowsAsync(async () => await conn.GetSchemaAsync()); + Assert.ThrowsAsync(async () => await conn.GetSchemaAsync("MetaDataCollections")); + Assert.ThrowsAsync(async () => await conn.GetSchemaAsync("MetaDataCollections", new string[0])); + } + + [Fact] + public async Task GetSchemaAsync_calls_GetSchema() + { + var conn = new GetSchemaConnection(); + + var row = (await conn.GetSchemaAsync()).Rows[0]; + Assert.Equal("Default", row["CollectionName"]); + Assert.Equal(false, row["WithRestrictions"]); + + row = (await conn.GetSchemaAsync("MetaDataCollections")).Rows[0]; + Assert.Equal("MetaDataCollections", row["CollectionName"]); + Assert.Equal(false, row["WithRestrictions"]); + + row = (await conn.GetSchemaAsync("MetaDataCollections", new string?[0])).Rows[0]; + Assert.Equal("MetaDataCollections", row["CollectionName"]); + Assert.Equal(true, row["WithRestrictions"]); + } } } diff --git a/src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderMock.cs b/src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderMock.cs index 00981a8d94eb36..0735e90df42a0d 100644 --- a/src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderMock.cs +++ b/src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderMock.cs @@ -24,52 +24,32 @@ // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. // +#nullable enable +using System.Collections; +using System.Linq; using System.Data.Common; -namespace System.Data.Tests.Common +namespace System.Data.Common.Tests { internal class DbDataReaderMock : DbDataReader { - private int _currentRowIndex = -1; - private DataTable _testDataTable; + protected int _currentRowIndex = -1; + protected DataTable _testDataTable; public DbDataReaderMock() - { - _testDataTable = new DataTable(); - } + => _testDataTable = new DataTable(); public DbDataReaderMock(DataTable testData) - { - _testDataTable = testData ?? throw new ArgumentNullException(nameof(testData)); - } - - public override void Close() - { - _testDataTable.Clear(); - } - - public override int Depth - { - get { throw new NotImplementedException(); } - } - - public override int FieldCount - { - get { throw new NotImplementedException(); } - } - - public override bool GetBoolean(int ordinal) - { - return (bool)GetValue(ordinal); - } + => _testDataTable = testData ?? throw new ArgumentNullException(nameof(testData)); - public override byte GetByte(int ordinal) - { - return (byte)GetValue(ordinal); - } + public override void Close() => _testDataTable.Clear(); + public override int Depth => throw new NotImplementedException(); + public override int FieldCount => throw new NotImplementedException(); + public override bool GetBoolean(int ordinal) => (bool)GetValue(ordinal); + public override byte GetByte(int ordinal) => (byte)GetValue(ordinal); - public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) + public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) { object value = GetValue(ordinal); if (value == DBNull.Value) @@ -78,17 +58,19 @@ public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int b } byte[] data = (byte[])value; + if (buffer is null) + { + return data.Length; + } + long bytesToRead = Math.Min(data.Length - dataOffset, length); Buffer.BlockCopy(data, (int)dataOffset, buffer, bufferOffset, (int)bytesToRead); return bytesToRead; } - public override char GetChar(int ordinal) - { - return (char)GetValue(ordinal); - } + public override char GetChar(int ordinal) => (char)GetValue(ordinal); - public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) + public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) { object value = GetValue(ordinal); if (value == DBNull.Value) @@ -96,71 +78,28 @@ public override long GetChars(int ordinal, long dataOffset, char[] buffer, int b return 0; } - char[] data = value.ToString().ToCharArray(); + char[] data = value.ToString()!.ToCharArray(); + if (buffer is null) + { + return data.Length; + } long bytesToRead = Math.Min(data.Length - dataOffset, length); Array.Copy(data, dataOffset, buffer, bufferOffset, bytesToRead); return bytesToRead; } - public override string GetDataTypeName(int ordinal) - { - throw new NotImplementedException(); - } - - public override DateTime GetDateTime(int ordinal) - { - return (DateTime)GetValue(ordinal); - } - - public override decimal GetDecimal(int ordinal) - { - return (decimal)GetValue(ordinal); - } - - public override double GetDouble(int ordinal) - { - return (double)GetValue(ordinal); - } - - public override global::System.Collections.IEnumerator GetEnumerator() - { - throw new NotImplementedException(); - } - - public override Type GetFieldType(int ordinal) - { - throw new NotImplementedException(); - } - - public override float GetFloat(int ordinal) - { - return (float)GetValue(ordinal); - } - - public override Guid GetGuid(int ordinal) - { - return (Guid)GetValue(ordinal); - } - - public override short GetInt16(int ordinal) - { - return (short)GetValue(ordinal); - } - - public override int GetInt32(int ordinal) - { - return (int)GetValue(ordinal); - } - - public override long GetInt64(int ordinal) - { - return (long)GetValue(ordinal); - } - - public override string GetName(int ordinal) - { - return _testDataTable.Columns[ordinal].ColumnName; - } + public override string GetDataTypeName(int ordinal) => throw new NotImplementedException(); + public override DateTime GetDateTime(int ordinal) => (DateTime)GetValue(ordinal); + public override decimal GetDecimal(int ordinal) => (decimal)GetValue(ordinal); + public override double GetDouble(int ordinal) => (double)GetValue(ordinal); + public override IEnumerator GetEnumerator() => throw new NotImplementedException(); + public override Type GetFieldType(int ordinal) => throw new NotImplementedException(); + public override float GetFloat(int ordinal) => (float)GetValue(ordinal); + public override Guid GetGuid(int ordinal) => (Guid)GetValue(ordinal); + public override short GetInt16(int ordinal) => (short)GetValue(ordinal); + public override int GetInt32(int ordinal) => (int)GetValue(ordinal); + public override long GetInt64(int ordinal) => (long)GetValue(ordinal); + public override string GetName(int ordinal) => _testDataTable.Columns[ordinal].ColumnName; public override int GetOrdinal(string name) { @@ -178,45 +117,13 @@ public override int GetOrdinal(string name) return -1; } - public override DataTable GetSchemaTable() - { - throw new NotImplementedException(); - } - - public override string GetString(int ordinal) - { - return (string)_testDataTable.Rows[_currentRowIndex][ordinal]; - } - - public override object GetValue(int ordinal) - { - return _testDataTable.Rows[_currentRowIndex][ordinal]; - } - - public override int GetValues(object[] values) - { - throw new NotImplementedException(); - } - - public override bool HasRows - { - get { throw new NotImplementedException(); } - } - - public override bool IsClosed - { - get { throw new NotImplementedException(); } - } - - public override bool IsDBNull(int ordinal) - { - return _testDataTable.Rows[_currentRowIndex][ordinal] == DBNull.Value; - } - - public override bool NextResult() - { - throw new NotImplementedException(); - } + public override string GetString(int ordinal) => (string)_testDataTable.Rows[_currentRowIndex][ordinal]; + public override object GetValue(int ordinal) => _testDataTable.Rows[_currentRowIndex][ordinal]; + public override int GetValues(object[] values) => throw new NotImplementedException(); + public override bool HasRows => throw new NotImplementedException(); + public override bool IsClosed => throw new NotImplementedException(); + public override bool IsDBNull(int ordinal) => _testDataTable.Rows[_currentRowIndex][ordinal] == DBNull.Value; + public override bool NextResult() => throw new NotImplementedException(); public override bool Read() { @@ -224,19 +131,30 @@ public override bool Read() return _currentRowIndex < _testDataTable.Rows.Count; } - public override int RecordsAffected - { - get { throw new NotImplementedException(); } - } + public override int RecordsAffected => throw new NotImplementedException(); + public override object this[string name] => throw new NotImplementedException(); + public override object this[int ordinal] => throw new NotImplementedException(); + } - public override object this[string name] - { - get { throw new NotImplementedException(); } - } + internal class SchemaDbDataReaderMock : DbDataReaderMock + { + public SchemaDbDataReaderMock(DataTable testData) : base(testData) {} - public override object this[int ordinal] + public override DataTable GetSchemaTable() { - get { throw new NotImplementedException(); } + var table = new DataTable("SchemaTable"); + table.Columns.Add("ColumnName", typeof(string)); + table.Columns.Add("DataType", typeof(Type)); + + foreach (var column in _testDataTable.Columns.Cast()) + { + var row = table.NewRow(); + row["ColumnName"] = column.ColumnName; + row["DataType"] = column.DataType; + table.Rows.Add(row); + } + + return table; } } } diff --git a/src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderTest.cs b/src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderTest.cs index 04e54ce43d6dc9..be9460d31e94db 100644 --- a/src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderTest.cs +++ b/src/libraries/System.Data.Common/tests/System/Data/Common/DbDataReaderTest.cs @@ -21,12 +21,15 @@ // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +#nullable enable + +using System.Linq; using System.IO; using System.Threading; using System.Threading.Tasks; using Xunit; -namespace System.Data.Tests.Common +namespace System.Data.Common.Tests { public class DbDataReaderTest { @@ -499,6 +502,46 @@ public Task IsDbNullAsyncByColumnNameCanceledTest() return Assert.ThrowsAsync(() => _dataReader.IsDBNullAsync("dbnull_col", new CancellationToken(true))); } + [Fact] + public void GetSchemaTableAsync_with_cancelled_token() + => Assert.ThrowsAsync(async () => await new DbDataReaderMock().GetSchemaTableAsync(new CancellationToken(true))); + + [Fact] + public void GetSchemaTableAsync_with_exception() + => Assert.ThrowsAsync(async () => await new DbDataReaderMock().GetSchemaTableAsync()); + + [Fact] + public async Task GetSchemaTableAsync_calls_GetSchemaTable() + { + var readerTable = new DataTable(); + readerTable.Columns.Add("text_col", typeof(string)); + + var table = await new SchemaDbDataReaderMock(readerTable).GetSchemaTableAsync(); + + var textColRow = table.Rows.Cast().Single(); + Assert.Equal("text_col", textColRow["ColumnName"]); + Assert.Same(typeof(string), textColRow["DataType"]); + } + + [Fact] + public void GetColumnSchemaAsync_with_cancelled_token() + => Assert.ThrowsAsync(async () => await new DbDataReaderMock().GetColumnSchemaAsync(new CancellationToken(true))); + + [Fact] + public void GetColumnSchemaAsync_with_exception() + => Assert.ThrowsAsync(async () => await new DbDataReaderMock().GetColumnSchemaAsync()); + + [Fact] + public async Task GetColumnSchemaAsync_calls_GetSchemaTable() + { + var readerTable = new DataTable(); + readerTable.Columns.Add("text_col", typeof(string)); + + var column = (await new SchemaDbDataReaderMock(readerTable).GetColumnSchemaAsync()).Single(); + Assert.Equal("text_col", column.ColumnName); + Assert.Same(typeof(string), column.DataType); + } + private void SkipRows(int rowsToSkip) { var i = 0;