Skip to content

Commit 3ac07d0

Browse files
M-ElsaeedMohammed Ehab
andcommitted
Ensure EventHandlerLoader Thread Safety. (#95)
* Make handler response buffers thread safe. * Add multiconcurrency tests * ThreadLocal instead of Allocating new buffers every invoke. * Thread local log4jContextPutMethod * Fix indentations * Add CountDownLatch to ensure all calls are done simultaneously. --------- Co-authored-by: Mohammed Ehab <moehabe@amazon.com>
1 parent ba4ef1c commit 3ac07d0

File tree

2 files changed

+96
-20
lines changed

2 files changed

+96
-20
lines changed

aws-lambda-java-runtime-interface-client/src/main/java/com/amazonaws/services/lambda/runtime/api/client/EventHandlerLoader.java

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ private enum Platform {
5757
UNKNOWN
5858
}
5959

60-
private static volatile PojoSerializer<LambdaClientContext> contextSerializer;
61-
private static volatile PojoSerializer<LambdaCognitoIdentity> cognitoSerializer;
60+
private static volatile ThreadLocal<PojoSerializer<LambdaClientContext>> contextSerializer = new ThreadLocal<>();
61+
private static volatile ThreadLocal<PojoSerializer<LambdaCognitoIdentity>> cognitoSerializer = new ThreadLocal<>();
6262

63-
private static final EnumMap<Platform, Map<Type, PojoSerializer<Object>>> typeCache = new EnumMap<>(Platform.class);
63+
private static final ThreadLocal<EnumMap<Platform, Map<Type, PojoSerializer<Object>>>> typeCache = ThreadLocal.withInitial(() -> new EnumMap<>(Platform.class));
6464

6565
private static final Comparator<Method> methodPriority = new Comparator<Method>() {
6666
public int compare(Method lhs, Method rhs) {
@@ -127,10 +127,11 @@ private static PojoSerializer<Object> getSerializer(Platform platform, Type type
127127
}
128128

129129
private static PojoSerializer<Object> getSerializerCached(Platform platform, Type type) {
130-
Map<Type, PojoSerializer<Object>> cache = typeCache.get(platform);
130+
EnumMap<Platform, Map<Type, PojoSerializer<Object>>> threadTypeCache = typeCache.get();
131+
Map<Type, PojoSerializer<Object>> cache = threadTypeCache.get(platform);
131132
if (cache == null) {
132133
cache = new HashMap<>();
133-
typeCache.put(platform, cache);
134+
threadTypeCache.put(platform, cache);
134135
}
135136

136137
PojoSerializer<Object> serializer = cache.get(type);
@@ -143,17 +144,17 @@ private static PojoSerializer<Object> getSerializerCached(Platform platform, Typ
143144
}
144145

145146
private static PojoSerializer<LambdaClientContext> getContextSerializer() {
146-
if (contextSerializer == null) {
147-
contextSerializer = GsonFactory.getInstance().getSerializer(LambdaClientContext.class);
147+
if (contextSerializer.get() == null) {
148+
contextSerializer.set(GsonFactory.getInstance().getSerializer(LambdaClientContext.class));
148149
}
149-
return contextSerializer;
150+
return contextSerializer.get();
150151
}
151152

152153
private static PojoSerializer<LambdaCognitoIdentity> getCognitoSerializer() {
153-
if (cognitoSerializer == null) {
154-
cognitoSerializer = GsonFactory.getInstance().getSerializer(LambdaCognitoIdentity.class);
154+
if (cognitoSerializer.get() == null) {
155+
cognitoSerializer.set(GsonFactory.getInstance().getSerializer(LambdaCognitoIdentity.class));
155156
}
156-
return cognitoSerializer;
157+
return cognitoSerializer.get();
157158
}
158159

159160

@@ -527,15 +528,14 @@ private static LambdaRequestHandler wrapPojoHandler(RequestHandler instance, Typ
527528

528529
private static LambdaRequestHandler wrapRequestStreamHandler(final RequestStreamHandler handler) {
529530
return new LambdaRequestHandler() {
530-
private final ByteArrayOutputStream output = new ByteArrayOutputStream(1024);
531-
private Functions.V2<String, String> log4jContextPutMethod = null;
531+
private final ThreadLocal<ByteArrayOutputStream> outputBuffers = ThreadLocal.withInitial(() -> new ByteArrayOutputStream(1024));
532+
private ThreadLocal<Functions.V2<String, String>> log4jContextPutMethod = new ThreadLocal<>();
532533

533-
private void safeAddRequestIdToLog4j(String log4jContextClassName,
534-
InvocationRequest request, Class contextMapValueClass) {
534+
private void safeAddRequestIdToLog4j(String log4jContextClassName, InvocationRequest request, Class contextMapValueClass) {
535535
try {
536536
Class<?> log4jContextClass = ReflectUtil.loadClass(AWSLambda.getCustomerClassLoader(), log4jContextClassName);
537-
log4jContextPutMethod = ReflectUtil.loadStaticV2(log4jContextClass, "put", false, String.class, contextMapValueClass);
538-
log4jContextPutMethod.call("AWSRequestId", request.getId());
537+
log4jContextPutMethod.set(ReflectUtil.loadStaticV2(log4jContextClass, "put", false, String.class, contextMapValueClass));
538+
log4jContextPutMethod.get().call("AWSRequestId", request.getId());
539539
} catch (Exception e) {
540540
// nothing to do here
541541
}
@@ -558,6 +558,7 @@ private void safeAddContextToLambdaLogger(LambdaContext context) {
558558
}
559559

560560
public ByteArrayOutputStream call(InvocationRequest request) throws Error, Exception {
561+
ByteArrayOutputStream output = outputBuffers.get();
561562
output.reset();
562563

563564
LambdaCognitoIdentity cognitoIdentity = null;
@@ -591,7 +592,7 @@ public ByteArrayOutputStream call(InvocationRequest request) throws Error, Excep
591592
safeAddRequestIdToLog4j("org.apache.log4j.MDC", request, Object.class);
592593
safeAddRequestIdToLog4j("org.apache.logging.log4j.ThreadContext", request, String.class);
593594
// if put method not assigned in either call to safeAddRequestIdtoLog4j then log4jContextPutMethod = null
594-
if (log4jContextPutMethod == null) {
595+
if (log4jContextPutMethod.get() == null) {
595596
System.err.println("Customer using log4j appender but unable to load either "
596597
+ "org.apache.log4j.MDC or org.apache.logging.log4j.ThreadContext. "
597598
+ "Customer cannot see RequestId in log4j log lines.");

aws-lambda-java-runtime-interface-client/src/test/java/com/amazonaws/services/lambda/runtime/api/client/EventHandlerLoaderTest.java

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@
44
import org.junit.jupiter.api.Test;
55

66
import java.io.ByteArrayOutputStream;
7+
import java.util.ArrayList;
8+
import java.util.List;
9+
import java.util.concurrent.CountDownLatch;
10+
import java.util.concurrent.ExecutorService;
11+
import java.util.concurrent.Executors;
12+
import java.util.concurrent.Future;
13+
import java.util.concurrent.TimeUnit;
714

815
import static org.junit.jupiter.api.Assertions.assertEquals;
16+
import static org.junit.jupiter.api.Assertions.assertTrue;
917

1018
class EventHandlerLoaderTest {
1119

@@ -37,7 +45,6 @@ void PojoHandlerTest_oneParamEvent() throws Exception {
3745
assertSuccessfulInvocation(lambdaRequestHandler);
3846
}
3947

40-
4148
@Test
4249
void PojoHandlerTest_oneParamContext() throws Exception {
4350
String handler = "test.lambda.handlers.POJOHanlderImpl::oneParamHandler_context";
@@ -74,4 +81,72 @@ private static InvocationRequest getTestInvocationRequest() {
7481
invocationRequest.setXrayTraceId("traceId");
7582
return invocationRequest;
7683
}
77-
}
84+
85+
// Multithreaded test methods
86+
87+
@Test
88+
void RequestHandlerTest_Multithreaded() throws Exception {
89+
testHandlerConcurrency("test.lambda.handlers.RequestHandlerImpl");
90+
}
91+
92+
@Test
93+
void RequestStreamHandlerTest_Multithreaded() throws Exception {
94+
testHandlerConcurrency("test.lambda.handlers.RequestStreamHandlerImpl");
95+
}
96+
97+
@Test
98+
void PojoHandlerTest_noParams_Multithreaded() throws Exception {
99+
testHandlerConcurrency("test.lambda.handlers.POJOHanlderImpl::noParamsHandler");
100+
}
101+
102+
@Test
103+
void PojoHandlerTest_oneParamEvent_Multithreaded() throws Exception {
104+
testHandlerConcurrency("test.lambda.handlers.POJOHanlderImpl::oneParamHandler_event");
105+
}
106+
107+
@Test
108+
void PojoHandlerTest_oneParamContext_Multithreaded() throws Exception {
109+
testHandlerConcurrency("test.lambda.handlers.POJOHanlderImpl::oneParamHandler_context");
110+
}
111+
112+
@Test
113+
void PojoHandlerTest_twoParams_Multithreaded() throws Exception {
114+
testHandlerConcurrency("test.lambda.handlers.POJOHanlderImpl::twoParamsHandler");
115+
}
116+
117+
private void testHandlerConcurrency(String handlerName) throws Exception {
118+
// Create one handler instance
119+
LambdaRequestHandler handler = getLambdaRequestHandler(handlerName);
120+
121+
int threadCount = 10;
122+
ExecutorService executor = Executors.newFixedThreadPool(threadCount);
123+
List<Future<String>> futures = new ArrayList<>();
124+
CountDownLatch startLatch = new CountDownLatch(1);
125+
126+
try {
127+
for (int i = 0; i < threadCount; i++) {
128+
futures.add(executor.submit(() -> {
129+
try {
130+
InvocationRequest request = getTestInvocationRequest();
131+
startLatch.await();
132+
ByteArrayOutputStream result = handler.call(request);
133+
return result.toString();
134+
} catch (Exception e) {
135+
throw new RuntimeException(e);
136+
}
137+
}));
138+
}
139+
140+
// Release all threads simultaneously and Verify all invocations return the expected result
141+
startLatch.countDown();
142+
143+
for (Future<String> future : futures) {
144+
String result = future.get(5, TimeUnit.SECONDS);
145+
assertEquals("\"success\"", result);
146+
}
147+
} finally {
148+
executor.shutdown();
149+
assertTrue(executor.awaitTermination(10, TimeUnit.SECONDS));
150+
}
151+
}
152+
}

0 commit comments

Comments
 (0)