diff --git a/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java b/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java index 706036e5e6f2..4ef7d5246db5 100644 --- a/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java +++ b/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java @@ -22,6 +22,10 @@ import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.base.Function; import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -30,7 +34,9 @@ import com.google.common.io.ByteSource; import com.google.common.util.concurrent.SettableFuture; import com.google.inject.Inject; +import com.metamx.common.Pair; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.common.config.JacksonConfigManager; @@ -46,8 +52,17 @@ import io.druid.indexing.overlord.TaskStorageQueryAdapter; import io.druid.indexing.overlord.WorkerTaskRunner; import io.druid.indexing.overlord.autoscaling.ScalingStats; +import io.druid.indexing.overlord.http.security.TaskResourceFilter; import io.druid.indexing.overlord.setup.WorkerBehaviorConfig; import io.druid.metadata.EntryExistsException; +import io.druid.server.http.security.ConfigResourceFilter; +import io.druid.server.http.security.StateResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import io.druid.tasklogs.TaskLogStreamer; import io.druid.timeline.DataSegment; import org.joda.time.DateTime; @@ -63,11 +78,13 @@ import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.io.IOException; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -85,6 +102,7 @@ public class OverlordResource private final TaskLogStreamer taskLogStreamer; private final JacksonConfigManager configManager; private final AuditManager auditManager; + private final AuthConfig authConfig; private AtomicReference workerConfigRef = null; @@ -94,7 +112,8 @@ public OverlordResource( TaskStorageQueryAdapter taskStorageQueryAdapter, TaskLogStreamer taskLogStreamer, JacksonConfigManager configManager, - AuditManager auditManager + AuditManager auditManager, + AuthConfig authConfig ) throws Exception { this.taskMaster = taskMaster; @@ -102,14 +121,35 @@ public OverlordResource( this.taskLogStreamer = taskLogStreamer; this.configManager = configManager; this.auditManager = auditManager; + this.authConfig = authConfig; } @POST @Path("/task") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public Response taskPost(final Task task) + public Response taskPost( + final Task task, + @Context final HttpServletRequest req + ) { + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String dataSource = task.getDataSource(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.WRITE + ); + if (!authResult.isAllowed()) { + return Response.status(Response.Status.FORBIDDEN).header("Access-Check-Result", authResult).build(); + } + } + return asLeaderWith( taskMaster.getTaskQueue(), new Function() @@ -133,6 +173,7 @@ public Response apply(TaskQueue taskQueue) @GET @Path("/leader") + @ResourceFilters(StateResourceFilter.class) @Produces(MediaType.APPLICATION_JSON) public Response getLeader() { @@ -142,6 +183,7 @@ public Response getLeader() @GET @Path("/task/{taskid}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response getTaskPayload(@PathParam("taskid") String taskid) { return optionalTaskResponse(taskid, "payload", taskStorageQueryAdapter.getTask(taskid)); @@ -150,6 +192,7 @@ public Response getTaskPayload(@PathParam("taskid") String taskid) @GET @Path("/task/{taskid}/status") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response getTaskStatus(@PathParam("taskid") String taskid) { return optionalTaskResponse(taskid, "status", taskStorageQueryAdapter.getStatus(taskid)); @@ -158,6 +201,7 @@ public Response getTaskStatus(@PathParam("taskid") String taskid) @GET @Path("/task/{taskid}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response getTaskSegments(@PathParam("taskid") String taskid) { final Set segments = taskStorageQueryAdapter.getInsertedSegments(taskid); @@ -167,6 +211,7 @@ public Response getTaskSegments(@PathParam("taskid") String taskid) @POST @Path("/task/{taskid}/shutdown") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response doShutdown(@PathParam("taskid") final String taskid) { return asLeaderWith( @@ -186,6 +231,7 @@ public Response apply(TaskQueue taskQueue) @GET @Path("/worker") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response getWorkerConfig() { if (workerConfigRef == null) { @@ -199,11 +245,12 @@ public Response getWorkerConfig() @POST @Path("/worker") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response setWorkerConfig( final WorkerBehaviorConfig workerBehaviorConfig, @HeaderParam(AuditManager.X_DRUID_AUTHOR) @DefaultValue("") final String author, @HeaderParam(AuditManager.X_DRUID_COMMENT) @DefaultValue("") final String comment, - @Context HttpServletRequest req + @Context final HttpServletRequest req ) { if (!configManager.set( @@ -222,6 +269,7 @@ public Response setWorkerConfig( @GET @Path("/worker/history") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response getWorkerConfigHistory( @QueryParam("interval") final String interval, @QueryParam("count") final Integer count @@ -258,6 +306,7 @@ public Response getWorkerConfigHistory( @POST @Path("/action") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response doAction(final TaskActionHolder holder) { return asLeaderWith( @@ -292,7 +341,7 @@ public Response apply(TaskActionClient taskActionClient) @GET @Path("/waitingTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getWaitingTasks() + public Response getWaitingTasks(@Context final HttpServletRequest req) { return workItemsResponse( new Function>() @@ -302,7 +351,38 @@ public Collection apply(TaskRunner taskRunner) { // A bit roundabout, but works as a way of figuring out what tasks haven't been handed // off to the runner yet: - final List activeTasks = taskStorageQueryAdapter.getActiveTasks(); + final List allActiveTasks = taskStorageQueryAdapter.getActiveTasks(); + final List activeTasks; + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + activeTasks = ImmutableList.copyOf( + Iterables.filter( + allActiveTasks, + new Predicate() + { + @Override + public boolean apply(Task input) + { + Resource resource = new Resource(input.getDataSource(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ) + ); + } else { + activeTasks = allActiveTasks; + } final Set runnersKnownTasks = Sets.newHashSet( Iterables.transform( taskRunner.getKnownTasks(), @@ -346,7 +426,7 @@ public TaskLocation getLocation() @GET @Path("/pendingTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getPendingTasks() + public Response getPendingTasks(@Context final HttpServletRequest req) { return workItemsResponse( new Function>() @@ -354,7 +434,13 @@ public Response getPendingTasks() @Override public Collection apply(TaskRunner taskRunner) { - return taskRunner.getPendingTasks(); + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + return securedTaskRunnerWorkItem(taskRunner.getPendingTasks(), req); + } else { + return taskRunner.getPendingTasks(); + } + } } ); @@ -363,7 +449,7 @@ public Collection apply(TaskRunner taskRunner) @GET @Path("/runningTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getRunningTasks() + public Response getRunningTasks(@Context final HttpServletRequest req) { return workItemsResponse( new Function>() @@ -371,7 +457,12 @@ public Response getRunningTasks() @Override public Collection apply(TaskRunner taskRunner) { - return taskRunner.getRunningTasks(); + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + return securedTaskRunnerWorkItem(taskRunner.getRunningTasks(), req); + } else { + return taskRunner.getRunningTasks(); + } } } ); @@ -380,10 +471,50 @@ public Collection apply(TaskRunner taskRunner) @GET @Path("/completeTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getCompleteTasks() + public Response getCompleteTasks(@Context final HttpServletRequest req) { + final List recentlyFinishedTasks; + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + recentlyFinishedTasks = ImmutableList.copyOf( + Iterables.filter( + taskStorageQueryAdapter.getRecentlyFinishedTaskStatuses(), + new Predicate() + { + @Override + public boolean apply(TaskStatus input) + { + final String taskId = input.getId(); + final Optional optionalTask = taskStorageQueryAdapter.getTask(taskId); + if (!optionalTask.isPresent()) { + throw new WebApplicationException( + Response.serverError().entity( + String.format("No task information found for task with id: [%s]", taskId) + ).build() + ); + } + Resource resource = new Resource(optionalTask.get().getDataSource(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ) + ); + } else { + recentlyFinishedTasks = taskStorageQueryAdapter.getRecentlyFinishedTaskStatuses(); + } + final List completeTasks = Lists.transform( - taskStorageQueryAdapter.getRecentlyFinishedTaskStatuses(), + recentlyFinishedTasks, new Function() { @Override @@ -406,6 +537,7 @@ public TaskResponseObject apply(TaskStatus taskStatus) @GET @Path("/workers") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getWorkers() { return asLeaderWith( @@ -435,6 +567,7 @@ public Response apply(TaskRunner taskRunner) @GET @Path("/scaling") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getScalingState() { // Don't use asLeaderWith, since we want to return 200 instead of 503 when missing an autoscaler. @@ -449,6 +582,7 @@ public Response getScalingState() @GET @Path("/task/{taskid}/log") @Produces("text/plain") + @ResourceFilters(TaskResourceFilter.class) public Response doGetLog( @PathParam("taskid") final String taskid, @QueryParam("offset") @DefaultValue("0") final long offset @@ -528,6 +662,45 @@ private Response asLeaderWith(Optional x, Function f) } } + private Collection securedTaskRunnerWorkItem( + Collection collectionToFilter, + HttpServletRequest req + ) + { + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + return Collections2.filter( + collectionToFilter, + new Predicate() + { + @Override + public boolean apply(TaskRunnerWorkItem input) + { + final String taskId = input.getTaskId(); + final Optional optionalTask = taskStorageQueryAdapter.getTask(taskId); + if (!optionalTask.isPresent()) { + throw new WebApplicationException( + Response.serverError().entity( + String.format("No task information found for task with id: [%s]", taskId) + ).build() + ); + } + Resource resource = new Resource(optionalTask.get().getDataSource(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ); + } + static class TaskResponseObject { private final String id; diff --git a/indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java b/indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java new file mode 100644 index 000000000000..0866658c08a7 --- /dev/null +++ b/indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java @@ -0,0 +1,123 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.indexing.overlord.http.security; + +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.indexing.common.task.Task; +import io.druid.indexing.overlord.TaskStorageQueryAdapter; +import io.druid.server.http.security.AbstractResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.PathSegment; +import javax.ws.rs.core.Response; +import java.util.List; + +/** + * Use this ResourceFilter when the datasource information is present after "task" segment in the request Path + * Here are some example paths where this filter is used - + * - druid/indexer/v1/task/{taskid}/... + * Note - DO NOT use this filter at MiddleManager resources as TaskStorageQueryAdapter cannot be injected there + */ +public class TaskResourceFilter extends AbstractResourceFilter +{ + private final TaskStorageQueryAdapter taskStorageQueryAdapter; + + @Inject + public TaskResourceFilter(TaskStorageQueryAdapter taskStorageQueryAdapter, AuthConfig authConfig) { + super(authConfig); + this.taskStorageQueryAdapter = taskStorageQueryAdapter; + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String taskId = Preconditions.checkNotNull( + request.getPathSegments() + .get( + Iterables.indexOf( + request.getPathSegments(), + new Predicate() + { + @Override + public boolean apply(PathSegment input) + { + return input.getPath().equals("task"); + } + } + ) + 1 + ).getPath() + ); + + Optional taskOptional = taskStorageQueryAdapter.getTask(taskId); + if (!taskOptional.isPresent()) { + throw new WebApplicationException( + Response.status(Response.Status.BAD_REQUEST) + .entity(String.format("Cannot find any task with id: [%s]", taskId)) + .build() + ); + } + final String dataSourceName = Preconditions.checkNotNull(taskOptional.get().getDataSource()); + + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + final Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSourceName, ResourceType.DATASOURCE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException(Response.status(Response.Status.FORBIDDEN) + .entity( + String.format("Access-Check-Result: %s", authResult.toString()) + ) + .build()); + } + } + + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + List applicablePaths = ImmutableList.of("druid/indexer/v1/task/"); + for (String path : applicablePaths) { + if(requestPath.startsWith(path) && !requestPath.equals(path)) { + return true; + } + } + return false; + } +} diff --git a/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java b/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java index 9bb3bdc44b67..49641462e912 100644 --- a/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java +++ b/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java @@ -27,10 +27,13 @@ import com.google.common.io.ByteSource; import com.google.inject.Inject; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.indexing.overlord.TaskRunner; import io.druid.indexing.overlord.TaskRunnerWorkItem; import io.druid.indexing.worker.Worker; import io.druid.indexing.worker.WorkerCuratorCoordinator; +import io.druid.server.http.security.ConfigResourceFilter; +import io.druid.server.http.security.StateResourceFilter; import io.druid.tasklogs.TaskLogStreamer; import javax.ws.rs.DefaultValue; @@ -73,6 +76,7 @@ public WorkerResource( @POST @Path("/disable") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response doDisable() { try { @@ -93,6 +97,7 @@ public Response doDisable() @POST @Path("/enable") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response doEnable() { try { @@ -107,6 +112,7 @@ public Response doEnable() @GET @Path("/enabled") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response isEnabled() { try { @@ -122,6 +128,7 @@ public Response isEnabled() @GET @Path("/tasks") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getTasks() { try { @@ -149,6 +156,7 @@ public String apply(TaskRunnerWorkItem input) @POST @Path("/task/{taskid}/shutdown") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response doShutdown(@PathParam("taskid") String taskid) { try { @@ -164,6 +172,7 @@ public Response doShutdown(@PathParam("taskid") String taskid) @GET @Path("/task/{taskid}/log") @Produces("text/plain") + @ResourceFilters(StateResourceFilter.class) public Response doGetLog( @PathParam("taskid") String taskid, @QueryParam("offset") @DefaultValue("0") long offset diff --git a/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java index 5ef4fd3c8c03..173bd905c37a 100644 --- a/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java +++ b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java @@ -22,379 +22,226 @@ import com.google.common.base.Function; import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.MoreExecutors; -import com.metamx.common.Pair; -import com.metamx.common.guava.CloseQuietly; -import com.metamx.emitter.EmittingLogger; -import com.metamx.emitter.service.ServiceEmitter; -import io.druid.concurrent.Execs; -import io.druid.curator.PotentiallyGzippedCompressionProvider; -import io.druid.curator.discovery.NoopServiceAnnouncer; import io.druid.indexing.common.TaskLocation; import io.druid.indexing.common.TaskStatus; -import io.druid.indexing.common.actions.TaskActionClientFactory; -import io.druid.indexing.common.config.TaskStorageConfig; +import io.druid.indexing.common.TaskToolbox; +import io.druid.indexing.common.actions.TaskActionClient; +import io.druid.indexing.common.task.AbstractTask; import io.druid.indexing.common.task.NoopTask; import io.druid.indexing.common.task.Task; -import io.druid.indexing.overlord.HeapMemoryTaskStorage; -import io.druid.indexing.overlord.TaskLockbox; import io.druid.indexing.overlord.TaskMaster; import io.druid.indexing.overlord.TaskRunner; -import io.druid.indexing.overlord.TaskRunnerFactory; -import io.druid.indexing.overlord.TaskRunnerListener; import io.druid.indexing.overlord.TaskRunnerWorkItem; -import io.druid.indexing.overlord.TaskStorage; import io.druid.indexing.overlord.TaskStorageQueryAdapter; -import io.druid.indexing.overlord.autoscaling.ScalingStats; -import io.druid.indexing.overlord.config.TaskQueueConfig; -import io.druid.server.DruidNode; -import io.druid.server.initialization.IndexerZkConfig; -import io.druid.server.initialization.ZkPathsConfig; -import io.druid.server.metrics.NoopServiceEmitter; -import org.apache.curator.framework.CuratorFramework; -import org.apache.curator.framework.CuratorFrameworkFactory; -import org.apache.curator.retry.RetryOneTime; -import org.apache.curator.test.TestingServer; -import org.apache.curator.test.Timing; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; import org.easymock.EasyMock; -import org.joda.time.Period; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import javax.annotation.Nullable; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; -import java.util.ArrayList; import java.util.Collection; import java.util.List; -import java.util.Map; -import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; public class OverlordResourceTest { - private static final TaskLocation TASK_LOCATION = new TaskLocation("dummy", 1000); - - private TestingServer server; - private Timing timing; - private CuratorFramework curator; - private TaskMaster taskMaster; - private TaskLockbox taskLockbox; - private TaskStorage taskStorage; - private TaskActionClientFactory taskActionClientFactory; - private CountDownLatch announcementLatch; - private DruidNode druidNode; private OverlordResource overlordResource; - private CountDownLatch[] taskCompletionCountDownLatches; - private CountDownLatch[] runTaskCountDownLatches; - - private void setupServerAndCurator() throws Exception - { - server = new TestingServer(); - timing = new Timing(); - curator = CuratorFrameworkFactory - .builder() - .connectString(server.getConnectString()) - .sessionTimeoutMs(timing.session()) - .connectionTimeoutMs(timing.connection()) - .retryPolicy(new RetryOneTime(1)) - .compressionProvider(new PotentiallyGzippedCompressionProvider(true)) - .build(); - } - - private void tearDownServerAndCurator() - { - CloseQuietly.close(curator); - CloseQuietly.close(server); - } + private TaskMaster taskMaster; + private TaskStorageQueryAdapter tsqa; + private HttpServletRequest req; + private TaskRunner taskRunner; @Before public void setUp() throws Exception { - taskLockbox = EasyMock.createStrictMock(TaskLockbox.class); - taskLockbox.syncFromStorage(); - EasyMock.expectLastCall().atLeastOnce(); - taskLockbox.add(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - taskLockbox.remove(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - - // for second Noop Task directly added to deep storage. - taskLockbox.add(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - taskLockbox.remove(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - - taskActionClientFactory = EasyMock.createStrictMock(TaskActionClientFactory.class); - EasyMock.expect(taskActionClientFactory.create(EasyMock.anyObject())) - .andReturn(null).anyTimes(); - EasyMock.replay(taskLockbox, taskActionClientFactory); + taskRunner = EasyMock.createMock(TaskRunner.class); + taskMaster = EasyMock.createStrictMock(TaskMaster.class); + tsqa = EasyMock.createStrictMock(TaskStorageQueryAdapter.class); + req = EasyMock.createStrictMock(HttpServletRequest.class); + + EasyMock.expect(taskMaster.getTaskRunner()).andReturn( + Optional.of(taskRunner) + ).anyTimes(); + + overlordResource = new OverlordResource( + taskMaster, + tsqa, + null, + null, + null, + new AuthConfig(true) + ); - taskStorage = new HeapMemoryTaskStorage(new TaskStorageConfig(null)); - runTaskCountDownLatches = new CountDownLatch[2]; - runTaskCountDownLatches[0] = new CountDownLatch(1); - runTaskCountDownLatches[1] = new CountDownLatch(1); - taskCompletionCountDownLatches = new CountDownLatch[2]; - taskCompletionCountDownLatches[0] = new CountDownLatch(1); - taskCompletionCountDownLatches[1] = new CountDownLatch(1); - announcementLatch = new CountDownLatch(1); - IndexerZkConfig indexerZkConfig = new IndexerZkConfig(new ZkPathsConfig(), null, null, null, null, null); - setupServerAndCurator(); - curator.start(); - curator.blockUntilConnected(); - curator.create().creatingParentsIfNeeded().forPath(indexerZkConfig.getLeaderLatchPath()); - druidNode = new DruidNode("hey", "what", 1234); - ServiceEmitter serviceEmitter = new NoopServiceEmitter(); - taskMaster = new TaskMaster( - new TaskQueueConfig(null, new Period(1), null, new Period(10)), - taskLockbox, - taskStorage, - taskActionClientFactory, - druidNode, - indexerZkConfig, - new TaskRunnerFactory() - { - @Override - public MockTaskRunner build() - { - return new MockTaskRunner(runTaskCountDownLatches, taskCompletionCountDownLatches); - } - }, - curator, - new NoopServiceAnnouncer() + EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN)).andReturn( + new AuthorizationInfo() { @Override - public void announce(DruidNode node) + public Access isAuthorized( + Resource resource, Action action + ) { - announcementLatch.countDown(); + if (resource.getName().equals("allow")) { + return new Access(true); + } else { + return new Access(false); + } } - }, - serviceEmitter + } ); - EmittingLogger.registerEmitter(serviceEmitter); } - @Test(timeout = 2000L) - public void testOverlordResource() throws Exception + @Test + public void testSecuredGetWaitingTask() throws Exception { - // basic task master lifecycle test - taskMaster.start(); - announcementLatch.await(); - while (!taskMaster.isLeading()) { - // I believe the control will never reach here and thread will never sleep but just to be on safe side - Thread.sleep(10); - } - Assert.assertEquals(taskMaster.getLeader(), druidNode.getHostAndPort()); - // Test Overlord resource stuff - overlordResource = new OverlordResource(taskMaster, new TaskStorageQueryAdapter(taskStorage), null, null, null); - Response response = overlordResource.getLeader(); - Assert.assertEquals(druidNode.getHostAndPort(), response.getEntity()); - - final String taskId_0 = "0"; - NoopTask task_0 = new NoopTask(taskId_0, 0, 0, null, null, null); - response = overlordResource.taskPost(task_0); - Assert.assertEquals(200, response.getStatus()); - Assert.assertEquals(ImmutableMap.of("task", taskId_0), response.getEntity()); + EasyMock.expect(tsqa.getActiveTasks()).andReturn( + ImmutableList.of( + getTaskWithIdAndDatasource("id_1", "allow"), + getTaskWithIdAndDatasource("id_2", "allow"), + getTaskWithIdAndDatasource("id_3", "deny"), + getTaskWithIdAndDatasource("id_4", "deny") + ) + ).once(); + + EasyMock.>expect(taskRunner.getKnownTasks()).andReturn( + ImmutableList.of( + new MockTaskRunnerWorkItem("id_1", null), + new MockTaskRunnerWorkItem("id_4", null) + ) + ); - // Duplicate task - should fail - response = overlordResource.taskPost(task_0); - Assert.assertEquals(400, response.getStatus()); + EasyMock.replay(taskRunner, taskMaster, tsqa, req); - // Task payload for task_0 should be present in taskStorage - response = overlordResource.getTaskPayload(taskId_0); - Assert.assertEquals(task_0, ((Map) response.getEntity()).get("payload")); + List responseObjects = (List) overlordResource.getWaitingTasks(req) + .getEntity(); + Assert.assertEquals(1, responseObjects.size()); + Assert.assertEquals("id_2", responseObjects.get(0).toJson().get("id")); + } - // Task not present in taskStorage - should fail - response = overlordResource.getTaskPayload("whatever"); - Assert.assertEquals(404, response.getStatus()); + @Test + public void testSecuredGetCompleteTasks() + { + List tasksIds = ImmutableList.of("id_1", "id_2", "id_3"); + EasyMock.expect(tsqa.getRecentlyFinishedTaskStatuses()).andReturn( + Lists.transform( + tasksIds, + new Function() + { + @Override + public TaskStatus apply(String input) + { + return TaskStatus.success(input); + } + } + ) + ).once(); + + EasyMock.expect(tsqa.getTask(tasksIds.get(0))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(0), "deny")) + ).once(); + EasyMock.expect(tsqa.getTask(tasksIds.get(1))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(1), "allow")) + ).once(); + EasyMock.expect(tsqa.getTask(tasksIds.get(2))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(2), "allow")) + ).once(); + EasyMock.replay(taskRunner, taskMaster, tsqa, req); + + List responseObjects = (List) overlordResource.getCompleteTasks(req) + .getEntity(); + + Assert.assertEquals(2, responseObjects.size()); + Assert.assertEquals(tasksIds.get(1), responseObjects.get(0).toJson().get("id")); + Assert.assertEquals(tasksIds.get(2), responseObjects.get(1).toJson().get("id")); + } - // Task status of the submitted task should be running - response = overlordResource.getTaskStatus(taskId_0); - Assert.assertEquals(taskId_0, ((Map) response.getEntity()).get("task")); - Assert.assertEquals( - TaskStatus.running(taskId_0).getStatusCode(), - ((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode() + @Test + public void testSecuredGetRunningTasks() + { + List tasksIds = ImmutableList.of("id_1", "id_2"); + EasyMock.>expect(taskRunner.getRunningTasks()).andReturn( + ImmutableList.of( + new MockTaskRunnerWorkItem(tasksIds.get(0), null), + new MockTaskRunnerWorkItem(tasksIds.get(1), null) + ) ); + EasyMock.expect(tsqa.getTask(tasksIds.get(0))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(0), "deny")) + ).once(); + EasyMock.expect(tsqa.getTask(tasksIds.get(1))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(1), "allow")) + ).once(); - // Simulate completion of task_0 - taskCompletionCountDownLatches[Integer.parseInt(taskId_0)].countDown(); - // Wait for taskQueue to handle success status of task_0 - waitForTaskStatus(taskId_0, TaskStatus.Status.SUCCESS); - - // Manually insert task in taskStorage - // Verifies sync from storage - final String taskId_1 = "1"; - NoopTask task_1 = new NoopTask(taskId_1, 0, 0, null, null, null); - taskStorage.insert(task_1, TaskStatus.running(taskId_1)); - // Wait for task runner to run task_1 - runTaskCountDownLatches[Integer.parseInt(taskId_1)].await(); + EasyMock.replay(taskRunner, taskMaster, tsqa, req); - response = overlordResource.getRunningTasks(); - // 1 task that was manually inserted should be in running state - Assert.assertEquals(1, (((List) response.getEntity()).size())); - final OverlordResource.TaskResponseObject taskResponseObject = ((List) response - .getEntity()).get(0); - Assert.assertEquals(taskId_1, taskResponseObject.toJson().get("id")); - Assert.assertEquals(TASK_LOCATION, taskResponseObject.toJson().get("location")); + List responseObjects = (List) overlordResource.getRunningTasks(req) + .getEntity(); - // Simulate completion of task_1 - taskCompletionCountDownLatches[Integer.parseInt(taskId_1)].countDown(); - // Wait for taskQueue to handle success status of task_1 - waitForTaskStatus(taskId_1, TaskStatus.Status.SUCCESS); - - // should return number of tasks which are not in running state - response = overlordResource.getCompleteTasks(); - Assert.assertEquals(2, (((List) response.getEntity()).size())); - taskMaster.stop(); - Assert.assertFalse(taskMaster.isLeading()); - EasyMock.verify(taskLockbox, taskActionClientFactory); + Assert.assertEquals(1, responseObjects.size()); + Assert.assertEquals(tasksIds.get(1), responseObjects.get(0).toJson().get("id")); } - /* Wait until the task with given taskId has the given Task Status - * These method will not timeout until the condition is met so calling method should ensure timeout - * This method also assumes that the task with given taskId is present - * */ - private void waitForTaskStatus(String taskId, TaskStatus.Status status) throws InterruptedException + @Test + public void testSecuredTaskPost() { - while (true) { - Response response = overlordResource.getTaskStatus(taskId); - if (status.equals(((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode())) { - break; - } - Thread.sleep(10); - } + EasyMock.replay(taskRunner, taskMaster, tsqa, req); + Task task = NoopTask.create(); + Response response = overlordResource.taskPost(task, req); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); } @After - public void tearDown() throws Exception + public void tearDown() { - tearDownServerAndCurator(); + EasyMock.verify(taskRunner, taskMaster, tsqa, req); } - public static class MockTaskRunner implements TaskRunner + private Task getTaskWithIdAndDatasource(String id, String datasource) { - private CountDownLatch[] completionLatches; - private CountDownLatch[] runLatches; - private ConcurrentHashMap taskRunnerWorkItems; - private List runningTasks; - private final AtomicBoolean started = new AtomicBoolean(false); - - public MockTaskRunner(CountDownLatch[] runLatches, CountDownLatch[] completionLatches) - { - this.runLatches = runLatches; - this.completionLatches = completionLatches; - this.taskRunnerWorkItems = new ConcurrentHashMap<>(); - this.runningTasks = new ArrayList<>(); - } - - @Override - public List>> restore() + return new AbstractTask(id, datasource, null) { - return ImmutableList.of(); - } - - public void registerListener(TaskRunnerListener listener, Executor executor) - { - // Overlord doesn't call this method - throw new UnsupportedOperationException(); - } - - @Override - public synchronized ListenableFuture run(final Task task) - { - final String taskId = task.getId(); - ListenableFuture future = MoreExecutors.listeningDecorator( - Execs.singleThreaded( - "noop_test_task_exec_%s" - ) - ).submit( - new Callable() - { - @Override - public TaskStatus call() throws Exception - { - // adding of task to list of runningTasks should be done before count down as - // getRunningTasks may not include the task for which latch has been counted down - // Count down to let know that task is actually running - // this is equivalent of getting process holder to run task in ForkingTaskRunner - runningTasks.add(taskId); - runLatches[Integer.parseInt(taskId)].countDown(); - // Wait for completion count down - completionLatches[Integer.parseInt(taskId)].await(); - taskRunnerWorkItems.remove(taskId); - runningTasks.remove(taskId); - return TaskStatus.success(taskId); - } - } - ); - TaskRunnerWorkItem taskRunnerWorkItem = new TaskRunnerWorkItem(taskId, future) + @Override + public String getType() { - @Override - public TaskLocation getLocation() - { - return TASK_LOCATION; - } - }; - taskRunnerWorkItems.put(taskId, taskRunnerWorkItem); - return future; - } - - @Override - public void shutdown(String taskid) {} - - @Override - public synchronized Collection getRunningTasks() - { - List runningTaskList = Lists.transform( - runningTasks, - new Function() - { - @Nullable - @Override - public TaskRunnerWorkItem apply(String input) - { - return taskRunnerWorkItems.get(input); - } - } - ); - return runningTaskList; - } - - @Override - public Collection getPendingTasks() - { - return ImmutableList.of(); - } + return null; + } - @Override - public Collection getKnownTasks() - { - return taskRunnerWorkItems.values(); - } + @Override + public boolean isReady(TaskActionClient taskActionClient) throws Exception + { + return false; + } - @Override - public Optional getScalingStats() - { - return Optional.absent(); - } + @Override + public TaskStatus run(TaskToolbox toolbox) throws Exception + { + return null; + } + }; + } - @Override - public void start() + private static class MockTaskRunnerWorkItem extends TaskRunnerWorkItem + { + public MockTaskRunnerWorkItem( + String taskId, + ListenableFuture result + ) { - started.set(true); + super(taskId, result); } @Override - public void stop() + public TaskLocation getLocation() { - started.set(false); + return null; } } + } diff --git a/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java new file mode 100644 index 000000000000..16df2895f323 --- /dev/null +++ b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java @@ -0,0 +1,413 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.indexing.overlord.http; + +import com.google.common.base.Function; +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import com.metamx.common.Pair; +import com.metamx.common.guava.CloseQuietly; +import com.metamx.emitter.EmittingLogger; +import com.metamx.emitter.service.ServiceEmitter; +import io.druid.concurrent.Execs; +import io.druid.curator.PotentiallyGzippedCompressionProvider; +import io.druid.curator.discovery.NoopServiceAnnouncer; +import io.druid.indexing.common.TaskLocation; +import io.druid.indexing.common.TaskStatus; +import io.druid.indexing.common.actions.TaskActionClientFactory; +import io.druid.indexing.common.config.TaskStorageConfig; +import io.druid.indexing.common.task.NoopTask; +import io.druid.indexing.common.task.Task; +import io.druid.indexing.overlord.HeapMemoryTaskStorage; +import io.druid.indexing.overlord.TaskLockbox; +import io.druid.indexing.overlord.TaskMaster; +import io.druid.indexing.overlord.TaskRunner; +import io.druid.indexing.overlord.TaskRunnerFactory; +import io.druid.indexing.overlord.TaskRunnerListener; +import io.druid.indexing.overlord.TaskRunnerWorkItem; +import io.druid.indexing.overlord.TaskStorage; +import io.druid.indexing.overlord.TaskStorageQueryAdapter; +import io.druid.indexing.overlord.autoscaling.ScalingStats; +import io.druid.indexing.overlord.config.TaskQueueConfig; +import io.druid.server.DruidNode; +import io.druid.server.initialization.IndexerZkConfig; +import io.druid.server.initialization.ZkPathsConfig; +import io.druid.server.metrics.NoopServiceEmitter; +import io.druid.server.security.AuthConfig; +import org.apache.curator.framework.CuratorFramework; +import org.apache.curator.framework.CuratorFrameworkFactory; +import org.apache.curator.retry.RetryOneTime; +import org.apache.curator.test.TestingServer; +import org.apache.curator.test.Timing; +import org.easymock.EasyMock; +import org.joda.time.Period; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import javax.annotation.Nullable; +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.Response; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; + +public class OverlordTest +{ + private static final TaskLocation TASK_LOCATION = new TaskLocation("dummy", 1000); + + private TestingServer server; + private Timing timing; + private CuratorFramework curator; + private TaskMaster taskMaster; + private TaskLockbox taskLockbox; + private TaskStorage taskStorage; + private TaskActionClientFactory taskActionClientFactory; + private CountDownLatch announcementLatch; + private DruidNode druidNode; + private OverlordResource overlordResource; + private CountDownLatch[] taskCompletionCountDownLatches; + private CountDownLatch[] runTaskCountDownLatches; + private HttpServletRequest req; + + private void setupServerAndCurator() throws Exception + { + server = new TestingServer(); + timing = new Timing(); + curator = CuratorFrameworkFactory + .builder() + .connectString(server.getConnectString()) + .sessionTimeoutMs(timing.session()) + .connectionTimeoutMs(timing.connection()) + .retryPolicy(new RetryOneTime(1)) + .compressionProvider(new PotentiallyGzippedCompressionProvider(true)) + .build(); + } + + private void tearDownServerAndCurator() + { + CloseQuietly.close(curator); + CloseQuietly.close(server); + } + + @Before + public void setUp() throws Exception + { + req = EasyMock.createStrictMock(HttpServletRequest.class); + taskLockbox = EasyMock.createStrictMock(TaskLockbox.class); + taskLockbox.syncFromStorage(); + EasyMock.expectLastCall().atLeastOnce(); + taskLockbox.add(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + taskLockbox.remove(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + + // for second Noop Task directly added to deep storage. + taskLockbox.add(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + taskLockbox.remove(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + + taskActionClientFactory = EasyMock.createStrictMock(TaskActionClientFactory.class); + EasyMock.expect(taskActionClientFactory.create(EasyMock.anyObject())) + .andReturn(null).anyTimes(); + EasyMock.replay(taskLockbox, taskActionClientFactory); + + taskStorage = new HeapMemoryTaskStorage(new TaskStorageConfig(null)); + runTaskCountDownLatches = new CountDownLatch[2]; + runTaskCountDownLatches[0] = new CountDownLatch(1); + runTaskCountDownLatches[1] = new CountDownLatch(1); + taskCompletionCountDownLatches = new CountDownLatch[2]; + taskCompletionCountDownLatches[0] = new CountDownLatch(1); + taskCompletionCountDownLatches[1] = new CountDownLatch(1); + announcementLatch = new CountDownLatch(1); + IndexerZkConfig indexerZkConfig = new IndexerZkConfig(new ZkPathsConfig(), null, null, null, null, null); + setupServerAndCurator(); + curator.start(); + curator.blockUntilConnected(); + curator.create().creatingParentsIfNeeded().forPath(indexerZkConfig.getLeaderLatchPath()); + druidNode = new DruidNode("hey", "what", 1234); + ServiceEmitter serviceEmitter = new NoopServiceEmitter(); + taskMaster = new TaskMaster( + new TaskQueueConfig(null, new Period(1), null, new Period(10)), + taskLockbox, + taskStorage, + taskActionClientFactory, + druidNode, + indexerZkConfig, + new TaskRunnerFactory() + { + @Override + public MockTaskRunner build() + { + return new MockTaskRunner(runTaskCountDownLatches, taskCompletionCountDownLatches); + } + }, + curator, + new NoopServiceAnnouncer() + { + @Override + public void announce(DruidNode node) + { + announcementLatch.countDown(); + } + }, + serviceEmitter + ); + EmittingLogger.registerEmitter(serviceEmitter); + } + + @Test(timeout = 2000L) + public void testOverlordRun() throws Exception + { + // basic task master lifecycle test + taskMaster.start(); + announcementLatch.await(); + while (!taskMaster.isLeading()) { + // I believe the control will never reach here and thread will never sleep but just to be on safe side + Thread.sleep(10); + } + Assert.assertEquals(taskMaster.getLeader(), druidNode.getHostAndPort()); + // Test Overlord resource stuff + overlordResource = new OverlordResource( + taskMaster, + new TaskStorageQueryAdapter(taskStorage), + null, + null, + null, + new AuthConfig() + ); + Response response = overlordResource.getLeader(); + Assert.assertEquals(druidNode.getHostAndPort(), response.getEntity()); + + final String taskId_0 = "0"; + NoopTask task_0 = new NoopTask(taskId_0, 0, 0, null, null, null); + response = overlordResource.taskPost(task_0, req); + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(ImmutableMap.of("task", taskId_0), response.getEntity()); + + // Duplicate task - should fail + response = overlordResource.taskPost(task_0, req); + Assert.assertEquals(400, response.getStatus()); + + // Task payload for task_0 should be present in taskStorage + response = overlordResource.getTaskPayload(taskId_0); + Assert.assertEquals(task_0, ((Map) response.getEntity()).get("payload")); + + // Task not present in taskStorage - should fail + response = overlordResource.getTaskPayload("whatever"); + Assert.assertEquals(404, response.getStatus()); + + // Task status of the submitted task should be running + response = overlordResource.getTaskStatus(taskId_0); + Assert.assertEquals(taskId_0, ((Map) response.getEntity()).get("task")); + Assert.assertEquals( + TaskStatus.running(taskId_0).getStatusCode(), + ((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode() + ); + + // Simulate completion of task_0 + taskCompletionCountDownLatches[Integer.parseInt(taskId_0)].countDown(); + // Wait for taskQueue to handle success status of task_0 + waitForTaskStatus(taskId_0, TaskStatus.Status.SUCCESS); + + // Manually insert task in taskStorage + // Verifies sync from storage + final String taskId_1 = "1"; + NoopTask task_1 = new NoopTask(taskId_1, 0, 0, null, null, null); + taskStorage.insert(task_1, TaskStatus.running(taskId_1)); + // Wait for task runner to run task_1 + runTaskCountDownLatches[Integer.parseInt(taskId_1)].await(); + + response = overlordResource.getRunningTasks(req); + // 1 task that was manually inserted should be in running state + Assert.assertEquals(1, (((List) response.getEntity()).size())); + final OverlordResource.TaskResponseObject taskResponseObject = ((List) response + .getEntity()).get(0); + Assert.assertEquals(taskId_1, taskResponseObject.toJson().get("id")); + Assert.assertEquals(TASK_LOCATION, taskResponseObject.toJson().get("location")); + + // Simulate completion of task_1 + taskCompletionCountDownLatches[Integer.parseInt(taskId_1)].countDown(); + // Wait for taskQueue to handle success status of task_1 + waitForTaskStatus(taskId_1, TaskStatus.Status.SUCCESS); + + // should return number of tasks which are not in running state + response = overlordResource.getCompleteTasks(req); + Assert.assertEquals(2, (((List) response.getEntity()).size())); + taskMaster.stop(); + Assert.assertFalse(taskMaster.isLeading()); + EasyMock.verify(taskLockbox, taskActionClientFactory); + } + + /* Wait until the task with given taskId has the given Task Status + * These method will not timeout until the condition is met so calling method should ensure timeout + * This method also assumes that the task with given taskId is present + * */ + private void waitForTaskStatus(String taskId, TaskStatus.Status status) throws InterruptedException + { + while (true) { + Response response = overlordResource.getTaskStatus(taskId); + if (status.equals(((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode())) { + break; + } + Thread.sleep(10); + } + } + + @After + public void tearDown() throws Exception + { + tearDownServerAndCurator(); + } + + public static class MockTaskRunner implements TaskRunner + { + private CountDownLatch[] completionLatches; + private CountDownLatch[] runLatches; + private ConcurrentHashMap taskRunnerWorkItems; + private List runningTasks; + + public MockTaskRunner(CountDownLatch[] runLatches, CountDownLatch[] completionLatches) + { + this.runLatches = runLatches; + this.completionLatches = completionLatches; + this.taskRunnerWorkItems = new ConcurrentHashMap<>(); + this.runningTasks = new ArrayList<>(); + } + + @Override + public List>> restore() + { + return ImmutableList.of(); + } + + @Override + public void registerListener(TaskRunnerListener listener, Executor executor) + { + // Overlord doesn't call this method + throw new UnsupportedOperationException(); + } + + @Override + public void stop() + { + // Do nothing + } + + @Override + public synchronized ListenableFuture run(final Task task) + { + final String taskId = task.getId(); + ListenableFuture future = MoreExecutors.listeningDecorator( + Execs.singleThreaded( + "noop_test_task_exec_%s" + ) + ).submit( + new Callable() + { + @Override + public TaskStatus call() throws Exception + { + // adding of task to list of runningTasks should be done before count down as + // getRunningTasks may not include the task for which latch has been counted down + // Count down to let know that task is actually running + // this is equivalent of getting process holder to run task in ForkingTaskRunner + runningTasks.add(taskId); + if (runLatches != null) { + runLatches[Integer.parseInt(taskId)].countDown(); + } + // Wait for completion count down + if (completionLatches != null) { + completionLatches[Integer.parseInt(taskId)].await(); + } + taskRunnerWorkItems.remove(taskId); + runningTasks.remove(taskId); + return TaskStatus.success(taskId); + } + } + ); + TaskRunnerWorkItem taskRunnerWorkItem = new TaskRunnerWorkItem(taskId, future) + { + @Override + public TaskLocation getLocation() + { + return TASK_LOCATION; + } + }; + taskRunnerWorkItems.put(taskId, taskRunnerWorkItem); + return future; + } + + @Override + public void shutdown(String taskid) {} + + @Override + public synchronized Collection getRunningTasks() + { + return Lists.transform( + runningTasks, + new Function() + { + @Nullable + @Override + public TaskRunnerWorkItem apply(String input) + { + return taskRunnerWorkItems.get(input); + } + } + ); + } + + @Override + public Collection getPendingTasks() + { + return ImmutableList.of(); + } + + @Override + public Collection getKnownTasks() + { + return taskRunnerWorkItems.values(); + } + + @Override + public Optional getScalingStats() + { + return Optional.absent(); + } + + @Override + public void start() + { + //Do nothing + } + } +} diff --git a/indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java b/indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java new file mode 100644 index 000000000000..a0aa98458cf1 --- /dev/null +++ b/indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.indexing.overlord.http.security; + +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Injector; +import com.sun.jersey.spi.container.ResourceFilter; +import io.druid.indexing.common.task.NoopTask; +import io.druid.indexing.common.task.Task; +import io.druid.indexing.overlord.TaskStorageQueryAdapter; +import io.druid.indexing.overlord.http.OverlordResource; +import io.druid.indexing.worker.http.WorkerResource; +import io.druid.server.http.security.AbstractResourceFilter; +import io.druid.server.http.security.ResourceFilterTestHelper; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; +import java.util.Collection; + +@RunWith(Parameterized.class) +public class SecurityResourceFilterTest extends ResourceFilterTestHelper +{ + + @Parameterized.Parameters + public static Collection data() + { + return ImmutableList.copyOf( + Iterables.concat( + getRequestPaths(OverlordResource.class, ImmutableList.>of(TaskStorageQueryAdapter.class)), + getRequestPaths(WorkerResource.class) + ) + ); + } + + private final String requestPath; + private final String requestMethod; + private final ResourceFilter resourceFilter; + private final Injector injector; + private final Task noopTask = new NoopTask(null, 0, 0, null, null, null); + + private static boolean mockedOnce; + private TaskStorageQueryAdapter tsqa; + + public SecurityResourceFilterTest( + String requestPath, + String requestMethod, + ResourceFilter resourceFilter, + Injector injector + ) + { + this.requestPath = requestPath; + this.requestMethod = requestMethod; + this.resourceFilter = resourceFilter; + this.injector = injector; + } + + @Before + public void setUp() throws Exception + { + if (resourceFilter instanceof TaskResourceFilter && !mockedOnce) { + // Since we are creating the mocked tsqa object only once and getting that object from Guice here therefore + // if the mockedOnce check is not done then we will call EasyMock.expect and EasyMock.replay on the mocked object + // multiple times and it will throw exceptions + tsqa = injector.getInstance(TaskStorageQueryAdapter.class); + EasyMock.expect(tsqa.getTask(EasyMock.anyString())).andReturn(Optional.of(noopTask)).anyTimes(); + EasyMock.replay(tsqa); + mockedOnce = true; + } + setUp(resourceFilter); + } + + @Test + public void testDatasourcesResourcesFilteringAccess() + { + setUpMockExpectations(requestPath, true, requestMethod); + EasyMock.expect(request.getEntity(Task.class)).andReturn(noopTask).anyTimes(); + // As request object is a strict mock the ordering of expected calls matters + // therefore adding the expectation below again as getEntity is called before getMethod + EasyMock.expect(request.getMethod()).andReturn(requestMethod).anyTimes(); + EasyMock.replay(req, request, authorizationInfo); + resourceFilter.getRequestFilter().filter(request); + Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + } + + @Test(expected = WebApplicationException.class) + public void testDatasourcesResourcesFilteringNoAccess() + { + setUpMockExpectations(requestPath, false, requestMethod); + EasyMock.expect(request.getEntity(Task.class)).andReturn(noopTask).anyTimes(); + EasyMock.expect(request.getMethod()).andReturn(requestMethod).anyTimes(); + EasyMock.replay(req, request, authorizationInfo); + Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + try { + resourceFilter.getRequestFilter().filter(request); + } + catch (WebApplicationException e) { + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), e.getResponse().getStatus()); + throw e; + } + } + + @Test + public void testDatasourcesResourcesFilteringBadPath() + { + final String badRequestPath = requestPath.replaceAll("\\w+", "droid"); + EasyMock.expect(request.getPath()).andReturn(badRequestPath).anyTimes(); + EasyMock.replay(req, request, authorizationInfo); + Assert.assertFalse(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(badRequestPath)); + } + + @After + public void tearDown() + { + EasyMock.verify(req, request, authorizationInfo); + if (tsqa != null) { + EasyMock.verify(tsqa); + } + } + +} diff --git a/server/src/main/java/io/druid/guice/security/DruidAuthModule.java b/server/src/main/java/io/druid/guice/security/DruidAuthModule.java new file mode 100644 index 000000000000..e89c8ca23679 --- /dev/null +++ b/server/src/main/java/io/druid/guice/security/DruidAuthModule.java @@ -0,0 +1,44 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.guice.security; + +import com.fasterxml.jackson.databind.Module; +import com.google.inject.Binder; +import io.druid.guice.JsonConfigProvider; +import io.druid.initialization.DruidModule; +import io.druid.server.security.AuthConfig; + +import java.util.Collections; +import java.util.List; + +public class DruidAuthModule implements DruidModule +{ + @Override + public List getJacksonModules() + { + return Collections.emptyList(); + } + + @Override + public void configure(Binder binder) + { + JsonConfigProvider.bind(binder, "druid.auth", AuthConfig.class); + } +} diff --git a/server/src/main/java/io/druid/initialization/Initialization.java b/server/src/main/java/io/druid/initialization/Initialization.java index 0bfc8c0c7bdd..0752036575e2 100644 --- a/server/src/main/java/io/druid/initialization/Initialization.java +++ b/server/src/main/java/io/druid/initialization/Initialization.java @@ -57,6 +57,7 @@ import io.druid.guice.annotations.Json; import io.druid.guice.annotations.Smile; import io.druid.guice.http.HttpClientModule; +import io.druid.guice.security.DruidAuthModule; import io.druid.metadata.storage.derby.DerbyMetadataStorageDruidModule; import io.druid.server.initialization.EmitterModule; import io.druid.server.initialization.jetty.JettyServerModule; @@ -318,7 +319,9 @@ public static Injector makeInjectorWithModules(final Injector baseInjector, Iter { final ModuleList defaultModules = new ModuleList(baseInjector); defaultModules.addModules( + // New modules should be added after Log4jShutterDownerModule new Log4jShutterDownerModule(), + new DruidAuthModule(), new LifecycleModule(), EmitterModule.class, HttpClientModule.global(), diff --git a/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java b/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java index 415c5c9101bf..ff6cb39e8e23 100644 --- a/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java +++ b/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java @@ -143,7 +143,7 @@ public class EventReceiverFirehose implements ChatHandler, Firehose, EventReceiv public EventReceiverFirehose(MapInputRowParser parser) { - this.buffer = new ArrayBlockingQueue(bufferSize); + this.buffer = new ArrayBlockingQueue<>(bufferSize); this.parser = parser; } diff --git a/server/src/main/java/io/druid/server/ClientInfoResource.java b/server/src/main/java/io/druid/server/ClientInfoResource.java index e3a653fe3716..9b800b891d79 100644 --- a/server/src/main/java/io/druid/server/ClientInfoResource.java +++ b/server/src/main/java/io/druid/server/ClientInfoResource.java @@ -19,13 +19,17 @@ package io.druid.server; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.inject.Inject; +import com.metamx.common.Pair; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.FilteredServerInventoryView; @@ -34,6 +38,13 @@ import io.druid.client.selector.ServerSelector; import io.druid.query.TableDataSource; import io.druid.query.metadata.SegmentMetadataQueryConfig; +import io.druid.server.http.security.DatasourceResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import io.druid.timeline.DataSegment; import io.druid.timeline.TimelineLookup; import io.druid.timeline.TimelineObjectHolder; @@ -41,14 +52,17 @@ import org.joda.time.DateTime; import org.joda.time.Interval; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -67,18 +81,21 @@ public class ClientInfoResource private FilteredServerInventoryView serverInventoryView; private TimelineServerView timelineServerView; private SegmentMetadataQueryConfig segmentMetadataQueryConfig; + private final AuthConfig authConfig; @Inject public ClientInfoResource( FilteredServerInventoryView serverInventoryView, TimelineServerView timelineServerView, - SegmentMetadataQueryConfig segmentMetadataQueryConfig + SegmentMetadataQueryConfig segmentMetadataQueryConfig, + AuthConfig authConfig ) { this.serverInventoryView = serverInventoryView; this.timelineServerView = timelineServerView; this.segmentMetadataQueryConfig = (segmentMetadataQueryConfig == null) ? new SegmentMetadataQueryConfig() : segmentMetadataQueryConfig; + this.authConfig = authConfig; } private Map> getSegmentsForDatasources() @@ -98,14 +115,41 @@ private Map> getSegmentsForDatasources() @GET @Produces(MediaType.APPLICATION_JSON) - public Iterable getDataSources() + public Iterable getDataSources(@Context final HttpServletRequest request) { - return getSegmentsForDatasources().keySet(); + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) request.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + return Collections2.filter( + getSegmentsForDatasources().keySet(), + new Predicate() + { + @Override + public boolean apply(String input) + { + Resource resource = new Resource(input, ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ); + } else { + return getSegmentsForDatasources().keySet(); + } } @GET @Path("/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Map getDatasource( @PathParam("dataSourceName") String dataSourceName, @QueryParam("interval") String interval, @@ -193,6 +237,7 @@ public int compare(Interval o1, Interval o2) @GET @Path("/{dataSourceName}/dimensions") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Iterable getDatasourceDimensions( @PathParam("dataSourceName") String dataSourceName, @QueryParam("interval") String interval @@ -225,6 +270,7 @@ public Iterable getDatasourceDimensions( @GET @Path("/{dataSourceName}/metrics") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Iterable getDatasourceMetrics( @PathParam("dataSourceName") String dataSourceName, @QueryParam("interval") String interval diff --git a/server/src/main/java/io/druid/server/QueryManager.java b/server/src/main/java/io/druid/server/QueryManager.java index 3e2b3b510791..49252c8c0ad6 100644 --- a/server/src/main/java/io/druid/server/QueryManager.java +++ b/server/src/main/java/io/druid/server/QueryManager.java @@ -27,20 +27,28 @@ import io.druid.query.Query; import io.druid.query.QueryWatcher; +import java.util.List; import java.util.Set; public class QueryManager implements QueryWatcher { - final SetMultimap queries; + + private final SetMultimap queries; + private final SetMultimap queryDatasources; public QueryManager() { this.queries = Multimaps.synchronizedSetMultimap( HashMultimap.create() ); + this.queryDatasources = Multimaps.synchronizedSetMultimap( + HashMultimap.create() + ); } - public boolean cancelQuery(String id) { + public boolean cancelQuery(String id) + { + queryDatasources.removeAll(id); Set futures = queries.removeAll(id); boolean success = true; for (ListenableFuture future : futures) { @@ -52,7 +60,9 @@ public boolean cancelQuery(String id) { public void registerQuery(Query query, final ListenableFuture future) { final String id = query.getId(); + final List datasources = query.getDataSource().getNames(); queries.put(id, future); + queryDatasources.putAll(id, datasources); future.addListener( new Runnable() { @@ -60,9 +70,17 @@ public void registerQuery(Query query, final ListenableFuture future) public void run() { queries.remove(id, future); + for (String datasource : datasources) { + queryDatasources.remove(id, datasource); + } } }, MoreExecutors.sameThreadExecutor() ); } + + public Set getQueryDatasources(final String queryId) + { + return queryDatasources.get(queryId); + } } diff --git a/server/src/main/java/io/druid/server/QueryResource.java b/server/src/main/java/io/druid/server/QueryResource.java index 0b9ac2b0fa50..63e37e338da3 100644 --- a/server/src/main/java/io/druid/server/QueryResource.java +++ b/server/src/main/java/io/druid/server/QueryResource.java @@ -22,11 +22,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import com.google.common.collect.MapMaker; import com.google.common.io.CountingOutputStream; import com.google.inject.Inject; +import com.metamx.common.ISE; import com.metamx.common.guava.Sequence; import com.metamx.common.guava.Sequences; import com.metamx.common.guava.Yielder; @@ -42,6 +44,12 @@ import io.druid.query.QuerySegmentWalker; import io.druid.server.initialization.ServerConfig; import io.druid.server.log.RequestLogger; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import org.joda.time.DateTime; import javax.servlet.http.HttpServletRequest; @@ -61,6 +69,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.util.Map; +import java.util.Set; import java.util.UUID; /** @@ -81,6 +90,7 @@ public class QueryResource private final ServiceEmitter emitter; private final RequestLogger requestLogger; private final QueryManager queryManager; + private final AuthConfig authConfig; @Inject public QueryResource( @@ -90,7 +100,8 @@ public QueryResource( QuerySegmentWalker texasRanger, ServiceEmitter emitter, RequestLogger requestLogger, - QueryManager queryManager + QueryManager queryManager, + AuthConfig authConfig ) { this.config = config; @@ -100,16 +111,39 @@ public QueryResource( this.emitter = emitter; this.requestLogger = requestLogger; this.queryManager = queryManager; + this.authConfig = authConfig; } @DELETE @Path("{id}") @Produces(MediaType.APPLICATION_JSON) - public Response getServer(@PathParam("id") String queryId) + public Response getServer(@PathParam("id") String queryId, @Context final HttpServletRequest req) { if (log.isDebugEnabled()) { log.debug("Received cancel request for query [%s]", queryId); } + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + Set datasources = queryManager.getQueryDatasources(queryId); + if (datasources == null) { + log.warn("QueryId [%s] not registered with QueryManager, cannot cancel", queryId); + } else { + for (String dataSource : datasources) { + Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.WRITE + ); + if (!authResult.isAllowed()) { + return Response.status(Response.Status.FORBIDDEN).header("Access-Check-Result", authResult).build(); + } + } + } + } queryManager.cancelQuery(queryId); return Response.status(Response.Status.ACCEPTED).build(); } @@ -120,7 +154,7 @@ public Response getServer(@PathParam("id") String queryId) public Response doPost( InputStream in, @QueryParam("pretty") String pretty, - @Context final HttpServletRequest req // used only to get request content-type and remote address + @Context final HttpServletRequest req // used to get request content-type, remote address and AuthorizationInfo ) throws IOException { final long start = System.currentTimeMillis(); @@ -160,6 +194,24 @@ public Response doPost( log.debug("Got query [%s]", query); } + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + if (authorizationInfo != null) { + for (String dataSource : query.getDataSource().getNames()) { + Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.READ + ); + if (!authResult.isAllowed()) { + return Response.status(Response.Status.FORBIDDEN).header("Access-Check-Result", authResult).build(); + } + } + } else { + throw new ISE("WTF?! Security is enabled but no authorization info found in the request"); + } + } + final Map responseContext = new MapMaker().makeMap(); final Sequence res = query.run(texasRanger, responseContext); final Sequence results; diff --git a/server/src/main/java/io/druid/server/StatusResource.java b/server/src/main/java/io/druid/server/StatusResource.java index f5012daafec7..edbd65b4fdb7 100644 --- a/server/src/main/java/io/druid/server/StatusResource.java +++ b/server/src/main/java/io/druid/server/StatusResource.java @@ -21,8 +21,10 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.initialization.DruidModule; import io.druid.initialization.Initialization; +import io.druid.server.http.security.StateResourceFilter; import javax.ws.rs.GET; import javax.ws.rs.Path; @@ -35,6 +37,7 @@ /** */ @Path("/status") +@ResourceFilters(StateResourceFilter.class) public class StatusResource { @GET diff --git a/server/src/main/java/io/druid/server/http/BrokerResource.java b/server/src/main/java/io/druid/server/http/BrokerResource.java index 7e9701a39b7c..7adc968e402b 100644 --- a/server/src/main/java/io/druid/server/http/BrokerResource.java +++ b/server/src/main/java/io/druid/server/http/BrokerResource.java @@ -21,7 +21,9 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.BrokerServerView; +import io.druid.server.http.security.StateResourceFilter; import javax.ws.rs.GET; import javax.ws.rs.Path; @@ -30,6 +32,7 @@ import javax.ws.rs.core.Response; @Path("/druid/broker/v1") +@ResourceFilters(StateResourceFilter.class) public class BrokerResource { private final BrokerServerView brokerServerView; diff --git a/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java b/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java index 0d955b915bf2..c4e572a15a5c 100644 --- a/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java +++ b/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java @@ -19,15 +19,15 @@ package io.druid.server.http; +import com.google.common.collect.ImmutableMap; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.common.config.JacksonConfigManager; import io.druid.server.coordinator.CoordinatorDynamicConfig; - +import io.druid.server.http.security.ConfigResourceFilter; import org.joda.time.Interval; -import com.google.common.collect.ImmutableMap; - import javax.inject.Inject; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.Consumes; @@ -45,6 +45,7 @@ /** */ @Path("/druid/coordinator/v1/config") +@ResourceFilters(ConfigResourceFilter.class) public class CoordinatorDynamicConfigsResource { private final JacksonConfigManager manager; diff --git a/server/src/main/java/io/druid/server/http/CoordinatorResource.java b/server/src/main/java/io/druid/server/http/CoordinatorResource.java index ac13e9ec22f9..20f6805dae12 100644 --- a/server/src/main/java/io/druid/server/http/CoordinatorResource.java +++ b/server/src/main/java/io/druid/server/http/CoordinatorResource.java @@ -24,8 +24,10 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.inject.Inject; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.server.coordinator.DruidCoordinator; import io.druid.server.coordinator.LoadQueuePeon; +import io.druid.server.http.security.StateResourceFilter; import io.druid.timeline.DataSegment; import javax.ws.rs.GET; @@ -38,6 +40,7 @@ /** */ @Path("/druid/coordinator/v1") +@ResourceFilters(StateResourceFilter.class) public class CoordinatorResource { private final DruidCoordinator coordinator; diff --git a/server/src/main/java/io/druid/server/http/DatasourcesResource.java b/server/src/main/java/io/druid/server/http/DatasourcesResource.java index 8aa035f96694..274e03492c5a 100644 --- a/server/src/main/java/io/druid/server/http/DatasourcesResource.java +++ b/server/src/main/java/io/druid/server/http/DatasourcesResource.java @@ -31,6 +31,7 @@ import com.metamx.common.guava.Comparators; import com.metamx.common.guava.FunctionalIterable; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.CoordinatorServerView; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; @@ -39,6 +40,9 @@ import io.druid.client.indexing.IndexingServiceClient; import io.druid.metadata.MetadataSegmentManager; import io.druid.query.TableDataSource; +import io.druid.server.http.security.DatasourceResourceFilter; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; import io.druid.timeline.DataSegment; import io.druid.timeline.TimelineLookup; import io.druid.timeline.TimelineObjectHolder; @@ -47,6 +51,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; import javax.ws.rs.GET; @@ -55,6 +60,7 @@ import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.util.Comparator; @@ -73,28 +79,38 @@ public class DatasourcesResource private final CoordinatorServerView serverInventoryView; private final MetadataSegmentManager databaseSegmentManager; private final IndexingServiceClient indexingServiceClient; + private final AuthConfig authConfig; @Inject public DatasourcesResource( CoordinatorServerView serverInventoryView, MetadataSegmentManager databaseSegmentManager, - @Nullable IndexingServiceClient indexingServiceClient + @Nullable IndexingServiceClient indexingServiceClient, + AuthConfig authConfig ) { this.serverInventoryView = serverInventoryView; this.databaseSegmentManager = databaseSegmentManager; this.indexingServiceClient = indexingServiceClient; + this.authConfig = authConfig; } @GET @Produces(MediaType.APPLICATION_JSON) public Response getQueryableDataSources( @QueryParam("full") String full, - @QueryParam("simple") String simple + @QueryParam("simple") String simple, + @Context final HttpServletRequest req ) { Response.ResponseBuilder builder = Response.ok(); - final Set datasources = InventoryViewUtils.getDataSources(serverInventoryView); + final Set datasources = authConfig.isEnabled() ? + InventoryViewUtils.getSecuredDataSources( + serverInventoryView, + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN) + ) : + InventoryViewUtils.getDataSources(serverInventoryView); + if (full != null) { return builder.entity(datasources).build(); } else if (simple != null) { @@ -135,12 +151,14 @@ public String apply(DruidDataSource dataSource) @GET @Path("/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getTheDataSource( @PathParam("dataSourceName") final String dataSourceName, @QueryParam("full") final String full ) { DruidDataSource dataSource = getDataSource(dataSourceName); + if (dataSource == null) { return Response.noContent().build(); } @@ -155,6 +173,7 @@ public Response getTheDataSource( @POST @Path("/{dataSourceName}") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response enableDataSource( @PathParam("dataSourceName") final String dataSourceName ) @@ -175,6 +194,7 @@ public Response enableDataSource( @DELETE @Deprecated @Path("/{dataSourceName}") + @ResourceFilters(DatasourceResourceFilter.class) @Produces(MediaType.APPLICATION_JSON) public Response deleteDataSource( @PathParam("dataSourceName") final String dataSourceName, @@ -253,6 +273,7 @@ public Response deleteDataSourceSpecificInterval( @GET @Path("/{dataSourceName}/intervals") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceIntervals( @PathParam("dataSourceName") String dataSourceName, @QueryParam("simple") String simple, @@ -313,6 +334,7 @@ public Response getSegmentDataSourceIntervals( @GET @Path("/{dataSourceName}/intervals/{interval}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSpecificInterval( @PathParam("dataSourceName") String dataSourceName, @PathParam("interval") String interval, @@ -380,6 +402,7 @@ public Response getSegmentDataSourceSpecificInterval( @GET @Path("/{dataSourceName}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSegments( @PathParam("dataSourceName") String dataSourceName, @QueryParam("full") String full @@ -413,6 +436,7 @@ public Object apply(DataSegment segment) @GET @Path("/{dataSourceName}/segments/{segmentId}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId @@ -436,6 +460,7 @@ public Response getSegmentDataSourceSegment( @DELETE @Path("/{dataSourceName}/segments/{segmentId}") + @ResourceFilters(DatasourceResourceFilter.class) public Response deleteDatasourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId @@ -451,6 +476,7 @@ public Response deleteDatasourceSegment( @POST @Path("/{dataSourceName}/segments/{segmentId}") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response enableDatasourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId @@ -466,6 +492,7 @@ public Response enableDatasourceSegment( @GET @Path("/{dataSourceName}/tiers") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceTiers( @PathParam("dataSourceName") String dataSourceName ) @@ -624,6 +651,7 @@ private Map> getSimpleDatasource(String dataSourceNa @GET @Path("/{dataSourceName}/intervals/{interval}/serverview") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSpecificInterval( @PathParam("dataSourceName") String dataSourceName, @PathParam("interval") String interval, diff --git a/server/src/main/java/io/druid/server/http/HistoricalResource.java b/server/src/main/java/io/druid/server/http/HistoricalResource.java index 4680cf29c6c5..bc77ce0fc056 100644 --- a/server/src/main/java/io/druid/server/http/HistoricalResource.java +++ b/server/src/main/java/io/druid/server/http/HistoricalResource.java @@ -20,7 +20,9 @@ package io.druid.server.http; import com.google.common.collect.ImmutableMap; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.server.coordination.ZkCoordinator; +import io.druid.server.http.security.StateResourceFilter; import javax.inject.Inject; import javax.ws.rs.GET; @@ -30,6 +32,7 @@ import javax.ws.rs.core.Response; @Path("/druid/historical/v1") +@ResourceFilters(StateResourceFilter.class) public class HistoricalResource { private final ZkCoordinator coordinator; diff --git a/server/src/main/java/io/druid/server/http/IntervalsResource.java b/server/src/main/java/io/druid/server/http/IntervalsResource.java index 103330fc50ab..29c8a1f4f86f 100644 --- a/server/src/main/java/io/druid/server/http/IntervalsResource.java +++ b/server/src/main/java/io/druid/server/http/IntervalsResource.java @@ -25,14 +25,18 @@ import com.metamx.common.guava.Comparators; import io.druid.client.DruidDataSource; import io.druid.client.InventoryView; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; import io.druid.timeline.DataSegment; import org.joda.time.Interval; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.util.Comparator; @@ -45,35 +49,43 @@ public class IntervalsResource { private final InventoryView serverInventoryView; + private final AuthConfig authConfig; @Inject public IntervalsResource( - InventoryView serverInventoryView + InventoryView serverInventoryView, + AuthConfig authConfig ) { this.serverInventoryView = serverInventoryView; + this.authConfig = authConfig; } @GET @Produces(MediaType.APPLICATION_JSON) - public Response getIntervals() + public Response getIntervals(@Context final HttpServletRequest req) { - final Comparator comparator = Comparators.inverse(Comparators.intervalsByStartThenEnd()); - final Set datasources = InventoryViewUtils.getDataSources(serverInventoryView); - - final Map>> retVal = Maps.newTreeMap(comparator); - for (DruidDataSource dataSource : datasources) { - for (DataSegment dataSegment : dataSource.getSegments()) { - Map> interval = retVal.get(dataSegment.getInterval()); - if (interval == null) { - Map> tmp = Maps.newHashMap(); - retVal.put(dataSegment.getInterval(), tmp); - } - setProperties(retVal, dataSource, dataSegment); + final Comparator comparator = Comparators.inverse(Comparators.intervalsByStartThenEnd()); + final Set datasources = authConfig.isEnabled() ? + InventoryViewUtils.getSecuredDataSources( + serverInventoryView, + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN) + ) : + InventoryViewUtils.getDataSources(serverInventoryView); + + final Map>> retVal = Maps.newTreeMap(comparator); + for (DruidDataSource dataSource : datasources) { + for (DataSegment dataSegment : dataSource.getSegments()) { + Map> interval = retVal.get(dataSegment.getInterval()); + if (interval == null) { + Map> tmp = Maps.newHashMap(); + retVal.put(dataSegment.getInterval(), tmp); } + setProperties(retVal, dataSource, dataSegment); } + } - return Response.ok(retVal).build(); + return Response.ok(retVal).build(); } @GET @@ -82,13 +94,20 @@ public Response getIntervals() public Response getSpecificIntervals( @PathParam("interval") String interval, @QueryParam("simple") String simple, - @QueryParam("full") String full + @QueryParam("full") String full, + @Context final HttpServletRequest req ) { final Interval theInterval = new Interval(interval.replace("_", "/")); - final Set datasources = InventoryViewUtils.getDataSources(serverInventoryView); + final Set datasources = authConfig.isEnabled() ? + InventoryViewUtils.getSecuredDataSources( + serverInventoryView, + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN) + ) : + InventoryViewUtils.getDataSources(serverInventoryView); final Comparator comparator = Comparators.inverse(Comparators.intervalsByStartThenEnd()); + if (full != null) { final Map>> retVal = Maps.newTreeMap(comparator); for (DruidDataSource dataSource : datasources) { diff --git a/server/src/main/java/io/druid/server/http/InventoryViewUtils.java b/server/src/main/java/io/druid/server/http/InventoryViewUtils.java index df39f5e70c13..62cb5109eadb 100644 --- a/server/src/main/java/io/druid/server/http/InventoryViewUtils.java +++ b/server/src/main/java/io/druid/server/http/InventoryViewUtils.java @@ -20,18 +20,30 @@ package io.druid.server.http; import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.metamx.common.ISE; +import com.metamx.common.Pair; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; import java.util.Set; import java.util.TreeSet; -public class InventoryViewUtils { +public class InventoryViewUtils +{ public static Set getDataSources(InventoryView serverInventoryView) { @@ -64,4 +76,38 @@ public Iterable apply(DruidServer input) ); return dataSources; } + + public static Set getSecuredDataSources( + InventoryView inventoryView, + final AuthorizationInfo authorizationInfo + ) + { + if (authorizationInfo == null) { + throw new ISE("Invalid to call a secured method with null AuthorizationInfo!!"); + } else { + final Map, Access> resourceAccessMap = new HashMap<>(); + return ImmutableSet.copyOf( + Iterables.filter( + getDataSources(inventoryView), + new Predicate() + { + @Override + public boolean apply(DruidDataSource input) + { + Resource resource = new Resource(input.getName(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ) + ); + } + } } diff --git a/server/src/main/java/io/druid/server/http/MetadataResource.java b/server/src/main/java/io/druid/server/http/MetadataResource.java index 294165402f32..e480121b8b9f 100644 --- a/server/src/main/java/io/druid/server/http/MetadataResource.java +++ b/server/src/main/java/io/druid/server/http/MetadataResource.java @@ -20,26 +20,42 @@ package io.druid.server.http; import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.inject.Inject; +import com.metamx.common.Pair; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidDataSource; import io.druid.indexing.overlord.IndexerMetadataStorageCoordinator; import io.druid.metadata.MetadataSegmentManager; +import io.druid.server.http.security.DatasourceResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import io.druid.timeline.DataSegment; import org.joda.time.Interval; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.GET; import javax.ws.rs.POST; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.io.IOException; +import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** */ @@ -48,15 +64,18 @@ public class MetadataResource { private final MetadataSegmentManager metadataSegmentManager; private final IndexerMetadataStorageCoordinator metadataStorageCoordinator; + private final AuthConfig authConfig; @Inject public MetadataResource( MetadataSegmentManager metadataSegmentManager, - IndexerMetadataStorageCoordinator metadataStorageCoordinator + IndexerMetadataStorageCoordinator metadataStorageCoordinator, + AuthConfig authConfig ) { this.metadataSegmentManager = metadataSegmentManager; this.metadataStorageCoordinator = metadataStorageCoordinator; + this.authConfig = authConfig; } @GET @@ -64,20 +83,88 @@ public MetadataResource( @Produces(MediaType.APPLICATION_JSON) public Response getDatabaseDataSources( @QueryParam("full") String full, - @QueryParam("includeDisabled") String includeDisabled + @QueryParam("includeDisabled") String includeDisabled, + @Context final HttpServletRequest req ) { Response.ResponseBuilder builder = Response.status(Response.Status.OK); + + final Collection druidDataSources; + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + if (includeDisabled != null) { + return builder.entity( + Collections2.filter( + metadataSegmentManager.getAllDatasourceNames(), + new Predicate() + { + @Override + public boolean apply(String input) + { + Resource resource = new Resource(input, ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + )).build(); + } else { + druidDataSources = + Collections2.filter( + metadataSegmentManager.getInventory(), + new Predicate() + { + @Override + public boolean apply(DruidDataSource input) + { + Resource resource = new Resource(input.getName(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ); + } + } else { + druidDataSources = metadataSegmentManager.getInventory(); + } + if (includeDisabled != null) { - return builder.entity(metadataSegmentManager.getAllDatasourceNames()).build(); + return builder.entity( + Collections2.transform( + druidDataSources, + new Function() + { + @Override + public String apply(DruidDataSource input) + { + return input.getName(); + } + } + ) + ).build(); } if (full != null) { - return builder.entity(metadataSegmentManager.getInventory()).build(); + return builder.entity(druidDataSources).build(); } List dataSourceNames = Lists.newArrayList( Iterables.transform( - metadataSegmentManager.getInventory(), + druidDataSources, new Function() { @Override @@ -97,6 +184,7 @@ public String apply(DruidDataSource dataSource) @GET @Path("/datasources/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSource( @PathParam("dataSourceName") final String dataSourceName ) @@ -112,6 +200,7 @@ public Response getDatabaseSegmentDataSource( @GET @Path("/datasources/{dataSourceName}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSourceSegments( @PathParam("dataSourceName") String dataSourceName, @QueryParam("full") String full @@ -145,13 +234,14 @@ public String apply(DataSegment segment) @POST @Path("/datasources/{dataSourceName}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSourceSegments( @PathParam("dataSourceName") String dataSourceName, @QueryParam("full") String full, List intervals ) { - List segments = null; + List segments; try { segments = metadataStorageCoordinator.getUsedSegmentsForIntervals(dataSourceName, intervals); } @@ -182,6 +272,7 @@ public String apply(DataSegment segment) @GET @Path("/datasources/{dataSourceName}/segments/{segmentId}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId diff --git a/server/src/main/java/io/druid/server/http/RulesResource.java b/server/src/main/java/io/druid/server/http/RulesResource.java index fdacb228ea63..1d93d61df7d8 100644 --- a/server/src/main/java/io/druid/server/http/RulesResource.java +++ b/server/src/main/java/io/druid/server/http/RulesResource.java @@ -21,13 +21,14 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; - +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.audit.AuditEntry; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.metadata.MetadataRuleManager; import io.druid.server.coordinator.rules.Rule; - +import io.druid.server.http.security.RulesResourceFilter; +import io.druid.server.http.security.StateResourceFilter; import org.joda.time.Interval; import javax.servlet.http.HttpServletRequest; @@ -43,7 +44,6 @@ import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; - import java.util.List; /** @@ -66,6 +66,7 @@ public RulesResource( @GET @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getRules() { return Response.ok(databaseRuleManager.getAllRules()).build(); @@ -74,6 +75,7 @@ public Response getRules() @GET @Path("/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(RulesResourceFilter.class) public Response getDatasourceRules( @PathParam("dataSourceName") final String dataSourceName, @QueryParam("full") final String full @@ -91,6 +93,7 @@ public Response getDatasourceRules( @POST @Path("/{dataSourceName}") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(RulesResourceFilter.class) public Response setDatasourceRules( @PathParam("dataSourceName") final String dataSourceName, final List rules, @@ -112,6 +115,7 @@ public Response setDatasourceRules( @GET @Path("/{dataSourceName}/history") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(RulesResourceFilter.class) public Response getDatasourceRuleHistory( @PathParam("dataSourceName") final String dataSourceName, @QueryParam("interval") final String interval, @@ -131,6 +135,7 @@ public Response getDatasourceRuleHistory( @GET @Path("/history") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getDatasourceRuleHistory( @QueryParam("interval") final String interval, @QueryParam("count") final Integer count diff --git a/server/src/main/java/io/druid/server/http/ServersResource.java b/server/src/main/java/io/druid/server/http/ServersResource.java index 33665fda81d2..70308eb8ebb0 100644 --- a/server/src/main/java/io/druid/server/http/ServersResource.java +++ b/server/src/main/java/io/druid/server/http/ServersResource.java @@ -25,8 +25,10 @@ import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.inject.Inject; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.http.security.StateResourceFilter; import io.druid.timeline.DataSegment; import javax.ws.rs.GET; @@ -41,6 +43,7 @@ /** */ @Path("/druid/coordinator/v1/servers") +@ResourceFilters(StateResourceFilter.class) public class ServersResource { private static Map makeSimpleServer(DruidServer input) diff --git a/server/src/main/java/io/druid/server/http/TiersResource.java b/server/src/main/java/io/druid/server/http/TiersResource.java index 6990dae2839a..db9189e56e5c 100644 --- a/server/src/main/java/io/druid/server/http/TiersResource.java +++ b/server/src/main/java/io/druid/server/http/TiersResource.java @@ -28,9 +28,11 @@ import com.google.common.collect.Table; import com.google.inject.Inject; import com.metamx.common.MapUtils; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.http.security.StateResourceFilter; import io.druid.timeline.DataSegment; import org.joda.time.Interval; @@ -47,6 +49,7 @@ /** */ @Path("/druid/coordinator/v1/tiers") +@ResourceFilters(StateResourceFilter.class) public class TiersResource { private final InventoryView serverInventoryView; diff --git a/server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java b/server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java new file mode 100644 index 000000000000..a8a1fb4cb4e1 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java @@ -0,0 +1,89 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import com.sun.jersey.spi.container.ContainerRequestFilter; +import com.sun.jersey.spi.container.ContainerResponseFilter; +import com.sun.jersey.spi.container.ResourceFilter; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.Context; + +public abstract class AbstractResourceFilter implements ResourceFilter, ContainerRequestFilter +{ + //https://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-520005 + @Context + private HttpServletRequest req; + + private final AuthConfig authConfig; + + @Inject + public AbstractResourceFilter(AuthConfig authConfig) + { + this.authConfig = authConfig; + } + + @Override + public ContainerRequestFilter getRequestFilter() + { + return this; + } + + @Override + public ContainerResponseFilter getResponseFilter() + { + return null; + } + + public HttpServletRequest getReq() + { + return req; + } + + public AuthConfig getAuthConfig() + { + return authConfig; + } + + public AbstractResourceFilter setReq(HttpServletRequest req) + { + this.req = req; + return this; + } + + protected Action getAction(ContainerRequest request) + { + Action action; + switch (request.getMethod()) { + case "GET": + case "HEAD": + action = Action.READ; + break; + default: + action = Action.WRITE; + } + return action; + } + + public abstract boolean isApplicable(String requestPath); +} diff --git a/server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java b/server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java new file mode 100644 index 000000000000..61fc28f16269 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java @@ -0,0 +1,85 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; + +/** + * Use this ResourceFilter at end points where Druid Cluster configuration is read or written + * Here are some example paths where this filter is used - + * - druid/worker/v1 + * - druid/indexer/v1 + * - druid/coordinator/v1/config + * Note - Currently the resource name for all end points is set to "CONFIG" however if more fine grained access control + * is required the resource name can be set to specific config properties. + */ +public class ConfigResourceFilter extends AbstractResourceFilter +{ + @Inject + public ConfigResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String resourceName = "CONFIG"; + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + + final Access authResult = authorizationInfo.isAuthorized( + new Resource(resourceName, ResourceType.CONFIG), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + return requestPath.startsWith("druid/worker/v1") || + requestPath.startsWith("druid/indexer/v1") || + requestPath.startsWith("druid/coordinator/v1/config"); + } +} diff --git a/server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java b/server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java new file mode 100644 index 000000000000..ccbeab866008 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java @@ -0,0 +1,110 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.PathSegment; +import javax.ws.rs.core.Response; +import java.util.List; + +/** + * Use this ResourceFilter when the datasource information is present after "datasources" segment in the request Path + * Here are some example paths where this filter is used - + * - druid/coordinator/v1/datasources/{dataSourceName}/... + * - druid/coordinator/v1/metadata/datasources/{dataSourceName}/... + * - druid/v2/datasources/{dataSourceName}/... + */ +public class DatasourceResourceFilter extends AbstractResourceFilter +{ + @Inject + public DatasourceResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String dataSourceName = request.getPathSegments() + .get( + Iterables.indexOf( + request.getPathSegments(), + new Predicate() + { + @Override + public boolean apply(PathSegment input) + { + return input.getPath().equals("datasources"); + } + } + ) + 1 + ).getPath(); + Preconditions.checkNotNull(dataSourceName); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + final Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSourceName, ResourceType.DATASOURCE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + List applicablePaths = ImmutableList.of( + "druid/coordinator/v1/datasources/", + "druid/coordinator/v1/metadata/datasources/", + "druid/v2/datasources/" + ); + for (String path : applicablePaths) { + if(requestPath.startsWith(path) && !requestPath.equals(path)) { + return true; + } + } + return false; + } +} diff --git a/server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java b/server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java new file mode 100644 index 000000000000..0e87fab200fe --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java @@ -0,0 +1,106 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.PathSegment; +import javax.ws.rs.core.Response; +import java.util.List; + + +/** + * Use this ResourceFilter when the datasource information is present after "rules" segment in the request Path + * Here are some example paths where this filter is used - + * - druid/coordinator/v1/rules/ + * */ + +public class RulesResourceFilter extends AbstractResourceFilter +{ + @Inject + public RulesResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String dataSourceName = request.getPathSegments() + .get( + Iterables.indexOf( + request.getPathSegments(), + new Predicate() + { + @Override + public boolean apply(PathSegment input) + { + return input.getPath().equals("rules"); + } + } + ) + 1 + ).getPath(); + Preconditions.checkNotNull(dataSourceName); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + final Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSourceName, ResourceType.DATASOURCE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + List applicablePaths = ImmutableList.of("druid/coordinator/v1/rules/"); + for (String path : applicablePaths) { + if(requestPath.startsWith(path) && !requestPath.equals(path)) { + return true; + } + } + return false; + } +} diff --git a/server/src/main/java/io/druid/server/http/security/StateResourceFilter.java b/server/src/main/java/io/druid/server/http/security/StateResourceFilter.java new file mode 100644 index 000000000000..b4d9d40195f4 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/StateResourceFilter.java @@ -0,0 +1,97 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; + +/** + * Use this ResourceFilter at end points where Druid Cluster State is read or written + * Here are some example paths where this filter is used - + * - druid/broker/v1 + * - druid/coordinator/v1 + * - druid/historical/v1 + * - druid/indexer/v1 + * - druid/coordinator/v1/rules + * - druid/coordinator/v1/tiers + * - druid/worker/v1 + * - druid/coordinator/v1/servers + * - status + * Note - Currently the resource name for all end points is set to "STATE" however if more fine grained access control + * is required the resource name can be set to specific state properties. + */ +public class StateResourceFilter extends AbstractResourceFilter +{ + @Inject + public StateResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String resourceName = "STATE"; + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + + final Access authResult = authorizationInfo.isAuthorized( + new Resource(resourceName, ResourceType.STATE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + + return request; + } + + public boolean isApplicable(String requestPath) + { + return requestPath.startsWith("druid/broker/v1") || + requestPath.startsWith("druid/coordinator/v1") || + requestPath.startsWith("druid/historical/v1") || + requestPath.startsWith("druid/indexer/v1") || + requestPath.startsWith("druid/coordinator/v1/rules") || + requestPath.startsWith("druid/coordinator/v1/tiers") || + requestPath.startsWith("druid/worker/v1") || + requestPath.startsWith("druid/coordinator/v1/servers") || + requestPath.startsWith("status"); + } +} diff --git a/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java b/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java index a0ad9b765b19..66fd4c1a6fda 100644 --- a/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java +++ b/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java @@ -28,11 +28,9 @@ import com.metamx.metrics.KeyedDiff; import com.metamx.metrics.MonitorUtils; import io.druid.query.DruidMetrics; -import io.druid.segment.realtime.firehose.EventReceiverFirehoseFactory; import java.util.Map; import java.util.Properties; -import java.util.concurrent.atomic.AtomicLong; public class EventReceiverFirehoseMonitor extends AbstractMonitor { diff --git a/server/src/main/java/io/druid/server/security/Access.java b/server/src/main/java/io/druid/server/security/Access.java new file mode 100644 index 000000000000..a70e579f3a4c --- /dev/null +++ b/server/src/main/java/io/druid/server/security/Access.java @@ -0,0 +1,51 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public class Access +{ + private final boolean allowed; + private String message; + + public Access(boolean allowed) { + this(allowed, ""); + } + + public Access(boolean allowed, String message) { + this.allowed = allowed; + this.message = message; + } + + public boolean isAllowed() { + return allowed; + } + + public Access setMessage(String message) + { + this.message = message; + return this; + } + + @Override + public String toString() + { + return String.format("Allowed:%s, Message:%s", allowed, message); + } +} diff --git a/server/src/main/java/io/druid/server/security/Action.java b/server/src/main/java/io/druid/server/security/Action.java new file mode 100644 index 000000000000..2b7606b58dd8 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/Action.java @@ -0,0 +1,26 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public enum Action +{ + READ, + WRITE +} diff --git a/server/src/main/java/io/druid/server/security/AuthConfig.java b/server/src/main/java/io/druid/server/security/AuthConfig.java new file mode 100644 index 000000000000..8ade4ce6c415 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/AuthConfig.java @@ -0,0 +1,85 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +public class AuthConfig +{ + /** + * Use this String as the attribute name for the request attribute to pass {@link AuthorizationInfo} + * from the servlet filter to the jersey resource + * */ + public static final String DRUID_AUTH_TOKEN = "Druid-Auth-Token"; + + public AuthConfig() { + this(false); + } + + @JsonCreator + public AuthConfig( + @JsonProperty("enabled") boolean enabled + ){ + this.enabled = enabled; + } + /** + * If druid.auth.enabled is set to true then an implementation of AuthorizationInfo + * must be provided and it must be set as a request attribute possibly inside the servlet filter + * injected in the filter chain using your own extension + * */ + @JsonProperty + private final boolean enabled; + + public boolean isEnabled() + { + return enabled; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + AuthConfig that = (AuthConfig) o; + + return enabled == that.enabled; + + } + + @Override + public int hashCode() + { + return (enabled ? 1 : 0); + } + + @Override + public String toString() + { + return "AuthConfig{" + + "enabled=" + enabled + + '}'; + } +} diff --git a/server/src/main/java/io/druid/server/security/AuthorizationInfo.java b/server/src/main/java/io/druid/server/security/AuthorizationInfo.java new file mode 100644 index 000000000000..31097a935477 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/AuthorizationInfo.java @@ -0,0 +1,44 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +/** + * This interface should be used to store as well as process Authorization Information + * An extension can be used to inject servlet filter which will create objects of this type + * and set it as a request attribute with attribute name as {@link AuthConfig#DRUID_AUTH_TOKEN}. + * In the jersey resources if the authorization is enabled depending on {@link AuthConfig#enabled} + * the {@link #isAuthorized(Resource, Action)} method will be used to perform authorization checks + * */ +public interface AuthorizationInfo +{ + /** + * Perform authorization checks for the given {@link Resource} and {@link Action}. + * resource and action objects should be instantiated depending on + * the specific endPoint where the check is being performed. + * Modeling Principal and specific way of performing authorization checks is + * entirely implementation dependent. + * + * @param resource information about resource that is being accessed + * @param action action to be performed on the resource + * @return a {@link Access} object having {@link Access#allowed} set to true if authorized otherwise set to false + * and optionally {@link Access#message} set to appropriate message + * */ + Access isAuthorized(Resource resource, Action action); +} diff --git a/server/src/main/java/io/druid/server/security/Resource.java b/server/src/main/java/io/druid/server/security/Resource.java new file mode 100644 index 000000000000..d3c74fb52899 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/Resource.java @@ -0,0 +1,69 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public class Resource +{ + private final String name; + private final ResourceType type; + + public Resource(String name, ResourceType type) + { + this.name = name; + this.type = type; + } + + public String getName() + { + return name; + } + + public ResourceType getType() + { + return type; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Resource resource = (Resource) o; + + if (!name.equals(resource.name)) { + return false; + } + return type == resource.type; + + } + + @Override + public int hashCode() + { + int result = name.hashCode(); + result = 31 * result + type.hashCode(); + return result; + } +} diff --git a/server/src/main/java/io/druid/server/security/ResourceType.java b/server/src/main/java/io/druid/server/security/ResourceType.java new file mode 100644 index 000000000000..818bf9ca947d --- /dev/null +++ b/server/src/main/java/io/druid/server/security/ResourceType.java @@ -0,0 +1,27 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public enum ResourceType +{ + DATASOURCE, + CONFIG, + STATE +} diff --git a/server/src/test/java/io/druid/server/ClientInfoResourceTest.java b/server/src/test/java/io/druid/server/ClientInfoResourceTest.java index a81938a7284f..1436ab2534b2 100644 --- a/server/src/test/java/io/druid/server/ClientInfoResourceTest.java +++ b/server/src/test/java/io/druid/server/ClientInfoResourceTest.java @@ -47,6 +47,7 @@ import io.druid.client.selector.ServerSelector; import io.druid.query.TableDataSource; import io.druid.query.metadata.SegmentMetadataQueryConfig; +import io.druid.server.security.AuthConfig; import io.druid.timeline.DataSegment; import io.druid.timeline.VersionedIntervalTimeline; import io.druid.timeline.partition.NumberedShardSpec; @@ -411,7 +412,7 @@ private ClientInfoResource getResourceTestHelper( SegmentMetadataQueryConfig segmentMetadataQueryConfig ) { - return new ClientInfoResource(serverInventoryView, timelineServerView, segmentMetadataQueryConfig) + return new ClientInfoResource(serverInventoryView, timelineServerView, segmentMetadataQueryConfig, new AuthConfig()) { @Override protected DateTime getCurrentTime() diff --git a/server/src/test/java/io/druid/server/QueryResourceTest.java b/server/src/test/java/io/druid/server/QueryResourceTest.java index ed2b3f1091f4..dabd6b575e8c 100644 --- a/server/src/test/java/io/druid/server/QueryResourceTest.java +++ b/server/src/test/java/io/druid/server/QueryResourceTest.java @@ -20,9 +20,13 @@ package io.druid.server; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; import com.metamx.common.guava.Sequence; import com.metamx.common.guava.Sequences; import com.metamx.emitter.service.ServiceEmitter; +import io.druid.concurrent.Execs; import io.druid.jackson.DefaultObjectMapper; import io.druid.query.Query; import io.druid.query.QueryRunner; @@ -31,9 +35,15 @@ import io.druid.server.initialization.ServerConfig; import io.druid.server.log.NoopRequestLogger; import io.druid.server.metrics.NoopServiceEmitter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; import org.easymock.EasyMock; import org.joda.time.Interval; import org.joda.time.Period; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; @@ -45,6 +55,8 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; /** * @@ -97,6 +109,9 @@ public QueryRunner getQueryRunnerForSegments( private static final ServiceEmitter noopServiceEmitter = new NoopServiceEmitter(); + private QueryResource queryResource; + private QueryManager queryManager; + @BeforeClass public static void staticSetup() { @@ -106,9 +121,19 @@ public static void staticSetup() @Before public void setup() { - EasyMock.expect(testServletRequest.getContentType()).andReturn(MediaType.APPLICATION_JSON); + EasyMock.expect(testServletRequest.getContentType()).andReturn(MediaType.APPLICATION_JSON).anyTimes(); EasyMock.expect(testServletRequest.getRemoteAddr()).andReturn("localhost").anyTimes(); - EasyMock.replay(testServletRequest); + queryManager = new QueryManager(); + queryResource = new QueryResource( + serverConfig, + jsonMapper, + jsonMapper, + testSegmentWalker, + new NoopServiceEmitter(), + new NoopRequestLogger(), + queryManager, + new AuthConfig() + ); } private static final String simpleTimeSeriesQuery = "{\n" @@ -129,42 +154,273 @@ public void setup() @Test public void testGoodQuery() throws IOException { - QueryResource queryResource = new QueryResource( + EasyMock.replay(testServletRequest); + Response response = queryResource.doPost( + new ByteArrayInputStream(simpleTimeSeriesQuery.getBytes("UTF-8")), + null /*pretty*/, + testServletRequest + ); + Assert.assertNotNull(response); + } + + @Test + public void testBadQuery() throws IOException + { + EasyMock.replay(testServletRequest); + Response response = queryResource.doPost( + new ByteArrayInputStream("Meka Leka Hi Meka Hiney Ho".getBytes("UTF-8")), + null /*pretty*/, + testServletRequest + ); + Assert.assertNotNull(response); + Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); + } + + @Test + public void testSecuredQuery() throws Exception + { + EasyMock.expect(testServletRequest.getAttribute(EasyMock.anyString())).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + if (resource.getName().equals("allow")) { + return new Access(true); + } else { + return new Access(false); + } + } + } + ).times(2); + EasyMock.replay(testServletRequest); + + queryResource = new QueryResource( serverConfig, jsonMapper, jsonMapper, testSegmentWalker, new NoopServiceEmitter(), new NoopRequestLogger(), - new QueryManager() + queryManager, + new AuthConfig(true) ); - Response respone = queryResource.doPost( + + Response response = queryResource.doPost( new ByteArrayInputStream(simpleTimeSeriesQuery.getBytes("UTF-8")), null /*pretty*/, testServletRequest ); - Assert.assertNotNull(respone); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); + + response = queryResource.doPost( + new ByteArrayInputStream("{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"}".getBytes("UTF-8")), + null /*pretty*/, + testServletRequest + ); + + Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + } - @Test - public void testBadQuery() throws IOException + @Test(timeout = 60_000L) + public void testSecuredGetServer() throws Exception + { + final CountDownLatch waitForCancellationLatch = new CountDownLatch(1); + final CountDownLatch waitFinishLatch = new CountDownLatch(2); + final CountDownLatch startAwaitLatch = new CountDownLatch(1); + final CountDownLatch cancelledCountDownLatch = new CountDownLatch(1); + + EasyMock.expect(testServletRequest.getAttribute(EasyMock.anyString())).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + // READ action corresponds to the query + // WRITE corresponds to cancellation of query + if (action.equals(Action.READ)) { + try { + waitForCancellationLatch.await(); + } + catch (InterruptedException e) { + // When the query is cancelled the control will reach here, + // countdown the latch and rethrow the exception so that error response is returned for the query + cancelledCountDownLatch.countDown(); + Throwables.propagate(e); + } + return new Access(true); + } else { + return new Access(true); + } + } + } + ).times(2); + EasyMock.replay(testServletRequest); + + queryResource = new QueryResource( + serverConfig, + jsonMapper, + jsonMapper, + testSegmentWalker, + new NoopServiceEmitter(), + new NoopRequestLogger(), + queryManager, + new AuthConfig(true) + ); + + final String queryString = "{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"," + + "\"context\":{\"queryId\":\"id_1\"}}"; + ObjectMapper mapper = new DefaultObjectMapper(); + Query query = mapper.readValue(queryString, Query.class); + + ListenableFuture future = MoreExecutors.listeningDecorator( + Execs.singleThreaded("test_query_resource_%s") + ).submit( + new Runnable() + { + @Override + public void run() + { + try { + startAwaitLatch.countDown(); + Response response = queryResource.doPost( + new ByteArrayInputStream(queryString.getBytes("UTF-8")), + null, + testServletRequest + ); + + Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); + } + catch (IOException e) { + Throwables.propagate(e); + } + waitFinishLatch.countDown(); + } + } + ); + + queryManager.registerQuery(query, future); + startAwaitLatch.await(); + + Executors.newSingleThreadExecutor().submit( + new Runnable() + { + @Override + public void run() + { + Response response = queryResource.getServer("id_1", testServletRequest); + Assert.assertEquals(Response.Status.ACCEPTED.getStatusCode(), response.getStatus()); + waitForCancellationLatch.countDown(); + waitFinishLatch.countDown(); + } + } + ); + waitFinishLatch.await(); + cancelledCountDownLatch.await(); + } + + @Test(timeout = 60_000L) + public void testDenySecuredGetServer() throws Exception { + final CountDownLatch waitForCancellationLatch = new CountDownLatch(1); + final CountDownLatch waitFinishLatch = new CountDownLatch(2); + final CountDownLatch startAwaitLatch = new CountDownLatch(1); + + EasyMock.expect(testServletRequest.getAttribute(EasyMock.anyString())).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + // READ action corresponds to the query + // WRITE corresponds to cancellation of query + if (action.equals(Action.READ)) { + try { + waitForCancellationLatch.await(); + } + catch (InterruptedException e) { + Throwables.propagate(e); + } + return new Access(true); + } else { + // Deny access to cancel the query + return new Access(false); + } + } + } + ).times(2); + EasyMock.replay(testServletRequest); - QueryResource queryResource = new QueryResource( + queryResource = new QueryResource( serverConfig, jsonMapper, jsonMapper, testSegmentWalker, new NoopServiceEmitter(), new NoopRequestLogger(), - new QueryManager() + queryManager, + new AuthConfig(true) ); - Response respone = queryResource.doPost( - new ByteArrayInputStream("Meka Leka Hi Meka Hiney Ho".getBytes("UTF-8")), - null /*pretty*/, - testServletRequest + + final String queryString = "{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"," + + "\"context\":{\"queryId\":\"id_1\"}}"; + ObjectMapper mapper = new DefaultObjectMapper(); + Query query = mapper.readValue(queryString, Query.class); + + ListenableFuture future = MoreExecutors.listeningDecorator( + Execs.singleThreaded("test_query_resource_%s") + ).submit( + new Runnable() + { + @Override + public void run() + { + try { + startAwaitLatch.countDown(); + Response response = queryResource.doPost( + new ByteArrayInputStream(queryString.getBytes("UTF-8")), + null, + testServletRequest + ); + Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + } + catch (IOException e) { + Throwables.propagate(e); + } + waitFinishLatch.countDown(); + } + } + ); + + queryManager.registerQuery(query, future); + startAwaitLatch.await(); + + Executors.newSingleThreadExecutor().submit( + new Runnable() + { + @Override + public void run() + { + Response response = queryResource.getServer("id_1", testServletRequest); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); + waitForCancellationLatch.countDown(); + waitFinishLatch.countDown(); + } + } ); - Assert.assertNotNull(respone); - Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), respone.getStatus()); + waitFinishLatch.await(); + } + + @After + public void tearDown() + { + EasyMock.verify(testServletRequest); } } diff --git a/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java b/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java index 51f5cbb88527..71147cdaa7bb 100644 --- a/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java +++ b/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java @@ -25,6 +25,11 @@ import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.indexing.IndexingServiceClient; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; import io.druid.timeline.DataSegment; import org.easymock.EasyMock; import org.joda.time.Interval; @@ -32,6 +37,7 @@ import org.junit.Before; import org.junit.Test; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; import java.util.ArrayList; import java.util.HashMap; @@ -47,10 +53,12 @@ public class DatasourcesResourceTest private DruidServer server; private List listDataSources; private List dataSegmentList; + private HttpServletRequest request; @Before public void setUp() { + request = EasyMock.createStrictMock(HttpServletRequest.class); inventoryView = EasyMock.createStrictMock(CoordinatorServerView.class); server = EasyMock.createStrictMock(DruidServer.class); dataSegmentList = new ArrayList<>(); @@ -94,8 +102,12 @@ public void setUp() ) ); listDataSources = new ArrayList<>(); - listDataSources.add(new DruidDataSource("datasource1", new HashMap()).addSegment("part1", dataSegmentList.get(0))); - listDataSources.add(new DruidDataSource("datasource2", new HashMap()).addSegment("part1", dataSegmentList.get(1))); + listDataSources.add( + new DruidDataSource("datasource1", new HashMap()).addSegment("part1", dataSegmentList.get(0)) + ); + listDataSources.add( + new DruidDataSource("datasource2", new HashMap()).addSegment("part1", dataSegmentList.get(1)) + ); } @Test @@ -108,8 +120,8 @@ public void testGetFullQueryableDataSources() throws Exception ImmutableList.of(server) ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); - Response response = datasourcesResource.getQueryableDataSources("full", null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); + Response response = datasourcesResource.getQueryableDataSources("full", null, request); Set result = (Set) response.getEntity(); DruidDataSource[] resultantDruidDataSources = new DruidDataSource[result.size()]; result.toArray(resultantDruidDataSources); @@ -117,7 +129,7 @@ public void testGetFullQueryableDataSources() throws Exception Assert.assertEquals(2, resultantDruidDataSources.length); Assert.assertArrayEquals(listDataSources.toArray(), resultantDruidDataSources); - response = datasourcesResource.getQueryableDataSources(null, null); + response = datasourcesResource.getQueryableDataSources(null, null, request); List result1 = (List) response.getEntity(); Assert.assertEquals(200, response.getStatus()); Assert.assertEquals(2, result1.size()); @@ -126,6 +138,53 @@ public void testGetFullQueryableDataSources() throws Exception EasyMock.verify(inventoryView, server); } + @Test + public void testSecuredGetFullQueryableDataSources() throws Exception + { + EasyMock.expect(server.getDataSources()).andReturn( + ImmutableList.of(listDataSources.get(0), listDataSources.get(1)) + ).atLeastOnce(); + EasyMock.expect(inventoryView.getInventory()).andReturn( + ImmutableList.of(server) + ).atLeastOnce(); + EasyMock.expect(request.getAttribute(AuthConfig.DRUID_AUTH_TOKEN)).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + if (resource.getName().equals("datasource1")) { + return new Access(true); + } else { + return new Access(false); + } + } + } + ).times(2); + EasyMock.replay(inventoryView, server, request); + + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig(true)); + Response response = datasourcesResource.getQueryableDataSources("full", null, request); + Set result = (Set) response.getEntity(); + DruidDataSource[] resultantDruidDataSources = new DruidDataSource[result.size()]; + result.toArray(resultantDruidDataSources); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(1, resultantDruidDataSources.length); + Assert.assertArrayEquals(listDataSources.subList(0, 1).toArray(), resultantDruidDataSources); + + response = datasourcesResource.getQueryableDataSources(null, null, request); + List result1 = (List) response.getEntity(); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(1, result1.size()); + Assert.assertTrue(result1.contains("datasource1")); + + EasyMock.verify(inventoryView, server, request); + } + @Test public void testGetSimpleQueryableDataSources() throws Exception { @@ -145,8 +204,8 @@ public void testGetSimpleQueryableDataSources() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); - Response response = datasourcesResource.getQueryableDataSources(null, "simple"); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); + Response response = datasourcesResource.getQueryableDataSources(null, "simple", request); Assert.assertEquals(200, response.getStatus()); List> results = (List>) response.getEntity(); int index = 0; @@ -172,7 +231,7 @@ public void testFullGetTheDataSource() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getTheDataSource("datasource1", "full"); DruidDataSource result = (DruidDataSource) response.getEntity(); Assert.assertEquals(200, response.getStatus()); @@ -189,7 +248,7 @@ public void testNullGetTheDataSource() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Assert.assertEquals(204, datasourcesResource.getTheDataSource("none", null).getStatus()); EasyMock.verify(inventoryView, server); } @@ -211,7 +270,7 @@ public void testSimpleGetTheDataSource() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getTheDataSource("datasource1", null); Assert.assertEquals(200, response.getStatus()); Map> result = (Map>) response.getEntity(); @@ -250,7 +309,7 @@ public void testSimpleGetTheDataSourceManyTiers() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server, server2, server3); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getTheDataSource("datasource1", null); Assert.assertEquals(200, response.getStatus()); Map> result = (Map>) response.getEntity(); @@ -281,7 +340,7 @@ public void testGetSegmentDataSourceIntervals() List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-22T00:00:00.000Z/2010-01-23T00:00:00.000Z")); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getSegmentDataSourceIntervals("invalidDataSource", null, null); Assert.assertEquals(response.getEntity(), null); @@ -328,7 +387,7 @@ public void testGetSegmentDataSourceSpecificInterval() ).atLeastOnce(); EasyMock.replay(inventoryView); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getSegmentDataSourceSpecificInterval( "invalidDataSource", "2010-01-01/P1D", @@ -395,7 +454,7 @@ public void testDeleteDataSourceSpecificInterval() throws Exception EasyMock.expectLastCall().once(); EasyMock.replay(indexingServiceClient, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient, new AuthConfig()); Response response = datasourcesResource.deleteDataSourceSpecificInterval("datasource1", interval); Assert.assertEquals(200, response.getStatus()); @@ -407,7 +466,7 @@ public void testDeleteDataSourceSpecificInterval() throws Exception public void testDeleteDataSource() { IndexingServiceClient indexingServiceClient = EasyMock.createStrictMock(IndexingServiceClient.class); EasyMock.replay(indexingServiceClient, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient, new AuthConfig()); Response response = datasourcesResource.deleteDataSource("datasource", "true", "???"); Assert.assertEquals(400, response.getStatus()); Assert.assertNotNull(response.getEntity()); diff --git a/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java b/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java index b77842bff8dd..4fb50795c85a 100644 --- a/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java +++ b/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java @@ -22,13 +22,16 @@ import com.google.common.collect.ImmutableList; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.security.AuthConfig; import io.druid.timeline.DataSegment; import org.easymock.EasyMock; import org.joda.time.Interval; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; import java.util.ArrayList; import java.util.List; @@ -40,12 +43,15 @@ public class IntervalsResourceTest private InventoryView inventoryView; private DruidServer server; private List dataSegmentList; + private HttpServletRequest request; @Before public void setUp() { inventoryView = EasyMock.createStrictMock(InventoryView.class); server = EasyMock.createStrictMock(DruidServer.class); + request = EasyMock.createStrictMock(HttpServletRequest.class); + dataSegmentList = new ArrayList<>(); dataSegmentList.add( new DataSegment( @@ -103,9 +109,9 @@ public void testGetIntervals() List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); expectedIntervals.add(new Interval("2010-01-22T00:00:00.000Z/2010-01-23T00:00:00.000Z")); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getIntervals(); + Response response = intervalsResource.getIntervals(request); TreeMap>> actualIntervals = (TreeMap) response.getEntity(); Assert.assertEquals(2, actualIntervals.size()); Assert.assertEquals(expectedIntervals.get(1), actualIntervals.firstKey()); @@ -117,7 +123,6 @@ public void testGetIntervals() Assert.assertEquals(5L, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("size")); Assert.assertEquals(1, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("count")); - EasyMock.verify(inventoryView); } @Test @@ -130,16 +135,15 @@ public void testSimpleGetSpecificIntervals() List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", "simple", null); + Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", "simple", null, request); Map> actualIntervals = (Map) response.getEntity(); Assert.assertEquals(1, actualIntervals.size()); Assert.assertTrue(actualIntervals.containsKey(expectedIntervals.get(0))); Assert.assertEquals(25L, actualIntervals.get(expectedIntervals.get(0)).get("size")); Assert.assertEquals(2, actualIntervals.get(expectedIntervals.get(0)).get("count")); - EasyMock.verify(inventoryView); } @Test @@ -152,9 +156,9 @@ public void testFullGetSpecificIntervals() List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, "full"); + Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, "full", request); TreeMap>> actualIntervals = (TreeMap) response.getEntity(); Assert.assertEquals(1, actualIntervals.size()); Assert.assertEquals(expectedIntervals.get(0), actualIntervals.firstKey()); @@ -163,7 +167,6 @@ public void testFullGetSpecificIntervals() Assert.assertEquals(5L, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("size")); Assert.assertEquals(1, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("count")); - EasyMock.verify(inventoryView); } @Test @@ -174,14 +177,19 @@ public void testGetSpecificIntervals() ).atLeastOnce(); EasyMock.replay(inventoryView); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, null); + Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, null, request); Map actualIntervals = (Map) response.getEntity(); Assert.assertEquals(2, actualIntervals.size()); Assert.assertEquals(25L, actualIntervals.get("size")); Assert.assertEquals(2, actualIntervals.get("count")); + } + + @After + public void tearDown() { EasyMock.verify(inventoryView); } + } diff --git a/server/src/test/java/io/druid/server/http/RulesResourceTest.java b/server/src/test/java/io/druid/server/http/RulesResourceTest.java index 283026f82cf6..d153397cee95 100644 --- a/server/src/test/java/io/druid/server/http/RulesResourceTest.java +++ b/server/src/test/java/io/druid/server/http/RulesResourceTest.java @@ -20,12 +20,10 @@ package io.druid.server.http; import com.google.common.collect.ImmutableList; - import io.druid.audit.AuditEntry; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.metadata.MetadataRuleManager; - import org.easymock.EasyMock; import org.joda.time.DateTime; import org.joda.time.Interval; @@ -34,7 +32,6 @@ import org.junit.Test; import javax.ws.rs.core.Response; - import java.util.List; import java.util.Map; @@ -255,4 +252,5 @@ public void testGetAllDatasourcesRuleHistoryWithWrongCount() EasyMock.verify(auditManager); } + } diff --git a/server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java b/server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java new file mode 100644 index 000000000000..ae317314b21e --- /dev/null +++ b/server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java @@ -0,0 +1,245 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Binder; +import com.google.inject.Guice; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Module; +import com.sun.jersey.spi.container.ContainerRequest; +import com.sun.jersey.spi.container.ResourceFilter; +import com.sun.jersey.spi.container.ResourceFilters; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import org.easymock.EasyMock; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.DELETE; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.core.MultivaluedMap; +import javax.ws.rs.core.PathSegment; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +public class ResourceFilterTestHelper +{ + public HttpServletRequest req; + public AuthorizationInfo authorizationInfo; + public ContainerRequest request; + + public void setUp(ResourceFilter resourceFilter) throws Exception + { + req = EasyMock.createStrictMock(HttpServletRequest.class); + request = EasyMock.createStrictMock(ContainerRequest.class); + authorizationInfo = EasyMock.createStrictMock(AuthorizationInfo.class); + + // Memory barrier + synchronized (this) { + ((AbstractResourceFilter) resourceFilter).setReq(req); + } + } + + public void setUpMockExpectations( + String requestPath, + boolean authCheckResult, + String requestMethod + ) + { + EasyMock.expect(request.getPath()).andReturn(requestPath).anyTimes(); + EasyMock.expect(request.getPathSegments()).andReturn( + ImmutableList.copyOf( + Iterables.transform( + Arrays.asList(requestPath.split("/")), + new Function() + { + @Override + public PathSegment apply(final String input) + { + return new PathSegment() + { + @Override + public String getPath() + { + return input; + } + + @Override + public MultivaluedMap getMatrixParameters() + { + return null; + } + }; + } + } + ) + ) + ).anyTimes(); + EasyMock.expect(request.getMethod()).andReturn(requestMethod).anyTimes(); + EasyMock.expect(req.getAttribute(EasyMock.anyString())).andReturn(authorizationInfo).atLeastOnce(); + EasyMock.expect(authorizationInfo.isAuthorized( + EasyMock.anyObject(Resource.class), + EasyMock.anyObject(Action.class) + )).andReturn( + new Access(authCheckResult) + ).atLeastOnce(); + + } + + public static Collection getRequestPaths(final Class clazz) + { + return getRequestPaths(clazz, ImmutableList.>of(), ImmutableList.>of()); + } + + public static Collection getRequestPaths( + final Class clazz, + final Iterable> mockableInjections + ) + { + return getRequestPaths(clazz, mockableInjections, ImmutableList.>of()); + } + + public static Collection getRequestPaths( + final Class clazz, + final Iterable> mockableInjections, + final Iterable> mockableKeys + ) + { + return getRequestPaths(clazz, mockableInjections, mockableKeys, ImmutableList.of()); + } + + // Feeds in an array of [ PathName, MethodName, ResourceFilter , Injector] + public static Collection getRequestPaths( + final Class clazz, + final Iterable> mockableInjections, + final Iterable> mockableKeys, + final Iterable injectedObjs + ) + { + final Injector injector = Guice.createInjector( + new Module() + { + @Override + public void configure(Binder binder) + { + for (Class clazz : mockableInjections) { + binder.bind(clazz).toInstance(EasyMock.createNiceMock(clazz)); + } + for (Object obj : injectedObjs) { + binder.bind((Class) obj.getClass()).toInstance(obj); + } + for (Key key : mockableKeys) { + binder.bind((Key) key).toInstance(EasyMock.createNiceMock(key.getTypeLiteral().getRawType())); + } + binder.bind(AuthConfig.class).toInstance(new AuthConfig(true)); + } + } + ); + final String basepath = ((Path) clazz.getAnnotation(Path.class)).value().substring(1); //Ignore the first "/" + final List> baseResourceFilters = + clazz.getAnnotation(ResourceFilters.class) == null ? Collections.>emptyList() : + ImmutableList.copyOf(((ResourceFilters) clazz.getAnnotation(ResourceFilters.class)).value()); + + return ImmutableList.copyOf( + Iterables.concat( + // Step 3 - Merge all the Objects arrays for each endpoints + Iterables.transform( + // Step 2 - + // For each endpoint, make an Object array containing + // - Request Path like "druid/../../.." + // - Request Method like "GET" or "POST" or "DELETE" + // - Resource Filter instance for the endpoint + Iterables.filter( + // Step 1 - + // Filter out non resource endpoint methods + // and also the endpoints that does not have any + // ResourceFilters applied to them + ImmutableList.copyOf(clazz.getDeclaredMethods()), + new Predicate() + { + @Override + public boolean apply(Method input) + { + return input.getAnnotation(GET.class) != null + || input.getAnnotation(POST.class) != null + || input.getAnnotation(DELETE.class) != null + && (input.getAnnotation(ResourceFilters.class) != null + || !baseResourceFilters.isEmpty()); + } + } + ), + new Function>() + { + @Override + public Collection apply(final Method method) + { + final List> resourceFilters = + method.getAnnotation(ResourceFilters.class) == null ? baseResourceFilters : + ImmutableList.copyOf(method.getAnnotation(ResourceFilters.class).value()); + + return Collections2.transform( + resourceFilters, + new Function, Object[]>() + { + @Override + public Object[] apply(Class input) + { + if (method.getAnnotation(Path.class) != null) { + return new Object[]{ + String.format("%s%s", basepath, method.getAnnotation(Path.class).value()), + input.getAnnotation(GET.class) == null ? (method.getAnnotation(DELETE.class) == null + ? "POST" + : "DELETE") : "GET", + injector.getInstance(input), + injector + }; + } else { + return new Object[]{ + basepath, + input.getAnnotation(GET.class) == null ? (method.getAnnotation(DELETE.class) == null + ? "POST" + : "DELETE") : "GET", + injector.getInstance(input), + injector + }; + } + } + } + ); + } + } + ) + ) + ); + } +} diff --git a/server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java b/server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java new file mode 100644 index 000000000000..4a7cd0de8258 --- /dev/null +++ b/server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Injector; +import com.sun.jersey.spi.container.ResourceFilter; +import io.druid.server.ClientInfoResource; +import io.druid.server.QueryResource; +import io.druid.server.StatusResource; +import io.druid.server.http.BrokerResource; +import io.druid.server.http.CoordinatorDynamicConfigsResource; +import io.druid.server.http.CoordinatorResource; +import io.druid.server.http.DatasourcesResource; +import io.druid.server.http.HistoricalResource; +import io.druid.server.http.IntervalsResource; +import io.druid.server.http.MetadataResource; +import io.druid.server.http.RulesResource; +import io.druid.server.http.ServersResource; +import io.druid.server.http.TiersResource; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; +import java.util.Collection; + +@RunWith(Parameterized.class) +public class SecurityResourceFilterTest extends ResourceFilterTestHelper +{ + @Parameterized.Parameters + public static Collection data() + { + return ImmutableList.copyOf( + Iterables.concat( + getRequestPaths(CoordinatorResource.class), + getRequestPaths(DatasourcesResource.class), + getRequestPaths(BrokerResource.class), + getRequestPaths(HistoricalResource.class), + getRequestPaths(IntervalsResource.class), + getRequestPaths(MetadataResource.class), + getRequestPaths(RulesResource.class), + getRequestPaths(ServersResource.class), + getRequestPaths(TiersResource.class), + getRequestPaths(ClientInfoResource.class), + getRequestPaths(CoordinatorDynamicConfigsResource.class), + getRequestPaths(QueryResource.class), + getRequestPaths(StatusResource.class) + ) + ); + } + + private final String requestPath; + private final String requestMethod; + private final ResourceFilter resourceFilter; + private final Injector injector; + + public SecurityResourceFilterTest( + String requestPath, + String requestMethod, + ResourceFilter resourceFilter, + Injector injector + ) + { + this.requestPath = requestPath; + this.requestMethod = requestMethod; + this.resourceFilter = resourceFilter; + this.injector = injector; + } + + @Before + public void setUp() throws Exception + { + setUp(resourceFilter); + } + + @Test + public void testDatasourcesResourcesFilteringAccess() + { + setUpMockExpectations(requestPath, true, requestMethod); + EasyMock.replay(req, request, authorizationInfo); + Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + resourceFilter.getRequestFilter().filter(request); + EasyMock.verify(req, request, authorizationInfo); + } + + @Test(expected = WebApplicationException.class) + public void testDatasourcesResourcesFilteringNoAccess() + { + setUpMockExpectations(requestPath, false, requestMethod); + EasyMock.replay(req, request, authorizationInfo); + //Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + try { + resourceFilter.getRequestFilter().filter(request); + } + catch (WebApplicationException e) { + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), e.getResponse().getStatus()); + throw e; + } + EasyMock.verify(req, request, authorizationInfo); + } + + @Test + public void testDatasourcesResourcesFilteringBadPath() + { + EasyMock.replay(req, request, authorizationInfo); + final String badRequestPath = requestPath.replaceAll("\\w+", "droid"); + Assert.assertFalse(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(badRequestPath)); + EasyMock.verify(req, request, authorizationInfo); + } + +}