diff --git a/api-test/integration-tests/pom.xml b/api-test/integration-tests/pom.xml index 7e669927..cc93c7fe 100644 --- a/api-test/integration-tests/pom.xml +++ b/api-test/integration-tests/pom.xml @@ -73,6 +73,11 @@ 8.0.33 test + + org.postgresql + postgresql + 42.7.2 + org.testcontainers diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java index e43842fc..ad10ad64 100644 --- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java @@ -18,7 +18,7 @@ public class Hotel { @VectorStoreRecordVectorAttribute(dimensions = 3) private final List descriptionEmbedding; @VectorStoreRecordDataAttribute - private final double rating; + private double rating; public Hotel() { this(null, null, 0, null, null, 0.0); @@ -56,4 +56,8 @@ public List getDescriptionEmbedding() { public double getRating() { return rating; } + + public void setRating(double rating) { + this.rating = rating; + } } diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java index c873ea5b..8bee5a76 100644 --- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java @@ -1,104 +1,151 @@ package com.microsoft.semantickernel.tests.connectors.memory.jdbc; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; - +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider; import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollection; import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions; -import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider; import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; import com.microsoft.semantickernel.tests.connectors.memory.Hotel; import com.mysql.cj.jdbc.MysqlDataSource; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import javax.annotation.Nonnull; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.postgresql.ds.PGSimpleDataSource; import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import javax.annotation.Nonnull; +import javax.sql.DataSource; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + @Testcontainers public class JDBCVectorStoreRecordCollectionTest { @Container - private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); - private static final String MYSQL_USER = "test"; - private static final String MYSQL_PASSWORD = "test"; - private static MysqlDataSource dataSource; - - @BeforeAll - static void setup() { - dataSource = new MysqlDataSource(); - dataSource.setUrl(CONTAINER.getJdbcUrl()); - dataSource.setUser(MYSQL_USER); - dataSource.setPassword(MYSQL_PASSWORD); + private static final MySQLContainer MYSQL_CONTAINER = new MySQLContainer<>("mysql:5.7.34"); + + private static final DockerImageName PGVECTOR = DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres"); + @Container + private static final PostgreSQLContainer POSTGRESQL_CONTAINER = new PostgreSQLContainer<>(PGVECTOR); + + public enum QueryProvider { + MySQL, + PostgreSQL } - private JDBCVectorStoreRecordCollection buildRecordCollection( - @Nonnull String collectionName) { - JDBCVectorStoreRecordCollection recordCollection = new JDBCVectorStoreRecordCollection<>( - dataSource, - collectionName, - JDBCVectorStoreRecordCollectionOptions.builder() - .withRecordClass(Hotel.class) - .withQueryProvider(MySQLVectorStoreQueryProvider.builder() - .withDataSource(dataSource) - .build()) - .build()); + private JDBCVectorStoreRecordCollection buildRecordCollection(QueryProvider provider, @Nonnull String collectionName) { + JDBCVectorStoreQueryProvider queryProvider; + DataSource dataSource; + + switch (provider) { + case MySQL: + MysqlDataSource mysqlDataSource = new MysqlDataSource(); + mysqlDataSource.setUrl(MYSQL_CONTAINER.getJdbcUrl()); + mysqlDataSource.setUser(MYSQL_CONTAINER.getUsername()); + mysqlDataSource.setPassword(MYSQL_CONTAINER.getPassword()); + dataSource = mysqlDataSource; + queryProvider = MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + break; + case PostgreSQL: + PGSimpleDataSource pgSimpleDataSource = new PGSimpleDataSource(); + pgSimpleDataSource.setUrl(POSTGRESQL_CONTAINER.getJdbcUrl()); + pgSimpleDataSource.setUser(POSTGRESQL_CONTAINER.getUsername()); + pgSimpleDataSource.setPassword(POSTGRESQL_CONTAINER.getPassword()); + dataSource = pgSimpleDataSource; + queryProvider = PostgreSQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + break; + default: + throw new IllegalArgumentException("Unknown query provider: " + provider); + } + + + JDBCVectorStoreRecordCollection recordCollection = new JDBCVectorStoreRecordCollection<>( + dataSource, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(Hotel.class) + .withQueryProvider(queryProvider) + .build()); recordCollection.prepareAsync().block(); recordCollection.createCollectionIfNotExistsAsync().block(); return recordCollection; } - @Test - public void buildRecordCollection() { - assertNotNull(buildRecordCollection("buildTest")); + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void buildRecordCollection(QueryProvider provider) { + assertNotNull(buildRecordCollection(provider, "buildTest")); } private List getHotels() { return List.of( - new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), - 4.0), - new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), - 3.0), - new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), - 5.0), - new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), - 4.0), - new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), - 5.0) + new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), + new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), 3.0), + new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0), + new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), + new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0) ); } - @Test - public void upsertAndGetRecordAsync() { + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void upsertAndGetRecordAsync(QueryProvider provider) { String collectionName = "upsertAndGetRecordAsync"; - JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); List hotels = getHotels(); for (Hotel hotel : hotels) { - recordStore.upsertAsync(hotel, null).block(); + recordCollection.upsertAsync(hotel, null).block(); + } + + // Upsert the first time + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + assertEquals(hotel.getRating(), retrievedHotel.getRating()); + + // Update the rating + hotel.setRating(1.0); + } + + // Upsert the second time with updated rating + for (Hotel hotel : hotels) { + recordCollection.upsertAsync(hotel, null).block(); } for (Hotel hotel : hotels) { - Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block(); + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block(); assertNotNull(retrievedHotel); assertEquals(hotel.getId(), retrievedHotel.getId()); + assertEquals(1.0, retrievedHotel.getRating()); } } - @Test - public void getBatchAsync() { + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void getBatchAsync(QueryProvider provider) { String collectionName = "getBatchAsync"; - JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); List hotels = getHotels(); for (Hotel hotel : hotels) { - recordStore.upsertAsync(hotel, null).block(); + recordCollection.upsertAsync(hotel, null).block(); } List keys = new ArrayList<>(); @@ -106,99 +153,104 @@ public void getBatchAsync() { keys.add(hotel.getId()); } - List retrievedHotels = recordStore.getBatchAsync(keys, null).block(); + List retrievedHotels = recordCollection.getBatchAsync(keys, null).block(); assertNotNull(retrievedHotels); assertEquals(hotels.size(), retrievedHotels.size()); } - @Test - public void upsertBatchAndGetBatchAsync() { + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void upsertBatchAndGetBatchAsync(QueryProvider provider) { String collectionName = "upsertBatchAndGetBatchAsync"; - JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); List hotels = getHotels(); - recordStore.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); List keys = new ArrayList<>(); for (Hotel hotel : hotels) { keys.add(hotel.getId()); } - List retrievedHotels = recordStore.getBatchAsync(keys, null).block(); + List retrievedHotels = recordCollection.getBatchAsync(keys, null).block(); assertNotNull(retrievedHotels); assertEquals(hotels.size(), retrievedHotels.size()); } - @Test - public void insertAndReplaceAsync() { + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void insertAndReplaceAsync(QueryProvider provider) { String collectionName = "insertAndReplaceAsync"; - JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); List hotels = getHotels(); - recordStore.upsertBatchAsync(hotels, null).block(); - recordStore.upsertBatchAsync(hotels, null).block(); - recordStore.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); List keys = new ArrayList<>(); for (Hotel hotel : hotels) { keys.add(hotel.getId()); } - List retrievedHotels = recordStore.getBatchAsync(keys, null).block(); + List retrievedHotels = recordCollection.getBatchAsync(keys, null).block(); assertNotNull(retrievedHotels); assertEquals(hotels.size(), retrievedHotels.size()); } - @Test - public void deleteRecordAsync() { + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void deleteRecordAsync(QueryProvider provider) { String collectionName = "deleteRecordAsync"; - JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); List hotels = getHotels(); - recordStore.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); for (Hotel hotel : hotels) { - recordStore.deleteAsync(hotel.getId(), null).block(); - Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block(); + recordCollection.deleteAsync(hotel.getId(), null).block(); + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block(); assertNull(retrievedHotel); } } - @Test - public void deleteBatchAsync() { + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void deleteBatchAsync(QueryProvider provider) { String collectionName = "deleteBatchAsync"; - JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); List hotels = getHotels(); - recordStore.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); List keys = new ArrayList<>(); for (Hotel hotel : hotels) { keys.add(hotel.getId()); } - recordStore.deleteBatchAsync(keys, null).block(); + recordCollection.deleteBatchAsync(keys, null).block(); for (String key : keys) { - Hotel retrievedHotel = recordStore.getAsync(key, null).block(); + Hotel retrievedHotel = recordCollection.getAsync(key, null).block(); assertNull(retrievedHotel); } } - @Test - public void getWithNoVectors() { + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void getWithNoVectors(QueryProvider provider) { String collectionName = "getWithNoVectors"; - JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); List hotels = getHotels(); - recordStore.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); GetRecordOptions options = GetRecordOptions.builder() .includeVectors(false) .build(); for (Hotel hotel : hotels) { - Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block(); + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block(); assertNotNull(retrievedHotel); assertEquals(hotel.getId(), retrievedHotel.getId()); assertNull(retrievedHotel.getDescriptionEmbedding()); @@ -209,20 +261,21 @@ public void getWithNoVectors() { .build(); for (Hotel hotel : hotels) { - Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block(); + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block(); assertNotNull(retrievedHotel); assertEquals(hotel.getId(), retrievedHotel.getId()); assertNotNull(retrievedHotel.getDescriptionEmbedding()); } } - @Test - public void getBatchWithNoVectors() { + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void getBatchWithNoVectors(QueryProvider provider) { String collectionName = "getBatchWithNoVectors"; - JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); List hotels = getHotels(); - recordStore.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); GetRecordOptions options = GetRecordOptions.builder() .includeVectors(false) @@ -233,7 +286,7 @@ public void getBatchWithNoVectors() { keys.add(hotel.getId()); } - List retrievedHotels = recordStore.getBatchAsync(keys, options).block(); + List retrievedHotels = recordCollection.getBatchAsync(keys, options).block(); assertNotNull(retrievedHotels); assertEquals(hotels.size(), retrievedHotels.size()); @@ -245,7 +298,7 @@ public void getBatchWithNoVectors() { .includeVectors(true) .build(); - retrievedHotels = recordStore.getBatchAsync(keys, options).block(); + retrievedHotels = recordCollection.getBatchAsync(keys, options).block(); assertNotNull(retrievedHotels); assertEquals(hotels.size(), retrievedHotels.size()); diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java index e0906a1c..8c2fbfd0 100644 --- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java @@ -1,45 +1,71 @@ package com.microsoft.semantickernel.tests.connectors.memory.jdbc; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore; import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions; -import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider; import com.microsoft.semantickernel.tests.connectors.memory.Hotel; import com.mysql.cj.jdbc.MysqlDataSource; -import java.util.Arrays; -import java.util.List; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.postgresql.ds.PGSimpleDataSource; import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import javax.annotation.Nonnull; +import javax.sql.DataSource; +import java.util.Arrays; +import java.util.List; + +import com.microsoft.semantickernel.tests.connectors.memory.jdbc.JDBCVectorStoreRecordCollectionTest.QueryProvider; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; @Testcontainers public class JDBCVectorStoreTest { @Container - private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); - private static final String MYSQL_USER = "test"; - private static final String MYSQL_PASSWORD = "test"; - private static MysqlDataSource dataSource; - - @BeforeAll - static void setup() { - dataSource = new MysqlDataSource(); - dataSource.setUrl(CONTAINER.getJdbcUrl()); - dataSource.setUser(MYSQL_USER); - dataSource.setPassword(MYSQL_PASSWORD); - } + private static final MySQLContainer MYSQL_CONTAINER = new MySQLContainer<>("mysql:5.7.34"); + + private static final DockerImageName PGVECTOR = DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres"); + @Container + private static final PostgreSQLContainer POSTGRESQL_CONTAINER = new PostgreSQLContainer<>(PGVECTOR); + + private JDBCVectorStore buildVectorStore(QueryProvider provider) { + JDBCVectorStoreQueryProvider queryProvider; + DataSource dataSource; + + switch (provider) { + case MySQL: + MysqlDataSource mysqlDataSource = new MysqlDataSource(); + mysqlDataSource.setUrl(MYSQL_CONTAINER.getJdbcUrl()); + mysqlDataSource.setUser(MYSQL_CONTAINER.getUsername()); + mysqlDataSource.setPassword(MYSQL_CONTAINER.getPassword()); + dataSource = mysqlDataSource; + queryProvider = MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + break; + case PostgreSQL: + PGSimpleDataSource pgSimpleDataSource = new PGSimpleDataSource(); + pgSimpleDataSource.setUrl(POSTGRESQL_CONTAINER.getJdbcUrl()); + pgSimpleDataSource.setUser(POSTGRESQL_CONTAINER.getUsername()); + pgSimpleDataSource.setPassword(POSTGRESQL_CONTAINER.getPassword()); + dataSource = pgSimpleDataSource; + queryProvider = PostgreSQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + break; + default: + throw new IllegalArgumentException("Unknown query provider: " + provider); + } - @Test - public void getCollectionNamesAsync() { - MySQLVectorStoreQueryProvider queryProvider = MySQLVectorStoreQueryProvider.builder() - .withDataSource(dataSource) - .build(); - JDBCVectorStore vectorStore = JDBCVectorStore.builder() + JDBCVectorStore vectorStore = JDBCVectorStore.builder() .withDataSource(dataSource) .withOptions( JDBCVectorStoreOptions.builder() @@ -48,6 +74,16 @@ public void getCollectionNamesAsync() { ) .build(); + vectorStore.prepareAsync().block(); + return vectorStore; + } + + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void getCollectionNamesAsync(QueryProvider provider) { + JDBCVectorStore vectorStore = buildVectorStore(provider); + vectorStore.getCollectionNamesAsync().block(); List collectionNames = Arrays.asList("collection1", "collection2", "collection3"); diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java index 8aa4bddb..ed1d8bd5 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java @@ -8,7 +8,7 @@ import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService; import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore; import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions; -import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider; import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; diff --git a/semantickernel-experimental/pom.xml b/semantickernel-experimental/pom.xml index 20fa172a..3efd8ab0 100644 --- a/semantickernel-experimental/pom.xml +++ b/semantickernel-experimental/pom.xml @@ -109,6 +109,11 @@ + + org.postgresql + postgresql + 42.7.2 + diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java index 096f240a..f1795083 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.jdbc; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; import com.microsoft.semantickernel.exceptions.SKException; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; @@ -24,14 +25,27 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.Stream; public class JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { - private static final Map, String> supportedKeyTypes; - private static final Map, String> supportedDataTypes; - private static final Map, String> supportedVectorTypes; - static { + private Map, String> supportedKeyTypes; + private Map, String> supportedDataTypes; + private Map, String> supportedVectorTypes; + private final DataSource dataSource; + private final String collectionsTable; + private final String prefixForCollectionTables; + + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + protected JDBCVectorStoreDefaultQueryProvider( + @Nonnull DataSource dataSource, + @Nonnull String collectionsTable, + @Nonnull String prefixForCollectionTables) { + this.dataSource = dataSource; + this.collectionsTable = collectionsTable; + this.prefixForCollectionTables = prefixForCollectionTables; + supportedKeyTypes = new HashMap<>(); supportedKeyTypes.put(String.class, "VARCHAR(255)"); @@ -54,19 +68,6 @@ public class JDBCVectorStoreDefaultQueryProvider supportedVectorTypes.put(List.class, "TEXT"); supportedVectorTypes.put(Collection.class, "TEXT"); } - private final DataSource dataSource; - private final String collectionsTable; - private final String prefixForCollectionTables; - - @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed - protected JDBCVectorStoreDefaultQueryProvider( - @Nonnull DataSource dataSource, - @Nonnull String collectionsTable, - @Nonnull String prefixForCollectionTables) { - this.dataSource = dataSource; - this.collectionsTable = collectionsTable; - this.prefixForCollectionTables = prefixForCollectionTables; - } /** * Creates a new builder. @@ -82,14 +83,9 @@ public static Builder builder() { * @return the formatted wildcard string */ protected String getWildcardString(int wildcards) { - StringBuilder wildcardString = new StringBuilder(); - for (int i = 0; i < wildcards; ++i) { - wildcardString.append("?"); - if (i < wildcards - 1) { - wildcardString.append(", "); - } - } - return wildcardString.toString(); + return Stream.generate(() -> "?") + .limit(wildcards) + .collect(Collectors.joining(", ")); } /** @@ -102,6 +98,12 @@ protected String getQueryColumnsFromFields(List fields) .collect(Collectors.joining(", ")); } + /** + * Formats the column names and types for a table. + * @param fields the fields + * @param types the types + * @return the formatted column names and types + */ protected String getColumnNamesAndTypes(List fields, Map, String> types) { List columns = fields.stream() .map(field -> field.getName() + " " + types.get(field.getType())) @@ -114,6 +116,36 @@ protected String getCollectionTableName(String collectionName) { return validateSQLidentifier(prefixForCollectionTables + collectionName); } + /** + * Gets the supported key types and their corresponding SQL types. + * + * @return the supported key types + */ + @Override + public Map, String> getSupportedKeyTypes() { + return new HashMap<>(this.supportedKeyTypes); + } + + /** + * Gets the supported data types and their corresponding SQL types. + * + * @return the supported data types + */ + @Override + public Map, String> getSupportedDataTypes() { + return new HashMap<>(this.supportedDataTypes); + } + + /** + * Gets the supported vector types and their corresponding SQL types. + * + * @return the supported vector types + */ + @Override + public Map, String> getSupportedVectorTypes() { + return new HashMap<>(this.supportedVectorTypes); + } + /** * Prepares the vector store. * Executes any necessary setup steps for the vector store. @@ -146,11 +178,12 @@ public void validateSupportedTypes(Class recordClass, VectorStoreRecordDefinition recordDefinition) { VectorStoreRecordDefinition.validateSupportedTypes( Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)), - supportedKeyTypes.keySet()); + getSupportedKeyTypes().keySet()); VectorStoreRecordDefinition.validateSupportedTypes( - recordDefinition.getDataDeclaredFields(recordClass), supportedDataTypes.keySet()); + recordDefinition.getDataDeclaredFields(recordClass), getSupportedDataTypes().keySet()); VectorStoreRecordDefinition.validateSupportedTypes( - recordDefinition.getVectorDeclaredFields(recordClass), supportedVectorTypes.keySet()); + recordDefinition.getVectorDeclaredFields(recordClass), + getSupportedVectorTypes().keySet()); } /** @@ -194,8 +227,8 @@ public void createCollection(String collectionName, Class recordClass, String createStorageTable = "CREATE TABLE IF NOT EXISTS " + getCollectionTableName(collectionName) + " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, " - + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", " - + getColumnNamesAndTypes(vectorDeclaredFields, supportedVectorTypes) + ");"; + + getColumnNamesAndTypes(dataDeclaredFields, getSupportedDataTypes()) + ", " + + getColumnNamesAndTypes(vectorDeclaredFields, getSupportedVectorTypes()) + ");"; String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable) + " (collectionId) VALUES (?)"; @@ -284,7 +317,8 @@ public List getCollectionNames() { */ @Override public List getRecords(String collectionName, List keys, - VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, + VectorStoreRecordDefinition recordDefinition, + VectorStoreRecordMapper mapper, GetRecordOptions options) { List fields; if (options == null || options.includeVectors()) { diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java index 26d976aa..6009b885 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java @@ -2,12 +2,15 @@ package com.microsoft.semantickernel.connectors.data.jdbc; import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import java.sql.ResultSet; import java.util.List; +import java.util.Map; /** * The JDBC vector store query provider. @@ -24,6 +27,27 @@ public interface JDBCVectorStoreQueryProvider { */ String DEFAULT_PREFIX_FOR_COLLECTION_TABLES = "SKCollection_"; + /** + * Gets the supported key types and their corresponding SQL types. + * + * @return the supported key types + */ + Map, String> getSupportedKeyTypes(); + + /** + * Gets the supported data types and their corresponding SQL types. + * + * @return the supported data types + */ + Map, String> getSupportedDataTypes(); + + /** + * Gets the supported vector types and their corresponding SQL types. + * + * @return the supported vector types + */ + Map, String> getSupportedVectorTypes(); + /** * Prepares the vector store. * Executes any necessary setup steps for the vector store. @@ -81,7 +105,8 @@ void createCollection(String collectionName, Class recordClass, * @return the records */ List getRecords(String collectionName, List keys, - VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, + VectorStoreRecordDefinition recordDefinition, + VectorStoreRecordMapper mapper, GetRecordOptions options); /** diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java index 6135b284..44fc2338 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java @@ -2,6 +2,10 @@ package com.microsoft.semantickernel.connectors.data.jdbc; import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreRecordMapper; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; @@ -10,6 +14,7 @@ import com.microsoft.semantickernel.exceptions.SKException; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.lang.reflect.Field; +import java.sql.ResultSet; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -23,8 +28,8 @@ public class JDBCVectorStoreRecordCollection private final String collectionName; private final VectorStoreRecordDefinition recordDefinition; + private final VectorStoreRecordMapper vectorStoreRecordMapper; private final JDBCVectorStoreRecordCollectionOptions options; - private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper; private final JDBCVectorStoreQueryProvider queryProvider; /** @@ -47,16 +52,6 @@ public JDBCVectorStoreRecordCollection( ? VectorStoreRecordDefinition.fromRecordClass(options.getRecordClass()) : options.getRecordDefinition(); - // If mapper is not provided, set a default one - if (options.getVectorStoreRecordMapper() == null) { - vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder() - .withRecordClass(options.getRecordClass()) - .withVectorStoreRecordDefinition(recordDefinition) - .build(); - } else { - vectorStoreRecordMapper = options.getVectorStoreRecordMapper(); - } - // If the query provider is not provided, set a default one if (options.getQueryProvider() == null) { this.queryProvider = JDBCVectorStoreDefaultQueryProvider.builder() @@ -66,6 +61,31 @@ public JDBCVectorStoreRecordCollection( this.queryProvider = options.getQueryProvider(); } + // If mapper is not provided, set a default one + if (options.getVectorStoreRecordMapper() == null) { + // Default mapper for PostgreSQL + if (this.queryProvider instanceof PostgreSQLVectorStoreQueryProvider) { + vectorStoreRecordMapper = PostgreSQLVectorStoreRecordMapper.builder() + .withRecordClass(options.getRecordClass()) + .withVectorStoreRecordDefinition(recordDefinition) + .build(); + // Default mapper for MySQL + } else if (this.queryProvider instanceof MySQLVectorStoreQueryProvider) { + vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder() + .withRecordClass(options.getRecordClass()) + .withVectorStoreRecordDefinition(recordDefinition) + .build(); + // Default mapper for other databases + } else { + vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder() + .withRecordClass(options.getRecordClass()) + .withVectorStoreRecordDefinition(recordDefinition) + .build(); + } + } else { + vectorStoreRecordMapper = options.getVectorStoreRecordMapper(); + } + // Check if the types are supported queryProvider.validateSupportedTypes(options.getRecordClass(), recordDefinition); } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java index af1ec49e..f6aa871d 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java @@ -1,16 +1,19 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.jdbc; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.sql.ResultSet; + import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider.validateSQLidentifier; import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_COLLECTIONS_TABLE; import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_PREFIX_FOR_COLLECTION_TABLES; public class JDBCVectorStoreRecordCollectionOptions { private final Class recordClass; - private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper; + private final VectorStoreRecordMapper vectorStoreRecordMapper; private final VectorStoreRecordDefinition recordDefinition; private final JDBCVectorStoreQueryProvider queryProvider; private final String collectionsTableName; @@ -19,7 +22,7 @@ public class JDBCVectorStoreRecordCollectionOptions { private JDBCVectorStoreRecordCollectionOptions( Class recordClass, VectorStoreRecordDefinition recordDefinition, - JDBCVectorStoreRecordMapper vectorStoreRecordMapper, + VectorStoreRecordMapper vectorStoreRecordMapper, JDBCVectorStoreQueryProvider queryProvider, String collectionsTableName, String prefixForCollectionTables) { @@ -60,7 +63,7 @@ public VectorStoreRecordDefinition getRecordDefinition() { * Gets the vector store record mapper. * @return the vector store record mapper */ - public JDBCVectorStoreRecordMapper getVectorStoreRecordMapper() { + public VectorStoreRecordMapper getVectorStoreRecordMapper() { return vectorStoreRecordMapper; } @@ -92,7 +95,7 @@ public JDBCVectorStoreQueryProvider getQueryProvider() { public static class Builder { private Class recordClass; private VectorStoreRecordDefinition recordDefinition; - private JDBCVectorStoreRecordMapper vectorStoreRecordMapper; + private VectorStoreRecordMapper vectorStoreRecordMapper; private JDBCVectorStoreQueryProvider queryProvider; private String collectionsTableName = DEFAULT_COLLECTIONS_TABLE; private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; @@ -123,7 +126,7 @@ public Builder withRecordDefinition(VectorStoreRecordDefinition recordDe * @return the builder */ public Builder withVectorStoreRecordMapper( - JDBCVectorStoreRecordMapper vectorStoreRecordMapper) { + VectorStoreRecordMapper vectorStoreRecordMapper) { this.vectorStoreRecordMapper = vectorStoreRecordMapper; return this; } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java similarity index 96% rename from semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java rename to semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java index 72ecd87e..ff19017c 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java @@ -1,8 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. -package com.microsoft.semantickernel.connectors.data.jdbc; +package com.microsoft.semantickernel.connectors.data.mysql; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField; diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java new file mode 100644 index 00000000..d9d5deff --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java @@ -0,0 +1,338 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.postgres; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import com.microsoft.semantickernel.exceptions.SKException; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.sql.DataSource; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class PostgreSQLVectorStoreQueryProvider extends + JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { + + private Map, String> supportedKeyTypes; + private Map, String> supportedDataTypes; + private Map, String> supportedVectorTypes; + + private final DataSource dataSource; + private final String collectionsTable; + private final String prefixForCollectionTables; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + private PostgreSQLVectorStoreQueryProvider(DataSource dataSource, String collectionsTable, + String prefixForCollectionTables) { + super(dataSource, collectionsTable, prefixForCollectionTables); + this.dataSource = dataSource; + this.collectionsTable = collectionsTable; + this.prefixForCollectionTables = prefixForCollectionTables; + + supportedKeyTypes = new HashMap<>(); + supportedKeyTypes.put(String.class, "VARCHAR(255)"); + + supportedDataTypes = new HashMap<>(); + supportedDataTypes.put(String.class, "TEXT"); + supportedDataTypes.put(Integer.class, "INTEGER"); + supportedDataTypes.put(int.class, "INTEGER"); + supportedDataTypes.put(Long.class, "BIGINT"); + supportedDataTypes.put(long.class, "BIGINT"); + supportedDataTypes.put(Float.class, "REAL"); + supportedDataTypes.put(float.class, "REAL"); + supportedDataTypes.put(Double.class, "DOUBLE PRECISION"); + supportedDataTypes.put(double.class, "DOUBLE PRECISION"); + supportedDataTypes.put(Boolean.class, "BOOLEAN"); + supportedDataTypes.put(boolean.class, "BOOLEAN"); + supportedDataTypes.put(OffsetDateTime.class, "TIMESTAMPTZ"); + + supportedVectorTypes = new HashMap<>(); + supportedDataTypes.put(String.class, "TEXT"); + supportedVectorTypes.put(List.class, "VECTOR(%d)"); + supportedVectorTypes.put(Collection.class, "VECTOR(%d)"); + } + + /** + * Gets the supported key types and their corresponding SQL types. + * + * @return the supported key types + */ + @Override + public Map, String> getSupportedKeyTypes() { + return new HashMap<>(this.supportedKeyTypes); + } + + /** + * Gets the supported data types and their corresponding SQL types. + * + * @return the supported data types + */ + @Override + public Map, String> getSupportedDataTypes() { + return new HashMap<>(this.supportedDataTypes); + } + + /** + * Gets the supported vector types and their corresponding SQL types. + * + * @return the supported vector types + */ + @Override + public Map, String> getSupportedVectorTypes() { + return new HashMap<>(this.supportedVectorTypes); + } + + /** + * Creates a new builder. + * @return the builder + */ + public static PostgreSQLVectorStoreQueryProvider.Builder builder() { + return new PostgreSQLVectorStoreQueryProvider.Builder(); + } + + /** + * Prepares the vector store. + * Executes any necessary setup steps for the vector store. + * + * @throws SKException if an error occurs while preparing the vector store + */ + @Override + public void prepareVectorStore() { + super.prepareVectorStore(); + + // Create the vector extension + String pgVector = "CREATE EXTENSION IF NOT EXISTS vector"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement createPgVector = connection.prepareStatement(pgVector)) { + createPgVector.execute(); + } catch (SQLException e) { + throw new SKException("Failed to prepare vector store", e); + } + } + + private String getColumnNamesAndTypesForVectorFields(List fields, + Class recordClass) { + StringBuilder columnNames = new StringBuilder(); + for (VectorStoreRecordVectorField field : fields) { + try { + Field declaredField = recordClass.getDeclaredField(field.getName()); + if (columnNames.length() > 0) { + columnNames.append(", "); + } + + if (declaredField.getType().equals(String.class)) { + columnNames.append(field.getName()).append(" ") + .append(supportedVectorTypes.get(String.class)); + } else { + // Get the vector type and dimensions + String type = String.format(supportedVectorTypes.get(declaredField.getType()), + field.getDimensions()); + columnNames.append(field.getName()).append(" ").append(type); + } + } catch (NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + + return columnNames.toString(); + } + + /** + * Creates a collection. + * + * @param collectionName the collection name + * @param recordClass the record class + * @param recordDefinition the record definition + * @throws SKException if an error occurs while creating the collection + */ + @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers + public void createCollection(String collectionName, Class recordClass, + VectorStoreRecordDefinition recordDefinition) { + Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass); + List dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass); + + String createStorageTable = "CREATE TABLE IF NOT EXISTS " + + getCollectionTableName(collectionName) + + " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, " + + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", " + + getColumnNamesAndTypesForVectorFields(recordDefinition.getVectorFields(), recordClass) + + ");"; + + String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable) + + " (collectionId) VALUES (?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement createTable = connection.prepareStatement(createStorageTable)) { + createTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to create collection", e); + } + + try (Connection connection = dataSource.getConnection(); + PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) { + insert.setObject(1, collectionName); + insert.execute(); + } catch (SQLException e) { + throw new SKException("Failed to insert collection", e); + } + } + + private void setStatementValues(PreparedStatement statement, Object record, + List fields) { + for (int i = 0; i < fields.size(); ++i) { + VectorStoreRecordField field = fields.get(i); + try { + Field recordField = record.getClass().getDeclaredField(field.getName()); + recordField.setAccessible(true); + Object value = recordField.get(record); + + if (field instanceof VectorStoreRecordKeyField) { + statement.setObject(i + 1, (String) value); + } else if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = record.getClass().getDeclaredField(field.getName()) + .getType(); + + // If the vector field is other than String, serialize it to JSON + if (vectorType.equals(String.class)) { + statement.setObject(i + 1, value); + } else { + // Serialize the vector to JSON + statement.setString(i + 1, new ObjectMapper().writeValueAsString(value)); + } + } else { + statement.setObject(i + 1, value); + } + } catch (NoSuchFieldException | IllegalAccessException | SQLException e) { + throw new SKException("Failed to set statement values", e); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + private String getWildcardStringWithCast(List fields) { + StringBuilder wildcardString = new StringBuilder(); + int wildcards = fields.size(); + for (int i = 0; i < wildcards; ++i) { + if (i > 0) { + wildcardString.append(", "); + } + wildcardString.append("?"); + // Add casting for vector fields + if (fields.get(i) instanceof VectorStoreRecordVectorField) { + wildcardString.append("::vector"); + } + } + return wildcardString.toString(); + } + + /** + * Upserts records into the collection. + * @param collectionName the collection name + * @param records the records to upsert + * @param recordDefinition the record definition + * @param options the upsert options + * @throws SKException if the upsert fails + */ + @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers + public void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { + validateSQLidentifier(getCollectionTableName(collectionName)); + List fields = recordDefinition.getAllFields(); + + StringBuilder onDuplicateKeyUpdate = new StringBuilder(); + for (VectorStoreRecordField field : fields) { + if (field instanceof VectorStoreRecordKeyField) { + continue; + } + if (onDuplicateKeyUpdate.length() > 0) { + onDuplicateKeyUpdate.append(", "); + } + onDuplicateKeyUpdate.append(field.getName()) + .append(" = EXCLUDED.") + .append(field.getName()); + } + + String query = "INSERT INTO " + getCollectionTableName(collectionName) + + " (" + getQueryColumnsFromFields(fields) + ")" + + " VALUES (" + getWildcardStringWithCast(fields) + ")" + + " ON CONFLICT (" + recordDefinition.getKeyField().getName() + ") DO UPDATE SET " + + onDuplicateKeyUpdate; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + for (Object record : records) { + setStatementValues(statement, record, recordDefinition.getAllFields()); + statement.addBatch(); + } + + statement.executeBatch(); + } catch (SQLException e) { + throw new SKException("Failed to upsert records", e); + } + } + + public static class Builder + extends JDBCVectorStoreDefaultQueryProvider.Builder { + private DataSource dataSource; + private String collectionsTable = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + public PostgreSQLVectorStoreQueryProvider.Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTable the collections table name + * @return the builder + */ + public PostgreSQLVectorStoreQueryProvider.Builder withCollectionsTable( + String collectionsTable) { + this.collectionsTable = validateSQLidentifier(collectionsTable); + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public PostgreSQLVectorStoreQueryProvider.Builder withPrefixForCollectionTables( + String prefixForCollectionTables) { + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); + return this; + } + + public PostgreSQLVectorStoreQueryProvider build() { + if (dataSource == null) { + throw new SKException("DataSource is required"); + } + + return new PostgreSQLVectorStoreQueryProvider(dataSource, collectionsTable, + prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreRecordMapper.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreRecordMapper.java new file mode 100644 index 00000000..83b821c3 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreRecordMapper.java @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.postgres; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; +import com.microsoft.semantickernel.exceptions.SKException; +import org.postgresql.util.PGobject; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.List; +import java.util.function.Function; + +public class PostgreSQLVectorStoreRecordMapper + extends VectorStoreRecordMapper { + + /** + * Constructs a new instance of the VectorStoreRecordMapper. + * + * @param storageModelToRecordMapper the function to convert a storage model to a record + */ + protected PostgreSQLVectorStoreRecordMapper( + Function storageModelToRecordMapper) { + super(null, storageModelToRecordMapper); + } + + /** + * Creates a new builder. + * + * @param the record type + * @return the builder + */ + public static Builder builder() { + return new Builder<>(); + } + + public static class Builder + implements SemanticKernelBuilder> { + private Class recordClass; + private VectorStoreRecordDefinition vectorStoreRecordDefinition; + + /** + * Sets the record class. + * + * @param recordClass the record class + * @return the builder + */ + public Builder withRecordClass(Class recordClass) { + this.recordClass = recordClass; + return this; + } + + /** + * Sets the vector store record definition. + * + * @param vectorStoreRecordDefinition the vector store record definition + * @return the builder + */ + public Builder withVectorStoreRecordDefinition( + VectorStoreRecordDefinition vectorStoreRecordDefinition) { + this.vectorStoreRecordDefinition = vectorStoreRecordDefinition; + return this; + } + + /** + * Builds the {@link PostgreSQLVectorStoreRecordMapper}. + * + * @return the {@link PostgreSQLVectorStoreRecordMapper} + */ + public PostgreSQLVectorStoreRecordMapper build() { + if (recordClass == null) { + throw new IllegalArgumentException("recordClass is required"); + } + if (vectorStoreRecordDefinition == null) { + throw new IllegalArgumentException("vectorStoreRecordDefinition is required"); + } + + return new PostgreSQLVectorStoreRecordMapper<>( + resultSet -> { + try { + Constructor constructor = recordClass.getDeclaredConstructor(); + constructor.setAccessible(true); + Record record = (Record) constructor.newInstance(); + + // Select fields from the record definition. + // Check if vector fields are present in the result set. + List fields; + ResultSetMetaData metaData = resultSet.getMetaData(); + if (metaData.getColumnCount() == vectorStoreRecordDefinition.getAllFields() + .size()) { + fields = vectorStoreRecordDefinition.getAllFields(); + } else { + fields = vectorStoreRecordDefinition.getNonVectorFields(); + } + + for (VectorStoreRecordField field : fields) { + Object value = resultSet.getObject(field.getName()); + Field recordField = recordClass.getDeclaredField(field.getName()); + recordField.setAccessible(true); + + // If the field is a vector field, deserialize the JSON string + if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = recordField.getType(); + + // If the vector type is a string, set the value directly + if (vectorType.equals(String.class)) { + recordField.set(record, value); + } else { + // Deserialize the pgvector string to the vector type + PGobject pgObject = (PGobject) value; + recordField.set(record, + new ObjectMapper().readValue(pgObject.getValue(), + vectorType)); + } + } else { + recordField.set(record, value); + } + } + + return record; + } catch (NoSuchMethodException e) { + throw new SKException("Default constructor not found.", e); + } catch (InstantiationException | InvocationTargetException e) { + throw new SKException(String.format( + "SK cannot instantiate %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (JsonProcessingException e) { + throw new SKException(String.format( + "SK cannot deserialize %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (SQLException | NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + }); + } + } +}