diff --git a/pom.xml b/pom.xml index a5d5c559a..6503ced8a 100644 --- a/pom.xml +++ b/pom.xml @@ -65,7 +65,7 @@ 17 1.18.36 2.12.1 - 5.4.2 + 5.4.3 3.17.0 5.12.0 1.20.5 @@ -194,6 +194,7 @@ org.mock-server + mockserver-netty ${mock-server.version} test diff --git a/src/it/java/io/weaviate/containers/Weaviate.java b/src/it/java/io/weaviate/containers/Weaviate.java index c70342fd1..bcc1ba7d4 100644 --- a/src/it/java/io/weaviate/containers/Weaviate.java +++ b/src/it/java/io/weaviate/containers/Weaviate.java @@ -9,7 +9,6 @@ import org.testcontainers.weaviate.WeaviateContainer; -import io.weaviate.client6.v1.api.Config; import io.weaviate.client6.v1.api.WeaviateClient; public class Weaviate extends WeaviateContainer { @@ -29,8 +28,15 @@ public WeaviateClient getClient() { start(); } if (clientInstance == null) { - var config = new Config("http", getHttpHostAddress(), getGrpcHostAddress()); - clientInstance = new WeaviateClient(config); + try { + clientInstance = WeaviateClient.local( + conn -> conn + .host(getHost()) + .httpPort(getMappedPort(8080)) + .grpcPort(getMappedPort(50051))); + } catch (Exception e) { + throw new RuntimeException("create WeaviateClient for Weaviate container", e); + } } return clientInstance; } @@ -46,7 +52,6 @@ public static Weaviate.Builder custom() { public static class Builder { private String versionTag; private Set enableModules = new HashSet<>(); - private String defaultVectorizerModule; private boolean telemetry; private Map environment = new HashMap<>(); diff --git a/src/it/java/io/weaviate/integration/AuthorizationITest.java b/src/it/java/io/weaviate/integration/AuthorizationITest.java new file mode 100644 index 000000000..30422c67a --- /dev/null +++ b/src/it/java/io/weaviate/integration/AuthorizationITest.java @@ -0,0 +1,53 @@ +package io.weaviate.integration; + +import java.io.IOException; +import java.util.Collections; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.HttpRequest; + +import io.weaviate.ConcurrentTest; +import io.weaviate.client6.v1.api.Authorization; +import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; +import io.weaviate.client6.v1.internal.rest.Endpoint; +import io.weaviate.client6.v1.internal.rest.RestTransportOptions; + +public class AuthorizationITest extends ConcurrentTest { + private ClientAndServer mockServer; + + @Before + public void startMockServer() { + mockServer = ClientAndServer.startClientAndServer(8080); + } + + @Test + public void testAuthorization_apiKey() throws IOException { + var transportOptions = new RestTransportOptions( + "http", "localhost", 8080, + Collections.emptyMap(), Authorization.apiKey("my-api-key")); + + try (final var restClient = new DefaultRestTransport(transportOptions)) { + restClient.performRequest(null, Endpoint.of( + request -> "GET", + request -> "/", + (gson, request) -> null, + request -> null, + code -> code != 200, + (gson, response) -> null)); + } + + mockServer.verify( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/") + .withHeader("Authorization", "Bearer my-api-key")); + } + + @After + public void stopMockServer() { + mockServer.stop(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/Authorization.java b/src/main/java/io/weaviate/client6/v1/api/Authorization.java new file mode 100644 index 000000000..9bcfc2ed1 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/Authorization.java @@ -0,0 +1,10 @@ +package io.weaviate.client6.v1.api; + +import io.weaviate.client6.v1.internal.TokenProvider; +import io.weaviate.client6.v1.internal.TokenProvider.Token; + +public class Authorization { + public static TokenProvider apiKey(String apiKey) { + return TokenProvider.staticToken(new Token(apiKey)); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java index 2e7d9391d..bd005d167 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Config.java +++ b/src/main/java/io/weaviate/client6/v1/api/Config.java @@ -1,68 +1,166 @@ package io.weaviate.client6.v1.api; -import java.util.Collections; +import java.net.URI; +import java.util.HashMap; import java.util.Map; +import java.util.function.Function; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; -import io.weaviate.client6.v1.internal.rest.TransportOptions; - -public class Config { - private final String version = "v1"; - private final String scheme; - private final String httpHost; - private final String grpcHost; - private final Map headers = Collections.emptyMap(); - - public Config(String scheme, String httpHost, String grpcHost) { - this.scheme = scheme; - this.httpHost = httpHost; - this.grpcHost = grpcHost; +import io.weaviate.client6.v1.internal.rest.RestTransportOptions; + +public record Config( + String scheme, + String httpHost, + int httpPort, + String grpcHost, + int grpcPort, + Map headers, + TokenProvider tokenProvider) { + + public static Config of(String scheme, Function> fn) { + return fn.apply(new Custom(scheme)).build(); } - public String baseUrl() { - return scheme + "://" + httpHost + "/" + version; + public Config(Builder builder) { + this( + builder.scheme, + builder.httpHost, + builder.httpPort, + builder.grpcHost, + builder.grpcPort, + builder.headers, + builder.tokenProvider); } - public String grpcAddress() { - if (grpcHost.contains(":")) { - return grpcHost; - } - // FIXME: use secure port (433) if scheme == https - return String.format("%s:80", grpcHost); + public RestTransportOptions restTransportOptions() { + return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider); } - public TransportOptions rest() { - return new TransportOptions() { + public GrpcChannelOptions grpcTransportOptions() { + return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider); + } - @Override - public String host() { - return baseUrl(); - } + abstract static class Builder> implements ObjectBuilder { + // Required parameters; + protected final String scheme; + + protected String httpHost; + protected int httpPort; + protected String grpcHost; + protected int grpcPort; + protected TokenProvider tokenProvider; + protected Map headers = new HashMap<>(); - @Override - public Map headers() { - return headers; + protected Builder(String scheme) { + this.scheme = scheme; + } + + @SuppressWarnings("unchecked") + public SELF setHeader(String key, String value) { + this.headers.put(key, value); + return (SELF) this; + } + + @SuppressWarnings("unchecked") + public SELF setHeaders(Map headers) { + this.headers = Map.copyOf(headers); + return (SELF) this; + } + + private static final String HEADER_X_WEAVIATE_CLUSTER_URL = "X-Weaviate-Cluster-URL"; + + /** + * isWeaviateDomain returns true if the host matches weaviate.io, + * semi.technology, or weaviate.cloud domain. + */ + private static boolean isWeaviateDomain(String host) { + var lower = host.toLowerCase(); + return lower.contains("weaviate.io") || + lower.contains("semi.technology") || + lower.contains("weaviate.cloud"); + } + + @Override + public Config build() { + if (isWeaviateDomain(httpHost) && tokenProvider != null) { + setHeader(HEADER_X_WEAVIATE_CLUSTER_URL, "https://" + httpHost + ":" + httpPort); } + return new Config(this); + } + } + + public static class Local extends Builder { + public Local() { + super("http"); + host("localhost"); + httpPort(8080); + grpcPort(50051); + } + + public Local host(String host) { + this.httpHost = host; + this.grpcHost = host; + return this; + } - }; + public Local httpPort(int port) { + this.httpPort = port; + return this; + } + + public Local grpcPort(int port) { + this.grpcPort = port; + return this; + } } - public GrpcChannelOptions grpc() { - return new GrpcChannelOptions() { - @Override - public String host() { - return grpcAddress(); - } + public static class WeaviateCloud extends Builder { + public WeaviateCloud(String clusterUrl, TokenProvider tokenProvider) { + this(URI.create(clusterUrl), tokenProvider); + } - @Override - public boolean useTls() { - return scheme.equals("https"); - } + public WeaviateCloud(URI clusterUrl, TokenProvider tokenProvider) { + super("https"); + this.httpHost = clusterUrl.getHost(); + this.httpPort = 443; + this.grpcHost = "grpc-" + httpPort; + this.grpcPort = 443; + this.tokenProvider = tokenProvider; + } + } - @Override - public Map headers() { - return headers; - } - }; + public static class Custom extends Builder { + public Custom(String scheme) { + super(scheme); + httpPort(scheme == "https" ? 443 : 80); + grpcPort(scheme == "https" ? 443 : 80); + } + + public Custom httpHost(String host) { + this.httpHost = host; + return this; + } + + public Custom httpPort(int port) { + this.grpcPort = port; + return this; + } + + public Custom grpcHost(String host) { + this.grpcHost = host; + return this; + } + + public Custom grpcPort(int port) { + this.grpcPort = port; + return this; + } + + public Custom authorization(TokenProvider tokenProvider) { + this.tokenProvider = tokenProvider; + return this; + } } } diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java index f2ceeff24..7f41fbffc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java @@ -2,8 +2,10 @@ import java.io.Closeable; import java.io.IOException; +import java.util.function.Function; import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClient; +import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.DefaultGrpcTransport; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; @@ -20,8 +22,8 @@ public class WeaviateClient implements Closeable { public WeaviateClient(Config config) { this.config = config; - this.restTransport = new DefaultRestTransport(config.rest()); - this.grpcTransport = new DefaultGrpcTransport(config.grpc()); + this.restTransport = new DefaultRestTransport(config.restTransportOptions()); + this.grpcTransport = new DefaultGrpcTransport(config.grpcTransportOptions()); this.collections = new WeaviateCollectionsClient(restTransport, grpcTransport); } @@ -30,6 +32,25 @@ public WeaviateClientAsync async() { return new WeaviateClientAsync(config); } + public static WeaviateClient local() { + return local(ObjectBuilder.identity()); + } + + public static WeaviateClient local(Function> fn) { + var config = new Config.Local(); + return new WeaviateClient(fn.apply(config).build()); + } + + public static WeaviateClient wcd(String clusterUrl, String apiKey) { + return wcd(clusterUrl, apiKey, ObjectBuilder.identity()); + } + + public static WeaviateClient wcd(String clusterUrl, String apiKey, + Function> fn) { + var config = new Config.WeaviateCloud(clusterUrl, Authorization.apiKey(apiKey)); + return new WeaviateClient(fn.apply(config).build()); + } + @Override public void close() throws IOException { this.restTransport.close(); diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java index a33927292..af7d7acc3 100644 --- a/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java @@ -2,8 +2,10 @@ import java.io.Closeable; import java.io.IOException; +import java.util.function.Function; import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClientAsync; +import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.DefaultGrpcTransport; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; @@ -16,12 +18,31 @@ public class WeaviateClientAsync implements Closeable { public final WeaviateCollectionsClientAsync collections; public WeaviateClientAsync(Config config) { - this.restTransport = new DefaultRestTransport(config.rest()); - this.grpcTransport = new DefaultGrpcTransport(config.grpc()); + this.restTransport = new DefaultRestTransport(config.restTransportOptions()); + this.grpcTransport = new DefaultGrpcTransport(config.grpcTransportOptions()); this.collections = new WeaviateCollectionsClientAsync(restTransport, grpcTransport); } + public static WeaviateClientAsync local() { + return local(ObjectBuilder.identity()); + } + + public static WeaviateClientAsync local(Function> fn) { + var config = new Config.Local(); + return new WeaviateClientAsync(fn.apply(config).build()); + } + + public static WeaviateClientAsync wcd(String clusterUrl, String apiKey) { + return wcd(clusterUrl, apiKey, ObjectBuilder.identity()); + } + + public static WeaviateClientAsync wcd(String clusterUrl, String apiKey, + Function> fn) { + var config = new Config.WeaviateCloud(clusterUrl, Authorization.apiKey(apiKey)); + return new WeaviateClientAsync(fn.apply(config).build()); + } + @Override public void close() throws IOException { this.restTransport.close(); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java index d0aa0a815..6fdecbb84 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java @@ -13,7 +13,7 @@ public record ObjectMetadata( @SerializedName("lastUpdateTImeUnix") Long lastUpdatedAt) implements WeaviateMetadata { public ObjectMetadata(Builder builder) { - this(builder.id, builder.vectors, null, null); + this(builder.uuid, builder.vectors, null, null); } public static ObjectMetadata of(Function> fn) { @@ -21,11 +21,11 @@ public static ObjectMetadata of(Function> } public static class Builder implements ObjectBuilder { - private String id; + private String uuid; private Vectors vectors; - public Builder id(String id) { - this.id = id; + public Builder uuid(String uuid) { + this.uuid = uuid; return this; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java deleted file mode 100644 index b7f5f9128..000000000 --- a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java +++ /dev/null @@ -1,9 +0,0 @@ -package io.weaviate.client6.v1.api.collections; - -import java.util.List; - -import io.weaviate.client6.v1.api.collections.query.QueryMetadata; - -public record ObjectReference( - List, QueryMetadata>> objects) { -} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java index 84367b67e..2f7345f2a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java @@ -166,7 +166,7 @@ public void write(JsonWriter out, WeaviateObject value) throws IOExcept builder.properties(propertiesAdapter.fromJsonTree(trueProperties)); - metadata.id(object.get("id").getAsString()); + metadata.uuid(object.get("id").getAsString()); builder.metadata(metadata.build()); return builder.build(); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java index 6654c9c8d..f44515536 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java @@ -54,7 +54,7 @@ public Builder(String collectionName, T properties) { } public Builder uuid(String uuid) { - this.metadata.id(uuid); + this.metadata.uuid(uuid); return this; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java index 59cdee22a..d54678e67 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java @@ -16,7 +16,7 @@ public static class Builder implements ObjectBuilder { private Float certainty; private Vectors vectors; - public final Builder id(String uuid) { + public final Builder uuid(String uuid) { this.uuid = uuid; return this; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index fd40952e5..becf356b1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -88,7 +88,7 @@ private static WeaviateObject unmarshalResultObjec CollectionDescriptor descriptor) { var res = unmarshalReferences(propertiesResult, metadataResult, descriptor); var metadata = new QueryMetadata.Builder() - .id(res.metadata().uuid()) + .uuid(res.metadata().uuid()) .distance(metadataResult.getDistance()) .certainty(metadataResult.getCertainty()) .vectors(res.metadata().vectors()); @@ -146,7 +146,7 @@ private static WeaviateObject unmarshalReferences ObjectMetadata metadata = null; if (metadataResult != null) { var metadataBuilder = new ObjectMetadata.Builder() - .id(metadataResult.getId()); + .uuid(metadataResult.getId()); var vectors = new Vectors.Builder(); for (final var vector : metadataResult.getVectorsList()) { diff --git a/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java new file mode 100644 index 000000000..af69a456b --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java @@ -0,0 +1,13 @@ +package io.weaviate.client6.v1.internal; + +@FunctionalInterface +public interface TokenProvider { + Token getToken(); + + public record Token(String accessToken) { + } + + public static TokenProvider staticToken(Token token) { + return () -> token; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java new file mode 100644 index 000000000..03ee045b7 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/TransportOptions.java @@ -0,0 +1,41 @@ +package io.weaviate.client6.v1.internal; + +public abstract class TransportOptions { + private final String scheme; + private final String host; + private final int port; + private final TokenProvider tokenProvider; + private final H headers; + + protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider) { + this.scheme = scheme; + this.host = host; + this.port = port; + this.tokenProvider = tokenProvider; + this.headers = headers; + } + + public boolean isSecure() { + return scheme == "https"; + } + + public String scheme() { + return this.scheme; + } + + public String host() { + return this.host; + } + + public int port() { + return this.port; + } + + public TokenProvider tokenProvider() { + return this.tokenProvider; + } + + public H headers() { + return this.headers; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java index f071c9005..82aea9598 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java @@ -9,7 +9,6 @@ import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; -import io.grpc.Metadata; import io.grpc.stub.MetadataUtils; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; @@ -21,13 +20,23 @@ public final class DefaultGrpcTransport implements GrpcTransport { private final WeaviateBlockingStub blockingStub; private final WeaviateFutureStub futureStub; - private static final int HTTP_PORT = 80; - private static final int HTTPS_PORT = 443; + public DefaultGrpcTransport(GrpcChannelOptions transportOptions) { + this.channel = buildChannel(transportOptions); - public DefaultGrpcTransport(GrpcChannelOptions channelOptions) { - this.channel = buildChannel(channelOptions); - this.blockingStub = WeaviateGrpc.newBlockingStub(channel); - this.futureStub = WeaviateGrpc.newFutureStub(channel); + var blockingStub = WeaviateGrpc.newBlockingStub(channel) + .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); + + var futureStub = WeaviateGrpc.newFutureStub(channel) + .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); + + if (transportOptions.tokenProvider() != null) { + var credentials = new TokenCallCredentials(transportOptions.tokenProvider()); + blockingStub = blockingStub.withCallCredentials(credentials); + futureStub = futureStub.withCallCredentials(credentials); + } + + this.blockingStub = blockingStub; + this.futureStub = futureStub; } @Override @@ -70,24 +79,16 @@ public void onFailure(Throwable t) { return completable; } - private static ManagedChannel buildChannel(GrpcChannelOptions options) { - // var port = options.useTls() ? HTTPS_PORT : HTTP_PORT; - // var channel = ManagedChannelBuilder.forAddress(options.host(), port); - var channel = ManagedChannelBuilder.forTarget(options.host()); + private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) { + var channel = ManagedChannelBuilder.forAddress(transportOptions.host(), transportOptions.port()); - if (options.useTls()) { + if (transportOptions.isSecure()) { channel.useTransportSecurity(); } else { channel.usePlaintext(); } - var headers = new Metadata(); - for (final var header : options.headers().entrySet()) { - var key = Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER); - headers.put(key, header.getValue()); - - } - channel.intercept(MetadataUtils.newAttachHeadersInterceptor(headers)); + channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); return channel.build(); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java index 517345844..da67cb0c2 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java @@ -1,15 +1,24 @@ package io.weaviate.client6.v1.internal.grpc; -import java.util.Collections; import java.util.Map; -// TODO: unify with rest.TransportOptions? -public interface GrpcChannelOptions { - String host(); +import io.grpc.Metadata; +import io.weaviate.client6.v1.internal.TokenProvider; +import io.weaviate.client6.v1.internal.TransportOptions; - default Map headers() { - return Collections.emptyMap(); +public class GrpcChannelOptions extends TransportOptions { + public GrpcChannelOptions(String scheme, String host, int port, Map headers, + TokenProvider tokenProvider) { + super(scheme, host, port, buildMetadata(headers), tokenProvider); } - boolean useTls(); + private static final Metadata buildMetadata(Map headers) { + var metadata = new Metadata(); + for (var header : headers.entrySet()) { + metadata.put( + Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER), + header.getValue()); + } + return metadata; + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java new file mode 100644 index 000000000..c24a9093a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java @@ -0,0 +1,33 @@ +package io.weaviate.client6.v1.internal.grpc; + +import java.util.concurrent.Executor; + +import io.grpc.CallCredentials; +import io.grpc.Metadata; +import io.grpc.Status; +import io.weaviate.client6.v1.internal.TokenProvider; + +class TokenCallCredentials extends CallCredentials { + private static final Metadata.Key AUTHORIZATION = Metadata.Key.of("Authorization", + Metadata.ASCII_STRING_MARSHALLER); + + private final TokenProvider tokenProvider; + + TokenCallCredentials(TokenProvider tokenProvider) { + this.tokenProvider = tokenProvider; + } + + @Override + public void applyRequestMetadata(RequestInfo requestInfo, Executor executor, MetadataApplier metadataApplier) { + executor.execute(() -> { + try { + var headers = new Metadata(); + var token = tokenProvider.getToken().accessToken(); + headers.put(AUTHORIZATION, "Bearer " + token); + metadataApplier.apply(headers); + } catch (Exception e) { + metadataApplier.fail(Status.UNAUTHENTICATED.withCause(e)); + } + }); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/AuthorizationInterceptor.java b/src/main/java/io/weaviate/client6/v1/internal/rest/AuthorizationInterceptor.java new file mode 100644 index 000000000..9fe109d23 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/AuthorizationInterceptor.java @@ -0,0 +1,29 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.io.IOException; + +import org.apache.hc.core5.http.EntityDetails; +import org.apache.hc.core5.http.HttpException; +import org.apache.hc.core5.http.HttpRequest; +import org.apache.hc.core5.http.HttpRequestInterceptor; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.hc.core5.http.protocol.HttpContext; + +import io.weaviate.client6.v1.internal.TokenProvider; + +class AuthorizationInterceptor implements HttpRequestInterceptor { + private static final String AUTHORIZATION = "Authorization"; + + private final TokenProvider tokenProvider; + + AuthorizationInterceptor(TokenProvider tokenProvider) { + this.tokenProvider = tokenProvider; + } + + @Override + public void process(HttpRequest request, EntityDetails entity, HttpContext context) + throws HttpException, IOException { + var token = tokenProvider.getToken().accessToken(); + request.addHeader(new BasicHeader(AUTHORIZATION, "Bearer " + token)); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java index 470df5e89..f2b12f3b3 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java @@ -22,15 +22,29 @@ public class DefaultRestTransport implements RestTransport { private final CloseableHttpClient httpClient; private final CloseableHttpAsyncClient httpClientAsync; - private final TransportOptions transportOptions; + private final RestTransportOptions transportOptions; + // TODO: retire private static final Gson gson = new GsonBuilder().create(); - public DefaultRestTransport(TransportOptions options) { - this.transportOptions = options; - this.httpClient = HttpClients.createDefault(); - this.httpClientAsync = HttpAsyncClients.createDefault(); - httpClientAsync.start(); + public DefaultRestTransport(RestTransportOptions transportOptions) { + this.transportOptions = transportOptions; + + // TODO: doesn't make sense to spin up both? + var httpClient = HttpClients.custom() + .setDefaultHeaders(transportOptions.headers()); + var httpClientAsync = HttpAsyncClients.custom() + .setDefaultHeaders(transportOptions.headers()); + + if (transportOptions.tokenProvider() != null) { + var interceptor = new AuthorizationInterceptor(transportOptions.tokenProvider()); + httpClient.addRequestInterceptorFirst(interceptor); + httpClientAsync.addRequestInterceptorFirst(interceptor); + } + + this.httpClient = httpClient.build(); + this.httpClientAsync = httpClientAsync.build(); + this.httpClientAsync.start(); } @Override @@ -76,7 +90,7 @@ public void cancelled() { private SimpleHttpRequest prepareSimpleRequest(RequestT request, Endpoint endpoint) { var method = endpoint.method(request); - var uri = transportOptions.host() + endpoint.requestUrl(request); + var uri = transportOptions.baseUrl() + endpoint.requestUrl(request); // TODO: apply options; var body = endpoint.body(gson, request); @@ -89,7 +103,7 @@ private SimpleHttpRequest prepareSimpleRequest(RequestT request, Endp private ClassicHttpRequest prepareClassicRequest(RequestT request, Endpoint endpoint) { var method = endpoint.method(request); - var uri = transportOptions.host() + endpoint.requestUrl(request); + var uri = transportOptions.baseUrl() + endpoint.requestUrl(request); // TODO: apply options; var req = ClassicRequestBuilder.create(method).setUri(uri); diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java new file mode 100644 index 000000000..795695e72 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransportOptions.java @@ -0,0 +1,31 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; + +import org.apache.hc.core5.http.message.BasicHeader; + +import io.weaviate.client6.v1.internal.TokenProvider; +import io.weaviate.client6.v1.internal.TransportOptions; + +public final class RestTransportOptions extends TransportOptions> { + private static final String API_VERSION = "v1"; + + public RestTransportOptions(String scheme, String host, int port, Map headers, + TokenProvider tokenProvider) { + super(scheme, host, port, buildHeaders(headers), tokenProvider); + } + + private static final Collection buildHeaders(Map headers) { + var basicHeaders = new HashSet(); + for (var header : headers.entrySet()) { + basicHeaders.add(new BasicHeader(header.getKey(), header.getValue())); + } + return basicHeaders; + } + + public String baseUrl() { + return scheme() + "://" + host() + ":" + port() + "/" + API_VERSION; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/rest/TransportOptions.java deleted file mode 100644 index 9ddb3fa70..000000000 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/TransportOptions.java +++ /dev/null @@ -1,12 +0,0 @@ -package io.weaviate.client6.v1.internal.rest; - -import java.util.Collections; -import java.util.Map; - -public interface TransportOptions { - String host(); - - default Map headers() { - return Collections.emptyMap(); - } -} diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index 0be90cba7..e1ab5a9ee 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -237,7 +237,7 @@ public static Object[][] testCases() { "Things", Map.of("title", "ThingOne"), Map.of("hasRef", List.of(Reference.uuids("ref-1"))), - ObjectMetadata.of(meta -> meta.id("thing-1"))), + ObjectMetadata.of(meta -> meta.uuid("thing-1"))), """ { "class": "Things",