From 21082b89dc6c5b5444a51315d67989dd97173d5c Mon Sep 17 00:00:00 2001 From: Milder Hernandez Cagua Date: Wed, 31 Jul 2024 10:19:39 -0700 Subject: [PATCH 1/6] Update VectorStoreRecordCollection --- ...reAISearchVectorStoreRecordCollection.java | 15 ++- .../RedisVectorStoreRecordCollection.java | 10 +- .../data/VectorStoreRecordCollection.java | 11 +- .../VolatileVectorStoreRecordCollection.java | 5 +- .../VectorStoreRecordDefinition.java | 117 ++++++++---------- 5 files changed, 77 insertions(+), 81 deletions(-) diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java index 9576b122..70d517ac 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java @@ -90,12 +90,15 @@ public AzureAISearchVectorStoreRecordCollection( : options.getRecordDefinition(); // Validate supported types - VectorStoreRecordDefinition.validateSupportedKeyTypes(this.options.getRecordClass(), - this.recordDefinition, supportedKeyTypes); - VectorStoreRecordDefinition.validateSupportedDataTypes(this.options.getRecordClass(), - this.recordDefinition, supportedDataTypes); - VectorStoreRecordDefinition.validateSupportedVectorTypes(this.options.getRecordClass(), - this.recordDefinition, supportedVectorTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + Collections.singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), + supportedKeyTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getDataDeclaredFields(this.options.getRecordClass()), + supportedDataTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), + supportedVectorTypes); // Add non-vector fields to the list nonVectorFields.add(this.recordDefinition.getKeyField().getName()); diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java index 8783320b..e8d9e1db 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java @@ -81,10 +81,12 @@ public RedisVectorStoreRecordCollection( } // Validate supported types - VectorStoreRecordDefinition.validateSupportedKeyTypes(options.getRecordClass(), - recordDefinition, supportedKeyTypes); - VectorStoreRecordDefinition.validateSupportedVectorTypes(options.getRecordClass(), - recordDefinition, supportedVectorTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + Collections.singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), + supportedKeyTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), + supportedVectorTypes); // If mapper is not provided, set a default one if (options.getVectorStoreRecordMapper() == null) { diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java index 867cbf16..1466ac35 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java @@ -9,41 +9,40 @@ import java.util.List; public interface VectorStoreRecordCollection { - /** * Gets the name of the collection. * * @return The name of the collection. */ - public String getCollectionName(); + String getCollectionName(); /** * Checks if the collection exists in the store. * * @return A Mono emitting a boolean indicating if the collection exists. */ - public Mono collectionExistsAsync(); + Mono collectionExistsAsync(); /** * Creates the collection in the store. * * @return A Mono representing the completion of the creation operation. */ - public Mono createCollectionAsync(); + Mono createCollectionAsync(); /** * Creates the collection in the store if it does not exist. * * @return A Mono representing the completion of the creation operation. */ - public Mono createCollectionIfNotExistsAsync(); + Mono createCollectionIfNotExistsAsync(); /** * Deletes the collection from the store. * * @return A Mono representing the completion of the deletion operation. */ - public Mono deleteCollectionAsync(); + Mono deleteCollectionAsync(); /** * Gets a record from the store. diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java index e675d4cd..8c1270ba 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java @@ -43,8 +43,9 @@ public VolatileVectorStoreRecordCollection(String collectionName, } // Validate the key type - VectorStoreRecordDefinition.validateSupportedKeyTypes(options.getRecordClass(), - recordDefinition, supportedKeyTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + Collections.singletonList(recordDefinition.getKeyDeclaredField(options.getRecordClass())), + supportedKeyTypes); } VolatileVectorStoreRecordCollection(String collectionName, diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java index a1914d2c..e36a258b 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java @@ -5,13 +5,12 @@ import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; /** @@ -50,6 +49,50 @@ public List getAllFields() { return fields; } + public List getNonVectorFields() { + List fields = new ArrayList<>(); + fields.add(keyField); + fields.addAll(dataFields); + return fields; + } + + private List getDeclaredFields(Class recordClass, List fields, String fieldType) { + List declaredFields = new ArrayList<>(); + for (VectorStoreRecordField field : fields) { + try { + Field declaredField = recordClass.getDeclaredField(field.getName()); + declaredFields.add(declaredField); + } catch (NoSuchFieldException e) { + throw new IllegalArgumentException( + String.format("%s field not found in record class: %s", fieldType, field.getName())); + } + } + return declaredFields; + } + + public Field getKeyDeclaredField(Class recordClass) { + try { + return recordClass.getDeclaredField(keyField.getName()); + } catch (NoSuchFieldException e) { + throw new IllegalArgumentException( + "Key field not found in record class: " + keyField.getName()); + } + } + + public List getDataDeclaredFields(Class recordClass) { + return getDeclaredFields( + recordClass, + dataFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), + "Data"); + } + + public List getVectorDeclaredFields(Class recordClass) { + return getDeclaredFields( + recordClass, + vectorFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), + "Vector"); + } + private VectorStoreRecordDefinition( VectorStoreRecordKeyField keyField, List dataFields, @@ -148,71 +191,19 @@ public static VectorStoreRecordDefinition fromRecordClass(Class recordClass) return checkFields(keyFields, dataFields, vectorFields); } - private static String getSupportedTypesString(@Nullable HashSet> types) { - if (types == null || types.isEmpty()) { - return ""; - } - return types.stream().map(Class::getName).collect(Collectors.joining(", ")); - } - - public static void validateSupportedKeyTypes(@Nonnull Class recordClass, - @Nonnull VectorStoreRecordDefinition recordDefinition, - @Nonnull HashSet> supportedTypes) { - String supportedTypesString = getSupportedTypesString(supportedTypes); - - try { - Field declaredField = recordClass.getDeclaredField(recordDefinition.keyField.getName()); + public static void validateSupportedTypes(List declaredFields, Set> supportedTypes) { + Set> unsupportedTypes = new HashSet<>(); + for (Field declaredField : declaredFields) { if (!supportedTypes.contains(declaredField.getType())) { - throw new IllegalArgumentException( - "Unsupported key field type: " + declaredField.getType().getName() - + ". Supported types are: " + supportedTypesString); - } - } catch (NoSuchFieldException e) { - throw new IllegalArgumentException( - "Key field not found in record class: " + recordDefinition.keyField.getName()); - } - } - - public static void validateSupportedDataTypes(@Nonnull Class recordClass, - @Nonnull VectorStoreRecordDefinition recordDefinition, - @Nonnull HashSet> supportedTypes) { - String supportedTypesString = getSupportedTypesString(supportedTypes); - - for (VectorStoreRecordDataField field : recordDefinition.dataFields) { - try { - Field declaredField = recordClass.getDeclaredField(field.getName()); - - if (!supportedTypes.contains(declaredField.getType())) { - throw new IllegalArgumentException( - "Unsupported data field type: " + declaredField.getType().getName() - + ". Supported types are: " + supportedTypesString); - } - } catch (NoSuchFieldException e) { - throw new IllegalArgumentException( - "Data field not found in record class: " + field.getName()); + unsupportedTypes.add(declaredField.getType()); } } - } - - public static void validateSupportedVectorTypes(@Nonnull Class recordClass, - @Nonnull VectorStoreRecordDefinition recordDefinition, - @Nonnull HashSet> supportedTypes) { - String supportedTypesString = getSupportedTypesString(supportedTypes); - - for (VectorStoreRecordVectorField field : recordDefinition.vectorFields) { - try { - Field declaredField = recordClass.getDeclaredField(field.getName()); - - if (!supportedTypes.contains(declaredField.getType())) { - throw new IllegalArgumentException( - "Unsupported vector field type: " + declaredField.getType().getName() - + ". Supported types are: " + supportedTypesString); - } - } catch (NoSuchFieldException e) { - throw new IllegalArgumentException( - "Vector field not found in record class: " + field.getName()); - } + if (!unsupportedTypes.isEmpty()) { + throw new IllegalArgumentException( + String.format("Unsupported field types found in record class: %s. Supported types: %s", + unsupportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")), + supportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")))); } } } From 4d17c1f5bb0ca5209afa64bc5d190e6981c75995 Mon Sep 17 00:00:00 2001 From: Milder Hernandez Cagua Date: Thu, 1 Aug 2024 00:07:02 -0700 Subject: [PATCH 2/6] Add JDBC Vector Store --- .../JDBCVectorStoreRecordCollectionTest.java | 248 +++++++++++++ .../memory/jdbc/JDBCVectorStoreTest.java | 64 ++++ .../connectors/data/jdbc/JDBCVectorStore.java | 171 +++++++++ .../JDBCVectorStoreDefaultQueryProvider.java | 304 ++++++++++++++++ .../data/jdbc/JDBCVectorStoreOptions.java | 101 ++++++ .../jdbc/JDBCVectorStoreQueryProvider.java | 120 +++++++ .../jdbc/JDBCVectorStoreRecordCollection.java | 331 ++++++++++++++++++ ...DBCVectorStoreRecordCollectionFactory.java | 19 + ...DBCVectorStoreRecordCollectionOptions.java | 128 +++++++ .../jdbc/JDBCVectorStoreRecordMapper.java | 150 ++++++++ .../jdbc/MySQLVectorStoreQueryProvider.java | 104 ++++++ .../connectors/data/jdbc/SQLVectorStore.java | 15 + .../jdbc/SQLVectorStoreRecordCollection.java | 14 + 13 files changed, 1769 insertions(+) create mode 100644 api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java create mode 100644 api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java create mode 100644 semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java new file mode 100644 index 00000000..efbe9638 --- /dev/null +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java @@ -0,0 +1,248 @@ +package com.microsoft.semantickernel.tests.connectors.memory.jdbc; + +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollection; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions; +import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.tests.connectors.memory.Hotel; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import javax.annotation.Nonnull; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +@Testcontainers +public class JDBCVectorStoreRecordCollectionTest { + @Container + private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); + private static final String MYSQL_USER = "test"; + private static final String MYSQL_PASSWORD = "test"; + private static Connection connection; + @BeforeAll + static void setup() throws SQLException { + connection = DriverManager.getConnection(CONTAINER.getJdbcUrl(), MYSQL_USER, MYSQL_PASSWORD); + } + + private JDBCVectorStoreRecordCollection buildRecordCollection(@Nonnull String collectionName) { + JDBCVectorStoreRecordCollection recordCollection = new JDBCVectorStoreRecordCollection<>( + connection, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(Hotel.class) + .withQueryProvider(MySQLVectorStoreQueryProvider.builder() + .withConnection(connection) + .build()) + .build()); + + recordCollection.prepareAsync().block(); + recordCollection.createCollectionIfNotExistsAsync().block(); + return recordCollection; + } + + @Test + public void buildRecordCollection() { + assertNotNull(buildRecordCollection("buildTest")); + } + + private List getHotels() { + return List.of( + new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), + new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), 3.0), + new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0), + new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), + new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0) + ); + } + + @Test + public void upsertAndGetRecordAsync() { + String collectionName = "upsertAndGetRecordAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + for (Hotel hotel : hotels) { + recordStore.upsertAsync(hotel, null).block(); + } + + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + } + } + + @Test + public void getBatchAsync() { + String collectionName = "getBatchAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + for (Hotel hotel : hotels) { + recordStore.upsertAsync(hotel, null).block(); + } + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordStore.getBatchAsync(keys, null).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + } + + @Test + public void upsertBatchAndGetBatchAsync() { + String collectionName = "upsertBatchAndGetBatchAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordStore.getBatchAsync(keys, null).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + } + + @Test + public void insertAndReplaceAsync() { + String collectionName = "insertAndReplaceAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + recordStore.upsertBatchAsync(hotels, null).block(); + recordStore.upsertBatchAsync(hotels, null).block(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordStore.getBatchAsync(keys, null).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + } + + @Test + public void deleteRecordAsync() { + String collectionName = "deleteRecordAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + for (Hotel hotel : hotels) { + recordStore.deleteAsync(hotel.getId(), null).block(); + Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block(); + assertNull(retrievedHotel); + } + } + + @Test + public void deleteBatchAsync() { + String collectionName = "deleteBatchAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + recordStore.deleteBatchAsync(keys, null).block(); + + for (String key : keys) { + Hotel retrievedHotel = recordStore.getAsync(key, null).block(); + assertNull(retrievedHotel); + } + } + + @Test + public void getWithNoVectors() { + String collectionName = "getWithNoVectors"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + GetRecordOptions options = GetRecordOptions.builder() + .includeVectors(false) + .build(); + + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + assertNull(retrievedHotel.getDescriptionEmbedding()); + } + + options = GetRecordOptions.builder() + .includeVectors(true) + .build(); + + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + assertNotNull(retrievedHotel.getDescriptionEmbedding()); + } + } + + @Test + public void getBatchWithNoVectors() { + String collectionName = "getBatchWithNoVectors"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + GetRecordOptions options = GetRecordOptions.builder() + .includeVectors(false) + .build(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordStore.getBatchAsync(keys, options).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + + for (Hotel hotel : retrievedHotels) { + assertNull(hotel.getDescriptionEmbedding()); + } + + options = GetRecordOptions.builder() + .includeVectors(true) + .build(); + + retrievedHotels = recordStore.getBatchAsync(keys, options).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + + for (Hotel hotel : retrievedHotels) { + assertNotNull(hotel.getDescriptionEmbedding()); + } + } +} diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java new file mode 100644 index 00000000..0ebeed42 --- /dev/null +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java @@ -0,0 +1,64 @@ +package com.microsoft.semantickernel.tests.connectors.memory.jdbc; + +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions; +import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.tests.connectors.memory.Hotel; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Testcontainers +public class JDBCVectorStoreTest { + @Container + private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); + private static final String MYSQL_USER = "test"; + private static final String MYSQL_PASSWORD = "test"; + private static Connection connection; + + @BeforeAll + static void setup() throws SQLException { + connection = DriverManager.getConnection(CONTAINER.getJdbcUrl(), MYSQL_USER, MYSQL_PASSWORD); + } + + @Test + public void getCollectionNamesAsync() { + JDBCVectorStoreOptions options = JDBCVectorStoreOptions.builder() + .withQueryProvider(MySQLVectorStoreQueryProvider.builder() + .withConnection(connection) + .build()) + .build(); + + JDBCVectorStore vectorStore = JDBCVectorStore.builder() + .withConnection(connection) + .withOptions(options) + .build(); + + vectorStore.getCollectionNamesAsync().block(); + + List collectionNames = Arrays.asList("collection1", "collection2", "collection3"); + + for (String collectionName : collectionNames) { + vectorStore.getCollection(collectionName, Hotel.class, null).createCollectionAsync().block(); + } + + List retrievedCollectionNames = vectorStore.getCollectionNamesAsync().block(); + assertNotNull(retrievedCollectionNames); + assertEquals(collectionNames.size(), retrievedCollectionNames.size()); + for (String collectionName : collectionNames) { + assertTrue(retrievedCollectionNames.contains(collectionName)); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java new file mode 100644 index 00000000..df2155a0 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java @@ -0,0 +1,171 @@ +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.exceptions.SKException; +import reactor.core.publisher.Mono; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +/** + * A JDBC vector store. + */ +public class JDBCVectorStore implements SQLVectorStore> { + private final Connection connection; + private final JDBCVectorStoreOptions options; + private final JDBCVectorStoreQueryProvider queryProvider; + + /** + * Creates a new instance of the {@link JDBCVectorStore}. + * If using this constructor, call {@link #prepareAsync()} before using the vector store. + * + * @param connection the connection + * @param options the options + */ + public JDBCVectorStore(@Nonnull Connection connection, @Nullable JDBCVectorStoreOptions options) { + this.connection = connection; + this.options = options; + + if (this.options != null && this.options.getQueryProvider() != null) { + this.queryProvider = this.options.getQueryProvider(); + } else { + this.queryProvider = new JDBCVectorStoreDefaultQueryProvider(connection); + } + } + + /** + * Creates a new builder for the vector store. + * + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets a collection from the vector store. + * + * @param collectionName The name of the collection. + * @param recordClass The class type of the record. + * @param recordDefinition The record definition. + * @return The collection. + */ + @Override + public JDBCVectorStoreRecordCollection getCollection( + @Nonnull String collectionName, + @Nonnull Class recordClass, + @Nullable VectorStoreRecordDefinition recordDefinition) { + + if (this.options != null && this.options.getVectorStoreRecordCollectionFactory() != null) { + return this.options.getVectorStoreRecordCollectionFactory() + .createVectorStoreRecordCollection( + connection, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(recordClass) + .withRecordDefinition(recordDefinition) + .withQueryProvider(this.queryProvider) + .build()); + } + + return new JDBCVectorStoreRecordCollection<>( + connection, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(recordClass) + .withRecordDefinition(recordDefinition) + .withQueryProvider(this.queryProvider) + .build()); + } + + /** + * Gets the names of all collections in the vector store. + * + * @return A list of collection names. + */ + @Override + public Mono> getCollectionNamesAsync() { + return Mono.fromCallable(() -> { + List collectionNames = new ArrayList<>(); + try { + ResultSet resultSet = queryProvider.getCollectionNames(); + while (resultSet.next()) { + collectionNames.add(resultSet.getString(1)); + } + + return collectionNames; + } catch (SQLException e) { + throw new SKException("Failed to get collection names.", e); + } + }); + } + + @Override + public Mono prepareAsync() { + return Mono.fromRunnable(() -> { + try { + queryProvider.prepareVectorStore(); + } catch (SQLException e) { + throw new SKException("Failed to prepare vector store.", e); + } + }); + } + + /** + * Builder for creating a {@link JDBCVectorStore}. + */ + public static class Builder { + private Connection connection; + private JDBCVectorStoreOptions options; + + /** + * Sets the connection. + * + * @param connection the connection + * @return the builder + */ + public Builder withConnection(Connection connection) { + this.connection = connection; + return this; + } + + /** + * Sets the options. + * + * @param options the options + * @return the builder + */ + public Builder withOptions(JDBCVectorStoreOptions options) { + this.options = options; + return this; + } + + /** + * Builds the {@link JDBCVectorStore}. + * + * @return the {@link JDBCVectorStore} + */ + public JDBCVectorStore build() { + return buildAsync().block(); + } + + /** + * Builds the {@link JDBCVectorStore} asynchronously. + * + * @return the {@link Mono} with the {@link JDBCVectorStore} + */ + public Mono buildAsync() { + if (connection == null) { + throw new IllegalArgumentException("connection is required"); + } + + JDBCVectorStore vectorStore = new JDBCVectorStore(connection, options); + return vectorStore.prepareAsync().thenReturn(vectorStore); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java new file mode 100644 index 00000000..ceb2c544 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java @@ -0,0 +1,304 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.exceptions.SKException; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.annotation.Nonnull; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.time.OffsetDateTime; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class JDBCVectorStoreDefaultQueryProvider + implements JDBCVectorStoreQueryProvider { + private static final Map, String> supportedKeyTypes; + private static final Map, String> supportedDataTypes; + private static final Map, String> supportedVectorTypes; + + static { + supportedKeyTypes = new HashMap<>(); + supportedKeyTypes.put(String.class, "VARCHAR(255)"); + + supportedDataTypes = new HashMap<>(); + supportedDataTypes.put(String.class, "TEXT"); + supportedDataTypes.put(Integer.class, "INTEGER"); + supportedDataTypes.put(int.class, "INTEGER"); + supportedDataTypes.put(Long.class, "BIGINT"); + supportedDataTypes.put(long.class, "BIGINT"); + supportedDataTypes.put(Float.class, "REAL"); + supportedDataTypes.put(float.class, "REAL"); + supportedDataTypes.put(Double.class, "DOUBLE"); + supportedDataTypes.put(double.class, "DOUBLE"); + supportedDataTypes.put(Boolean.class, "BOOLEAN"); + supportedDataTypes.put(boolean.class, "BOOLEAN"); + supportedDataTypes.put(OffsetDateTime.class, "TIMESTAMPTZ"); + + supportedVectorTypes = new HashMap<>(); + supportedVectorTypes.put(String.class, "TEXT"); + supportedVectorTypes.put(List.class, "TEXT"); + supportedVectorTypes.put(Collection.class, "TEXT"); + } + + protected final Connection connection; + protected final String collectionsTable; + protected final String prefixForCollectionTables; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + public JDBCVectorStoreDefaultQueryProvider( + @Nonnull Connection connection, + @Nonnull String collectionsTable, + @Nonnull String prefixForCollectionTables) { + this.connection = connection; + // Validate table name + if (!isValidSQLIdentifier(collectionsTable)) { + throw new IllegalArgumentException("Invalid collections table name: " + collectionsTable); + } + if (!isValidSQLIdentifier(prefixForCollectionTables)) { + throw new IllegalArgumentException("Invalid prefix for collection tables: " + prefixForCollectionTables); + } + + this.collectionsTable = collectionsTable; + this.prefixForCollectionTables = prefixForCollectionTables; + } + + public JDBCVectorStoreDefaultQueryProvider( + @Nonnull Connection connection) { + this(connection, DEFAULT_COLLECTIONS_TABLE, DEFAULT_PREFIX_FOR_COLLECTION_TABLES); + } + + /** + * Creates a new builder. + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Formats a wildcard string for a query. + * @param wildcards the number of wildcards + * @return the formatted wildcard string + */ + protected String getWildcardString(int wildcards) { + StringBuilder wildcardString = new StringBuilder(); + for (int i = 0; i < wildcards; ++i) { + wildcardString.append("?"); + if (i < wildcards - 1) { + wildcardString.append(", "); + } + } + return wildcardString.toString(); + } + + /** + * Formats the query columns from a record definition. + * @param fields the fields to get the columns from + * @return the formatted query columns + */ + protected String getQueryColumnsFromFields(List fields) { + return fields.stream().map(VectorStoreRecordField::getName) + .collect(Collectors.joining(", ")); + } + + protected String getColumnNamesAndTypes(List fields, Map, String> types) { + List columns = fields.stream() + .map(field -> field.getName() + " " + types.get(field.getType())) + .collect(Collectors.toList()); + + return String.join(", ", columns); + } + + protected String getCollectionTableName(String collectionName) { + return prefixForCollectionTables + collectionName; + } + + @Override + public void prepareVectorStore() throws SQLException { + String createCollectionsTable = + "CREATE TABLE IF NOT EXISTS " + collectionsTable + + " (collectionId VARCHAR(255) PRIMARY KEY);"; + + PreparedStatement createTable = connection.prepareStatement(createCollectionsTable); + createTable.execute(); + } + + @Override + public void validateSupportedTypes(Class recordClass, VectorStoreRecordDefinition recordDefinition) { + VectorStoreRecordDefinition.validateSupportedTypes( + Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)), supportedKeyTypes.keySet()); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getDataDeclaredFields(recordClass), supportedDataTypes.keySet()); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getVectorDeclaredFields(recordClass), supportedVectorTypes.keySet()); + } + + @Override + public boolean collectionExists(String collectionName) throws SQLException { + String query = "SELECT 1 FROM " + collectionsTable + " WHERE collectionId = ?"; + + PreparedStatement statement = connection.prepareStatement(query); + statement.setObject(1, collectionName); + + return statement.executeQuery().next(); + } + + @Override + public void createCollection(String collectionName, Class recordClass, VectorStoreRecordDefinition recordDefinition) throws SQLException { + Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass); + List dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass); + List vectorDeclaredFields = recordDefinition.getVectorDeclaredFields(recordClass); + + String createStorageTable = + "CREATE TABLE IF NOT EXISTS " + getCollectionTableName(collectionName) + + " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, " + + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", " + + getColumnNamesAndTypes(vectorDeclaredFields, supportedVectorTypes) + ");"; + + PreparedStatement createTable = connection.prepareStatement(createStorageTable); + + String insertCollectionQuery = "INSERT INTO " + collectionsTable + " (collectionId) VALUES (?)"; + PreparedStatement insert = connection.prepareStatement(insertCollectionQuery); + insert.setObject(1, collectionName); + + createTable.execute(); + insert.execute(); + } + + @Override + public void deleteCollection(String collectionName) throws SQLException { + String deleteCollectionOperation = "DELETE FROM " + collectionsTable + " WHERE collectionId = ?"; + String dropTableOperation = "DROP TABLE " + getCollectionTableName(collectionName); + + PreparedStatement deleteCollection = connection.prepareStatement(deleteCollectionOperation); + deleteCollection.setObject(1, collectionName); + + PreparedStatement dropTable = connection.prepareStatement(dropTableOperation); + + dropTable.execute(); + deleteCollection.execute(); + } + + @Override + public ResultSet getCollectionNames() throws SQLException { + String query = "SELECT collectionId FROM " + collectionsTable; + + return connection.prepareStatement(query).executeQuery(); + } + + @Override + public ResultSet getRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, GetRecordOptions options) throws SQLException { + List fields; + if (options == null || options.includeVectors()) { + fields = recordDefinition.getAllFields(); + } else { + fields = recordDefinition.getNonVectorFields(); + } + + String query = "SELECT " + getQueryColumnsFromFields(fields) + + " FROM " + getCollectionTableName(collectionName) + + " WHERE " + recordDefinition.getKeyField().getName() + + " IN (" + getWildcardString(keys.size()) + ")"; + + PreparedStatement statement = connection.prepareStatement(query); + for (int i = 0; i < keys.size(); ++i) { + try { + statement.setObject(i + 1, keys.get(i)); + } catch (SQLException e) { + throw new SKException("Failed to set statement values", e); + } + } + + return statement.executeQuery(); + } + + @Override + public void upsertRecords(String collectionName, List records, VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) throws SQLException { + throw new UnsupportedOperationException( + "Upsert is not supported. Try with a specific query provider."); + } + + @Override + public void deleteRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) throws SQLException { + String query = "DELETE FROM " + getCollectionTableName(collectionName) + + " WHERE " + recordDefinition.getKeyField().getName() + + " IN (" + getWildcardString(keys.size()) + ")"; + + PreparedStatement statement = connection.prepareStatement(query); + for (int i = 0; i < keys.size(); ++i) { + try { + statement.setObject(i + 1, keys.get(i)); + } catch (SQLException e) { + throw new SKException("Failed to set statement values", e); + } + } + + statement.execute(); + } + + public static boolean isValidSQLIdentifier(String identifier) { + return identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*"); + } + + /** + * The builder for {@link JDBCVectorStoreDefaultQueryProvider}. + */ + public static class Builder + implements JDBCVectorStoreQueryProvider.Builder { + protected Connection connection; + protected String collectionsTable = DEFAULT_COLLECTIONS_TABLE; + protected String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + /** + * Sets the connection. + * @param connection the connection + * @return the builder + */ + public Builder withConnection(Connection connection) { + this.connection = connection; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTable the collections table name + * @return the builder + */ + public Builder withCollectionsTable(String collectionsTable) { + this.collectionsTable = collectionsTable; + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + this.prefixForCollectionTables = prefixForCollectionTables; + return this; + } + + @Override + public JDBCVectorStoreDefaultQueryProvider build() { + if (connection == null) { + throw new IllegalArgumentException("connection is required"); + } + + return new JDBCVectorStoreDefaultQueryProvider(connection, collectionsTable, prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java new file mode 100644 index 00000000..6ecb59ea --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java @@ -0,0 +1,101 @@ +package com.microsoft.semantickernel.connectors.data.jdbc; + +import javax.annotation.Nullable; + +public class JDBCVectorStoreOptions { + @Nullable + private final JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory; + @Nullable + private final JDBCVectorStoreQueryProvider queryProvider; + + /** + * Creates a new instance of the JDBC vector store options. + * + * @param vectorStoreRecordCollectionFactory The vector store record collection factory. + */ + public JDBCVectorStoreOptions( + @Nullable JDBCVectorStoreQueryProvider queryProvider, + @Nullable JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory) { + this.queryProvider = queryProvider; + this.vectorStoreRecordCollectionFactory = vectorStoreRecordCollectionFactory; + } + + /** + * Creates a new instance of the JDBC vector store options. + */ + public JDBCVectorStoreOptions() { + this(null, null); + } + + /** + * Gets the query provider. + * + * @return the query provider + */ + @Nullable + public JDBCVectorStoreQueryProvider getQueryProvider() { + return queryProvider; + } + + /** + * Creates a new builder. + * + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the vector store record collection factory. + * + * @return the vector store record collection factory + */ + @Nullable + public JDBCVectorStoreRecordCollectionFactory getVectorStoreRecordCollectionFactory() { + return vectorStoreRecordCollectionFactory; + } + + /** + * Builder for JDBC vector store options. + * + */ + public static class Builder { + @Nullable + private JDBCVectorStoreQueryProvider queryProvider; + @Nullable + private JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory; + + /** + * Sets the query provider. + * + * @param queryProvider The query provider. + * @return The updated builder instance. + */ + public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvider) { + this.queryProvider = queryProvider; + return this; + } + + /** + * Sets the vector store record collection factory. + * + * @param vectorStoreRecordCollectionFactory The vector store record collection factory. + * @return The updated builder instance. + */ + public Builder withVectorStoreRecordCollectionFactory( + JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory) { + this.vectorStoreRecordCollectionFactory = vectorStoreRecordCollectionFactory; + return this; + } + + /** + * Builds the JDBC vector store options. + * + * @return The JDBC vector store options. + */ + public JDBCVectorStoreOptions build() { + return new JDBCVectorStoreOptions(queryProvider, vectorStoreRecordCollectionFactory); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java new file mode 100644 index 00000000..6104de8b --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; + +/** + * The JDBC vector store query provider. + * Provides the necessary methods to interact with a JDBC vector store and vector store collections. + */ +public interface JDBCVectorStoreQueryProvider { + /** + * The default name for the collections table. + */ + String DEFAULT_COLLECTIONS_TABLE = "SKCollections"; + + /** + * The prefix for collection tables. + */ + String DEFAULT_PREFIX_FOR_COLLECTION_TABLES = "SKCollection_"; + + /** + * Prepares the vector store. + * Executes any necessary setup steps for the vector store. + * + * @throws SQLException if an error occurs + */ + void prepareVectorStore() throws SQLException; + + /** + * Checks if the types of the record class fields are supported. + * + * @param recordClass the record class + * @param recordDefinition the record definition + */ + void validateSupportedTypes(Class recordClass, VectorStoreRecordDefinition recordDefinition); + + /** + * Checks if a collection exists. + * + * @param collectionName the collection name + * @return true if the collection exists, false otherwise + * @throws SQLException if an error occurs + */ + boolean collectionExists(String collectionName) throws SQLException; + + /** + * Creates a collection. + * + * @param collectionName the collection name + * @param recordClass the record class + * @param recordDefinition the record definition + * @throws SQLException if an error occurs + */ + void createCollection(String collectionName, Class recordClass, VectorStoreRecordDefinition recordDefinition) throws SQLException; + + /** + * Deletes a collection. + * + * @param collectionName the collection name + * @throws SQLException if an error occurs + */ + void deleteCollection(String collectionName) throws SQLException; + + /** + * Gets the names of the collections. + * + * @return the result set + * @throws SQLException if an error occurs + */ + ResultSet getCollectionNames() throws SQLException; + + /** + * Gets the records. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param options the options + * @return the result set + * @throws SQLException if an error occurs + */ + ResultSet getRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, GetRecordOptions options) throws SQLException; + + /** + * Upserts records. + * + * @param collectionName the collection name + * @param records the records + * @param vectorStoreRecordDefinition the record definition + * @param options the options + * @throws SQLException if an error occurs + */ + void upsertRecords(String collectionName, List records, VectorStoreRecordDefinition vectorStoreRecordDefinition, UpsertRecordOptions options) throws SQLException; + + /** + * Deletes records. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param options the options + * @throws SQLException if an error occurs + */ + void deleteRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) throws SQLException; + + /** + * The builder for the JDBC vector store query provider. + */ + interface Builder extends SemanticKernelBuilder { + + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java new file mode 100644 index 00000000..54f0b76d --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java @@ -0,0 +1,331 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import com.microsoft.semantickernel.exceptions.SKException; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import javax.annotation.Nonnull; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +public class JDBCVectorStoreRecordCollection + implements SQLVectorStoreRecordCollection { + private final String collectionName; + private final VectorStoreRecordDefinition recordDefinition; + private final JDBCVectorStoreRecordCollectionOptions options; + private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper; + private final JDBCVectorStoreQueryProvider queryProvider; + + /** + * Creates a new instance of the JDBCVectorRecordStore. + * If using this constructor, call {@link #prepareAsync()} before using the record collection. + * + * @param connection The JDBC connection. + * @param options The options for the store. + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") + public JDBCVectorStoreRecordCollection( + @Nonnull Connection connection, + @Nonnull String collectionName, + @Nonnull JDBCVectorStoreRecordCollectionOptions options) { + this.collectionName = collectionName; + this.options = options; + + // If record definition is not provided, create one from the record class + recordDefinition = options.getRecordDefinition() == null + ? VectorStoreRecordDefinition.fromRecordClass(options.getRecordClass()) + : options.getRecordDefinition(); + + // If mapper is not provided, set a default one + if (options.getVectorStoreRecordMapper() == null) { + vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder() + .withRecordClass(options.getRecordClass()) + .withVectorStoreRecordDefinition(recordDefinition) + .build(); + } else { + vectorStoreRecordMapper = options.getVectorStoreRecordMapper(); + } + + // If the query provider is not provided, set a default one + if (options.getQueryProvider() == null) { + this.queryProvider = new JDBCVectorStoreDefaultQueryProvider(connection); + } else { + this.queryProvider = options.getQueryProvider(); + } + + // Check if the types are supported + queryProvider.validateSupportedTypes(options.getRecordClass(), recordDefinition); + } + + /** + * Gets the name of the collection. + * + * @return The name of the collection. + */ + @Override + public String getCollectionName() { + return collectionName; + } + + /** + * Checks if the collection exists in the store. + * + * @return A Mono emitting a boolean indicating if the collection exists. + */ + @Override + public Mono collectionExistsAsync() { + return Mono.fromCallable( + () -> { + try { + return queryProvider.collectionExists(this.collectionName); + } catch (SQLException e) { + throw new SKException("Failed to check if collection exists", e); + } + }) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Creates the collection in the store. + * + * @return A Mono representing the completion of the creation operation. + */ + @Override + public Mono createCollectionAsync() { + return Mono.fromRunnable( + () -> { + try { + queryProvider.createCollection(this.collectionName, options.getRecordClass(), recordDefinition); + } catch (SQLException e) { + throw new SKException("Failed to create collection", e); + } + }) + .subscribeOn(Schedulers.boundedElastic()) + .then(); + } + + /** + * Creates the collection in the store if it does not exist. + * + * @return A Mono representing the completion of the creation operation. + */ + @Override + public Mono createCollectionIfNotExistsAsync() { + return collectionExistsAsync().map( + exists -> { + if (!exists) { + return createCollectionAsync(); + } + return Mono.empty(); + }) + .flatMap(mono -> mono) + .then(); + } + + /** + * Deletes the collection from the store. + * + * @return A Mono representing the completion of the deletion operation. + */ + @Override + public Mono deleteCollectionAsync() { + return Mono.fromRunnable( + () -> { + try { + queryProvider.deleteCollection(this.collectionName); + } catch (SQLException e) { + throw new SKException("Failed to delete collection", e); + } + }).subscribeOn(Schedulers.boundedElastic()).then(); + } + + /** + * Gets a record from the store. + * + * @param key The key of the record to get. + * @param options The options for getting the record. + * @return A Mono emitting the record. + */ + @Override + public Mono getAsync(String key, GetRecordOptions options) { + return this.getBatchAsync(Collections.singletonList(key), options) + .mapNotNull(records -> { + if (records.isEmpty()) { + return null; + } + return records.get(0); + }); + } + + /** + * Gets a batch of records from the store. + * + * @param keys The keys of the records to get. + * @param options The options for getting the records. + * @return A Mono emitting a collection of records. + */ + @Override + public Mono> getBatchAsync(List keys, GetRecordOptions options) { + return Mono.fromCallable( + () -> { + List records = new ArrayList<>(); + + try { + ResultSet resultSet = queryProvider.getRecords(this.collectionName, keys, recordDefinition, options); + while (resultSet.next()) { + records.add(vectorStoreRecordMapper.mapStorageModeltoRecord(resultSet)); + } + } catch (SQLException e) { + throw new SKException("Failed to get records", e); + } + + return records; + }).subscribeOn(Schedulers.boundedElastic()); + } + + protected String getKeyFromRecord(Record data) { + try { + Field keyField = data.getClass().getDeclaredField(recordDefinition.getKeyField().getName()); + keyField.setAccessible(true); + return (String) keyField.get(data); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new SKException("Failed to get key from record", e); + } + } + + /** + * Inserts or updates a record in the store. + * + * @param data The record to upsert. + * @param options The options for upserting the record. + * @return A Mono emitting the key of the upserted record. + */ + @Override + public Mono upsertAsync(Record data, UpsertRecordOptions options) { + return this.upsertBatchAsync(Collections.singletonList(data), options) + .mapNotNull(keys -> { + if (keys.isEmpty()) { + return null; + } + return keys.get(0); + }); + } + + /** + * Inserts or updates a batch of records in the store. + * + * @param data The records to upsert. + * @param options The options for upserting the records. + * @return A Mono emitting a collection of keys of the upserted records. + */ + @Override + public Mono> upsertBatchAsync(List data, UpsertRecordOptions options) { + return Mono.fromCallable( + () -> { + try { + queryProvider.upsertRecords(this.collectionName, data, recordDefinition, options); + + return data.stream().map(this::getKeyFromRecord).collect(Collectors.toList()); + } catch (SQLException e) { + throw new SKException("Failed to upsert records", e); + } + }) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Deletes a record from the store. + * + * @param key The key of the record to delete. + * @param options The options for deleting the record. + * @return A Mono representing the completion of the deletion operation. + */ + @Override + public Mono deleteAsync(String key, DeleteRecordOptions options) { + return this.deleteBatchAsync(Collections.singletonList(key), options); + } + + /** + * Deletes a batch of records from the store. + * + * @param keys The keys of the records to delete. + * @param options The options for deleting the records. + * @return A Mono representing the completion of the deletion operation. + */ + @Override + public Mono deleteBatchAsync(List keys, DeleteRecordOptions options) { + return Mono.fromRunnable( + () -> { + try { + queryProvider.deleteRecords(this.collectionName, keys, recordDefinition, options); + } catch (SQLException e) { + throw new SKException("Failed to delete records", e); + } + }).subscribeOn(Schedulers.boundedElastic()).then(); + } + + /** + * Prepares the collection for use. + * + * @return A Mono representing the completion of the preparation operation. + */ + @Override + public Mono prepareAsync() { + return Mono.fromRunnable(() -> { + try { + queryProvider.prepareVectorStore(); + } catch (SQLException e) { + throw new SKException("Failed to prepare vector store record collection", e); + } + }).subscribeOn(Schedulers.boundedElastic()).then(); + } + + public static class Builder implements SemanticKernelBuilder> { + private Connection connection; + private String collectionName; + private JDBCVectorStoreRecordCollectionOptions options; + + public Builder withConnection(Connection connection) { + this.connection = connection; + return this; + } + + public Builder withCollectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + public Builder withOptions(JDBCVectorStoreRecordCollectionOptions options) { + this.options = options; + return this; + } + + @Override + public JDBCVectorStoreRecordCollection build() { + if (connection == null) { + throw new IllegalArgumentException("connection is required"); + } + if (collectionName == null) { + throw new IllegalArgumentException("collectionName is required"); + } + if (options == null) { + throw new IllegalArgumentException("options is required"); + } + + return new JDBCVectorStoreRecordCollection<>(connection, collectionName, options); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java new file mode 100644 index 00000000..4d92eb25 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java @@ -0,0 +1,19 @@ +package com.microsoft.semantickernel.connectors.data.jdbc; + +import java.sql.Connection; + +/** + * Factory for creating JDBC vector store record collections. + */ +public interface JDBCVectorStoreRecordCollectionFactory { + /** + * Creates a new JDBC vector store record collection. + * + * @param options The options for the collection. + * @return The new JDBC vector store record collection. + */ + JDBCVectorStoreRecordCollection createVectorStoreRecordCollection( + Connection connection, + String collectionName, + JDBCVectorStoreRecordCollectionOptions options); +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java new file mode 100644 index 00000000..6fe6dbaf --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java @@ -0,0 +1,128 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; + +public class JDBCVectorStoreRecordCollectionOptions { + private final Class recordClass; + private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper; + private final VectorStoreRecordDefinition recordDefinition; + private final JDBCVectorStoreQueryProvider queryProvider; + + public JDBCVectorStoreRecordCollectionOptions( + Class recordClass, + VectorStoreRecordDefinition recordDefinition, + JDBCVectorStoreRecordMapper vectorStoreRecordMapper, + JDBCVectorStoreQueryProvider queryProvider) { + this.recordClass = recordClass; + this.recordDefinition = recordDefinition; + this.vectorStoreRecordMapper = vectorStoreRecordMapper; + this.queryProvider = queryProvider; + } + + /** + * Creates a new builder. + * @param the record type + * @return the builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Gets the record class. + * @return the record class + */ + public Class getRecordClass() { + return recordClass; + } + + /** + * Gets the record definition. + * @return the record definition + */ + public VectorStoreRecordDefinition getRecordDefinition() { + return recordDefinition; + } + + /** + * Gets the vector store record mapper. + * @return the vector store record mapper + */ + public JDBCVectorStoreRecordMapper getVectorStoreRecordMapper() { + return vectorStoreRecordMapper; + } + + /** + * Gets the query provider. + * @return the query provider + */ + public JDBCVectorStoreQueryProvider getQueryProvider() { + return queryProvider; + } + + public static class Builder { + private Class recordClass; + private VectorStoreRecordDefinition recordDefinition; + private JDBCVectorStoreRecordMapper vectorStoreRecordMapper; + private JDBCVectorStoreQueryProvider queryProvider; + + /** + * Sets the record class. + * @param recordClass the record class + * @return the builder + */ + public Builder withRecordClass(Class recordClass) { + this.recordClass = recordClass; + return this; + } + + /** + * Sets the record definition. + * @param recordDefinition the record definition + * @return the builder + */ + public Builder withRecordDefinition(VectorStoreRecordDefinition recordDefinition) { + this.recordDefinition = recordDefinition; + return this; + } + + /** + * Sets the vector store record mapper. + * @param vectorStoreRecordMapper the vector store record mapper + * @return the builder + */ + public Builder withVectorStoreRecordMapper( + JDBCVectorStoreRecordMapper vectorStoreRecordMapper) { + this.vectorStoreRecordMapper = vectorStoreRecordMapper; + return this; + } + + /** + * Sets the query provider. + * @param queryProvider the query provider + * @return the builder + */ + public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvider) { + this.queryProvider = queryProvider; + return this; + } + + /** + * Builds the options. + * @return the options + */ + public JDBCVectorStoreRecordCollectionOptions build() { + if (recordClass == null) { + throw new IllegalArgumentException("recordClass is required"); + } + + return new JDBCVectorStoreRecordCollectionOptions<>( + recordClass, + recordDefinition, + vectorStoreRecordMapper, + queryProvider + ); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java new file mode 100644 index 00000000..6eff0c7d --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.exceptions.SKException; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; + +import java.sql.ResultSetMetaData; +import java.util.List; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.function.Function; + +public class JDBCVectorStoreRecordMapper + extends VectorStoreRecordMapper { + + /** + * Constructs a new instance of the VectorStoreRecordMapper. + * + * @param storageModelToRecordMapper the function to convert a storage model to a record + */ + protected JDBCVectorStoreRecordMapper(Function storageModelToRecordMapper) { + super(null, storageModelToRecordMapper); + } + + /** + * Creates a new builder. + * + * @param the record type + * @return the builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Operation not supported. + */ + @Override + public ResultSet mapRecordToStorageModel(Record record) { + throw new UnsupportedOperationException("Not implemented"); + } + + public static class Builder + implements SemanticKernelBuilder> { + private Class recordClass; + private VectorStoreRecordDefinition vectorStoreRecordDefinition; + + /** + * Sets the record class. + * + * @param recordClass the record class + * @return the builder + */ + public Builder withRecordClass(Class recordClass) { + this.recordClass = recordClass; + return this; + } + + /** + * Sets the vector store record definition. + * + * @param vectorStoreRecordDefinition the vector store record definition + * @return the builder + */ + public Builder withVectorStoreRecordDefinition( + VectorStoreRecordDefinition vectorStoreRecordDefinition) { + this.vectorStoreRecordDefinition = vectorStoreRecordDefinition; + return this; + } + + /** + * Builds the {@link JDBCVectorStoreRecordMapper}. + * + * @return the {@link JDBCVectorStoreRecordMapper} + */ + public JDBCVectorStoreRecordMapper build() { + if (recordClass == null) { + throw new IllegalArgumentException("recordClass is required"); + } + if (vectorStoreRecordDefinition == null) { + throw new IllegalArgumentException("vectorStoreRecordDefinition is required"); + } + + return new JDBCVectorStoreRecordMapper<>( + resultSet -> { + try { + Constructor constructor = recordClass.getDeclaredConstructor(); + constructor.setAccessible(true); + Record record = (Record) constructor.newInstance(); + + // Select fields from the record definition. + // Check if vector fields are present in the result set. + List fields; + ResultSetMetaData metaData = resultSet.getMetaData(); + if (metaData.getColumnCount() == vectorStoreRecordDefinition.getAllFields() + .size()) { + fields = vectorStoreRecordDefinition.getAllFields(); + } else { + fields = vectorStoreRecordDefinition.getNonVectorFields(); + } + + for (VectorStoreRecordField field : fields) { + Object value = resultSet.getObject(field.getName()); + Field recordField = recordClass.getDeclaredField(field.getName()); + recordField.setAccessible(true); + + // If the field is a vector field, deserialize the JSON string + if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = recordField.getType(); + + // If the vector type is a string, set the value directly + if (vectorType.equals(String.class)) { + recordField.set(record, value); + } else { + // Deserialize the JSON string to the vector type + recordField.set(record, + new ObjectMapper().readValue((String) value, vectorType)); + } + } else { + recordField.set(record, value); + } + } + + return record; + } catch (NoSuchMethodException e) { + throw new SKException("Default constructor not found.", e); + } catch (InstantiationException | InvocationTargetException e) { + throw new SKException(String.format( + "SK cannot instantiate %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (JsonProcessingException e) { + throw new SKException(String.format( + "SK cannot deserialize %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (SQLException | NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + }); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java new file mode 100644 index 00000000..f7211f24 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import com.microsoft.semantickernel.exceptions.SKException; + +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.List; + +public class MySQLVectorStoreQueryProvider extends + JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { + + public MySQLVectorStoreQueryProvider(Connection connection, String collectionsTable, String prefixForCollectionTables) { + super(connection, collectionsTable, prefixForCollectionTables); + } + + /** + * Creates a new builder. + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + private void setStatementValues(PreparedStatement statement, Object record, List fields) { + for (int i = 0; i < fields.size(); ++i) { + VectorStoreRecordField field = fields.get(i); + try { + Field recordField = record.getClass().getDeclaredField(field.getName()); + recordField.setAccessible(true); + Object value = recordField.get(record); + + if (field instanceof VectorStoreRecordKeyField) { + statement.setObject(i + 1, (String) value); + } else if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = record.getClass().getDeclaredField(field.getName()).getType(); + + // If the vector field is other than String, serialize it to JSON + if (vectorType.equals(String.class)) { + statement.setObject(i + 1, value); + } else { + // Serialize the vector to JSON + statement.setObject(i + 1, new ObjectMapper().writeValueAsString(value)); + } + } else { + statement.setObject(i + 1, value); + } + } catch (NoSuchFieldException | IllegalAccessException | SQLException e) { + throw new SKException("Failed to set statement values", e); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public void upsertRecords(String collectionName, List records, VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) throws SQLException { + List fields = recordDefinition.getAllFields(); + + StringBuilder onDuplicateKeyUpdate = new StringBuilder(); + for (int i = 0; i < fields.size(); ++i) { + VectorStoreRecordField field = fields.get(i); + if (i > 0) { + onDuplicateKeyUpdate.append(", "); + } + + onDuplicateKeyUpdate.append(field.getName()).append(" = VALUES(").append(field.getName()).append(")"); + } + + String query = "INSERT INTO " + getCollectionTableName(collectionName) + + " (" + getQueryColumnsFromFields(fields) + ")" + + " VALUES (" + getWildcardString(fields.size()) + ")" + + " ON DUPLICATE KEY UPDATE " + onDuplicateKeyUpdate; + + PreparedStatement statement = connection.prepareStatement(query); + + for (Object record : records) { + setStatementValues(statement, record, recordDefinition.getAllFields()); + statement.addBatch(); + } + + statement.executeBatch(); + } + + public static class Builder + extends JDBCVectorStoreDefaultQueryProvider.Builder { + public MySQLVectorStoreQueryProvider build() { + if (connection == null) { + throw new IllegalArgumentException("connection is required"); + } + + return new MySQLVectorStoreQueryProvider(connection, collectionsTable, prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java new file mode 100644 index 00000000..dd19f787 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java @@ -0,0 +1,15 @@ +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.VectorStore; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import reactor.core.publisher.Mono; + +public interface SQLVectorStore> extends VectorStore { + + /** + * Prepares the vector store. + * + * @return A {@link Mono} that completes when the vector store is prepared to be used. + */ + Mono prepareAsync(); +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java new file mode 100644 index 00000000..3583a273 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java @@ -0,0 +1,14 @@ +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import reactor.core.publisher.Mono; + +public interface SQLVectorStoreRecordCollection extends VectorStoreRecordCollection { + + /** + * Prepares the vector store record collection. + * + * @return A {@link Mono} that completes when the vector store record collection is prepared to be used. + */ + Mono prepareAsync(); +} From 4fe5a421ad4cb7d050216071a252d3611a8a1f52 Mon Sep 17 00:00:00 2001 From: Milder Hernandez Cagua Date: Thu, 1 Aug 2024 02:28:23 -0700 Subject: [PATCH 3/6] Format --- .../services/audio/AudioToTextService.java | 3 +- .../services/openai/OpenAiServiceBuilder.java | 5 +- .../textcompletion/TextGenerationService.java | 3 +- ...reAISearchVectorStoreRecordCollection.java | 13 +- .../connectors/data/jdbc/JDBCVectorStore.java | 69 +++--- .../JDBCVectorStoreDefaultQueryProvider.java | 214 +++++++++++------- .../data/jdbc/JDBCVectorStoreOptions.java | 1 + .../jdbc/JDBCVectorStoreQueryProvider.java | 41 ++-- .../jdbc/JDBCVectorStoreRecordCollection.java | 99 +++----- ...DBCVectorStoreRecordCollectionFactory.java | 1 + ...DBCVectorStoreRecordCollectionOptions.java | 72 +++++- .../jdbc/MySQLVectorStoreQueryProvider.java | 44 ++-- .../connectors/data/jdbc/SQLVectorStore.java | 4 +- .../jdbc/SQLVectorStoreRecordCollection.java | 4 +- .../RedisVectorStoreRecordCollection.java | 9 +- .../VolatileVectorStoreRecordCollection.java | 3 +- .../VectorStoreRecordDefinition.java | 31 +-- 17 files changed, 351 insertions(+), 265 deletions(-) diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java index 70224596..871d4cb4 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java @@ -33,7 +33,8 @@ static Builder builder() { /** * Builder for the AudioToTextService. */ - abstract class Builder extends OpenAiServiceBuilder { + abstract class Builder + extends OpenAiServiceBuilder { } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java index e42b2c6e..5386cd83 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java @@ -11,8 +11,9 @@ * @param The service type * @param The builder type */ -public abstract class OpenAiServiceBuilder> implements - +public abstract class OpenAiServiceBuilder> + implements + SemanticKernelBuilder { @Nullable diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java index 8010b67f..0ab08f5f 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java @@ -61,6 +61,7 @@ Flux getStreamingTextContentsAsync( /** * Builder for a TextGenerationService */ - abstract class Builder extends OpenAiServiceBuilder { + abstract class Builder + extends OpenAiServiceBuilder { } } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java index 70d517ac..5155299d 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java @@ -91,14 +91,15 @@ public AzureAISearchVectorStoreRecordCollection( // Validate supported types VectorStoreRecordDefinition.validateSupportedTypes( - Collections.singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), - supportedKeyTypes); + Collections + .singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), + supportedKeyTypes); VectorStoreRecordDefinition.validateSupportedTypes( - recordDefinition.getDataDeclaredFields(this.options.getRecordClass()), - supportedDataTypes); + recordDefinition.getDataDeclaredFields(this.options.getRecordClass()), + supportedDataTypes); VectorStoreRecordDefinition.validateSupportedTypes( - recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), - supportedVectorTypes); + recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), + supportedVectorTypes); // Add non-vector fields to the list nonVectorFields.add(this.recordDefinition.getKeyField().getName()); diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java index df2155a0..7e02f4a8 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java @@ -1,15 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.jdbc; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; -import com.microsoft.semantickernel.exceptions.SKException; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayList; import java.util.List; /** @@ -27,14 +26,18 @@ public class JDBCVectorStore implements SQLVectorStore JDBCVectorStoreRecordCollection getCollection( - @Nonnull String collectionName, - @Nonnull Class recordClass, - @Nullable VectorStoreRecordDefinition recordDefinition) { + @Nonnull String collectionName, + @Nonnull Class recordClass, + @Nullable VectorStoreRecordDefinition recordDefinition) { if (this.options != null && this.options.getVectorStoreRecordCollectionFactory() != null) { return this.options.getVectorStoreRecordCollectionFactory() @@ -67,20 +70,20 @@ public JDBCVectorStoreRecordCollection getCollection( connection, collectionName, JDBCVectorStoreRecordCollectionOptions.builder() - .withRecordClass(recordClass) - .withRecordDefinition(recordDefinition) - .withQueryProvider(this.queryProvider) - .build()); - } - - return new JDBCVectorStoreRecordCollection<>( - connection, - collectionName, - JDBCVectorStoreRecordCollectionOptions.builder() .withRecordClass(recordClass) .withRecordDefinition(recordDefinition) .withQueryProvider(this.queryProvider) .build()); + } + + return new JDBCVectorStoreRecordCollection<>( + connection, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(recordClass) + .withRecordDefinition(recordDefinition) + .withQueryProvider(this.queryProvider) + .build()); } /** @@ -90,30 +93,17 @@ public JDBCVectorStoreRecordCollection getCollection( */ @Override public Mono> getCollectionNamesAsync() { - return Mono.fromCallable(() -> { - List collectionNames = new ArrayList<>(); - try { - ResultSet resultSet = queryProvider.getCollectionNames(); - while (resultSet.next()) { - collectionNames.add(resultSet.getString(1)); - } - - return collectionNames; - } catch (SQLException e) { - throw new SKException("Failed to get collection names.", e); - } - }); + return Mono.fromCallable(queryProvider::getCollectionNames) + .subscribeOn(Schedulers.boundedElastic()); } + /** + * Prepares the vector store. + */ @Override public Mono prepareAsync() { - return Mono.fromRunnable(() -> { - try { - queryProvider.prepareVectorStore(); - } catch (SQLException e) { - throw new SKException("Failed to prepare vector store.", e); - } - }); + return Mono.fromRunnable(queryProvider::prepareVectorStore) + .subscribeOn(Schedulers.boundedElastic()).then(); } /** @@ -129,6 +119,7 @@ public static class Builder { * @param connection the connection * @return the builder */ + @SuppressFBWarnings("EI_EXPOSE_REP2") public Builder withConnection(Connection connection) { this.connection = connection; return this; diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java index ceb2c544..9bf75afd 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java @@ -16,6 +16,7 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.time.OffsetDateTime; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -52,34 +53,20 @@ public class JDBCVectorStoreDefaultQueryProvider supportedVectorTypes.put(List.class, "TEXT"); supportedVectorTypes.put(Collection.class, "TEXT"); } - protected final Connection connection; protected final String collectionsTable; protected final String prefixForCollectionTables; @SuppressFBWarnings("EI_EXPOSE_REP2") - public JDBCVectorStoreDefaultQueryProvider( - @Nonnull Connection connection, - @Nonnull String collectionsTable, - @Nonnull String prefixForCollectionTables) { + protected JDBCVectorStoreDefaultQueryProvider( + @Nonnull Connection connection, + @Nonnull String collectionsTable, + @Nonnull String prefixForCollectionTables) { this.connection = connection; - // Validate table name - if (!isValidSQLIdentifier(collectionsTable)) { - throw new IllegalArgumentException("Invalid collections table name: " + collectionsTable); - } - if (!isValidSQLIdentifier(prefixForCollectionTables)) { - throw new IllegalArgumentException("Invalid prefix for collection tables: " + prefixForCollectionTables); - } - this.collectionsTable = collectionsTable; this.prefixForCollectionTables = prefixForCollectionTables; } - public JDBCVectorStoreDefaultQueryProvider( - @Nonnull Connection connection) { - this(connection, DEFAULT_COLLECTIONS_TABLE, DEFAULT_PREFIX_FOR_COLLECTION_TABLES); - } - /** * Creates a new builder. * @return the builder @@ -116,8 +103,8 @@ protected String getQueryColumnsFromFields(List fields) protected String getColumnNamesAndTypes(List fields, Map, String> types) { List columns = fields.stream() - .map(field -> field.getName() + " " + types.get(field.getType())) - .collect(Collectors.toList()); + .map(field -> field.getName() + " " + types.get(field.getType())) + .collect(Collectors.toList()); return String.join(", ", columns); } @@ -127,80 +114,126 @@ protected String getCollectionTableName(String collectionName) { } @Override - public void prepareVectorStore() throws SQLException { - String createCollectionsTable = - "CREATE TABLE IF NOT EXISTS " + collectionsTable - + " (collectionId VARCHAR(255) PRIMARY KEY);"; - - PreparedStatement createTable = connection.prepareStatement(createCollectionsTable); - createTable.execute(); + public void prepareVectorStore() { + String createCollectionsTable = "CREATE TABLE IF NOT EXISTS " + collectionsTable + + " (collectionId VARCHAR(255) PRIMARY KEY);"; + + try (PreparedStatement createTable = connection.prepareStatement(createCollectionsTable)) { + createTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to prepare vector store", e); + } } @Override - public void validateSupportedTypes(Class recordClass, VectorStoreRecordDefinition recordDefinition) { + public void validateSupportedTypes(Class recordClass, + VectorStoreRecordDefinition recordDefinition) { VectorStoreRecordDefinition.validateSupportedTypes( - Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)), supportedKeyTypes.keySet()); + Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)), + supportedKeyTypes.keySet()); VectorStoreRecordDefinition.validateSupportedTypes( - recordDefinition.getDataDeclaredFields(recordClass), supportedDataTypes.keySet()); + recordDefinition.getDataDeclaredFields(recordClass), supportedDataTypes.keySet()); VectorStoreRecordDefinition.validateSupportedTypes( - recordDefinition.getVectorDeclaredFields(recordClass), supportedVectorTypes.keySet()); + recordDefinition.getVectorDeclaredFields(recordClass), supportedVectorTypes.keySet()); } @Override - public boolean collectionExists(String collectionName) throws SQLException { + public boolean collectionExists(String collectionName) { + validateSQLidentifier(collectionsTable); + String query = "SELECT 1 FROM " + collectionsTable + " WHERE collectionId = ?"; - PreparedStatement statement = connection.prepareStatement(query); - statement.setObject(1, collectionName); + try (PreparedStatement statement = connection.prepareStatement(query)) { + statement.setObject(1, collectionName); - return statement.executeQuery().next(); + return statement.executeQuery().next(); + } catch (SQLException e) { + throw new SKException("Failed to check if collection exists", e); + } } @Override - public void createCollection(String collectionName, Class recordClass, VectorStoreRecordDefinition recordDefinition) throws SQLException { + public void createCollection(String collectionName, Class recordClass, + VectorStoreRecordDefinition recordDefinition) { + validateSQLidentifier(collectionName); + Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass); List dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass); List vectorDeclaredFields = recordDefinition.getVectorDeclaredFields(recordClass); - String createStorageTable = - "CREATE TABLE IF NOT EXISTS " + getCollectionTableName(collectionName) - + " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, " - + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", " - + getColumnNamesAndTypes(vectorDeclaredFields, supportedVectorTypes) + ");"; + String createStorageTable = "CREATE TABLE IF NOT EXISTS " + + getCollectionTableName(collectionName) + + " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, " + + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", " + + getColumnNamesAndTypes(vectorDeclaredFields, supportedVectorTypes) + ");"; - PreparedStatement createTable = connection.prepareStatement(createStorageTable); + String insertCollectionQuery = "INSERT INTO " + collectionsTable + + " (collectionId) VALUES (?)"; - String insertCollectionQuery = "INSERT INTO " + collectionsTable + " (collectionId) VALUES (?)"; - PreparedStatement insert = connection.prepareStatement(insertCollectionQuery); - insert.setObject(1, collectionName); + try (PreparedStatement createTable = connection.prepareStatement(createStorageTable)) { + createTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to create collection", e); + } - createTable.execute(); - insert.execute(); + try (PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) { + insert.setObject(1, collectionName); + insert.execute(); + } catch (SQLException e) { + throw new SKException("Failed to insert collection", e); + } } @Override - public void deleteCollection(String collectionName) throws SQLException { - String deleteCollectionOperation = "DELETE FROM " + collectionsTable + " WHERE collectionId = ?"; - String dropTableOperation = "DROP TABLE " + getCollectionTableName(collectionName); + public void deleteCollection(String collectionName) { + validateSQLidentifier(collectionsTable); + validateSQLidentifier(getCollectionTableName(collectionName)); - PreparedStatement deleteCollection = connection.prepareStatement(deleteCollectionOperation); - deleteCollection.setObject(1, collectionName); + String deleteCollectionOperation = "DELETE FROM " + collectionsTable + + " WHERE collectionId = ?"; + String dropTableOperation = "DROP TABLE " + getCollectionTableName(collectionName); - PreparedStatement dropTable = connection.prepareStatement(dropTableOperation); + try (PreparedStatement deleteCollection = connection + .prepareStatement(deleteCollectionOperation)) { + deleteCollection.setObject(1, collectionName); + deleteCollection.execute(); + } catch (SQLException e) { + throw new SKException("Failed to delete collection", e); + } - dropTable.execute(); - deleteCollection.execute(); + try (PreparedStatement dropTable = connection.prepareStatement(dropTableOperation)) { + dropTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to drop table", e); + } } @Override - public ResultSet getCollectionNames() throws SQLException { + public List getCollectionNames() { + validateSQLidentifier(collectionsTable); + String query = "SELECT collectionId FROM " + collectionsTable; - return connection.prepareStatement(query).executeQuery(); + try (PreparedStatement statement = connection.prepareStatement(query)) { + List collectionNames = new ArrayList<>(); + ResultSet resultSet = statement.executeQuery(); + + while (resultSet.next()) { + collectionNames.add(resultSet.getString(1)); + } + + return Collections.unmodifiableList(collectionNames); + } catch (SQLException e) { + throw new SKException("Failed to get collection names", e); + } } @Override - public ResultSet getRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, GetRecordOptions options) throws SQLException { + public List getRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, + GetRecordOptions options) { + validateSQLidentifier(getCollectionTableName(collectionName)); + List fields; if (options == null || options.includeVectors()) { fields = recordDefinition.getAllFields(); @@ -209,48 +242,59 @@ public ResultSet getRecords(String collectionName, List keys, VectorStor } String query = "SELECT " + getQueryColumnsFromFields(fields) - + " FROM " + getCollectionTableName(collectionName) - + " WHERE " + recordDefinition.getKeyField().getName() - + " IN (" + getWildcardString(keys.size()) + ")"; + + " FROM " + getCollectionTableName(collectionName) + + " WHERE " + recordDefinition.getKeyField().getName() + + " IN (" + getWildcardString(keys.size()) + ")"; - PreparedStatement statement = connection.prepareStatement(query); - for (int i = 0; i < keys.size(); ++i) { - try { + try (PreparedStatement statement = connection.prepareStatement(query)) { + for (int i = 0; i < keys.size(); ++i) { statement.setObject(i + 1, keys.get(i)); - } catch (SQLException e) { - throw new SKException("Failed to set statement values", e); } - } - return statement.executeQuery(); + List records = new ArrayList<>(); + ResultSet resultSet = statement.executeQuery(); + + while (resultSet.next()) { + records.add(mapper.mapStorageModeltoRecord(resultSet)); + } + + return Collections.unmodifiableList(records); + } catch (SQLException e) { + throw new SKException("Failed to set statement values", e); + } } @Override - public void upsertRecords(String collectionName, List records, VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) throws SQLException { + public void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { throw new UnsupportedOperationException( - "Upsert is not supported. Try with a specific query provider."); + "Upsert is not supported. Try with a specific query provider."); } @Override - public void deleteRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) throws SQLException { + public void deleteRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) { + validateSQLidentifier(getCollectionTableName(collectionName)); + String query = "DELETE FROM " + getCollectionTableName(collectionName) - + " WHERE " + recordDefinition.getKeyField().getName() - + " IN (" + getWildcardString(keys.size()) + ")"; + + " WHERE " + recordDefinition.getKeyField().getName() + + " IN (" + getWildcardString(keys.size()) + ")"; - PreparedStatement statement = connection.prepareStatement(query); - for (int i = 0; i < keys.size(); ++i) { - try { + try (PreparedStatement statement = connection.prepareStatement(query)) { + for (int i = 0; i < keys.size(); ++i) { statement.setObject(i + 1, keys.get(i)); - } catch (SQLException e) { - throw new SKException("Failed to set statement values", e); } - } - statement.execute(); + statement.execute(); + } catch (SQLException e) { + throw new SKException("Failed to set statement values", e); + } } - public static boolean isValidSQLIdentifier(String identifier) { - return identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*"); + public static void validateSQLidentifier(String identifier) { + if (!identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*")) { + throw new IllegalArgumentException("Invalid SQL identifier: " + identifier); + } } /** @@ -267,6 +311,7 @@ public static class Builder * @param connection the connection * @return the builder */ + @SuppressFBWarnings("EI_EXPOSE_REP2") public Builder withConnection(Connection connection) { this.connection = connection; return this; @@ -278,6 +323,7 @@ public Builder withConnection(Connection connection) { * @return the builder */ public Builder withCollectionsTable(String collectionsTable) { + validateSQLidentifier(collectionsTable); this.collectionsTable = collectionsTable; return this; } @@ -288,6 +334,7 @@ public Builder withCollectionsTable(String collectionsTable) { * @return the builder */ public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + validateSQLidentifier(prefixForCollectionTables); this.prefixForCollectionTables = prefixForCollectionTables; return this; } @@ -298,7 +345,8 @@ public JDBCVectorStoreDefaultQueryProvider build() { throw new IllegalArgumentException("connection is required"); } - return new JDBCVectorStoreDefaultQueryProvider(connection, collectionsTable, prefixForCollectionTables); + return new JDBCVectorStoreDefaultQueryProvider(connection, collectionsTable, + prefixForCollectionTables); } } } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java index 6ecb59ea..580868e9 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.jdbc; import javax.annotation.Nullable; diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java index 6104de8b..fd75dac7 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java @@ -7,8 +7,6 @@ import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; -import java.sql.ResultSet; -import java.sql.SQLException; import java.util.List; /** @@ -29,10 +27,8 @@ public interface JDBCVectorStoreQueryProvider { /** * Prepares the vector store. * Executes any necessary setup steps for the vector store. - * - * @throws SQLException if an error occurs */ - void prepareVectorStore() throws SQLException; + void prepareVectorStore(); /** * Checks if the types of the record class fields are supported. @@ -47,9 +43,8 @@ public interface JDBCVectorStoreQueryProvider { * * @param collectionName the collection name * @return true if the collection exists, false otherwise - * @throws SQLException if an error occurs */ - boolean collectionExists(String collectionName) throws SQLException; + boolean collectionExists(String collectionName); /** * Creates a collection. @@ -57,37 +52,37 @@ public interface JDBCVectorStoreQueryProvider { * @param collectionName the collection name * @param recordClass the record class * @param recordDefinition the record definition - * @throws SQLException if an error occurs */ - void createCollection(String collectionName, Class recordClass, VectorStoreRecordDefinition recordDefinition) throws SQLException; + void createCollection(String collectionName, Class recordClass, + VectorStoreRecordDefinition recordDefinition); /** * Deletes a collection. * * @param collectionName the collection name - * @throws SQLException if an error occurs */ - void deleteCollection(String collectionName) throws SQLException; + void deleteCollection(String collectionName); /** - * Gets the names of the collections. + * Gets the collection names. * - * @return the result set - * @throws SQLException if an error occurs + * @return the collection names */ - ResultSet getCollectionNames() throws SQLException; + List getCollectionNames(); /** - * Gets the records. + * Gets records. * * @param collectionName the collection name * @param keys the keys * @param recordDefinition the record definition + * @param mapper the mapper * @param options the options - * @return the result set - * @throws SQLException if an error occurs + * @return the records */ - ResultSet getRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, GetRecordOptions options) throws SQLException; + List getRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, + GetRecordOptions options); /** * Upserts records. @@ -96,9 +91,9 @@ public interface JDBCVectorStoreQueryProvider { * @param records the records * @param vectorStoreRecordDefinition the record definition * @param options the options - * @throws SQLException if an error occurs */ - void upsertRecords(String collectionName, List records, VectorStoreRecordDefinition vectorStoreRecordDefinition, UpsertRecordOptions options) throws SQLException; + void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition vectorStoreRecordDefinition, UpsertRecordOptions options); /** * Deletes records. @@ -107,9 +102,9 @@ public interface JDBCVectorStoreQueryProvider { * @param keys the keys * @param recordDefinition the record definition * @param options the options - * @throws SQLException if an error occurs */ - void deleteRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) throws SQLException; + void deleteRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options); /** * The builder for the JDBC vector store query provider. diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java index 54f0b76d..0096794c 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java @@ -14,9 +14,6 @@ import javax.annotation.Nonnull; import java.lang.reflect.Field; import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -38,30 +35,32 @@ public class JDBCVectorStoreRecordCollection */ @SuppressFBWarnings("EI_EXPOSE_REP2") public JDBCVectorStoreRecordCollection( - @Nonnull Connection connection, - @Nonnull String collectionName, - @Nonnull JDBCVectorStoreRecordCollectionOptions options) { + @Nonnull Connection connection, + @Nonnull String collectionName, + @Nonnull JDBCVectorStoreRecordCollectionOptions options) { this.collectionName = collectionName; this.options = options; // If record definition is not provided, create one from the record class recordDefinition = options.getRecordDefinition() == null - ? VectorStoreRecordDefinition.fromRecordClass(options.getRecordClass()) - : options.getRecordDefinition(); + ? VectorStoreRecordDefinition.fromRecordClass(options.getRecordClass()) + : options.getRecordDefinition(); // If mapper is not provided, set a default one if (options.getVectorStoreRecordMapper() == null) { vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder() - .withRecordClass(options.getRecordClass()) - .withVectorStoreRecordDefinition(recordDefinition) - .build(); + .withRecordClass(options.getRecordClass()) + .withVectorStoreRecordDefinition(recordDefinition) + .build(); } else { vectorStoreRecordMapper = options.getVectorStoreRecordMapper(); } - + // If the query provider is not provided, set a default one if (options.getQueryProvider() == null) { - this.queryProvider = new JDBCVectorStoreDefaultQueryProvider(connection); + this.queryProvider = JDBCVectorStoreDefaultQueryProvider.builder() + .withConnection(connection) + .build(); } else { this.queryProvider = options.getQueryProvider(); } @@ -88,14 +87,8 @@ public String getCollectionName() { @Override public Mono collectionExistsAsync() { return Mono.fromCallable( - () -> { - try { - return queryProvider.collectionExists(this.collectionName); - } catch (SQLException e) { - throw new SKException("Failed to check if collection exists", e); - } - }) - .subscribeOn(Schedulers.boundedElastic()); + () -> queryProvider.collectionExists(this.collectionName)) + .subscribeOn(Schedulers.boundedElastic()); } /** @@ -106,15 +99,10 @@ public Mono collectionExistsAsync() { @Override public Mono createCollectionAsync() { return Mono.fromRunnable( - () -> { - try { - queryProvider.createCollection(this.collectionName, options.getRecordClass(), recordDefinition); - } catch (SQLException e) { - throw new SKException("Failed to create collection", e); - } - }) - .subscribeOn(Schedulers.boundedElastic()) - .then(); + () -> queryProvider.createCollection(this.collectionName, options.getRecordClass(), + recordDefinition)) + .subscribeOn(Schedulers.boundedElastic()) + .then(); } /** @@ -144,11 +132,7 @@ public Mono createCollectionIfNotExistsAsync() { public Mono deleteCollectionAsync() { return Mono.fromRunnable( () -> { - try { - queryProvider.deleteCollection(this.collectionName); - } catch (SQLException e) { - throw new SKException("Failed to delete collection", e); - } + queryProvider.deleteCollection(this.collectionName); }).subscribeOn(Schedulers.boundedElastic()).then(); } @@ -181,24 +165,15 @@ public Mono getAsync(String key, GetRecordOptions options) { public Mono> getBatchAsync(List keys, GetRecordOptions options) { return Mono.fromCallable( () -> { - List records = new ArrayList<>(); - - try { - ResultSet resultSet = queryProvider.getRecords(this.collectionName, keys, recordDefinition, options); - while (resultSet.next()) { - records.add(vectorStoreRecordMapper.mapStorageModeltoRecord(resultSet)); - } - } catch (SQLException e) { - throw new SKException("Failed to get records", e); - } - - return records; + return queryProvider.getRecords(this.collectionName, keys, recordDefinition, + vectorStoreRecordMapper, options); }).subscribeOn(Schedulers.boundedElastic()); } protected String getKeyFromRecord(Record data) { try { - Field keyField = data.getClass().getDeclaredField(recordDefinition.getKeyField().getName()); + Field keyField = data.getClass() + .getDeclaredField(recordDefinition.getKeyField().getName()); keyField.setAccessible(true); return (String) keyField.get(data); } catch (NoSuchFieldException | IllegalAccessException e) { @@ -235,13 +210,8 @@ public Mono upsertAsync(Record data, UpsertRecordOptions options) { public Mono> upsertBatchAsync(List data, UpsertRecordOptions options) { return Mono.fromCallable( () -> { - try { - queryProvider.upsertRecords(this.collectionName, data, recordDefinition, options); - - return data.stream().map(this::getKeyFromRecord).collect(Collectors.toList()); - } catch (SQLException e) { - throw new SKException("Failed to upsert records", e); - } + queryProvider.upsertRecords(this.collectionName, data, recordDefinition, options); + return data.stream().map(this::getKeyFromRecord).collect(Collectors.toList()); }) .subscribeOn(Schedulers.boundedElastic()); } @@ -269,11 +239,7 @@ public Mono deleteAsync(String key, DeleteRecordOptions options) { public Mono deleteBatchAsync(List keys, DeleteRecordOptions options) { return Mono.fromRunnable( () -> { - try { - queryProvider.deleteRecords(this.collectionName, keys, recordDefinition, options); - } catch (SQLException e) { - throw new SKException("Failed to delete records", e); - } + queryProvider.deleteRecords(this.collectionName, keys, recordDefinition, options); }).subscribeOn(Schedulers.boundedElastic()).then(); } @@ -284,20 +250,17 @@ public Mono deleteBatchAsync(List keys, DeleteRecordOptions option */ @Override public Mono prepareAsync() { - return Mono.fromRunnable(() -> { - try { - queryProvider.prepareVectorStore(); - } catch (SQLException e) { - throw new SKException("Failed to prepare vector store record collection", e); - } - }).subscribeOn(Schedulers.boundedElastic()).then(); + return Mono.fromRunnable(queryProvider::prepareVectorStore) + .subscribeOn(Schedulers.boundedElastic()).then(); } - public static class Builder implements SemanticKernelBuilder> { + public static class Builder + implements SemanticKernelBuilder> { private Connection connection; private String collectionName; private JDBCVectorStoreRecordCollectionOptions options; + @SuppressFBWarnings("EI_EXPOSE_REP2") public Builder withConnection(Connection connection) { this.connection = connection; return this; diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java index 4d92eb25..8461cd08 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java @@ -1,3 +1,4 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.jdbc; import java.sql.Connection; diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java index 6fe6dbaf..6b860b39 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java @@ -4,20 +4,37 @@ import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; public class JDBCVectorStoreRecordCollectionOptions { + + /** + * The default name for the collections table. + */ + public static final String DEFAULT_COLLECTIONS_TABLE = "SKCollections"; + + /** + * The prefix for collection tables. + */ + public static final String DEFAULT_PREFIX_FOR_COLLECTION_TABLES = "SKCollection_"; + private final Class recordClass; private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper; private final VectorStoreRecordDefinition recordDefinition; private final JDBCVectorStoreQueryProvider queryProvider; + private final String collectionsTableName; + private final String prefixForCollectionTables; - public JDBCVectorStoreRecordCollectionOptions( + private JDBCVectorStoreRecordCollectionOptions( Class recordClass, VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper vectorStoreRecordMapper, - JDBCVectorStoreQueryProvider queryProvider) { + JDBCVectorStoreQueryProvider queryProvider, + String collectionsTableName, + String prefixForCollectionTables) { this.recordClass = recordClass; this.recordDefinition = recordDefinition; this.vectorStoreRecordMapper = vectorStoreRecordMapper; this.queryProvider = queryProvider; + this.collectionsTableName = collectionsTableName; + this.prefixForCollectionTables = prefixForCollectionTables; } /** @@ -53,6 +70,22 @@ public JDBCVectorStoreRecordMapper getVectorStoreRecordMapper() { return vectorStoreRecordMapper; } + /** + * Gets the collections table. + * @return the collections table + */ + public String getCollectionsTableName() { + return collectionsTableName; + } + + /** + * Gets the prefix for collection tables. + * @return the prefix for collection tables + */ + public String getPrefixForCollectionTables() { + return prefixForCollectionTables; + } + /** * Gets the query provider. * @return the query provider @@ -61,11 +94,19 @@ public JDBCVectorStoreQueryProvider getQueryProvider() { return queryProvider; } + public static void validateSQLidentifier(String identifier) { + if (!identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*")) { + throw new IllegalArgumentException("Invalid SQL identifier: " + identifier); + } + } + public static class Builder { private Class recordClass; private VectorStoreRecordDefinition recordDefinition; private JDBCVectorStoreRecordMapper vectorStoreRecordMapper; private JDBCVectorStoreQueryProvider queryProvider; + private String collectionsTableName = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; /** * Sets the record class. @@ -108,6 +149,28 @@ public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvi return this; } + /** + * Sets the collections table name. + * @param collectionsTableName the collections table name + * @return the builder + */ + public Builder withCollectionsTableName(String collectionsTableName) { + validateSQLidentifier(collectionsTableName); + this.collectionsTableName = collectionsTableName; + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + validateSQLidentifier(prefixForCollectionTables); + this.prefixForCollectionTables = prefixForCollectionTables; + return this; + } + /** * Builds the options. * @return the options @@ -121,8 +184,9 @@ public JDBCVectorStoreRecordCollectionOptions build() { recordClass, recordDefinition, vectorStoreRecordMapper, - queryProvider - ); + queryProvider, + collectionsTableName, + prefixForCollectionTables); } } } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java index f7211f24..1110cb95 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java @@ -17,9 +17,10 @@ import java.util.List; public class MySQLVectorStoreQueryProvider extends - JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { + JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { - public MySQLVectorStoreQueryProvider(Connection connection, String collectionsTable, String prefixForCollectionTables) { + public MySQLVectorStoreQueryProvider(Connection connection, String collectionsTable, + String prefixForCollectionTables) { super(connection, collectionsTable, prefixForCollectionTables); } @@ -31,7 +32,8 @@ public static Builder builder() { return new Builder(); } - private void setStatementValues(PreparedStatement statement, Object record, List fields) { + private void setStatementValues(PreparedStatement statement, Object record, + List fields) { for (int i = 0; i < fields.size(); ++i) { VectorStoreRecordField field = fields.get(i); try { @@ -42,7 +44,8 @@ private void setStatementValues(PreparedStatement statement, Object record, List if (field instanceof VectorStoreRecordKeyField) { statement.setObject(i + 1, (String) value); } else if (field instanceof VectorStoreRecordVectorField) { - Class vectorType = record.getClass().getDeclaredField(field.getName()).getType(); + Class vectorType = record.getClass().getDeclaredField(field.getName()) + .getType(); // If the vector field is other than String, serialize it to JSON if (vectorType.equals(String.class)) { @@ -63,7 +66,10 @@ private void setStatementValues(PreparedStatement statement, Object record, List } @Override - public void upsertRecords(String collectionName, List records, VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) throws SQLException { + public void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { + validateSQLidentifier(getCollectionTableName(collectionName)); + List fields = recordDefinition.getAllFields(); StringBuilder onDuplicateKeyUpdate = new StringBuilder(); @@ -73,22 +79,25 @@ public void upsertRecords(String collectionName, List records, VectorStoreRec onDuplicateKeyUpdate.append(", "); } - onDuplicateKeyUpdate.append(field.getName()).append(" = VALUES(").append(field.getName()).append(")"); + onDuplicateKeyUpdate.append(field.getName()).append(" = VALUES(") + .append(field.getName()).append(")"); } String query = "INSERT INTO " + getCollectionTableName(collectionName) - + " (" + getQueryColumnsFromFields(fields) + ")" - + " VALUES (" + getWildcardString(fields.size()) + ")" - + " ON DUPLICATE KEY UPDATE " + onDuplicateKeyUpdate; - - PreparedStatement statement = connection.prepareStatement(query); + + " (" + getQueryColumnsFromFields(fields) + ")" + + " VALUES (" + getWildcardString(fields.size()) + ")" + + " ON DUPLICATE KEY UPDATE " + onDuplicateKeyUpdate; + + try (PreparedStatement statement = connection.prepareStatement(query)) { + for (Object record : records) { + setStatementValues(statement, record, recordDefinition.getAllFields()); + statement.addBatch(); + } - for (Object record : records) { - setStatementValues(statement, record, recordDefinition.getAllFields()); - statement.addBatch(); + statement.executeBatch(); + } catch (SQLException e) { + throw new SKException("Failed to upsert records", e); } - - statement.executeBatch(); } public static class Builder @@ -98,7 +107,8 @@ public MySQLVectorStoreQueryProvider build() { throw new IllegalArgumentException("connection is required"); } - return new MySQLVectorStoreQueryProvider(connection, collectionsTable, prefixForCollectionTables); + return new MySQLVectorStoreQueryProvider(connection, collectionsTable, + prefixForCollectionTables); } } } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java index dd19f787..10e4d2ef 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java @@ -1,10 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.jdbc; import com.microsoft.semantickernel.data.VectorStore; import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import reactor.core.publisher.Mono; -public interface SQLVectorStore> extends VectorStore { +public interface SQLVectorStore> + extends VectorStore { /** * Prepares the vector store. diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java index 3583a273..ff12c88b 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java @@ -1,9 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.jdbc; import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import reactor.core.publisher.Mono; -public interface SQLVectorStoreRecordCollection extends VectorStoreRecordCollection { +public interface SQLVectorStoreRecordCollection + extends VectorStoreRecordCollection { /** * Prepares the vector store record collection. diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java index e8d9e1db..52d30bf8 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java @@ -82,11 +82,12 @@ public RedisVectorStoreRecordCollection( // Validate supported types VectorStoreRecordDefinition.validateSupportedTypes( - Collections.singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), - supportedKeyTypes); + Collections + .singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), + supportedKeyTypes); VectorStoreRecordDefinition.validateSupportedTypes( - recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), - supportedVectorTypes); + recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), + supportedVectorTypes); // If mapper is not provided, set a default one if (options.getVectorStoreRecordMapper() == null) { diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java index 8c1270ba..7a433dbb 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java @@ -44,7 +44,8 @@ public VolatileVectorStoreRecordCollection(String collectionName, // Validate the key type VectorStoreRecordDefinition.validateSupportedTypes( - Collections.singletonList(recordDefinition.getKeyDeclaredField(options.getRecordClass())), + Collections + .singletonList(recordDefinition.getKeyDeclaredField(options.getRecordClass())), supportedKeyTypes); } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java index e36a258b..dff27c18 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java @@ -56,7 +56,8 @@ public List getNonVectorFields() { return fields; } - private List getDeclaredFields(Class recordClass, List fields, String fieldType) { + private List getDeclaredFields(Class recordClass, List fields, + String fieldType) { List declaredFields = new ArrayList<>(); for (VectorStoreRecordField field : fields) { try { @@ -64,7 +65,8 @@ private List getDeclaredFields(Class recordClass, List recordClass) { return recordClass.getDeclaredField(keyField.getName()); } catch (NoSuchFieldException e) { throw new IllegalArgumentException( - "Key field not found in record class: " + keyField.getName()); + "Key field not found in record class: " + keyField.getName()); } } public List getDataDeclaredFields(Class recordClass) { return getDeclaredFields( - recordClass, - dataFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), - "Data"); + recordClass, + dataFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), + "Data"); } public List getVectorDeclaredFields(Class recordClass) { return getDeclaredFields( - recordClass, - vectorFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), - "Vector"); + recordClass, + vectorFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), + "Vector"); } private VectorStoreRecordDefinition( @@ -191,8 +193,8 @@ public static VectorStoreRecordDefinition fromRecordClass(Class recordClass) return checkFields(keyFields, dataFields, vectorFields); } - - public static void validateSupportedTypes(List declaredFields, Set> supportedTypes) { + public static void validateSupportedTypes(List declaredFields, + Set> supportedTypes) { Set> unsupportedTypes = new HashSet<>(); for (Field declaredField : declaredFields) { if (!supportedTypes.contains(declaredField.getType())) { @@ -201,9 +203,10 @@ public static void validateSupportedTypes(List declaredFields, Set Date: Fri, 2 Aug 2024 01:52:17 -0700 Subject: [PATCH 4/6] Add sample and change Connection for DataSource --- .../audio/OpenAiAudioToTextService.java | 3 +- .../audio/OpenAiTextToAudioService.java | 3 +- .../chatcompletion/OpenAIChatCompletion.java | 6 +- .../OpenAITextGenerationService.java | 3 +- api-test/integration-tests/pom.xml | 6 +- .../JDBCVectorStoreRecordCollectionTest.java | 15 +- .../memory/jdbc/JDBCVectorStoreTest.java | 24 ++- .../semantickernel-syntax-examples/pom.xml | 6 + .../memory/JDBC_DataStorage.java | 188 ++++++++++++++++++ .../connectors/data/jdbc/JDBCVectorStore.java | 36 ++-- .../JDBCVectorStoreDefaultQueryProvider.java | 109 +++++----- .../data/jdbc/JDBCVectorStoreOptions.java | 5 + .../jdbc/JDBCVectorStoreQueryProvider.java | 4 +- .../jdbc/JDBCVectorStoreRecordCollection.java | 48 +++-- ...DBCVectorStoreRecordCollectionFactory.java | 3 +- ...DBCVectorStoreRecordCollectionOptions.java | 30 +-- .../jdbc/MySQLVectorStoreQueryProvider.java | 50 ++++- 17 files changed, 399 insertions(+), 140 deletions(-) create mode 100644 samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java index 7becd67e..631f2cac 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java @@ -18,7 +18,8 @@ /** * Provides OpenAi implementation of audio to text service. */ -public class OpenAiAudioToTextService extends OpenAiService implements AudioToTextService { +public class OpenAiAudioToTextService extends OpenAiService + implements AudioToTextService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiAudioToTextService.class); diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java index 25071ca9..c698fab3 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java @@ -17,7 +17,8 @@ /** * Provides OpenAi implementation of text to audio service. */ -public class OpenAiTextToAudioService extends OpenAiService implements TextToAudioService { +public class OpenAiTextToAudioService extends OpenAiService + implements TextToAudioService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiTextToAudioService.class); diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java index 83b6c9a7..84a8287e 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java @@ -79,7 +79,8 @@ /** * OpenAI chat completion service. */ -public class OpenAIChatCompletion extends OpenAiService implements ChatCompletionService { +public class OpenAIChatCompletion extends OpenAiService + implements ChatCompletionService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAIChatCompletion.class); @@ -1055,7 +1056,8 @@ static ChatRequestMessage getChatRequestMessage( /** * Builder for creating a new instance of {@link OpenAIChatCompletion}. */ - public static class Builder extends OpenAiServiceBuilder { + public static class Builder + extends OpenAiServiceBuilder { @Override public OpenAIChatCompletion build() { diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java index 5c418649..13783229 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java @@ -30,7 +30,8 @@ /** * An OpenAI implementation of a {@link TextGenerationService}. */ -public class OpenAITextGenerationService extends OpenAiService implements TextGenerationService { +public class OpenAITextGenerationService extends OpenAiService + implements TextGenerationService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAITextGenerationService.class); diff --git a/api-test/integration-tests/pom.xml b/api-test/integration-tests/pom.xml index 126b3741..7e669927 100644 --- a/api-test/integration-tests/pom.xml +++ b/api-test/integration-tests/pom.xml @@ -68,9 +68,9 @@ 3.44.1.0 - com.mysql - mysql-connector-j - 8.2.0 + mysql + mysql-connector-java + 8.0.33 test diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java index efbe9638..6e80e4ac 100644 --- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java @@ -5,6 +5,7 @@ import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; import com.microsoft.semantickernel.tests.connectors.memory.Hotel; +import com.mysql.cj.jdbc.MysqlDataSource; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.testcontainers.containers.MySQLContainer; @@ -12,6 +13,7 @@ import org.testcontainers.junit.jupiter.Testcontainers; import javax.annotation.Nonnull; +import javax.sql.DataSource; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; @@ -29,20 +31,23 @@ public class JDBCVectorStoreRecordCollectionTest { private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); private static final String MYSQL_USER = "test"; private static final String MYSQL_PASSWORD = "test"; - private static Connection connection; + private static MysqlDataSource dataSource; @BeforeAll - static void setup() throws SQLException { - connection = DriverManager.getConnection(CONTAINER.getJdbcUrl(), MYSQL_USER, MYSQL_PASSWORD); + static void setup() { + dataSource = new MysqlDataSource(); + dataSource.setUrl(CONTAINER.getJdbcUrl()); + dataSource.setUser(MYSQL_USER); + dataSource.setPassword(MYSQL_PASSWORD); } private JDBCVectorStoreRecordCollection buildRecordCollection(@Nonnull String collectionName) { JDBCVectorStoreRecordCollection recordCollection = new JDBCVectorStoreRecordCollection<>( - connection, + dataSource, collectionName, JDBCVectorStoreRecordCollectionOptions.builder() .withRecordClass(Hotel.class) .withQueryProvider(MySQLVectorStoreQueryProvider.builder() - .withConnection(connection) + .withDataSource(dataSource) .build()) .build()); diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java index 0ebeed42..eb134dd0 100644 --- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java @@ -4,6 +4,7 @@ import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions; import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; import com.microsoft.semantickernel.tests.connectors.memory.Hotel; +import com.mysql.cj.jdbc.MysqlDataSource; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.testcontainers.containers.MySQLContainer; @@ -26,24 +27,29 @@ public class JDBCVectorStoreTest { private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); private static final String MYSQL_USER = "test"; private static final String MYSQL_PASSWORD = "test"; - private static Connection connection; + private static MysqlDataSource dataSource; @BeforeAll - static void setup() throws SQLException { - connection = DriverManager.getConnection(CONTAINER.getJdbcUrl(), MYSQL_USER, MYSQL_PASSWORD); + static void setup() { + dataSource = new MysqlDataSource(); + dataSource.setUrl(CONTAINER.getJdbcUrl()); + dataSource.setUser(MYSQL_USER); + dataSource.setPassword(MYSQL_PASSWORD); } @Test public void getCollectionNamesAsync() { - JDBCVectorStoreOptions options = JDBCVectorStoreOptions.builder() - .withQueryProvider(MySQLVectorStoreQueryProvider.builder() - .withConnection(connection) - .build()) + MySQLVectorStoreQueryProvider queryProvider = MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) .build(); JDBCVectorStore vectorStore = JDBCVectorStore.builder() - .withConnection(connection) - .withOptions(options) + .withDataSource(dataSource) + .withOptions( + JDBCVectorStoreOptions.builder() + .withQueryProvider(queryProvider) + .build() + ) .build(); vectorStore.getCollectionNamesAsync().block(); diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml b/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml index bea43f73..dcf92999 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml @@ -88,6 +88,12 @@ 1.1.0 compile + + + mysql + mysql-connector-java + 8.0.33 + diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java new file mode 100644 index 00000000..2379e572 --- /dev/null +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java @@ -0,0 +1,188 @@ +package com.microsoft.semantickernel.samples.syntaxexamples.memory; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.credential.KeyCredential; +import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions; +import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute; +import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; +import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute; +import com.mysql.cj.jdbc.MysqlDataSource; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import javax.sql.DataSource; +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class JDBC_DataStorage { + + private static final String CLIENT_KEY = System.getenv("CLIENT_KEY"); + private static final String AZURE_CLIENT_KEY = System.getenv("AZURE_CLIENT_KEY"); + + // Only required if AZURE_CLIENT_KEY is set + private static final String CLIENT_ENDPOINT = System.getenv("CLIENT_ENDPOINT"); + private static final String MODEL_ID = System.getenv() + .getOrDefault("EMBEDDING_MODEL_ID", "text-embedding-3-large"); + private static final int EMBEDDING_DIMENSIONS = 1536; + + // Run a MySQL server with: + // docker run -d --name mysql-container -e MYSQL_ROOT_PASSWORD=root -e MYSQL_DATABASE=sk -p 3306:3306 mysql:latest + + static class GitHubFile { + @VectorStoreRecordKeyAttribute() + private final String id; + @VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding") + private final String description; + @VectorStoreRecordDataAttribute + private final String link; + @VectorStoreRecordVectorAttribute(dimensions = EMBEDDING_DIMENSIONS, indexKind = "Hnsw") + private final List embedding; + + public GitHubFile() { + this(null, null, null, Collections.emptyList()); + } + + public GitHubFile( + String id, + String description, + String link, + List embedding) { + this.id = id; + this.description = description; + this.link = link; + this.embedding = embedding; + } + + public String getId() { + return id; + } + + public String getDescription() { + return description; + } + + static String encodeId(String realId) { + byte[] bytes = Base64.getUrlEncoder().encode(realId.getBytes(StandardCharsets.UTF_8)); + return new String(bytes, StandardCharsets.UTF_8); + } + } + + public static void main(String[] args) throws SQLException { + System.out.println("=============================================================="); + System.out.println("========== JDBC Vector Store Example =============="); + System.out.println("=============================================================="); + + OpenAIAsyncClient client; + + if (AZURE_CLIENT_KEY != null) { + client = new OpenAIClientBuilder() + .credential(new AzureKeyCredential(AZURE_CLIENT_KEY)) + .endpoint(CLIENT_ENDPOINT) + .buildAsyncClient(); + + } else { + client = new OpenAIClientBuilder() + .credential(new KeyCredential(CLIENT_KEY)) + .buildAsyncClient(); + } + + var embeddingGeneration = OpenAITextEmbeddingGenerationService.builder() + .withOpenAIAsyncClient(client) + .withModelId(MODEL_ID) + .withDimensions(EMBEDDING_DIMENSIONS) + .build(); + + var dataSource = new MysqlDataSource(); + dataSource.setUrl("jdbc:mysql://localhost:3306/sk"); + dataSource.setPassword("root"); + dataSource.setUser("root"); + + dataStorageWithMySQL(dataSource, embeddingGeneration); + } + + public static void dataStorageWithMySQL( + DataSource dataSource, + OpenAITextEmbeddingGenerationService embeddingGeneration) { + + // Build a query provider + var queryProvider = MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + + // Create a new vector store + var jdbcVectorStore = JDBCVectorStore.builder() + .withDataSource(dataSource) + .withOptions(JDBCVectorStoreOptions.builder() + .withQueryProvider(queryProvider) + .build()) + .build(); + + String collectionName = "skgithubfiles"; + var collection = jdbcVectorStore.getCollection(collectionName, GitHubFile.class, + null); + + // Create collection if it does not exist and store data + List ids = collection + .createCollectionIfNotExistsAsync() + .then(storeData(collection, embeddingGeneration, sampleData())) + .block(); + + List data = collection.getBatchAsync(ids, null).block(); + + data.forEach(gitHubFile -> System.out.println("Retrieved: " + gitHubFile.getDescription())); + } + + private static Mono> storeData( + VectorStoreRecordCollection recordStore, + OpenAITextEmbeddingGenerationService embeddingGeneration, + Map data) { + + return Flux.fromIterable(data.entrySet()) + .flatMap(entry -> { + System.out.println("Save '" + entry.getKey() + "' to memory."); + + return embeddingGeneration + .generateEmbeddingsAsync(Collections.singletonList(entry.getValue())) + .flatMap(embeddings -> { + GitHubFile gitHubFile = new GitHubFile( + GitHubFile.encodeId(entry.getKey()), + entry.getValue(), + entry.getKey(), + embeddings.get(0).getVector()); + return recordStore.upsertAsync(gitHubFile, null); + }); + }) + .collectList(); + } + + private static Map sampleData() { + return Arrays.stream(new String[][] { + { "https://github.com/microsoft/semantic-kernel/blob/main/README.md", + "README: Installation, getting started with Semantic Kernel, and how to contribute" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/samples/notebooks/dotnet/02-running-prompts-from-file.ipynb", + "Jupyter notebook describing how to pass prompts from a file to a semantic skill or function" }, + { "https://github.com/microsoft/semantic-kernel/tree/main/samples/skills/ChatSkill/ChatGPT", + "Sample demonstrating how to create a chat skill interfacing with ChatGPT" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/SemanticKernel/Memory/VolatileMemoryStore.cs", + "C# class that defines a volatile embedding store" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/samples/dotnet/KernelHttpServer/README.md", + "README: How to set up a Semantic Kernel Service API using Azure Function Runtime v4" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/samples/apps/chat-summary-webapp-react/README.md", + "README: README associated with a sample chat summary react-based webapp" }, + }).collect(Collectors.toMap(element -> element[0], element -> element[1])); + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java index 7e02f4a8..5e497176 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java @@ -8,14 +8,14 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; -import java.sql.Connection; +import javax.sql.DataSource; import java.util.List; /** * A JDBC vector store. */ public class JDBCVectorStore implements SQLVectorStore> { - private final Connection connection; + private final DataSource dataSource; private final JDBCVectorStoreOptions options; private final JDBCVectorStoreQueryProvider queryProvider; @@ -23,20 +23,20 @@ public class JDBCVectorStore implements SQLVectorStore JDBCVectorStoreRecordCollection getCollection( + public JDBCVectorStoreRecordCollection getCollection( @Nonnull String collectionName, @Nonnull Class recordClass, @Nullable VectorStoreRecordDefinition recordDefinition) { @@ -67,7 +67,7 @@ public JDBCVectorStoreRecordCollection getCollection( if (this.options != null && this.options.getVectorStoreRecordCollectionFactory() != null) { return this.options.getVectorStoreRecordCollectionFactory() .createVectorStoreRecordCollection( - connection, + dataSource, collectionName, JDBCVectorStoreRecordCollectionOptions.builder() .withRecordClass(recordClass) @@ -77,7 +77,7 @@ public JDBCVectorStoreRecordCollection getCollection( } return new JDBCVectorStoreRecordCollection<>( - connection, + dataSource, collectionName, JDBCVectorStoreRecordCollectionOptions.builder() .withRecordClass(recordClass) @@ -110,18 +110,18 @@ public Mono prepareAsync() { * Builder for creating a {@link JDBCVectorStore}. */ public static class Builder { - private Connection connection; + private DataSource dataSource; private JDBCVectorStoreOptions options; /** - * Sets the connection. + * Sets the data source. * - * @param connection the connection + * @param dataSource the data source * @return the builder */ @SuppressFBWarnings("EI_EXPOSE_REP2") - public Builder withConnection(Connection connection) { - this.connection = connection; + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; return this; } @@ -151,11 +151,11 @@ public JDBCVectorStore build() { * @return the {@link Mono} with the {@link JDBCVectorStore} */ public Mono buildAsync() { - if (connection == null) { - throw new IllegalArgumentException("connection is required"); + if (dataSource == null) { + throw new IllegalArgumentException("dataSource is required"); } - JDBCVectorStore vectorStore = new JDBCVectorStore(connection, options); + JDBCVectorStore vectorStore = new JDBCVectorStore(dataSource, options); return vectorStore.prepareAsync().thenReturn(vectorStore); } } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java index 9bf75afd..23264bf1 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java @@ -10,6 +10,7 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import javax.annotation.Nonnull; +import javax.sql.DataSource; import java.lang.reflect.Field; import java.sql.Connection; import java.sql.PreparedStatement; @@ -53,16 +54,16 @@ public class JDBCVectorStoreDefaultQueryProvider supportedVectorTypes.put(List.class, "TEXT"); supportedVectorTypes.put(Collection.class, "TEXT"); } - protected final Connection connection; - protected final String collectionsTable; - protected final String prefixForCollectionTables; + private final DataSource dataSource; + private final String collectionsTable; + private final String prefixForCollectionTables; - @SuppressFBWarnings("EI_EXPOSE_REP2") + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed protected JDBCVectorStoreDefaultQueryProvider( - @Nonnull Connection connection, + @Nonnull DataSource dataSource, @Nonnull String collectionsTable, @Nonnull String prefixForCollectionTables) { - this.connection = connection; + this.dataSource = dataSource; this.collectionsTable = collectionsTable; this.prefixForCollectionTables = prefixForCollectionTables; } @@ -110,15 +111,17 @@ protected String getColumnNamesAndTypes(List fields, Map, String } protected String getCollectionTableName(String collectionName) { - return prefixForCollectionTables + collectionName; + return validateSQLidentifier(prefixForCollectionTables + collectionName); } @Override public void prepareVectorStore() { - String createCollectionsTable = "CREATE TABLE IF NOT EXISTS " + collectionsTable + String createCollectionsTable = "CREATE TABLE IF NOT EXISTS " + + validateSQLidentifier(collectionsTable) + " (collectionId VARCHAR(255) PRIMARY KEY);"; - try (PreparedStatement createTable = connection.prepareStatement(createCollectionsTable)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement createTable = connection.prepareStatement(createCollectionsTable)) { createTable.execute(); } catch (SQLException e) { throw new SKException("Failed to prepare vector store", e); @@ -139,11 +142,11 @@ public void validateSupportedTypes(Class recordClass, @Override public boolean collectionExists(String collectionName) { - validateSQLidentifier(collectionsTable); - - String query = "SELECT 1 FROM " + collectionsTable + " WHERE collectionId = ?"; + String query = "SELECT 1 FROM " + validateSQLidentifier(collectionsTable) + + " WHERE collectionId = ?"; - try (PreparedStatement statement = connection.prepareStatement(query)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { statement.setObject(1, collectionName); return statement.executeQuery().next(); @@ -153,10 +156,9 @@ public boolean collectionExists(String collectionName) { } @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers public void createCollection(String collectionName, Class recordClass, VectorStoreRecordDefinition recordDefinition) { - validateSQLidentifier(collectionName); - Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass); List dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass); List vectorDeclaredFields = recordDefinition.getVectorDeclaredFields(recordClass); @@ -167,16 +169,18 @@ public void createCollection(String collectionName, Class recordClass, + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", " + getColumnNamesAndTypes(vectorDeclaredFields, supportedVectorTypes) + ");"; - String insertCollectionQuery = "INSERT INTO " + collectionsTable + String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable) + " (collectionId) VALUES (?)"; - try (PreparedStatement createTable = connection.prepareStatement(createStorageTable)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement createTable = connection.prepareStatement(createStorageTable)) { createTable.execute(); } catch (SQLException e) { throw new SKException("Failed to create collection", e); } - try (PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) { insert.setObject(1, collectionName); insert.execute(); } catch (SQLException e) { @@ -186,22 +190,21 @@ public void createCollection(String collectionName, Class recordClass, @Override public void deleteCollection(String collectionName) { - validateSQLidentifier(collectionsTable); - validateSQLidentifier(getCollectionTableName(collectionName)); - - String deleteCollectionOperation = "DELETE FROM " + collectionsTable + String deleteCollectionOperation = "DELETE FROM " + validateSQLidentifier(collectionsTable) + " WHERE collectionId = ?"; String dropTableOperation = "DROP TABLE " + getCollectionTableName(collectionName); - try (PreparedStatement deleteCollection = connection - .prepareStatement(deleteCollectionOperation)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement deleteCollection = connection + .prepareStatement(deleteCollectionOperation)) { deleteCollection.setObject(1, collectionName); deleteCollection.execute(); } catch (SQLException e) { throw new SKException("Failed to delete collection", e); } - try (PreparedStatement dropTable = connection.prepareStatement(dropTableOperation)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement dropTable = connection.prepareStatement(dropTableOperation)) { dropTable.execute(); } catch (SQLException e) { throw new SKException("Failed to drop table", e); @@ -210,11 +213,10 @@ public void deleteCollection(String collectionName) { @Override public List getCollectionNames() { - validateSQLidentifier(collectionsTable); + String query = "SELECT collectionId FROM " + validateSQLidentifier(collectionsTable); - String query = "SELECT collectionId FROM " + collectionsTable; - - try (PreparedStatement statement = connection.prepareStatement(query)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { List collectionNames = new ArrayList<>(); ResultSet resultSet = statement.executeQuery(); @@ -230,10 +232,8 @@ public List getCollectionNames() { @Override public List getRecords(String collectionName, List keys, - VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, - GetRecordOptions options) { - validateSQLidentifier(getCollectionTableName(collectionName)); - + VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, + GetRecordOptions options) { List fields; if (options == null || options.includeVectors()) { fields = recordDefinition.getAllFields(); @@ -246,7 +246,8 @@ public List getRecords(String collectionName, List keys + " WHERE " + recordDefinition.getKeyField().getName() + " IN (" + getWildcardString(keys.size()) + ")"; - try (PreparedStatement statement = connection.prepareStatement(query)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { for (int i = 0; i < keys.size(); ++i) { statement.setObject(i + 1, keys.get(i)); } @@ -274,13 +275,12 @@ public void upsertRecords(String collectionName, List records, @Override public void deleteRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) { - validateSQLidentifier(getCollectionTableName(collectionName)); - String query = "DELETE FROM " + getCollectionTableName(collectionName) + " WHERE " + recordDefinition.getKeyField().getName() + " IN (" + getWildcardString(keys.size()) + ")"; - try (PreparedStatement statement = connection.prepareStatement(query)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { for (int i = 0; i < keys.size(); ++i) { statement.setObject(i + 1, keys.get(i)); } @@ -291,10 +291,11 @@ public void deleteRecords(String collectionName, List keys, } } - public static void validateSQLidentifier(String identifier) { - if (!identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*")) { - throw new IllegalArgumentException("Invalid SQL identifier: " + identifier); + public static String validateSQLidentifier(String identifier) { + if (identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*")) { + return identifier; } + throw new IllegalArgumentException("Invalid SQL identifier: " + identifier); } /** @@ -302,18 +303,18 @@ public static void validateSQLidentifier(String identifier) { */ public static class Builder implements JDBCVectorStoreQueryProvider.Builder { - protected Connection connection; - protected String collectionsTable = DEFAULT_COLLECTIONS_TABLE; - protected String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + private DataSource dataSource; + private String collectionsTable = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; /** - * Sets the connection. - * @param connection the connection + * Sets the data source. + * @param dataSource the data source * @return the builder */ - @SuppressFBWarnings("EI_EXPOSE_REP2") - public Builder withConnection(Connection connection) { - this.connection = connection; + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; return this; } @@ -323,8 +324,7 @@ public Builder withConnection(Connection connection) { * @return the builder */ public Builder withCollectionsTable(String collectionsTable) { - validateSQLidentifier(collectionsTable); - this.collectionsTable = collectionsTable; + this.collectionsTable = validateSQLidentifier(collectionsTable); return this; } @@ -334,18 +334,17 @@ public Builder withCollectionsTable(String collectionsTable) { * @return the builder */ public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { - validateSQLidentifier(prefixForCollectionTables); - this.prefixForCollectionTables = prefixForCollectionTables; + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); return this; } @Override public JDBCVectorStoreDefaultQueryProvider build() { - if (connection == null) { - throw new IllegalArgumentException("connection is required"); + if (dataSource == null) { + throw new IllegalArgumentException("DataSource is required"); } - return new JDBCVectorStoreDefaultQueryProvider(connection, collectionsTable, + return new JDBCVectorStoreDefaultQueryProvider(dataSource, collectionsTable, prefixForCollectionTables); } } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java index 580868e9..adb6e13c 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.jdbc; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + import javax.annotation.Nullable; public class JDBCVectorStoreOptions { @@ -14,6 +16,7 @@ public class JDBCVectorStoreOptions { * * @param vectorStoreRecordCollectionFactory The vector store record collection factory. */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource in queryProvider is not exposed public JDBCVectorStoreOptions( @Nullable JDBCVectorStoreQueryProvider queryProvider, @Nullable JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory) { @@ -34,6 +37,7 @@ public JDBCVectorStoreOptions() { * @return the query provider */ @Nullable + @SuppressFBWarnings("EI_EXPOSE_REP") // DataSource in queryProvider is not exposed public JDBCVectorStoreQueryProvider getQueryProvider() { return queryProvider; } @@ -73,6 +77,7 @@ public static class Builder { * @param queryProvider The query provider. * @return The updated builder instance. */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource in queryProvider is not exposed public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvider) { this.queryProvider = queryProvider; return this; diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java index fd75dac7..26d976aa 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java @@ -81,8 +81,8 @@ void createCollection(String collectionName, Class recordClass, * @return the records */ List getRecords(String collectionName, List keys, - VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, - GetRecordOptions options); + VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, + GetRecordOptions options); /** * Upserts records. diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java index 0096794c..dac65a4d 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java @@ -12,8 +12,8 @@ import reactor.core.scheduler.Schedulers; import javax.annotation.Nonnull; +import javax.sql.DataSource; import java.lang.reflect.Field; -import java.sql.Connection; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -27,15 +27,15 @@ public class JDBCVectorStoreRecordCollection private final JDBCVectorStoreQueryProvider queryProvider; /** - * Creates a new instance of the JDBCVectorRecordStore. - * If using this constructor, call {@link #prepareAsync()} before using the record collection. + * Creates a new instance of the {@link JDBCVectorStoreRecordCollection}. * - * @param connection The JDBC connection. - * @param options The options for the store. + * @param dataSource the data source + * @param collectionName the name of the collection + * @param options the options */ - @SuppressFBWarnings("EI_EXPOSE_REP2") + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed public JDBCVectorStoreRecordCollection( - @Nonnull Connection connection, + @Nonnull DataSource dataSource, @Nonnull String collectionName, @Nonnull JDBCVectorStoreRecordCollectionOptions options) { this.collectionName = collectionName; @@ -59,7 +59,7 @@ public JDBCVectorStoreRecordCollection( // If the query provider is not provided, set a default one if (options.getQueryProvider() == null) { this.queryProvider = JDBCVectorStoreDefaultQueryProvider.builder() - .withConnection(connection) + .withDataSource(dataSource) .build(); } else { this.queryProvider = options.getQueryProvider(); @@ -256,21 +256,39 @@ public Mono prepareAsync() { public static class Builder implements SemanticKernelBuilder> { - private Connection connection; + private DataSource dataSource; private String collectionName; private JDBCVectorStoreRecordCollectionOptions options; - @SuppressFBWarnings("EI_EXPOSE_REP2") - public Builder withConnection(Connection connection) { - this.connection = connection; + /** + * Sets the data source. + * + * @param dataSource the data source + * @return the builder + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; return this; } + /** + * Sets the collection name. + * + * @param collectionName the collection name + * @return the builder + */ public Builder withCollectionName(String collectionName) { this.collectionName = collectionName; return this; } + /** + * Sets the options. + * + * @param options the options + * @return the builder + */ public Builder withOptions(JDBCVectorStoreRecordCollectionOptions options) { this.options = options; return this; @@ -278,8 +296,8 @@ public Builder withOptions(JDBCVectorStoreRecordCollectionOptions build() { - if (connection == null) { - throw new IllegalArgumentException("connection is required"); + if (dataSource == null) { + throw new IllegalArgumentException("dataSource is required"); } if (collectionName == null) { throw new IllegalArgumentException("collectionName is required"); @@ -288,7 +306,7 @@ public JDBCVectorStoreRecordCollection build() { throw new IllegalArgumentException("options is required"); } - return new JDBCVectorStoreRecordCollection<>(connection, collectionName, options); + return new JDBCVectorStoreRecordCollection<>(dataSource, collectionName, options); } } } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java index 8461cd08..70b62a7e 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.jdbc; +import javax.sql.DataSource; import java.sql.Connection; /** @@ -14,7 +15,7 @@ public interface JDBCVectorStoreRecordCollectionFactory { * @return The new JDBC vector store record collection. */ JDBCVectorStoreRecordCollection createVectorStoreRecordCollection( - Connection connection, + DataSource dataSource, String collectionName, JDBCVectorStoreRecordCollectionOptions options); } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java index 6b860b39..af1ec49e 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java @@ -2,19 +2,13 @@ package com.microsoft.semantickernel.connectors.data.jdbc; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -public class JDBCVectorStoreRecordCollectionOptions { - - /** - * The default name for the collections table. - */ - public static final String DEFAULT_COLLECTIONS_TABLE = "SKCollections"; - - /** - * The prefix for collection tables. - */ - public static final String DEFAULT_PREFIX_FOR_COLLECTION_TABLES = "SKCollection_"; +import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider.validateSQLidentifier; +import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_COLLECTIONS_TABLE; +import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_PREFIX_FOR_COLLECTION_TABLES; +public class JDBCVectorStoreRecordCollectionOptions { private final Class recordClass; private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper; private final VectorStoreRecordDefinition recordDefinition; @@ -90,16 +84,11 @@ public String getPrefixForCollectionTables() { * Gets the query provider. * @return the query provider */ + @SuppressFBWarnings("EI_EXPOSE_REP") // DataSource in queryProvider is not exposed public JDBCVectorStoreQueryProvider getQueryProvider() { return queryProvider; } - public static void validateSQLidentifier(String identifier) { - if (!identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*")) { - throw new IllegalArgumentException("Invalid SQL identifier: " + identifier); - } - } - public static class Builder { private Class recordClass; private VectorStoreRecordDefinition recordDefinition; @@ -144,6 +133,7 @@ public Builder withVectorStoreRecordMapper( * @param queryProvider the query provider * @return the builder */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource in queryProvider is not exposed public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvider) { this.queryProvider = queryProvider; return this; @@ -155,8 +145,7 @@ public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvi * @return the builder */ public Builder withCollectionsTableName(String collectionsTableName) { - validateSQLidentifier(collectionsTableName); - this.collectionsTableName = collectionsTableName; + this.collectionsTableName = validateSQLidentifier(collectionsTableName); return this; } @@ -166,8 +155,7 @@ public Builder withCollectionsTableName(String collectionsTableName) { * @return the builder */ public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { - validateSQLidentifier(prefixForCollectionTables); - this.prefixForCollectionTables = prefixForCollectionTables; + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); return this; } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java index 1110cb95..fbf2afce 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java @@ -9,7 +9,9 @@ import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; import com.microsoft.semantickernel.exceptions.SKException; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import javax.sql.DataSource; import java.lang.reflect.Field; import java.sql.Connection; import java.sql.PreparedStatement; @@ -19,9 +21,13 @@ public class MySQLVectorStoreQueryProvider extends JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { - public MySQLVectorStoreQueryProvider(Connection connection, String collectionsTable, + private final DataSource dataSource; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + private MySQLVectorStoreQueryProvider(DataSource dataSource, String collectionsTable, String prefixForCollectionTables) { - super(connection, collectionsTable, prefixForCollectionTables); + super(dataSource, collectionsTable, prefixForCollectionTables); + this.dataSource = dataSource; } /** @@ -66,6 +72,7 @@ private void setStatementValues(PreparedStatement statement, Object record, } @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers public void upsertRecords(String collectionName, List records, VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { validateSQLidentifier(getCollectionTableName(collectionName)); @@ -88,7 +95,8 @@ public void upsertRecords(String collectionName, List records, + " VALUES (" + getWildcardString(fields.size()) + ")" + " ON DUPLICATE KEY UPDATE " + onDuplicateKeyUpdate; - try (PreparedStatement statement = connection.prepareStatement(query)) { + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { for (Object record : records) { setStatementValues(statement, record, recordDefinition.getAllFields()); statement.addBatch(); @@ -102,12 +110,42 @@ public void upsertRecords(String collectionName, List records, public static class Builder extends JDBCVectorStoreDefaultQueryProvider.Builder { + private DataSource dataSource; + private String collectionsTable = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTable the collections table name + * @return the builder + */ + public Builder withCollectionsTable(String collectionsTable) { + this.collectionsTable = validateSQLidentifier(collectionsTable); + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); + return this; + } + public MySQLVectorStoreQueryProvider build() { - if (connection == null) { - throw new IllegalArgumentException("connection is required"); + if (dataSource == null) { + throw new SKException("DataSource is required"); } - return new MySQLVectorStoreQueryProvider(connection, collectionsTable, + return new MySQLVectorStoreQueryProvider(dataSource, collectionsTable, prefixForCollectionTables); } } From 473e9d189f3193c7d7eb3c36dfa920e3624932c2 Mon Sep 17 00:00:00 2001 From: Milder Hernandez Cagua Date: Mon, 5 Aug 2024 16:24:52 -0700 Subject: [PATCH 5/6] Add suggestions and update docs --- .../JDBCVectorStoreRecordCollectionTest.java | 2 +- .../JDBCVectorStoreDefaultQueryProvider.java | 68 +++++++++++++++++++ .../jdbc/JDBCVectorStoreRecordCollection.java | 11 +++ .../jdbc/MySQLVectorStoreQueryProvider.java | 8 +++ .../VectorStoreRecordDefinition.java | 10 ++- 5 files changed, 95 insertions(+), 4 deletions(-) diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java index 6e80e4ac..5e676b1c 100644 --- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java @@ -28,7 +28,7 @@ @Testcontainers public class JDBCVectorStoreRecordCollectionTest { @Container - private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); + private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:latest"); private static final String MYSQL_USER = "test"; private static final String MYSQL_PASSWORD = "test"; private static MysqlDataSource dataSource; diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java index 23264bf1..096f240a 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java @@ -114,6 +114,12 @@ protected String getCollectionTableName(String collectionName) { return validateSQLidentifier(prefixForCollectionTables + collectionName); } + /** + * Prepares the vector store. + * Executes any necessary setup steps for the vector store. + * + * @throws SKException if an error occurs while preparing the vector store + */ @Override public void prepareVectorStore() { String createCollectionsTable = "CREATE TABLE IF NOT EXISTS " @@ -128,6 +134,13 @@ public void prepareVectorStore() { } } + /** + * Checks if the types of the record class fields are supported. + * + * @param recordClass the record class + * @param recordDefinition the record definition + * @throws IllegalArgumentException if the types are not supported + */ @Override public void validateSupportedTypes(Class recordClass, VectorStoreRecordDefinition recordDefinition) { @@ -140,6 +153,13 @@ public void validateSupportedTypes(Class recordClass, recordDefinition.getVectorDeclaredFields(recordClass), supportedVectorTypes.keySet()); } + /** + * Checks if a collection exists. + * + * @param collectionName the collection name + * @return true if the collection exists, false otherwise + * @throws SKException if an error occurs while checking if the collection exists + */ @Override public boolean collectionExists(String collectionName) { String query = "SELECT 1 FROM " + validateSQLidentifier(collectionsTable) @@ -155,6 +175,14 @@ public boolean collectionExists(String collectionName) { } } + /** + * Creates a collection. + * + * @param collectionName the collection name + * @param recordClass the record class + * @param recordDefinition the record definition + * @throws SKException if an error occurs while creating the collection + */ @Override @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers public void createCollection(String collectionName, Class recordClass, @@ -188,6 +216,12 @@ public void createCollection(String collectionName, Class recordClass, } } + /** + * Deletes a collection. + * + * @param collectionName the collection name + * @throws SKException if an error occurs while deleting the collection + */ @Override public void deleteCollection(String collectionName) { String deleteCollectionOperation = "DELETE FROM " + validateSQLidentifier(collectionsTable) @@ -211,6 +245,12 @@ public void deleteCollection(String collectionName) { } } + /** + * Gets the collection names. + * + * @return the collection names + * @throws SKException if an error occurs while getting the collection names + */ @Override public List getCollectionNames() { String query = "SELECT collectionId FROM " + validateSQLidentifier(collectionsTable); @@ -230,6 +270,18 @@ public List getCollectionNames() { } } + /** + * Gets a list of records from the store. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param mapper the mapper + * @param options the options + * @return the records + * @param the record type + * @throws SKException if an error occurs while getting the records + */ @Override public List getRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, @@ -272,6 +324,15 @@ public void upsertRecords(String collectionName, List records, "Upsert is not supported. Try with a specific query provider."); } + /** + * Deletes records. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param options the options + * @throws SKException if an error occurs while deleting the records + */ @Override public void deleteRecords(String collectionName, List keys, VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) { @@ -291,6 +352,13 @@ public void deleteRecords(String collectionName, List keys, } } + /** + * Validates an SQL identifier. + * + * @param identifier the identifier + * @return the identifier if it is valid + * @throws IllegalArgumentException if the identifier is invalid + */ public static String validateSQLidentifier(String identifier) { if (identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*")) { return identifier; diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java index dac65a4d..b9c0bd3c 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java @@ -83,6 +83,7 @@ public String getCollectionName() { * Checks if the collection exists in the store. * * @return A Mono emitting a boolean indicating if the collection exists. + * @throws SKException if the operation fails */ @Override public Mono collectionExistsAsync() { @@ -95,6 +96,7 @@ public Mono collectionExistsAsync() { * Creates the collection in the store. * * @return A Mono representing the completion of the creation operation. + * @throws SKException if the operation fails */ @Override public Mono createCollectionAsync() { @@ -109,6 +111,7 @@ public Mono createCollectionAsync() { * Creates the collection in the store if it does not exist. * * @return A Mono representing the completion of the creation operation. + * @throws SKException if the operation fails */ @Override public Mono createCollectionIfNotExistsAsync() { @@ -127,6 +130,7 @@ public Mono createCollectionIfNotExistsAsync() { * Deletes the collection from the store. * * @return A Mono representing the completion of the deletion operation. + * @throws SKException if the operation fails */ @Override public Mono deleteCollectionAsync() { @@ -142,6 +146,7 @@ public Mono deleteCollectionAsync() { * @param key The key of the record to get. * @param options The options for getting the record. * @return A Mono emitting the record. + * @throws SKException if the operation fails */ @Override public Mono getAsync(String key, GetRecordOptions options) { @@ -160,6 +165,7 @@ public Mono getAsync(String key, GetRecordOptions options) { * @param keys The keys of the records to get. * @param options The options for getting the records. * @return A Mono emitting a collection of records. + * @throws SKException if the operation fails */ @Override public Mono> getBatchAsync(List keys, GetRecordOptions options) { @@ -187,6 +193,7 @@ protected String getKeyFromRecord(Record data) { * @param data The record to upsert. * @param options The options for upserting the record. * @return A Mono emitting the key of the upserted record. + * @throws SKException if the operation fails */ @Override public Mono upsertAsync(Record data, UpsertRecordOptions options) { @@ -205,6 +212,7 @@ public Mono upsertAsync(Record data, UpsertRecordOptions options) { * @param data The records to upsert. * @param options The options for upserting the records. * @return A Mono emitting a collection of keys of the upserted records. + * @throws SKException if the operation fails */ @Override public Mono> upsertBatchAsync(List data, UpsertRecordOptions options) { @@ -222,6 +230,7 @@ public Mono> upsertBatchAsync(List data, UpsertRecordOption * @param key The key of the record to delete. * @param options The options for deleting the record. * @return A Mono representing the completion of the deletion operation. + * @throws SKException if the operation fails */ @Override public Mono deleteAsync(String key, DeleteRecordOptions options) { @@ -234,6 +243,7 @@ public Mono deleteAsync(String key, DeleteRecordOptions options) { * @param keys The keys of the records to delete. * @param options The options for deleting the records. * @return A Mono representing the completion of the deletion operation. + * @throws SKException if the operation fails */ @Override public Mono deleteBatchAsync(List keys, DeleteRecordOptions options) { @@ -247,6 +257,7 @@ public Mono deleteBatchAsync(List keys, DeleteRecordOptions option * Prepares the collection for use. * * @return A Mono representing the completion of the preparation operation. + * @throws SKException if the operation fails */ @Override public Mono prepareAsync() { diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java index fbf2afce..72ecd87e 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java @@ -71,6 +71,14 @@ private void setStatementValues(PreparedStatement statement, Object record, } } + /** + * Upserts records into the collection. + * @param collectionName the collection name + * @param records the records to upsert + * @param recordDefinition the record definition + * @param options the upsert options + * @throws SKException if the upsert fails + */ @Override @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers public void upsertRecords(String collectionName, List records, diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java index dff27c18..39e04a3f 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java @@ -56,8 +56,12 @@ public List getNonVectorFields() { return fields; } + private enum DeclaredFieldType { + KEY, DATA, VECTOR + } + private List getDeclaredFields(Class recordClass, List fields, - String fieldType) { + DeclaredFieldType fieldType) { List declaredFields = new ArrayList<>(); for (VectorStoreRecordField field : fields) { try { @@ -85,14 +89,14 @@ public List getDataDeclaredFields(Class recordClass) { return getDeclaredFields( recordClass, dataFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), - "Data"); + DeclaredFieldType.DATA); } public List getVectorDeclaredFields(Class recordClass) { return getDeclaredFields( recordClass, vectorFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), - "Vector"); + DeclaredFieldType.VECTOR); } private VectorStoreRecordDefinition( From c2ac6757251788628372d92b70a5795ee6f5ea87 Mon Sep 17 00:00:00 2001 From: Milder Hernandez Cagua Date: Mon, 5 Aug 2024 16:49:04 -0700 Subject: [PATCH 6/6] Fix --- .../memory/jdbc/JDBCVectorStoreRecordCollectionTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java index 5e676b1c..6e80e4ac 100644 --- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java @@ -28,7 +28,7 @@ @Testcontainers public class JDBCVectorStoreRecordCollectionTest { @Container - private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:latest"); + private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); private static final String MYSQL_USER = "test"; private static final String MYSQL_PASSWORD = "test"; private static MysqlDataSource dataSource;