From 6f06d3d80187080adb0a124b3c014cb57c33c367 Mon Sep 17 00:00:00 2001 From: Cagri Yonca Date: Fri, 2 May 2025 14:13:27 +0200 Subject: [PATCH] fix: Kafka context propagation Signed-off-by: Cagri Yonca Co-authored-by: Paulo Vital Signed-off-by: Cagri Yonca --- .../kafka/confluent_kafka_python.py | 104 +++++---- .../instrumentation/kafka/kafka_python.py | 121 ++++++---- src/instana/propagators/base_propagator.py | 207 ++++++++++++------ src/instana/propagators/kafka_propagator.py | 88 +++++++- src/instana/tracer.py | 4 +- tests/clients/kafka/test_confluent_kafka.py | 141 +++++++++++- tests/clients/kafka/test_kafka_python.py | 169 +++++++++++++- 7 files changed, 668 insertions(+), 166 deletions(-) diff --git a/src/instana/instrumentation/kafka/confluent_kafka_python.py b/src/instana/instrumentation/kafka/confluent_kafka_python.py index 1eef59c9..9c5d1194 100644 --- a/src/instana/instrumentation/kafka/confluent_kafka_python.py +++ b/src/instana/instrumentation/kafka/confluent_kafka_python.py @@ -66,7 +66,13 @@ def trace_kafka_produce( span.set_attribute("kafka.access", "produce") # context propagation - headers = args[6] if len(args) > 6 else kwargs.get("headers", {}) + # + # As stated in the official documentation at + # https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#pythonclient-producer, + # headers can be either a list of (key, value) pairs or a + # dictionary. To maintain compatibility with the headers for the + # Kafka Python library, we will use a list of tuples. + headers = args[6] if len(args) > 6 else kwargs.get("headers", []) tracer.inject( span.context, Format.KAFKA_HEADERS, @@ -75,44 +81,63 @@ def trace_kafka_produce( ) try: + kwargs["headers"] = headers res = wrapped(*args, **kwargs) except Exception as exc: span.record_exception(exc) else: return res - def trace_kafka_consume( - wrapped: Callable[..., InstanaConfluentKafkaConsumer.consume], - instance: InstanaConfluentKafkaConsumer, - args: Tuple[int, str, Tuple[Any, ...]], - kwargs: Dict[str, Any], - ) -> List[confluent_kafka.Message]: - if tracing_is_off(): - return wrapped(*args, **kwargs) - + def create_span( + span_type: str, + topic: Optional[str] = "", + headers: Optional[List[Tuple[str, bytes]]] = [], + exception: Optional[str] = None, + ) -> None: tracer, parent_span, _ = get_tracer_tuple() - parent_context = ( parent_span.get_span_context() if parent_span else tracer.extract( - Format.KAFKA_HEADERS, {}, disable_w3c_trace_context=True + Format.KAFKA_HEADERS, + headers, + disable_w3c_trace_context=True, ) ) - with tracer.start_as_current_span( "kafka-consumer", span_context=parent_context, kind=SpanKind.CONSUMER ) as span: - span.set_attribute("kafka.access", "consume") + if topic: + span.set_attribute("kafka.service", topic) + span.set_attribute("kafka.access", span_type) - try: - res = wrapped(*args, **kwargs) - if isinstance(res, list) and len(res) > 0: - span.set_attribute("kafka.service", res[0].topic()) - except Exception as exc: - span.record_exception(exc) + if exception: + span.record_exception(exception) + + def trace_kafka_consume( + wrapped: Callable[..., InstanaConfluentKafkaConsumer.consume], + instance: InstanaConfluentKafkaConsumer, + args: Tuple[int, str, Tuple[Any, ...]], + kwargs: Dict[str, Any], + ) -> List[confluent_kafka.Message]: + if tracing_is_off(): + return wrapped(*args, **kwargs) + + res = None + exception = None + + try: + res = wrapped(*args, **kwargs) + except Exception as exc: + exception = exc + finally: + if res: + for message in res: + create_span("consume", message.topic(), message.headers()) else: - return res + create_span("consume", exception=exception) + + return res def trace_kafka_poll( wrapped: Callable[..., InstanaConfluentKafkaConsumer.poll], @@ -123,29 +148,24 @@ def trace_kafka_poll( if tracing_is_off(): return wrapped(*args, **kwargs) - tracer, parent_span, _ = get_tracer_tuple() - - parent_context = ( - parent_span.get_span_context() - if parent_span - else tracer.extract( - Format.KAFKA_HEADERS, {}, disable_w3c_trace_context=True - ) - ) - - with tracer.start_as_current_span( - "kafka-consumer", span_context=parent_context, kind=SpanKind.CONSUMER - ) as span: - span.set_attribute("kafka.access", "poll") + res = None + exception = None - try: - res = wrapped(*args, **kwargs) - if res: - span.set_attribute("kafka.service", res.topic()) - except Exception as exc: - span.record_exception(exc) + try: + res = wrapped(*args, **kwargs) + except Exception as exc: + exception = exc + finally: + if res: + create_span("poll", res.topic(), res.headers()) else: - return res + create_span( + "poll", + next(iter(instance.list_topics().topics)), + exception=exception, + ) + + return res # Apply the monkey patch confluent_kafka.Producer = InstanaConfluentKafkaProducer diff --git a/src/instana/instrumentation/kafka/kafka_python.py b/src/instana/instrumentation/kafka/kafka_python.py index c4979cc1..ad26ec0e 100644 --- a/src/instana/instrumentation/kafka/kafka_python.py +++ b/src/instana/instrumentation/kafka/kafka_python.py @@ -1,7 +1,8 @@ # (c) Copyright IBM Corp. 2025 try: - from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple + import inspect + from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import kafka # noqa: F401 import wrapt @@ -37,53 +38,77 @@ def trace_kafka_send( span.set_attribute("kafka.access", "send") # context propagation + headers = kwargs.get("headers", []) tracer.inject( span.context, Format.KAFKA_HEADERS, - kwargs.get("headers", {}), + headers, disable_w3c_trace_context=True, ) try: + kwargs["headers"] = headers res = wrapped(*args, **kwargs) except Exception as exc: span.record_exception(exc) else: return res - @wrapt.patch_function_wrapper("kafka", "KafkaConsumer.__next__") - def trace_kafka_consume( - wrapped: Callable[..., "kafka.KafkaConsumer.__next__"], - instance: "kafka.KafkaConsumer", - args: Tuple[int, str, Tuple[Any, ...]], - kwargs: Dict[str, Any], - ) -> "FutureRecordMetadata": - if tracing_is_off(): - return wrapped(*args, **kwargs) - + def create_span( + span_type: str, + topic: Optional[str], + headers: Optional[List[Tuple[str, bytes]]] = [], + exception: Optional[str] = None, + ) -> None: tracer, parent_span, _ = get_tracer_tuple() - parent_context = ( parent_span.get_span_context() if parent_span else tracer.extract( - Format.KAFKA_HEADERS, {}, disable_w3c_trace_context=True + Format.KAFKA_HEADERS, + headers, + disable_w3c_trace_context=True, ) ) - with tracer.start_as_current_span( "kafka-consumer", span_context=parent_context, kind=SpanKind.CONSUMER ) as span: - topic = list(instance.subscription())[0] - span.set_attribute("kafka.service", topic) - span.set_attribute("kafka.access", "consume") + if topic: + span.set_attribute("kafka.service", topic) + span.set_attribute("kafka.access", span_type) + if exception: + span.record_exception(exception) - try: - res = wrapped(*args, **kwargs) - except Exception as exc: - span.record_exception(exc) + @wrapt.patch_function_wrapper("kafka", "KafkaConsumer.__next__") + def trace_kafka_consume( + wrapped: Callable[..., "kafka.KafkaConsumer.__next__"], + instance: "kafka.KafkaConsumer", + args: Tuple[int, str, Tuple[Any, ...]], + kwargs: Dict[str, Any], + ) -> "FutureRecordMetadata": + if tracing_is_off(): + return wrapped(*args, **kwargs) + + exception = None + res = None + + try: + res = wrapped(*args, **kwargs) + except Exception as exc: + exception = exc + finally: + if res: + create_span( + "consume", + res.topic if res else list(instance.subscription())[0], + res.headers, + ) else: - return res + create_span( + "consume", list(instance.subscription())[0], exception=exception + ) + + return res @wrapt.patch_function_wrapper("kafka", "KafkaConsumer.poll") def trace_kafka_poll( @@ -91,38 +116,40 @@ def trace_kafka_poll( instance: "kafka.KafkaConsumer", args: Tuple[int, str, Tuple[Any, ...]], kwargs: Dict[str, Any], - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: if tracing_is_off(): return wrapped(*args, **kwargs) - tracer, parent_span, _ = get_tracer_tuple() - # The KafkaConsumer.consume() from the kafka-python-ng call the # KafkaConsumer.poll() internally, so we do not consider it here. - if parent_span and parent_span.name == "kafka-consumer": + if any( + frame.function == "trace_kafka_consume" + for frame in inspect.getouterframes(inspect.currentframe(), 2) + ): return wrapped(*args, **kwargs) - parent_context = ( - parent_span.get_span_context() - if parent_span - else tracer.extract( - Format.KAFKA_HEADERS, {}, disable_w3c_trace_context=True - ) - ) - - with tracer.start_as_current_span( - "kafka-consumer", span_context=parent_context, kind=SpanKind.CONSUMER - ) as span: - topic = list(instance.subscription())[0] - span.set_attribute("kafka.service", topic) - span.set_attribute("kafka.access", "poll") - - try: - res = wrapped(*args, **kwargs) - except Exception as exc: - span.record_exception(exc) + exception = None + res = None + + try: + res = wrapped(*args, **kwargs) + except Exception as exc: + exception = exc + finally: + if res: + for partition, consumer_records in res.items(): + for message in consumer_records: + create_span( + "poll", + partition.topic, + message.headers if hasattr(message, "headers") else [], + ) else: - return res + create_span( + "poll", list(instance.subscription())[0], exception=exception + ) + + return res logger.debug("Instrumenting Kafka (kafka-python)") except ImportError: diff --git a/src/instana/propagators/base_propagator.py b/src/instana/propagators/base_propagator.py index 7778b8f7..a981c8a1 100644 --- a/src/instana/propagators/base_propagator.py +++ b/src/instana/propagators/base_propagator.py @@ -8,7 +8,14 @@ from instana.log import logger from instana.span_context import SpanContext -from instana.util.ids import header_to_id, header_to_long_id, hex_id, internal_id, internal_id_limited, hex_id_limited +from instana.util.ids import ( + header_to_id, + header_to_long_id, + hex_id, + internal_id, + internal_id_limited, + hex_id_limited, +) from instana.w3c_trace_context.traceparent import Traceparent from instana.w3c_trace_context.tracestate import Tracestate @@ -32,45 +39,50 @@ class BasePropagator(object): - HEADER_KEY_T = 'X-INSTANA-T' - HEADER_KEY_S = 'X-INSTANA-S' - HEADER_KEY_L = 'X-INSTANA-L' - HEADER_KEY_SYNTHETIC = 'X-INSTANA-SYNTHETIC' + HEADER_KEY_T = "X-INSTANA-T" + HEADER_KEY_S = "X-INSTANA-S" + HEADER_KEY_L = "X-INSTANA-L" + HEADER_KEY_SYNTHETIC = "X-INSTANA-SYNTHETIC" HEADER_KEY_TRACEPARENT = "traceparent" HEADER_KEY_TRACESTATE = "tracestate" HEADER_KEY_SERVER_TIMING = "Server-Timing" - LC_HEADER_KEY_T = 'x-instana-t' - LC_HEADER_KEY_S = 'x-instana-s' - LC_HEADER_KEY_L = 'x-instana-l' - LC_HEADER_KEY_SYNTHETIC = 'x-instana-synthetic' + LC_HEADER_KEY_T = "x-instana-t" + LC_HEADER_KEY_S = "x-instana-s" + LC_HEADER_KEY_L = "x-instana-l" + LC_HEADER_KEY_SYNTHETIC = "x-instana-synthetic" LC_HEADER_KEY_SERVER_TIMING = "server-timing" - ALT_LC_HEADER_KEY_T = 'http_x_instana_t' - ALT_LC_HEADER_KEY_S = 'http_x_instana_s' - ALT_LC_HEADER_KEY_L = 'http_x_instana_l' - ALT_LC_HEADER_KEY_SYNTHETIC = 'http_x_instana_synthetic' + ALT_LC_HEADER_KEY_T = "http_x_instana_t" + ALT_LC_HEADER_KEY_S = "http_x_instana_s" + ALT_LC_HEADER_KEY_L = "http_x_instana_l" + ALT_LC_HEADER_KEY_SYNTHETIC = "http_x_instana_synthetic" ALT_HEADER_KEY_TRACEPARENT = "http_traceparent" ALT_HEADER_KEY_TRACESTATE = "http_tracestate" ALT_LC_HEADER_KEY_SERVER_TIMING = "http_server_timing" # ByteArray variations - B_HEADER_KEY_T = b'x-instana-t' - B_HEADER_KEY_S = b'x-instana-s' - B_HEADER_KEY_L = b'x-instana-l' - B_HEADER_KEY_SYNTHETIC = b'x-instana-synthetic' - B_HEADER_KEY_TRACEPARENT = b'traceparent' - B_HEADER_KEY_TRACESTATE = b'tracestate' + B_HEADER_KEY_T = b"x-instana-t" + B_HEADER_KEY_S = b"x-instana-s" + B_HEADER_KEY_L = b"x-instana-l" + B_HEADER_KEY_SYNTHETIC = b"x-instana-synthetic" + B_HEADER_KEY_TRACEPARENT = b"traceparent" + B_HEADER_KEY_TRACESTATE = b"tracestate" B_HEADER_KEY_SERVER_TIMING = b"server-timing" - B_ALT_LC_HEADER_KEY_T = b'http_x_instana_t' - B_ALT_LC_HEADER_KEY_S = b'http_x_instana_s' - B_ALT_LC_HEADER_KEY_L = b'http_x_instana_l' - B_ALT_LC_HEADER_KEY_SYNTHETIC = b'http_x_instana_synthetic' - B_ALT_HEADER_KEY_TRACEPARENT = b'http_traceparent' - B_ALT_HEADER_KEY_TRACESTATE = b'http_tracestate' + B_ALT_LC_HEADER_KEY_T = b"http_x_instana_t" + B_ALT_LC_HEADER_KEY_S = b"http_x_instana_s" + B_ALT_LC_HEADER_KEY_L = b"http_x_instana_l" + B_ALT_LC_HEADER_KEY_SYNTHETIC = b"http_x_instana_synthetic" + B_ALT_HEADER_KEY_TRACEPARENT = b"http_traceparent" + B_ALT_HEADER_KEY_TRACESTATE = b"http_tracestate" B_ALT_LC_HEADER_KEY_SERVER_TIMING = b"http_server_timing" + # Kafka Modern Headers + KAFKA_HEADER_KEY_T = "x_instana_t" + KAFKA_HEADER_KEY_S = "x_instana_s" + KAFKA_HEADER_KEY_L_S = "x_instana_l_s" + def __init__(self): self._tp = Traceparent() self._ts = Tracestate() @@ -94,7 +106,9 @@ def extract_headers_dict(carrier: CarrierT) -> Optional[Dict]: else: dc = dict(carrier) except Exception: - logger.debug(f"base_propagator extract_headers_dict: Couldn't convert - {carrier}") + logger.debug( + f"base_propagator extract_headers_dict: Couldn't convert - {carrier}" + ) return dc @@ -113,7 +127,7 @@ def _get_ctx_level(level: str) -> int: return ctx_level @staticmethod - def _get_correlation_properties(level:str): + def _get_correlation_properties(level: str): """ Get the correlation values if they are present. @@ -122,12 +136,16 @@ def _get_correlation_properties(level:str): """ correlation_type, correlation_id = [None] * 2 try: - correlation_type = level.split(",")[1].split("correlationType=")[1].split(";")[0] + correlation_type = ( + level.split(",")[1].split("correlationType=")[1].split(";")[0] + ) if "correlationId" in level: - correlation_id = level.split(",")[1].split("correlationId=")[1].split(";")[0] + correlation_id = ( + level.split(",")[1].split("correlationId=")[1].split(";")[0] + ) except Exception: logger.debug("extract instana correlation type/id error:", exc_info=True) - + return correlation_type, correlation_id def _get_participating_trace_context(self, span_context: SpanContext): @@ -143,7 +161,9 @@ def _get_participating_trace_context(self, span_context: SpanContext): tp_trace_id = span_context.trace_id traceparent = span_context.traceparent tracestate = span_context.tracestate - traceparent = self._tp.update_traceparent(traceparent, tp_trace_id, span_context.span_id, span_context.level) + traceparent = self._tp.update_traceparent( + traceparent, tp_trace_id, span_context.span_id, span_context.level + ) # In suppression mode do not update the tracestate and # do not add the 'in=' key-value pair to the incoming tracestate @@ -151,19 +171,23 @@ def _get_participating_trace_context(self, span_context: SpanContext): if span_context.suppression: return traceparent, tracestate - tracestate = self._ts.update_tracestate(tracestate, hex_id_limited(span_context.trace_id), hex_id(span_context.span_id)) + tracestate = self._ts.update_tracestate( + tracestate, + hex_id_limited(span_context.trace_id), + hex_id(span_context.span_id), + ) return traceparent, tracestate def __determine_span_context( - self, - trace_id: int, - span_id: int, - level: str, - synthetic: bool, - traceparent, - tracestate, - disable_w3c_trace_context: bool, - ) -> SpanContext: + self, + trace_id: int, + span_id: int, + level: str, + synthetic: bool, + traceparent, + tracestate, + disable_w3c_trace_context: bool, + ) -> SpanContext: """ This method determines the span context depending on a set of conditions being met Detailed description of the conditions can be found in the instana internal technical-documentation, @@ -179,7 +203,9 @@ def __determine_span_context( :return: SpanContext """ correlation = False - disable_traceparent = os.environ.get("INSTANA_DISABLE_W3C_TRACE_CORRELATION", "") + disable_traceparent = os.environ.get( + "INSTANA_DISABLE_W3C_TRACE_CORRELATION", "" + ) instana_ancestor = None if level and "correlationType" in level: @@ -189,7 +215,7 @@ def __determine_span_context( ( ctx_level, ctx_synthetic, - ctx_trace_parent, + ctx_trace_parent, ctx_instana_ancestor, ctx_long_trace_id, ctx_correlation_type, @@ -214,8 +240,15 @@ def __determine_span_context( if len(hex_trace_id) > 16: ctx_long_trace_id = hex_trace_id - elif not disable_w3c_trace_context and traceparent and not trace_id and not span_id: - _, tp_trace_id, tp_parent_id, _ = self._tp.get_traceparent_fields(traceparent) + elif ( + not disable_w3c_trace_context + and traceparent + and not trace_id + and not span_id + ): + _, tp_trace_id, tp_parent_id, _ = self._tp.get_traceparent_fields( + traceparent + ) if tracestate and "in=" in tracestate: instana_ancestor = self._ts.get_instana_ancestor(tracestate) @@ -237,7 +270,9 @@ def __determine_span_context( ctx_synthetic = synthetic if correlation: - ctx_correlation_type, ctx_correlation_id = self._get_correlation_properties(level) + ctx_correlation_type, ctx_correlation_id = self._get_correlation_properties( + level + ) if traceparent: ctx_traceparent = traceparent @@ -246,7 +281,7 @@ def __determine_span_context( if ctx_trace_id: if isinstance(ctx_trace_id, int): # check if ctx_trace_id is a valid internal trace id - if (ctx_trace_id <= 2**64 - 1): + if ctx_trace_id <= 2**64 - 1: trace_id = ctx_trace_id else: trace_id = internal_id(hex_id_limited(ctx_trace_id)) @@ -257,7 +292,9 @@ def __determine_span_context( return SpanContext( trace_id=trace_id, - span_id=internal_id_limited(ctx_span_id) if ctx_span_id else INVALID_SPAN_ID, + span_id=internal_id_limited(ctx_span_id) + if ctx_span_id + else INVALID_SPAN_ID, is_remote=False, level=ctx_level, synthetic=ctx_synthetic, @@ -270,8 +307,9 @@ def __determine_span_context( tracestate=ctx_tracestate, ) - - def extract_instana_headers(self, dc: Dict[str, Any]) -> Tuple[Optional[int], Optional[int], Optional[str], Optional[bool]]: + def extract_instana_headers( + self, dc: Dict[str, Any] + ) -> Tuple[Optional[int], Optional[int], Optional[str], Optional[bool]]: """ Search carrier for the *HEADER* keys and return the tracing key-values. @@ -282,25 +320,44 @@ def extract_instana_headers(self, dc: Dict[str, Any]) -> Tuple[Optional[int], Op # Headers can exist in the standard X-Instana-T/S format or the alternate HTTP_X_INSTANA_T/S style try: - trace_id = dc.get(self.LC_HEADER_KEY_T) or dc.get(self.ALT_LC_HEADER_KEY_T) or dc.get( - self.B_HEADER_KEY_T) or dc.get(self.B_ALT_LC_HEADER_KEY_T) + trace_id = ( + dc.get(self.LC_HEADER_KEY_T) + or dc.get(self.ALT_LC_HEADER_KEY_T) + or dc.get(self.B_HEADER_KEY_T) + or dc.get(self.B_ALT_LC_HEADER_KEY_T) + or dc.get(self.KAFKA_HEADER_KEY_T.lower()) + ) if trace_id: trace_id = header_to_long_id(trace_id) - span_id = dc.get(self.LC_HEADER_KEY_S) or dc.get(self.ALT_LC_HEADER_KEY_S) or dc.get( - self.B_HEADER_KEY_S) or dc.get(self.B_ALT_LC_HEADER_KEY_S) + span_id = ( + dc.get(self.LC_HEADER_KEY_S) + or dc.get(self.ALT_LC_HEADER_KEY_S) + or dc.get(self.B_HEADER_KEY_S) + or dc.get(self.B_ALT_LC_HEADER_KEY_S) + or dc.get(self.KAFKA_HEADER_KEY_S.lower()) + ) if span_id: span_id = header_to_id(span_id) - level = dc.get(self.LC_HEADER_KEY_L) or dc.get(self.ALT_LC_HEADER_KEY_L) or dc.get( - self.B_HEADER_KEY_L) or dc.get(self.B_ALT_LC_HEADER_KEY_L) + level = ( + dc.get(self.LC_HEADER_KEY_L) + or dc.get(self.ALT_LC_HEADER_KEY_L) + or dc.get(self.B_HEADER_KEY_L) + or dc.get(self.B_ALT_LC_HEADER_KEY_L) + or dc.get(self.KAFKA_HEADER_KEY_L_S.lower()) + ) if level and isinstance(level, bytes): level = level.decode("utf-8") - synthetic = dc.get(self.LC_HEADER_KEY_SYNTHETIC) or dc.get(self.ALT_LC_HEADER_KEY_SYNTHETIC) or dc.get( - self.B_HEADER_KEY_SYNTHETIC) or dc.get(self.B_ALT_LC_HEADER_KEY_SYNTHETIC) + synthetic = ( + dc.get(self.LC_HEADER_KEY_SYNTHETIC) + or dc.get(self.ALT_LC_HEADER_KEY_SYNTHETIC) + or dc.get(self.B_HEADER_KEY_SYNTHETIC) + or dc.get(self.B_ALT_LC_HEADER_KEY_SYNTHETIC) + ) if synthetic: - synthetic = synthetic in ['1', b'1'] + synthetic = synthetic in ["1", b"1"] except Exception: logger.debug("extract error:", exc_info=True) @@ -317,13 +374,21 @@ def __extract_w3c_trace_context_headers(self, dc): traceparent, tracestate = [None] * 2 try: - traceparent = dc.get(self.HEADER_KEY_TRACEPARENT) or dc.get(self.ALT_HEADER_KEY_TRACEPARENT) or dc.get( - self.B_HEADER_KEY_TRACEPARENT) or dc.get(self.B_ALT_HEADER_KEY_TRACEPARENT) + traceparent = ( + dc.get(self.HEADER_KEY_TRACEPARENT) + or dc.get(self.ALT_HEADER_KEY_TRACEPARENT) + or dc.get(self.B_HEADER_KEY_TRACEPARENT) + or dc.get(self.B_ALT_HEADER_KEY_TRACEPARENT) + ) if traceparent and isinstance(traceparent, bytes): traceparent = traceparent.decode("utf-8") - tracestate = dc.get(self.HEADER_KEY_TRACESTATE) or dc.get(self.ALT_HEADER_KEY_TRACESTATE) or dc.get( - self.B_HEADER_KEY_TRACESTATE) or dc.get(self.B_ALT_HEADER_KEY_TRACESTATE) + tracestate = ( + dc.get(self.HEADER_KEY_TRACESTATE) + or dc.get(self.ALT_HEADER_KEY_TRACESTATE) + or dc.get(self.B_HEADER_KEY_TRACESTATE) + or dc.get(self.B_ALT_HEADER_KEY_TRACESTATE) + ) if tracestate and isinstance(tracestate, bytes): tracestate = tracestate.decode("utf-8") @@ -332,10 +397,12 @@ def __extract_w3c_trace_context_headers(self, dc): return traceparent, tracestate - def extract(self, carrier: CarrierT, disable_w3c_trace_context: bool = False) -> Optional[SpanContext]: + def extract( + self, carrier: CarrierT, disable_w3c_trace_context: bool = False + ) -> Optional[SpanContext]: """ - This method overrides one of the Base classes as with the introduction - of W3C trace context for the HTTP requests more extracting steps and + This method overrides one of the Base classes as with the introduction + of W3C trace context for the HTTP requests more extracting steps and logic was required. :param disable_w3c_trace_context: @@ -349,9 +416,13 @@ def extract(self, carrier: CarrierT, disable_w3c_trace_context: bool = False) -> return None headers = {k.lower(): v for k, v in headers.items()} - trace_id, span_id, level, synthetic = self.extract_instana_headers(dc=headers) + trace_id, span_id, level, synthetic = self.extract_instana_headers( + dc=headers + ) if not disable_w3c_trace_context: - traceparent, tracestate = self.__extract_w3c_trace_context_headers(dc=headers) + traceparent, tracestate = self.__extract_w3c_trace_context_headers( + dc=headers + ) if traceparent: traceparent = self._tp.validate(traceparent) diff --git a/src/instana/propagators/kafka_propagator.py b/src/instana/propagators/kafka_propagator.py index 6b22fb6e..ad182b13 100644 --- a/src/instana/propagators/kafka_propagator.py +++ b/src/instana/propagators/kafka_propagator.py @@ -1,5 +1,5 @@ # (c) Copyright IBM Corp. 2025 -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional from opentelemetry.trace.span import format_span_id @@ -22,15 +22,81 @@ class KafkaPropagator(BasePropagator): def __init__(self) -> None: super(KafkaPropagator, self).__init__() + # Assisted by watsonx Code Assistant + def extract_carrier_headers(self, carrier: CarrierT) -> Dict[str, Any]: + """ + Extracts headers from a carrier object. + + Args: + carrier (CarrierT): The carrier object to extract headers from. + + Returns: + Dict[str, Any]: A dictionary containing the extracted headers. + """ + dc = {} + try: + if isinstance(carrier, list): + for header in carrier: + if isinstance(header, tuple): + dc[header[0]] = header[1] + elif isinstance(header, dict): + for k, v in header.items(): + dc[k] = v + else: + dc = self.extract_headers_dict(carrier) + except Exception: + logger.debug( + f"kafka_propagator extract_headers_list: Couldn't convert - {carrier}" + ) + + return dc + + def extract( + self, carrier: CarrierT, disable_w3c_trace_context: bool = False + ) -> Optional["SpanContext"]: + """ + This method overrides one of the Base classes as with the introduction + of W3C trace context for the Kafka requests more extracting steps and + logic was required. + + Args: + carrier (CarrierT): The carrier object to extract headers from. + disable_w3c_trace_context (bool): A flag to disable the W3C trace context. + + Returns: + Optional["SpanContext"]: The extracted span context or None. + """ + try: + headers = self.extract_carrier_headers(carrier=carrier) + return super(KafkaPropagator, self).extract( + carrier=headers, + disable_w3c_trace_context=disable_w3c_trace_context, + ) + + except Exception: + logger.debug("kafka_propagator extract error:", exc_info=True) + + # Assisted by watsonx Code Assistant def inject( self, span_context: "SpanContext", carrier: CarrierT, disable_w3c_trace_context: bool = True, ) -> None: + """ + Inject the trace context into a carrier. + + Args: + span_context (SpanContext): The SpanContext object containing trace information. + carrier (CarrierT): The carrier object to store the trace context. + disable_w3c_trace_context (bool, optional): A boolean flag to disable W3C trace context. Defaults to True. + + Returns: + None + """ trace_id = span_context.trace_id span_id = span_context.span_id - dictionary_carrier = self.extract_headers_dict(carrier) + dictionary_carrier = self.extract_carrier_headers(carrier) if dictionary_carrier: # Suppression `level` made in the child context or in the parent context @@ -53,9 +119,21 @@ def inject_key_value(carrier, key, value): ) try: - inject_key_value(carrier, "X_INSTANA_L_S", serializable_level) - inject_key_value(carrier, "X_INSTANA_T", hex_id_limited(trace_id)) - inject_key_value(carrier, "X_INSTANA_S", format_span_id(span_id)) + inject_key_value( + carrier, + self.KAFKA_HEADER_KEY_L_S, + serializable_level.encode("utf-8"), + ) + inject_key_value( + carrier, + self.KAFKA_HEADER_KEY_T, + hex_id_limited(trace_id).encode("utf-8"), + ) + inject_key_value( + carrier, + self.KAFKA_HEADER_KEY_S, + format_span_id(span_id).encode("utf-8"), + ) except Exception: logger.debug("KafkaPropagator - inject error:", exc_info=True) diff --git a/src/instana/tracer.py b/src/instana/tracer.py index aed28d17..83ea05ec 100644 --- a/src/instana/tracer.py +++ b/src/instana/tracer.py @@ -246,7 +246,7 @@ def inject( self, span_context: SpanContext, format: Union[ - Format.BINARY, Format.HTTP_HEADERS, Format.TEXT_MAP, Format.KAFKA_HEADERS + Format.BINARY, Format.HTTP_HEADERS, Format.TEXT_MAP, Format.KAFKA_HEADERS # type: ignore ], carrier: "CarrierT", disable_w3c_trace_context: bool = False, @@ -261,7 +261,7 @@ def inject( def extract( self, format: Union[ - Format.BINARY, Format.HTTP_HEADERS, Format.TEXT_MAP, Format.KAFKA_HEADERS + Format.BINARY, Format.HTTP_HEADERS, Format.TEXT_MAP, Format.KAFKA_HEADERS # type: ignore ], carrier: "CarrierT", disable_w3c_trace_context: bool = False, diff --git a/tests/clients/kafka/test_confluent_kafka.py b/tests/clients/kafka/test_confluent_kafka.py index 827a8c95..0b995e81 100644 --- a/tests/clients/kafka/test_confluent_kafka.py +++ b/tests/clients/kafka/test_confluent_kafka.py @@ -12,7 +12,7 @@ from opentelemetry.trace import SpanKind from instana.singletons import agent, tracer -from tests.helpers import testenv +from tests.helpers import get_first_span_by_filter, testenv class TestConfluentKafka: @@ -186,3 +186,142 @@ def test_trace_confluent_kafka_error(self) -> None: kafka_span.data["kafka"]["error"] == "num_messages must be between 0 and 1000000 (1M)" ) + + def test_confluent_kafka_consumer_root_exit(self) -> None: + agent.options.allow_exit_as_root = True + + self.producer.produce(testenv["kafka_topic"] + "_1", b"raw_bytes") + self.producer.produce(testenv["kafka_topic"] + "_2", b"raw_bytes") + self.producer.flush(timeout=10) + + # Consume the events + consumer_config = self.kafka_config.copy() + consumer_config["group.id"] = "my-group" + consumer_config["auto.offset.reset"] = "earliest" + + consumer = Consumer(consumer_config) + consumer.subscribe( + [ + testenv["kafka_topic"] + "_1", + testenv["kafka_topic"] + "_2", + ] + ) + + consumer.consume(num_messages=2, timeout=60) # noqa: F841 + + consumer.close() + + spans = self.recorder.queued_spans() + assert len(spans) == 4 + + producer_span_1 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "produce" + and span.data["kafka"]["service"] == "span-topic_1", + ) + producer_span_2 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "produce" + and span.data["kafka"]["service"] == "span-topic_2", + ) + consumer_span_1 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "consume" + and span.data["kafka"]["service"] == "span-topic_1", + ) + consumer_span_2 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "consume" + and span.data["kafka"]["service"] == "span-topic_2", + ) + + # same trace id, different span ids + assert producer_span_1.t == consumer_span_1.t + assert producer_span_1.s == consumer_span_1.p + assert producer_span_1.s != consumer_span_1.s + + assert producer_span_2.t == consumer_span_2.t + assert producer_span_2.s == consumer_span_2.p + assert producer_span_2.s != consumer_span_2.s + + self.kafka_client.delete_topics( + [ + testenv["kafka_topic"] + "_1", + testenv["kafka_topic"] + "_2", + ] + ) + + def test_confluent_kafka_poll_root_exit(self) -> None: + agent.options.allow_exit_as_root = True + + # Produce some events + self.producer.produce(testenv["kafka_topic"], b"raw_bytes1") + self.producer.flush() + + # Consume the events + consumer_config = self.kafka_config.copy() + consumer_config["group.id"] = "my-group" + consumer_config["auto.offset.reset"] = "earliest" + + consumer = Consumer(consumer_config) + consumer.subscribe([testenv["kafka_topic"]]) + + msg = consumer.poll(timeout=30) # noqa: F841 + + consumer.close() + + spans = self.recorder.queued_spans() + assert len(spans) == 2 + + producer_span = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "produce" + and span.data["kafka"]["service"] == "span-topic", + ) + + poll_span = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "poll" + and span.data["kafka"]["service"] == "span-topic", + ) + + # Same traceId + assert producer_span.t == poll_span.t + assert producer_span.s == poll_span.p + assert producer_span.s != poll_span.s + + def test_confluent_kafka_poll_root_exit_error(self) -> None: + agent.options.allow_exit_as_root = True + + # Produce some events + self.producer.produce(testenv["kafka_topic"], b"raw_bytes1") + self.producer.flush() + + # Consume the events + consumer_config = self.kafka_config.copy() + consumer_config["group.id"] = "my-group" + consumer_config["auto.offset.reset"] = "earliest" + + consumer = Consumer(consumer_config) + consumer.subscribe([testenv["kafka_topic"]]) + + msg = consumer.poll(timeout="wrong_value") # noqa: F841 + + consumer.close() + + spans = self.recorder.queued_spans() + assert len(spans) == 2 + + poll_span = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "poll" + and span.data["kafka"]["service"] == "span-topic", + ) + assert poll_span.data["kafka"]["error"] == "must be real number, not str" diff --git a/tests/clients/kafka/test_kafka_python.py b/tests/clients/kafka/test_kafka_python.py index 3a0ecfde..5999ef09 100644 --- a/tests/clients/kafka/test_kafka_python.py +++ b/tests/clients/kafka/test_kafka_python.py @@ -9,7 +9,7 @@ from opentelemetry.trace import SpanKind from instana.singletons import agent, tracer -from tests.helpers import testenv +from tests.helpers import get_first_span_by_filter, testenv class TestKafkaPython: @@ -202,3 +202,170 @@ def test_trace_kafka_python_error(self) -> None: assert kafka_span.data["kafka"]["service"] == "inexistent_kafka_topic" assert kafka_span.data["kafka"]["access"] == "consume" assert kafka_span.data["kafka"]["error"] == "StopIteration()" + + def test_kafka_consumer_root_exit(self) -> None: + agent.options.allow_exit_as_root = True + + self.producer.send(testenv["kafka_topic"], b"raw_bytes") + self.producer.flush() + + # Consume the events + consumer = KafkaConsumer( + testenv["kafka_topic"], + bootstrap_servers=testenv["kafka_bootstrap_servers"], + auto_offset_reset="earliest", # consume earliest available messages + enable_auto_commit=False, # do not auto-commit offsets + consumer_timeout_ms=1000, + ) + + for msg in consumer: + if msg is None: + break + + consumer.close() + + spans = self.recorder.queued_spans() + assert len(spans) == 4 + + producer_span = spans[0] + consumer_span = spans[1] + + assert producer_span.s + assert producer_span.n == "kafka" + assert producer_span.data["kafka"]["access"] == "send" + assert producer_span.data["kafka"]["service"] == "span-topic" + + assert consumer_span.s + assert consumer_span.n == "kafka" + assert consumer_span.data["kafka"]["access"] == "consume" + assert consumer_span.data["kafka"]["service"] == "span-topic" + + assert producer_span.t == consumer_span.t + + def test_kafka_poll_root_exit(self) -> None: + agent.options.allow_exit_as_root = True + + self.kafka_client.create_topics( + [ + NewTopic( + name=testenv["kafka_topic"] + "_1", + num_partitions=1, + replication_factor=1, + ), + NewTopic( + name=testenv["kafka_topic"] + "_2", + num_partitions=1, + replication_factor=1, + ), + NewTopic( + name=testenv["kafka_topic"] + "_3", + num_partitions=1, + replication_factor=1, + ), + ] + ) + + self.producer.send(testenv["kafka_topic"] + "_1", b"raw_bytes1") + self.producer.send(testenv["kafka_topic"] + "_2", b"raw_bytes2") + self.producer.send(testenv["kafka_topic"] + "_3", b"raw_bytes3") + self.producer.flush() + + # Consume the events + consumer = KafkaConsumer( + bootstrap_servers=testenv["kafka_bootstrap_servers"], + auto_offset_reset="earliest", # consume earliest available messages + enable_auto_commit=False, # do not auto-commit offsets + consumer_timeout_ms=1000, + ) + topics = [ + testenv["kafka_topic"] + "_1", + testenv["kafka_topic"] + "_2", + testenv["kafka_topic"] + "_3", + ] + consumer.subscribe(topics) + + messages = consumer.poll(timeout_ms=1000) # noqa: F841 + consumer.close() + + spans = self.recorder.queued_spans() + assert len(spans) == 6 + + producer_span_1 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "send" + and span.data["kafka"]["service"] == "span-topic_1", + ) + producer_span_2 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "send" + and span.data["kafka"]["service"] == "span-topic_2", + ) + producer_span_3 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "send" + and span.data["kafka"]["service"] == "span-topic_3", + ) + + poll_span_1 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "poll" + and span.data["kafka"]["service"] == "span-topic_1", + ) + poll_span_2 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "poll" + and span.data["kafka"]["service"] == "span-topic_2", + ) + poll_span_3 = get_first_span_by_filter( + spans, + lambda span: span.n == "kafka" + and span.data["kafka"]["access"] == "poll" + and span.data["kafka"]["service"] == "span-topic_3", + ) + + assert producer_span_1.n == "kafka" + assert producer_span_1.data["kafka"]["access"] == "send" + assert producer_span_1.data["kafka"]["service"] == "span-topic_1" + + assert producer_span_2.n == "kafka" + assert producer_span_2.data["kafka"]["access"] == "send" + assert producer_span_2.data["kafka"]["service"] == "span-topic_2" + + assert producer_span_3.n == "kafka" + assert producer_span_3.data["kafka"]["access"] == "send" + assert producer_span_3.data["kafka"]["service"] == "span-topic_3" + + assert poll_span_1.n == "kafka" + assert poll_span_1.data["kafka"]["access"] == "poll" + assert poll_span_1.data["kafka"]["service"] == "span-topic_1" + + assert poll_span_2.n == "kafka" + assert poll_span_2.data["kafka"]["access"] == "poll" + assert poll_span_2.data["kafka"]["service"] == "span-topic_2" + + assert poll_span_3.n == "kafka" + assert poll_span_3.data["kafka"]["access"] == "poll" + assert poll_span_3.data["kafka"]["service"] == "span-topic_3" + + # same trace id, different span ids + assert producer_span_1.t == poll_span_1.t + assert producer_span_1.s != poll_span_1.s + + assert producer_span_2.t == poll_span_2.t + assert producer_span_2.s != poll_span_2.s + + assert producer_span_3.t == poll_span_3.t + assert producer_span_3.s != poll_span_3.s + + self.kafka_client.delete_topics( + [ + testenv["kafka_topic"] + "_1", + testenv["kafka_topic"] + "_2", + testenv["kafka_topic"] + "_3", + ] + )