diff --git a/api-test/integration-tests/pom.xml b/api-test/integration-tests/pom.xml
index 7e669927..cc93c7fe 100644
--- a/api-test/integration-tests/pom.xml
+++ b/api-test/integration-tests/pom.xml
@@ -73,6 +73,11 @@
8.0.33
test
+
+ org.postgresql
+ postgresql
+ 42.7.2
+
org.testcontainers
diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java
index e43842fc..ad10ad64 100644
--- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java
+++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java
@@ -18,7 +18,7 @@ public class Hotel {
@VectorStoreRecordVectorAttribute(dimensions = 3)
private final List descriptionEmbedding;
@VectorStoreRecordDataAttribute
- private final double rating;
+ private double rating;
public Hotel() {
this(null, null, 0, null, null, 0.0);
@@ -56,4 +56,8 @@ public List getDescriptionEmbedding() {
public double getRating() {
return rating;
}
+
+ public void setRating(double rating) {
+ this.rating = rating;
+ }
}
diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java
index c873ea5b..8bee5a76 100644
--- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java
+++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java
@@ -1,104 +1,151 @@
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-import static org.junit.jupiter.api.Assertions.assertNull;
-
+import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollection;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
-import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
+import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
+import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.mysql.cj.jdbc.MysqlDataSource;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import javax.annotation.Nonnull;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.EnumSource;
+import org.postgresql.ds.PGSimpleDataSource;
import org.testcontainers.containers.MySQLContainer;
+import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
+import org.testcontainers.utility.DockerImageName;
+
+import javax.annotation.Nonnull;
+import javax.sql.DataSource;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
@Testcontainers
public class JDBCVectorStoreRecordCollectionTest {
@Container
- private static final MySQLContainer> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
- private static final String MYSQL_USER = "test";
- private static final String MYSQL_PASSWORD = "test";
- private static MysqlDataSource dataSource;
-
- @BeforeAll
- static void setup() {
- dataSource = new MysqlDataSource();
- dataSource.setUrl(CONTAINER.getJdbcUrl());
- dataSource.setUser(MYSQL_USER);
- dataSource.setPassword(MYSQL_PASSWORD);
+ private static final MySQLContainer> MYSQL_CONTAINER = new MySQLContainer<>("mysql:5.7.34");
+
+ private static final DockerImageName PGVECTOR = DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres");
+ @Container
+ private static final PostgreSQLContainer> POSTGRESQL_CONTAINER = new PostgreSQLContainer<>(PGVECTOR);
+
+ public enum QueryProvider {
+ MySQL,
+ PostgreSQL
}
- private JDBCVectorStoreRecordCollection buildRecordCollection(
- @Nonnull String collectionName) {
- JDBCVectorStoreRecordCollection recordCollection = new JDBCVectorStoreRecordCollection<>(
- dataSource,
- collectionName,
- JDBCVectorStoreRecordCollectionOptions.builder()
- .withRecordClass(Hotel.class)
- .withQueryProvider(MySQLVectorStoreQueryProvider.builder()
- .withDataSource(dataSource)
- .build())
- .build());
+ private JDBCVectorStoreRecordCollection buildRecordCollection(QueryProvider provider, @Nonnull String collectionName) {
+ JDBCVectorStoreQueryProvider queryProvider;
+ DataSource dataSource;
+
+ switch (provider) {
+ case MySQL:
+ MysqlDataSource mysqlDataSource = new MysqlDataSource();
+ mysqlDataSource.setUrl(MYSQL_CONTAINER.getJdbcUrl());
+ mysqlDataSource.setUser(MYSQL_CONTAINER.getUsername());
+ mysqlDataSource.setPassword(MYSQL_CONTAINER.getPassword());
+ dataSource = mysqlDataSource;
+ queryProvider = MySQLVectorStoreQueryProvider.builder()
+ .withDataSource(dataSource)
+ .build();
+ break;
+ case PostgreSQL:
+ PGSimpleDataSource pgSimpleDataSource = new PGSimpleDataSource();
+ pgSimpleDataSource.setUrl(POSTGRESQL_CONTAINER.getJdbcUrl());
+ pgSimpleDataSource.setUser(POSTGRESQL_CONTAINER.getUsername());
+ pgSimpleDataSource.setPassword(POSTGRESQL_CONTAINER.getPassword());
+ dataSource = pgSimpleDataSource;
+ queryProvider = PostgreSQLVectorStoreQueryProvider.builder()
+ .withDataSource(dataSource)
+ .build();
+ break;
+ default:
+ throw new IllegalArgumentException("Unknown query provider: " + provider);
+ }
+
+
+ JDBCVectorStoreRecordCollection recordCollection = new JDBCVectorStoreRecordCollection<>(
+ dataSource,
+ collectionName,
+ JDBCVectorStoreRecordCollectionOptions.builder()
+ .withRecordClass(Hotel.class)
+ .withQueryProvider(queryProvider)
+ .build());
recordCollection.prepareAsync().block();
recordCollection.createCollectionIfNotExistsAsync().block();
return recordCollection;
}
- @Test
- public void buildRecordCollection() {
- assertNotNull(buildRecordCollection("buildTest"));
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void buildRecordCollection(QueryProvider provider) {
+ assertNotNull(buildRecordCollection(provider, "buildTest"));
}
private List getHotels() {
return List.of(
- new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f),
- 4.0),
- new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f),
- 3.0),
- new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f),
- 5.0),
- new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f),
- 4.0),
- new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f),
- 5.0)
+ new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0),
+ new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), 3.0),
+ new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0),
+ new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0),
+ new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0)
);
}
- @Test
- public void upsertAndGetRecordAsync() {
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void upsertAndGetRecordAsync(QueryProvider provider) {
String collectionName = "upsertAndGetRecordAsync";
- JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName);
+ JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName);
List hotels = getHotels();
for (Hotel hotel : hotels) {
- recordStore.upsertAsync(hotel, null).block();
+ recordCollection.upsertAsync(hotel, null).block();
+ }
+
+ // Upsert the first time
+ for (Hotel hotel : hotels) {
+ Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
+ assertNotNull(retrievedHotel);
+ assertEquals(hotel.getId(), retrievedHotel.getId());
+ assertEquals(hotel.getRating(), retrievedHotel.getRating());
+
+ // Update the rating
+ hotel.setRating(1.0);
+ }
+
+ // Upsert the second time with updated rating
+ for (Hotel hotel : hotels) {
+ recordCollection.upsertAsync(hotel, null).block();
}
for (Hotel hotel : hotels) {
- Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
+ Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
+ assertEquals(1.0, retrievedHotel.getRating());
}
}
- @Test
- public void getBatchAsync() {
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void getBatchAsync(QueryProvider provider) {
String collectionName = "getBatchAsync";
- JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName);
+ JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName);
List hotels = getHotels();
for (Hotel hotel : hotels) {
- recordStore.upsertAsync(hotel, null).block();
+ recordCollection.upsertAsync(hotel, null).block();
}
List keys = new ArrayList<>();
@@ -106,99 +153,104 @@ public void getBatchAsync() {
keys.add(hotel.getId());
}
- List retrievedHotels = recordStore.getBatchAsync(keys, null).block();
+ List retrievedHotels = recordCollection.getBatchAsync(keys, null).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
}
- @Test
- public void upsertBatchAndGetBatchAsync() {
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void upsertBatchAndGetBatchAsync(QueryProvider provider) {
String collectionName = "upsertBatchAndGetBatchAsync";
- JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName);
+ JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName);
List hotels = getHotels();
- recordStore.upsertBatchAsync(hotels, null).block();
+ recordCollection.upsertBatchAsync(hotels, null).block();
List keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}
- List retrievedHotels = recordStore.getBatchAsync(keys, null).block();
+ List retrievedHotels = recordCollection.getBatchAsync(keys, null).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
}
- @Test
- public void insertAndReplaceAsync() {
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void insertAndReplaceAsync(QueryProvider provider) {
String collectionName = "insertAndReplaceAsync";
- JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName);
+ JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName);
List hotels = getHotels();
- recordStore.upsertBatchAsync(hotels, null).block();
- recordStore.upsertBatchAsync(hotels, null).block();
- recordStore.upsertBatchAsync(hotels, null).block();
+ recordCollection.upsertBatchAsync(hotels, null).block();
+ recordCollection.upsertBatchAsync(hotels, null).block();
+ recordCollection.upsertBatchAsync(hotels, null).block();
List keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}
- List retrievedHotels = recordStore.getBatchAsync(keys, null).block();
+ List retrievedHotels = recordCollection.getBatchAsync(keys, null).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
}
- @Test
- public void deleteRecordAsync() {
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void deleteRecordAsync(QueryProvider provider) {
String collectionName = "deleteRecordAsync";
- JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName);
+ JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName);
List hotels = getHotels();
- recordStore.upsertBatchAsync(hotels, null).block();
+ recordCollection.upsertBatchAsync(hotels, null).block();
for (Hotel hotel : hotels) {
- recordStore.deleteAsync(hotel.getId(), null).block();
- Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
+ recordCollection.deleteAsync(hotel.getId(), null).block();
+ Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
assertNull(retrievedHotel);
}
}
- @Test
- public void deleteBatchAsync() {
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void deleteBatchAsync(QueryProvider provider) {
String collectionName = "deleteBatchAsync";
- JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName);
+ JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName);
List hotels = getHotels();
- recordStore.upsertBatchAsync(hotels, null).block();
+ recordCollection.upsertBatchAsync(hotels, null).block();
List keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}
- recordStore.deleteBatchAsync(keys, null).block();
+ recordCollection.deleteBatchAsync(keys, null).block();
for (String key : keys) {
- Hotel retrievedHotel = recordStore.getAsync(key, null).block();
+ Hotel retrievedHotel = recordCollection.getAsync(key, null).block();
assertNull(retrievedHotel);
}
}
- @Test
- public void getWithNoVectors() {
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void getWithNoVectors(QueryProvider provider) {
String collectionName = "getWithNoVectors";
- JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName);
+ JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName);
List hotels = getHotels();
- recordStore.upsertBatchAsync(hotels, null).block();
+ recordCollection.upsertBatchAsync(hotels, null).block();
GetRecordOptions options = GetRecordOptions.builder()
.includeVectors(false)
.build();
for (Hotel hotel : hotels) {
- Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block();
+ Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
assertNull(retrievedHotel.getDescriptionEmbedding());
@@ -209,20 +261,21 @@ public void getWithNoVectors() {
.build();
for (Hotel hotel : hotels) {
- Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block();
+ Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
assertNotNull(retrievedHotel.getDescriptionEmbedding());
}
}
- @Test
- public void getBatchWithNoVectors() {
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void getBatchWithNoVectors(QueryProvider provider) {
String collectionName = "getBatchWithNoVectors";
- JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName);
+ JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName);
List hotels = getHotels();
- recordStore.upsertBatchAsync(hotels, null).block();
+ recordCollection.upsertBatchAsync(hotels, null).block();
GetRecordOptions options = GetRecordOptions.builder()
.includeVectors(false)
@@ -233,7 +286,7 @@ public void getBatchWithNoVectors() {
keys.add(hotel.getId());
}
- List retrievedHotels = recordStore.getBatchAsync(keys, options).block();
+ List retrievedHotels = recordCollection.getBatchAsync(keys, options).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
@@ -245,7 +298,7 @@ public void getBatchWithNoVectors() {
.includeVectors(true)
.build();
- retrievedHotels = recordStore.getBatchAsync(keys, options).block();
+ retrievedHotels = recordCollection.getBatchAsync(keys, options).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java
index e0906a1c..8c2fbfd0 100644
--- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java
+++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java
@@ -1,45 +1,71 @@
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
-import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
+import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
+import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
+import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.mysql.cj.jdbc.MysqlDataSource;
-import java.util.Arrays;
-import java.util.List;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.EnumSource;
+import org.postgresql.ds.PGSimpleDataSource;
import org.testcontainers.containers.MySQLContainer;
+import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
+import org.testcontainers.utility.DockerImageName;
+
+import javax.annotation.Nonnull;
+import javax.sql.DataSource;
+import java.util.Arrays;
+import java.util.List;
+
+import com.microsoft.semantickernel.tests.connectors.memory.jdbc.JDBCVectorStoreRecordCollectionTest.QueryProvider;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
@Testcontainers
public class JDBCVectorStoreTest {
@Container
- private static final MySQLContainer> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
- private static final String MYSQL_USER = "test";
- private static final String MYSQL_PASSWORD = "test";
- private static MysqlDataSource dataSource;
-
- @BeforeAll
- static void setup() {
- dataSource = new MysqlDataSource();
- dataSource.setUrl(CONTAINER.getJdbcUrl());
- dataSource.setUser(MYSQL_USER);
- dataSource.setPassword(MYSQL_PASSWORD);
- }
+ private static final MySQLContainer> MYSQL_CONTAINER = new MySQLContainer<>("mysql:5.7.34");
+
+ private static final DockerImageName PGVECTOR = DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres");
+ @Container
+ private static final PostgreSQLContainer> POSTGRESQL_CONTAINER = new PostgreSQLContainer<>(PGVECTOR);
+
+ private JDBCVectorStore buildVectorStore(QueryProvider provider) {
+ JDBCVectorStoreQueryProvider queryProvider;
+ DataSource dataSource;
+
+ switch (provider) {
+ case MySQL:
+ MysqlDataSource mysqlDataSource = new MysqlDataSource();
+ mysqlDataSource.setUrl(MYSQL_CONTAINER.getJdbcUrl());
+ mysqlDataSource.setUser(MYSQL_CONTAINER.getUsername());
+ mysqlDataSource.setPassword(MYSQL_CONTAINER.getPassword());
+ dataSource = mysqlDataSource;
+ queryProvider = MySQLVectorStoreQueryProvider.builder()
+ .withDataSource(dataSource)
+ .build();
+ break;
+ case PostgreSQL:
+ PGSimpleDataSource pgSimpleDataSource = new PGSimpleDataSource();
+ pgSimpleDataSource.setUrl(POSTGRESQL_CONTAINER.getJdbcUrl());
+ pgSimpleDataSource.setUser(POSTGRESQL_CONTAINER.getUsername());
+ pgSimpleDataSource.setPassword(POSTGRESQL_CONTAINER.getPassword());
+ dataSource = pgSimpleDataSource;
+ queryProvider = PostgreSQLVectorStoreQueryProvider.builder()
+ .withDataSource(dataSource)
+ .build();
+ break;
+ default:
+ throw new IllegalArgumentException("Unknown query provider: " + provider);
+ }
- @Test
- public void getCollectionNamesAsync() {
- MySQLVectorStoreQueryProvider queryProvider = MySQLVectorStoreQueryProvider.builder()
- .withDataSource(dataSource)
- .build();
- JDBCVectorStore vectorStore = JDBCVectorStore.builder()
+ JDBCVectorStore vectorStore = JDBCVectorStore.builder()
.withDataSource(dataSource)
.withOptions(
JDBCVectorStoreOptions.builder()
@@ -48,6 +74,16 @@ public void getCollectionNamesAsync() {
)
.build();
+ vectorStore.prepareAsync().block();
+ return vectorStore;
+ }
+
+
+ @ParameterizedTest
+ @EnumSource(QueryProvider.class)
+ public void getCollectionNamesAsync(QueryProvider provider) {
+ JDBCVectorStore vectorStore = buildVectorStore(provider);
+
vectorStore.getCollectionNamesAsync().block();
List collectionNames = Arrays.asList("collection1", "collection2", "collection3");
diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java
index 8aa4bddb..ed1d8bd5 100644
--- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java
+++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java
@@ -8,7 +8,7 @@
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
-import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
+import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
diff --git a/semantickernel-experimental/pom.xml b/semantickernel-experimental/pom.xml
index 20fa172a..3efd8ab0 100644
--- a/semantickernel-experimental/pom.xml
+++ b/semantickernel-experimental/pom.xml
@@ -109,6 +109,11 @@
+
+ org.postgresql
+ postgresql
+ 42.7.2
+
diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java
index 096f240a..f1795083 100644
--- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java
+++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.connectors.data.jdbc;
+import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
@@ -24,14 +25,27 @@
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
public class JDBCVectorStoreDefaultQueryProvider
implements JDBCVectorStoreQueryProvider {
- private static final Map, String> supportedKeyTypes;
- private static final Map, String> supportedDataTypes;
- private static final Map, String> supportedVectorTypes;
- static {
+ private Map, String> supportedKeyTypes;
+ private Map, String> supportedDataTypes;
+ private Map, String> supportedVectorTypes;
+ private final DataSource dataSource;
+ private final String collectionsTable;
+ private final String prefixForCollectionTables;
+
+ @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed
+ protected JDBCVectorStoreDefaultQueryProvider(
+ @Nonnull DataSource dataSource,
+ @Nonnull String collectionsTable,
+ @Nonnull String prefixForCollectionTables) {
+ this.dataSource = dataSource;
+ this.collectionsTable = collectionsTable;
+ this.prefixForCollectionTables = prefixForCollectionTables;
+
supportedKeyTypes = new HashMap<>();
supportedKeyTypes.put(String.class, "VARCHAR(255)");
@@ -54,19 +68,6 @@ public class JDBCVectorStoreDefaultQueryProvider
supportedVectorTypes.put(List.class, "TEXT");
supportedVectorTypes.put(Collection.class, "TEXT");
}
- private final DataSource dataSource;
- private final String collectionsTable;
- private final String prefixForCollectionTables;
-
- @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed
- protected JDBCVectorStoreDefaultQueryProvider(
- @Nonnull DataSource dataSource,
- @Nonnull String collectionsTable,
- @Nonnull String prefixForCollectionTables) {
- this.dataSource = dataSource;
- this.collectionsTable = collectionsTable;
- this.prefixForCollectionTables = prefixForCollectionTables;
- }
/**
* Creates a new builder.
@@ -82,14 +83,9 @@ public static Builder builder() {
* @return the formatted wildcard string
*/
protected String getWildcardString(int wildcards) {
- StringBuilder wildcardString = new StringBuilder();
- for (int i = 0; i < wildcards; ++i) {
- wildcardString.append("?");
- if (i < wildcards - 1) {
- wildcardString.append(", ");
- }
- }
- return wildcardString.toString();
+ return Stream.generate(() -> "?")
+ .limit(wildcards)
+ .collect(Collectors.joining(", "));
}
/**
@@ -102,6 +98,12 @@ protected String getQueryColumnsFromFields(List fields)
.collect(Collectors.joining(", "));
}
+ /**
+ * Formats the column names and types for a table.
+ * @param fields the fields
+ * @param types the types
+ * @return the formatted column names and types
+ */
protected String getColumnNamesAndTypes(List fields, Map, String> types) {
List columns = fields.stream()
.map(field -> field.getName() + " " + types.get(field.getType()))
@@ -114,6 +116,36 @@ protected String getCollectionTableName(String collectionName) {
return validateSQLidentifier(prefixForCollectionTables + collectionName);
}
+ /**
+ * Gets the supported key types and their corresponding SQL types.
+ *
+ * @return the supported key types
+ */
+ @Override
+ public Map, String> getSupportedKeyTypes() {
+ return new HashMap<>(this.supportedKeyTypes);
+ }
+
+ /**
+ * Gets the supported data types and their corresponding SQL types.
+ *
+ * @return the supported data types
+ */
+ @Override
+ public Map, String> getSupportedDataTypes() {
+ return new HashMap<>(this.supportedDataTypes);
+ }
+
+ /**
+ * Gets the supported vector types and their corresponding SQL types.
+ *
+ * @return the supported vector types
+ */
+ @Override
+ public Map, String> getSupportedVectorTypes() {
+ return new HashMap<>(this.supportedVectorTypes);
+ }
+
/**
* Prepares the vector store.
* Executes any necessary setup steps for the vector store.
@@ -146,11 +178,12 @@ public void validateSupportedTypes(Class> recordClass,
VectorStoreRecordDefinition recordDefinition) {
VectorStoreRecordDefinition.validateSupportedTypes(
Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)),
- supportedKeyTypes.keySet());
+ getSupportedKeyTypes().keySet());
VectorStoreRecordDefinition.validateSupportedTypes(
- recordDefinition.getDataDeclaredFields(recordClass), supportedDataTypes.keySet());
+ recordDefinition.getDataDeclaredFields(recordClass), getSupportedDataTypes().keySet());
VectorStoreRecordDefinition.validateSupportedTypes(
- recordDefinition.getVectorDeclaredFields(recordClass), supportedVectorTypes.keySet());
+ recordDefinition.getVectorDeclaredFields(recordClass),
+ getSupportedVectorTypes().keySet());
}
/**
@@ -194,8 +227,8 @@ public void createCollection(String collectionName, Class> recordClass,
String createStorageTable = "CREATE TABLE IF NOT EXISTS "
+ getCollectionTableName(collectionName)
+ " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, "
- + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", "
- + getColumnNamesAndTypes(vectorDeclaredFields, supportedVectorTypes) + ");";
+ + getColumnNamesAndTypes(dataDeclaredFields, getSupportedDataTypes()) + ", "
+ + getColumnNamesAndTypes(vectorDeclaredFields, getSupportedVectorTypes()) + ");";
String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
+ " (collectionId) VALUES (?)";
@@ -284,7 +317,8 @@ public List getCollectionNames() {
*/
@Override
public List getRecords(String collectionName, List keys,
- VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper,
+ VectorStoreRecordDefinition recordDefinition,
+ VectorStoreRecordMapper mapper,
GetRecordOptions options) {
List fields;
if (options == null || options.includeVectors()) {
diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java
index 26d976aa..6009b885 100644
--- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java
+++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java
@@ -2,12 +2,15 @@
package com.microsoft.semantickernel.connectors.data.jdbc;
import com.microsoft.semantickernel.builders.SemanticKernelBuilder;
+import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions;
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions;
+import java.sql.ResultSet;
import java.util.List;
+import java.util.Map;
/**
* The JDBC vector store query provider.
@@ -24,6 +27,27 @@ public interface JDBCVectorStoreQueryProvider {
*/
String DEFAULT_PREFIX_FOR_COLLECTION_TABLES = "SKCollection_";
+ /**
+ * Gets the supported key types and their corresponding SQL types.
+ *
+ * @return the supported key types
+ */
+ Map, String> getSupportedKeyTypes();
+
+ /**
+ * Gets the supported data types and their corresponding SQL types.
+ *
+ * @return the supported data types
+ */
+ Map, String> getSupportedDataTypes();
+
+ /**
+ * Gets the supported vector types and their corresponding SQL types.
+ *
+ * @return the supported vector types
+ */
+ Map, String> getSupportedVectorTypes();
+
/**
* Prepares the vector store.
* Executes any necessary setup steps for the vector store.
@@ -81,7 +105,8 @@ void createCollection(String collectionName, Class> recordClass,
* @return the records
*/
List getRecords(String collectionName, List keys,
- VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper,
+ VectorStoreRecordDefinition recordDefinition,
+ VectorStoreRecordMapper mapper,
GetRecordOptions options);
/**
diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java
index 6135b284..44fc2338 100644
--- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java
+++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java
@@ -2,6 +2,10 @@
package com.microsoft.semantickernel.connectors.data.jdbc;
import com.microsoft.semantickernel.builders.SemanticKernelBuilder;
+import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
+import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
+import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreRecordMapper;
+import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions;
@@ -10,6 +14,7 @@
import com.microsoft.semantickernel.exceptions.SKException;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.lang.reflect.Field;
+import java.sql.ResultSet;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
@@ -23,8 +28,8 @@ public class JDBCVectorStoreRecordCollection
private final String collectionName;
private final VectorStoreRecordDefinition recordDefinition;
+ private final VectorStoreRecordMapper vectorStoreRecordMapper;
private final JDBCVectorStoreRecordCollectionOptions options;
- private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper;
private final JDBCVectorStoreQueryProvider queryProvider;
/**
@@ -47,16 +52,6 @@ public JDBCVectorStoreRecordCollection(
? VectorStoreRecordDefinition.fromRecordClass(options.getRecordClass())
: options.getRecordDefinition();
- // If mapper is not provided, set a default one
- if (options.getVectorStoreRecordMapper() == null) {
- vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder()
- .withRecordClass(options.getRecordClass())
- .withVectorStoreRecordDefinition(recordDefinition)
- .build();
- } else {
- vectorStoreRecordMapper = options.getVectorStoreRecordMapper();
- }
-
// If the query provider is not provided, set a default one
if (options.getQueryProvider() == null) {
this.queryProvider = JDBCVectorStoreDefaultQueryProvider.builder()
@@ -66,6 +61,31 @@ public JDBCVectorStoreRecordCollection(
this.queryProvider = options.getQueryProvider();
}
+ // If mapper is not provided, set a default one
+ if (options.getVectorStoreRecordMapper() == null) {
+ // Default mapper for PostgreSQL
+ if (this.queryProvider instanceof PostgreSQLVectorStoreQueryProvider) {
+ vectorStoreRecordMapper = PostgreSQLVectorStoreRecordMapper.builder()
+ .withRecordClass(options.getRecordClass())
+ .withVectorStoreRecordDefinition(recordDefinition)
+ .build();
+ // Default mapper for MySQL
+ } else if (this.queryProvider instanceof MySQLVectorStoreQueryProvider) {
+ vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder()
+ .withRecordClass(options.getRecordClass())
+ .withVectorStoreRecordDefinition(recordDefinition)
+ .build();
+ // Default mapper for other databases
+ } else {
+ vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder()
+ .withRecordClass(options.getRecordClass())
+ .withVectorStoreRecordDefinition(recordDefinition)
+ .build();
+ }
+ } else {
+ vectorStoreRecordMapper = options.getVectorStoreRecordMapper();
+ }
+
// Check if the types are supported
queryProvider.validateSupportedTypes(options.getRecordClass(), recordDefinition);
}
diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java
index af1ec49e..f6aa871d 100644
--- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java
+++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java
@@ -1,16 +1,19 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.connectors.data.jdbc;
+import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import java.sql.ResultSet;
+
import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider.validateSQLidentifier;
import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_COLLECTIONS_TABLE;
import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_PREFIX_FOR_COLLECTION_TABLES;
public class JDBCVectorStoreRecordCollectionOptions {
private final Class recordClass;
- private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper;
+ private final VectorStoreRecordMapper vectorStoreRecordMapper;
private final VectorStoreRecordDefinition recordDefinition;
private final JDBCVectorStoreQueryProvider queryProvider;
private final String collectionsTableName;
@@ -19,7 +22,7 @@ public class JDBCVectorStoreRecordCollectionOptions {
private JDBCVectorStoreRecordCollectionOptions(
Class recordClass,
VectorStoreRecordDefinition recordDefinition,
- JDBCVectorStoreRecordMapper vectorStoreRecordMapper,
+ VectorStoreRecordMapper vectorStoreRecordMapper,
JDBCVectorStoreQueryProvider queryProvider,
String collectionsTableName,
String prefixForCollectionTables) {
@@ -60,7 +63,7 @@ public VectorStoreRecordDefinition getRecordDefinition() {
* Gets the vector store record mapper.
* @return the vector store record mapper
*/
- public JDBCVectorStoreRecordMapper getVectorStoreRecordMapper() {
+ public VectorStoreRecordMapper getVectorStoreRecordMapper() {
return vectorStoreRecordMapper;
}
@@ -92,7 +95,7 @@ public JDBCVectorStoreQueryProvider getQueryProvider() {
public static class Builder {
private Class recordClass;
private VectorStoreRecordDefinition recordDefinition;
- private JDBCVectorStoreRecordMapper vectorStoreRecordMapper;
+ private VectorStoreRecordMapper vectorStoreRecordMapper;
private JDBCVectorStoreQueryProvider queryProvider;
private String collectionsTableName = DEFAULT_COLLECTIONS_TABLE;
private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES;
@@ -123,7 +126,7 @@ public Builder withRecordDefinition(VectorStoreRecordDefinition recordDe
* @return the builder
*/
public Builder withVectorStoreRecordMapper(
- JDBCVectorStoreRecordMapper vectorStoreRecordMapper) {
+ VectorStoreRecordMapper vectorStoreRecordMapper) {
this.vectorStoreRecordMapper = vectorStoreRecordMapper;
return this;
}
diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java
similarity index 96%
rename from semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java
rename to semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java
index 72ecd87e..ff19017c 100644
--- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java
+++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java
@@ -1,8 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.
-package com.microsoft.semantickernel.connectors.data.jdbc;
+package com.microsoft.semantickernel.connectors.data.mysql;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
+import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider;
+import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField;
diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java
new file mode 100644
index 00000000..d9d5deff
--- /dev/null
+++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java
@@ -0,0 +1,338 @@
+// Copyright (c) Microsoft. All rights reserved.
+package com.microsoft.semantickernel.connectors.data.postgres;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider;
+import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
+import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
+import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
+import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField;
+import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField;
+import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions;
+import com.microsoft.semantickernel.exceptions.SKException;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+
+import javax.sql.DataSource;
+import java.lang.reflect.Field;
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.time.OffsetDateTime;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class PostgreSQLVectorStoreQueryProvider extends
+ JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider {
+
+ private Map, String> supportedKeyTypes;
+ private Map, String> supportedDataTypes;
+ private Map, String> supportedVectorTypes;
+
+ private final DataSource dataSource;
+ private final String collectionsTable;
+ private final String prefixForCollectionTables;
+
+ @SuppressFBWarnings("EI_EXPOSE_REP2")
+ private PostgreSQLVectorStoreQueryProvider(DataSource dataSource, String collectionsTable,
+ String prefixForCollectionTables) {
+ super(dataSource, collectionsTable, prefixForCollectionTables);
+ this.dataSource = dataSource;
+ this.collectionsTable = collectionsTable;
+ this.prefixForCollectionTables = prefixForCollectionTables;
+
+ supportedKeyTypes = new HashMap<>();
+ supportedKeyTypes.put(String.class, "VARCHAR(255)");
+
+ supportedDataTypes = new HashMap<>();
+ supportedDataTypes.put(String.class, "TEXT");
+ supportedDataTypes.put(Integer.class, "INTEGER");
+ supportedDataTypes.put(int.class, "INTEGER");
+ supportedDataTypes.put(Long.class, "BIGINT");
+ supportedDataTypes.put(long.class, "BIGINT");
+ supportedDataTypes.put(Float.class, "REAL");
+ supportedDataTypes.put(float.class, "REAL");
+ supportedDataTypes.put(Double.class, "DOUBLE PRECISION");
+ supportedDataTypes.put(double.class, "DOUBLE PRECISION");
+ supportedDataTypes.put(Boolean.class, "BOOLEAN");
+ supportedDataTypes.put(boolean.class, "BOOLEAN");
+ supportedDataTypes.put(OffsetDateTime.class, "TIMESTAMPTZ");
+
+ supportedVectorTypes = new HashMap<>();
+ supportedDataTypes.put(String.class, "TEXT");
+ supportedVectorTypes.put(List.class, "VECTOR(%d)");
+ supportedVectorTypes.put(Collection.class, "VECTOR(%d)");
+ }
+
+ /**
+ * Gets the supported key types and their corresponding SQL types.
+ *
+ * @return the supported key types
+ */
+ @Override
+ public Map, String> getSupportedKeyTypes() {
+ return new HashMap<>(this.supportedKeyTypes);
+ }
+
+ /**
+ * Gets the supported data types and their corresponding SQL types.
+ *
+ * @return the supported data types
+ */
+ @Override
+ public Map, String> getSupportedDataTypes() {
+ return new HashMap<>(this.supportedDataTypes);
+ }
+
+ /**
+ * Gets the supported vector types and their corresponding SQL types.
+ *
+ * @return the supported vector types
+ */
+ @Override
+ public Map, String> getSupportedVectorTypes() {
+ return new HashMap<>(this.supportedVectorTypes);
+ }
+
+ /**
+ * Creates a new builder.
+ * @return the builder
+ */
+ public static PostgreSQLVectorStoreQueryProvider.Builder builder() {
+ return new PostgreSQLVectorStoreQueryProvider.Builder();
+ }
+
+ /**
+ * Prepares the vector store.
+ * Executes any necessary setup steps for the vector store.
+ *
+ * @throws SKException if an error occurs while preparing the vector store
+ */
+ @Override
+ public void prepareVectorStore() {
+ super.prepareVectorStore();
+
+ // Create the vector extension
+ String pgVector = "CREATE EXTENSION IF NOT EXISTS vector";
+
+ try (Connection connection = dataSource.getConnection();
+ PreparedStatement createPgVector = connection.prepareStatement(pgVector)) {
+ createPgVector.execute();
+ } catch (SQLException e) {
+ throw new SKException("Failed to prepare vector store", e);
+ }
+ }
+
+ private String getColumnNamesAndTypesForVectorFields(List fields,
+ Class> recordClass) {
+ StringBuilder columnNames = new StringBuilder();
+ for (VectorStoreRecordVectorField field : fields) {
+ try {
+ Field declaredField = recordClass.getDeclaredField(field.getName());
+ if (columnNames.length() > 0) {
+ columnNames.append(", ");
+ }
+
+ if (declaredField.getType().equals(String.class)) {
+ columnNames.append(field.getName()).append(" ")
+ .append(supportedVectorTypes.get(String.class));
+ } else {
+ // Get the vector type and dimensions
+ String type = String.format(supportedVectorTypes.get(declaredField.getType()),
+ field.getDimensions());
+ columnNames.append(field.getName()).append(" ").append(type);
+ }
+ } catch (NoSuchFieldException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ return columnNames.toString();
+ }
+
+ /**
+ * Creates a collection.
+ *
+ * @param collectionName the collection name
+ * @param recordClass the record class
+ * @param recordDefinition the record definition
+ * @throws SKException if an error occurs while creating the collection
+ */
+ @Override
+ @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers
+ public void createCollection(String collectionName, Class> recordClass,
+ VectorStoreRecordDefinition recordDefinition) {
+ Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass);
+ List dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass);
+
+ String createStorageTable = "CREATE TABLE IF NOT EXISTS "
+ + getCollectionTableName(collectionName)
+ + " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, "
+ + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", "
+ + getColumnNamesAndTypesForVectorFields(recordDefinition.getVectorFields(), recordClass)
+ + ");";
+
+ String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
+ + " (collectionId) VALUES (?)";
+
+ try (Connection connection = dataSource.getConnection();
+ PreparedStatement createTable = connection.prepareStatement(createStorageTable)) {
+ createTable.execute();
+ } catch (SQLException e) {
+ throw new SKException("Failed to create collection", e);
+ }
+
+ try (Connection connection = dataSource.getConnection();
+ PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) {
+ insert.setObject(1, collectionName);
+ insert.execute();
+ } catch (SQLException e) {
+ throw new SKException("Failed to insert collection", e);
+ }
+ }
+
+ private void setStatementValues(PreparedStatement statement, Object record,
+ List fields) {
+ for (int i = 0; i < fields.size(); ++i) {
+ VectorStoreRecordField field = fields.get(i);
+ try {
+ Field recordField = record.getClass().getDeclaredField(field.getName());
+ recordField.setAccessible(true);
+ Object value = recordField.get(record);
+
+ if (field instanceof VectorStoreRecordKeyField) {
+ statement.setObject(i + 1, (String) value);
+ } else if (field instanceof VectorStoreRecordVectorField) {
+ Class> vectorType = record.getClass().getDeclaredField(field.getName())
+ .getType();
+
+ // If the vector field is other than String, serialize it to JSON
+ if (vectorType.equals(String.class)) {
+ statement.setObject(i + 1, value);
+ } else {
+ // Serialize the vector to JSON
+ statement.setString(i + 1, new ObjectMapper().writeValueAsString(value));
+ }
+ } else {
+ statement.setObject(i + 1, value);
+ }
+ } catch (NoSuchFieldException | IllegalAccessException | SQLException e) {
+ throw new SKException("Failed to set statement values", e);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ private String getWildcardStringWithCast(List fields) {
+ StringBuilder wildcardString = new StringBuilder();
+ int wildcards = fields.size();
+ for (int i = 0; i < wildcards; ++i) {
+ if (i > 0) {
+ wildcardString.append(", ");
+ }
+ wildcardString.append("?");
+ // Add casting for vector fields
+ if (fields.get(i) instanceof VectorStoreRecordVectorField) {
+ wildcardString.append("::vector");
+ }
+ }
+ return wildcardString.toString();
+ }
+
+ /**
+ * Upserts records into the collection.
+ * @param collectionName the collection name
+ * @param records the records to upsert
+ * @param recordDefinition the record definition
+ * @param options the upsert options
+ * @throws SKException if the upsert fails
+ */
+ @Override
+ @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers
+ public void upsertRecords(String collectionName, List> records,
+ VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) {
+ validateSQLidentifier(getCollectionTableName(collectionName));
+ List fields = recordDefinition.getAllFields();
+
+ StringBuilder onDuplicateKeyUpdate = new StringBuilder();
+ for (VectorStoreRecordField field : fields) {
+ if (field instanceof VectorStoreRecordKeyField) {
+ continue;
+ }
+ if (onDuplicateKeyUpdate.length() > 0) {
+ onDuplicateKeyUpdate.append(", ");
+ }
+ onDuplicateKeyUpdate.append(field.getName())
+ .append(" = EXCLUDED.")
+ .append(field.getName());
+ }
+
+ String query = "INSERT INTO " + getCollectionTableName(collectionName)
+ + " (" + getQueryColumnsFromFields(fields) + ")"
+ + " VALUES (" + getWildcardStringWithCast(fields) + ")"
+ + " ON CONFLICT (" + recordDefinition.getKeyField().getName() + ") DO UPDATE SET "
+ + onDuplicateKeyUpdate;
+
+ try (Connection connection = dataSource.getConnection();
+ PreparedStatement statement = connection.prepareStatement(query)) {
+ for (Object record : records) {
+ setStatementValues(statement, record, recordDefinition.getAllFields());
+ statement.addBatch();
+ }
+
+ statement.executeBatch();
+ } catch (SQLException e) {
+ throw new SKException("Failed to upsert records", e);
+ }
+ }
+
+ public static class Builder
+ extends JDBCVectorStoreDefaultQueryProvider.Builder {
+ private DataSource dataSource;
+ private String collectionsTable = DEFAULT_COLLECTIONS_TABLE;
+ private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES;
+
+ @SuppressFBWarnings("EI_EXPOSE_REP2")
+ public PostgreSQLVectorStoreQueryProvider.Builder withDataSource(DataSource dataSource) {
+ this.dataSource = dataSource;
+ return this;
+ }
+
+ /**
+ * Sets the collections table name.
+ * @param collectionsTable the collections table name
+ * @return the builder
+ */
+ public PostgreSQLVectorStoreQueryProvider.Builder withCollectionsTable(
+ String collectionsTable) {
+ this.collectionsTable = validateSQLidentifier(collectionsTable);
+ return this;
+ }
+
+ /**
+ * Sets the prefix for collection tables.
+ * @param prefixForCollectionTables the prefix for collection tables
+ * @return the builder
+ */
+ public PostgreSQLVectorStoreQueryProvider.Builder withPrefixForCollectionTables(
+ String prefixForCollectionTables) {
+ this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables);
+ return this;
+ }
+
+ public PostgreSQLVectorStoreQueryProvider build() {
+ if (dataSource == null) {
+ throw new SKException("DataSource is required");
+ }
+
+ return new PostgreSQLVectorStoreQueryProvider(dataSource, collectionsTable,
+ prefixForCollectionTables);
+ }
+ }
+}
diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreRecordMapper.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreRecordMapper.java
new file mode 100644
index 00000000..83b821c3
--- /dev/null
+++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreRecordMapper.java
@@ -0,0 +1,146 @@
+// Copyright (c) Microsoft. All rights reserved.
+package com.microsoft.semantickernel.connectors.data.postgres;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.microsoft.semantickernel.builders.SemanticKernelBuilder;
+import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
+import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
+import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
+import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField;
+import com.microsoft.semantickernel.exceptions.SKException;
+import org.postgresql.util.PGobject;
+
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
+import java.sql.ResultSet;
+import java.sql.ResultSetMetaData;
+import java.sql.SQLException;
+import java.util.List;
+import java.util.function.Function;
+
+public class PostgreSQLVectorStoreRecordMapper
+ extends VectorStoreRecordMapper {
+
+ /**
+ * Constructs a new instance of the VectorStoreRecordMapper.
+ *
+ * @param storageModelToRecordMapper the function to convert a storage model to a record
+ */
+ protected PostgreSQLVectorStoreRecordMapper(
+ Function storageModelToRecordMapper) {
+ super(null, storageModelToRecordMapper);
+ }
+
+ /**
+ * Creates a new builder.
+ *
+ * @param the record type
+ * @return the builder
+ */
+ public static Builder builder() {
+ return new Builder<>();
+ }
+
+ public static class Builder
+ implements SemanticKernelBuilder> {
+ private Class recordClass;
+ private VectorStoreRecordDefinition vectorStoreRecordDefinition;
+
+ /**
+ * Sets the record class.
+ *
+ * @param recordClass the record class
+ * @return the builder
+ */
+ public Builder withRecordClass(Class recordClass) {
+ this.recordClass = recordClass;
+ return this;
+ }
+
+ /**
+ * Sets the vector store record definition.
+ *
+ * @param vectorStoreRecordDefinition the vector store record definition
+ * @return the builder
+ */
+ public Builder withVectorStoreRecordDefinition(
+ VectorStoreRecordDefinition vectorStoreRecordDefinition) {
+ this.vectorStoreRecordDefinition = vectorStoreRecordDefinition;
+ return this;
+ }
+
+ /**
+ * Builds the {@link PostgreSQLVectorStoreRecordMapper}.
+ *
+ * @return the {@link PostgreSQLVectorStoreRecordMapper}
+ */
+ public PostgreSQLVectorStoreRecordMapper build() {
+ if (recordClass == null) {
+ throw new IllegalArgumentException("recordClass is required");
+ }
+ if (vectorStoreRecordDefinition == null) {
+ throw new IllegalArgumentException("vectorStoreRecordDefinition is required");
+ }
+
+ return new PostgreSQLVectorStoreRecordMapper<>(
+ resultSet -> {
+ try {
+ Constructor> constructor = recordClass.getDeclaredConstructor();
+ constructor.setAccessible(true);
+ Record record = (Record) constructor.newInstance();
+
+ // Select fields from the record definition.
+ // Check if vector fields are present in the result set.
+ List fields;
+ ResultSetMetaData metaData = resultSet.getMetaData();
+ if (metaData.getColumnCount() == vectorStoreRecordDefinition.getAllFields()
+ .size()) {
+ fields = vectorStoreRecordDefinition.getAllFields();
+ } else {
+ fields = vectorStoreRecordDefinition.getNonVectorFields();
+ }
+
+ for (VectorStoreRecordField field : fields) {
+ Object value = resultSet.getObject(field.getName());
+ Field recordField = recordClass.getDeclaredField(field.getName());
+ recordField.setAccessible(true);
+
+ // If the field is a vector field, deserialize the JSON string
+ if (field instanceof VectorStoreRecordVectorField) {
+ Class> vectorType = recordField.getType();
+
+ // If the vector type is a string, set the value directly
+ if (vectorType.equals(String.class)) {
+ recordField.set(record, value);
+ } else {
+ // Deserialize the pgvector string to the vector type
+ PGobject pgObject = (PGobject) value;
+ recordField.set(record,
+ new ObjectMapper().readValue(pgObject.getValue(),
+ vectorType));
+ }
+ } else {
+ recordField.set(record, value);
+ }
+ }
+
+ return record;
+ } catch (NoSuchMethodException e) {
+ throw new SKException("Default constructor not found.", e);
+ } catch (InstantiationException | InvocationTargetException e) {
+ throw new SKException(String.format(
+ "SK cannot instantiate %s. A custom mapper is required.",
+ recordClass.getName()), e);
+ } catch (JsonProcessingException e) {
+ throw new SKException(String.format(
+ "SK cannot deserialize %s. A custom mapper is required.",
+ recordClass.getName()), e);
+ } catch (SQLException | NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+ }
+}