diff --git a/opentelemetry/build.gradle b/opentelemetry/build.gradle index 509960e5dbc..00d913c280d 100644 --- a/opentelemetry/build.gradle +++ b/opentelemetry/build.gradle @@ -14,8 +14,10 @@ dependencies { libraries.opentelemetry.api, libraries.auto.value.annotations - testImplementation testFixtures(project(':grpc-core')), - project(':grpc-testing'), + testImplementation project(':grpc-testing'), + project(':grpc-inprocess'), + testFixtures(project(':grpc-core')), + testFixtures(project(':grpc-api')), libraries.opentelemetry.sdk.testing, libraries.assertj.core // opentelemetry.sdk.testing uses compileOnly for assertj diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java index 03183ef4920..6b257a18a6c 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcOpenTelemetry.java @@ -33,10 +33,12 @@ import io.grpc.ManagedChannelBuilder; import io.grpc.MetricSink; import io.grpc.ServerBuilder; +import io.grpc.internal.GrpcUtil; import io.grpc.opentelemetry.internal.OpenTelemetryConstants; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.metrics.Meter; import io.opentelemetry.api.metrics.MeterProvider; +import io.opentelemetry.api.trace.Tracer; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -61,6 +63,10 @@ public Stopwatch get() { } }; + @VisibleForTesting + static boolean ENABLE_OTEL_TRACING = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_ENABLE_OTEL_TRACING", + false); + private final OpenTelemetry openTelemetrySdk; private final MeterProvider meterProvider; private final Meter meter; @@ -68,6 +74,7 @@ public Stopwatch get() { private final boolean disableDefault; private final OpenTelemetryMetricsResource resource; private final OpenTelemetryMetricsModule openTelemetryMetricsModule; + private final OpenTelemetryTracingModule openTelemetryTracingModule; private final List optionalLabels; private final MetricSink sink; @@ -88,6 +95,7 @@ private GrpcOpenTelemetry(Builder builder) { this.optionalLabels = ImmutableList.copyOf(builder.optionalLabels); this.openTelemetryMetricsModule = new OpenTelemetryMetricsModule( STOPWATCH_SUPPLIER, resource, optionalLabels, builder.plugins); + this.openTelemetryTracingModule = new OpenTelemetryTracingModule(openTelemetrySdk); this.sink = new OpenTelemetryMetricSink(meter, enableMetrics, disableDefault, optionalLabels); } @@ -125,6 +133,11 @@ MetricSink getSink() { return sink; } + @VisibleForTesting + Tracer getTracer() { + return this.openTelemetryTracingModule.getTracer(); + } + /** * Registers GrpcOpenTelemetry globally, applying its configuration to all subsequently created * gRPC channels and servers. @@ -152,6 +165,9 @@ public void configureChannelBuilder(ManagedChannelBuilder builder) { InternalManagedChannelBuilder.addMetricSink(builder, sink); InternalManagedChannelBuilder.interceptWithTarget( builder, openTelemetryMetricsModule::getClientInterceptor); + if (ENABLE_OTEL_TRACING) { + builder.intercept(openTelemetryTracingModule.getClientInterceptor()); + } } /** @@ -161,6 +177,11 @@ public void configureChannelBuilder(ManagedChannelBuilder builder) { */ public void configureServerBuilder(ServerBuilder serverBuilder) { serverBuilder.addStreamTracerFactory(openTelemetryMetricsModule.getServerTracerFactory()); + if (ENABLE_OTEL_TRACING) { + serverBuilder.addStreamTracerFactory( + openTelemetryTracingModule.getServerTracerFactory()); + serverBuilder.intercept(openTelemetryTracingModule.getServerSpanPropagationInterceptor()); + } } @VisibleForTesting @@ -342,6 +363,11 @@ public Builder disableAllMetrics() { return this; } + Builder enableTracing(boolean enable) { + ENABLE_OTEL_TRACING = enable; + return this; + } + /** * Returns a new {@link GrpcOpenTelemetry} built with the configuration of this {@link * Builder}. diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java index 5d5543dddda..ea1e7ab803f 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/InternalGrpcOpenTelemetry.java @@ -29,4 +29,8 @@ public static void builderPlugin( GrpcOpenTelemetry.Builder builder, InternalOpenTelemetryPlugin plugin) { builder.plugin(plugin); } + + public static void enableTracing(GrpcOpenTelemetry.Builder builder, boolean enable) { + builder.enableTracing(enable); + } } diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java index 11659c87708..6f2d3268ae0 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; +import static io.grpc.internal.GrpcUtil.IMPLEMENTATION_VERSION; import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; @@ -28,15 +29,21 @@ import io.grpc.ClientStreamTracer; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.ForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.ServerStreamTracer; +import io.grpc.opentelemetry.internal.OpenTelemetryConstants; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.common.AttributesBuilder; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; import io.opentelemetry.context.propagation.ContextPropagators; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.logging.Level; @@ -50,7 +57,7 @@ final class OpenTelemetryTracingModule { private static final Logger logger = Logger.getLogger(OpenTelemetryTracingModule.class.getName()); @VisibleForTesting - static final String OTEL_TRACING_SCOPE_NAME = "grpc-java"; + final io.grpc.Context.Key otelSpan = io.grpc.Context.key("opentelemetry-span-key"); @Nullable private static final AtomicIntegerFieldUpdater callEndedUpdater; @Nullable @@ -83,13 +90,23 @@ final class OpenTelemetryTracingModule { private final MetadataGetter metadataGetter = MetadataGetter.getInstance(); private final MetadataSetter metadataSetter = MetadataSetter.getInstance(); private final TracingClientInterceptor clientInterceptor = new TracingClientInterceptor(); + private final ServerInterceptor serverSpanPropagationInterceptor = + new TracingServerSpanPropagationInterceptor(); private final ServerTracerFactory serverTracerFactory = new ServerTracerFactory(); OpenTelemetryTracingModule(OpenTelemetry openTelemetry) { - this.otelTracer = checkNotNull(openTelemetry.getTracer(OTEL_TRACING_SCOPE_NAME), "otelTracer"); + this.otelTracer = checkNotNull(openTelemetry.getTracerProvider(), "tracerProvider") + .tracerBuilder(OpenTelemetryConstants.INSTRUMENTATION_SCOPE) + .setInstrumentationVersion(IMPLEMENTATION_VERSION) + .build(); this.contextPropagators = checkNotNull(openTelemetry.getPropagators(), "contextPropagators"); } + @VisibleForTesting + Tracer getTracer() { + return otelTracer; + } + /** * Creates a {@link CallAttemptsTracerFactory} for a new call. */ @@ -112,6 +129,10 @@ ClientInterceptor getClientInterceptor() { return clientInterceptor; } + ServerInterceptor getServerSpanPropagationInterceptor() { + return serverSpanPropagationInterceptor; + } + @VisibleForTesting final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory { volatile int callEnded; @@ -252,6 +273,11 @@ public void streamClosed(io.grpc.Status status) { endSpanWithStatus(span, status); } + @Override + public io.grpc.Context filterContext(io.grpc.Context context) { + return context.withValue(otelSpan, span); + } + @Override public void outboundMessageSent( int seqNo, long optionalWireSize, long optionalUncompressedSize) { @@ -293,6 +319,69 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata } } + @VisibleForTesting + final class TracingServerSpanPropagationInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + Span span = otelSpan.get(io.grpc.Context.current()); + if (span == null) { + logger.log(Level.FINE, "Server span not found. ServerTracerFactory for server " + + "tracing must be set."); + return next.startCall(call, headers); + } + Context serverCallContext = Context.current().with(span); + try (Scope scope = serverCallContext.makeCurrent()) { + return new ContextServerCallListener<>(next.startCall(call, headers), serverCallContext); + } + } + } + + private static class ContextServerCallListener extends + ForwardingServerCallListener.SimpleForwardingServerCallListener { + private final Context context; + + protected ContextServerCallListener(ServerCall.Listener delegate, Context context) { + super(delegate); + this.context = checkNotNull(context, "context"); + } + + @Override + public void onMessage(ReqT message) { + try (Scope scope = context.makeCurrent()) { + delegate().onMessage(message); + } + } + + @Override + public void onHalfClose() { + try (Scope scope = context.makeCurrent()) { + delegate().onHalfClose(); + } + } + + @Override + public void onCancel() { + try (Scope scope = context.makeCurrent()) { + delegate().onCancel(); + } + } + + @Override + public void onComplete() { + try (Scope scope = context.makeCurrent()) { + delegate().onComplete(); + } + } + + @Override + public void onReady() { + try (Scope scope = context.makeCurrent()) { + delegate().onReady(); + } + } + } + @VisibleForTesting final class TracingClientInterceptor implements ClientInterceptor { diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java index e4a0fa46e8b..1ae7b755a48 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcOpenTelemetryTest.java @@ -17,15 +17,26 @@ package io.grpc.opentelemetry; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import com.google.common.collect.ImmutableList; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannelBuilder; import io.grpc.MetricSink; +import io.grpc.ServerBuilder; import io.grpc.internal.GrpcUtil; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.metrics.SdkMeterProvider; import io.opentelemetry.sdk.testing.exporter.InMemoryMetricReader; +import io.opentelemetry.sdk.trace.SdkTracerProvider; import java.util.Arrays; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -35,7 +46,19 @@ public class GrpcOpenTelemetryTest { private final InMemoryMetricReader inMemoryMetricReader = InMemoryMetricReader.create(); private final SdkMeterProvider meterProvider = SdkMeterProvider.builder().registerMetricReader(inMemoryMetricReader).build(); + private final SdkTracerProvider tracerProvider = SdkTracerProvider.builder().build(); private final OpenTelemetry noopOpenTelemetry = OpenTelemetry.noop(); + private boolean originalEnableOtelTracing; + + @Before + public void setup() { + originalEnableOtelTracing = GrpcOpenTelemetry.ENABLE_OTEL_TRACING; + } + + @After + public void tearDown() { + GrpcOpenTelemetry.ENABLE_OTEL_TRACING = originalEnableOtelTracing; + } @Test public void build() { @@ -56,6 +79,31 @@ public void build() { assertThat(openTelemetryModule.getOptionalLabels()).isEqualTo(ImmutableList.of("version")); } + @Test + public void buildTracer() { + OpenTelemetrySdk sdk = + OpenTelemetrySdk.builder().setTracerProvider(tracerProvider).build(); + + GrpcOpenTelemetry grpcOpenTelemetry = GrpcOpenTelemetry.newBuilder() + .enableTracing(true) + .sdk(sdk).build(); + + assertThat(grpcOpenTelemetry.getOpenTelemetryInstance()).isSameInstanceAs(sdk); + assertThat(grpcOpenTelemetry.getTracer()).isSameInstanceAs( + tracerProvider.tracerBuilder("grpc-java") + .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) + .build()); + ServerBuilder mockServerBuiler = mock(ServerBuilder.class); + grpcOpenTelemetry.configureServerBuilder(mockServerBuiler); + verify(mockServerBuiler, times(2)).addStreamTracerFactory(any()); + verify(mockServerBuiler).intercept(any()); + verifyNoMoreInteractions(mockServerBuiler); + + ManagedChannelBuilder mockChannelBuilder = mock(ManagedChannelBuilder.class); + grpcOpenTelemetry.configureChannelBuilder(mockChannelBuilder); + verify(mockChannelBuilder).intercept(any(ClientInterceptor.class)); + } + @Test public void builderDefaults() { GrpcOpenTelemetry module = GrpcOpenTelemetry.newBuilder().build(); @@ -73,6 +121,13 @@ public void builderDefaults() { assertThat(module.getEnableMetrics()).isEmpty(); assertThat(module.getOptionalLabels()).isEmpty(); assertThat(module.getSink()).isInstanceOf(MetricSink.class); + + assertThat(module.getTracer()).isSameInstanceAs(noopOpenTelemetry + .getTracerProvider() + .tracerBuilder("grpc-java") + .setInstrumentationVersion(GrpcUtil.IMPLEMENTATION_VERSION) + .build() + ); } @Test diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java index 68cba17e802..89e79d55d25 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java @@ -17,13 +17,14 @@ package io.grpc.opentelemetry; import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; -import static io.grpc.opentelemetry.OpenTelemetryTracingModule.OTEL_TRACING_SCOPE_NAME; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -38,14 +39,23 @@ import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; +import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.NoopServerCall; +import io.grpc.Server; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.opentelemetry.OpenTelemetryTracingModule.CallAttemptsTracerFactory; +import io.grpc.opentelemetry.internal.OpenTelemetryConstants; +import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.GrpcServerRule; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.trace.Span; @@ -54,6 +64,8 @@ import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.api.trace.TraceId; import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.api.trace.TracerBuilder; +import io.opentelemetry.api.trace.TracerProvider; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.opentelemetry.context.propagation.ContextPropagators; @@ -130,6 +142,8 @@ public String parse(InputStream stream) { public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); @Rule public final GrpcServerRule grpcServerRule = new GrpcServerRule().directExecutor(); + @Rule + public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); private Tracer tracerRule; @Mock private Tracer mockTracer; @@ -156,8 +170,15 @@ public String parse(InputStream stream) { @Before public void setUp() { - tracerRule = openTelemetryRule.getOpenTelemetry().getTracer(OTEL_TRACING_SCOPE_NAME); - when(mockOpenTelemetry.getTracer(OTEL_TRACING_SCOPE_NAME)).thenReturn(mockTracer); + tracerRule = openTelemetryRule.getOpenTelemetry().getTracer( + OpenTelemetryConstants.INSTRUMENTATION_SCOPE); + TracerProvider mockTracerProvider = mock(TracerProvider.class); + when(mockOpenTelemetry.getTracerProvider()).thenReturn(mockTracerProvider); + TracerBuilder mockTracerBuilder = mock(TracerBuilder.class); + when(mockTracerProvider.tracerBuilder(OpenTelemetryConstants.INSTRUMENTATION_SCOPE)) + .thenReturn(mockTracerBuilder); + when(mockTracerBuilder.setInstrumentationVersion(any())).thenReturn(mockTracerBuilder); + when(mockTracerBuilder.build()).thenReturn(mockTracer); when(mockOpenTelemetry.getPropagators()).thenReturn(ContextPropagators.create(mockPropagator)); when(mockSpanBuilder.startSpan()).thenReturn(mockAttemptSpan); when(mockSpanBuilder.setParent(any())).thenReturn(mockSpanBuilder); @@ -451,7 +472,8 @@ public ClientCall interceptCall( @Test public void clientStreamNeverCreatedStillRecordTracing() { - OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(mockOpenTelemetry); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); CallAttemptsTracerFactory callTracer = tracingModule.newClientCallTracer(mockClientSpan, method); @@ -570,6 +592,157 @@ public void grpcTraceBinPropagator() { Span.fromContext(contextArgumentCaptor.getValue()).getSpanContext()); } + @Test + public void testServerParentSpanPropagation() throws Exception { + final AtomicReference applicationSpan = new AtomicReference<>(); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + ServerServiceDefinition serviceDefinition = + ServerServiceDefinition.builder("package1.service2").addMethod( + method, new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + applicationSpan.set(Span.fromContext(Context.current())); + call.sendHeaders(new Metadata()); + call.sendMessage("Hello"); + call.close( + Status.PERMISSION_DENIED.withDescription("No you don't"), new Metadata()); + return mockServerCallListener; + } + }).build(); + + Server server = InProcessServerBuilder.forName("test-server-span") + .addService( + ServerInterceptors.intercept(serviceDefinition, + tracingModule.getServerSpanPropagationInterceptor())) + .addStreamTracerFactory(tracingModule.getServerTracerFactory()) + .directExecutor().build().start(); + grpcCleanupRule.register(server); + + ManagedChannel channel = InProcessChannelBuilder.forName("test-server-span") + .directExecutor().build(); + grpcCleanupRule.register(channel); + + Span parentSpan = tracerRule.spanBuilder("test-parent-span").startSpan(); + try (Scope scope = Context.current().with(parentSpan).makeCurrent()) { + Channel interceptedChannel = + ClientInterceptors.intercept( + channel, tracingModule.getClientInterceptor()); + ClientCall call = interceptedChannel.newCall(method, CALL_OPTIONS); + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + parentSpan.end(); + } + + verify(mockClientCallListener).onClose(statusCaptor.capture(), any(Metadata.class)); + Status rpcStatus = statusCaptor.getValue(); + assertEquals(rpcStatus.getCode(), Status.Code.PERMISSION_DENIED); + assertEquals(rpcStatus.getDescription(), "No you don't"); + assertEquals(applicationSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + + List spans = openTelemetryRule.getSpans(); + assertEquals(spans.size(), 4); + SpanData clientSpan = spans.get(2); + SpanData attemptSpan = spans.get(1); + + assertEquals(clientSpan.getName(), "Sent.package1.service2.method3"); + assertTrue(clientSpan.hasEnded()); + assertEquals(clientSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(clientSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + + assertEquals(attemptSpan.getName(), "Attempt.package1.service2.method3"); + assertTrue(attemptSpan.hasEnded()); + assertEquals(attemptSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(attemptSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + + SpanData serverSpan = spans.get(0); + assertEquals(serverSpan.getName(), "Recv.package1.service2.method3"); + assertTrue(serverSpan.hasEnded()); + assertEquals(serverSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(serverSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + } + + @Test + public void serverSpanPropagationInterceptor() throws Exception { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + Server server = InProcessServerBuilder.forName("test-span-propagation-interceptor") + .directExecutor().build().start(); + grpcCleanupRule.register(server); + final AtomicReference callbackSpan = new AtomicReference<>(); + ServerCall.Listener getContextListener = new ServerCall.Listener() { + @Override + public void onMessage(Integer message) { + callbackSpan.set(Span.fromContext(Context.current())); + } + + @Override + public void onHalfClose() { + callbackSpan.set(Span.fromContext(Context.current())); + } + + @Override + public void onCancel() { + callbackSpan.set(Span.fromContext(Context.current())); + } + + @Override + public void onComplete() { + callbackSpan.set(Span.fromContext(Context.current())); + } + }; + ServerInterceptor interceptor = tracingModule.getServerSpanPropagationInterceptor(); + @SuppressWarnings("unchecked") + ServerCallHandler handler = mock(ServerCallHandler.class); + when(handler.startCall(any(), any())).thenReturn(getContextListener); + ServerCall call = new NoopServerCall<>(); + Metadata metadata = new Metadata(); + ServerCall.Listener listener = interceptor.interceptCall(call, metadata, handler); + verify(handler).startCall(same(call), same(metadata)); + listener.onMessage(1); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onReady(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onCancel(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onHalfClose(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + listener.onComplete(); + assertEquals(callbackSpan.get(), Span.getInvalid()); + + Span parentSpan = tracerRule.spanBuilder("parent-span").startSpan(); + io.grpc.Context context = io.grpc.Context.current().withValue( + tracingModule.otelSpan, parentSpan); + io.grpc.Context previous = context.attach(); + try { + listener = interceptor.interceptCall(call, metadata, handler); + verify(handler, times(2)).startCall(same(call), same(metadata)); + listener.onMessage(1); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onReady(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onCancel(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onHalfClose(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + listener.onComplete(); + assertEquals(callbackSpan.get().getSpanContext().getTraceId(), + parentSpan.getSpanContext().getTraceId()); + } finally { + context.detach(previous); + } + } + @Test public void generateTraceSpanName() { assertEquals(