Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ public void testGetActiveTaskRedactsPassword() throws JsonProcessingException
final String password = "AbCd_1234";
final ObjectMapper mapper = getObjectMapper();

final HttpInputSourceConfig httpInputSourceConfig = new HttpInputSourceConfig(Collections.singleton("http"));
final HttpInputSourceConfig httpInputSourceConfig = new HttpInputSourceConfig(Collections.singleton("http"), null);
mapper.setInjectableValues(new InjectableValues.Std()
.addValue(HttpInputSourceConfig.class, httpInputSourceConfig)
.addValue(ObjectMapper.class, new DefaultObjectMapper())
Expand Down Expand Up @@ -562,6 +562,7 @@ public void testGetActiveTaskRedactsPassword() throws JsonProcessingException
"user",
new DefaultPasswordProvider(password),
null,
null,
httpInputSourceConfig),
new NoopInputFormat(),
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.net.URI;
import java.net.URLConnection;
import java.util.Base64;
import java.util.Map;

public class HttpEntity extends RetryingInputEntity
{
Expand All @@ -45,15 +46,19 @@ public class HttpEntity extends RetryingInputEntity
@Nullable
private final PasswordProvider httpAuthenticationPasswordProvider;

private final Map<String, String> requestHeaders;

HttpEntity(
URI uri,
@Nullable String httpAuthenticationUsername,
@Nullable PasswordProvider httpAuthenticationPasswordProvider
@Nullable PasswordProvider httpAuthenticationPasswordProvider,
@Nullable Map<String, String> requestHeaders
)
{
this.uri = uri;
this.httpAuthenticationUsername = httpAuthenticationUsername;
this.httpAuthenticationPasswordProvider = httpAuthenticationPasswordProvider;
this.requestHeaders = requestHeaders;
}

@Override
Expand All @@ -65,7 +70,7 @@ public URI getUri()
@Override
protected InputStream readFrom(long offset) throws IOException
{
return openInputStream(uri, httpAuthenticationUsername, httpAuthenticationPasswordProvider, offset);
return openInputStream(uri, httpAuthenticationUsername, httpAuthenticationPasswordProvider, offset, requestHeaders);
}

@Override
Expand All @@ -80,10 +85,15 @@ public Predicate<Throwable> getRetryCondition()
return t -> t instanceof IOException;
}

public static InputStream openInputStream(URI object, String userName, PasswordProvider passwordProvider, long offset)
public static InputStream openInputStream(URI object, String userName, PasswordProvider passwordProvider, long offset, final Map<String, String> requestHeaders)
throws IOException
{
final URLConnection urlConnection = object.toURL().openConnection();
if (requestHeaders != null && requestHeaders.size() > 0) {
for (Map.Entry<String, String> entry : requestHeaders.entrySet()) {
urlConnection.addRequestProperty(entry.getKey(), entry.getValue());
}
}
if (!Strings.isNullOrEmpty(userName) && passwordProvider != null) {
String userPass = userName + ":" + passwordProvider.getPassword();
String basicAuthString = "Basic " + Base64.getEncoder().encodeToString(StringUtils.toUtf8(userPass));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.druid.data.input.impl.systemfield.SystemFieldDecoratorFactory;
import org.apache.druid.data.input.impl.systemfield.SystemFieldInputSource;
import org.apache.druid.data.input.impl.systemfield.SystemFields;
import org.apache.druid.error.InvalidInput;
import org.apache.druid.java.util.common.CloseableIterators;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
Expand All @@ -47,6 +48,7 @@
import java.net.URI;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;
Expand All @@ -64,13 +66,15 @@ public class HttpInputSource
private final PasswordProvider httpAuthenticationPasswordProvider;
private final SystemFields systemFields;
private final HttpInputSourceConfig config;
private final Map<String, String> requestHeaders;

@JsonCreator
public HttpInputSource(
@JsonProperty("uris") List<URI> uris,
@JsonProperty("httpAuthenticationUsername") @Nullable String httpAuthenticationUsername,
@JsonProperty("httpAuthenticationPassword") @Nullable PasswordProvider httpAuthenticationPasswordProvider,
@JsonProperty(SYSTEM_FIELDS_PROPERTY) @Nullable SystemFields systemFields,
@JsonProperty("requestHeaders") @Nullable Map<String, String> requestHeaders,
@JacksonInject HttpInputSourceConfig config
)
{
Expand All @@ -80,17 +84,11 @@ public HttpInputSource(
this.httpAuthenticationUsername = httpAuthenticationUsername;
this.httpAuthenticationPasswordProvider = httpAuthenticationPasswordProvider;
this.systemFields = systemFields == null ? SystemFields.none() : systemFields;
this.requestHeaders = requestHeaders == null ? Collections.emptyMap() : requestHeaders;
throwIfForbiddenHeaders(config, this.requestHeaders);
this.config = config;
}

@JsonIgnore
@Nonnull
@Override
public Set<String> getTypes()
{
return Collections.singleton(TYPE_KEY);
}

public static void throwIfInvalidProtocols(HttpInputSourceConfig config, List<URI> uris)
{
for (URI uri : uris) {
Expand All @@ -100,6 +98,27 @@ public static void throwIfInvalidProtocols(HttpInputSourceConfig config, List<UR
}
}

public static void throwIfForbiddenHeaders(HttpInputSourceConfig config, Map<String, String> requestHeaders)
{
if (config.getAllowedHeaders().size() > 0) {
for (Map.Entry<String, String> entry : requestHeaders.entrySet()) {
if (!config.getAllowedHeaders().contains(StringUtils.toLowerCase(entry.getKey()))) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the keys in allowedHeaders always lower case?

throw InvalidInput.exception("Got forbidden header %s, allowed headers are only %s ",
entry.getKey(), config.getAllowedHeaders()
);
}
}
}
}

@JsonIgnore
@Nonnull
@Override
public Set<String> getTypes()
{
return Collections.singleton(TYPE_KEY);
}

@JsonProperty
public List<URI> getUris()
{
Expand Down Expand Up @@ -128,6 +147,14 @@ public PasswordProvider getHttpAuthenticationPasswordProvider()
return httpAuthenticationPasswordProvider;
}

@Nullable
@JsonProperty("requestHeaders")
@JsonInclude(JsonInclude.Include.NON_NULL)
public Map<String, String> getRequestHeaders()
{
return requestHeaders;
}

@Override
public Stream<InputSplit<URI>> createSplits(InputFormat inputFormat, @Nullable SplitHintSpec splitHintSpec)
{
Expand All @@ -148,6 +175,7 @@ public SplittableInputSource<URI> withSplit(InputSplit<URI> split)
httpAuthenticationUsername,
httpAuthenticationPasswordProvider,
systemFields,
requestHeaders,
config
);
}
Expand Down Expand Up @@ -181,7 +209,8 @@ protected InputSourceReader formattableReader(
createSplits(inputFormat, null).map(split -> new HttpEntity(
split.get(),
httpAuthenticationUsername,
httpAuthenticationPasswordProvider
httpAuthenticationPasswordProvider,
requestHeaders
)).iterator()
),
SystemFieldDecoratorFactory.fromInputSource(this),
Expand All @@ -203,13 +232,21 @@ public boolean equals(Object o)
&& Objects.equals(httpAuthenticationUsername, that.httpAuthenticationUsername)
&& Objects.equals(httpAuthenticationPasswordProvider, that.httpAuthenticationPasswordProvider)
&& Objects.equals(systemFields, that.systemFields)
&& Objects.equals(requestHeaders, that.requestHeaders)
&& Objects.equals(config, that.config);
}

@Override
public int hashCode()
{
return Objects.hash(uris, httpAuthenticationUsername, httpAuthenticationPasswordProvider, systemFields, config);
return Objects.hash(
uris,
httpAuthenticationUsername,
httpAuthenticationPasswordProvider,
systemFields,
requestHeaders,
config
);
}

@Override
Expand All @@ -226,6 +263,7 @@ public String toString()
", httpAuthenticationUsername=" + httpAuthenticationUsername +
", httpAuthenticationPasswordProvider=" + httpAuthenticationPasswordProvider +
(systemFields.getFields().isEmpty() ? "" : ", systemFields=" + systemFields) +
", requestHeaders = " + requestHeaders +
"}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.druid.java.util.common.StringUtils;

import javax.annotation.Nullable;
import java.util.Collections;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
Expand All @@ -38,21 +39,33 @@ public class HttpInputSourceConfig
@JsonProperty
private final Set<String> allowedProtocols;

@JsonProperty
private final Set<String> allowedHeaders;

@JsonCreator
public HttpInputSourceConfig(
@JsonProperty("allowedProtocols") @Nullable Set<String> allowedProtocols
@JsonProperty("allowedProtocols") @Nullable Set<String> allowedProtocols,
@JsonProperty("allowedHeaders") @Nullable Set<String> allowedHeaders
)
{
this.allowedProtocols = allowedProtocols == null || allowedProtocols.isEmpty()
? DEFAULT_ALLOWED_PROTOCOLS
: allowedProtocols.stream().map(StringUtils::toLowerCase).collect(Collectors.toSet());
this.allowedHeaders = allowedHeaders == null
? Collections.emptySet()
: allowedHeaders.stream().map(StringUtils::toLowerCase).collect(Collectors.toSet());
}

public Set<String> getAllowedProtocols()
{
return allowedProtocols;
}

public Set<String> getAllowedHeaders()
{
return allowedHeaders;
}

@Override
public boolean equals(Object o)
{
Expand All @@ -63,20 +76,24 @@ public boolean equals(Object o)
return false;
}
HttpInputSourceConfig that = (HttpInputSourceConfig) o;
return Objects.equals(allowedProtocols, that.allowedProtocols);
return Objects.equals(allowedProtocols, that.allowedProtocols) && Objects.equals(
allowedHeaders,
that.allowedHeaders
);
}

@Override
public int hashCode()
{
return Objects.hash(allowedProtocols);
return Objects.hash(allowedProtocols, allowedHeaders);
}

@Override
public String toString()
{
return "HttpInputSourceConfig{" +
"allowedProtocols=" + allowedProtocols +
", allowedHeaders=" + allowedHeaders +
'}';
}
}
Expand Down
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add one test with non-empty headers map?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added test

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

package org.apache.druid.data.input.impl;

import com.google.common.collect.ImmutableMap;
import com.google.common.net.HttpHeaders;
import com.sun.net.httpserver.Headers;
import com.sun.net.httpserver.HttpServer;
import org.apache.commons.io.IOUtils;
import org.apache.druid.java.util.common.StringUtils;
Expand All @@ -42,6 +44,8 @@
import java.net.URL;
import java.net.URLConnection;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;

public class HttpEntityTest
{
Expand Down Expand Up @@ -96,8 +100,61 @@ public void testOpenInputStream() throws IOException, URISyntaxException
server.start();

URI url = new URI("http://" + server.getAddress().getHostName() + ":" + server.getAddress().getPort() + "/test");
inputStream = HttpEntity.openInputStream(url, "", null, 0);
inputStreamPartial = HttpEntity.openInputStream(url, "", null, 5);
inputStream = HttpEntity.openInputStream(url, "", null, 0, Collections.emptyMap());
inputStreamPartial = HttpEntity.openInputStream(url, "", null, 5, Collections.emptyMap());
inputStream.skip(5);

Check notice

Code scanning / CodeQL

Ignored error status of call

Method testOpenInputStream ignores exceptional return value of InputStream.skip.
Assert.assertTrue(IOUtils.contentEquals(inputStream, inputStreamPartial));
}
finally {
IOUtils.closeQuietly(inputStream);
IOUtils.closeQuietly(inputStreamPartial);
if (server != null) {
server.stop(0);
}
if (serverSocket != null) {
serverSocket.close();
}
}
}

@Test
public void testRequestHeaders() throws IOException, URISyntaxException
{
HttpServer server = null;
InputStream inputStream = null;
InputStream inputStreamPartial = null;
ServerSocket serverSocket = null;
Map<String, String> requestHeaders = ImmutableMap.of("r-Cookie", "test", "Content-Type", "application/json");
try {
serverSocket = new ServerSocket(0);
int port = serverSocket.getLocalPort();
// closing port so that the httpserver can use. Can cause race conditions.
serverSocket.close();
server = HttpServer.create(new InetSocketAddress("localhost", port), 0);
server.createContext(
"/test",
(httpExchange) -> {
Headers headers = httpExchange.getRequestHeaders();
for (Map.Entry<String, String> entry : requestHeaders.entrySet()) {
Assert.assertTrue(headers.containsKey(entry.getKey()));
Assert.assertEquals(headers.get(entry.getKey()).get(0), entry.getValue());
}
String payload = "12345678910";
byte[] outputBytes = payload.getBytes(StandardCharsets.UTF_8);
httpExchange.sendResponseHeaders(200, outputBytes.length);
OutputStream os = httpExchange.getResponseBody();
httpExchange.getResponseHeaders().set(HttpHeaders.CONTENT_TYPE, "application/octet-stream");
httpExchange.getResponseHeaders().set(HttpHeaders.CONTENT_LENGTH, String.valueOf(outputBytes.length));
httpExchange.getResponseHeaders().set(HttpHeaders.CONTENT_RANGE, "bytes 0");
os.write(outputBytes);
os.close();
}
);
server.start();

URI url = new URI("http://" + server.getAddress().getHostName() + ":" + server.getAddress().getPort() + "/test");
inputStream = HttpEntity.openInputStream(url, "", null, 0, requestHeaders);
inputStreamPartial = HttpEntity.openInputStream(url, "", null, 5, requestHeaders);
inputStream.skip(5);
Assert.assertTrue(IOUtils.contentEquals(inputStream, inputStreamPartial));
}
Expand All @@ -119,7 +176,7 @@ public void testWithServerSupportingRanges() throws IOException
long offset = 15;
String contentRange = StringUtils.format("bytes %d-%d/%d", offset, 1000, 1000);
Mockito.when(urlConnection.getHeaderField(HttpHeaders.CONTENT_RANGE)).thenReturn(contentRange);
HttpEntity.openInputStream(uri, "", null, offset);
HttpEntity.openInputStream(uri, "", null, offset, Collections.emptyMap());
Mockito.verify(inputStreamMock, Mockito.times(0)).skip(offset);
}

Expand All @@ -128,7 +185,7 @@ public void testWithServerNotSupportingRanges() throws IOException
{
long offset = 15;
Mockito.when(urlConnection.getHeaderField(HttpHeaders.CONTENT_RANGE)).thenReturn(null);
HttpEntity.openInputStream(uri, "", null, offset);
HttpEntity.openInputStream(uri, "", null, offset, Collections.emptyMap());
Mockito.verify(inputStreamMock, Mockito.times(1)).skip(offset);
}

Expand All @@ -137,7 +194,7 @@ public void testWithServerNotSupportingBytesRanges() throws IOException
{
long offset = 15;
Mockito.when(urlConnection.getHeaderField(HttpHeaders.CONTENT_RANGE)).thenReturn("token 2-12/12");
HttpEntity.openInputStream(uri, "", null, offset);
HttpEntity.openInputStream(uri, "", null, offset, Collections.emptyMap());
Mockito.verify(inputStreamMock, Mockito.times(1)).skip(offset);
}
}
Loading