diff --git a/src/it/java/io/weaviate/containers/Weaviate.java b/src/it/java/io/weaviate/containers/Weaviate.java index bcc1ba7d4..d08d8a202 100644 --- a/src/it/java/io/weaviate/containers/Weaviate.java +++ b/src/it/java/io/weaviate/containers/Weaviate.java @@ -14,7 +14,7 @@ public class Weaviate extends WeaviateContainer { private WeaviateClient clientInstance; - public static final String VERSION = "1.29.0"; + public static final String VERSION = "1.29.1"; public static final String DOCKER_IMAGE = "semitechnologies/weaviate"; /** diff --git a/src/it/java/io/weaviate/integration/CollectionsITest.java b/src/it/java/io/weaviate/integration/CollectionsITest.java index be270a42f..06c70b1e5 100644 --- a/src/it/java/io/weaviate/integration/CollectionsITest.java +++ b/src/it/java/io/weaviate/integration/CollectionsITest.java @@ -8,9 +8,11 @@ import io.weaviate.ConcurrentTest; import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.api.collections.CollectionConfig; +import io.weaviate.client6.v1.api.collections.InvertedIndex; import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.Replication; import io.weaviate.client6.v1.api.collections.VectorIndex; -import io.weaviate.client6.v1.api.collections.WeaviateCollection; import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.containers.Container; @@ -29,8 +31,8 @@ public void testCreateGetDelete() throws IOException { var thingsCollection = client.collections.getConfig(collectionName); Assertions.assertThat(thingsCollection).get() - .hasFieldOrPropertyWithValue("name", collectionName) - .extracting(WeaviateCollection::vectors, InstanceOfAssertFactories.map(String.class, VectorIndex.class)) + .hasFieldOrPropertyWithValue("collectionName", collectionName) + .extracting(CollectionConfig::vectors, InstanceOfAssertFactories.map(String.class, VectorIndex.class)) .as("default vector").extractingByKey("default") .satisfies(defaultVector -> { Assertions.assertThat(defaultVector).extracting(VectorIndex::vectorizer) @@ -61,7 +63,7 @@ public void testCrossReferences() throws IOException { .as("after create Things").get() .satisfies(c -> { Assertions.assertThat(c.references()) - .as("ownedBy").filteredOn(p -> p.name().equals("ownedBy")).first() + .as("ownedBy").filteredOn(p -> p.propertyName().equals("ownedBy")).first() .extracting(p -> p.dataTypes(), InstanceOfAssertFactories.LIST) .containsOnly(nsOwners); }); @@ -81,7 +83,7 @@ public void testCrossReferences() throws IOException { .as("after add property").get() .satisfies(c -> { Assertions.assertThat(c.references()) - .as("soldIn").filteredOn(p -> p.name().equals("soldIn")).first() + .as("soldIn").filteredOn(p -> p.propertyName().equals("soldIn")).first() .extracting(p -> p.dataTypes(), InstanceOfAssertFactories.LIST) .containsOnly(nsOnlineStores, nsMarkets); }); @@ -105,7 +107,7 @@ public void testListDeleteAll() throws IOException { var all = client.collections.list(); Assertions.assertThat(all) .hasSizeGreaterThanOrEqualTo(3) - .extracting(WeaviateCollection::name) + .extracting(CollectionConfig::collectionName) .contains(nsA, nsB, nsC); client.collections.deleteAll(); @@ -114,4 +116,47 @@ public void testListDeleteAll() throws IOException { Assertions.assertThat(all.isEmpty()); } + + @Test + public void testUpdateCollection() throws IOException { + var nsBoxes = ns("Boxes"); + var nsThings = ns("Things"); + + client.collections.create(nsBoxes); + + client.collections.create(nsThings, + collection -> collection + .description("Things stored in boxes") + .properties( + Property.text("name"), + Property.integer("width", + w -> w.description("how wide this thing is"))) + .invertedIndex(idx -> idx.cleanupIntervalSeconds(10)) + .replication(repl -> repl.asyncEnabled(true))); + + var things = client.collections.use(nsThings); + + // Act + things.config.update(nsThings, collection -> collection + .description("Things stored on shelves") + .propertyDescription("width", "not height") + .invertedIndex(idx -> idx.cleanupIntervalSeconds(30)) + .replication(repl -> repl.asyncEnabled(false))); + + // Assert + var updated = things.config.get(); + Assertions.assertThat(updated).get() + .returns("Things stored on shelves", CollectionConfig::description) + .satisfies(collection -> { + Assertions.assertThat(collection) + .extracting(CollectionConfig::properties, InstanceOfAssertFactories.list(Property.class)) + .extracting(Property::description).contains("not height"); + + Assertions.assertThat(collection) + .extracting(CollectionConfig::invertedIndex).returns(30, InvertedIndex::cleanupIntervalSeconds); + + Assertions.assertThat(collection) + .extracting(CollectionConfig::replication).returns(false, Replication::asyncEnabled); + }); + } } diff --git a/src/it/java/io/weaviate/integration/ReferencesITest.java b/src/it/java/io/weaviate/integration/ReferencesITest.java index 68aa7b623..8ee613b4f 100644 --- a/src/it/java/io/weaviate/integration/ReferencesITest.java +++ b/src/it/java/io/weaviate/integration/ReferencesITest.java @@ -61,7 +61,7 @@ public void testReferences() throws IOException { .as("Artists: create collection") .extracting(c -> c.references().stream().findFirst()) .as("has one reference property").extracting(Optional::get) - .returns("hasAwards", ReferenceProperty::name) + .returns("hasAwards", ReferenceProperty::propertyName) .extracting(ReferenceProperty::dataTypes, InstanceOfAssertFactories.list(String.class)) .containsOnly(nsGrammy, nsOscar); @@ -87,7 +87,7 @@ public void testReferences() throws IOException { Assertions.assertThat(collectionArtists).get() .as("Artists: add reference to Movies") .extracting(c -> c.references().stream() - .filter(property -> property.name().equals("featuredIn")).findFirst()) + .filter(property -> property.propertyName().equals("featuredIn")).findFirst()) .as("featuredIn reference property").extracting(Optional::get) .extracting(ReferenceProperty::dataTypes, InstanceOfAssertFactories.list(String.class)) .containsOnly(nsMovies); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java new file mode 100644 index 000000000..39820b34f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java @@ -0,0 +1,321 @@ +package io.weaviate.client6.v1.api.collections; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.annotations.SerializedName; +import com.google.gson.internal.Streams; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record CollectionConfig( + @SerializedName("class") String collectionName, + @SerializedName("description") String description, + @SerializedName("properties") List properties, + List references, + @SerializedName("vectorConfig") Map vectors, + @SerializedName("multiTenancyConfig") MultiTenancy multiTenancy, + @SerializedName("shardingConfig") Sharding sharding, + @SerializedName("replicationConfig") Replication replication, + @SerializedName("invertedIndexConfig") InvertedIndex invertedIndex, + List rerankerModules, + Generative generativeModule) { + + public static CollectionConfig of(String collectionName) { + return of(collectionName, ObjectBuilder.identity()); + } + + public static CollectionConfig of(String collectionName, Function> fn) { + return fn.apply(new Builder(collectionName)).build(); + } + + /** + * Returns a {@link Builder} with all current values of + * {@code WeaviateCollection} pre-filled. + */ + public Builder edit() { + return new Builder(collectionName) + .description(description) + .properties(properties) + .references(references) + .vectors(vectors) + .multiTenancy(multiTenancy) + .sharding(sharding) + .replication(replication) + .invertedIndex(invertedIndex) + .rerankerModules(rerankerModules != null ? rerankerModules : new ArrayList<>()) + .generativeModule(generativeModule); + } + + /** Create a copy of this {@code WeaviateCollection} and edit parts of it. */ + public CollectionConfig edit(Function> fn) { + return fn.apply(edit()).build(); + } + + public CollectionConfig(Builder builder) { + this( + builder.collectionName, + builder.description, + builder.propertyList(), + builder.referenceList(), + builder.vectors, + builder.multiTenancy, + builder.sharding, + builder.replication, + builder.invertedIndex, + builder.rerankerModules, + builder.generativeModule); + } + + public static class Builder implements ObjectBuilder { + // Required parameters; + private final String collectionName; + + private String description; + private Map properties = new HashMap<>(); + private Map references = new HashMap<>(); + private Map vectors = new HashMap<>(); + private MultiTenancy multiTenancy; + private Sharding sharding; + private Replication replication; + private InvertedIndex invertedIndex; + private List rerankerModules = new ArrayList<>(); + private Generative generativeModule; + + public Builder(String collectionName) { + this.collectionName = collectionName; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder properties(Property... properties) { + return properties(Arrays.asList(properties)); + } + + public Builder properties(List properties) { + properties.forEach(property -> this.properties.put(property.propertyName(), property)); + return this; + } + + private List propertyList() { + return this.properties.values().stream().toList(); + } + + public Builder references(ReferenceProperty... references) { + return references(Arrays.asList(references)); + } + + public Builder references(List references) { + references.forEach(reference -> this.references.put(reference.propertyName(), reference)); + return this; + } + + private List referenceList() { + return this.references.values().stream().toList(); + } + + public Builder vector(VectorIndex vector) { + this.vectors.put(VectorIndex.DEFAULT_VECTOR_NAME, vector); + return this; + } + + public Builder vector(String name, VectorIndex vector) { + this.vectors.put(name, vector); + return this; + } + + public Builder vectors(Map vectors) { + this.vectors.putAll(vectors); + return this; + } + + public Builder vectors(Function>> fn) { + this.vectors = fn.apply(new VectorsBuilder()).build(); + return this; + } + + public static class VectorsBuilder implements ObjectBuilder> { + private Map vectors = new HashMap<>(); + + public VectorsBuilder vector(String name, VectorIndex vector) { + vectors.put(name, vector); + return this; + } + + @Override + public Map build() { + return this.vectors; + } + } + + public Builder sharding(Sharding sharding) { + this.sharding = sharding; + return this; + } + + public Builder sharding(Function> fn) { + this.sharding = Sharding.of(fn); + return this; + } + + public Builder multiTenancy(MultiTenancy multiTenancy) { + this.multiTenancy = multiTenancy; + return this; + } + + public Builder multiTenancy(Function> fn) { + this.multiTenancy = MultiTenancy.of(fn); + return this; + } + + public Builder replication(Replication replication) { + this.replication = replication; + return this; + } + + public Builder replication(Function> fn) { + this.replication = Replication.of(fn); + return this; + } + + public Builder invertedIndex(InvertedIndex invertedIndex) { + this.invertedIndex = invertedIndex; + return this; + } + + public Builder invertedIndex(Function> fn) { + this.invertedIndex = InvertedIndex.of(fn); + return this; + } + + public Builder rerankerModules(Reranker... rerankerModules) { + return rerankerModules(Arrays.asList(rerankerModules)); + } + + public Builder rerankerModules(List rerankerModules) { + this.rerankerModules.addAll(rerankerModules); + return this; + } + + public Builder generativeModule(Generative generativeModule) { + this.generativeModule = generativeModule; + return this; + } + + @Override + public CollectionConfig build() { + return new CollectionConfig(this); + } + } + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + if (type.getRawType() != CollectionConfig.class) { + return null; + } + + final var delegate = gson.getDelegateAdapter(this, (TypeToken) type); + return (TypeAdapter) new TypeAdapter() { + + @Override + public void write(JsonWriter out, CollectionConfig value) throws IOException { + var jsonObject = delegate.toJsonTree(value).getAsJsonObject(); + + // References must be merged with properties. + var references = jsonObject.remove("references").getAsJsonArray(); + var properties = jsonObject.get("properties").getAsJsonArray(); + properties.addAll(references); + + // Reranker and Generative module configs belong to the "moduleConfig". + var rerankerModules = jsonObject.remove("rerankerModules").getAsJsonArray(); + var generativeModule = jsonObject.remove("generativeModule"); + if (!rerankerModules.isEmpty() || !generativeModule.isJsonNull()) { + var modules = new JsonObject(); + + // Copy configuration for each reranker module. + rerankerModules.forEach(reranker -> { + reranker.getAsJsonObject().entrySet() + .stream().forEach(entry -> modules.add(entry.getKey(), entry.getValue())); + }); + + // Copy configuration for each generative module. + generativeModule.getAsJsonObject().entrySet() + .stream().forEach(entry -> modules.add(entry.getKey(), entry.getValue())); + + jsonObject.add("moduleConfig", modules); + } + + Streams.write(jsonObject, out); + } + + @Override + public CollectionConfig read(JsonReader in) throws IOException { + var jsonObject = JsonParser.parseReader(in).getAsJsonObject(); + + var mixedProperties = jsonObject.get("properties").getAsJsonArray(); + var references = new JsonArray(); + var properties = new JsonArray(); + + for (var property : mixedProperties) { + var dataTypes = property.getAsJsonObject().get("dataType").getAsJsonArray(); + if (dataTypes.size() == 1 && DataType.KNOWN_TYPES.contains(dataTypes.get(0).getAsString())) { + properties.add(property); + } else { + references.add(property); + } + } + + jsonObject.add("properties", properties); + jsonObject.add("references", references); + + if (!jsonObject.has("vectorConfig")) { + jsonObject.add("vectorConfig", new JsonObject()); + } + + // Separate modules into reranker- and generative- modules. + var rerankerModules = new JsonArray(); + if (jsonObject.has("moduleConfig")) { + var moduleConfig = jsonObject.remove("moduleConfig").getAsJsonObject(); + + moduleConfig.entrySet().stream() + .forEach(entry -> { + var module = new JsonObject(); + var name = entry.getKey(); + module.add(name, entry.getValue()); + + if (name.startsWith("reranker-")) { + rerankerModules.add(module); + } else if (name.startsWith("generative-")) { + jsonObject.add("generativeModule", module); + } + }); + } + jsonObject.add("rerankerModules", rerankerModules); + + return delegate.fromJsonTree(jsonObject); + } + }.nullSafe(); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CreateCollectionRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/CreateCollectionRequest.java index 25fe319ef..4573a4d06 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CreateCollectionRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CreateCollectionRequest.java @@ -7,12 +7,12 @@ import io.weaviate.client6.v1.internal.json.JSON; import io.weaviate.client6.v1.internal.rest.Endpoint; -public record CreateCollectionRequest(WeaviateCollection collection) { - public static final Endpoint _ENDPOINT = Endpoint.of( +public record CreateCollectionRequest(CollectionConfig collection) { + public static final Endpoint _ENDPOINT = Endpoint.of( request -> "POST", request -> "/schema/", (gson, request) -> JSON.serialize(request.collection), request -> Collections.emptyMap(), code -> code != HttpStatus.SC_SUCCESS, - (gson, response) -> JSON.deserialize(response, WeaviateCollection.class)); + (gson, response) -> JSON.deserialize(response, CollectionConfig.class)); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java new file mode 100644 index 000000000..2354d5b7a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -0,0 +1,110 @@ +package io.weaviate.client6.v1.api.collections; + +import java.io.IOException; +import java.util.EnumMap; +import java.util.Map; +import java.util.function.Function; + +import com.google.gson.Gson; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonToken; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.json.JsonEnum; + +public interface Generative { + public enum Kind implements JsonEnum { + COHERE("generative-cohere"); + + private static final Map jsonValueMap = JsonEnum.collectNames(Kind.values()); + private final String jsonValue; + + private Kind(String jsonValue) { + this.jsonValue = jsonValue; + } + + @Override + public String jsonValue() { + return this.jsonValue; + } + + public static Kind valueOfJson(String jsonValue) { + return JsonEnum.valueOfJson(jsonValue, jsonValueMap, Kind.class); + } + } + + Kind _kind(); + + Object _self(); + + public static Generative cohere() { + return CohereGenerative.of(); + } + + public static Generative cohere(Function> fn) { + return CohereGenerative.of(fn); + } + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + private static final EnumMap> readAdapters = new EnumMap<>( + Generative.Kind.class); + + private final void addAdapter(Gson gson, Generative.Kind kind, Class cls) { + readAdapters.put(kind, (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(cls))); + } + + private final void init(Gson gson) { + addAdapter(gson, Generative.Kind.COHERE, CohereGenerative.class); + } + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + var rawType = type.getRawType(); + if (!Generative.class.isAssignableFrom(rawType)) { + return null; + } + + if (readAdapters.isEmpty()) { + init(gson); + } + + final TypeAdapter writeAdapter = (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(rawType)); + return (TypeAdapter) new TypeAdapter() { + + @Override + public void write(JsonWriter out, Generative value) throws IOException { + out.beginObject(); + out.name(value._kind().jsonValue()); + writeAdapter.write(out, (T) value._self()); + out.endObject(); + } + + @Override + public Generative read(JsonReader in) throws IOException { + in.beginObject(); + var moduleName = in.nextName(); + try { + var kind = Generative.Kind.valueOfJson(moduleName); + var adapter = readAdapters.get(kind); + return adapter.read(in); + } catch (IllegalArgumentException e) { + return null; + } finally { + if (in.peek() == JsonToken.BEGIN_OBJECT) { + in.beginObject(); + } + in.endObject(); + } + } + }.nullSafe(); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/GetConfigRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/GetConfigRequest.java index 2027428ce..39914ec99 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/GetConfigRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/GetConfigRequest.java @@ -9,11 +9,11 @@ import io.weaviate.client6.v1.internal.rest.Endpoint; public record GetConfigRequest(String collectionName) { - public static final Endpoint> _ENDPOINT = Endpoint.of( + public static final Endpoint> _ENDPOINT = Endpoint.of( request -> "GET", request -> "/schema/" + request.collectionName, (gson, request) -> null, request -> Collections.emptyMap(), code -> code != HttpStatus.SC_SUCCESS, - (gson, response) -> Optional.ofNullable(JSON.deserialize(response, WeaviateCollection.class))); + (gson, response) -> Optional.ofNullable(JSON.deserialize(response, CollectionConfig.class))); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/InvertedIndex.java b/src/main/java/io/weaviate/client6/v1/api/collections/InvertedIndex.java new file mode 100644 index 000000000..3787c1e35 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/InvertedIndex.java @@ -0,0 +1,165 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record InvertedIndex( + @SerializedName("cleanupIntervalSeconds") Integer cleanupIntervalSeconds, + @SerializedName("bm25") Bm25 bm25, + @SerializedName("stopwords") Stopwords stopwords, + @SerializedName("indexTimestamps") Boolean indexTimestamps, + @SerializedName("indexNullState") Boolean indexNulls, + @SerializedName("indexPropertyLength") Boolean indexPropertyLength, + @SerializedName("usingBlockMaxWAND") Boolean useBlockMaxWAND) { + + public static InvertedIndex of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public record Bm25( + @SerializedName("b") Integer b, + @SerializedName("k1") Integer k1) { + + public static Bm25 of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public Bm25(Builder builder) { + this(builder.b, builder.k1); + } + + public static class Builder implements ObjectBuilder { + private Integer b; + private Integer k1; + + public Builder b(int b) { + this.b = b; + return this; + } + + public Builder k1(int k1) { + this.k1 = k1; + return this; + } + + @Override + public Bm25 build() { + return new Bm25(this); + } + } + } + + public record Stopwords( + @SerializedName("preset") String preset, + @SerializedName("additions") List additions, + @SerializedName("removals") List removals) { + + public static Stopwords of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public Stopwords(Builder builder) { + this(builder.preset, builder.additions, builder.removals); + } + + public static class Builder implements ObjectBuilder { + private String preset; + private List additions; + private List removals; + + public Builder preset(String preset) { + this.preset = preset; + return this; + } + + public Builder add(String... additions) { + return add(Arrays.asList(additions)); + } + + public Builder add(List additions) { + this.additions.addAll(additions); + return this; + } + + public Builder remove(String... removals) { + return remove(Arrays.asList(removals)); + } + + public Builder remove(List removals) { + this.removals.addAll(removals); + return this; + } + + @Override + public Stopwords build() { + return new Stopwords(this); + } + } + } + + public InvertedIndex(Builder builder) { + this( + builder.cleanupIntervalSeconds, + builder.bm25, + builder.stopwords, + builder.indexTimestamps, + builder.indexNulls, + builder.indexPropertyLength, + builder.useBlockMaxWAND); + } + + public static class Builder implements ObjectBuilder { + private Integer cleanupIntervalSeconds; + private Bm25 bm25; + private Stopwords stopwords; + private Boolean indexTimestamps; + private Boolean indexNulls; + private Boolean indexPropertyLength; + private Boolean useBlockMaxWAND; + + public Builder cleanupIntervalSeconds(int cleanupIntervalSeconds) { + this.cleanupIntervalSeconds = cleanupIntervalSeconds; + return this; + } + + public Builder bm25(Function> fn) { + this.bm25 = Bm25.of(fn); + return this; + } + + public Builder stopwords(Function> fn) { + this.stopwords = Stopwords.of(fn); + return this; + } + + public Builder indexTimestamps(Boolean indexTimestamps) { + this.indexTimestamps = indexTimestamps; + return this; + } + + public Builder indexNulls(Boolean indexNulls) { + this.indexNulls = indexNulls; + return this; + } + + public Builder indexPropertyLength(Boolean indexPropertyLength) { + this.indexPropertyLength = indexPropertyLength; + return this; + } + + public Builder useBlockMaxWAND(Boolean useBlockMaxWAND) { + this.useBlockMaxWAND = useBlockMaxWAND; + return this; + } + + @Override + public InvertedIndex build() { + return new InvertedIndex(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ListCollectionRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/ListCollectionRequest.java index 2c7535743..46652fe3a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/ListCollectionRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/ListCollectionRequest.java @@ -5,13 +5,11 @@ import org.apache.hc.core5.http.HttpStatus; -import com.google.gson.reflect.TypeToken; - import io.weaviate.client6.v1.internal.json.JSON; import io.weaviate.client6.v1.internal.rest.Endpoint; public record ListCollectionRequest() { - public static final Endpoint> _ENDPOINT = Endpoint.of( + public static final Endpoint> _ENDPOINT = Endpoint.of( request -> "GET", request -> "/schema", (gson, request) -> null, diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ListCollectionResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/ListCollectionResponse.java index 4523db1fe..5044431e1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/ListCollectionResponse.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/ListCollectionResponse.java @@ -4,5 +4,5 @@ import com.google.gson.annotations.SerializedName; -public record ListCollectionResponse(@SerializedName("classes") List collections) { +public record ListCollectionResponse(@SerializedName("classes") List collections) { } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/MultiTenancy.java b/src/main/java/io/weaviate/client6/v1/api/collections/MultiTenancy.java new file mode 100644 index 000000000..8365c95b7 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/MultiTenancy.java @@ -0,0 +1,42 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record MultiTenancy( + @SerializedName("autoTenantCreation") Boolean createAutomatically, + @SerializedName("autoTenantActivate") Boolean activateAutomatically) { + + public static MultiTenancy of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public MultiTenancy(Builder builder) { + this( + builder.createAutomatically, + builder.activateAutomatically); + } + + public static class Builder implements ObjectBuilder { + private Boolean createAutomatically; + private Boolean activateAutomatically; + + public Builder createAutomatically(boolean createAutomatically) { + this.createAutomatically = createAutomatically; + return this; + } + + public Builder activateAutomatically(boolean activateAutomatically) { + this.activateAutomatically = activateAutomatically; + return this; + } + + @Override + public MultiTenancy build() { + return new MultiTenancy(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Property.java b/src/main/java/io/weaviate/client6/v1/api/collections/Property.java index fb1e636b7..08e638032 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Property.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Property.java @@ -2,30 +2,45 @@ import java.util.Arrays; import java.util.List; +import java.util.function.Function; import com.google.gson.annotations.SerializedName; +import io.weaviate.client6.v1.internal.ObjectBuilder; + public record Property( - @SerializedName("name") String name, - @SerializedName("dataType") List dataTypes) { + @SerializedName("name") String propertyName, + @SerializedName("dataType") List dataTypes, + @SerializedName("description") String description, + @SerializedName("indexInverted") Boolean indexInverted, + @SerializedName("indexFilterable") Boolean indexFilterable, + @SerializedName("indexRangeFilters") Boolean indexRangeFilters, + @SerializedName("indexSearchable") Boolean indexSearchable, + @SerializedName("skipVectorization") Boolean skipVectorization, + @SerializedName("vectorizePropertyName") Boolean vectorizePropertyName) { - public Property(String name, String dataType) { - this(name, List.of(dataType)); + public static Property text(String name) { + return text(name, ObjectBuilder.identity()); } - /** Add text property with default configuration. */ - public static Property text(String name) { - return new Property(name, DataType.TEXT); + public static Property text(String name, Function> fn) { + return fn.apply(new Builder(name, DataType.TEXT)).build(); } - /** Add integer property with default configuration. */ public static Property integer(String name) { - return new Property(name, DataType.INT); + return integer(name, ObjectBuilder.identity()); + } + + public static Property integer(String name, Function> fn) { + return fn.apply(new Builder(name, DataType.INT)).build(); } - /** Add blob property with default configuration. */ public static Property blob(String name) { - return new Property(name, DataType.BLOB); + return blob(name, ObjectBuilder.identity()); + } + + public static Property blob(String name, Function> fn) { + return fn.apply(new Builder(name, DataType.BLOB)).build(); } public static ReferenceProperty reference(String name, String... collections) { @@ -35,4 +50,106 @@ public static ReferenceProperty reference(String name, String... collections) { public static ReferenceProperty reference(String name, List collections) { return new ReferenceProperty(name, collections); } + + public Builder edit() { + return new Builder(propertyName, dataTypes) + .description(description) + .indexInverted(indexInverted) + .indexFilterable(indexFilterable) + .indexRangeFilters(indexRangeFilters) + .indexSearchable(indexSearchable) + .skipVectorization(skipVectorization) + .vectorizePropertyName(vectorizePropertyName); + } + + public Property edit(Function> fn) { + return fn.apply(edit()).build(); + } + + public Property(Builder builder) { + this( + builder.propertyName, + builder.dataTypes, + builder.description, + builder.indexInverted, + builder.indexFilterable, + builder.indexRangeFilters, + builder.indexSearchable, + builder.skipVectorization, + builder.vectorizePropertyName); + } + + public static class Builder implements ObjectBuilder { + // Required parameters. + private final String propertyName; + + // Optional parameters. + private List dataTypes; + private String description; + private Boolean indexInverted; + private Boolean indexFilterable; + private Boolean indexRangeFilters; + private Boolean indexSearchable; + private Boolean skipVectorization; + private Boolean vectorizePropertyName; + + public Builder(String propertyName, String dataType) { + this.propertyName = propertyName; + this.dataTypes = List.of(dataType); + } + + public Builder(String propertyName, String... dataTypes) { + this(propertyName, Arrays.asList(dataTypes)); + } + + public Builder(String propertyName, List dataTypes) { + this.propertyName = propertyName; + this.dataTypes = List.copyOf(dataTypes); + } + + public Builder dataTypes(List dataTypes) { + this.dataTypes = dataTypes; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder indexInverted(Boolean indexInverted) { + this.indexInverted = indexInverted; + return this; + } + + public Builder indexFilterable(Boolean indexFilterable) { + this.indexFilterable = indexFilterable; + return this; + } + + public Builder indexRangeFilters(Boolean indexRangeFilters) { + this.indexRangeFilters = indexRangeFilters; + return this; + } + + public Builder indexSearchable(Boolean indexSearchable) { + this.indexSearchable = indexSearchable; + return this; + } + + public Builder skipVectorization(Boolean skipVectorization) { + this.skipVectorization = skipVectorization; + return this; + } + + public Builder vectorizePropertyName(Boolean vectorizePropertyName) { + this.vectorizePropertyName = vectorizePropertyName; + return this; + } + + @Override + public Property build() { + return new Property(this); + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ReferenceProperty.java b/src/main/java/io/weaviate/client6/v1/api/collections/ReferenceProperty.java index 3ed9b5aed..6fa0ade98 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/ReferenceProperty.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/ReferenceProperty.java @@ -5,10 +5,10 @@ import com.google.gson.annotations.SerializedName; public record ReferenceProperty( - @SerializedName("name") String name, + @SerializedName("name") String propertyName, @SerializedName("dataType") List dataTypes) { public Property toProperty() { - return new Property(name, dataTypes); + return new Property.Builder(propertyName, dataTypes).build(); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Replication.java b/src/main/java/io/weaviate/client6/v1/api/collections/Replication.java new file mode 100644 index 000000000..2b175bae5 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Replication.java @@ -0,0 +1,59 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record Replication( + @SerializedName("factor") Integer replicationFactor, + @SerializedName("asyncEnabled") Boolean asyncEnabled, + @SerializedName("deletionStrategy") DeletionStrategy deletionStrategy) { + + public static enum DeletionStrategy { + @SerializedName("NoAutomatedResolution") + NO_AUTOMATED_RESOLUTION, + @SerializedName("DeleteOnConflict") + DELETE_ON_CONFLICT, + @SerializedName("TimeBasedResolution") + TIME_BASED_RESOLUTION; + } + + public static Replication of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public Replication(Builder builder) { + this( + builder.replicationFactor, + builder.asyncEnabled, + builder.deletionStrategy); + } + + public static class Builder implements ObjectBuilder { + private Integer replicationFactor; + private Boolean asyncEnabled; + private DeletionStrategy deletionStrategy; + + public Builder replicationFactor(int replicationFactor) { + this.replicationFactor = replicationFactor; + return this; + } + + public Builder asyncEnabled(boolean asyncEnabled) { + this.asyncEnabled = asyncEnabled; + return this; + } + + public Builder deletionStrategy(DeletionStrategy deletionStrategy) { + this.deletionStrategy = deletionStrategy; + return this; + } + + @Override + public Replication build() { + return new Replication(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java new file mode 100644 index 000000000..fb473b40c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java @@ -0,0 +1,110 @@ +package io.weaviate.client6.v1.api.collections; + +import java.io.IOException; +import java.util.EnumMap; +import java.util.Map; +import java.util.function.Function; + +import com.google.gson.Gson; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonToken; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.api.collections.rerankers.CohereReranker; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.json.JsonEnum; + +public interface Reranker { + public enum Kind implements JsonEnum { + COHERE("reranker-cohere"); + + private static final Map jsonValueMap = JsonEnum.collectNames(Kind.values()); + private final String jsonValue; + + private Kind(String jsonValue) { + this.jsonValue = jsonValue; + } + + @Override + public String jsonValue() { + return this.jsonValue; + } + + public static Kind valueOfJson(String jsonValue) { + return JsonEnum.valueOfJson(jsonValue, jsonValueMap, Kind.class); + } + } + + Kind _kind(); + + Object _self(); + + public static Reranker cohere() { + return CohereReranker.of(); + } + + public static Reranker cohere(Function> fn) { + return CohereReranker.of(fn); + } + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + private static final EnumMap> readAdapters = new EnumMap<>( + Reranker.Kind.class); + + private final void addAdapter(Gson gson, Reranker.Kind kind, Class cls) { + readAdapters.put(kind, (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(cls))); + } + + private final void init(Gson gson) { + addAdapter(gson, Reranker.Kind.COHERE, CohereReranker.class); + } + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + var rawType = type.getRawType(); + if (!Reranker.class.isAssignableFrom(rawType)) { + return null; + } + + if (readAdapters.isEmpty()) { + init(gson); + } + + final TypeAdapter writeAdapter = (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(rawType)); + return (TypeAdapter) new TypeAdapter() { + + @Override + public void write(JsonWriter out, Reranker value) throws IOException { + out.beginObject(); + out.name(value._kind().jsonValue()); + writeAdapter.write(out, (T) value._self()); + out.endObject(); + } + + @Override + public Reranker read(JsonReader in) throws IOException { + in.beginObject(); + var rerankerName = in.nextName(); + try { + var kind = Reranker.Kind.valueOfJson(rerankerName); + var adapter = readAdapters.get(kind); + return adapter.read(in); + } catch (IllegalArgumentException e) { + return null; + } finally { + if (in.peek() == JsonToken.BEGIN_OBJECT) { + in.beginObject(); + } + in.endObject(); + } + } + }.nullSafe(); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Sharding.java b/src/main/java/io/weaviate/client6/v1/api/collections/Sharding.java new file mode 100644 index 000000000..79cdcdd09 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Sharding.java @@ -0,0 +1,50 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record Sharding( + @SerializedName("virtualPerPhysical") Integer virtualPerPhysical, + @SerializedName("desiredCound") Integer desiredCount, + @SerializedName("desiredVirtualCount") Integer desiredVirtualCount) { + + public static Sharding of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public Sharding(Builder builder) { + this( + builder.virtualPerPhysical, + builder.desiredCount, + builder.desiredVirtualCount); + } + + public static class Builder implements ObjectBuilder { + private Integer virtualPerPhysical; + private Integer desiredCount; + private Integer desiredVirtualCount; + + public Builder virtualPerPhysical(int virtualPerPhysical) { + this.virtualPerPhysical = virtualPerPhysical; + return this; + } + + public Builder desiredCount(int desiredCount) { + this.desiredCount = desiredCount; + return this; + } + + public Builder desiredVirtualCount(int desiredVirtualCount) { + this.desiredVirtualCount = desiredVirtualCount; + return this; + } + + @Override + public Sharding build() { + return new Sharding(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java index 7752cfb07..7ca6568ab 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java @@ -51,11 +51,11 @@ public static Kind valueOfJson(String jsonValue) { public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { INSTANCE; - private static final EnumMap> readAdapters = new EnumMap<>( + private static final EnumMap> delegateAdapters = new EnumMap<>( Vectorizer.Kind.class); private final void addAdapter(Gson gson, Vectorizer.Kind kind, Class cls) { - readAdapters.put(kind, (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(cls))); + delegateAdapters.put(kind, (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(cls))); } private final void init(Gson gson) { @@ -74,7 +74,7 @@ public TypeAdapter create(Gson gson, TypeToken type) { return null; } - if (readAdapters.isEmpty()) { + if (delegateAdapters.isEmpty()) { init(gson); } @@ -82,11 +82,11 @@ public TypeAdapter create(Gson gson, TypeToken type) { @Override public void write(JsonWriter out, Vectorizer value) throws IOException { - var writeAdapter = readAdapters.get(value._kind()); + TypeAdapter adapter = (TypeAdapter) delegateAdapters.get(value._kind()); out.beginObject(); out.name(value._kind().jsonValue()); - ((TypeAdapter) writeAdapter).write(out, (T) value._self()); + adapter.write(out, (T) value._self()); out.endObject(); } @@ -96,7 +96,7 @@ public Vectorizer read(JsonReader in) throws IOException { var vectorizerName = in.nextName(); try { var kind = Vectorizer.Kind.valueOfJson(vectorizerName); - var adapter = readAdapters.get(kind); + var adapter = delegateAdapters.get(kind); return adapter.read(in); } catch (IllegalArgumentException e) { return null; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollection.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollection.java deleted file mode 100644 index 6dc8fa306..000000000 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollection.java +++ /dev/null @@ -1,167 +0,0 @@ -package io.weaviate.client6.v1.api.collections; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Function; - -import com.google.gson.Gson; -import com.google.gson.JsonArray; -import com.google.gson.JsonParser; -import com.google.gson.TypeAdapter; -import com.google.gson.TypeAdapterFactory; -import com.google.gson.annotations.SerializedName; -import com.google.gson.internal.Streams; -import com.google.gson.reflect.TypeToken; -import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonWriter; - -import io.weaviate.client6.v1.internal.ObjectBuilder; - -public record WeaviateCollection( - @SerializedName("class") String name, - @SerializedName("description") String description, - @SerializedName("properties") List properties, - List references, - @SerializedName("vectorConfig") Map vectors) { - - public static WeaviateCollection of(String collectionName) { - return of(collectionName, ObjectBuilder.identity()); - } - - public static WeaviateCollection of(String collectionName, Function> fn) { - return fn.apply(new Builder(collectionName)).build(); - } - - public WeaviateCollection(Builder builder) { - this( - builder.collectionName, - builder.description, - builder.properties, - builder.references, - builder.vectors); - } - - public static class Builder implements ObjectBuilder { - // Required parameters; - private final String collectionName; - - private String description; - private List properties = new ArrayList<>(); - private List references = new ArrayList<>(); - private Map vectors = new HashMap<>(); - - public Builder(String collectionName) { - this.collectionName = collectionName; - } - - public Builder description(String description) { - this.description = description; - return this; - } - - public Builder properties(Property... properties) { - return properties(Arrays.asList(properties)); - } - - public Builder properties(List properties) { - this.properties = properties; - return this; - } - - public Builder references(ReferenceProperty... references) { - return references(Arrays.asList(references)); - } - - public Builder references(List references) { - this.references = references; - return this; - } - - public Builder vector(VectorIndex vector) { - this.vectors.put(VectorIndex.DEFAULT_VECTOR_NAME, vector); - return this; - } - - public Builder vector(String name, VectorIndex vector) { - this.vectors.put(name, vector); - return this; - } - - public Builder vectors(Function>> fn) { - this.vectors = fn.apply(new VectorsBuilder()).build(); - return this; - } - - public static class VectorsBuilder implements ObjectBuilder> { - private Map vectors = new HashMap<>(); - - public VectorsBuilder vector(String name, VectorIndex vector) { - vectors.put(name, vector); - return this; - } - - @Override - public Map build() { - return this.vectors; - } - } - - @Override - public WeaviateCollection build() { - return new WeaviateCollection(this); - } - } - - public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { - INSTANCE; - - @SuppressWarnings("unchecked") - @Override - public TypeAdapter create(Gson gson, TypeToken type) { - if (type.getRawType() != WeaviateCollection.class) { - return null; - } - - final var delegate = gson.getDelegateAdapter(this, (TypeToken) type); - return (TypeAdapter) new TypeAdapter() { - - @Override - public void write(JsonWriter out, WeaviateCollection value) throws IOException { - var jsonObject = delegate.toJsonTree(value).getAsJsonObject(); - - var references = jsonObject.remove("references").getAsJsonArray(); - var properties = jsonObject.get("properties").getAsJsonArray(); - properties.addAll(references); - - Streams.write(jsonObject, out); - } - - @Override - public WeaviateCollection read(JsonReader in) throws IOException { - var jsonObject = JsonParser.parseReader(in).getAsJsonObject(); - - var mixedProperties = jsonObject.get("properties").getAsJsonArray(); - var references = new JsonArray(); - var properties = new JsonArray(); - - for (var property : mixedProperties) { - var dataTypes = property.getAsJsonObject().get("dataType").getAsJsonArray(); - if (dataTypes.size() == 1 && DataType.KNOWN_TYPES.contains(dataTypes.get(0).getAsString())) { - properties.add(property); - } else { - references.add(property); - } - } - - jsonObject.add("properties", properties); - jsonObject.add("references", references); - return delegate.fromJsonTree(jsonObject); - } - }.nullSafe(); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java index c14c7ccfa..03d4c6b33 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java @@ -24,25 +24,25 @@ public CollectionHandle> use(String collectionName) { return new CollectionHandle<>(restTransport, grpcTransport, CollectionDescriptor.ofMap(collectionName)); } - public WeaviateCollection create(String name) throws IOException { - return create(WeaviateCollection.of(name)); + public CollectionConfig create(String name) throws IOException { + return create(CollectionConfig.of(name)); } - public WeaviateCollection create(String name, - Function> fn) throws IOException { - return create(WeaviateCollection.of(name, fn)); + public CollectionConfig create(String name, + Function> fn) throws IOException { + return create(CollectionConfig.of(name, fn)); } - public WeaviateCollection create(WeaviateCollection collection) throws IOException { + public CollectionConfig create(CollectionConfig collection) throws IOException { return this.restTransport.performRequest(new CreateCollectionRequest(collection), CreateCollectionRequest._ENDPOINT); } - public Optional getConfig(String name) throws IOException { + public Optional getConfig(String name) throws IOException { return this.restTransport.performRequest(new GetConfigRequest(name), GetConfigRequest._ENDPOINT); } - public List list() throws IOException { + public List list() throws IOException { return this.restTransport.performRequest(new ListCollectionRequest(), ListCollectionRequest._ENDPOINT); } @@ -52,7 +52,7 @@ public void delete(String name) throws IOException { public void deleteAll() throws IOException { for (var collection : list()) { - delete(collection.name()); + delete(collection.collectionName()); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java index 6bdc4721a..d357d56cc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java @@ -26,25 +26,25 @@ public CollectionHandleAsync> use(String collectionName) { CollectionDescriptor.ofMap(collectionName)); } - public CompletableFuture create(String name) { - return create(WeaviateCollection.of(name)); + public CompletableFuture create(String name) { + return create(CollectionConfig.of(name)); } - public CompletableFuture create(String name, - Function> fn) { - return create(WeaviateCollection.of(name, fn)); + public CompletableFuture create(String name, + Function> fn) { + return create(CollectionConfig.of(name, fn)); } - public CompletableFuture create(WeaviateCollection collection) { + public CompletableFuture create(CollectionConfig collection) { return this.restTransport.performRequestAsync(new CreateCollectionRequest(collection), CreateCollectionRequest._ENDPOINT); } - public CompletableFuture> getConfig(String name) { + public CompletableFuture> getConfig(String name) { return this.restTransport.performRequestAsync(new GetConfigRequest(name), GetConfigRequest._ENDPOINT); } - public CompletableFuture> list() { + public CompletableFuture> list() { return this.restTransport.performRequestAsync(new ListCollectionRequest(), ListCollectionRequest._ENDPOINT); } @@ -55,7 +55,7 @@ public CompletableFuture delete(String name) { public CompletableFuture deleteAll() throws IOException { return list().thenCompose(collections -> { var futures = collections.stream() - .map(collection -> delete(collection.name())) + .map(collection -> delete(collection.collectionName())) .toArray(CompletableFuture[]::new); return CompletableFuture.allOf(futures); }); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java new file mode 100644 index 000000000..3b233f01f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java @@ -0,0 +1,112 @@ +package io.weaviate.client6.v1.api.collections.config; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import org.apache.hc.core5.http.HttpStatus; + +import io.weaviate.client6.v1.api.collections.CollectionConfig; +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.InvertedIndex; +import io.weaviate.client6.v1.api.collections.Replication; +import io.weaviate.client6.v1.api.collections.Reranker; +import io.weaviate.client6.v1.api.collections.VectorIndex; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record UpdateCollectionRequest(CollectionConfig collection) { + + public static final Endpoint _ENDPOINT = Endpoint.of( + request -> "PUT", + request -> "/schema/" + request.collection.collectionName(), + (gson, request) -> JSON.serialize(request.collection), + request -> Collections.emptyMap(), + code -> code != HttpStatus.SC_SUCCESS, + (gson, response) -> null); + + public static UpdateCollectionRequest of(CollectionConfig collection, + Function> fn) { + return fn.apply(new Builder(collection)).build(); + } + + public UpdateCollectionRequest(Builder builder) { + this(builder.newCollection.build()); + } + + public static class Builder implements ObjectBuilder { + // For updating property descriptions + private final CollectionConfig currentCollection; + // Builder for the updated collection config. + private final CollectionConfig.Builder newCollection; + + public Builder(CollectionConfig currentCollection) { + this.currentCollection = currentCollection; + this.newCollection = currentCollection.edit(); + } + + public Builder description(String description) { + this.newCollection.description(description); + return this; + } + + public Builder propertyDescription(String propertyName, String description) { + for (var property : currentCollection.properties()) { + if (property.propertyName().equals(propertyName)) { + var newProperty = property.edit(p -> p.description(description)); + this.newCollection.properties(newProperty); + break; + } + } + return this; + } + + public Builder replication(Replication replication) { + this.newCollection.replication(replication); + return this; + } + + public Builder replication(Function> fn) { + this.newCollection.replication(fn); + return this; + } + + public Builder invertedIndex(InvertedIndex invertedIndex) { + this.newCollection.invertedIndex(invertedIndex); + return this; + } + + public Builder invertedIndex(Function> fn) { + this.newCollection.invertedIndex(fn); + return this; + } + + public Builder rerankerModules(Reranker... rerankerModules) { + this.newCollection.rerankerModules(rerankerModules); + return this; + } + + public Builder rerankerModules(List rerankerModules) { + this.newCollection.rerankerModules(rerankerModules); + return this; + } + + public Builder generativeModule(Generative generativeModule) { + this.newCollection.generativeModule(generativeModule); + return this; + } + + @SafeVarargs + public final Builder vectors(Map.Entry... vectors) { + this.newCollection.vectors(Map.ofEntries(vectors)); + return this; + } + + @Override + public UpdateCollectionRequest build() { + return new UpdateCollectionRequest(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java index 0fa3cb860..9979bfa3a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java @@ -2,37 +2,47 @@ import java.io.IOException; import java.util.Optional; +import java.util.function.Function; +import io.weaviate.client6.v1.api.collections.CollectionConfig; import io.weaviate.client6.v1.api.collections.Property; -import io.weaviate.client6.v1.api.collections.WeaviateCollection; import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClient; +import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.RestTransport; public class WeaviateConfigClient { - private final RestTransport transport; + private final RestTransport restTransport; private final WeaviateCollectionsClient collectionsClient; protected final CollectionDescriptor collection; public WeaviateConfigClient(CollectionDescriptor collection, RestTransport restTransport, GrpcTransport grpcTransport) { - this.transport = restTransport; + this.restTransport = restTransport; this.collectionsClient = new WeaviateCollectionsClient(restTransport, grpcTransport); this.collection = collection; } - public Optional get() throws IOException { + public Optional get() throws IOException { return collectionsClient.getConfig(collection.name()); } public void addProperty(Property property) throws IOException { - this.transport.performRequest(new AddPropertyRequest(collection.name(), property), AddPropertyRequest._ENDPOINT); + this.restTransport.performRequest(new AddPropertyRequest(collection.name(), property), + AddPropertyRequest._ENDPOINT); } - public void addReference(String name, String... dataTypes) throws IOException { - this.addProperty(Property.reference(name, dataTypes).toProperty()); + public void addReference(String propertyName, String... dataTypes) throws IOException { + this.addProperty(Property.reference(propertyName, dataTypes).toProperty()); + } + + public void update(String collectionName, + Function> fn) throws IOException { + var thisCollection = get().orElseThrow(); // TODO: use descriptive error + this.restTransport.performRequest(UpdateCollectionRequest.of(thisCollection, fn), + UpdateCollectionRequest._ENDPOINT); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java index 54e586a2b..001f7e4ca 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java @@ -3,38 +3,49 @@ import java.io.IOException; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import io.weaviate.client6.v1.api.collections.CollectionConfig; import io.weaviate.client6.v1.api.collections.Property; -import io.weaviate.client6.v1.api.collections.WeaviateCollection; import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClientAsync; +import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; import io.weaviate.client6.v1.internal.rest.RestTransport; public class WeaviateConfigClientAsync { - private final RestTransport transport; + private final RestTransport restTransport; private final WeaviateCollectionsClientAsync collectionsClient; - protected final CollectionDescriptor collection; + protected final CollectionDescriptor collectionDescriptor; public WeaviateConfigClientAsync(CollectionDescriptor collection, RestTransport restTransport, GrpcTransport grpcTransport) { - this.transport = restTransport; + this.restTransport = restTransport; this.collectionsClient = new WeaviateCollectionsClientAsync(restTransport, grpcTransport); - this.collection = collection; + this.collectionDescriptor = collection; } - public CompletableFuture> get() throws IOException { - return collectionsClient.getConfig(collection.name()); + public CompletableFuture> get() throws IOException { + return collectionsClient.getConfig(collectionDescriptor.name()); } public CompletableFuture addProperty(Property property) throws IOException { - return this.transport.performRequestAsync(new AddPropertyRequest(collection.name(), property), + return this.restTransport.performRequestAsync(new AddPropertyRequest(collectionDescriptor.name(), property), AddPropertyRequest._ENDPOINT); } public CompletableFuture addReference(String name, String... dataTypes) throws IOException { return this.addProperty(Property.reference(name, dataTypes).toProperty()); } + + public CompletableFuture update(String collectionName, + Function> fn) throws IOException { + return get().thenCompose(maybeCollection -> { + var thisCollection = maybeCollection.orElseThrow(); + return this.restTransport.performRequestAsync(UpdateCollectionRequest.of(thisCollection, fn), + UpdateCollectionRequest._ENDPOINT); + }); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java new file mode 100644 index 000000000..b95ffc601 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java @@ -0,0 +1,96 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record CohereGenerative( + @SerializedName("kProperty") String kProperty, + @SerializedName("model") String model, + @SerializedName("maxTokensProperty") Integer maxTokensProperty, + @SerializedName("returnLikelihoodsProperty") String returnLikelihoodsProperty, + @SerializedName("stopSequencesProperty") List stopSequencesProperty, + @SerializedName("temperatureProperty") String temperatureProperty) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.COHERE; + } + + @Override + public Object _self() { + return this; + } + + public static CohereGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static CohereGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public CohereGenerative(Builder builder) { + this( + builder.kProperty, + builder.model, + builder.maxTokensProperty, + builder.returnLikelihoodsProperty, + builder.stopSequencesProperty, + builder.temperatureProperty); + } + + public static class Builder implements ObjectBuilder { + private String kProperty; + private String model; + private Integer maxTokensProperty; + private String returnLikelihoodsProperty; + private List stopSequencesProperty = new ArrayList<>(); + private String temperatureProperty; + + public Builder kProperty(String kProperty) { + this.kProperty = kProperty; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder maxTokensProperty(int maxTokensProperty) { + this.maxTokensProperty = maxTokensProperty; + return this; + } + + public Builder returnLikelihoodsProperty(String returnLikelihoodsProperty) { + this.returnLikelihoodsProperty = returnLikelihoodsProperty; + return this; + } + + public Builder stopSequencesProperty(String... stopSequencesProperty) { + return stopSequencesProperty(Arrays.asList(stopSequencesProperty)); + } + + public Builder stopSequencesProperty(List stopSequencesProperty) { + this.stopSequencesProperty = stopSequencesProperty; + return this; + } + + public Builder temperatureProperty(String temperatureProperty) { + this.temperatureProperty = temperatureProperty; + return this; + } + + @Override + public CohereGenerative build() { + return new CohereGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/CohereReranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/CohereReranker.java new file mode 100644 index 000000000..ba435e572 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/CohereReranker.java @@ -0,0 +1,51 @@ +package io.weaviate.client6.v1.api.collections.rerankers; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Reranker; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record CohereReranker( + @SerializedName("model") String model) implements Reranker { + + public static final String RERANK_ENGLISH_V2 = "rerank-english-v2.0"; + public static final String RERANK_MULTILINGUAL_V2 = "rerank-mulilingual-v2.0"; + + @Override + public Kind _kind() { + return Reranker.Kind.COHERE; + } + + @Override + public Object _self() { + return this; + } + + public static CohereReranker of() { + return of(ObjectBuilder.identity()); + } + + public static CohereReranker of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public CohereReranker(Builder builder) { + this(builder.model); + } + + public static class Builder implements ObjectBuilder { + private String model; + + public Builder model(String model) { + this.model = model; + return this; + } + + @Override + public CohereReranker build() { + return new CohereReranker(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java index 91ae742ce..8285e574b 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java +++ b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java @@ -12,13 +12,17 @@ public final class JSON { gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.WeaviateObject.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapterFactory( - io.weaviate.client6.v1.api.collections.WeaviateCollection.CustomTypeAdapterFactory.INSTANCE); + io.weaviate.client6.v1.api.collections.CollectionConfig.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.Vectors.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.Vectorizer.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.VectorIndex.CustomTypeAdapterFactory.INSTANCE); + gsonBuilder.registerTypeAdapterFactory( + io.weaviate.client6.v1.api.collections.Reranker.CustomTypeAdapterFactory.INSTANCE); + gsonBuilder.registerTypeAdapterFactory( + io.weaviate.client6.v1.api.collections.Generative.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapter( io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer.class, diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index e1ab5a9ee..8f4f211a1 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -14,14 +14,17 @@ import com.jparams.junit4.JParamsTestRunner; import com.jparams.junit4.data.DataMethod; +import io.weaviate.client6.v1.api.collections.CollectionConfig; +import io.weaviate.client6.v1.api.collections.Generative; import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.Reranker; import io.weaviate.client6.v1.api.collections.VectorIndex; import io.weaviate.client6.v1.api.collections.Vectorizer; import io.weaviate.client6.v1.api.collections.Vectors; -import io.weaviate.client6.v1.api.collections.WeaviateCollection; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.data.Reference; +import io.weaviate.client6.v1.api.collections.rerankers.CohereReranker; import io.weaviate.client6.v1.api.collections.vectorindex.Distance; import io.weaviate.client6.v1.api.collections.vectorindex.Flat; import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; @@ -184,8 +187,8 @@ public static Object[][] testCases() { // WeaviateCollection.CustomTypeAdapterFactory { - WeaviateCollection.class, - WeaviateCollection.of("Things", things -> things + CollectionConfig.class, + CollectionConfig.of("Things", things -> things .description("A collection of things") .properties( Property.text("shape"), @@ -249,6 +252,43 @@ public static Object[][] testCases() { } """, }, + + // Reranker.CustomTypeAdapterFactory + { + Reranker.class, + Reranker.cohere(rerank -> rerank + .model(CohereReranker.RERANK_ENGLISH_V2)), + """ + { + "reranker-cohere": { + "model": "rerank-english-v2.0" + } + } + """, + }, + + { + Generative.class, + Generative.cohere(generate -> generate + .kProperty("k-property") + .maxTokensProperty(10) + .model("example-model") + .returnLikelihoodsProperty("likelihood") + .stopSequencesProperty("stop", "halt") + .temperatureProperty("celcius")), + """ + { + "generative-cohere": { + "kProperty": "k-property", + "maxTokensProperty": 10, + "model": "example-model", + "returnLikelihoodsProperty": "likelihood", + "stopSequencesProperty": ["stop", "halt"], + "temperatureProperty": "celcius" + } + } + """, + }, }; }