Skip to content
66 changes: 62 additions & 4 deletions bottlecap/src/lifecycle/invocation/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ use std::{
use chrono::{DateTime, Utc};
use datadog_trace_protobuf::pb::Span;
use datadog_trace_utils::{send_data::SendData, tracer_header_tags};
use serde_json::{json, Value};
use tokio::sync::mpsc::Sender;
use tracing::debug;

use crate::{
config::{self, AwsConfig},
lifecycle::invocation::{context::ContextBuffer, span_inferrer::SpanInferrer},
tags::provider,
traces::trace_processor,
traces::{
context::SpanContext,
propagation::{DatadogCompositePropagator, Propagator},
trace_processor,
},
};

pub const MS_TO_NS: f64 = 1_000_000.0;
Expand All @@ -23,6 +28,9 @@ pub struct Processor {
pub context_buffer: ContextBuffer,
inferrer: SpanInferrer,
pub span: Span,
pub extracted_span_context: Option<SpanContext>,
// Used to extract the trace context from inferred span, headers, or payload
propagator: DatadogCompositePropagator,
aws_config: AwsConfig,
tracer_detected: bool,
}
Expand All @@ -39,6 +47,8 @@ impl Processor {
.get_canonical_resource_name()
.unwrap_or("aws_lambda".to_string());

let propagator = DatadogCompositePropagator::new(Arc::clone(&config));

Processor {
context_buffer: ContextBuffer::default(),
inferrer: SpanInferrer::default(),
Expand All @@ -58,6 +68,8 @@ impl Processor {
meta_struct: HashMap::new(),
span_links: Vec::new(),
},
extracted_span_context: None,
propagator,
aws_config: aws_config.clone(),
tracer_detected: false,
}
Expand Down Expand Up @@ -166,21 +178,68 @@ impl Processor {
/// If this method is called, it means that we are operating in a Universally Instrumented
/// runtime. Therefore, we need to set the `tracer_detected` flag to `true`.
///
pub fn on_invocation_start(&mut self, payload: Vec<u8>) {
pub fn on_invocation_start(&mut self, headers: HashMap<String, String>, payload: Vec<u8>) {
self.tracer_detected = true;

// Reset trace context
self.span.trace_id = 0;
self.span.parent_id = 0;
self.span.span_id = 0;

self.inferrer.infer_span(&payload, &self.aws_config);
let payload_value = match serde_json::from_slice::<Value>(&payload) {
Ok(value) => value,
Err(_) => json!({}),
};

self.extracted_span_context = self.extract_span_context(&headers, &payload_value);
self.inferrer.infer_span(&payload_value, &self.aws_config);

if let Some(sc) = &self.extracted_span_context {
self.span.trace_id = sc.trace_id;
self.span.parent_id = sc.span_id;

// Set the right data to the correct root level span,
// If there's an inferred span, then that should be the root.
if self.inferrer.get_inferred_span().is_some() {
self.inferrer.set_parent_id(sc.span_id);
self.inferrer.extend_meta(sc.tags.clone());
} else {
self.span.meta.extend(sc.tags.clone());
}
}

if let Some(inferred_span) = self.inferrer.get_inferred_span() {
self.span.parent_id = inferred_span.span_id;
}
}

fn extract_span_context(
&mut self,
headers: &HashMap<String, String>,
payload_value: &Value,
) -> Option<SpanContext> {
if let Some(carrier) = self.inferrer.get_carrier() {
if let Some(sc) = self.propagator.extract(&carrier) {
debug!("Extracted trace context from inferred span");
return Some(sc);
}
}

if let Some(payload_headers) = payload_value.get("headers") {
if let Some(sc) = self.propagator.extract(payload_headers) {
debug!("Extracted trace context from event headers");
return Some(sc);
}
}

if let Some(sc) = self.propagator.extract(headers) {
debug!("Extracted trace context from headers");
return Some(sc);
}

None
}

/// Given trace context information, set it to the current span.
///
pub fn on_invocation_end(
Expand All @@ -194,7 +253,6 @@ impl Processor {
self.span.span_id = span_id;

if self.inferrer.get_inferred_span().is_some() {
self.inferrer.set_parent_id(parent_id);
if let Some(status_code) = status_code {
self.inferrer.set_status_code(status_code);
}
Expand Down
111 changes: 62 additions & 49 deletions bottlecap/src/lifecycle/invocation/span_inferrer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use datadog_trace_protobuf::pb::Span;
use rand::Rng;
use serde_json::Value;
Expand All @@ -16,6 +18,7 @@ const FUNCTION_TRIGGER_EVENT_SOURCE_ARN_TAG: &str = "function_trigger.event_sour
pub struct SpanInferrer {
inferred_span: Option<Span>,
is_async_span: bool,
carrier: Option<HashMap<String, String>>,
}

impl Default for SpanInferrer {
Expand All @@ -30,65 +33,64 @@ impl SpanInferrer {
Self {
inferred_span: None,
is_async_span: false,
carrier: None,
}
}

/// Given a byte payload, try to deserialize it into a `serde_json::Value`
/// and try matching it to a `Trigger` implementation, which will create
/// an inferred span and set it to `self.inferred_span`
///
pub fn infer_span(&mut self, payload: &[u8], aws_config: &AwsConfig) {
pub fn infer_span(&mut self, payload_value: &Value, aws_config: &AwsConfig) {
self.inferred_span = None;
if let Ok(payload_value) = serde_json::from_slice::<Value>(payload) {
if APIGatewayHttpEvent::is_match(&payload_value) {
if let Some(t) = APIGatewayHttpEvent::new(payload_value) {
let mut span = Span {
span_id: Self::generate_span_id(),
..Default::default()
};

t.enrich_span(&mut span);
span.meta.extend([
(
FUNCTION_TRIGGER_EVENT_SOURCE_TAG.to_string(),
"api_gateway".to_string(),
),
(
FUNCTION_TRIGGER_EVENT_SOURCE_ARN_TAG.to_string(),
t.get_arn(&aws_config.region),
),
]);

self.is_async_span = t.is_async();
self.inferred_span = Some(span);
}
} else if APIGatewayRestEvent::is_match(&payload_value) {
if let Some(t) = APIGatewayRestEvent::new(payload_value) {
let mut span = Span {
span_id: Self::generate_span_id(),
..Default::default()
};

t.enrich_span(&mut span);
span.meta.extend([
(
FUNCTION_TRIGGER_EVENT_SOURCE_TAG.to_string(),
"api_gateway".to_string(),
),
(
FUNCTION_TRIGGER_EVENT_SOURCE_ARN_TAG.to_string(),
t.get_arn(&aws_config.region),
),
]);

self.is_async_span = t.is_async();
self.inferred_span = Some(span);
}
} else {
debug!("Unable to infer span from payload");
if APIGatewayHttpEvent::is_match(payload_value) {
if let Some(t) = APIGatewayHttpEvent::new(payload_value.clone()) {
let mut span = Span {
span_id: Self::generate_span_id(),
..Default::default()
};

t.enrich_span(&mut span);
span.meta.extend([
(
FUNCTION_TRIGGER_EVENT_SOURCE_TAG.to_string(),
"api_gateway".to_string(),
),
(
FUNCTION_TRIGGER_EVENT_SOURCE_ARN_TAG.to_string(),
t.get_arn(&aws_config.region),
),
]);

self.carrier = Some(t.get_carrier());
self.is_async_span = t.is_async();
self.inferred_span = Some(span);
}
} else if APIGatewayRestEvent::is_match(payload_value) {
if let Some(t) = APIGatewayRestEvent::new(payload_value.clone()) {
let mut span = Span {
span_id: Self::generate_span_id(),
..Default::default()
};

t.enrich_span(&mut span);
span.meta.extend([
(
FUNCTION_TRIGGER_EVENT_SOURCE_TAG.to_string(),
"api_gateway".to_string(),
),
(
FUNCTION_TRIGGER_EVENT_SOURCE_ARN_TAG.to_string(),
t.get_arn(&aws_config.region),
),
]);

self.carrier = Some(t.get_carrier());
self.is_async_span = t.is_async();
self.inferred_span = Some(span);
}
} else {
debug!("Unable to serialize payload");
debug!("Unable to infer span from payload");
}
}

Expand All @@ -101,6 +103,12 @@ impl SpanInferrer {
}
}

pub fn extend_meta(&mut self, iter: HashMap<String, String>) {
if let Some(s) = &mut self.inferred_span {
s.meta.extend(iter);
}
}

pub fn set_status_code(&mut self, status_code: String) {
if let Some(s) = &mut self.inferred_span {
s.meta.insert("http.status_code".to_string(), status_code);
Expand Down Expand Up @@ -136,4 +144,9 @@ impl SpanInferrer {
pub fn get_inferred_span(&self) -> &Option<Span> {
&self.inferred_span
}

#[must_use]
pub fn get_carrier(&self) -> Option<HashMap<String, String>> {
self.carrier.clone()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ impl Trigger for APIGatewayHttpEvent {
.get("x-amz-invocation-type")
.is_some_and(|v| v == "Event")
}

fn get_carrier(&self) -> HashMap<String, String> {
self.headers.clone()
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ impl Trigger for APIGatewayRestEvent {
.get("x-amz-invocation-type")
.is_some_and(|v| v == "Event")
}

fn get_carrier(&self) -> HashMap<String, String> {
self.headers.clone()
}
}

#[cfg(test)]
Expand Down
1 change: 1 addition & 0 deletions bottlecap/src/lifecycle/invocation/triggers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub trait Trigger: Sized {
fn enrich_span(&self, span: &mut Span);
fn get_tags(&self) -> HashMap<String, String>;
fn get_arn(&self, region: &str) -> String;
fn get_carrier(&self) -> HashMap<String, String>;
fn is_async(&self) -> bool;
}

Expand Down
48 changes: 43 additions & 5 deletions bottlecap/src/lifecycle/listener.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright 2024-Present Datadog, Inc. https://www.datadoghq.com/
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
Expand All @@ -12,6 +13,10 @@ use tokio::sync::Mutex;
use tracing::{debug, error, warn};

use crate::lifecycle::invocation::processor::Processor as InvocationProcessor;
use crate::traces::propagation::text_map_propagator::{
DATADOG_HIGHER_ORDER_TRACE_ID_BITS_KEY, DATADOG_SAMPLING_PRIORITY_KEY, DATADOG_TAGS_KEY,
DATADOG_TRACE_ID_KEY,
};

const HELLO_PATH: &str = "/lambda/hello";
const START_INVOCATION_PATH: &str = "/lambda/start-invocation";
Expand Down Expand Up @@ -83,18 +88,37 @@ impl Listener {
invocation_processor: Arc<Mutex<InvocationProcessor>>,
) -> http::Result<Response<Body>> {
debug!("Received start invocation request");
let (_, body) = req.into_parts();
let (parts, body) = req.into_parts();
match hyper::body::to_bytes(body).await {
Ok(b) => {
let body = b.to_vec();
let mut processor = invocation_processor.lock().await;

processor.on_invocation_start(body);
let headers = Self::headers_to_map(parts.headers);

processor.on_invocation_start(headers, body);

let mut response = Response::builder().status(200);
if processor.span.trace_id != 0 {
response =
response.header("x-datadog-trace-id", processor.span.trace_id.to_string());

// If a `SpanContext` exists, then tell the tracer to use it.
// todo: update this whole code with DatadogHeaderPropagator::inject
// since this logic looks messy
if let Some(sp) = &processor.extracted_span_context {
response = response.header(DATADOG_TRACE_ID_KEY, sp.trace_id.to_string());
if let Some(priority) = sp.sampling.and_then(|s| s.priority) {
response =
response.header(DATADOG_SAMPLING_PRIORITY_KEY, priority.to_string());
}

// Handle 128 bit trace ids
if let Some(trace_id_higher_order_bits) =
sp.tags.get(DATADOG_HIGHER_ORDER_TRACE_ID_BITS_KEY)
{
response = response.header(
DATADOG_TAGS_KEY,
format!("{DATADOG_HIGHER_ORDER_TRACE_ID_BITS_KEY}={trace_id_higher_order_bits}"),
);
}
}

drop(processor);
Expand Down Expand Up @@ -128,6 +152,8 @@ impl Listener {

let mut processor = invocation_processor.lock().await;

// todo: fix this, code is a copy of the existing logic in Go, not accounting
// when a 128 bit trace id exist
let mut trace_id = 0;
if let Some(header) = headers.get("x-datadog-trace-id") {
if let Ok(header_value) = header.to_str() {
Expand Down Expand Up @@ -163,4 +189,16 @@ impl Listener {
.status(200)
.body(Body::from(json!({}).to_string()))
}

fn headers_to_map(headers: http::HeaderMap) -> HashMap<String, String> {
headers
.iter()
.map(|(k, v)| {
(
k.as_str().to_string(),
v.to_str().unwrap_or_default().to_string(),
)
})
.collect()
}
}
Loading