Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
/**
* Provides OpenAi implementation of audio to text service.
*/
public class OpenAiAudioToTextService extends OpenAiService<OpenAIAsyncClient> implements AudioToTextService {
public class OpenAiAudioToTextService extends OpenAiService<OpenAIAsyncClient>
implements AudioToTextService {

private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiAudioToTextService.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
/**
* Provides OpenAi implementation of text to audio service.
*/
public class OpenAiTextToAudioService extends OpenAiService<OpenAIAsyncClient> implements TextToAudioService {
public class OpenAiTextToAudioService extends OpenAiService<OpenAIAsyncClient>
implements TextToAudioService {

private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiTextToAudioService.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
/**
* OpenAI chat completion service.
*/
public class OpenAIChatCompletion extends OpenAiService<OpenAIAsyncClient> implements ChatCompletionService {
public class OpenAIChatCompletion extends OpenAiService<OpenAIAsyncClient>
implements ChatCompletionService {

private static final Logger LOGGER = LoggerFactory.getLogger(OpenAIChatCompletion.class);

Expand Down Expand Up @@ -1055,7 +1056,8 @@ static ChatRequestMessage getChatRequestMessage(
/**
* Builder for creating a new instance of {@link OpenAIChatCompletion}.
*/
public static class Builder extends OpenAiServiceBuilder<OpenAIAsyncClient, OpenAIChatCompletion, Builder> {
public static class Builder
extends OpenAiServiceBuilder<OpenAIAsyncClient, OpenAIChatCompletion, Builder> {

@Override
public OpenAIChatCompletion build() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
/**
* An OpenAI implementation of a {@link TextGenerationService}.
*/
public class OpenAITextGenerationService extends OpenAiService<OpenAIAsyncClient> implements TextGenerationService {
public class OpenAITextGenerationService extends OpenAiService<OpenAIAsyncClient>
implements TextGenerationService {

private static final Logger LOGGER = LoggerFactory.getLogger(OpenAITextGenerationService.class);

Expand Down
6 changes: 3 additions & 3 deletions api-test/integration-tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@
<version>3.44.1.0</version>
</dependency>
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<version>8.2.0</version>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.33</version>
<scope>test</scope>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;

import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollection;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.mysql.cj.jdbc.MysqlDataSource;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.MySQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import javax.annotation.Nonnull;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;

@Testcontainers
public class JDBCVectorStoreRecordCollectionTest {
@Container
private static final MySQLContainer<?> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
private static final String MYSQL_USER = "test";
private static final String MYSQL_PASSWORD = "test";
private static MysqlDataSource dataSource;
@BeforeAll
static void setup() {
dataSource = new MysqlDataSource();
dataSource.setUrl(CONTAINER.getJdbcUrl());
dataSource.setUser(MYSQL_USER);
dataSource.setPassword(MYSQL_PASSWORD);
}

private JDBCVectorStoreRecordCollection<Hotel> buildRecordCollection(@Nonnull String collectionName) {
JDBCVectorStoreRecordCollection<Hotel> recordCollection = new JDBCVectorStoreRecordCollection<>(
dataSource,
collectionName,
JDBCVectorStoreRecordCollectionOptions.<Hotel>builder()
.withRecordClass(Hotel.class)
.withQueryProvider(MySQLVectorStoreQueryProvider.builder()
.withDataSource(dataSource)
.build())
.build());

recordCollection.prepareAsync().block();
recordCollection.createCollectionIfNotExistsAsync().block();
return recordCollection;
}

@Test
public void buildRecordCollection() {
assertNotNull(buildRecordCollection("buildTest"));
}

private List<Hotel> getHotels() {
return List.of(
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0),
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), 3.0),
new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0),
new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0),
new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0)
);
}

@Test
public void upsertAndGetRecordAsync() {
String collectionName = "upsertAndGetRecordAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
for (Hotel hotel : hotels) {
recordStore.upsertAsync(hotel, null).block();
}

for (Hotel hotel : hotels) {
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
}
}

@Test
public void getBatchAsync() {
String collectionName = "getBatchAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
for (Hotel hotel : hotels) {
recordStore.upsertAsync(hotel, null).block();
}

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, null).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
}

@Test
public void upsertBatchAndGetBatchAsync() {
String collectionName = "upsertBatchAndGetBatchAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, null).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
}

@Test
public void insertAndReplaceAsync() {
String collectionName = "insertAndReplaceAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();
recordStore.upsertBatchAsync(hotels, null).block();
recordStore.upsertBatchAsync(hotels, null).block();

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, null).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
}

@Test
public void deleteRecordAsync() {
String collectionName = "deleteRecordAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

for (Hotel hotel : hotels) {
recordStore.deleteAsync(hotel.getId(), null).block();
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
assertNull(retrievedHotel);
}
}

@Test
public void deleteBatchAsync() {
String collectionName = "deleteBatchAsync";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

recordStore.deleteBatchAsync(keys, null).block();

for (String key : keys) {
Hotel retrievedHotel = recordStore.getAsync(key, null).block();
assertNull(retrievedHotel);
}
}

@Test
public void getWithNoVectors() {
String collectionName = "getWithNoVectors";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

GetRecordOptions options = GetRecordOptions.builder()
.includeVectors(false)
.build();

for (Hotel hotel : hotels) {
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
assertNull(retrievedHotel.getDescriptionEmbedding());
}

options = GetRecordOptions.builder()
.includeVectors(true)
.build();

for (Hotel hotel : hotels) {
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), options).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
assertNotNull(retrievedHotel.getDescriptionEmbedding());
}
}

@Test
public void getBatchWithNoVectors() {
String collectionName = "getBatchWithNoVectors";
JDBCVectorStoreRecordCollection<Hotel> recordStore = buildRecordCollection(collectionName);

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();

GetRecordOptions options = GetRecordOptions.builder()
.includeVectors(false)
.build();

List<String> keys = new ArrayList<>();
for (Hotel hotel : hotels) {
keys.add(hotel.getId());
}

List<Hotel> retrievedHotels = recordStore.getBatchAsync(keys, options).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());

for (Hotel hotel : retrievedHotels) {
assertNull(hotel.getDescriptionEmbedding());
}

options = GetRecordOptions.builder()
.includeVectors(true)
.build();

retrievedHotels = recordStore.getBatchAsync(keys, options).block();
assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());

for (Hotel hotel : retrievedHotels) {
assertNotNull(hotel.getDescriptionEmbedding());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;

import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.mysql.cj.jdbc.MysqlDataSource;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.MySQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

@Testcontainers
public class JDBCVectorStoreTest {
@Container
private static final MySQLContainer<?> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
private static final String MYSQL_USER = "test";
private static final String MYSQL_PASSWORD = "test";
private static MysqlDataSource dataSource;

@BeforeAll
static void setup() {
dataSource = new MysqlDataSource();
dataSource.setUrl(CONTAINER.getJdbcUrl());
dataSource.setUser(MYSQL_USER);
dataSource.setPassword(MYSQL_PASSWORD);
}

@Test
public void getCollectionNamesAsync() {
MySQLVectorStoreQueryProvider queryProvider = MySQLVectorStoreQueryProvider.builder()
.withDataSource(dataSource)
.build();

JDBCVectorStore vectorStore = JDBCVectorStore.builder()
.withDataSource(dataSource)
.withOptions(
JDBCVectorStoreOptions.builder()
.withQueryProvider(queryProvider)
.build()
)
.build();

vectorStore.getCollectionNamesAsync().block();

List<String> collectionNames = Arrays.asList("collection1", "collection2", "collection3");

for (String collectionName : collectionNames) {
vectorStore.getCollection(collectionName, Hotel.class, null).createCollectionAsync().block();
}

List<String> retrievedCollectionNames = vectorStore.getCollectionNamesAsync().block();
assertNotNull(retrievedCollectionNames);
assertEquals(collectionNames.size(), retrievedCollectionNames.size());
for (String collectionName : collectionNames) {
assertTrue(retrievedCollectionNames.contains(collectionName));
}
}
}
Loading