Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
<maven.compiler.release>17</maven.compiler.release>
<lombok.version>1.18.36</lombok.version>
<gson.version>2.12.1</gson.version>
<httpclient.version>5.4.2</httpclient.version>
<httpclient.version>5.4.3</httpclient.version>
<lang3.version>3.17.0</lang3.version>
<junit.version>5.12.0</junit.version>
<testcontainers.version>1.20.5</testcontainers.version>
Expand Down Expand Up @@ -194,6 +194,7 @@
</dependency>
<dependency>
<groupId>org.mock-server</groupId>
<!-- TODO: check if we can reduce JAR size by using mockserver-netty-no-dependencies -->
<artifactId>mockserver-netty</artifactId>
<version>${mock-server.version}</version>
<scope>test</scope>
Expand Down
13 changes: 9 additions & 4 deletions src/it/java/io/weaviate/containers/Weaviate.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.testcontainers.weaviate.WeaviateContainer;

import io.weaviate.client6.v1.api.Config;
import io.weaviate.client6.v1.api.WeaviateClient;

public class Weaviate extends WeaviateContainer {
Expand All @@ -29,8 +28,15 @@ public WeaviateClient getClient() {
start();
}
if (clientInstance == null) {
var config = new Config("http", getHttpHostAddress(), getGrpcHostAddress());
clientInstance = new WeaviateClient(config);
try {
clientInstance = WeaviateClient.local(
conn -> conn
.host(getHost())
.httpPort(getMappedPort(8080))
.grpcPort(getMappedPort(50051)));
} catch (Exception e) {
throw new RuntimeException("create WeaviateClient for Weaviate container", e);
}
}
return clientInstance;
}
Expand All @@ -46,7 +52,6 @@ public static Weaviate.Builder custom() {
public static class Builder {
private String versionTag;
private Set<String> enableModules = new HashSet<>();
private String defaultVectorizerModule;
private boolean telemetry;

private Map<String, String> environment = new HashMap<>();
Expand Down
53 changes: 53 additions & 0 deletions src/it/java/io/weaviate/integration/AuthorizationITest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.weaviate.integration;

import java.io.IOException;
import java.util.Collections;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockserver.integration.ClientAndServer;
import org.mockserver.model.HttpRequest;

import io.weaviate.ConcurrentTest;
import io.weaviate.client6.v1.api.Authorization;
import io.weaviate.client6.v1.internal.rest.DefaultRestTransport;
import io.weaviate.client6.v1.internal.rest.Endpoint;
import io.weaviate.client6.v1.internal.rest.RestTransportOptions;

public class AuthorizationITest extends ConcurrentTest {
private ClientAndServer mockServer;

@Before
public void startMockServer() {
mockServer = ClientAndServer.startClientAndServer(8080);
}

@Test
public void testAuthorization_apiKey() throws IOException {
var transportOptions = new RestTransportOptions(
"http", "localhost", 8080,
Collections.emptyMap(), Authorization.apiKey("my-api-key"));

try (final var restClient = new DefaultRestTransport(transportOptions)) {
restClient.performRequest(null, Endpoint.of(
request -> "GET",
request -> "/",
(gson, request) -> null,
request -> null,
code -> code != 200,
(gson, response) -> null));
}

mockServer.verify(
HttpRequest.request()
.withMethod("GET")
.withPath("/v1/")
.withHeader("Authorization", "Bearer my-api-key"));
}

@After
public void stopMockServer() {
mockServer.stop();
}
}
10 changes: 10 additions & 0 deletions src/main/java/io/weaviate/client6/v1/api/Authorization.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package io.weaviate.client6.v1.api;

import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.TokenProvider.Token;

public class Authorization {
public static TokenProvider apiKey(String apiKey) {
return TokenProvider.staticToken(new Token(apiKey));
}
}
192 changes: 145 additions & 47 deletions src/main/java/io/weaviate/client6/v1/api/Config.java
Original file line number Diff line number Diff line change
@@ -1,68 +1,166 @@
package io.weaviate.client6.v1.api;

import java.util.Collections;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions;
import io.weaviate.client6.v1.internal.rest.TransportOptions;

public class Config {
private final String version = "v1";
private final String scheme;
private final String httpHost;
private final String grpcHost;
private final Map<String, String> headers = Collections.emptyMap();

public Config(String scheme, String httpHost, String grpcHost) {
this.scheme = scheme;
this.httpHost = httpHost;
this.grpcHost = grpcHost;
import io.weaviate.client6.v1.internal.rest.RestTransportOptions;

public record Config(
String scheme,
String httpHost,
int httpPort,
String grpcHost,
int grpcPort,
Map<String, String> headers,
TokenProvider tokenProvider) {

public static Config of(String scheme, Function<Custom, ObjectBuilder<Config>> fn) {
return fn.apply(new Custom(scheme)).build();
}

public String baseUrl() {
return scheme + "://" + httpHost + "/" + version;
public Config(Builder<?> builder) {
this(
builder.scheme,
builder.httpHost,
builder.httpPort,
builder.grpcHost,
builder.grpcPort,
builder.headers,
builder.tokenProvider);
}

public String grpcAddress() {
if (grpcHost.contains(":")) {
return grpcHost;
}
// FIXME: use secure port (433) if scheme == https
return String.format("%s:80", grpcHost);
public RestTransportOptions restTransportOptions() {
return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider);
}

public TransportOptions rest() {
return new TransportOptions() {
public GrpcChannelOptions grpcTransportOptions() {
return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider);
}

@Override
public String host() {
return baseUrl();
}
abstract static class Builder<SELF extends Builder<SELF>> implements ObjectBuilder<Config> {
// Required parameters;
protected final String scheme;

protected String httpHost;
protected int httpPort;
protected String grpcHost;
protected int grpcPort;
protected TokenProvider tokenProvider;
protected Map<String, String> headers = new HashMap<>();

@Override
public Map<String, String> headers() {
return headers;
protected Builder(String scheme) {
this.scheme = scheme;
}

@SuppressWarnings("unchecked")
public SELF setHeader(String key, String value) {
this.headers.put(key, value);
return (SELF) this;
}

@SuppressWarnings("unchecked")
public SELF setHeaders(Map<String, String> headers) {
this.headers = Map.copyOf(headers);
return (SELF) this;
}

private static final String HEADER_X_WEAVIATE_CLUSTER_URL = "X-Weaviate-Cluster-URL";

/**
* isWeaviateDomain returns true if the host matches weaviate.io,
* semi.technology, or weaviate.cloud domain.
*/
private static boolean isWeaviateDomain(String host) {
var lower = host.toLowerCase();
return lower.contains("weaviate.io") ||
lower.contains("semi.technology") ||
lower.contains("weaviate.cloud");
}

@Override
public Config build() {
if (isWeaviateDomain(httpHost) && tokenProvider != null) {
setHeader(HEADER_X_WEAVIATE_CLUSTER_URL, "https://" + httpHost + ":" + httpPort);
}
return new Config(this);
}
}

public static class Local extends Builder<Local> {
public Local() {
super("http");
host("localhost");
httpPort(8080);
grpcPort(50051);
}

public Local host(String host) {
this.httpHost = host;
this.grpcHost = host;
return this;
}

};
public Local httpPort(int port) {
this.httpPort = port;
return this;
}

public Local grpcPort(int port) {
this.grpcPort = port;
return this;
}
}

public GrpcChannelOptions grpc() {
return new GrpcChannelOptions() {
@Override
public String host() {
return grpcAddress();
}
public static class WeaviateCloud extends Builder<WeaviateCloud> {
public WeaviateCloud(String clusterUrl, TokenProvider tokenProvider) {
this(URI.create(clusterUrl), tokenProvider);
}

@Override
public boolean useTls() {
return scheme.equals("https");
}
public WeaviateCloud(URI clusterUrl, TokenProvider tokenProvider) {
super("https");
this.httpHost = clusterUrl.getHost();
this.httpPort = 443;
this.grpcHost = "grpc-" + httpPort;
this.grpcPort = 443;
this.tokenProvider = tokenProvider;
}
}

@Override
public Map<String, String> headers() {
return headers;
}
};
public static class Custom extends Builder<Custom> {
public Custom(String scheme) {
super(scheme);
httpPort(scheme == "https" ? 443 : 80);
grpcPort(scheme == "https" ? 443 : 80);
}

public Custom httpHost(String host) {
this.httpHost = host;
return this;
}

public Custom httpPort(int port) {
this.grpcPort = port;
return this;
}

public Custom grpcHost(String host) {
this.grpcHost = host;
return this;
}

public Custom grpcPort(int port) {
this.grpcPort = port;
return this;
}

public Custom authorization(TokenProvider tokenProvider) {
this.tokenProvider = tokenProvider;
return this;
}
}
}
25 changes: 23 additions & 2 deletions src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import java.io.Closeable;
import java.io.IOException;
import java.util.function.Function;

import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClient;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.DefaultGrpcTransport;
import io.weaviate.client6.v1.internal.grpc.GrpcTransport;
import io.weaviate.client6.v1.internal.rest.DefaultRestTransport;
Expand All @@ -20,8 +22,8 @@ public class WeaviateClient implements Closeable {

public WeaviateClient(Config config) {
this.config = config;
this.restTransport = new DefaultRestTransport(config.rest());
this.grpcTransport = new DefaultGrpcTransport(config.grpc());
this.restTransport = new DefaultRestTransport(config.restTransportOptions());
this.grpcTransport = new DefaultGrpcTransport(config.grpcTransportOptions());

this.collections = new WeaviateCollectionsClient(restTransport, grpcTransport);
}
Expand All @@ -30,6 +32,25 @@ public WeaviateClientAsync async() {
return new WeaviateClientAsync(config);
}

public static WeaviateClient local() {
return local(ObjectBuilder.identity());
}

public static WeaviateClient local(Function<Config.Local, ObjectBuilder<Config>> fn) {
var config = new Config.Local();
return new WeaviateClient(fn.apply(config).build());
}

public static WeaviateClient wcd(String clusterUrl, String apiKey) {
return wcd(clusterUrl, apiKey, ObjectBuilder.identity());
}

public static WeaviateClient wcd(String clusterUrl, String apiKey,
Function<Config.WeaviateCloud, ObjectBuilder<Config>> fn) {
var config = new Config.WeaviateCloud(clusterUrl, Authorization.apiKey(apiKey));
return new WeaviateClient(fn.apply(config).build());
}

@Override
public void close() throws IOException {
this.restTransport.close();
Expand Down
Loading