Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.druid.guice.annotations.Client;
import io.druid.guice.annotations.Global;
import io.druid.initialization.DruidModule;
import io.druid.server.router.Router;

import javax.net.ssl.SSLContext;
import java.util.List;
Expand All @@ -46,5 +47,6 @@ public void configure(Binder binder)
binder.bind(SSLContext.class).toProvider(SSLContextProvider.class);
binder.bind(SSLContext.class).annotatedWith(Global.class).toProvider(SSLContextProvider.class);
binder.bind(SSLContext.class).annotatedWith(Client.class).toProvider(SSLContextProvider.class);
binder.bind(SSLContext.class).annotatedWith(Router.class).toProvider(SSLContextProvider.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.inject.Provider;
import com.metamx.emitter.EmittingLogger;
import com.metamx.emitter.service.ServiceEmitter;
import io.druid.client.selector.Server;
import io.druid.guice.annotations.Json;
import io.druid.guice.annotations.Smile;
import io.druid.guice.http.DruidHttpClientConfig;
Expand Down Expand Up @@ -72,6 +73,7 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
private static final String APPLICATION_SMILE = "application/smile";

private static final String HOST_ATTRIBUTE = "io.druid.proxy.to.host";
private static final String SCHEME_ATTRIBUTE = "io.druid.proxy.to.host.scheme";
private static final String QUERY_ATTRIBUTE = "io.druid.proxy.query";
private static final String OBJECTMAPPER_ATTRIBUTE = "io.druid.proxy.objectMapper";

Expand Down Expand Up @@ -169,35 +171,31 @@ protected void service(HttpServletRequest request, HttpServletResponse response)
final ObjectMapper objectMapper = isSmile ? smileMapper : jsonMapper;
request.setAttribute(OBJECTMAPPER_ATTRIBUTE, objectMapper);

final String defaultHost = hostFinder.getDefaultHost();
request.setAttribute(HOST_ATTRIBUTE, defaultHost);
final Server defaultServer = hostFinder.getDefaultServer();
request.setAttribute(HOST_ATTRIBUTE, defaultServer.getHost());
request.setAttribute(SCHEME_ATTRIBUTE, defaultServer.getScheme());

final boolean isQueryEndpoint = request.getRequestURI().startsWith("/druid/v2");

if (isQueryEndpoint && HttpMethod.DELETE.is(request.getMethod())) {
// query cancellation request
for (final String host : hostFinder.getAllHosts()) {
for (final Server server: hostFinder.getAllServers()) {
// send query cancellation to all brokers this query may have gone to
// to keep the code simple, the proxy servlet will also send a request to one of the default brokers
if (!host.equals(defaultHost)) {
if (!server.getHost().equals(defaultServer.getHost())) {
// issue async requests
broadcastClient
.newRequest(rewriteURI(request, host))
.newRequest(rewriteURI(request, server.getScheme(), server.getHost()))
.method(HttpMethod.DELETE)
.timeout(CANCELLATION_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS)
.send(
new Response.CompleteListener()
{
@Override
public void onComplete(Result result)
{
if (result.isFailed()) {
log.warn(
result.getFailure(),
"Failed to forward cancellation request to [%s]",
host
);
}
result -> {
if (result.isFailed()) {
log.warn(
result.getFailure(),
"Failed to forward cancellation request to [%s]",
server.getHost()
);
}
}
);
Expand All @@ -209,7 +207,9 @@ public void onComplete(Result result)
try {
Query inputQuery = objectMapper.readValue(request.getInputStream(), Query.class);
if (inputQuery != null) {
request.setAttribute(HOST_ATTRIBUTE, hostFinder.getHost(inputQuery));
final Server server = hostFinder.getServer(inputQuery);
request.setAttribute(HOST_ATTRIBUTE, server.getHost());
request.setAttribute(SCHEME_ATTRIBUTE, server.getScheme());
if (inputQuery.getId() == null) {
inputQuery = inputQuery.withId(UUID.randomUUID().toString());
}
Expand Down Expand Up @@ -289,19 +289,19 @@ protected Response.Listener newProxyResponseListener(
@Override
protected String rewriteTarget(HttpServletRequest request)
{
return rewriteURI(request, (String) request.getAttribute(HOST_ATTRIBUTE)).toString();
return rewriteURI(request, (String) request.getAttribute(SCHEME_ATTRIBUTE), (String) request.getAttribute(HOST_ATTRIBUTE)).toString();
}

protected URI rewriteURI(HttpServletRequest request, String host)
protected URI rewriteURI(HttpServletRequest request, String scheme, String host)
{
return makeURI(host, request.getRequestURI(), request.getQueryString());
return makeURI(scheme, host, request.getRequestURI(), request.getQueryString());
}

protected static URI makeURI(String host, String requestURI, String rawQueryString)
protected static URI makeURI(String scheme, String host, String requestURI, String rawQueryString)
{
try {
return new URI(
"http",
scheme,
host,
requestURI,
rawQueryString == null ? null : URLDecoder.decode(rawQueryString, "UTF-8"),
Expand Down
38 changes: 10 additions & 28 deletions server/src/main/java/io/druid/server/router/QueryHostFinder.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

package io.druid.server.router;

import com.google.common.base.Function;
import com.google.common.collect.FluentIterable;
import com.google.inject.Inject;
import com.metamx.emitter.EmittingLogger;
import io.druid.client.selector.Server;
Expand All @@ -31,6 +29,7 @@
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
*/
Expand Down Expand Up @@ -62,30 +61,14 @@ public Server findDefaultServer()
return findServerInner(selected);
}

public Collection<String> getAllHosts()
public Collection<Server> getAllServers()
{
return FluentIterable
.from((Collection<List<Server>>) hostSelector.getAllBrokers().values())
.transformAndConcat(
new Function<List<Server>, Iterable<Server>>()
{
@Override
public Iterable<Server> apply(List<Server> input)
{
return input;
}
}
).transform(new Function<Server, String>()
{
@Override
public String apply(Server input)
{
return input.getHost();
}
}).toList();
return ((Collection<List<Server>>) hostSelector.getAllBrokers().values()).stream()
.flatMap(Collection::stream)
.collect(Collectors.toList());
}

public <T> String getHost(Query<T> query)
public <T> Server getServer(Query<T> query)
{
Server server = findServer(query);

Expand All @@ -97,13 +80,12 @@ public <T> String getHost(Query<T> query)
throw new ISE("No server found for query[%s]", query);
}

final String host = server.getHost();
log.debug("Selected [%s]", host);
log.debug("Selected [%s]", server.getHost());

return host;
return server;
}

public String getDefaultHost()
public Server getDefaultServer()
{
Server server = findDefaultServer();

Expand All @@ -115,7 +97,7 @@ public String getDefaultHost()
throw new ISE("No default server found!");
}

return server.getHost();
return server;
}

private Server findServerInner(final Pair<String, Server> selected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.servlet.GuiceFilter;

import io.druid.common.utils.SocketUtil;
import io.druid.guice.GuiceInjectors;
import io.druid.guice.Jerseys;
Expand Down Expand Up @@ -108,7 +107,9 @@ protected Injector setupInjector()
public void configure(Binder binder)
{
JsonConfigProvider.bindInstance(
binder, Key.get(DruidNode.class, Self.class), new DruidNode("test", "localhost", null, null, new ServerConfig())
binder,
Key.get(DruidNode.class, Self.class),
new DruidNode("test", "localhost", null, null, new ServerConfig())
);
binder.bind(JettyServerInitializer.class).to(ProxyJettyServerInit.class).in(LazySingleton.class);
Jerseys.addResource(binder, SlowResource.class);
Expand Down Expand Up @@ -197,24 +198,24 @@ public void initialize(Server server, Injector injector)
final QueryHostFinder hostFinder = new QueryHostFinder(null)
{
@Override
public String getHost(Query query)
public io.druid.client.selector.Server getServer(Query query)
{
return "localhost:" + node.getPlaintextPort();
return new TestServer("http", "localhost", node.getPlaintextPort());
}

@Override
public String getDefaultHost()
public io.druid.client.selector.Server getDefaultServer()
{
return "localhost:" + node.getPlaintextPort();
return new TestServer("http", "localhost", node.getPlaintextPort());
}

@Override
public Collection<String> getAllHosts()
public Collection<io.druid.client.selector.Server> getAllServers()
{
return ImmutableList.of(
"localhost:" + node.getPlaintextPort(),
"localhost:" + port1,
"localhost:" + port2
new TestServer("http", "localhost", node.getPlaintextPort()),
new TestServer("http", "localhost", port1),
new TestServer("http", "localhost", port2)
);
}
};
Expand All @@ -241,9 +242,9 @@ public void log(RequestLogLine requestLogLine) throws IOException
)
{
@Override
protected URI rewriteURI(HttpServletRequest request, String host)
protected URI rewriteURI(HttpServletRequest request, String scheme, String host)
{
String uri = super.rewriteURI(request, host).toString();
String uri = super.rewriteURI(request, scheme, host).toString();
if (uri.contains("/druid/v2")) {
return URI.create(uri.replace("/druid/v2", "/default"));
}
Expand Down Expand Up @@ -272,14 +273,15 @@ public void testRewriteURI() throws Exception
// test params
Assert.assertEquals(
new URI("http://localhost:1234/some/path?param=1"),
AsyncQueryForwardingServlet.makeURI("localhost:1234", "/some/path", "param=1")
AsyncQueryForwardingServlet.makeURI("http", "localhost:1234", "/some/path", "param=1")
);

// HttpServletRequest.getQueryString returns encoded form
// use ascii representation in case URI is using non-ascii characters
Assert.assertEquals(
"http://[2a00:1450:4007:805::1007]:1234/some/path?param=1&param2=%E2%82%AC",
AsyncQueryForwardingServlet.makeURI(
"http",
HostAndPort.fromParts("2a00:1450:4007:805::1007", 1234).toString(),
"/some/path",
"param=1&param2=%E2%82%AC"
Expand All @@ -289,7 +291,46 @@ public void testRewriteURI() throws Exception
// test null query
Assert.assertEquals(
new URI("http://localhost/"),
AsyncQueryForwardingServlet.makeURI("localhost", "/", null)
AsyncQueryForwardingServlet.makeURI("http", "localhost", "/", null)
);
}

private static class TestServer implements io.druid.client.selector.Server
{

private final String scheme;
private final String address;
private final int port;

public TestServer(String scheme, String address, int port)
{
this.scheme = scheme;
this.address = address;
this.port = port;
}

@Override
public String getScheme()
{
return scheme;
}

@Override
public String getHost()
{
return address + ":" + port;
}

@Override
public String getAddress()
{
return address;
}

@Override
public int getPort()
{
return port;
}
}
}