diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 8806df9432..bf1f38da00 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -1307,6 +1307,10 @@ private Type GetFieldTypeInternal(_SqlMetaData metaData) Connection.CheckGetExtendedUDTInfo(metaData, false); fieldType = metaData.udt?.Type; } + else if (metaData.type == SqlDbTypeExtensions.Vector) + { + fieldType = GetVectorFieldType(metaData.scale); + } else { // For all other types, including Xml - use data in MetaType. if (metaData.cipherMD != null) @@ -1329,6 +1333,19 @@ private Type GetFieldTypeInternal(_SqlMetaData metaData) return fieldType; } +#if !NETFRAMEWORK + [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] +#endif + private static Type GetVectorFieldType(byte vectorElementType) + { + MetaType.SqlVectorElementType elementType = (MetaType.SqlVectorElementType)vectorElementType; + return elementType switch + { + MetaType.SqlVectorElementType.Float32 => typeof(SqlVector), + _ => throw SQL.VectorTypeNotSupported(elementType.ToString()), + }; + } + virtual internal int GetLocaleId(int i) { _SqlMetaData sqlMetaData = MetaData[i]; @@ -1422,6 +1439,10 @@ private Type GetProviderSpecificFieldTypeInternal(_SqlMetaData metaData) Connection.CheckGetExtendedUDTInfo(metaData, false); providerSpecificFieldType = metaData.udt?.Type; } + else if (metaData.type == SqlDbTypeExtensions.Vector) + { + providerSpecificFieldType = GetVectorFieldType(metaData.scale); + } else { // For all other types, including Xml - use data in MetaType. diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs index 2ff72bba06..0a86534112 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/VectorTest/NativeVectorFloat32Tests.cs @@ -29,19 +29,19 @@ public static IEnumerable GetVectorFloat32TestData() yield return new object[] { 3, new SqlVector(testData), testData, vectorColumnLength }; yield return new object[] { 4, new SqlVector(testData), testData, vectorColumnLength }; - // Pattern 1-4 with SqlVector(n) + // Pattern 1-4 with SqlVector(n) yield return new object[] { 1, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; yield return new object[] { 2, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; yield return new object[] { 3, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; yield return new object[] { 4, SqlVector.CreateNull(vectorColumnLength), Array.Empty(), vectorColumnLength }; - // Pattern 1-4 with DBNull + // Pattern 1-4 with DBNull yield return new object[] { 1, DBNull.Value, Array.Empty(), vectorColumnLength }; yield return new object[] { 2, DBNull.Value, Array.Empty(), vectorColumnLength }; yield return new object[] { 3, DBNull.Value, Array.Empty(), vectorColumnLength }; yield return new object[] { 4, DBNull.Value, Array.Empty(), vectorColumnLength }; - // Pattern 1-4 with SqlVector.Null + // Pattern 1-4 with SqlVector.Null yield return new object[] { 1, SqlVector.Null, Array.Empty(), vectorColumnLength }; // Following scenario is not supported in SqlClient. @@ -561,6 +561,41 @@ public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode) Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length); } + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] + public void TestGetFieldTypeReturnsSqlVectorForVectorColumn() + { + using var connection = new SqlConnection(s_connectionString); + connection.Open(); + + // Insert a row so we can query it + using (var insertCmd = new SqlCommand(s_insertCmdString, connection)) + { + var param = insertCmd.Parameters.Add(s_vectorParamName, SqlDbTypeExtensions.Vector); + param.Value = new SqlVector(VectorFloat32TestData.testData); + insertCmd.ExecuteNonQuery(); + } + + using var selectCmd = new SqlCommand(s_selectCmdString, connection); + using var reader = selectCmd.ExecuteReader(); + + // Verify GetFieldType returns SqlVector for the vector column + Assert.Equal(typeof(SqlVector), reader.GetFieldType(0)); + + // Verify GetProviderSpecificFieldType also returns SqlVector + Assert.Equal(typeof(SqlVector), reader.GetProviderSpecificFieldType(0)); + + // Verify that GetValue returns an instance consistent with GetFieldType + Assert.True(reader.Read(), "No data found in the table."); + object value = reader.GetValue(0); + Assert.IsType>(value); + Assert.Equal(VectorFloat32TestData.testData, ((SqlVector)value).Memory.ToArray()); + + // Verify GetFieldValue> returns the correct typed value + SqlVector typedValue = reader.GetFieldValue>(0); + Assert.IsType>(typedValue); + Assert.Equal(VectorFloat32TestData.testData, typedValue.Memory.ToArray()); + } + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))] public void TestInsertVectorsFloat32WithPrepare() {