diff --git a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableClientContext.java b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableClientContext.java index a952db68..bbf5e198 100644 --- a/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableClientContext.java +++ b/azurefunctions/src/main/java/com/microsoft/durabletask/azurefunctions/DurableClientContext.java @@ -8,7 +8,7 @@ import com.microsoft.azure.functions.HttpStatus; import com.microsoft.durabletask.DurableTaskClient; -import com.microsoft.durabletask.DurableTaskGrpcClientBuilder; +import com.microsoft.durabletask.DurableTaskGrpcClientFactory; import com.microsoft.durabletask.OrchestrationMetadata; import com.microsoft.durabletask.OrchestrationRuntimeStatus; @@ -45,6 +45,10 @@ public String getTaskHubName() { * @return the Durable Task client object associated with the current function invocation. */ public DurableTaskClient getClient() { + if (this.client != null) { + return this.client; + } + if (this.rpcBaseUrl == null || this.rpcBaseUrl.length() == 0) { throw new IllegalStateException("The client context wasn't populated with an RPC base URL!"); } @@ -56,7 +60,7 @@ public DurableTaskClient getClient() { throw new IllegalStateException("The client context RPC base URL was invalid!", ex); } - this.client = new DurableTaskGrpcClientBuilder().port(rpcURL.getPort()).build(); + this.client = DurableTaskGrpcClientFactory.getClient(rpcURL.getPort(), null); return this.client; } @@ -78,9 +82,7 @@ public HttpResponseMessage waitForCompletionOrCreateCheckStatusResponse( HttpRequestMessage request, String instanceId, Duration timeout) { - if (this.client == null) { - this.client = getClient(); - } + this.client = getClient(); OrchestrationMetadata orchestration; try { orchestration = this.client.waitForInstanceCompletion(instanceId, timeout, true); diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java index 52d072b8..2ebb6ec8 100644 --- a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java +++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClient.java @@ -61,6 +61,18 @@ public final class DurableTaskGrpcClient extends DurableTaskClient { this.sidecarClient = TaskHubSidecarServiceGrpc.newBlockingStub(sidecarGrpcChannel); } + DurableTaskGrpcClient(int port, String defaultVersion) { + this.dataConverter = new JacksonDataConverter(); + this.defaultVersion = defaultVersion; + + // Need to keep track of this channel so we can dispose it on close() + this.managedSidecarChannel = ManagedChannelBuilder + .forAddress("localhost", port) + .usePlaintext() + .build(); + this.sidecarClient = TaskHubSidecarServiceGrpc.newBlockingStub(this.managedSidecarChannel); + } + /** * Closes the internally managed gRPC channel, if one exists. *

diff --git a/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClientFactory.java b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClientFactory.java new file mode 100644 index 00000000..f22a4ba7 --- /dev/null +++ b/client/src/main/java/com/microsoft/durabletask/DurableTaskGrpcClientFactory.java @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.durabletask; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +public final class DurableTaskGrpcClientFactory { + private static final ConcurrentMap portToClientMap = new ConcurrentHashMap<>(); + + // Private to prevent instantiation and enforce a singleton pattern + private DurableTaskGrpcClientFactory() { + } + + public static DurableTaskClient getClient(int port, String defaultVersion) { + return portToClientMap.computeIfAbsent(port, p -> new DurableTaskGrpcClient(p, defaultVersion)); + } +} \ No newline at end of file diff --git a/client/src/test/java/com/microsoft/durabletask/DurableTaskGrpcClientFactoryTest.java b/client/src/test/java/com/microsoft/durabletask/DurableTaskGrpcClientFactoryTest.java new file mode 100644 index 00000000..561f1659 --- /dev/null +++ b/client/src/test/java/com/microsoft/durabletask/DurableTaskGrpcClientFactoryTest.java @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.durabletask; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for DurableTaskGrpcClientFactory. + */ +public class DurableTaskGrpcClientFactoryTest { + + private static final String DEFAULT_VERSION = null; + + @Test + void getClient_samePort_returnsSameInstance() { + // Arrange + int port = 5001; + + // Act + DurableTaskClient client1 = DurableTaskGrpcClientFactory.getClient(port, DEFAULT_VERSION); + DurableTaskClient client2 = DurableTaskGrpcClientFactory.getClient(port, DEFAULT_VERSION); + + // Assert + assertNotNull(client1, "First client should not be null"); + assertNotNull(client2, "Second client should not be null"); + assertSame(client1, client2, "getClient should return the same instance for the same port"); + } + + @Test + void getClient_differentPorts_returnsDifferentInstances() { + // Arrange + int port1 = 5002; + int port2 = 5003; + + // Act + DurableTaskClient client1 = DurableTaskGrpcClientFactory.getClient(port1, DEFAULT_VERSION); + DurableTaskClient client2 = DurableTaskGrpcClientFactory.getClient(port2, DEFAULT_VERSION); + + // Assert + assertNotNull(client1, "Client for port1 should not be null"); + assertNotNull(client2, "Client for port2 should not be null"); + assertNotSame(client1, client2, "getClient should return different instances for different ports"); + } + + @Test + void getClient_multiplePorts_maintainsCorrectMapping() { + // Arrange + int port1 = 5004; + int port2 = 5005; + int port3 = 5006; + + // Act + DurableTaskClient client1 = DurableTaskGrpcClientFactory.getClient(port1, DEFAULT_VERSION); + DurableTaskClient client2 = DurableTaskGrpcClientFactory.getClient(port2, DEFAULT_VERSION); + DurableTaskClient client3 = DurableTaskGrpcClientFactory.getClient(port3, DEFAULT_VERSION); + + // Request the same ports again + DurableTaskClient client1Again = DurableTaskGrpcClientFactory.getClient(port1, DEFAULT_VERSION); + DurableTaskClient client2Again = DurableTaskGrpcClientFactory.getClient(port2, DEFAULT_VERSION); + DurableTaskClient client3Again = DurableTaskGrpcClientFactory.getClient(port3, DEFAULT_VERSION); + + // Assert + // Verify each port returns the same instance + assertSame(client1, client1Again, "Port " + port1 + " should return the same instance"); + assertSame(client2, client2Again, "Port " + port2 + " should return the same instance"); + assertSame(client3, client3Again, "Port " + port3 + " should return the same instance"); + + // Verify all instances are different from each other + assertNotSame(client1, client2, "Client for port1 and port2 should be different"); + assertNotSame(client1, client3, "Client for port1 and port3 should be different"); + assertNotSame(client2, client3, "Client for port2 and port3 should be different"); + } +}