diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClient.java index 1894336043dc..927130e0ca7c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClient.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.indexing.client; import com.google.common.collect.ImmutableSet; +import org.apache.druid.client.indexing.TaskStatusResponse; import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskStatus; @@ -37,6 +38,7 @@ public class IndexerWorkerManagerClient implements WorkerManagerClient { private final OverlordClient overlordClient; + private final TaskLocationFetcher locationFetcher = new TaskLocationFetcher(); public IndexerWorkerManagerClient(final OverlordClient overlordClient) { @@ -65,16 +67,7 @@ public Map statuses(Set taskIds) @Override public TaskLocation location(String workerId) { - final TaskStatus response = FutureUtils.getUnchecked( - overlordClient.taskStatuses(ImmutableSet.of(workerId)), - true - ).get(workerId); - - if (response != null) { - return response.getLocation(); - } else { - return TaskLocation.unknown(); - } + return locationFetcher.getLocation(workerId); } @Override @@ -82,4 +75,31 @@ public void close() { // Nothing to do. The OverlordServiceClient is closed by the JVM lifecycle. } + + private class TaskLocationFetcher + { + TaskLocation getLocation(String workerId) + { + final TaskStatus taskStatus = FutureUtils.getUnchecked( + overlordClient.taskStatuses(ImmutableSet.of(workerId)), + true + ).get(workerId); + + if (taskStatus != null + && !TaskLocation.unknown().equals(taskStatus.getLocation())) { + return taskStatus.getLocation(); + } + + // Retry with the single status API + final TaskStatusResponse statusResponse = FutureUtils.getUnchecked( + overlordClient.taskStatus(workerId), + true + ); + if (statusResponse == null || statusResponse.getStatus() == null) { + return TaskLocation.unknown(); + } else { + return statusResponse.getStatus().getLocation(); + } + } + } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClientTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClientTest.java new file mode 100644 index 000000000000..4b53420cbb9d --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/client/IndexerWorkerManagerClientTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.indexing.client; + +import com.google.common.util.concurrent.Futures; +import org.apache.druid.client.indexing.TaskStatusResponse; +import org.apache.druid.indexer.TaskLocation; +import org.apache.druid.indexer.TaskState; +import org.apache.druid.indexer.TaskStatus; +import org.apache.druid.indexer.TaskStatusPlus; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.rpc.indexing.OverlordClient; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; + +import java.util.Collections; + +public class IndexerWorkerManagerClientTest +{ + + @Test + public void testGetLocationCallsMultiStatusApiByDefault() + { + final OverlordClient overlordClient = Mockito.mock(OverlordClient.class); + + final String taskId = "worker1"; + final TaskLocation expectedLocation = new TaskLocation("localhost", 1000, 1100, null); + Mockito.when(overlordClient.taskStatuses(Collections.singleton(taskId))).thenReturn( + Futures.immediateFuture( + Collections.singletonMap( + taskId, + new TaskStatus(taskId, TaskState.RUNNING, 100L, null, expectedLocation) + ) + ) + ); + + final IndexerWorkerManagerClient managerClient = new IndexerWorkerManagerClient(overlordClient); + Assert.assertEquals(managerClient.location(taskId), expectedLocation); + + Mockito.verify(overlordClient, Mockito.times(1)).taskStatuses(ArgumentMatchers.anySet()); + Mockito.verify(overlordClient, Mockito.never()).taskStatus(ArgumentMatchers.anyString()); + } + + @Test + public void testGetLocationFallsBackToSingleTaskApiIfLocationIsUnknown() + { + final OverlordClient overlordClient = Mockito.mock(OverlordClient.class); + + final String taskId = "worker1"; + Mockito.when(overlordClient.taskStatuses(Collections.singleton(taskId))).thenReturn( + Futures.immediateFuture( + Collections.singletonMap( + taskId, + new TaskStatus(taskId, TaskState.RUNNING, 100L, null, TaskLocation.unknown()) + ) + ) + ); + + final TaskLocation expectedLocation = new TaskLocation("localhost", 1000, 1100, null); + final TaskStatusPlus taskStatus = new TaskStatusPlus( + taskId, + null, + null, + DateTimes.nowUtc(), + DateTimes.nowUtc(), + TaskState.RUNNING, + null, + 100L, + expectedLocation, + "wiki", + null + ); + + Mockito.when(overlordClient.taskStatus(taskId)).thenReturn( + Futures.immediateFuture(new TaskStatusResponse(taskId, taskStatus)) + ); + + final IndexerWorkerManagerClient managerClient = new IndexerWorkerManagerClient(overlordClient); + Assert.assertEquals(managerClient.location(taskId), expectedLocation); + + Mockito.verify(overlordClient, Mockito.times(1)).taskStatuses(ArgumentMatchers.anySet()); + Mockito.verify(overlordClient, Mockito.times(1)).taskStatus(ArgumentMatchers.anyString()); + } + +} diff --git a/server/src/main/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocator.java b/server/src/main/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocator.java index 163c7e14e01a..3f5441318a5f 100644 --- a/server/src/main/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocator.java +++ b/server/src/main/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocator.java @@ -26,6 +26,8 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.client.indexing.TaskStatusResponse; +import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskStatus; @@ -55,6 +57,7 @@ public class SpecificTaskServiceLocator implements ServiceLocator private final String taskId; private final OverlordClient overlordClient; + private final TaskLocationFetcher locationFetcher = new TaskLocationFetcher(); private final Object lock = new Object(); @GuardedBy("lock") @@ -129,14 +132,20 @@ public void onSuccess(final Map taskStatusMap) lastKnownLocation = null; } else { lastKnownState = status.getStatusCode(); - + final TaskLocation location; if (TaskLocation.unknown().equals(status.getLocation())) { + location = locationFetcher.getLocation(); + } else { + location = status.getLocation(); + } + + if (TaskLocation.unknown().equals(location)) { lastKnownLocation = null; } else { lastKnownLocation = new ServiceLocation( - status.getLocation().getHost(), - status.getLocation().getPort(), - status.getLocation().getTlsPort(), + location.getHost(), + location.getPort(), + location.getTlsPort(), StringUtils.format("%s/%s", BASE_PATH, StringUtils.urlEncode(taskId)) ); } @@ -199,4 +208,20 @@ public void close() } } } + + private class TaskLocationFetcher + { + TaskLocation getLocation() + { + final TaskStatusResponse statusResponse = FutureUtils.getUnchecked( + overlordClient.taskStatus(taskId), + true + ); + if (statusResponse == null || statusResponse.getStatus() == null) { + return TaskLocation.unknown(); + } else { + return statusResponse.getStatus().getLocation(); + } + } + } } diff --git a/server/src/test/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocatorTest.java b/server/src/test/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocatorTest.java index 4888078af5dc..f75456797242 100644 --- a/server/src/test/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocatorTest.java +++ b/server/src/test/java/org/apache/druid/rpc/indexing/SpecificTaskServiceLocatorTest.java @@ -22,9 +22,12 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.client.indexing.TaskStatusResponse; import org.apache.druid.indexer.TaskLocation; import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskStatus; +import org.apache.druid.indexer.TaskStatusPlus; +import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; import org.apache.druid.rpc.ServiceLocation; import org.apache.druid.rpc.ServiceLocations; @@ -62,6 +65,25 @@ public void test_locate_noLocationYet() throws Exception { Mockito.when(overlordClient.taskStatuses(Collections.singleton(TASK_ID))) .thenReturn(status(TaskState.RUNNING, TaskLocation.unknown())); + final TaskStatusResponse response = new TaskStatusResponse( + TASK_ID, + new TaskStatusPlus( + TASK_ID, + null, + null, + DateTimes.nowUtc(), + DateTimes.EPOCH, + TaskState.RUNNING, + null, + null, + null, + TaskLocation.unknown(), + null, + null + ) + ); + Mockito.when(overlordClient.taskStatus(TASK_ID)) + .thenReturn(Futures.immediateFuture(response)); final SpecificTaskServiceLocator locator = new SpecificTaskServiceLocator(TASK_ID, overlordClient); final ListenableFuture future = locator.locate(); @@ -94,6 +116,25 @@ public void test_locate_taskSuccess() throws Exception { Mockito.when(overlordClient.taskStatuses(Collections.singleton(TASK_ID))) .thenReturn(status(TaskState.SUCCESS, TaskLocation.unknown())); + final TaskStatusResponse response = new TaskStatusResponse( + TASK_ID, + new TaskStatusPlus( + TASK_ID, + null, + null, + DateTimes.nowUtc(), + DateTimes.EPOCH, + TaskState.FAILED, + null, + null, + 100L, + TaskLocation.unknown(), + null, + null + ) + ); + Mockito.when(overlordClient.taskStatus(TASK_ID)) + .thenReturn(Futures.immediateFuture(response)); final SpecificTaskServiceLocator locator = new SpecificTaskServiceLocator(TASK_ID, overlordClient); final ListenableFuture future = locator.locate(); @@ -105,6 +146,25 @@ public void test_locate_taskFailed() throws Exception { Mockito.when(overlordClient.taskStatuses(Collections.singleton(TASK_ID))) .thenReturn(status(TaskState.FAILED, TaskLocation.unknown())); + final TaskStatusResponse response = new TaskStatusResponse( + TASK_ID, + new TaskStatusPlus( + TASK_ID, + null, + null, + DateTimes.nowUtc(), + DateTimes.EPOCH, + TaskState.FAILED, + null, + null, + 100L, + TaskLocation.unknown(), + null, + null + ) + ); + Mockito.when(overlordClient.taskStatus(TASK_ID)) + .thenReturn(Futures.immediateFuture(response)); final SpecificTaskServiceLocator locator = new SpecificTaskServiceLocator(TASK_ID, overlordClient); final ListenableFuture future = locator.locate();