Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions api-test/integration-tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@
<version>8.0.33</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>42.7.2</version> <!-- Use the latest version -->
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class Hotel {
@VectorStoreRecordVectorAttribute(dimensions = 3)
private final List<Float> descriptionEmbedding;
@VectorStoreRecordDataAttribute
private final double rating;
private double rating;

public Hotel() {
this(null, null, 0, null, null, 0.0);
Expand Down Expand Up @@ -56,4 +56,8 @@ public List<Float> getDescriptionEmbedding() {
public double getRating() {
return rating;
}

public void setRating(double rating) {
this.rating = rating;
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,45 +1,71 @@
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.mysql.cj.jdbc.MysqlDataSource;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.postgresql.ds.PGSimpleDataSource;
import org.testcontainers.containers.MySQLContainer;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;

import javax.annotation.Nonnull;
import javax.sql.DataSource;
import java.util.Arrays;
import java.util.List;

import com.microsoft.semantickernel.tests.connectors.memory.jdbc.JDBCVectorStoreRecordCollectionTest.QueryProvider;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

@Testcontainers
public class JDBCVectorStoreTest {
@Container
private static final MySQLContainer<?> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
private static final String MYSQL_USER = "test";
private static final String MYSQL_PASSWORD = "test";
private static MysqlDataSource dataSource;

@BeforeAll
static void setup() {
dataSource = new MysqlDataSource();
dataSource.setUrl(CONTAINER.getJdbcUrl());
dataSource.setUser(MYSQL_USER);
dataSource.setPassword(MYSQL_PASSWORD);
}
private static final MySQLContainer<?> MYSQL_CONTAINER = new MySQLContainer<>("mysql:5.7.34");

private static final DockerImageName PGVECTOR = DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres");
@Container
private static final PostgreSQLContainer<?> POSTGRESQL_CONTAINER = new PostgreSQLContainer<>(PGVECTOR);

private JDBCVectorStore buildVectorStore(QueryProvider provider) {
JDBCVectorStoreQueryProvider queryProvider;
DataSource dataSource;

switch (provider) {
case MySQL:
MysqlDataSource mysqlDataSource = new MysqlDataSource();
mysqlDataSource.setUrl(MYSQL_CONTAINER.getJdbcUrl());
mysqlDataSource.setUser(MYSQL_CONTAINER.getUsername());
mysqlDataSource.setPassword(MYSQL_CONTAINER.getPassword());
dataSource = mysqlDataSource;
queryProvider = MySQLVectorStoreQueryProvider.builder()
.withDataSource(dataSource)
.build();
break;
case PostgreSQL:
PGSimpleDataSource pgSimpleDataSource = new PGSimpleDataSource();
pgSimpleDataSource.setUrl(POSTGRESQL_CONTAINER.getJdbcUrl());
pgSimpleDataSource.setUser(POSTGRESQL_CONTAINER.getUsername());
pgSimpleDataSource.setPassword(POSTGRESQL_CONTAINER.getPassword());
dataSource = pgSimpleDataSource;
queryProvider = PostgreSQLVectorStoreQueryProvider.builder()
.withDataSource(dataSource)
.build();
break;
default:
throw new IllegalArgumentException("Unknown query provider: " + provider);
}

@Test
public void getCollectionNamesAsync() {
MySQLVectorStoreQueryProvider queryProvider = MySQLVectorStoreQueryProvider.builder()
.withDataSource(dataSource)
.build();

JDBCVectorStore vectorStore = JDBCVectorStore.builder()
JDBCVectorStore vectorStore = JDBCVectorStore.builder()
.withDataSource(dataSource)
.withOptions(
JDBCVectorStoreOptions.builder()
Expand All @@ -48,6 +74,16 @@ public void getCollectionNamesAsync() {
)
.build();

vectorStore.prepareAsync().block();
return vectorStore;
}


@ParameterizedTest
@EnumSource(QueryProvider.class)
public void getCollectionNamesAsync(QueryProvider provider) {
JDBCVectorStore vectorStore = buildVectorStore(provider);

vectorStore.getCollectionNamesAsync().block();

List<String> collectionNames = Arrays.asList("collection1", "collection2", "collection3");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
Expand Down
5 changes: 5 additions & 0 deletions semantickernel-experimental/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@
</exclusions>
</dependency>

<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>42.7.2</version> <!-- Use the latest version -->
</dependency>

</dependencies>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.connectors.data.jdbc;

import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
Expand All @@ -24,14 +25,27 @@
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class JDBCVectorStoreDefaultQueryProvider
implements JDBCVectorStoreQueryProvider {
private static final Map<Class<?>, String> supportedKeyTypes;
private static final Map<Class<?>, String> supportedDataTypes;
private static final Map<Class<?>, String> supportedVectorTypes;

static {
private Map<Class<?>, String> supportedKeyTypes;
private Map<Class<?>, String> supportedDataTypes;
private Map<Class<?>, String> supportedVectorTypes;
private final DataSource dataSource;
private final String collectionsTable;
private final String prefixForCollectionTables;

@SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed
protected JDBCVectorStoreDefaultQueryProvider(
@Nonnull DataSource dataSource,
@Nonnull String collectionsTable,
@Nonnull String prefixForCollectionTables) {
this.dataSource = dataSource;
this.collectionsTable = collectionsTable;
this.prefixForCollectionTables = prefixForCollectionTables;

supportedKeyTypes = new HashMap<>();
supportedKeyTypes.put(String.class, "VARCHAR(255)");

Expand All @@ -54,19 +68,6 @@ public class JDBCVectorStoreDefaultQueryProvider
supportedVectorTypes.put(List.class, "TEXT");
supportedVectorTypes.put(Collection.class, "TEXT");
}
private final DataSource dataSource;
private final String collectionsTable;
private final String prefixForCollectionTables;

@SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed
protected JDBCVectorStoreDefaultQueryProvider(
@Nonnull DataSource dataSource,
@Nonnull String collectionsTable,
@Nonnull String prefixForCollectionTables) {
this.dataSource = dataSource;
this.collectionsTable = collectionsTable;
this.prefixForCollectionTables = prefixForCollectionTables;
}

/**
* Creates a new builder.
Expand All @@ -82,14 +83,9 @@ public static Builder builder() {
* @return the formatted wildcard string
*/
protected String getWildcardString(int wildcards) {
StringBuilder wildcardString = new StringBuilder();
for (int i = 0; i < wildcards; ++i) {
wildcardString.append("?");
if (i < wildcards - 1) {
wildcardString.append(", ");
}
}
return wildcardString.toString();
return Stream.generate(() -> "?")
.limit(wildcards)
.collect(Collectors.joining(", "));
}

/**
Expand All @@ -102,6 +98,12 @@ protected String getQueryColumnsFromFields(List<VectorStoreRecordField> fields)
.collect(Collectors.joining(", "));
}

/**
* Formats the column names and types for a table.
* @param fields the fields
* @param types the types
* @return the formatted column names and types
*/
protected String getColumnNamesAndTypes(List<Field> fields, Map<Class<?>, String> types) {
List<String> columns = fields.stream()
.map(field -> field.getName() + " " + types.get(field.getType()))
Expand All @@ -114,6 +116,36 @@ protected String getCollectionTableName(String collectionName) {
return validateSQLidentifier(prefixForCollectionTables + collectionName);
}

/**
* Gets the supported key types and their corresponding SQL types.
*
* @return the supported key types
*/
@Override
public Map<Class<?>, String> getSupportedKeyTypes() {
return new HashMap<>(this.supportedKeyTypes);
}

/**
* Gets the supported data types and their corresponding SQL types.
*
* @return the supported data types
*/
@Override
public Map<Class<?>, String> getSupportedDataTypes() {
return new HashMap<>(this.supportedDataTypes);
}

/**
* Gets the supported vector types and their corresponding SQL types.
*
* @return the supported vector types
*/
@Override
public Map<Class<?>, String> getSupportedVectorTypes() {
return new HashMap<>(this.supportedVectorTypes);
}

/**
* Prepares the vector store.
* Executes any necessary setup steps for the vector store.
Expand Down Expand Up @@ -146,11 +178,12 @@ public void validateSupportedTypes(Class<?> recordClass,
VectorStoreRecordDefinition recordDefinition) {
VectorStoreRecordDefinition.validateSupportedTypes(
Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)),
supportedKeyTypes.keySet());
getSupportedKeyTypes().keySet());
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getDataDeclaredFields(recordClass), supportedDataTypes.keySet());
recordDefinition.getDataDeclaredFields(recordClass), getSupportedDataTypes().keySet());
VectorStoreRecordDefinition.validateSupportedTypes(
recordDefinition.getVectorDeclaredFields(recordClass), supportedVectorTypes.keySet());
recordDefinition.getVectorDeclaredFields(recordClass),
getSupportedVectorTypes().keySet());
}

/**
Expand Down Expand Up @@ -194,8 +227,8 @@ public void createCollection(String collectionName, Class<?> recordClass,
String createStorageTable = "CREATE TABLE IF NOT EXISTS "
+ getCollectionTableName(collectionName)
+ " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, "
+ getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", "
+ getColumnNamesAndTypes(vectorDeclaredFields, supportedVectorTypes) + ");";
+ getColumnNamesAndTypes(dataDeclaredFields, getSupportedDataTypes()) + ", "
+ getColumnNamesAndTypes(vectorDeclaredFields, getSupportedVectorTypes()) + ");";

String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
+ " (collectionId) VALUES (?)";
Expand Down Expand Up @@ -284,7 +317,8 @@ public List<String> getCollectionNames() {
*/
@Override
public <Record> List<Record> getRecords(String collectionName, List<String> keys,
VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper<Record> mapper,
VectorStoreRecordDefinition recordDefinition,
VectorStoreRecordMapper<Record, ResultSet> mapper,
GetRecordOptions options) {
List<VectorStoreRecordField> fields;
if (options == null || options.includeVectors()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
package com.microsoft.semantickernel.connectors.data.jdbc;

import com.microsoft.semantickernel.builders.SemanticKernelBuilder;
import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions;
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions;

import java.sql.ResultSet;
import java.util.List;
import java.util.Map;

/**
* The JDBC vector store query provider.
Expand All @@ -24,6 +27,27 @@ public interface JDBCVectorStoreQueryProvider {
*/
String DEFAULT_PREFIX_FOR_COLLECTION_TABLES = "SKCollection_";

/**
* Gets the supported key types and their corresponding SQL types.
*
* @return the supported key types
*/
Map<Class<?>, String> getSupportedKeyTypes();

/**
* Gets the supported data types and their corresponding SQL types.
*
* @return the supported data types
*/
Map<Class<?>, String> getSupportedDataTypes();

/**
* Gets the supported vector types and their corresponding SQL types.
*
* @return the supported vector types
*/
Map<Class<?>, String> getSupportedVectorTypes();

/**
* Prepares the vector store.
* Executes any necessary setup steps for the vector store.
Expand Down Expand Up @@ -81,7 +105,8 @@ void createCollection(String collectionName, Class<?> recordClass,
* @return the records
*/
<Record> List<Record> getRecords(String collectionName, List<String> keys,
VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper<Record> mapper,
VectorStoreRecordDefinition recordDefinition,
VectorStoreRecordMapper<Record, ResultSet> mapper,
GetRecordOptions options);

/**
Expand Down
Loading