Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,10 @@ private Type GetFieldTypeInternal(_SqlMetaData metaData)
Connection.CheckGetExtendedUDTInfo(metaData, false);
fieldType = metaData.udt?.Type;
}
else if (metaData.type == SqlDbTypeExtensions.Vector)
{
fieldType = GetVectorFieldType(metaData.scale);
}
else
{ // For all other types, including Xml - use data in MetaType.
if (metaData.cipherMD != null)
Expand All @@ -1329,6 +1333,19 @@ private Type GetFieldTypeInternal(_SqlMetaData metaData)
return fieldType;
}

#if !NETFRAMEWORK
[return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)]
#endif
private static Type GetVectorFieldType(byte vectorElementType)
{
MetaType.SqlVectorElementType elementType = (MetaType.SqlVectorElementType)vectorElementType;
return elementType switch
{
MetaType.SqlVectorElementType.Float32 => typeof(SqlVector<float>),
_ => throw SQL.VectorTypeNotSupported(elementType.ToString()),
};
}

virtual internal int GetLocaleId(int i)
{
_SqlMetaData sqlMetaData = MetaData[i];
Expand Down Expand Up @@ -1422,6 +1439,10 @@ private Type GetProviderSpecificFieldTypeInternal(_SqlMetaData metaData)
Connection.CheckGetExtendedUDTInfo(metaData, false);
providerSpecificFieldType = metaData.udt?.Type;
}
else if (metaData.type == SqlDbTypeExtensions.Vector)
{
providerSpecificFieldType = GetVectorFieldType(metaData.scale);
}
else
{
// For all other types, including Xml - use data in MetaType.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ public static IEnumerable<object[]> GetVectorFloat32TestData()
yield return new object[] { 3, new SqlVector<float>(testData), testData, vectorColumnLength };
yield return new object[] { 4, new SqlVector<float>(testData), testData, vectorColumnLength };

// Pattern 1-4 with SqlVector<float>(n)
// Pattern 1-4 with SqlVector<float>(n)
yield return new object[] { 1, SqlVector<float>.CreateNull(vectorColumnLength), Array.Empty<float>(), vectorColumnLength };
yield return new object[] { 2, SqlVector<float>.CreateNull(vectorColumnLength), Array.Empty<float>(), vectorColumnLength };
yield return new object[] { 3, SqlVector<float>.CreateNull(vectorColumnLength), Array.Empty<float>(), vectorColumnLength };
yield return new object[] { 4, SqlVector<float>.CreateNull(vectorColumnLength), Array.Empty<float>(), vectorColumnLength };

// Pattern 1-4 with DBNull
// Pattern 1-4 with DBNull
yield return new object[] { 1, DBNull.Value, Array.Empty<float>(), vectorColumnLength };
yield return new object[] { 2, DBNull.Value, Array.Empty<float>(), vectorColumnLength };
yield return new object[] { 3, DBNull.Value, Array.Empty<float>(), vectorColumnLength };
yield return new object[] { 4, DBNull.Value, Array.Empty<float>(), vectorColumnLength };

// Pattern 1-4 with SqlVector<float>.Null
// Pattern 1-4 with SqlVector<float>.Null
yield return new object[] { 1, SqlVector<float>.Null, Array.Empty<float>(), vectorColumnLength };

// Following scenario is not supported in SqlClient.
Expand Down Expand Up @@ -561,6 +561,41 @@ public async Task TestBulkCopyFromSqlTableAsync(int bulkCopySourceMode)
Assert.Equal(VectorFloat32TestData.testData.Length, vector.Length);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))]
public void TestGetFieldTypeReturnsSqlVectorForVectorColumn()
{
using var connection = new SqlConnection(s_connectionString);
connection.Open();

// Insert a row so we can query it
using (var insertCmd = new SqlCommand(s_insertCmdString, connection))
{
var param = insertCmd.Parameters.Add(s_vectorParamName, SqlDbTypeExtensions.Vector);
param.Value = new SqlVector<float>(VectorFloat32TestData.testData);
insertCmd.ExecuteNonQuery();
}

using var selectCmd = new SqlCommand(s_selectCmdString, connection);
using var reader = selectCmd.ExecuteReader();

// Verify GetFieldType returns SqlVector<float> for the vector column
Assert.Equal(typeof(SqlVector<float>), reader.GetFieldType(0));

// Verify GetProviderSpecificFieldType also returns SqlVector<float>
Assert.Equal(typeof(SqlVector<float>), reader.GetProviderSpecificFieldType(0));

// Verify that GetValue returns an instance consistent with GetFieldType
Assert.True(reader.Read(), "No data found in the table.");
object value = reader.GetValue(0);
Assert.IsType<SqlVector<float>>(value);
Assert.Equal(VectorFloat32TestData.testData, ((SqlVector<float>)value).Memory.ToArray());

// Verify GetFieldValue<SqlVector<float>> returns the correct typed value
SqlVector<float> typedValue = reader.GetFieldValue<SqlVector<float>>(0);
Assert.IsType<SqlVector<float>>(typedValue);
Assert.Equal(VectorFloat32TestData.testData, typedValue.Memory.ToArray());
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsSqlVectorSupported))]
public void TestInsertVectorsFloat32WithPrepare()
{
Expand Down
Loading