diff --git a/src/it/java/io/weaviate/integration/CollectionsITest.java b/src/it/java/io/weaviate/integration/CollectionsITest.java index 06c70b1e5..dcde8a399 100644 --- a/src/it/java/io/weaviate/integration/CollectionsITest.java +++ b/src/it/java/io/weaviate/integration/CollectionsITest.java @@ -13,6 +13,8 @@ import io.weaviate.client6.v1.api.collections.Property; import io.weaviate.client6.v1.api.collections.Replication; import io.weaviate.client6.v1.api.collections.VectorIndex; +import io.weaviate.client6.v1.api.collections.config.Shard; +import io.weaviate.client6.v1.api.collections.config.ShardStatus; import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.containers.Container; @@ -159,4 +161,27 @@ public void testUpdateCollection() throws IOException { .extracting(CollectionConfig::replication).returns(false, Replication::asyncEnabled); }); } + + @Test + public void testShards() throws IOException { + var nsShatteredCups = ns("ShatteredCups"); + client.collections.create(nsShatteredCups); + var cups = client.collections.use(nsShatteredCups); + + // Act: get initial shard state + var shards = cups.config.getShards(); + + Assertions.assertThat(shards).as("single-tenant collections has 1 shard").hasSize(1); + var singleShard = shards.get(0); + + // Act: flip the status + var wantStatus = singleShard.status().equals("READY") ? ShardStatus.READONLY : ShardStatus.READY; + var updated = cups.config.updateShards(wantStatus, singleShard.name()); + + Assertions.assertThat(updated) + .as("shard status changed") + .hasSize(1) + .extracting(Shard::status) + .containsOnly(wantStatus.name()); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java new file mode 100644 index 000000000..aa5638ff2 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/GetShardsRequest.java @@ -0,0 +1,24 @@ +package io.weaviate.client6.v1.api.collections.config; + +import java.util.Collections; +import java.util.List; + +import org.apache.hc.core5.http.HttpStatus; + +import com.google.gson.reflect.TypeToken; + +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record GetShardsRequest(String collectionName) { + + @SuppressWarnings("unchecked") + public static final Endpoint> _ENDPOINT = Endpoint.of( + request -> "GET", + request -> "/schema/" + request.collectionName + "/shards", // TODO: tenant support + (gson, request) -> null, + request -> Collections.emptyMap(), + code -> code != HttpStatus.SC_SUCCESS, + (gson, response) -> (List) JSON.deserialize(response, TypeToken.getParameterized( + List.class, Shard.class))); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/Shard.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/Shard.java new file mode 100644 index 000000000..f0797668f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/Shard.java @@ -0,0 +1,9 @@ +package io.weaviate.client6.v1.api.collections.config; + +import com.google.gson.annotations.SerializedName; + +public record Shard( + @SerializedName("name") String name, + @SerializedName("status") String status, + @SerializedName("vectorQueueSize") long vectorQueueSize) { +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/ShardStatus.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/ShardStatus.java new file mode 100644 index 000000000..bc5bb06d2 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/ShardStatus.java @@ -0,0 +1,10 @@ +package io.weaviate.client6.v1.api.collections.config; + +import com.google.gson.annotations.SerializedName; + +public enum ShardStatus { + @SerializedName("READY") + READY, + @SerializedName("READONLY") + READONLY; +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateShardStatusRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateShardStatusRequest.java new file mode 100644 index 000000000..95431d273 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateShardStatusRequest.java @@ -0,0 +1,19 @@ +package io.weaviate.client6.v1.api.collections.config; + +import java.util.Collections; +import java.util.Map; + +import org.apache.hc.core5.http.HttpStatus; + +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record UpdateShardStatusRequest(String collection, String shard, ShardStatus status) { + public static final Endpoint _ENDPOINT = Endpoint.of( + request -> "PUT", + request -> "/schema/" + request.collection + "/shards/" + request.shard, + (gson, request) -> JSON.serialize(Map.of("status", request.status)), + request -> Collections.emptyMap(), + code -> code != HttpStatus.SC_SUCCESS, + (gson, response) -> null); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java index 9979bfa3a..b2a0fb493 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java @@ -1,6 +1,8 @@ package io.weaviate.client6.v1.api.collections.config; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.Optional; import java.util.function.Function; @@ -45,4 +47,21 @@ public void update(String collectionName, this.restTransport.performRequest(UpdateCollectionRequest.of(thisCollection, fn), UpdateCollectionRequest._ENDPOINT); } + + public List getShards() throws IOException { + return this.restTransport.performRequest(new GetShardsRequest(collection.name()), GetShardsRequest._ENDPOINT); + } + + public List updateShards(ShardStatus status, String... shards) throws IOException { + return updateShards(status, Arrays.asList(shards)); + } + + public List updateShards(ShardStatus status, List shards) throws IOException { + for (var shard : shards) { + this.restTransport.performRequest( + new UpdateShardStatusRequest(collection.name(), shard, status), + UpdateShardStatusRequest._ENDPOINT); + } + return getShards(); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java index 001f7e4ca..a418a47aa 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java @@ -1,6 +1,8 @@ package io.weaviate.client6.v1.api.collections.config; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Function; @@ -48,4 +50,22 @@ public CompletableFuture update(String collectionName, UpdateCollectionRequest._ENDPOINT); }); } + + public CompletableFuture> getShards() { + return this.restTransport.performRequestAsync(new GetShardsRequest(collectionDescriptor.name()), + GetShardsRequest._ENDPOINT); + } + + public CompletableFuture> updateShards(ShardStatus status, String... shards) throws IOException { + return updateShards(status, Arrays.asList(shards)); + } + + public CompletableFuture> updateShards(ShardStatus status, List shards) throws IOException { + var updates = shards.stream().map( + shard -> this.restTransport.performRequestAsync( + new UpdateShardStatusRequest(collectionDescriptor.name(), shard, status), + UpdateShardStatusRequest._ENDPOINT)) + .toArray(CompletableFuture[]::new); + return CompletableFuture.allOf(updates).thenCompose(__ -> getShards()); + } }