Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package io.druid.curator.discovery;

import com.google.common.base.Function;
import com.google.common.collect.Collections2;
import com.google.common.net.HostAndPort;
import com.metamx.common.lifecycle.LifecycleStart;
import com.metamx.common.lifecycle.LifecycleStop;
Expand All @@ -27,6 +29,8 @@
import org.apache.curator.x.discovery.ServiceProvider;

import java.io.IOException;
import java.util.Collection;
import java.util.Collections;

/**
*/
Expand All @@ -41,6 +45,40 @@ public ServerDiscoverySelector(ServiceProvider serviceProvider)
this.serviceProvider = serviceProvider;
}

private static final Function<ServiceInstance, Server> TO_SERVER = new Function<ServiceInstance, Server>()
{
@Override
public Server apply(final ServiceInstance instance)
{
return new Server()
{
@Override
public String getHost()
{
return HostAndPort.fromParts(getAddress(), getPort()).toString();
}

@Override
public String getAddress()
{
return instance.getAddress();
}

@Override
public int getPort()
{
return instance.getPort();
}

@Override
public String getScheme()
{
return "http";
}
};
}
};

@Override
public Server pick()
{
Expand All @@ -58,32 +96,18 @@ public Server pick()
return null;
}

return new Server()
{
@Override
public String getHost()
{
return HostAndPort.fromParts(getAddress(), getPort()).toString();
}

@Override
public String getAddress()
{
return instance.getAddress();
}

@Override
public int getPort()
{
return instance.getPort();
}
return TO_SERVER.apply(instance);
}

@Override
public String getScheme()
{
return "http";
}
};
public Collection<Server> getAll()
{
try {
return Collections2.transform(serviceProvider.getAllInstances(), TO_SERVER);
}
catch (Exception e) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blanket exception catching should handle InterruptedException properly

log.info(e, "Unable to get all instances");
return Collections.emptyList();
}
}

@LifecycleStart
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.MediaType;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLDecoder;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -110,18 +113,46 @@ public AsyncQueryForwardingServlet(
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
final boolean isSmile = SmileMediaTypes.APPLICATION_JACKSON_SMILE.equals(request.getContentType()) || APPLICATION_SMILE.equals(request.getContentType());
final boolean isSmile = SmileMediaTypes.APPLICATION_JACKSON_SMILE.equals(request.getContentType())
|| APPLICATION_SMILE.equals(request.getContentType());
final ObjectMapper objectMapper = isSmile ? smileMapper : jsonMapper;
request.setAttribute(OBJECTMAPPER_ATTRIBUTE, objectMapper);

String host = hostFinder.getDefaultHost();
request.setAttribute(HOST_ATTRIBUTE, host);

boolean isQuery = request.getMethod().equals(HttpMethod.POST.asString()) &&
request.getRequestURI().startsWith("/druid/v2");

// queries only exist for POST
if (isQuery) {
final String defaultHost = hostFinder.getDefaultHost();
request.setAttribute(HOST_ATTRIBUTE, defaultHost);

final boolean isQueryEndpoint = request.getRequestURI().startsWith("/druid/v2");

if (isQueryEndpoint && HttpMethod.DELETE.is(request.getMethod())) {
// query cancellation request
for (final String host : hostFinder.getAllHosts()) {
// send query cancellation to all brokers this query may have gone to
// to keep the code simple, the proxy servlet will also send a request to one of the default brokers
if (!host.equals(defaultHost)) {
// issue async requests
getHttpClient()
.newRequest(rewriteURI(request, host))
.method(HttpMethod.DELETE)
.send(
new Response.CompleteListener()
{
@Override
public void onComplete(Result result)
{
if (result.isFailed()) {
log.warn(
result.getFailure(),
"Failed to forward cancellation request to [%s]",
host
);
}
}
}
);
}
}
} else if (isQueryEndpoint && HttpMethod.POST.is(request.getMethod())) {
// query request
try {
Query inputQuery = objectMapper.readValue(request.getInputStream(), Query.class);
if (inputQuery != null) {
Expand Down Expand Up @@ -172,7 +203,8 @@ protected void customizeProxyRequest(Request proxyRequest, HttpServletRequest re
final ObjectMapper objectMapper = (ObjectMapper) request.getAttribute(OBJECTMAPPER_ATTRIBUTE);
try {
proxyRequest.content(new BytesContentProvider(objectMapper.writeValueAsBytes(query)));
} catch(JsonProcessingException e) {
}
catch (JsonProcessingException e) {
Throwables.propagate(e);
}
}
Expand All @@ -194,16 +226,29 @@ protected Response.Listener newProxyResponseListener(
@Override
protected URI rewriteURI(HttpServletRequest request)
{
final String host = (String) request.getAttribute(HOST_ATTRIBUTE);
final StringBuilder uri = new StringBuilder("http://");

uri.append(host);
uri.append(request.getRequestURI());
final String queryString = request.getQueryString();
if (queryString != null) {
uri.append("?").append(queryString);
return rewriteURI(request, (String) request.getAttribute(HOST_ATTRIBUTE));
}

protected URI rewriteURI(HttpServletRequest request, String host)
{
return makeURI(host, request.getRequestURI(), request.getQueryString());
}

protected static URI makeURI(String host, String requestURI, String rawQueryString)
{
try {
return new URI(
"http",
host,
requestURI,
rawQueryString == null ? null : URLDecoder.decode(rawQueryString, "UTF-8"),
null
);
}
catch (UnsupportedEncodingException | URISyntaxException e) {
log.error(e, "Unable to rewrite URI [%s]", e.getMessage());
throw Throwables.propagate(e);
}
return URI.create(uri.toString());
}

@Override
Expand Down Expand Up @@ -261,7 +306,7 @@ public void onComplete(Result result)
try {
emitter.emit(
DruidMetrics.makeQueryTimeMetric(jsonMapper, query, req.getRemoteAddr())
.build("query/time", requestTime)
.build("query/time", requestTime)
);

requestLogger.log(
Expand Down
39 changes: 33 additions & 6 deletions server/src/main/java/io/druid/server/router/QueryHostFinder.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package io.druid.server.router;

import com.google.common.base.Function;
import com.google.common.collect.FluentIterable;
import com.google.inject.Inject;
import com.metamx.common.ISE;
import com.metamx.common.Pair;
Expand All @@ -25,17 +27,18 @@
import io.druid.curator.discovery.ServerDiscoverySelector;
import io.druid.query.Query;

import java.util.Collection;
import java.util.concurrent.ConcurrentHashMap;

/**
*/
public class QueryHostFinder<T>
public class QueryHostFinder
{
private static EmittingLogger log = new EmittingLogger(QueryHostFinder.class);

private final TieredBrokerHostSelector hostSelector;

private final ConcurrentHashMap<String, Server> serverBackup = new ConcurrentHashMap<String, Server>();
private final ConcurrentHashMap<String, Server> serverBackup = new ConcurrentHashMap<>();

@Inject
public QueryHostFinder(
Expand All @@ -45,7 +48,7 @@ public QueryHostFinder(
this.hostSelector = hostSelector;
}

public Server findServer(Query<T> query)
public <T> Server findServer(Query<T> query)
{
final Pair<String, ServerDiscoverySelector> selected = hostSelector.select(query);
return findServerInner(selected);
Expand All @@ -57,7 +60,30 @@ public Server findDefaultServer()
return findServerInner(selected);
}

public String getHost(Query<T> query)
public Collection<String> getAllHosts()
{
return FluentIterable
.from((Collection<ServerDiscoverySelector>) hostSelector.getAllBrokers().values())
.transformAndConcat(
new Function<ServerDiscoverySelector, Iterable<Server>>()
{
@Override
public Iterable<Server> apply(ServerDiscoverySelector input)
{
return input.getAll();
}
}
).transform(new Function<Server, String>()
{
@Override
public String apply(Server input)
{
return input.getHost();
}
}).toList();
}

public <T> String getHost(Query<T> query)
{
Server server = findServer(query);

Expand All @@ -69,9 +95,10 @@ public String getHost(Query<T> query)
throw new ISE("No server found for query[%s]", query);
}

log.debug("Selected [%s]", server.getHost());
final String host = server.getHost();
log.debug("Selected [%s]", host);

return server.getHost();
return host;
}

public String getDefaultHost()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.joda.time.DateTime;
import org.joda.time.Interval;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -201,4 +202,9 @@ public Pair<String, ServerDiscoverySelector> getDefaultLookup()
final ServerDiscoverySelector retVal = selectorMap.get(brokerServiceName);
return new Pair<>(brokerServiceName, retVal);
}

public Map<String, ServerDiscoverySelector> getAllBrokers()
{
return Collections.unmodifiableMap(selectorMap);
}
}
Loading