diff --git a/pom.xml b/pom.xml
index a5d5c559a..6503ced8a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -65,7 +65,7 @@
17
1.18.36
2.12.1
- 5.4.2
+ 5.4.3
3.17.0
5.12.0
1.20.5
@@ -194,6 +194,7 @@
org.mock-server
+
mockserver-netty
${mock-server.version}
test
diff --git a/src/it/java/io/weaviate/containers/Weaviate.java b/src/it/java/io/weaviate/containers/Weaviate.java
index c70342fd1..bcc1ba7d4 100644
--- a/src/it/java/io/weaviate/containers/Weaviate.java
+++ b/src/it/java/io/weaviate/containers/Weaviate.java
@@ -9,7 +9,6 @@
import org.testcontainers.weaviate.WeaviateContainer;
-import io.weaviate.client6.v1.api.Config;
import io.weaviate.client6.v1.api.WeaviateClient;
public class Weaviate extends WeaviateContainer {
@@ -29,8 +28,15 @@ public WeaviateClient getClient() {
start();
}
if (clientInstance == null) {
- var config = new Config("http", getHttpHostAddress(), getGrpcHostAddress());
- clientInstance = new WeaviateClient(config);
+ try {
+ clientInstance = WeaviateClient.local(
+ conn -> conn
+ .host(getHost())
+ .httpPort(getMappedPort(8080))
+ .grpcPort(getMappedPort(50051)));
+ } catch (Exception e) {
+ throw new RuntimeException("create WeaviateClient for Weaviate container", e);
+ }
}
return clientInstance;
}
@@ -46,7 +52,6 @@ public static Weaviate.Builder custom() {
public static class Builder {
private String versionTag;
private Set enableModules = new HashSet<>();
- private String defaultVectorizerModule;
private boolean telemetry;
private Map environment = new HashMap<>();
diff --git a/src/it/java/io/weaviate/integration/AuthorizationITest.java b/src/it/java/io/weaviate/integration/AuthorizationITest.java
new file mode 100644
index 000000000..30422c67a
--- /dev/null
+++ b/src/it/java/io/weaviate/integration/AuthorizationITest.java
@@ -0,0 +1,53 @@
+package io.weaviate.integration;
+
+import java.io.IOException;
+import java.util.Collections;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockserver.integration.ClientAndServer;
+import org.mockserver.model.HttpRequest;
+
+import io.weaviate.ConcurrentTest;
+import io.weaviate.client6.v1.api.Authorization;
+import io.weaviate.client6.v1.internal.rest.DefaultRestTransport;
+import io.weaviate.client6.v1.internal.rest.Endpoint;
+import io.weaviate.client6.v1.internal.rest.RestTransportOptions;
+
+public class AuthorizationITest extends ConcurrentTest {
+ private ClientAndServer mockServer;
+
+ @Before
+ public void startMockServer() {
+ mockServer = ClientAndServer.startClientAndServer(8080);
+ }
+
+ @Test
+ public void testAuthorization_apiKey() throws IOException {
+ var transportOptions = new RestTransportOptions(
+ "http", "localhost", 8080,
+ Collections.emptyMap(), Authorization.apiKey("my-api-key"));
+
+ try (final var restClient = new DefaultRestTransport(transportOptions)) {
+ restClient.performRequest(null, Endpoint.of(
+ request -> "GET",
+ request -> "/",
+ (gson, request) -> null,
+ request -> null,
+ code -> code != 200,
+ (gson, response) -> null));
+ }
+
+ mockServer.verify(
+ HttpRequest.request()
+ .withMethod("GET")
+ .withPath("/v1/")
+ .withHeader("Authorization", "Bearer my-api-key"));
+ }
+
+ @After
+ public void stopMockServer() {
+ mockServer.stop();
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/Authorization.java b/src/main/java/io/weaviate/client6/v1/api/Authorization.java
new file mode 100644
index 000000000..9bcfc2ed1
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/api/Authorization.java
@@ -0,0 +1,10 @@
+package io.weaviate.client6.v1.api;
+
+import io.weaviate.client6.v1.internal.TokenProvider;
+import io.weaviate.client6.v1.internal.TokenProvider.Token;
+
+public class Authorization {
+ public static TokenProvider apiKey(String apiKey) {
+ return TokenProvider.staticToken(new Token(apiKey));
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java
index 2e7d9391d..bd005d167 100644
--- a/src/main/java/io/weaviate/client6/v1/api/Config.java
+++ b/src/main/java/io/weaviate/client6/v1/api/Config.java
@@ -1,68 +1,166 @@
package io.weaviate.client6.v1.api;
-import java.util.Collections;
+import java.net.URI;
+import java.util.HashMap;
import java.util.Map;
+import java.util.function.Function;
+import io.weaviate.client6.v1.internal.ObjectBuilder;
+import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions;
-import io.weaviate.client6.v1.internal.rest.TransportOptions;
-
-public class Config {
- private final String version = "v1";
- private final String scheme;
- private final String httpHost;
- private final String grpcHost;
- private final Map headers = Collections.emptyMap();
-
- public Config(String scheme, String httpHost, String grpcHost) {
- this.scheme = scheme;
- this.httpHost = httpHost;
- this.grpcHost = grpcHost;
+import io.weaviate.client6.v1.internal.rest.RestTransportOptions;
+
+public record Config(
+ String scheme,
+ String httpHost,
+ int httpPort,
+ String grpcHost,
+ int grpcPort,
+ Map headers,
+ TokenProvider tokenProvider) {
+
+ public static Config of(String scheme, Function> fn) {
+ return fn.apply(new Custom(scheme)).build();
}
- public String baseUrl() {
- return scheme + "://" + httpHost + "/" + version;
+ public Config(Builder> builder) {
+ this(
+ builder.scheme,
+ builder.httpHost,
+ builder.httpPort,
+ builder.grpcHost,
+ builder.grpcPort,
+ builder.headers,
+ builder.tokenProvider);
}
- public String grpcAddress() {
- if (grpcHost.contains(":")) {
- return grpcHost;
- }
- // FIXME: use secure port (433) if scheme == https
- return String.format("%s:80", grpcHost);
+ public RestTransportOptions restTransportOptions() {
+ return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider);
}
- public TransportOptions rest() {
- return new TransportOptions() {
+ public GrpcChannelOptions grpcTransportOptions() {
+ return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider);
+ }
- @Override
- public String host() {
- return baseUrl();
- }
+ abstract static class Builder> implements ObjectBuilder {
+ // Required parameters;
+ protected final String scheme;
+
+ protected String httpHost;
+ protected int httpPort;
+ protected String grpcHost;
+ protected int grpcPort;
+ protected TokenProvider tokenProvider;
+ protected Map headers = new HashMap<>();
- @Override
- public Map headers() {
- return headers;
+ protected Builder(String scheme) {
+ this.scheme = scheme;
+ }
+
+ @SuppressWarnings("unchecked")
+ public SELF setHeader(String key, String value) {
+ this.headers.put(key, value);
+ return (SELF) this;
+ }
+
+ @SuppressWarnings("unchecked")
+ public SELF setHeaders(Map headers) {
+ this.headers = Map.copyOf(headers);
+ return (SELF) this;
+ }
+
+ private static final String HEADER_X_WEAVIATE_CLUSTER_URL = "X-Weaviate-Cluster-URL";
+
+ /**
+ * isWeaviateDomain returns true if the host matches weaviate.io,
+ * semi.technology, or weaviate.cloud domain.
+ */
+ private static boolean isWeaviateDomain(String host) {
+ var lower = host.toLowerCase();
+ return lower.contains("weaviate.io") ||
+ lower.contains("semi.technology") ||
+ lower.contains("weaviate.cloud");
+ }
+
+ @Override
+ public Config build() {
+ if (isWeaviateDomain(httpHost) && tokenProvider != null) {
+ setHeader(HEADER_X_WEAVIATE_CLUSTER_URL, "https://" + httpHost + ":" + httpPort);
}
+ return new Config(this);
+ }
+ }
+
+ public static class Local extends Builder {
+ public Local() {
+ super("http");
+ host("localhost");
+ httpPort(8080);
+ grpcPort(50051);
+ }
+
+ public Local host(String host) {
+ this.httpHost = host;
+ this.grpcHost = host;
+ return this;
+ }
- };
+ public Local httpPort(int port) {
+ this.httpPort = port;
+ return this;
+ }
+
+ public Local grpcPort(int port) {
+ this.grpcPort = port;
+ return this;
+ }
}
- public GrpcChannelOptions grpc() {
- return new GrpcChannelOptions() {
- @Override
- public String host() {
- return grpcAddress();
- }
+ public static class WeaviateCloud extends Builder {
+ public WeaviateCloud(String clusterUrl, TokenProvider tokenProvider) {
+ this(URI.create(clusterUrl), tokenProvider);
+ }
- @Override
- public boolean useTls() {
- return scheme.equals("https");
- }
+ public WeaviateCloud(URI clusterUrl, TokenProvider tokenProvider) {
+ super("https");
+ this.httpHost = clusterUrl.getHost();
+ this.httpPort = 443;
+ this.grpcHost = "grpc-" + httpPort;
+ this.grpcPort = 443;
+ this.tokenProvider = tokenProvider;
+ }
+ }
- @Override
- public Map headers() {
- return headers;
- }
- };
+ public static class Custom extends Builder {
+ public Custom(String scheme) {
+ super(scheme);
+ httpPort(scheme == "https" ? 443 : 80);
+ grpcPort(scheme == "https" ? 443 : 80);
+ }
+
+ public Custom httpHost(String host) {
+ this.httpHost = host;
+ return this;
+ }
+
+ public Custom httpPort(int port) {
+ this.grpcPort = port;
+ return this;
+ }
+
+ public Custom grpcHost(String host) {
+ this.grpcHost = host;
+ return this;
+ }
+
+ public Custom grpcPort(int port) {
+ this.grpcPort = port;
+ return this;
+ }
+
+ public Custom authorization(TokenProvider tokenProvider) {
+ this.tokenProvider = tokenProvider;
+ return this;
+ }
}
}
diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java
index f2ceeff24..7f41fbffc 100644
--- a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java
+++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java
@@ -2,8 +2,10 @@
import java.io.Closeable;
import java.io.IOException;
+import java.util.function.Function;
import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClient;
+import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.DefaultGrpcTransport;
import io.weaviate.client6.v1.internal.grpc.GrpcTransport;
import io.weaviate.client6.v1.internal.rest.DefaultRestTransport;
@@ -20,8 +22,8 @@ public class WeaviateClient implements Closeable {
public WeaviateClient(Config config) {
this.config = config;
- this.restTransport = new DefaultRestTransport(config.rest());
- this.grpcTransport = new DefaultGrpcTransport(config.grpc());
+ this.restTransport = new DefaultRestTransport(config.restTransportOptions());
+ this.grpcTransport = new DefaultGrpcTransport(config.grpcTransportOptions());
this.collections = new WeaviateCollectionsClient(restTransport, grpcTransport);
}
@@ -30,6 +32,25 @@ public WeaviateClientAsync async() {
return new WeaviateClientAsync(config);
}
+ public static WeaviateClient local() {
+ return local(ObjectBuilder.identity());
+ }
+
+ public static WeaviateClient local(Function> fn) {
+ var config = new Config.Local();
+ return new WeaviateClient(fn.apply(config).build());
+ }
+
+ public static WeaviateClient wcd(String clusterUrl, String apiKey) {
+ return wcd(clusterUrl, apiKey, ObjectBuilder.identity());
+ }
+
+ public static WeaviateClient wcd(String clusterUrl, String apiKey,
+ Function> fn) {
+ var config = new Config.WeaviateCloud(clusterUrl, Authorization.apiKey(apiKey));
+ return new WeaviateClient(fn.apply(config).build());
+ }
+
@Override
public void close() throws IOException {
this.restTransport.close();
diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java
index a33927292..af7d7acc3 100644
--- a/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java
+++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java
@@ -2,8 +2,10 @@
import java.io.Closeable;
import java.io.IOException;
+import java.util.function.Function;
import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClientAsync;
+import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.DefaultGrpcTransport;
import io.weaviate.client6.v1.internal.grpc.GrpcTransport;
import io.weaviate.client6.v1.internal.rest.DefaultRestTransport;
@@ -16,12 +18,31 @@ public class WeaviateClientAsync implements Closeable {
public final WeaviateCollectionsClientAsync collections;
public WeaviateClientAsync(Config config) {
- this.restTransport = new DefaultRestTransport(config.rest());
- this.grpcTransport = new DefaultGrpcTransport(config.grpc());
+ this.restTransport = new DefaultRestTransport(config.restTransportOptions());
+ this.grpcTransport = new DefaultGrpcTransport(config.grpcTransportOptions());
this.collections = new WeaviateCollectionsClientAsync(restTransport, grpcTransport);
}
+ public static WeaviateClientAsync local() {
+ return local(ObjectBuilder.identity());
+ }
+
+ public static WeaviateClientAsync local(Function> fn) {
+ var config = new Config.Local();
+ return new WeaviateClientAsync(fn.apply(config).build());
+ }
+
+ public static WeaviateClientAsync wcd(String clusterUrl, String apiKey) {
+ return wcd(clusterUrl, apiKey, ObjectBuilder.identity());
+ }
+
+ public static WeaviateClientAsync wcd(String clusterUrl, String apiKey,
+ Function> fn) {
+ var config = new Config.WeaviateCloud(clusterUrl, Authorization.apiKey(apiKey));
+ return new WeaviateClientAsync(fn.apply(config).build());
+ }
+
@Override
public void close() throws IOException {
this.restTransport.close();
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java
index d0aa0a815..6fdecbb84 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java
@@ -13,7 +13,7 @@ public record ObjectMetadata(
@SerializedName("lastUpdateTImeUnix") Long lastUpdatedAt) implements WeaviateMetadata {
public ObjectMetadata(Builder builder) {
- this(builder.id, builder.vectors, null, null);
+ this(builder.uuid, builder.vectors, null, null);
}
public static ObjectMetadata of(Function> fn) {
@@ -21,11 +21,11 @@ public static ObjectMetadata of(Function>
}
public static class Builder implements ObjectBuilder {
- private String id;
+ private String uuid;
private Vectors vectors;
- public Builder id(String id) {
- this.id = id;
+ public Builder uuid(String uuid) {
+ this.uuid = uuid;
return this;
}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java
deleted file mode 100644
index b7f5f9128..000000000
--- a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java
+++ /dev/null
@@ -1,9 +0,0 @@
-package io.weaviate.client6.v1.api.collections;
-
-import java.util.List;
-
-import io.weaviate.client6.v1.api.collections.query.QueryMetadata;
-
-public record ObjectReference(
- List, QueryMetadata>> objects) {
-}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java
index 84367b67e..2f7345f2a 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java
@@ -166,7 +166,7 @@ public void write(JsonWriter out, WeaviateObject, ?, ?> value) throws IOExcept
builder.properties(propertiesAdapter.fromJsonTree(trueProperties));
- metadata.id(object.get("id").getAsString());
+ metadata.uuid(object.get("id").getAsString());
builder.metadata(metadata.build());
return builder.build();
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java
index 6654c9c8d..f44515536 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java
@@ -54,7 +54,7 @@ public Builder(String collectionName, T properties) {
}
public Builder uuid(String uuid) {
- this.metadata.id(uuid);
+ this.metadata.uuid(uuid);
return this;
}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java
index 59cdee22a..d54678e67 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java
@@ -16,7 +16,7 @@ public static class Builder implements ObjectBuilder {
private Float certainty;
private Vectors vectors;
- public final Builder id(String uuid) {
+ public final Builder uuid(String uuid) {
this.uuid = uuid;
return this;
}
diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java
index fd40952e5..becf356b1 100644
--- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java
+++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java
@@ -88,7 +88,7 @@ private static WeaviateObject unmarshalResultObjec
CollectionDescriptor descriptor) {
var res = unmarshalReferences(propertiesResult, metadataResult, descriptor);
var metadata = new QueryMetadata.Builder()
- .id(res.metadata().uuid())
+ .uuid(res.metadata().uuid())
.distance(metadataResult.getDistance())
.certainty(metadataResult.getCertainty())
.vectors(res.metadata().vectors());
@@ -146,7 +146,7 @@ private static WeaviateObject unmarshalReferences
ObjectMetadata metadata = null;
if (metadataResult != null) {
var metadataBuilder = new ObjectMetadata.Builder()
- .id(metadataResult.getId());
+ .uuid(metadataResult.getId());
var vectors = new Vectors.Builder();
for (final var vector : metadataResult.getVectorsList()) {
diff --git a/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java
new file mode 100644
index 000000000..af69a456b
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java
@@ -0,0 +1,13 @@
+package io.weaviate.client6.v1.internal;
+
+@FunctionalInterface
+public interface TokenProvider {
+ Token getToken();
+
+ public record Token(String accessToken) {
+ }
+
+ public static TokenProvider staticToken(Token token) {
+ return () -> token;
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java
new file mode 100644
index 000000000..03ee045b7
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java
@@ -0,0 +1,41 @@
+package io.weaviate.client6.v1.internal;
+
+public abstract class TransportOptions {
+ private final String scheme;
+ private final String host;
+ private final int port;
+ private final TokenProvider tokenProvider;
+ private final H headers;
+
+ protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider) {
+ this.scheme = scheme;
+ this.host = host;
+ this.port = port;
+ this.tokenProvider = tokenProvider;
+ this.headers = headers;
+ }
+
+ public boolean isSecure() {
+ return scheme == "https";
+ }
+
+ public String scheme() {
+ return this.scheme;
+ }
+
+ public String host() {
+ return this.host;
+ }
+
+ public int port() {
+ return this.port;
+ }
+
+ public TokenProvider tokenProvider() {
+ return this.tokenProvider;
+ }
+
+ public H headers() {
+ return this.headers;
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java
index f071c9005..82aea9598 100644
--- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java
+++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java
@@ -9,7 +9,6 @@
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
-import io.grpc.Metadata;
import io.grpc.stub.MetadataUtils;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub;
@@ -21,13 +20,23 @@ public final class DefaultGrpcTransport implements GrpcTransport {
private final WeaviateBlockingStub blockingStub;
private final WeaviateFutureStub futureStub;
- private static final int HTTP_PORT = 80;
- private static final int HTTPS_PORT = 443;
+ public DefaultGrpcTransport(GrpcChannelOptions transportOptions) {
+ this.channel = buildChannel(transportOptions);
- public DefaultGrpcTransport(GrpcChannelOptions channelOptions) {
- this.channel = buildChannel(channelOptions);
- this.blockingStub = WeaviateGrpc.newBlockingStub(channel);
- this.futureStub = WeaviateGrpc.newFutureStub(channel);
+ var blockingStub = WeaviateGrpc.newBlockingStub(channel)
+ .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));
+
+ var futureStub = WeaviateGrpc.newFutureStub(channel)
+ .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));
+
+ if (transportOptions.tokenProvider() != null) {
+ var credentials = new TokenCallCredentials(transportOptions.tokenProvider());
+ blockingStub = blockingStub.withCallCredentials(credentials);
+ futureStub = futureStub.withCallCredentials(credentials);
+ }
+
+ this.blockingStub = blockingStub;
+ this.futureStub = futureStub;
}
@Override
@@ -70,24 +79,16 @@ public void onFailure(Throwable t) {
return completable;
}
- private static ManagedChannel buildChannel(GrpcChannelOptions options) {
- // var port = options.useTls() ? HTTPS_PORT : HTTP_PORT;
- // var channel = ManagedChannelBuilder.forAddress(options.host(), port);
- var channel = ManagedChannelBuilder.forTarget(options.host());
+ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) {
+ var channel = ManagedChannelBuilder.forAddress(transportOptions.host(), transportOptions.port());
- if (options.useTls()) {
+ if (transportOptions.isSecure()) {
channel.useTransportSecurity();
} else {
channel.usePlaintext();
}
- var headers = new Metadata();
- for (final var header : options.headers().entrySet()) {
- var key = Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER);
- headers.put(key, header.getValue());
-
- }
- channel.intercept(MetadataUtils.newAttachHeadersInterceptor(headers));
+ channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));
return channel.build();
}
diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java
index 517345844..da67cb0c2 100644
--- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java
+++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java
@@ -1,15 +1,24 @@
package io.weaviate.client6.v1.internal.grpc;
-import java.util.Collections;
import java.util.Map;
-// TODO: unify with rest.TransportOptions?
-public interface GrpcChannelOptions {
- String host();
+import io.grpc.Metadata;
+import io.weaviate.client6.v1.internal.TokenProvider;
+import io.weaviate.client6.v1.internal.TransportOptions;
- default Map headers() {
- return Collections.emptyMap();
+public class GrpcChannelOptions extends TransportOptions {
+ public GrpcChannelOptions(String scheme, String host, int port, Map headers,
+ TokenProvider tokenProvider) {
+ super(scheme, host, port, buildMetadata(headers), tokenProvider);
}
- boolean useTls();
+ private static final Metadata buildMetadata(Map headers) {
+ var metadata = new Metadata();
+ for (var header : headers.entrySet()) {
+ metadata.put(
+ Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER),
+ header.getValue());
+ }
+ return metadata;
+ }
}
diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java
new file mode 100644
index 000000000..c24a9093a
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java
@@ -0,0 +1,33 @@
+package io.weaviate.client6.v1.internal.grpc;
+
+import java.util.concurrent.Executor;
+
+import io.grpc.CallCredentials;
+import io.grpc.Metadata;
+import io.grpc.Status;
+import io.weaviate.client6.v1.internal.TokenProvider;
+
+class TokenCallCredentials extends CallCredentials {
+ private static final Metadata.Key AUTHORIZATION = Metadata.Key.of("Authorization",
+ Metadata.ASCII_STRING_MARSHALLER);
+
+ private final TokenProvider tokenProvider;
+
+ TokenCallCredentials(TokenProvider tokenProvider) {
+ this.tokenProvider = tokenProvider;
+ }
+
+ @Override
+ public void applyRequestMetadata(RequestInfo requestInfo, Executor executor, MetadataApplier metadataApplier) {
+ executor.execute(() -> {
+ try {
+ var headers = new Metadata();
+ var token = tokenProvider.getToken().accessToken();
+ headers.put(AUTHORIZATION, "Bearer " + token);
+ metadataApplier.apply(headers);
+ } catch (Exception e) {
+ metadataApplier.fail(Status.UNAUTHENTICATED.withCause(e));
+ }
+ });
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/AuthorizationInterceptor.java b/src/main/java/io/weaviate/client6/v1/internal/rest/AuthorizationInterceptor.java
new file mode 100644
index 000000000..9fe109d23
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/internal/rest/AuthorizationInterceptor.java
@@ -0,0 +1,29 @@
+package io.weaviate.client6.v1.internal.rest;
+
+import java.io.IOException;
+
+import org.apache.hc.core5.http.EntityDetails;
+import org.apache.hc.core5.http.HttpException;
+import org.apache.hc.core5.http.HttpRequest;
+import org.apache.hc.core5.http.HttpRequestInterceptor;
+import org.apache.hc.core5.http.message.BasicHeader;
+import org.apache.hc.core5.http.protocol.HttpContext;
+
+import io.weaviate.client6.v1.internal.TokenProvider;
+
+class AuthorizationInterceptor implements HttpRequestInterceptor {
+ private static final String AUTHORIZATION = "Authorization";
+
+ private final TokenProvider tokenProvider;
+
+ AuthorizationInterceptor(TokenProvider tokenProvider) {
+ this.tokenProvider = tokenProvider;
+ }
+
+ @Override
+ public void process(HttpRequest request, EntityDetails entity, HttpContext context)
+ throws HttpException, IOException {
+ var token = tokenProvider.getToken().accessToken();
+ request.addHeader(new BasicHeader(AUTHORIZATION, "Bearer " + token));
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java
index 470df5e89..f2b12f3b3 100644
--- a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java
+++ b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java
@@ -22,15 +22,29 @@
public class DefaultRestTransport implements RestTransport {
private final CloseableHttpClient httpClient;
private final CloseableHttpAsyncClient httpClientAsync;
- private final TransportOptions transportOptions;
+ private final RestTransportOptions transportOptions;
+ // TODO: retire
private static final Gson gson = new GsonBuilder().create();
- public DefaultRestTransport(TransportOptions options) {
- this.transportOptions = options;
- this.httpClient = HttpClients.createDefault();
- this.httpClientAsync = HttpAsyncClients.createDefault();
- httpClientAsync.start();
+ public DefaultRestTransport(RestTransportOptions transportOptions) {
+ this.transportOptions = transportOptions;
+
+ // TODO: doesn't make sense to spin up both?
+ var httpClient = HttpClients.custom()
+ .setDefaultHeaders(transportOptions.headers());
+ var httpClientAsync = HttpAsyncClients.custom()
+ .setDefaultHeaders(transportOptions.headers());
+
+ if (transportOptions.tokenProvider() != null) {
+ var interceptor = new AuthorizationInterceptor(transportOptions.tokenProvider());
+ httpClient.addRequestInterceptorFirst(interceptor);
+ httpClientAsync.addRequestInterceptorFirst(interceptor);
+ }
+
+ this.httpClient = httpClient.build();
+ this.httpClientAsync = httpClientAsync.build();
+ this.httpClientAsync.start();
}
@Override
@@ -76,7 +90,7 @@ public void cancelled() {
private SimpleHttpRequest prepareSimpleRequest(RequestT request, Endpoint endpoint) {
var method = endpoint.method(request);
- var uri = transportOptions.host() + endpoint.requestUrl(request);
+ var uri = transportOptions.baseUrl() + endpoint.requestUrl(request);
// TODO: apply options;
var body = endpoint.body(gson, request);
@@ -89,7 +103,7 @@ private SimpleHttpRequest prepareSimpleRequest(RequestT request, Endp
private ClassicHttpRequest prepareClassicRequest(RequestT request, Endpoint endpoint) {
var method = endpoint.method(request);
- var uri = transportOptions.host() + endpoint.requestUrl(request);
+ var uri = transportOptions.baseUrl() + endpoint.requestUrl(request);
// TODO: apply options;
var req = ClassicRequestBuilder.create(method).setUri(uri);
diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java
new file mode 100644
index 000000000..795695e72
--- /dev/null
+++ b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java
@@ -0,0 +1,31 @@
+package io.weaviate.client6.v1.internal.rest;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Map;
+
+import org.apache.hc.core5.http.message.BasicHeader;
+
+import io.weaviate.client6.v1.internal.TokenProvider;
+import io.weaviate.client6.v1.internal.TransportOptions;
+
+public final class RestTransportOptions extends TransportOptions> {
+ private static final String API_VERSION = "v1";
+
+ public RestTransportOptions(String scheme, String host, int port, Map headers,
+ TokenProvider tokenProvider) {
+ super(scheme, host, port, buildHeaders(headers), tokenProvider);
+ }
+
+ private static final Collection buildHeaders(Map headers) {
+ var basicHeaders = new HashSet();
+ for (var header : headers.entrySet()) {
+ basicHeaders.add(new BasicHeader(header.getKey(), header.getValue()));
+ }
+ return basicHeaders;
+ }
+
+ public String baseUrl() {
+ return scheme() + "://" + host() + ":" + port() + "/" + API_VERSION;
+ }
+}
diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/rest/TransportOptions.java
deleted file mode 100644
index 9ddb3fa70..000000000
--- a/src/main/java/io/weaviate/client6/v1/internal/rest/TransportOptions.java
+++ /dev/null
@@ -1,12 +0,0 @@
-package io.weaviate.client6.v1.internal.rest;
-
-import java.util.Collections;
-import java.util.Map;
-
-public interface TransportOptions {
- String host();
-
- default Map headers() {
- return Collections.emptyMap();
- }
-}
diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java
index 0be90cba7..e1ab5a9ee 100644
--- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java
+++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java
@@ -237,7 +237,7 @@ public static Object[][] testCases() {
"Things",
Map.of("title", "ThingOne"),
Map.of("hasRef", List.of(Reference.uuids("ref-1"))),
- ObjectMetadata.of(meta -> meta.id("thing-1"))),
+ ObjectMetadata.of(meta -> meta.uuid("thing-1"))),
"""
{
"class": "Things",