diff --git a/dd-java-agent/instrumentation/aws-java-sqs-1.0/src/main/java/datadog/trace/instrumentation/aws/v1/sqs/SqsInterceptor.java b/dd-java-agent/instrumentation/aws-java-sqs-1.0/src/main/java/datadog/trace/instrumentation/aws/v1/sqs/SqsInterceptor.java index 4b2e431a6a6..4b353f12591 100644 --- a/dd-java-agent/instrumentation/aws-java-sqs-1.0/src/main/java/datadog/trace/instrumentation/aws/v1/sqs/SqsInterceptor.java +++ b/dd-java-agent/instrumentation/aws-java-sqs-1.0/src/main/java/datadog/trace/instrumentation/aws/v1/sqs/SqsInterceptor.java @@ -12,6 +12,7 @@ import com.amazonaws.AmazonWebServiceRequest; import com.amazonaws.handlers.RequestHandler2; +import com.amazonaws.services.sqs.model.MessageAttributeValue; import com.amazonaws.services.sqs.model.ReceiveMessageRequest; import com.amazonaws.services.sqs.model.SendMessageBatchRequest; import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry; @@ -22,7 +23,11 @@ import datadog.trace.api.datastreams.DataStreamsContext; import datadog.trace.bootstrap.ContextStore; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import java.util.ArrayList; +import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; public class SqsInterceptor extends RequestHandler2 { @@ -42,9 +47,14 @@ public AmazonWebServiceRequest beforeMarshalling(AmazonWebServiceRequest request Propagator dsmPropagator = Propagators.forConcern(DSM_CONCERN); Context context = newContext(request, queueUrl); + // making a copy of the MessageAttributes before modifying them because they can be stored in + // a kind of ImmutableMap + Map messageAttributes = + new HashMap<>(smRequest.getMessageAttributes()); + dsmPropagator.inject(context, messageAttributes, SETTER); // note: modifying message attributes has to be done before marshalling, otherwise the changes // are not reflected in the actual request (and the MD5 check on send will fail). - dsmPropagator.inject(context, smRequest.getMessageAttributes(), SETTER); + smRequest.setMessageAttributes(messageAttributes); } else if (request instanceof SendMessageBatchRequest) { SendMessageBatchRequest smbRequest = (SendMessageBatchRequest) request; @@ -54,13 +64,18 @@ public AmazonWebServiceRequest beforeMarshalling(AmazonWebServiceRequest request Propagator dsmPropagator = Propagators.forConcern(DSM_CONCERN); Context context = newContext(request, queueUrl); for (SendMessageBatchRequestEntry entry : smbRequest.getEntries()) { - dsmPropagator.inject(context, entry.getMessageAttributes(), SETTER); + Map messageAttributes = + new HashMap<>(entry.getMessageAttributes()); + dsmPropagator.inject(context, messageAttributes, SETTER); + entry.setMessageAttributes(messageAttributes); } } else if (request instanceof ReceiveMessageRequest) { ReceiveMessageRequest rmRequest = (ReceiveMessageRequest) request; if (rmRequest.getMessageAttributeNames().size() < 10 && !rmRequest.getMessageAttributeNames().contains(DATADOG_KEY)) { - rmRequest.getMessageAttributeNames().add(DATADOG_KEY); + List attributeNames = new ArrayList<>(rmRequest.getMessageAttributeNames()); + attributeNames.add(DATADOG_KEY); + rmRequest.setMessageAttributeNames(attributeNames); } } return request; diff --git a/dd-java-agent/instrumentation/aws-java-sqs-1.0/src/test/groovy/SqsClientTest.groovy b/dd-java-agent/instrumentation/aws-java-sqs-1.0/src/test/groovy/SqsClientTest.groovy index e2d3961f79e..9ae956e5068 100644 --- a/dd-java-agent/instrumentation/aws-java-sqs-1.0/src/test/groovy/SqsClientTest.groovy +++ b/dd-java-agent/instrumentation/aws-java-sqs-1.0/src/test/groovy/SqsClientTest.groovy @@ -7,7 +7,9 @@ import com.amazonaws.client.builder.AwsClientBuilder import com.amazonaws.services.sqs.AmazonSQSClientBuilder import com.amazonaws.services.sqs.model.Message import com.amazonaws.services.sqs.model.MessageAttributeValue +import com.amazonaws.services.sqs.model.ReceiveMessageRequest import com.amazonaws.services.sqs.model.SendMessageRequest +import com.google.common.collect.ImmutableMap import datadog.trace.agent.test.naming.VersionedNamingTestBase import datadog.trace.agent.test.utils.TraceUtils import datadog.trace.api.Config @@ -87,9 +89,9 @@ abstract class SqsClientTest extends VersionedNamingTestBase { def "trace details propagated via SQS system message attributes"() { setup: def client = AmazonSQSClientBuilder.standard() - .withEndpointConfiguration(endpoint) - .withCredentials(credentialsProvider) - .build() + .withEndpointConfiguration(endpoint) + .withCredentials(credentialsProvider) + .build() def queueUrl = client.createQueue('somequeue').queueUrl TEST_WRITER.clear() @@ -188,6 +190,56 @@ abstract class SqsClientTest extends VersionedNamingTestBase { client.shutdown() } + @IgnoreIf({ !instance.isDataStreamsEnabled() }) + def "propagation even when message attributes are readonly"() { + setup: + def client = AmazonSQSClientBuilder.standard() + .withEndpointConfiguration(endpoint) + .withCredentials(credentialsProvider) + .build() + def queueUrl = client.createQueue('somequeue').queueUrl + TEST_WRITER.clear() + + when: + TraceUtils.runUnderTrace('parent', { + def my_attribute = new MessageAttributeValue() + my_attribute.setStringValue("hello world") + my_attribute.setDataType("String") + def readonlyAttributes = ImmutableMap.of("my_key", my_attribute) + def req = new SendMessageRequest(queueUrl, 'sometext') + req.setMessageAttributes(readonlyAttributes) + client.sendMessage(req) + }) + + TEST_DATA_STREAMS_WRITER.waitForGroups(1) + + then: + assertTraces(1) { + trace(2) { + basicSpan(it, "parent") + span { + serviceName expectedService("SQS", "SendMessage") + operationName expectedOperation("SQS", "SendMessage") + resourceName "SQS.SendMessage" + spanType DDSpanTypes.HTTP_CLIENT + errored false + childOf(span(0)) + } + } + } + + and: + def recv = new ReceiveMessageRequest(queueUrl) + recv.withMessageAttributeNames("my_key") + def messages = client.receiveMessage(recv).messages + + assert messages[0].messageAttributes.containsKey("my_key") // what we set initially + assert messages[0].messageAttributes.containsKey("_datadog") // what was injected + + cleanup: + client.shutdown() + } + @IgnoreIf({ instance.isDataStreamsEnabled() }) def "trace details propagated via embedded SQS message attribute (string)"() { setup: @@ -196,8 +248,8 @@ abstract class SqsClientTest extends VersionedNamingTestBase { when: def message = new Message() message.addMessageAttributesEntry('_datadog', new MessageAttributeValue().withDataType('String').withStringValue( - "{\"x-datadog-trace-id\": \"4948377316357291421\", \"x-datadog-parent-id\": \"6746998015037429512\", \"x-datadog-sampling-priority\": \"1\"}" - )) + "{\"x-datadog-trace-id\": \"4948377316357291421\", \"x-datadog-parent-id\": \"6746998015037429512\", \"x-datadog-sampling-priority\": \"1\"}" + )) def messages = new TracingList([message], "http://localhost:${address.port}/000000000000/somequeue") messages.forEach {/* consume to create message spans */ } @@ -237,8 +289,8 @@ abstract class SqsClientTest extends VersionedNamingTestBase { when: def message = new Message() message.addMessageAttributesEntry('_datadog', new MessageAttributeValue().withDataType('Binary').withBinaryValue( - headerValue - )) + headerValue + )) def messages = new TracingList([message], "http://localhost:${address.port}/000000000000/somequeue") messages.forEach {/* consume to create message spans */ } @@ -281,9 +333,9 @@ abstract class SqsClientTest extends VersionedNamingTestBase { def "trace details propagated from SQS to JMS"() { setup: def client = AmazonSQSClientBuilder.standard() - .withEndpointConfiguration(endpoint) - .withCredentials(credentialsProvider) - .build() + .withEndpointConfiguration(endpoint) + .withCredentials(credentialsProvider) + .build() def connectionFactory = new SQSConnectionFactory(new ProviderConfiguration(), client) def connection = connectionFactory.createConnection() @@ -295,12 +347,12 @@ abstract class SqsClientTest extends VersionedNamingTestBase { when: def ddMsgAttribute = new MessageAttributeValue() - .withBinaryValue(ByteBuffer.wrap("hello world".getBytes(Charset.defaultCharset()))) - .withDataType("Binary") + .withBinaryValue(ByteBuffer.wrap("hello world".getBytes(Charset.defaultCharset()))) + .withDataType("Binary") connection.start() TraceUtils.runUnderTrace('parent') { client.sendMessage(new SendMessageRequest(queue.queueUrl, 'sometext') - .withMessageAttributes([_datadog: ddMsgAttribute])) + .withMessageAttributes([_datadog: ddMsgAttribute])) } def message = consumer.receive() consumer.receiveNoWait() @@ -558,9 +610,9 @@ class SqsClientV1DataStreamsForkedTest extends SqsClientTest { def "Data streams context extracted from message body"() { setup: def client = AmazonSQSClientBuilder.standard() - .withEndpointConfiguration(endpoint) - .withCredentials(credentialsProvider) - .build() + .withEndpointConfiguration(endpoint) + .withCredentials(credentialsProvider) + .build() def queueUrl = client.createQueue('somequeue').queueUrl TEST_WRITER.clear() @@ -588,9 +640,9 @@ class SqsClientV1DataStreamsForkedTest extends SqsClientTest { def "Data streams context not extracted from message body when message attributes are not present"() { setup: def client = AmazonSQSClientBuilder.standard() - .withEndpointConfiguration(endpoint) - .withCredentials(credentialsProvider) - .build() + .withEndpointConfiguration(endpoint) + .withCredentials(credentialsProvider) + .build() def queueUrl = client.createQueue('somequeue').queueUrl TEST_WRITER.clear() @@ -619,9 +671,9 @@ class SqsClientV1DataStreamsForkedTest extends SqsClientTest { def "Data streams context not extracted from message body when message is not a Json"() { setup: def client = AmazonSQSClientBuilder.standard() - .withEndpointConfiguration(endpoint) - .withCredentials(credentialsProvider) - .build() + .withEndpointConfiguration(endpoint) + .withCredentials(credentialsProvider) + .build() def queueUrl = client.createQueue('somequeue').queueUrl TEST_WRITER.clear()