diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java index 7becd67e..631f2cac 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java @@ -18,7 +18,8 @@ /** * Provides OpenAi implementation of audio to text service. */ -public class OpenAiAudioToTextService extends OpenAiService implements AudioToTextService { +public class OpenAiAudioToTextService extends OpenAiService + implements AudioToTextService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiAudioToTextService.class); diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java index 25071ca9..c698fab3 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java @@ -17,7 +17,8 @@ /** * Provides OpenAi implementation of text to audio service. */ -public class OpenAiTextToAudioService extends OpenAiService implements TextToAudioService { +public class OpenAiTextToAudioService extends OpenAiService + implements TextToAudioService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiTextToAudioService.class); diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java index 83b6c9a7..84a8287e 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java @@ -79,7 +79,8 @@ /** * OpenAI chat completion service. */ -public class OpenAIChatCompletion extends OpenAiService implements ChatCompletionService { +public class OpenAIChatCompletion extends OpenAiService + implements ChatCompletionService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAIChatCompletion.class); @@ -1055,7 +1056,8 @@ static ChatRequestMessage getChatRequestMessage( /** * Builder for creating a new instance of {@link OpenAIChatCompletion}. */ - public static class Builder extends OpenAiServiceBuilder { + public static class Builder + extends OpenAiServiceBuilder { @Override public OpenAIChatCompletion build() { diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java index 5c418649..13783229 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java @@ -30,7 +30,8 @@ /** * An OpenAI implementation of a {@link TextGenerationService}. */ -public class OpenAITextGenerationService extends OpenAiService implements TextGenerationService { +public class OpenAITextGenerationService extends OpenAiService + implements TextGenerationService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAITextGenerationService.class); diff --git a/api-test/integration-tests/pom.xml b/api-test/integration-tests/pom.xml index 126b3741..7e669927 100644 --- a/api-test/integration-tests/pom.xml +++ b/api-test/integration-tests/pom.xml @@ -68,9 +68,9 @@ 3.44.1.0 - com.mysql - mysql-connector-j - 8.2.0 + mysql + mysql-connector-java + 8.0.33 test diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java new file mode 100644 index 00000000..6e80e4ac --- /dev/null +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java @@ -0,0 +1,253 @@ +package com.microsoft.semantickernel.tests.connectors.memory.jdbc; + +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollection; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions; +import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.tests.connectors.memory.Hotel; +import com.mysql.cj.jdbc.MysqlDataSource; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import javax.annotation.Nonnull; +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +@Testcontainers +public class JDBCVectorStoreRecordCollectionTest { + @Container + private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); + private static final String MYSQL_USER = "test"; + private static final String MYSQL_PASSWORD = "test"; + private static MysqlDataSource dataSource; + @BeforeAll + static void setup() { + dataSource = new MysqlDataSource(); + dataSource.setUrl(CONTAINER.getJdbcUrl()); + dataSource.setUser(MYSQL_USER); + dataSource.setPassword(MYSQL_PASSWORD); + } + + private JDBCVectorStoreRecordCollection buildRecordCollection(@Nonnull String collectionName) { + JDBCVectorStoreRecordCollection recordCollection = new JDBCVectorStoreRecordCollection<>( + dataSource, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(Hotel.class) + .withQueryProvider(MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build()) + .build()); + + recordCollection.prepareAsync().block(); + recordCollection.createCollectionIfNotExistsAsync().block(); + return recordCollection; + } + + @Test + public void buildRecordCollection() { + assertNotNull(buildRecordCollection("buildTest")); + } + + private List getHotels() { + return List.of( + new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), + new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), 3.0), + new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0), + new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), + new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0) + ); + } + + @Test + public void upsertAndGetRecordAsync() { + String collectionName = "upsertAndGetRecordAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + for (Hotel hotel : hotels) { + recordStore.upsertAsync(hotel, null).block(); + } + + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + } + } + + @Test + public void getBatchAsync() { + String collectionName = "getBatchAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + for (Hotel hotel : hotels) { + recordStore.upsertAsync(hotel, null).block(); + } + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordStore.getBatchAsync(keys, null).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + } + + @Test + public void upsertBatchAndGetBatchAsync() { + String collectionName = "upsertBatchAndGetBatchAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordStore.getBatchAsync(keys, null).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + } + + @Test + public void insertAndReplaceAsync() { + String collectionName = "insertAndReplaceAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + recordStore.upsertBatchAsync(hotels, null).block(); + recordStore.upsertBatchAsync(hotels, null).block(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordStore.getBatchAsync(keys, null).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + } + + @Test + public void deleteRecordAsync() { + String collectionName = "deleteRecordAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + for (Hotel hotel : hotels) { + recordStore.deleteAsync(hotel.getId(), null).block(); + Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block(); + assertNull(retrievedHotel); + } + } + + @Test + public void deleteBatchAsync() { + String collectionName = "deleteBatchAsync"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + recordStore.deleteBatchAsync(keys, null).block(); + + for (String key : keys) { + Hotel retrievedHotel = recordStore.getAsync(key, null).block(); + assertNull(retrievedHotel); + } + } + + @Test + public void getWithNoVectors() { + String collectionName = "getWithNoVectors"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + GetRecordOptions options = GetRecordOptions.builder() + .includeVectors(false) + .build(); + + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + assertNull(retrievedHotel.getDescriptionEmbedding()); + } + + options = GetRecordOptions.builder() + .includeVectors(true) + .build(); + + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + assertNotNull(retrievedHotel.getDescriptionEmbedding()); + } + } + + @Test + public void getBatchWithNoVectors() { + String collectionName = "getBatchWithNoVectors"; + JDBCVectorStoreRecordCollection recordStore = buildRecordCollection(collectionName); + + List hotels = getHotels(); + recordStore.upsertBatchAsync(hotels, null).block(); + + GetRecordOptions options = GetRecordOptions.builder() + .includeVectors(false) + .build(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordStore.getBatchAsync(keys, options).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + + for (Hotel hotel : retrievedHotels) { + assertNull(hotel.getDescriptionEmbedding()); + } + + options = GetRecordOptions.builder() + .includeVectors(true) + .build(); + + retrievedHotels = recordStore.getBatchAsync(keys, options).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + + for (Hotel hotel : retrievedHotels) { + assertNotNull(hotel.getDescriptionEmbedding()); + } + } +} diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java new file mode 100644 index 00000000..eb134dd0 --- /dev/null +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java @@ -0,0 +1,70 @@ +package com.microsoft.semantickernel.tests.connectors.memory.jdbc; + +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions; +import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.tests.connectors.memory.Hotel; +import com.mysql.cj.jdbc.MysqlDataSource; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Testcontainers +public class JDBCVectorStoreTest { + @Container + private static final MySQLContainer CONTAINER = new MySQLContainer<>("mysql:5.7.34"); + private static final String MYSQL_USER = "test"; + private static final String MYSQL_PASSWORD = "test"; + private static MysqlDataSource dataSource; + + @BeforeAll + static void setup() { + dataSource = new MysqlDataSource(); + dataSource.setUrl(CONTAINER.getJdbcUrl()); + dataSource.setUser(MYSQL_USER); + dataSource.setPassword(MYSQL_PASSWORD); + } + + @Test + public void getCollectionNamesAsync() { + MySQLVectorStoreQueryProvider queryProvider = MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + + JDBCVectorStore vectorStore = JDBCVectorStore.builder() + .withDataSource(dataSource) + .withOptions( + JDBCVectorStoreOptions.builder() + .withQueryProvider(queryProvider) + .build() + ) + .build(); + + vectorStore.getCollectionNamesAsync().block(); + + List collectionNames = Arrays.asList("collection1", "collection2", "collection3"); + + for (String collectionName : collectionNames) { + vectorStore.getCollection(collectionName, Hotel.class, null).createCollectionAsync().block(); + } + + List retrievedCollectionNames = vectorStore.getCollectionNamesAsync().block(); + assertNotNull(retrievedCollectionNames); + assertEquals(collectionNames.size(), retrievedCollectionNames.size()); + for (String collectionName : collectionNames) { + assertTrue(retrievedCollectionNames.contains(collectionName)); + } + } +} diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml b/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml index bea43f73..dcf92999 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml @@ -88,6 +88,12 @@ 1.1.0 compile + + + mysql + mysql-connector-java + 8.0.33 + diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java new file mode 100644 index 00000000..2379e572 --- /dev/null +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java @@ -0,0 +1,188 @@ +package com.microsoft.semantickernel.samples.syntaxexamples.memory; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.credential.KeyCredential; +import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions; +import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute; +import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; +import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute; +import com.mysql.cj.jdbc.MysqlDataSource; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import javax.sql.DataSource; +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class JDBC_DataStorage { + + private static final String CLIENT_KEY = System.getenv("CLIENT_KEY"); + private static final String AZURE_CLIENT_KEY = System.getenv("AZURE_CLIENT_KEY"); + + // Only required if AZURE_CLIENT_KEY is set + private static final String CLIENT_ENDPOINT = System.getenv("CLIENT_ENDPOINT"); + private static final String MODEL_ID = System.getenv() + .getOrDefault("EMBEDDING_MODEL_ID", "text-embedding-3-large"); + private static final int EMBEDDING_DIMENSIONS = 1536; + + // Run a MySQL server with: + // docker run -d --name mysql-container -e MYSQL_ROOT_PASSWORD=root -e MYSQL_DATABASE=sk -p 3306:3306 mysql:latest + + static class GitHubFile { + @VectorStoreRecordKeyAttribute() + private final String id; + @VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding") + private final String description; + @VectorStoreRecordDataAttribute + private final String link; + @VectorStoreRecordVectorAttribute(dimensions = EMBEDDING_DIMENSIONS, indexKind = "Hnsw") + private final List embedding; + + public GitHubFile() { + this(null, null, null, Collections.emptyList()); + } + + public GitHubFile( + String id, + String description, + String link, + List embedding) { + this.id = id; + this.description = description; + this.link = link; + this.embedding = embedding; + } + + public String getId() { + return id; + } + + public String getDescription() { + return description; + } + + static String encodeId(String realId) { + byte[] bytes = Base64.getUrlEncoder().encode(realId.getBytes(StandardCharsets.UTF_8)); + return new String(bytes, StandardCharsets.UTF_8); + } + } + + public static void main(String[] args) throws SQLException { + System.out.println("=============================================================="); + System.out.println("========== JDBC Vector Store Example =============="); + System.out.println("=============================================================="); + + OpenAIAsyncClient client; + + if (AZURE_CLIENT_KEY != null) { + client = new OpenAIClientBuilder() + .credential(new AzureKeyCredential(AZURE_CLIENT_KEY)) + .endpoint(CLIENT_ENDPOINT) + .buildAsyncClient(); + + } else { + client = new OpenAIClientBuilder() + .credential(new KeyCredential(CLIENT_KEY)) + .buildAsyncClient(); + } + + var embeddingGeneration = OpenAITextEmbeddingGenerationService.builder() + .withOpenAIAsyncClient(client) + .withModelId(MODEL_ID) + .withDimensions(EMBEDDING_DIMENSIONS) + .build(); + + var dataSource = new MysqlDataSource(); + dataSource.setUrl("jdbc:mysql://localhost:3306/sk"); + dataSource.setPassword("root"); + dataSource.setUser("root"); + + dataStorageWithMySQL(dataSource, embeddingGeneration); + } + + public static void dataStorageWithMySQL( + DataSource dataSource, + OpenAITextEmbeddingGenerationService embeddingGeneration) { + + // Build a query provider + var queryProvider = MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + + // Create a new vector store + var jdbcVectorStore = JDBCVectorStore.builder() + .withDataSource(dataSource) + .withOptions(JDBCVectorStoreOptions.builder() + .withQueryProvider(queryProvider) + .build()) + .build(); + + String collectionName = "skgithubfiles"; + var collection = jdbcVectorStore.getCollection(collectionName, GitHubFile.class, + null); + + // Create collection if it does not exist and store data + List ids = collection + .createCollectionIfNotExistsAsync() + .then(storeData(collection, embeddingGeneration, sampleData())) + .block(); + + List data = collection.getBatchAsync(ids, null).block(); + + data.forEach(gitHubFile -> System.out.println("Retrieved: " + gitHubFile.getDescription())); + } + + private static Mono> storeData( + VectorStoreRecordCollection recordStore, + OpenAITextEmbeddingGenerationService embeddingGeneration, + Map data) { + + return Flux.fromIterable(data.entrySet()) + .flatMap(entry -> { + System.out.println("Save '" + entry.getKey() + "' to memory."); + + return embeddingGeneration + .generateEmbeddingsAsync(Collections.singletonList(entry.getValue())) + .flatMap(embeddings -> { + GitHubFile gitHubFile = new GitHubFile( + GitHubFile.encodeId(entry.getKey()), + entry.getValue(), + entry.getKey(), + embeddings.get(0).getVector()); + return recordStore.upsertAsync(gitHubFile, null); + }); + }) + .collectList(); + } + + private static Map sampleData() { + return Arrays.stream(new String[][] { + { "https://github.com/microsoft/semantic-kernel/blob/main/README.md", + "README: Installation, getting started with Semantic Kernel, and how to contribute" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/samples/notebooks/dotnet/02-running-prompts-from-file.ipynb", + "Jupyter notebook describing how to pass prompts from a file to a semantic skill or function" }, + { "https://github.com/microsoft/semantic-kernel/tree/main/samples/skills/ChatSkill/ChatGPT", + "Sample demonstrating how to create a chat skill interfacing with ChatGPT" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/SemanticKernel/Memory/VolatileMemoryStore.cs", + "C# class that defines a volatile embedding store" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/samples/dotnet/KernelHttpServer/README.md", + "README: How to set up a Semantic Kernel Service API using Azure Function Runtime v4" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/samples/apps/chat-summary-webapp-react/README.md", + "README: README associated with a sample chat summary react-based webapp" }, + }).collect(Collectors.toMap(element -> element[0], element -> element[1])); + } +} diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java index 70224596..871d4cb4 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java @@ -33,7 +33,8 @@ static Builder builder() { /** * Builder for the AudioToTextService. */ - abstract class Builder extends OpenAiServiceBuilder { + abstract class Builder + extends OpenAiServiceBuilder { } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java index e42b2c6e..5386cd83 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java @@ -11,8 +11,9 @@ * @param The service type * @param The builder type */ -public abstract class OpenAiServiceBuilder> implements - +public abstract class OpenAiServiceBuilder> + implements + SemanticKernelBuilder { @Nullable diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java index 8010b67f..0ab08f5f 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java @@ -61,6 +61,7 @@ Flux getStreamingTextContentsAsync( /** * Builder for a TextGenerationService */ - abstract class Builder extends OpenAiServiceBuilder { + abstract class Builder + extends OpenAiServiceBuilder { } } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java index 9576b122..5155299d 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java @@ -90,12 +90,16 @@ public AzureAISearchVectorStoreRecordCollection( : options.getRecordDefinition(); // Validate supported types - VectorStoreRecordDefinition.validateSupportedKeyTypes(this.options.getRecordClass(), - this.recordDefinition, supportedKeyTypes); - VectorStoreRecordDefinition.validateSupportedDataTypes(this.options.getRecordClass(), - this.recordDefinition, supportedDataTypes); - VectorStoreRecordDefinition.validateSupportedVectorTypes(this.options.getRecordClass(), - this.recordDefinition, supportedVectorTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + Collections + .singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), + supportedKeyTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getDataDeclaredFields(this.options.getRecordClass()), + supportedDataTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), + supportedVectorTypes); // Add non-vector fields to the list nonVectorFields.add(this.recordDefinition.getKeyField().getName()); diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java new file mode 100644 index 00000000..5e497176 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import javax.sql.DataSource; +import java.util.List; + +/** + * A JDBC vector store. + */ +public class JDBCVectorStore implements SQLVectorStore> { + private final DataSource dataSource; + private final JDBCVectorStoreOptions options; + private final JDBCVectorStoreQueryProvider queryProvider; + + /** + * Creates a new instance of the {@link JDBCVectorStore}. + * If using this constructor, call {@link #prepareAsync()} before using the vector store. + * + * @param dataSource the connection + * @param options the options + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public JDBCVectorStore(@Nonnull DataSource dataSource, + @Nullable JDBCVectorStoreOptions options) { + this.dataSource = dataSource; + this.options = options; + + if (this.options != null && this.options.getQueryProvider() != null) { + this.queryProvider = this.options.getQueryProvider(); + } else { + this.queryProvider = JDBCVectorStoreDefaultQueryProvider.builder() + .withDataSource(dataSource) + .build(); + } + } + + /** + * Creates a new builder for the vector store. + * + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets a collection from the vector store. + * + * @param collectionName The name of the collection. + * @param recordClass The class type of the record. + * @param recordDefinition The record definition. + * @return The collection. + */ + @Override + public JDBCVectorStoreRecordCollection getCollection( + @Nonnull String collectionName, + @Nonnull Class recordClass, + @Nullable VectorStoreRecordDefinition recordDefinition) { + + if (this.options != null && this.options.getVectorStoreRecordCollectionFactory() != null) { + return this.options.getVectorStoreRecordCollectionFactory() + .createVectorStoreRecordCollection( + dataSource, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(recordClass) + .withRecordDefinition(recordDefinition) + .withQueryProvider(this.queryProvider) + .build()); + } + + return new JDBCVectorStoreRecordCollection<>( + dataSource, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(recordClass) + .withRecordDefinition(recordDefinition) + .withQueryProvider(this.queryProvider) + .build()); + } + + /** + * Gets the names of all collections in the vector store. + * + * @return A list of collection names. + */ + @Override + public Mono> getCollectionNamesAsync() { + return Mono.fromCallable(queryProvider::getCollectionNames) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Prepares the vector store. + */ + @Override + public Mono prepareAsync() { + return Mono.fromRunnable(queryProvider::prepareVectorStore) + .subscribeOn(Schedulers.boundedElastic()).then(); + } + + /** + * Builder for creating a {@link JDBCVectorStore}. + */ + public static class Builder { + private DataSource dataSource; + private JDBCVectorStoreOptions options; + + /** + * Sets the data source. + * + * @param dataSource the data source + * @return the builder + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the options. + * + * @param options the options + * @return the builder + */ + public Builder withOptions(JDBCVectorStoreOptions options) { + this.options = options; + return this; + } + + /** + * Builds the {@link JDBCVectorStore}. + * + * @return the {@link JDBCVectorStore} + */ + public JDBCVectorStore build() { + return buildAsync().block(); + } + + /** + * Builds the {@link JDBCVectorStore} asynchronously. + * + * @return the {@link Mono} with the {@link JDBCVectorStore} + */ + public Mono buildAsync() { + if (dataSource == null) { + throw new IllegalArgumentException("dataSource is required"); + } + + JDBCVectorStore vectorStore = new JDBCVectorStore(dataSource, options); + return vectorStore.prepareAsync().thenReturn(vectorStore); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java new file mode 100644 index 00000000..096f240a --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java @@ -0,0 +1,419 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.exceptions.SKException; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.annotation.Nonnull; +import javax.sql.DataSource; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class JDBCVectorStoreDefaultQueryProvider + implements JDBCVectorStoreQueryProvider { + private static final Map, String> supportedKeyTypes; + private static final Map, String> supportedDataTypes; + private static final Map, String> supportedVectorTypes; + + static { + supportedKeyTypes = new HashMap<>(); + supportedKeyTypes.put(String.class, "VARCHAR(255)"); + + supportedDataTypes = new HashMap<>(); + supportedDataTypes.put(String.class, "TEXT"); + supportedDataTypes.put(Integer.class, "INTEGER"); + supportedDataTypes.put(int.class, "INTEGER"); + supportedDataTypes.put(Long.class, "BIGINT"); + supportedDataTypes.put(long.class, "BIGINT"); + supportedDataTypes.put(Float.class, "REAL"); + supportedDataTypes.put(float.class, "REAL"); + supportedDataTypes.put(Double.class, "DOUBLE"); + supportedDataTypes.put(double.class, "DOUBLE"); + supportedDataTypes.put(Boolean.class, "BOOLEAN"); + supportedDataTypes.put(boolean.class, "BOOLEAN"); + supportedDataTypes.put(OffsetDateTime.class, "TIMESTAMPTZ"); + + supportedVectorTypes = new HashMap<>(); + supportedVectorTypes.put(String.class, "TEXT"); + supportedVectorTypes.put(List.class, "TEXT"); + supportedVectorTypes.put(Collection.class, "TEXT"); + } + private final DataSource dataSource; + private final String collectionsTable; + private final String prefixForCollectionTables; + + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + protected JDBCVectorStoreDefaultQueryProvider( + @Nonnull DataSource dataSource, + @Nonnull String collectionsTable, + @Nonnull String prefixForCollectionTables) { + this.dataSource = dataSource; + this.collectionsTable = collectionsTable; + this.prefixForCollectionTables = prefixForCollectionTables; + } + + /** + * Creates a new builder. + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Formats a wildcard string for a query. + * @param wildcards the number of wildcards + * @return the formatted wildcard string + */ + protected String getWildcardString(int wildcards) { + StringBuilder wildcardString = new StringBuilder(); + for (int i = 0; i < wildcards; ++i) { + wildcardString.append("?"); + if (i < wildcards - 1) { + wildcardString.append(", "); + } + } + return wildcardString.toString(); + } + + /** + * Formats the query columns from a record definition. + * @param fields the fields to get the columns from + * @return the formatted query columns + */ + protected String getQueryColumnsFromFields(List fields) { + return fields.stream().map(VectorStoreRecordField::getName) + .collect(Collectors.joining(", ")); + } + + protected String getColumnNamesAndTypes(List fields, Map, String> types) { + List columns = fields.stream() + .map(field -> field.getName() + " " + types.get(field.getType())) + .collect(Collectors.toList()); + + return String.join(", ", columns); + } + + protected String getCollectionTableName(String collectionName) { + return validateSQLidentifier(prefixForCollectionTables + collectionName); + } + + /** + * Prepares the vector store. + * Executes any necessary setup steps for the vector store. + * + * @throws SKException if an error occurs while preparing the vector store + */ + @Override + public void prepareVectorStore() { + String createCollectionsTable = "CREATE TABLE IF NOT EXISTS " + + validateSQLidentifier(collectionsTable) + + " (collectionId VARCHAR(255) PRIMARY KEY);"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement createTable = connection.prepareStatement(createCollectionsTable)) { + createTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to prepare vector store", e); + } + } + + /** + * Checks if the types of the record class fields are supported. + * + * @param recordClass the record class + * @param recordDefinition the record definition + * @throws IllegalArgumentException if the types are not supported + */ + @Override + public void validateSupportedTypes(Class recordClass, + VectorStoreRecordDefinition recordDefinition) { + VectorStoreRecordDefinition.validateSupportedTypes( + Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)), + supportedKeyTypes.keySet()); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getDataDeclaredFields(recordClass), supportedDataTypes.keySet()); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getVectorDeclaredFields(recordClass), supportedVectorTypes.keySet()); + } + + /** + * Checks if a collection exists. + * + * @param collectionName the collection name + * @return true if the collection exists, false otherwise + * @throws SKException if an error occurs while checking if the collection exists + */ + @Override + public boolean collectionExists(String collectionName) { + String query = "SELECT 1 FROM " + validateSQLidentifier(collectionsTable) + + " WHERE collectionId = ?"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + statement.setObject(1, collectionName); + + return statement.executeQuery().next(); + } catch (SQLException e) { + throw new SKException("Failed to check if collection exists", e); + } + } + + /** + * Creates a collection. + * + * @param collectionName the collection name + * @param recordClass the record class + * @param recordDefinition the record definition + * @throws SKException if an error occurs while creating the collection + */ + @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers + public void createCollection(String collectionName, Class recordClass, + VectorStoreRecordDefinition recordDefinition) { + Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass); + List dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass); + List vectorDeclaredFields = recordDefinition.getVectorDeclaredFields(recordClass); + + String createStorageTable = "CREATE TABLE IF NOT EXISTS " + + getCollectionTableName(collectionName) + + " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, " + + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", " + + getColumnNamesAndTypes(vectorDeclaredFields, supportedVectorTypes) + ");"; + + String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable) + + " (collectionId) VALUES (?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement createTable = connection.prepareStatement(createStorageTable)) { + createTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to create collection", e); + } + + try (Connection connection = dataSource.getConnection(); + PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) { + insert.setObject(1, collectionName); + insert.execute(); + } catch (SQLException e) { + throw new SKException("Failed to insert collection", e); + } + } + + /** + * Deletes a collection. + * + * @param collectionName the collection name + * @throws SKException if an error occurs while deleting the collection + */ + @Override + public void deleteCollection(String collectionName) { + String deleteCollectionOperation = "DELETE FROM " + validateSQLidentifier(collectionsTable) + + " WHERE collectionId = ?"; + String dropTableOperation = "DROP TABLE " + getCollectionTableName(collectionName); + + try (Connection connection = dataSource.getConnection(); + PreparedStatement deleteCollection = connection + .prepareStatement(deleteCollectionOperation)) { + deleteCollection.setObject(1, collectionName); + deleteCollection.execute(); + } catch (SQLException e) { + throw new SKException("Failed to delete collection", e); + } + + try (Connection connection = dataSource.getConnection(); + PreparedStatement dropTable = connection.prepareStatement(dropTableOperation)) { + dropTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to drop table", e); + } + } + + /** + * Gets the collection names. + * + * @return the collection names + * @throws SKException if an error occurs while getting the collection names + */ + @Override + public List getCollectionNames() { + String query = "SELECT collectionId FROM " + validateSQLidentifier(collectionsTable); + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + List collectionNames = new ArrayList<>(); + ResultSet resultSet = statement.executeQuery(); + + while (resultSet.next()) { + collectionNames.add(resultSet.getString(1)); + } + + return Collections.unmodifiableList(collectionNames); + } catch (SQLException e) { + throw new SKException("Failed to get collection names", e); + } + } + + /** + * Gets a list of records from the store. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param mapper the mapper + * @param options the options + * @return the records + * @param the record type + * @throws SKException if an error occurs while getting the records + */ + @Override + public List getRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, + GetRecordOptions options) { + List fields; + if (options == null || options.includeVectors()) { + fields = recordDefinition.getAllFields(); + } else { + fields = recordDefinition.getNonVectorFields(); + } + + String query = "SELECT " + getQueryColumnsFromFields(fields) + + " FROM " + getCollectionTableName(collectionName) + + " WHERE " + recordDefinition.getKeyField().getName() + + " IN (" + getWildcardString(keys.size()) + ")"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + for (int i = 0; i < keys.size(); ++i) { + statement.setObject(i + 1, keys.get(i)); + } + + List records = new ArrayList<>(); + ResultSet resultSet = statement.executeQuery(); + + while (resultSet.next()) { + records.add(mapper.mapStorageModeltoRecord(resultSet)); + } + + return Collections.unmodifiableList(records); + } catch (SQLException e) { + throw new SKException("Failed to set statement values", e); + } + } + + @Override + public void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { + throw new UnsupportedOperationException( + "Upsert is not supported. Try with a specific query provider."); + } + + /** + * Deletes records. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param options the options + * @throws SKException if an error occurs while deleting the records + */ + @Override + public void deleteRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) { + String query = "DELETE FROM " + getCollectionTableName(collectionName) + + " WHERE " + recordDefinition.getKeyField().getName() + + " IN (" + getWildcardString(keys.size()) + ")"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + for (int i = 0; i < keys.size(); ++i) { + statement.setObject(i + 1, keys.get(i)); + } + + statement.execute(); + } catch (SQLException e) { + throw new SKException("Failed to set statement values", e); + } + } + + /** + * Validates an SQL identifier. + * + * @param identifier the identifier + * @return the identifier if it is valid + * @throws IllegalArgumentException if the identifier is invalid + */ + public static String validateSQLidentifier(String identifier) { + if (identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*")) { + return identifier; + } + throw new IllegalArgumentException("Invalid SQL identifier: " + identifier); + } + + /** + * The builder for {@link JDBCVectorStoreDefaultQueryProvider}. + */ + public static class Builder + implements JDBCVectorStoreQueryProvider.Builder { + private DataSource dataSource; + private String collectionsTable = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + /** + * Sets the data source. + * @param dataSource the data source + * @return the builder + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTable the collections table name + * @return the builder + */ + public Builder withCollectionsTable(String collectionsTable) { + this.collectionsTable = validateSQLidentifier(collectionsTable); + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); + return this; + } + + @Override + public JDBCVectorStoreDefaultQueryProvider build() { + if (dataSource == null) { + throw new IllegalArgumentException("DataSource is required"); + } + + return new JDBCVectorStoreDefaultQueryProvider(dataSource, collectionsTable, + prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java new file mode 100644 index 00000000..adb6e13c --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.annotation.Nullable; + +public class JDBCVectorStoreOptions { + @Nullable + private final JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory; + @Nullable + private final JDBCVectorStoreQueryProvider queryProvider; + + /** + * Creates a new instance of the JDBC vector store options. + * + * @param vectorStoreRecordCollectionFactory The vector store record collection factory. + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource in queryProvider is not exposed + public JDBCVectorStoreOptions( + @Nullable JDBCVectorStoreQueryProvider queryProvider, + @Nullable JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory) { + this.queryProvider = queryProvider; + this.vectorStoreRecordCollectionFactory = vectorStoreRecordCollectionFactory; + } + + /** + * Creates a new instance of the JDBC vector store options. + */ + public JDBCVectorStoreOptions() { + this(null, null); + } + + /** + * Gets the query provider. + * + * @return the query provider + */ + @Nullable + @SuppressFBWarnings("EI_EXPOSE_REP") // DataSource in queryProvider is not exposed + public JDBCVectorStoreQueryProvider getQueryProvider() { + return queryProvider; + } + + /** + * Creates a new builder. + * + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the vector store record collection factory. + * + * @return the vector store record collection factory + */ + @Nullable + public JDBCVectorStoreRecordCollectionFactory getVectorStoreRecordCollectionFactory() { + return vectorStoreRecordCollectionFactory; + } + + /** + * Builder for JDBC vector store options. + * + */ + public static class Builder { + @Nullable + private JDBCVectorStoreQueryProvider queryProvider; + @Nullable + private JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory; + + /** + * Sets the query provider. + * + * @param queryProvider The query provider. + * @return The updated builder instance. + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource in queryProvider is not exposed + public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvider) { + this.queryProvider = queryProvider; + return this; + } + + /** + * Sets the vector store record collection factory. + * + * @param vectorStoreRecordCollectionFactory The vector store record collection factory. + * @return The updated builder instance. + */ + public Builder withVectorStoreRecordCollectionFactory( + JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory) { + this.vectorStoreRecordCollectionFactory = vectorStoreRecordCollectionFactory; + return this; + } + + /** + * Builds the JDBC vector store options. + * + * @return The JDBC vector store options. + */ + public JDBCVectorStoreOptions build() { + return new JDBCVectorStoreOptions(queryProvider, vectorStoreRecordCollectionFactory); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java new file mode 100644 index 00000000..26d976aa --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; + +import java.util.List; + +/** + * The JDBC vector store query provider. + * Provides the necessary methods to interact with a JDBC vector store and vector store collections. + */ +public interface JDBCVectorStoreQueryProvider { + /** + * The default name for the collections table. + */ + String DEFAULT_COLLECTIONS_TABLE = "SKCollections"; + + /** + * The prefix for collection tables. + */ + String DEFAULT_PREFIX_FOR_COLLECTION_TABLES = "SKCollection_"; + + /** + * Prepares the vector store. + * Executes any necessary setup steps for the vector store. + */ + void prepareVectorStore(); + + /** + * Checks if the types of the record class fields are supported. + * + * @param recordClass the record class + * @param recordDefinition the record definition + */ + void validateSupportedTypes(Class recordClass, VectorStoreRecordDefinition recordDefinition); + + /** + * Checks if a collection exists. + * + * @param collectionName the collection name + * @return true if the collection exists, false otherwise + */ + boolean collectionExists(String collectionName); + + /** + * Creates a collection. + * + * @param collectionName the collection name + * @param recordClass the record class + * @param recordDefinition the record definition + */ + void createCollection(String collectionName, Class recordClass, + VectorStoreRecordDefinition recordDefinition); + + /** + * Deletes a collection. + * + * @param collectionName the collection name + */ + void deleteCollection(String collectionName); + + /** + * Gets the collection names. + * + * @return the collection names + */ + List getCollectionNames(); + + /** + * Gets records. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param mapper the mapper + * @param options the options + * @return the records + */ + List getRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper mapper, + GetRecordOptions options); + + /** + * Upserts records. + * + * @param collectionName the collection name + * @param records the records + * @param vectorStoreRecordDefinition the record definition + * @param options the options + */ + void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition vectorStoreRecordDefinition, UpsertRecordOptions options); + + /** + * Deletes records. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param options the options + */ + void deleteRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options); + + /** + * The builder for the JDBC vector store query provider. + */ + interface Builder extends SemanticKernelBuilder { + + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java new file mode 100644 index 00000000..b9c0bd3c --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import com.microsoft.semantickernel.exceptions.SKException; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import javax.annotation.Nonnull; +import javax.sql.DataSource; +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +public class JDBCVectorStoreRecordCollection + implements SQLVectorStoreRecordCollection { + private final String collectionName; + private final VectorStoreRecordDefinition recordDefinition; + private final JDBCVectorStoreRecordCollectionOptions options; + private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper; + private final JDBCVectorStoreQueryProvider queryProvider; + + /** + * Creates a new instance of the {@link JDBCVectorStoreRecordCollection}. + * + * @param dataSource the data source + * @param collectionName the name of the collection + * @param options the options + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public JDBCVectorStoreRecordCollection( + @Nonnull DataSource dataSource, + @Nonnull String collectionName, + @Nonnull JDBCVectorStoreRecordCollectionOptions options) { + this.collectionName = collectionName; + this.options = options; + + // If record definition is not provided, create one from the record class + recordDefinition = options.getRecordDefinition() == null + ? VectorStoreRecordDefinition.fromRecordClass(options.getRecordClass()) + : options.getRecordDefinition(); + + // If mapper is not provided, set a default one + if (options.getVectorStoreRecordMapper() == null) { + vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder() + .withRecordClass(options.getRecordClass()) + .withVectorStoreRecordDefinition(recordDefinition) + .build(); + } else { + vectorStoreRecordMapper = options.getVectorStoreRecordMapper(); + } + + // If the query provider is not provided, set a default one + if (options.getQueryProvider() == null) { + this.queryProvider = JDBCVectorStoreDefaultQueryProvider.builder() + .withDataSource(dataSource) + .build(); + } else { + this.queryProvider = options.getQueryProvider(); + } + + // Check if the types are supported + queryProvider.validateSupportedTypes(options.getRecordClass(), recordDefinition); + } + + /** + * Gets the name of the collection. + * + * @return The name of the collection. + */ + @Override + public String getCollectionName() { + return collectionName; + } + + /** + * Checks if the collection exists in the store. + * + * @return A Mono emitting a boolean indicating if the collection exists. + * @throws SKException if the operation fails + */ + @Override + public Mono collectionExistsAsync() { + return Mono.fromCallable( + () -> queryProvider.collectionExists(this.collectionName)) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Creates the collection in the store. + * + * @return A Mono representing the completion of the creation operation. + * @throws SKException if the operation fails + */ + @Override + public Mono createCollectionAsync() { + return Mono.fromRunnable( + () -> queryProvider.createCollection(this.collectionName, options.getRecordClass(), + recordDefinition)) + .subscribeOn(Schedulers.boundedElastic()) + .then(); + } + + /** + * Creates the collection in the store if it does not exist. + * + * @return A Mono representing the completion of the creation operation. + * @throws SKException if the operation fails + */ + @Override + public Mono createCollectionIfNotExistsAsync() { + return collectionExistsAsync().map( + exists -> { + if (!exists) { + return createCollectionAsync(); + } + return Mono.empty(); + }) + .flatMap(mono -> mono) + .then(); + } + + /** + * Deletes the collection from the store. + * + * @return A Mono representing the completion of the deletion operation. + * @throws SKException if the operation fails + */ + @Override + public Mono deleteCollectionAsync() { + return Mono.fromRunnable( + () -> { + queryProvider.deleteCollection(this.collectionName); + }).subscribeOn(Schedulers.boundedElastic()).then(); + } + + /** + * Gets a record from the store. + * + * @param key The key of the record to get. + * @param options The options for getting the record. + * @return A Mono emitting the record. + * @throws SKException if the operation fails + */ + @Override + public Mono getAsync(String key, GetRecordOptions options) { + return this.getBatchAsync(Collections.singletonList(key), options) + .mapNotNull(records -> { + if (records.isEmpty()) { + return null; + } + return records.get(0); + }); + } + + /** + * Gets a batch of records from the store. + * + * @param keys The keys of the records to get. + * @param options The options for getting the records. + * @return A Mono emitting a collection of records. + * @throws SKException if the operation fails + */ + @Override + public Mono> getBatchAsync(List keys, GetRecordOptions options) { + return Mono.fromCallable( + () -> { + return queryProvider.getRecords(this.collectionName, keys, recordDefinition, + vectorStoreRecordMapper, options); + }).subscribeOn(Schedulers.boundedElastic()); + } + + protected String getKeyFromRecord(Record data) { + try { + Field keyField = data.getClass() + .getDeclaredField(recordDefinition.getKeyField().getName()); + keyField.setAccessible(true); + return (String) keyField.get(data); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new SKException("Failed to get key from record", e); + } + } + + /** + * Inserts or updates a record in the store. + * + * @param data The record to upsert. + * @param options The options for upserting the record. + * @return A Mono emitting the key of the upserted record. + * @throws SKException if the operation fails + */ + @Override + public Mono upsertAsync(Record data, UpsertRecordOptions options) { + return this.upsertBatchAsync(Collections.singletonList(data), options) + .mapNotNull(keys -> { + if (keys.isEmpty()) { + return null; + } + return keys.get(0); + }); + } + + /** + * Inserts or updates a batch of records in the store. + * + * @param data The records to upsert. + * @param options The options for upserting the records. + * @return A Mono emitting a collection of keys of the upserted records. + * @throws SKException if the operation fails + */ + @Override + public Mono> upsertBatchAsync(List data, UpsertRecordOptions options) { + return Mono.fromCallable( + () -> { + queryProvider.upsertRecords(this.collectionName, data, recordDefinition, options); + return data.stream().map(this::getKeyFromRecord).collect(Collectors.toList()); + }) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Deletes a record from the store. + * + * @param key The key of the record to delete. + * @param options The options for deleting the record. + * @return A Mono representing the completion of the deletion operation. + * @throws SKException if the operation fails + */ + @Override + public Mono deleteAsync(String key, DeleteRecordOptions options) { + return this.deleteBatchAsync(Collections.singletonList(key), options); + } + + /** + * Deletes a batch of records from the store. + * + * @param keys The keys of the records to delete. + * @param options The options for deleting the records. + * @return A Mono representing the completion of the deletion operation. + * @throws SKException if the operation fails + */ + @Override + public Mono deleteBatchAsync(List keys, DeleteRecordOptions options) { + return Mono.fromRunnable( + () -> { + queryProvider.deleteRecords(this.collectionName, keys, recordDefinition, options); + }).subscribeOn(Schedulers.boundedElastic()).then(); + } + + /** + * Prepares the collection for use. + * + * @return A Mono representing the completion of the preparation operation. + * @throws SKException if the operation fails + */ + @Override + public Mono prepareAsync() { + return Mono.fromRunnable(queryProvider::prepareVectorStore) + .subscribeOn(Schedulers.boundedElastic()).then(); + } + + public static class Builder + implements SemanticKernelBuilder> { + private DataSource dataSource; + private String collectionName; + private JDBCVectorStoreRecordCollectionOptions options; + + /** + * Sets the data source. + * + * @param dataSource the data source + * @return the builder + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the collection name. + * + * @param collectionName the collection name + * @return the builder + */ + public Builder withCollectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * Sets the options. + * + * @param options the options + * @return the builder + */ + public Builder withOptions(JDBCVectorStoreRecordCollectionOptions options) { + this.options = options; + return this; + } + + @Override + public JDBCVectorStoreRecordCollection build() { + if (dataSource == null) { + throw new IllegalArgumentException("dataSource is required"); + } + if (collectionName == null) { + throw new IllegalArgumentException("collectionName is required"); + } + if (options == null) { + throw new IllegalArgumentException("options is required"); + } + + return new JDBCVectorStoreRecordCollection<>(dataSource, collectionName, options); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java new file mode 100644 index 00000000..70b62a7e --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import javax.sql.DataSource; +import java.sql.Connection; + +/** + * Factory for creating JDBC vector store record collections. + */ +public interface JDBCVectorStoreRecordCollectionFactory { + /** + * Creates a new JDBC vector store record collection. + * + * @param options The options for the collection. + * @return The new JDBC vector store record collection. + */ + JDBCVectorStoreRecordCollection createVectorStoreRecordCollection( + DataSource dataSource, + String collectionName, + JDBCVectorStoreRecordCollectionOptions options); +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java new file mode 100644 index 00000000..af1ec49e --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider.validateSQLidentifier; +import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_COLLECTIONS_TABLE; +import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + +public class JDBCVectorStoreRecordCollectionOptions { + private final Class recordClass; + private final JDBCVectorStoreRecordMapper vectorStoreRecordMapper; + private final VectorStoreRecordDefinition recordDefinition; + private final JDBCVectorStoreQueryProvider queryProvider; + private final String collectionsTableName; + private final String prefixForCollectionTables; + + private JDBCVectorStoreRecordCollectionOptions( + Class recordClass, + VectorStoreRecordDefinition recordDefinition, + JDBCVectorStoreRecordMapper vectorStoreRecordMapper, + JDBCVectorStoreQueryProvider queryProvider, + String collectionsTableName, + String prefixForCollectionTables) { + this.recordClass = recordClass; + this.recordDefinition = recordDefinition; + this.vectorStoreRecordMapper = vectorStoreRecordMapper; + this.queryProvider = queryProvider; + this.collectionsTableName = collectionsTableName; + this.prefixForCollectionTables = prefixForCollectionTables; + } + + /** + * Creates a new builder. + * @param the record type + * @return the builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Gets the record class. + * @return the record class + */ + public Class getRecordClass() { + return recordClass; + } + + /** + * Gets the record definition. + * @return the record definition + */ + public VectorStoreRecordDefinition getRecordDefinition() { + return recordDefinition; + } + + /** + * Gets the vector store record mapper. + * @return the vector store record mapper + */ + public JDBCVectorStoreRecordMapper getVectorStoreRecordMapper() { + return vectorStoreRecordMapper; + } + + /** + * Gets the collections table. + * @return the collections table + */ + public String getCollectionsTableName() { + return collectionsTableName; + } + + /** + * Gets the prefix for collection tables. + * @return the prefix for collection tables + */ + public String getPrefixForCollectionTables() { + return prefixForCollectionTables; + } + + /** + * Gets the query provider. + * @return the query provider + */ + @SuppressFBWarnings("EI_EXPOSE_REP") // DataSource in queryProvider is not exposed + public JDBCVectorStoreQueryProvider getQueryProvider() { + return queryProvider; + } + + public static class Builder { + private Class recordClass; + private VectorStoreRecordDefinition recordDefinition; + private JDBCVectorStoreRecordMapper vectorStoreRecordMapper; + private JDBCVectorStoreQueryProvider queryProvider; + private String collectionsTableName = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + /** + * Sets the record class. + * @param recordClass the record class + * @return the builder + */ + public Builder withRecordClass(Class recordClass) { + this.recordClass = recordClass; + return this; + } + + /** + * Sets the record definition. + * @param recordDefinition the record definition + * @return the builder + */ + public Builder withRecordDefinition(VectorStoreRecordDefinition recordDefinition) { + this.recordDefinition = recordDefinition; + return this; + } + + /** + * Sets the vector store record mapper. + * @param vectorStoreRecordMapper the vector store record mapper + * @return the builder + */ + public Builder withVectorStoreRecordMapper( + JDBCVectorStoreRecordMapper vectorStoreRecordMapper) { + this.vectorStoreRecordMapper = vectorStoreRecordMapper; + return this; + } + + /** + * Sets the query provider. + * @param queryProvider the query provider + * @return the builder + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource in queryProvider is not exposed + public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvider) { + this.queryProvider = queryProvider; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTableName the collections table name + * @return the builder + */ + public Builder withCollectionsTableName(String collectionsTableName) { + this.collectionsTableName = validateSQLidentifier(collectionsTableName); + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); + return this; + } + + /** + * Builds the options. + * @return the options + */ + public JDBCVectorStoreRecordCollectionOptions build() { + if (recordClass == null) { + throw new IllegalArgumentException("recordClass is required"); + } + + return new JDBCVectorStoreRecordCollectionOptions<>( + recordClass, + recordDefinition, + vectorStoreRecordMapper, + queryProvider, + collectionsTableName, + prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java new file mode 100644 index 00000000..6eff0c7d --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.exceptions.SKException; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; + +import java.sql.ResultSetMetaData; +import java.util.List; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.function.Function; + +public class JDBCVectorStoreRecordMapper + extends VectorStoreRecordMapper { + + /** + * Constructs a new instance of the VectorStoreRecordMapper. + * + * @param storageModelToRecordMapper the function to convert a storage model to a record + */ + protected JDBCVectorStoreRecordMapper(Function storageModelToRecordMapper) { + super(null, storageModelToRecordMapper); + } + + /** + * Creates a new builder. + * + * @param the record type + * @return the builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Operation not supported. + */ + @Override + public ResultSet mapRecordToStorageModel(Record record) { + throw new UnsupportedOperationException("Not implemented"); + } + + public static class Builder + implements SemanticKernelBuilder> { + private Class recordClass; + private VectorStoreRecordDefinition vectorStoreRecordDefinition; + + /** + * Sets the record class. + * + * @param recordClass the record class + * @return the builder + */ + public Builder withRecordClass(Class recordClass) { + this.recordClass = recordClass; + return this; + } + + /** + * Sets the vector store record definition. + * + * @param vectorStoreRecordDefinition the vector store record definition + * @return the builder + */ + public Builder withVectorStoreRecordDefinition( + VectorStoreRecordDefinition vectorStoreRecordDefinition) { + this.vectorStoreRecordDefinition = vectorStoreRecordDefinition; + return this; + } + + /** + * Builds the {@link JDBCVectorStoreRecordMapper}. + * + * @return the {@link JDBCVectorStoreRecordMapper} + */ + public JDBCVectorStoreRecordMapper build() { + if (recordClass == null) { + throw new IllegalArgumentException("recordClass is required"); + } + if (vectorStoreRecordDefinition == null) { + throw new IllegalArgumentException("vectorStoreRecordDefinition is required"); + } + + return new JDBCVectorStoreRecordMapper<>( + resultSet -> { + try { + Constructor constructor = recordClass.getDeclaredConstructor(); + constructor.setAccessible(true); + Record record = (Record) constructor.newInstance(); + + // Select fields from the record definition. + // Check if vector fields are present in the result set. + List fields; + ResultSetMetaData metaData = resultSet.getMetaData(); + if (metaData.getColumnCount() == vectorStoreRecordDefinition.getAllFields() + .size()) { + fields = vectorStoreRecordDefinition.getAllFields(); + } else { + fields = vectorStoreRecordDefinition.getNonVectorFields(); + } + + for (VectorStoreRecordField field : fields) { + Object value = resultSet.getObject(field.getName()); + Field recordField = recordClass.getDeclaredField(field.getName()); + recordField.setAccessible(true); + + // If the field is a vector field, deserialize the JSON string + if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = recordField.getType(); + + // If the vector type is a string, set the value directly + if (vectorType.equals(String.class)) { + recordField.set(record, value); + } else { + // Deserialize the JSON string to the vector type + recordField.set(record, + new ObjectMapper().readValue((String) value, vectorType)); + } + } else { + recordField.set(record, value); + } + } + + return record; + } catch (NoSuchMethodException e) { + throw new SKException("Default constructor not found.", e); + } catch (InstantiationException | InvocationTargetException e) { + throw new SKException(String.format( + "SK cannot instantiate %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (JsonProcessingException e) { + throw new SKException(String.format( + "SK cannot deserialize %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (SQLException | NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + }); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java new file mode 100644 index 00000000..72ecd87e --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/MySQLVectorStoreQueryProvider.java @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import com.microsoft.semantickernel.exceptions.SKException; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.sql.DataSource; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.List; + +public class MySQLVectorStoreQueryProvider extends + JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { + + private final DataSource dataSource; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + private MySQLVectorStoreQueryProvider(DataSource dataSource, String collectionsTable, + String prefixForCollectionTables) { + super(dataSource, collectionsTable, prefixForCollectionTables); + this.dataSource = dataSource; + } + + /** + * Creates a new builder. + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + private void setStatementValues(PreparedStatement statement, Object record, + List fields) { + for (int i = 0; i < fields.size(); ++i) { + VectorStoreRecordField field = fields.get(i); + try { + Field recordField = record.getClass().getDeclaredField(field.getName()); + recordField.setAccessible(true); + Object value = recordField.get(record); + + if (field instanceof VectorStoreRecordKeyField) { + statement.setObject(i + 1, (String) value); + } else if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = record.getClass().getDeclaredField(field.getName()) + .getType(); + + // If the vector field is other than String, serialize it to JSON + if (vectorType.equals(String.class)) { + statement.setObject(i + 1, value); + } else { + // Serialize the vector to JSON + statement.setObject(i + 1, new ObjectMapper().writeValueAsString(value)); + } + } else { + statement.setObject(i + 1, value); + } + } catch (NoSuchFieldException | IllegalAccessException | SQLException e) { + throw new SKException("Failed to set statement values", e); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + /** + * Upserts records into the collection. + * @param collectionName the collection name + * @param records the records to upsert + * @param recordDefinition the record definition + * @param options the upsert options + * @throws SKException if the upsert fails + */ + @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers + public void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { + validateSQLidentifier(getCollectionTableName(collectionName)); + + List fields = recordDefinition.getAllFields(); + + StringBuilder onDuplicateKeyUpdate = new StringBuilder(); + for (int i = 0; i < fields.size(); ++i) { + VectorStoreRecordField field = fields.get(i); + if (i > 0) { + onDuplicateKeyUpdate.append(", "); + } + + onDuplicateKeyUpdate.append(field.getName()).append(" = VALUES(") + .append(field.getName()).append(")"); + } + + String query = "INSERT INTO " + getCollectionTableName(collectionName) + + " (" + getQueryColumnsFromFields(fields) + ")" + + " VALUES (" + getWildcardString(fields.size()) + ")" + + " ON DUPLICATE KEY UPDATE " + onDuplicateKeyUpdate; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + for (Object record : records) { + setStatementValues(statement, record, recordDefinition.getAllFields()); + statement.addBatch(); + } + + statement.executeBatch(); + } catch (SQLException e) { + throw new SKException("Failed to upsert records", e); + } + } + + public static class Builder + extends JDBCVectorStoreDefaultQueryProvider.Builder { + private DataSource dataSource; + private String collectionsTable = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTable the collections table name + * @return the builder + */ + public Builder withCollectionsTable(String collectionsTable) { + this.collectionsTable = validateSQLidentifier(collectionsTable); + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); + return this; + } + + public MySQLVectorStoreQueryProvider build() { + if (dataSource == null) { + throw new SKException("DataSource is required"); + } + + return new MySQLVectorStoreQueryProvider(dataSource, collectionsTable, + prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java new file mode 100644 index 00000000..10e4d2ef --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.VectorStore; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import reactor.core.publisher.Mono; + +public interface SQLVectorStore> + extends VectorStore { + + /** + * Prepares the vector store. + * + * @return A {@link Mono} that completes when the vector store is prepared to be used. + */ + Mono prepareAsync(); +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java new file mode 100644 index 00000000..ff12c88b --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import reactor.core.publisher.Mono; + +public interface SQLVectorStoreRecordCollection + extends VectorStoreRecordCollection { + + /** + * Prepares the vector store record collection. + * + * @return A {@link Mono} that completes when the vector store record collection is prepared to be used. + */ + Mono prepareAsync(); +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java index 8783320b..52d30bf8 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java @@ -81,10 +81,13 @@ public RedisVectorStoreRecordCollection( } // Validate supported types - VectorStoreRecordDefinition.validateSupportedKeyTypes(options.getRecordClass(), - recordDefinition, supportedKeyTypes); - VectorStoreRecordDefinition.validateSupportedVectorTypes(options.getRecordClass(), - recordDefinition, supportedVectorTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + Collections + .singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), + supportedKeyTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), + supportedVectorTypes); // If mapper is not provided, set a default one if (options.getVectorStoreRecordMapper() == null) { diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java index 867cbf16..1466ac35 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java @@ -9,41 +9,40 @@ import java.util.List; public interface VectorStoreRecordCollection { - /** * Gets the name of the collection. * * @return The name of the collection. */ - public String getCollectionName(); + String getCollectionName(); /** * Checks if the collection exists in the store. * * @return A Mono emitting a boolean indicating if the collection exists. */ - public Mono collectionExistsAsync(); + Mono collectionExistsAsync(); /** * Creates the collection in the store. * * @return A Mono representing the completion of the creation operation. */ - public Mono createCollectionAsync(); + Mono createCollectionAsync(); /** * Creates the collection in the store if it does not exist. * * @return A Mono representing the completion of the creation operation. */ - public Mono createCollectionIfNotExistsAsync(); + Mono createCollectionIfNotExistsAsync(); /** * Deletes the collection from the store. * * @return A Mono representing the completion of the deletion operation. */ - public Mono deleteCollectionAsync(); + Mono deleteCollectionAsync(); /** * Gets a record from the store. diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java index e675d4cd..7a433dbb 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java @@ -43,8 +43,10 @@ public VolatileVectorStoreRecordCollection(String collectionName, } // Validate the key type - VectorStoreRecordDefinition.validateSupportedKeyTypes(options.getRecordClass(), - recordDefinition, supportedKeyTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + Collections + .singletonList(recordDefinition.getKeyDeclaredField(options.getRecordClass())), + supportedKeyTypes); } VolatileVectorStoreRecordCollection(String collectionName, diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java index a1914d2c..39e04a3f 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java @@ -5,13 +5,12 @@ import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; /** @@ -50,6 +49,56 @@ public List getAllFields() { return fields; } + public List getNonVectorFields() { + List fields = new ArrayList<>(); + fields.add(keyField); + fields.addAll(dataFields); + return fields; + } + + private enum DeclaredFieldType { + KEY, DATA, VECTOR + } + + private List getDeclaredFields(Class recordClass, List fields, + DeclaredFieldType fieldType) { + List declaredFields = new ArrayList<>(); + for (VectorStoreRecordField field : fields) { + try { + Field declaredField = recordClass.getDeclaredField(field.getName()); + declaredFields.add(declaredField); + } catch (NoSuchFieldException e) { + throw new IllegalArgumentException( + String.format("%s field not found in record class: %s", fieldType, + field.getName())); + } + } + return declaredFields; + } + + public Field getKeyDeclaredField(Class recordClass) { + try { + return recordClass.getDeclaredField(keyField.getName()); + } catch (NoSuchFieldException e) { + throw new IllegalArgumentException( + "Key field not found in record class: " + keyField.getName()); + } + } + + public List getDataDeclaredFields(Class recordClass) { + return getDeclaredFields( + recordClass, + dataFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), + DeclaredFieldType.DATA); + } + + public List getVectorDeclaredFields(Class recordClass) { + return getDeclaredFields( + recordClass, + vectorFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), + DeclaredFieldType.VECTOR); + } + private VectorStoreRecordDefinition( VectorStoreRecordKeyField keyField, List dataFields, @@ -148,71 +197,20 @@ public static VectorStoreRecordDefinition fromRecordClass(Class recordClass) return checkFields(keyFields, dataFields, vectorFields); } - private static String getSupportedTypesString(@Nullable HashSet> types) { - if (types == null || types.isEmpty()) { - return ""; - } - return types.stream().map(Class::getName).collect(Collectors.joining(", ")); - } - - public static void validateSupportedKeyTypes(@Nonnull Class recordClass, - @Nonnull VectorStoreRecordDefinition recordDefinition, - @Nonnull HashSet> supportedTypes) { - String supportedTypesString = getSupportedTypesString(supportedTypes); - - try { - Field declaredField = recordClass.getDeclaredField(recordDefinition.keyField.getName()); - + public static void validateSupportedTypes(List declaredFields, + Set> supportedTypes) { + Set> unsupportedTypes = new HashSet<>(); + for (Field declaredField : declaredFields) { if (!supportedTypes.contains(declaredField.getType())) { - throw new IllegalArgumentException( - "Unsupported key field type: " + declaredField.getType().getName() - + ". Supported types are: " + supportedTypesString); + unsupportedTypes.add(declaredField.getType()); } - } catch (NoSuchFieldException e) { - throw new IllegalArgumentException( - "Key field not found in record class: " + recordDefinition.keyField.getName()); } - } - - public static void validateSupportedDataTypes(@Nonnull Class recordClass, - @Nonnull VectorStoreRecordDefinition recordDefinition, - @Nonnull HashSet> supportedTypes) { - String supportedTypesString = getSupportedTypesString(supportedTypes); - - for (VectorStoreRecordDataField field : recordDefinition.dataFields) { - try { - Field declaredField = recordClass.getDeclaredField(field.getName()); - - if (!supportedTypes.contains(declaredField.getType())) { - throw new IllegalArgumentException( - "Unsupported data field type: " + declaredField.getType().getName() - + ". Supported types are: " + supportedTypesString); - } - } catch (NoSuchFieldException e) { - throw new IllegalArgumentException( - "Data field not found in record class: " + field.getName()); - } - } - } - - public static void validateSupportedVectorTypes(@Nonnull Class recordClass, - @Nonnull VectorStoreRecordDefinition recordDefinition, - @Nonnull HashSet> supportedTypes) { - String supportedTypesString = getSupportedTypesString(supportedTypes); - - for (VectorStoreRecordVectorField field : recordDefinition.vectorFields) { - try { - Field declaredField = recordClass.getDeclaredField(field.getName()); - - if (!supportedTypes.contains(declaredField.getType())) { - throw new IllegalArgumentException( - "Unsupported vector field type: " + declaredField.getType().getName() - + ". Supported types are: " + supportedTypesString); - } - } catch (NoSuchFieldException e) { - throw new IllegalArgumentException( - "Vector field not found in record class: " + field.getName()); - } + if (!unsupportedTypes.isEmpty()) { + throw new IllegalArgumentException( + String.format( + "Unsupported field types found in record class: %s. Supported types: %s", + unsupportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")), + supportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")))); } } }