diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a04221e8fc..3c3d45af50 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -50,6 +50,7 @@ libdd-library-config*/ @DataDog/apm-sdk-capabilities-rust libdd-libunwind*/ @DataDog/libdatadog-profiling libdd-log*/ @DataDog/apm-common-components-core libdd-profiling*/ @DataDog/libdatadog-profiling +libdd-shared-runtime*/ @DataDog/apm-common-components-core libdd-telemetry*/ @DataDog/apm-common-components-core libdd-tinybytes @DataDog/apm-common-components-core libdd-trace-normalization @DataDog/serverless @DataDog/libdatadog-apm diff --git a/Cargo.lock b/Cargo.lock index 392e8bda81..cb3ead76b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3007,6 +3007,7 @@ version = "3.0.1" dependencies = [ "anyhow", "arc-swap", + "async-trait", "bytes", "clap", "criterion", @@ -3021,6 +3022,7 @@ dependencies = [ "libdd-ddsketch", "libdd-dogstatsd-client", "libdd-log", + "libdd-shared-runtime", "libdd-telemetry", "libdd-tinybytes", "libdd-trace-protobuf", @@ -3048,6 +3050,7 @@ dependencies = [ "libdd-capabilities-impl", "libdd-common-ffi", "libdd-data-pipeline", + "libdd-shared-runtime", "libdd-tinybytes", "libdd-trace-utils", "rmp-serde", @@ -3247,11 +3250,34 @@ dependencies = [ "prost", ] +[[package]] +name = "libdd-shared-runtime" +version = "1.0.0" +dependencies = [ + "async-trait", + "futures", + "libdd-capabilities", + "libdd-common", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "libdd-shared-runtime-ffi" +version = "30.0.0" +dependencies = [ + "build_common", + "libdd-shared-runtime", + "tracing", +] + [[package]] name = "libdd-telemetry" version = "4.0.0" dependencies = [ "anyhow", + "async-trait", "base64 0.22.1", "bytes", "futures", @@ -3261,6 +3287,7 @@ dependencies = [ "libc", "libdd-common", "libdd-ddsketch", + "libdd-shared-runtime", "serde", "serde_json", "sys-info", diff --git a/Cargo.toml b/Cargo.toml index 6ab51ec8eb..7431bd304b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,8 @@ members = [ "spawn_worker", "tests/spawn_from_lib", "bin_tests", + "libdd-shared-runtime", + "libdd-shared-runtime-ffi", "libdd-data-pipeline", "libdd-data-pipeline-ffi", "libdd-ddsketch", diff --git a/libdd-common/src/lib.rs b/libdd-common/src/lib.rs index bab6087337..786af97cd9 100644 --- a/libdd-common/src/lib.rs +++ b/libdd-common/src/lib.rs @@ -38,7 +38,6 @@ pub mod threading; #[cfg(not(target_arch = "wasm32"))] pub mod timeout; pub mod unix_utils; -pub mod worker; /// Extension trait for `Mutex` to provide a method that acquires a lock, panicking if the lock is /// poisoned. diff --git a/libdd-data-pipeline-ffi/Cargo.toml b/libdd-data-pipeline-ffi/Cargo.toml index dd17c7bb56..f755eb0606 100644 --- a/libdd-data-pipeline-ffi/Cargo.toml +++ b/libdd-data-pipeline-ffi/Cargo.toml @@ -31,6 +31,7 @@ libdd-trace-utils = { path = "../libdd-trace-utils" } [dependencies] libdd-capabilities-impl = { version = "0.1.0", path = "../libdd-capabilities-impl" } libdd-data-pipeline = { path = "../libdd-data-pipeline" } +libdd-shared-runtime = { version = "1.0.0", path = "../libdd-shared-runtime" } libdd-common-ffi = { path = "../libdd-common-ffi", default-features = false } libdd-tinybytes = { path = "../libdd-tinybytes" } tracing = { version = "0.1", default-features = false } diff --git a/libdd-data-pipeline-ffi/cbindgen.toml b/libdd-data-pipeline-ffi/cbindgen.toml index 79742aad99..63c23705ca 100644 --- a/libdd-data-pipeline-ffi/cbindgen.toml +++ b/libdd-data-pipeline-ffi/cbindgen.toml @@ -40,4 +40,4 @@ must_use = "DDOG_CHECK_RETURN" [parse] parse_deps = true -include = ["libdd-common", "libdd-common-ffi", "libdd-data-pipeline"] +include = ["libdd-common", "libdd-common-ffi", "libdd-shared-runtime", "libdd-data-pipeline"] diff --git a/libdd-data-pipeline-ffi/src/trace_exporter.rs b/libdd-data-pipeline-ffi/src/trace_exporter.rs index 5355dc6bb2..440e78e1ed 100644 --- a/libdd-data-pipeline-ffi/src/trace_exporter.rs +++ b/libdd-data-pipeline-ffi/src/trace_exporter.rs @@ -9,7 +9,6 @@ use libdd_common_ffi::{ CharSlice, {slice::AsBytes, slice::ByteSlice}, }; - use libdd_data_pipeline::trace_exporter::{ TelemetryConfig, TraceExporter as GenericTraceExporter, TraceExporterInputFormat, TraceExporterOutputFormat, @@ -17,7 +16,8 @@ use libdd_data_pipeline::trace_exporter::{ type TraceExporter = GenericTraceExporter; -use std::{ptr::NonNull, time::Duration}; +use libdd_shared_runtime::SharedRuntime; +use std::{ptr::NonNull, sync::Arc, time::Duration}; use tracing::debug; #[inline] @@ -73,6 +73,7 @@ pub struct TraceExporterConfig { process_tags: Option, test_session_token: Option, connection_timeout: Option, + shared_runtime: Option>, otlp_endpoint: Option, } @@ -420,6 +421,36 @@ pub unsafe extern "C" fn ddog_trace_exporter_config_set_connection_timeout( ) } +/// Sets a shared runtime for the TraceExporter to use for background workers. +/// +/// `handle` must have been initialized with [`ddog_shared_runtime_new`]. +/// +/// When set, the exporter will use the provided runtime instead of creating its own. +/// This allows multiple exporters (or other components) to share a single runtime. +/// The config holds a clone of the `Arc` (increments the strong count), so the +/// original handle remains valid and must still be freed with +/// [`ddog_shared_runtime_free`]. +#[no_mangle] +pub unsafe extern "C" fn ddog_trace_exporter_config_set_shared_runtime( + config: Option<&mut TraceExporterConfig>, + handle: Option>, +) -> Option> { + catch_panic!( + match (config, handle) { + (Some(config), Some(handle)) => { + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + // Increment the strong count before reconstructing so the config's Arc + // is independent from the caller's handle. + Arc::increment_strong_count(handle.as_ptr()); + config.shared_runtime = Some(Arc::from_raw(handle.as_ptr())); + None + } + _ => gen_error!(ErrorCode::InvalidArgument), + }, + gen_error!(ErrorCode::Panic) + ) +} + /// Enables OTLP HTTP/JSON export and sets the endpoint URL. /// /// When set, traces are sent to this URL in OTLP HTTP/JSON format instead of the Datadog @@ -502,6 +533,10 @@ pub unsafe extern "C" fn ddog_trace_exporter_new( builder.enable_health_metrics(); } + if let Some(runtime) = config.shared_runtime.clone() { + builder.set_shared_runtime(runtime); + } + if let Some(ref url) = config.otlp_endpoint { builder.set_otlp_endpoint(url); } diff --git a/libdd-data-pipeline/Cargo.toml b/libdd-data-pipeline/Cargo.toml index bd46fe5d58..8d5db11a82 100644 --- a/libdd-data-pipeline/Cargo.toml +++ b/libdd-data-pipeline/Cargo.toml @@ -14,6 +14,7 @@ autobenches = false [dependencies] anyhow = { version = "1.0" } arc-swap = "1.7.1" +async-trait = "0.1" http = "1" http-body-util = "0.1" tracing = { version = "0.1", default-features = false } @@ -30,7 +31,8 @@ uuid = { version = "1.10.0", features = ["v4"] } tokio-util = "0.7.11" libdd-capabilities = { path = "../libdd-capabilities", version = "0.1.0" } libdd-common = { version = "3.0.2", path = "../libdd-common", default-features = false } -libdd-telemetry = { version = "4.0.0", path = "../libdd-telemetry", default-features = false, optional = true } +libdd-shared-runtime = { version = "1.0.0", path = "../libdd-shared-runtime" } +libdd-telemetry = { version = "4.0.0", path = "../libdd-telemetry", default-features = false, optional = true} libdd-trace-protobuf = { version = "3.0.1", path = "../libdd-trace-protobuf" } libdd-trace-stats = { version = "2.0.0", path = "../libdd-trace-stats" } libdd-trace-utils = { version = "3.0.1", path = "../libdd-trace-utils", default-features = false } diff --git a/libdd-data-pipeline/examples/send-traces-with-stats.rs b/libdd-data-pipeline/examples/send-traces-with-stats.rs index caa0ec6360..460018c587 100644 --- a/libdd-data-pipeline/examples/send-traces-with-stats.rs +++ b/libdd-data-pipeline/examples/send-traces-with-stats.rs @@ -9,9 +9,11 @@ use libdd_data_pipeline::trace_exporter::{ use libdd_log::logger::{ logger_configure_std, logger_set_log_level, LogEventLevel, StdConfig, StdTarget, }; +use libdd_shared_runtime::SharedRuntime; use libdd_trace_protobuf::pb; use std::{ collections::HashMap, + sync::Arc, time::{Duration, UNIX_EPOCH}, }; @@ -54,6 +56,8 @@ fn main() { .expect("Failed to configure logger"); logger_set_log_level(LogEventLevel::Debug).expect("Failed to set log level"); + let shared_runtime = Arc::new(SharedRuntime::new().expect("Failed to create runtime")); + let args = Args::parse(); let telemetry_cfg = TelemetryConfig::default(); let mut builder = TraceExporter::::builder(); @@ -68,6 +72,7 @@ fn main() { .set_language_version(env!("CARGO_PKG_RUST_VERSION")) .set_input_format(TraceExporterInputFormat::V04) .set_output_format(TraceExporterOutputFormat::V04) + .set_shared_runtime(shared_runtime.clone()) .enable_telemetry(telemetry_cfg) .enable_stats(Duration::from_secs(10)); let exporter = builder @@ -89,7 +94,7 @@ fn main() { let data = rmp_serde::to_vec_named(&traces).expect("Failed to serialize traces"); exporter.send(data.as_ref()).expect("Failed to send traces"); - exporter + shared_runtime .shutdown(None) - .expect("Failed to shutdown exporter"); + .expect("Failed to shutdown runtime"); } diff --git a/libdd-data-pipeline/src/agent_info/fetcher.rs b/libdd-data-pipeline/src/agent_info/fetcher.rs index 415321d79d..865bd5a87c 100644 --- a/libdd-data-pipeline/src/agent_info/fetcher.rs +++ b/libdd-data-pipeline/src/agent_info/fetcher.rs @@ -8,15 +8,16 @@ use super::{ AGENT_INFO_CACHE, }; use anyhow::{anyhow, Result}; +use async_trait::async_trait; use bytes::Bytes; use libdd_capabilities::{HttpClientTrait, MaybeSend}; -use libdd_common::{worker::Worker, Endpoint}; +use libdd_common::Endpoint; +use libdd_shared_runtime::Worker; use sha2::{Digest, Sha256}; use std::marker::PhantomData; use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc; -#[cfg(not(target_arch = "wasm32"))] use tokio::time::sleep; use tracing::{debug, warn}; /// Whether the agent reported the same value or not. @@ -129,9 +130,12 @@ async fn fetch_and_hash_response( /// Fetch the info endpoint and update an ArcSwap keeping it up-to-date. /// -/// Once the run method has been started, the fetcher will -/// update the global info state based on the given refresh interval. You can access the current -/// state with [`crate::agent_info::get_agent_info`] +/// This type implements [`libdd_shared_runtime::Worker`] and is intended to be driven by a worker +/// runner such as [`libdd_shared_runtime::SharedRuntime`]. +/// In that lifecycle, `trigger()` waits for the next refresh event and `run()` performs a single +/// fetch. +/// +/// You can access the current state with [`crate::agent_info::get_agent_info`]. /// /// # Response observer /// When the fetcher is created it also returns a [`ResponseObserver`] which can be used to check @@ -141,8 +145,8 @@ async fn fetch_and_hash_response( /// # Example /// ```no_run /// # use anyhow::Result; -/// # use libdd_common::worker::Worker; /// # use libdd_capabilities_impl::NativeCapabilities; +/// # use libdd_shared_runtime::Worker; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// // Define the endpoint @@ -154,10 +158,9 @@ async fn fetch_and_hash_response( /// >::new( /// endpoint, std::time::Duration::from_secs(5 * 60) /// ); -/// // Start the runner -/// tokio::spawn(async move { -/// fetcher.run().await; -/// }); +/// // Start the fetcher on a shared runtime +/// let runtime = libdd_shared_runtime::SharedRuntime::new()?; +/// runtime.spawn_worker(fetcher)?; /// /// // Get the Arc to access the info /// let agent_info_arc = agent_info::get_agent_info(); @@ -179,6 +182,7 @@ pub struct AgentInfoFetcher { info_endpoint: Endpoint, refresh_interval: Duration, trigger_rx: Option>, + trigger_tx: mpsc::Sender<()>, /// `H` must live on the struct because `Worker::run(&mut self)` (a fixed /// trait signature) calls `fetch_info_with_state::()` internally. _phantom: PhantomData, @@ -199,6 +203,7 @@ impl AgentInfoFetcher { info_endpoint, refresh_interval, trigger_rx: Some(trigger_rx), + trigger_tx: trigger_tx.clone(), _phantom: PhantomData, }; @@ -216,56 +221,53 @@ impl AgentInfoFetcher { } } +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Worker for AgentInfoFetcher { - /// Start fetching the info endpoint with the given interval. - /// - /// # Warning - /// This method does not return and should be called within a dedicated task. - async fn run(&mut self) { - #[cfg(target_arch = "wasm32")] - { - // Worker is never started on wasm; this is unreachable. + async fn initial_trigger(&mut self) { + // Skip initial wait if cache is not populated + if AGENT_INFO_CACHE.load().is_none() { return; } + self.trigger().await + } - #[cfg(not(target_arch = "wasm32"))] - { - // Skip the first fetch if some info is present to avoid calling the /info endpoint - // at fork for heavy-forking environment. - if AGENT_INFO_CACHE.load().is_none() { - self.fetch_and_update().await; - } - - // Main loop waiting for a trigger event or the end of the refresh interval to trigger - // the fetch. - loop { - match &mut self.trigger_rx { - Some(trigger_rx) => { - tokio::select! { - // Wait for manual trigger (new state from headers) - trigger = trigger_rx.recv() => { - if trigger.is_some() { - self.fetch_and_update().await; - } else { - // The channel has been closed - self.trigger_rx = None; - } - } - // Regular periodic fetch timer - _ = sleep(self.refresh_interval) => { - self.fetch_and_update().await; - } - }; - } - None => { - // If the trigger channel is closed we only use timed fetch. - sleep(self.refresh_interval).await; - self.fetch_and_update().await; + async fn trigger(&mut self) { + // Wait for either a manual trigger or the refresh interval + match &mut self.trigger_rx { + Some(trigger_rx) => { + tokio::select! { + // Wait for manual trigger (new state from headers) + trigger = trigger_rx.recv() => { + if trigger.is_none() { + // The channel has been closed + self.trigger_rx = None; + } } + // Regular periodic fetch timer + _ = sleep(self.refresh_interval) => {} } } + None => { + // If the trigger channel is closed we only use timed fetch. + sleep(self.refresh_interval).await; + } } } + + async fn on_pause(&mut self) { + // Release the IoStack waker stored in trigger_rx by waking the channel and drain the + // message to avoid a spurious fetch on restart. If the channel is not empty then it has + // already been waked. + if self.trigger_rx.as_ref().is_some_and(|rx| rx.is_empty()) { + let _ = self.trigger_tx.try_send(()); + self.drain(); + }; + } + + async fn run(&mut self) { + self.fetch_and_update().await; + } } impl AgentInfoFetcher { @@ -350,6 +352,7 @@ mod single_threaded_tests { use crate::agent_info; use httpmock::prelude::*; use libdd_capabilities_impl::NativeCapabilities; + use libdd_shared_runtime::SharedRuntime; const TEST_INFO: &str = r#"{ "version": "0.0.0", @@ -553,31 +556,33 @@ mod single_threaded_tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn test_agent_info_fetcher_run() { + #[test] + fn test_agent_info_fetcher_run() { AGENT_INFO_CACHE.store(None); let server = MockServer::start(); - let mock_v1 = server - .mock_async(|when, then| { - when.path("/info"); - then.status(200) - .header("content-type", "application/json") - .body(r#"{"version":"1"}"#); - }) - .await; + let mut mock_v1 = server.mock(|when, then| { + when.path("/info"); + then.status(200) + .header("content-type", "application/json") + .body(r#"{"version":"1"}"#); + }); let endpoint = Endpoint::from_url(server.url("/info").parse().unwrap()); - let (mut fetcher, _response_observer) = AgentInfoFetcher::::new( + let (fetcher, _response_observer) = AgentInfoFetcher::::new( endpoint.clone(), Duration::from_millis(100), ); assert!(agent_info::get_agent_info().is_none()); - tokio::spawn(async move { - fetcher.run().await; - }); + let shared_runtime = SharedRuntime::new().unwrap(); + shared_runtime.spawn_worker(fetcher).unwrap(); // Wait until the info is fetched + let start = std::time::Instant::now(); while agent_info::get_agent_info().is_none() { - tokio::time::sleep(Duration::from_millis(100)).await; + assert!( + start.elapsed() <= Duration::from_secs(10), + "Timeout waiting for first /info fetch" + ); + std::thread::sleep(Duration::from_millis(100)); } let version_1 = agent_info::get_agent_info() @@ -588,22 +593,24 @@ mod single_threaded_tests { .clone() .unwrap(); assert_eq!(version_1, "1"); - mock_v1.assert_async().await; // Update the info endpoint - mock_v1.delete_async().await; - let mock_v2 = server - .mock_async(|when, then| { - when.path("/info"); - then.status(200) - .header("content-type", "application/json") - .body(r#"{"version":"2"}"#); - }) - .await; + mock_v1.delete(); + let mock_v2 = server.mock(|when, then| { + when.path("/info"); + then.status(200) + .header("content-type", "application/json") + .body(r#"{"version":"2"}"#); + }); // Wait for second fetch - while mock_v2.calls_async().await == 0 { - tokio::time::sleep(Duration::from_millis(100)).await; + let start = std::time::Instant::now(); + while mock_v2.calls() == 0 { + assert!( + start.elapsed() <= Duration::from_secs(10), + "Timeout waiting for second /info fetch" + ); + std::thread::sleep(Duration::from_millis(100)); } // This check is not 100% deterministic, but between the time the mock returns the response @@ -622,22 +629,20 @@ mod single_threaded_tests { assert_eq!(version_2, "2"); break; } - tokio::time::sleep(Duration::from_millis(100)).await; + std::thread::sleep(Duration::from_millis(100)); } } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn test_agent_info_trigger_different_state() { + #[test] + fn test_agent_info_trigger_different_state() { let server = MockServer::start(); - let mock = server - .mock_async(|when, then| { - when.path("/info"); - then.status(200) - .header("content-type", "application/json") - .body(r#"{"version":"triggered"}"#); - }) - .await; + let mock = server.mock(|when, then| { + when.path("/info"); + then.status(200) + .header("content-type", "application/json") + .body(r#"{"version":"triggered"}"#); + }); // Populate the cache with initial state AGENT_INFO_CACHE.store(Some(Arc::new(AgentInfo { @@ -646,12 +651,12 @@ mod single_threaded_tests { }))); let endpoint = Endpoint::from_url(server.url("/info").parse().unwrap()); - let (mut fetcher, response_observer) = + let (fetcher, response_observer) = + // Interval is too long to fetch during the test AgentInfoFetcher::::new(endpoint, Duration::from_secs(3600)); - tokio::spawn(async move { - fetcher.run().await; - }); + let shared_runtime = SharedRuntime::new().unwrap(); + shared_runtime.spawn_worker(fetcher).unwrap(); // Create a mock HTTP response with the new agent state let response = http::Response::builder() @@ -668,13 +673,13 @@ mod single_threaded_tests { const SLEEP_DURATION_MS: u64 = 10; let mut attempts = 0; - while mock.calls_async().await == 0 && attempts < MAX_ATTEMPTS { + while mock.calls() == 0 && attempts < MAX_ATTEMPTS { attempts += 1; - tokio::time::sleep(Duration::from_millis(SLEEP_DURATION_MS)).await; + std::thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); } // Should trigger a fetch since the state is different - mock.assert_calls_async(1).await; + mock.assert_calls(1); // Wait for the cache to be updated with proper timeout let mut attempts = 0; @@ -688,7 +693,7 @@ mod single_threaded_tests { } } attempts += 1; - tokio::time::sleep(Duration::from_millis(SLEEP_DURATION_MS)).await; + std::thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); } // Verify the cache was updated @@ -708,17 +713,15 @@ mod single_threaded_tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn test_agent_info_trigger_same_state() { + #[test] + fn test_agent_info_trigger_same_state() { let server = MockServer::start(); - let mock = server - .mock_async(|when, then| { - when.path("/info"); - then.status(200) - .header("content-type", "application/json") - .body(r#"{"version":"same"}"#); - }) - .await; + let mock = server.mock(|when, then| { + when.path("/info"); + then.status(200) + .header("content-type", "application/json") + .body(r#"{"version":"same"}"#); + }); let same_json = r#"{"version":"same"}"#; let same_hash = calculate_hash(same_json); @@ -730,12 +733,11 @@ mod single_threaded_tests { }))); let endpoint = Endpoint::from_url(server.url("/info").parse().unwrap()); - let (mut fetcher, response_observer) = - AgentInfoFetcher::::new(endpoint, Duration::from_secs(3600)); + let (fetcher, response_observer) = + AgentInfoFetcher::::new(endpoint, Duration::from_secs(3600)); // Very long interval - tokio::spawn(async move { - fetcher.run().await; - }); + let shared_runtime = SharedRuntime::new().unwrap(); + shared_runtime.spawn_worker(fetcher).unwrap(); // Create a mock HTTP response with the same agent state let response = http::Response::builder() @@ -748,9 +750,9 @@ mod single_threaded_tests { response_observer.check_response(&response); // Wait to ensure no fetch occurs - tokio::time::sleep(Duration::from_millis(500)).await; + std::thread::sleep(Duration::from_millis(500)); // Should not trigger a fetch since the state is the same - mock.assert_calls_async(0).await; + mock.assert_calls(0); } } diff --git a/libdd-data-pipeline/src/lib.rs b/libdd-data-pipeline/src/lib.rs index d442db4a7a..33d080c022 100644 --- a/libdd-data-pipeline/src/lib.rs +++ b/libdd-data-pipeline/src/lib.rs @@ -13,8 +13,6 @@ pub mod agent_info; mod health_metrics; pub(crate) mod otlp; -#[cfg(not(target_arch = "wasm32"))] -mod pausable_worker; #[allow(missing_docs)] pub mod stats_exporter; #[cfg(feature = "telemetry")] diff --git a/libdd-data-pipeline/src/pausable_worker.rs b/libdd-data-pipeline/src/pausable_worker.rs deleted file mode 100644 index 290442b8de..0000000000 --- a/libdd-data-pipeline/src/pausable_worker.rs +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ -// SPDX-License-Identifier: Apache-2.0 - -//! Defines a pausable worker to be able to stop background processes before forks - -use libdd_capabilities::MaybeSend; -use libdd_common::worker::Worker; -use std::fmt::Display; -use tokio::{ - runtime::Runtime, - select, - task::{JoinError, JoinHandle}, -}; -use tokio_util::sync::CancellationToken; - -/// A pausable worker which can be paused and restarted on forks. -/// -/// Used to allow a [`libdd_common::worker::Worker`] to be paused while saving its state when -/// dropping a tokio runtime to be able to restart with the same state on a new runtime. This is -/// used to stop all threads before a fork to avoid deadlocks in child. -/// -/// # Time-to-pause -/// This loop should yield regularly to reduce time-to-pause. See [`tokio::task::yield_now`]. -/// -/// # Cancellation safety -/// The main loop can be interrupted at any yield point (`.await`ed call). The state of the worker -/// at this point will be saved and used to restart the worker. To be able to safely restart, the -/// worker must be in a valid state on every call to `.await`. -/// See [`tokio::select#cancellation-safety`] for more details. -#[derive(Debug)] -pub enum PausableWorker { - Running { - handle: JoinHandle, - stop_token: CancellationToken, - }, - Paused { - worker: T, - }, - InvalidState, -} - -#[derive(Debug)] -pub enum PausableWorkerError { - InvalidState, - TaskAborted, -} - -impl Display for PausableWorkerError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - PausableWorkerError::InvalidState => { - write!(f, "Worker is in an invalid state and must be recreated.") - } - PausableWorkerError::TaskAborted => { - write!(f, "Worker task has been aborted and state has been lost.") - } - } - } -} - -impl core::error::Error for PausableWorkerError {} - -impl PausableWorker { - /// Create a new pausable worker from the given worker. - pub fn new(worker: T) -> Self { - Self::Paused { worker } - } - - /// Start the worker on the given runtime. - /// - /// The worker's main loop will be run on the runtime. - /// - /// # Errors - /// Fails if the worker is in an invalid state. - pub fn start(&mut self, rt: &Runtime) -> Result<(), PausableWorkerError> { - if let Self::Running { .. } = self { - Ok(()) - } else if let Self::Paused { mut worker } = std::mem::replace(self, Self::InvalidState) { - // Worker is temporarily in an invalid state, but since this block is failsafe it will - // be replaced by a valid state. - let stop_token = CancellationToken::new(); - let cloned_token = stop_token.clone(); - - let handle = rt.spawn(async move { - select! { - _ = worker.run() => {worker} - _ = cloned_token.cancelled() => {worker} - } - }); - - *self = PausableWorker::Running { handle, stop_token }; - Ok(()) - } else { - Err(PausableWorkerError::InvalidState) - } - } - - /// Pause the worker saving it's state to be restarted. - /// - /// # Errors - /// Fails if the worker handle has been aborted preventing the worker from being retrieved. - pub async fn pause(&mut self) -> Result<(), PausableWorkerError> { - match self { - PausableWorker::Running { handle, stop_token } => { - stop_token.cancel(); - if let Ok(worker) = handle.await { - *self = PausableWorker::Paused { worker }; - Ok(()) - } else { - // The task has been aborted and the worker can't be retrieved. - *self = PausableWorker::InvalidState; - Err(PausableWorkerError::TaskAborted) - } - } - PausableWorker::Paused { .. } => Ok(()), - PausableWorker::InvalidState => Err(PausableWorkerError::InvalidState), - } - } - - /// Wait for the run method of the worker to exit. - pub async fn join(self) -> Result<(), JoinError> { - if let PausableWorker::Running { handle, .. } = self { - handle.await?; - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use tokio::{runtime::Builder, time::sleep}; - - use super::*; - use std::{ - sync::mpsc::{channel, Sender}, - time::Duration, - }; - - /// Test worker incrementing the state and sending it with the sender. - struct TestWorker { - state: u32, - sender: Sender, - } - - impl Worker for TestWorker { - async fn run(&mut self) { - loop { - let _ = self.sender.send(self.state); - self.state += 1; - sleep(Duration::from_millis(100)).await; - } - } - } - - #[test] - fn test_restart() { - let (sender, receiver) = channel::(); - let worker = TestWorker { state: 0, sender }; - let runtime = Builder::new_multi_thread().enable_time().build().unwrap(); - let mut pausable_worker = PausableWorker::new(worker); - - pausable_worker.start(&runtime).unwrap(); - - assert_eq!(receiver.recv().unwrap(), 0); - runtime.block_on(async { pausable_worker.pause().await.unwrap() }); - // Empty the message queue and get the last message - let mut next_message = 1; - for message in receiver.try_iter() { - next_message = message + 1; - } - pausable_worker.start(&runtime).unwrap(); - assert_eq!(receiver.recv().unwrap(), next_message); - } -} diff --git a/libdd-data-pipeline/src/stats_exporter.rs b/libdd-data-pipeline/src/stats_exporter.rs index e1b7103d16..e23b0b5f4b 100644 --- a/libdd-data-pipeline/src/stats_exporter.rs +++ b/libdd-data-pipeline/src/stats_exporter.rs @@ -11,14 +11,13 @@ use std::{ }; use crate::trace_exporter::TracerMetadata; +use async_trait::async_trait; use libdd_capabilities::{HttpClientTrait, MaybeSend}; -use libdd_common::{worker::Worker, Endpoint}; +use libdd_common::Endpoint; +use libdd_shared_runtime::Worker; use libdd_trace_protobuf::pb; use libdd_trace_stats::span_concentrator::SpanConcentrator; use libdd_trace_utils::send_with_retry::{send_with_retry, RetryStrategy}; -#[cfg(not(target_arch = "wasm32"))] -use tokio::select; -use tokio_util::sync::CancellationToken; use tracing::error; const STATS_ENDPOINT_PATH: &str = "/v0.6/stats"; @@ -34,7 +33,6 @@ pub struct StatsExporter { endpoint: Endpoint, meta: TracerMetadata, sequence_id: AtomicU64, - cancellation_token: CancellationToken, client: H, } @@ -52,7 +50,6 @@ impl StatsExporter { concentrator: Arc>, meta: TracerMetadata, endpoint: Endpoint, - cancellation_token: CancellationToken, client: H, ) -> Self { Self { @@ -61,7 +58,6 @@ impl StatsExporter { endpoint, meta, sequence_id: AtomicU64::new(0), - cancellation_token, client, } } @@ -136,30 +132,21 @@ impl StatsExporter { } } +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl Worker for StatsExporter { - /// Run loop of the stats exporter - /// - /// Once started, the stats exporter will flush and send stats on every `self.flush_interval`. - /// If the `self.cancellation_token` is cancelled, the exporter will force flush all stats and - /// return. + async fn trigger(&mut self) { + tokio::time::sleep(self.flush_interval).await; + } + + /// Flush and send stats on every trigger. async fn run(&mut self) { - #[cfg(target_arch = "wasm32")] - { - return; - } + let _ = self.send(false).await; + } - #[cfg(not(target_arch = "wasm32"))] - loop { - select! { - _ = self.cancellation_token.cancelled() => { - let _ = self.send(true).await; - break; - }, - _ = tokio::time::sleep(self.flush_interval) => { - let _ = self.send(false).await; - }, - }; - } + async fn shutdown(&mut self) { + // Force flush all stats on shutdown + let _ = self.send(true).await; } } @@ -206,6 +193,7 @@ mod tests { use httpmock::prelude::*; use httpmock::MockServer; use libdd_capabilities_impl::NativeCapabilities; + use libdd_shared_runtime::SharedRuntime; use libdd_trace_utils::span::{trace_utils, v04::SpanSlice}; use libdd_trace_utils::test_utils::poll_for_mock_hit; use time::Duration; @@ -283,7 +271,6 @@ mod tests { Arc::new(Mutex::new(get_test_concentrator())), get_test_metadata(), Endpoint::from_url(stats_url_from_agent_url(&server.url("/")).unwrap()), - CancellationToken::new(), NativeCapabilities::new_client(), ); @@ -311,7 +298,6 @@ mod tests { Arc::new(Mutex::new(get_test_concentrator())), get_test_metadata(), Endpoint::from_url(stats_url_from_agent_url(&server.url("/")).unwrap()), - CancellationToken::new(), NativeCapabilities::new_client(), ); @@ -325,80 +311,81 @@ mod tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn test_run() { - let server = MockServer::start_async().await; - - let mut mock = server - .mock_async(|when, then| { - when.method(POST) - .header("Content-type", "application/msgpack") - .path("/v0.6/stats") - .body_includes("libdatadog-test") - .body_includes("key1:value1,key2:value2"); - then.status(200).body(""); - }) - .await; + #[test] + fn test_run() { + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + + let server = MockServer::start(); + + let mut mock = server.mock(|when, then| { + when.method(POST) + .header("Content-type", "application/msgpack") + .path("/v0.6/stats") + .body_includes("libdatadog-test") + .body_includes("key1:value1,key2:value2"); + then.status(200).body(""); + }); - let mut stats_exporter = StatsExporter::::new( - BUCKETS_DURATION, + let stats_exporter = StatsExporter::::new( + // Use smaller buckets duration to speed up test + Duration::from_secs(1), Arc::new(Mutex::new(get_test_concentrator())), get_test_metadata(), Endpoint::from_url(stats_url_from_agent_url(&server.url("/")).unwrap()), - CancellationToken::new(), NativeCapabilities::new_client(), ); - tokio::time::pause(); - tokio::spawn(async move { - stats_exporter.run().await; - }); - // Wait for the stats to be flushed - tokio::time::sleep(BUCKETS_DURATION + Duration::from_secs(1)).await; - // Resume time to sleep while the stats are being sent - tokio::time::resume(); + let _handle = shared_runtime + .spawn_worker(stats_exporter) + .expect("Failed to spawn worker"); + + // Wait for stats to be flushed + std::thread::sleep(Duration::from_secs(1)); + assert!( - poll_for_mock_hit(&mut mock, 10, 100, 1, false).await, + shared_runtime + .block_on(poll_for_mock_hit(&mut mock, 10, 100, 1, false)) + .expect("Failed to use runtime"), "Expected max retry attempts" ); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn test_cancellation_token() { - let server = MockServer::start_async().await; - - let mut mock = server - .mock_async(|when, then| { - when.method(POST) - .header("Content-type", "application/msgpack") - .path("/v0.6/stats") - .body_includes("libdatadog-test") - .body_includes("key1:value1,key2:value2"); - then.status(200).body(""); - }) - .await; + #[test] + fn test_worker_shutdown() { + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + + let server = MockServer::start(); + + let mut mock = server.mock(|when, then| { + when.method(POST) + .header("Content-type", "application/msgpack") + .path("/v0.6/stats") + .body_includes("libdatadog-test") + .body_includes("key1:value1,key2:value2"); + then.status(200).body(""); + }); let buckets_duration = Duration::from_secs(10); - let cancellation_token = CancellationToken::new(); - let mut stats_exporter = StatsExporter::::new( + let stats_exporter = StatsExporter::::new( buckets_duration, Arc::new(Mutex::new(get_test_concentrator())), get_test_metadata(), Endpoint::from_url(stats_url_from_agent_url(&server.url("/")).unwrap()), - cancellation_token.clone(), NativeCapabilities::new_client(), ); - tokio::spawn(async move { - stats_exporter.run().await; - }); - // Cancel token to trigger force flush - cancellation_token.cancel(); + let _handle = shared_runtime + .spawn_worker(stats_exporter) + .expect("Failed to spawn worker"); + + shared_runtime.shutdown(None).unwrap(); assert!( - poll_for_mock_hit(&mut mock, 10, 100, 1, false).await, + shared_runtime + .block_on(poll_for_mock_hit(&mut mock, 10, 100, 1, false)) + .expect("Failed to get runtime"), "Expected max retry attempts" ); } diff --git a/libdd-data-pipeline/src/telemetry/mod.rs b/libdd-data-pipeline/src/telemetry/mod.rs index 3d9f628291..86b7a302a7 100644 --- a/libdd-data-pipeline/src/telemetry/mod.rs +++ b/libdd-data-pipeline/src/telemetry/mod.rs @@ -16,7 +16,6 @@ use libdd_trace_utils::{ trace_utils::SendDataResult, }; use std::{collections::HashMap, time::Duration}; -use tokio::runtime::Handle; /// Structure to build a Telemetry client. /// @@ -100,7 +99,7 @@ impl TelemetryClientBuilder { } /// Builds the telemetry client. - pub fn build(self, runtime: Handle) -> (TelemetryClient, TelemetryWorker) { + pub fn build(self) -> (TelemetryClient, TelemetryWorker) { #[allow(clippy::unwrap_used)] let mut builder = TelemetryWorkerBuilder::new_fetch_host( self.service_name.unwrap(), @@ -118,7 +117,7 @@ impl TelemetryClientBuilder { builder.runtime_id = Some(id); } - let (worker_handle, worker) = builder.build_worker(runtime); + let (worker_handle, worker) = builder.build_worker(None); ( TelemetryClient { @@ -302,44 +301,38 @@ impl TelemetryClient { .send_msg(TelemetryActions::Lifecycle(LifecycleAction::Start)) .await; } - - /// Shutdowns the telemetry client. - pub async fn shutdown(self) { - _ = self - .worker - .send_msg(TelemetryActions::Lifecycle(LifecycleAction::Stop)) - .await; - } } #[cfg(test)] mod tests { + use super::*; use bytes::Bytes; use httpmock::Method::POST; use httpmock::MockServer; use libdd_capabilities::HttpError; - use libdd_common::worker::Worker; + use libdd_shared_runtime::{SharedRuntime, WorkerHandle}; + use libdd_trace_utils::test_utils::poll_for_mock_hits; use regex::Regex; use tokio::time::sleep; - use super::*; - - async fn get_test_client(url: &str) -> TelemetryClient { - let (client, mut worker) = TelemetryClientBuilder::default() + fn get_test_client(url: &str, runtime: &SharedRuntime) -> (TelemetryClient, WorkerHandle) { + let (client, worker) = TelemetryClientBuilder::default() .set_service_name("test_service") .set_service_version("test_version") .set_env("test_env") .set_language("test_language") .set_language_version("test_language_version") .set_tracer_version("test_tracer_version") + .set_runtime_id("foo") .set_url(url) .set_heartbeat(100) .set_debug_enabled(true) - .build(Handle::current()); - tokio::spawn(async move { worker.run().await }); - client + .build(); + let handle = runtime + .spawn_worker(worker) + .expect("Failed to spawn worker"); + (client, handle) } - #[test] fn builder_test() { let builder = TelemetryClientBuilder::default() @@ -371,306 +364,313 @@ mod tests { } #[cfg_attr(miri, ignore)] - #[tokio::test(flavor = "multi_thread")] - async fn spawn_test() { - let _ = TelemetryClientBuilder::default() - .set_service_name("test_service") - .set_service_version("test_version") - .set_env("test_env") - .set_language("test_language") - .set_language_version("test_language_version") - .set_tracer_version("test_tracer_version") - .build(Handle::current()); - } - - #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn api_bytes_test() { + #[test] + fn api_bytes_test() { let payload = Regex::new(r#""metric":"trace_api.bytes","tags":\["src_library:libdatadog"\],"sketch_b64":".+","common":true,"interval":\d+,"type":"distribution""#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { bytes_sent: 1, ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn requests_test() { + #[test] + fn requests_test() { let payload = Regex::new(r#""metric":"trace_api.requests","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog"\],"common":true,"type":"count""#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { requests_count: 1, ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn responses_per_code_test() { + #[test] + fn responses_per_code_test() { let payload = Regex::new(r#""metric":"trace_api.responses","points":\[\[\d+,1\.0\]\],"tags":\["status_code:200","src_library:libdatadog"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { responses_count_per_code: HashMap::from([(200, 1)]), ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn errors_timeout_test() { + #[test] + fn errors_timeout_test() { let payload = Regex::new(r#""metric":"trace_api.errors","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","type:timeout"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { errors_timeout: 1, ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn errors_network_test() { + #[test] + fn errors_network_test() { let payload = Regex::new(r#""metric":"trace_api.errors","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","type:network"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { errors_network: 1, ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn errors_status_code_test() { + #[test] + fn errors_status_code_test() { let payload = Regex::new(r#""metric":"trace_api.errors","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","type:status_code"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { errors_status_code: 1, ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn chunks_sent_test() { + #[test] + fn chunks_sent_test() { let payload = Regex::new(r#""metric":"trace_chunks_sent","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { chunks_sent: 1, ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn chunks_dropped_send_failure_test() { + #[test] + fn chunks_dropped_send_failure_test() { let payload = Regex::new(r#""metric":"trace_chunks_dropped","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","reason:send_failure"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { chunks_dropped_send_failure: 1, ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn chunks_dropped_p0_test() { + #[test] + fn chunks_dropped_p0_test() { let payload = Regex::new(r#""metric":"trace_chunks_dropped","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","reason:p0_drop"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { chunks_dropped_p0: 1, ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn chunks_dropped_serialization_error_test() { + #[test] + fn chunks_dropped_serialization_error_test() { let payload = Regex::new(r#""metric":"trace_chunks_dropped","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","reason:serialization_error"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }); let data = SendPayloadTelemetry { chunks_dropped_serialization_error: 1, ..Default::default() }; - - let client = get_test_client(&server.url("/")).await; - - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); + }) + .expect("Failed to get runtime"); } #[test] @@ -772,8 +772,8 @@ mod tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn telemetry_from_build_error_test() { + #[test] + fn telemetry_from_build_error_test() { let result = Err(SendWithRetryError::Build(5)); let telemetry = SendPayloadTelemetry::from_retry_result(&result, 1, 2, 0); assert_eq!( @@ -816,88 +816,66 @@ mod tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn runtime_id_test() { - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_includes(r#""runtime_id":"foo""#); - then.status(200).body(""); - }) - .await; - - let (client, mut worker) = TelemetryClientBuilder::default() - .set_service_name("test_service") - .set_service_version("test_version") - .set_env("test_env") - .set_language("test_language") - .set_language_version("test_language_version") - .set_tracer_version("test_tracer_version") - .set_url(&server.url("/")) - .set_heartbeat(100) - .set_runtime_id("foo") - .build(Handle::current()); - tokio::spawn(async move { worker.run().await }); - - client.start().await; - client - .send(&SendPayloadTelemetry { - requests_count: 1, - ..Default::default() + #[test] + fn runtime_id_test() { + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_includes(r#""runtime_id":"foo""#); + then.status(200).body(""); + }); + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + client + .send(&SendPayloadTelemetry { + requests_count: 1, + ..Default::default() + }) + .unwrap(); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); }) - .unwrap(); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - // One payload generate-metrics - telemetry_srv.assert_calls_async(1).await; + .expect("Failed to get runtime"); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn application_metadata_test() { - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST) - .body_includes(r#""application":{"service_name":"test_service","service_version":"test_version","env":"test_env","language_name":"test_language","language_version":"test_language_version","tracer_version":"test_tracer_version"}"#); - then.status(200).body(""); - }) - .await; - - let (client, mut worker) = TelemetryClientBuilder::default() - .set_service_name("test_service") - .set_service_version("test_version") - .set_env("test_env") - .set_language("test_language") - .set_language_version("test_language_version") - .set_tracer_version("test_tracer_version") - .set_url(&server.url("/")) - .set_heartbeat(100) - .set_runtime_id("foo") - .build(Handle::current()); - tokio::spawn(async move { worker.run().await }); - - client.start().await; - client - .send(&SendPayloadTelemetry { - requests_count: 1, - ..Default::default() + #[test] + fn application_metadata_test() { + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let server = MockServer::start(); + let mut telemetry_srv = server.mock(|when, then| { + when.method(POST).body_includes( + r#""application":{"service_name":"test_service","service_version":"test_version","env":"test_env","language_name":"test_language","language_version":"test_language_version","tracer_version":"test_tracer_version"}"#, + ); + then.status(200).body(""); + }); + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + shared_runtime + .block_on(async { + client.start().await; + client + .send(&SendPayloadTelemetry { + requests_count: 1, + ..Default::default() + }) + .unwrap(); + // Wait for send to be processed + sleep(Duration::from_millis(100)).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hits(&mut telemetry_srv, 1000, 10, 1).await, + "telemetry server did not receive calls within timeout" + ); }) - .unwrap(); - client.shutdown().await; - // Wait for the server to receive at least one call, but don't hang forever. - let start = std::time::Instant::now(); - while telemetry_srv.calls_async().await == 0 { - if start.elapsed() > Duration::from_secs(180) { - panic!("telemetry server did not receive calls within timeout"); - } - sleep(Duration::from_millis(10)).await; - } - // One payload generate-metrics - telemetry_srv.assert_calls_async(1).await; + .expect("Failed to get runtime"); } } diff --git a/libdd-data-pipeline/src/trace_exporter/builder.rs b/libdd-data-pipeline/src/trace_exporter/builder.rs index b98a328c84..29c3d78657 100644 --- a/libdd-data-pipeline/src/trace_exporter/builder.rs +++ b/libdd-data-pipeline/src/trace_exporter/builder.rs @@ -4,8 +4,6 @@ use crate::agent_info::AgentInfoFetcher; use crate::otlp::config::{OtlpProtocol, DEFAULT_OTLP_TIMEOUT}; use crate::otlp::OtlpTraceConfig; -#[cfg(not(target_arch = "wasm32"))] -use crate::pausable_worker::PausableWorker; #[cfg(feature = "telemetry")] use crate::telemetry::TelemetryClientBuilder; use crate::trace_exporter::agent_response::AgentResponsePayloadVersion; @@ -22,7 +20,8 @@ use arc_swap::ArcSwap; use libdd_capabilities::{HttpClientTrait, MaybeSend}; use libdd_common::{parse_uri, tag, Endpoint}; use libdd_dogstatsd_client::new; -use std::sync::{Arc, Mutex}; +use libdd_shared_runtime::SharedRuntime; +use std::sync::Arc; use std::time::Duration; const DEFAULT_AGENT_URL: &str = "http://127.0.0.1:8126"; @@ -55,6 +54,7 @@ pub struct TraceExporterBuilder { peer_tags: Vec, #[cfg(feature = "telemetry")] telemetry: Option, + shared_runtime: Option>, health_metrics_enabled: bool, test_session_token: Option, agent_rates_payload_version_enabled: bool, @@ -216,6 +216,12 @@ impl TraceExporterBuilder { self } + /// Set a shared runtime used by the exporter for background workers. + pub fn set_shared_runtime(&mut self, shared_runtime: Arc) -> &mut Self { + self.shared_runtime = Some(shared_runtime); + self + } + /// Enables health metrics emission. pub fn enable_health_metrics(&mut self) -> &mut Self { self.health_metrics_enabled = true; @@ -268,7 +274,13 @@ impl TraceExporterBuilder { )); } - let runtime = Arc::new(super::build_runtime()?); + let shared_runtime = + self.shared_runtime + .unwrap_or(Arc::new(SharedRuntime::new().map_err(|e| { + TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration( + e.to_string(), + )) + })?)); let dogstatsd = self.dogstatsd_url.and_then(|u| { new(Endpoint::from_slice(&u)).ok() // If we couldn't set the endpoint return @@ -285,14 +297,12 @@ impl TraceExporterBuilder { #[allow(unused_mut)] let mut stats = StatsComputationStatus::Disabled; - let info_endpoint = Endpoint::from_url(add_path(&agent_url, INFO_ENDPOINT)); - let (info_fetcher, info_response_observer) = - AgentInfoFetcher::::new(info_endpoint.clone(), Duration::from_secs(5 * 60)); - #[cfg(not(target_arch = "wasm32"))] { - let mut info_fetcher_worker = PausableWorker::new(info_fetcher); - info_fetcher_worker.start(&runtime).map_err(|e| { + let info_endpoint = Endpoint::from_url(add_path(&agent_url, INFO_ENDPOINT)); + let (info_fetcher, info_response_observer) = + AgentInfoFetcher::::new(info_endpoint.clone(), Duration::from_secs(5 * 60)); + let info_fetcher_handle = shared_runtime.spawn_worker(info_fetcher).map_err(|e| { TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration(e.to_string())) })?; @@ -301,7 +311,7 @@ impl TraceExporterBuilder { } #[cfg(feature = "telemetry")] - let (telemetry_client, telemetry_worker) = { + let (telemetry_client, telemetry_handle) = { let telemetry = self.telemetry.map(|telemetry_config| { let mut builder = TelemetryClientBuilder::default() .set_language(&self.language) @@ -316,20 +326,23 @@ impl TraceExporterBuilder { if let Some(id) = telemetry_config.runtime_id { builder = builder.set_runtime_id(&id); } - builder.build(runtime.handle().clone()) + Ok(builder.build()) }); - match telemetry { - Some((client, worker)) => { - let mut telemetry_worker = PausableWorker::new(worker); - telemetry_worker.start(&runtime).map_err(|e| { + Some(Ok((client, worker))) => { + let handle = shared_runtime.spawn_worker(worker).map_err(|e| { + TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration( + e.to_string(), + )) + })?; + shared_runtime.block_on(client.start()).map_err(|e| { TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration( e.to_string(), )) })?; - runtime.block_on(client.start()); - (Some(client), Some(telemetry_worker)) + (Some(client), Some(handle)) } + Some(Err(e)) => return Err(e), None => (None, None), } }; @@ -362,7 +375,7 @@ impl TraceExporterBuilder { input_format: self.input_format, output_format: self.output_format, client_computed_top_level: self.client_computed_top_level, - runtime: Arc::new(Mutex::new(Some(runtime))), + shared_runtime, dogstatsd, common_stats_tags: vec![libdatadog_version], client_side_stats: ArcSwap::new(stats.into()), @@ -372,12 +385,11 @@ impl TraceExporterBuilder { telemetry: telemetry_client, health_metrics_enabled: self.health_metrics_enabled, client: H::new_client(), - workers: Arc::new(Mutex::new(TraceExporterWorkers { - info: info_fetcher_worker, - stats: None, + workers: TraceExporterWorkers { + info_fetcher: info_fetcher_handle, #[cfg(feature = "telemetry")] - telemetry: telemetry_worker, - })), + telemetry: telemetry_handle, + }, agent_payload_response_version: self .agent_rates_payload_version_enabled .then(AgentResponsePayloadVersion::new), @@ -415,7 +427,9 @@ impl TraceExporterBuilder { #[cfg(target_arch = "wasm32")] { - drop(info_fetcher); + let info_endpoint = Endpoint::from_url(add_path(&agent_url, INFO_ENDPOINT)); + let (_info_fetcher, info_response_observer) = + AgentInfoFetcher::::new(info_endpoint, Duration::from_secs(5 * 60)); Ok(TraceExporter { endpoint: Endpoint { @@ -445,7 +459,7 @@ impl TraceExporterBuilder { input_format: self.input_format, output_format: self.output_format, client_computed_top_level: self.client_computed_top_level, - runtime: Arc::new(Mutex::new(Some(runtime))), + shared_runtime, dogstatsd, common_stats_tags: vec![libdatadog_version], client_side_stats: ArcSwap::new(stats.into()), @@ -571,6 +585,18 @@ mod tests { assert!(exporter.telemetry.is_none()); } + #[cfg_attr(miri, ignore)] + #[test] + fn test_set_shared_runtime() { + let mut builder = TraceExporterBuilder::default(); + let shared_runtime = Arc::new(SharedRuntime::new().unwrap()); + builder.set_shared_runtime(shared_runtime.clone()); + + let exporter = builder.build::().unwrap(); + + assert!(Arc::ptr_eq(&exporter.shared_runtime, &shared_runtime)); + } + #[test] #[cfg_attr(miri, ignore)] fn test_builder_error() { diff --git a/libdd-data-pipeline/src/trace_exporter/mod.rs b/libdd-data-pipeline/src/trace_exporter/mod.rs index b99ad3d954..180193f507 100644 --- a/libdd-data-pipeline/src/trace_exporter/mod.rs +++ b/libdd-data-pipeline/src/trace_exporter/mod.rs @@ -14,21 +14,16 @@ use self::agent_response::AgentResponse; use self::metrics::MetricsEmitter; use self::stats::StatsComputationStatus; use self::trace_serializer::TraceSerializer; -#[cfg(not(target_arch = "wasm32"))] -use crate::agent_info::AgentInfoFetcher; use crate::agent_info::ResponseObserver; use crate::otlp::{map_traces_to_otlp, send_otlp_traces_http, OtlpResourceInfo, OtlpTraceConfig}; -#[cfg(not(target_arch = "wasm32"))] -use crate::pausable_worker::PausableWorker; -#[cfg(not(target_arch = "wasm32"))] -use crate::stats_exporter::StatsExporter; #[cfg(feature = "telemetry")] use crate::telemetry::{SendPayloadTelemetry, TelemetryClient}; use crate::trace_exporter::agent_response::{ AgentResponsePayloadVersion, DATADOG_RATES_PAYLOAD_VERSION, }; -use crate::trace_exporter::error::InternalErrorKind; -use crate::trace_exporter::error::{RequestError, TraceExporterError}; +use crate::trace_exporter::error::{ + InternalErrorKind, RequestError, ShutdownError, TraceExporterError, +}; use crate::{ agent_info::{self, schema::AgentInfo}, health_metrics, @@ -41,10 +36,9 @@ use http::uri::PathAndQuery; use http::Uri; use libdd_capabilities::{HttpClientTrait, MaybeSend}; use libdd_common::tag::Tag; -use libdd_common::{Endpoint, MutexExt}; +use libdd_common::Endpoint; use libdd_dogstatsd_client::Client; -#[cfg(all(not(target_arch = "wasm32"), feature = "telemetry"))] -use libdd_telemetry::worker::TelemetryWorker; +use libdd_shared_runtime::{SharedRuntime, WorkerHandle}; use libdd_trace_utils::msgpack_decoder; use libdd_trace_utils::send_with_retry::{ send_with_retry, RetryStrategy, SendWithRetryError, SendWithRetryResult, @@ -52,10 +46,10 @@ use libdd_trace_utils::send_with_retry::{ use libdd_trace_utils::span::{v04::Span, TraceData}; use libdd_trace_utils::trace_utils::TracerHeaderTags; use std::io; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Duration; use std::{borrow::Borrow, str::FromStr}; -use tokio::runtime::Runtime; +use tokio::task::JoinSet; use tracing::{debug, error, warn}; const INFO_ENDPOINT: &str = "/info"; @@ -124,20 +118,6 @@ fn add_path(url: &Uri, path: &str) -> Uri { Uri::from_parts(parts).unwrap() } -pub(crate) fn build_runtime() -> io::Result { - #[cfg(not(target_arch = "wasm32"))] - { - tokio::runtime::Builder::new_multi_thread() - .worker_threads(1) - .enable_all() - .build() - } - #[cfg(target_arch = "wasm32")] - { - tokio::runtime::Builder::new_current_thread().build() - } -} - #[derive(Clone, Default, Debug)] pub struct TracerMetadata { pub hostname: String, @@ -177,16 +157,13 @@ impl<'a> From<&'a TracerMetadata> for HeaderMap { } } +/// Handles for the background workers owned by a [`TraceExporter`]. #[cfg(not(target_arch = "wasm32"))] -/// Background workers managed by a [`TraceExporter`]. -/// -/// `H` is the HTTP client implementation, see [`HttpClientTrait`]. #[derive(Debug)] -pub(crate) struct TraceExporterWorkers { - pub info: PausableWorker>, - pub stats: Option>>, +pub(crate) struct TraceExporterWorkers { + info_fetcher: WorkerHandle, #[cfg(feature = "telemetry")] - pub telemetry: Option>, + telemetry: Option, } /// The TraceExporter ingest traces from the tracers serialized as messagepack and forward them to @@ -230,7 +207,7 @@ pub struct TraceExporter { metadata: TracerMetadata, input_format: TraceExporterInputFormat, output_format: TraceExporterOutputFormat, - runtime: Arc>>>, + shared_runtime: Arc, /// None if dogstatsd is disabled dogstatsd: Option, common_stats_tags: Vec, @@ -244,7 +221,7 @@ pub struct TraceExporter { health_metrics_enabled: bool, client: H, #[cfg(not(target_arch = "wasm32"))] - workers: Arc>>, + workers: TraceExporterWorkers, agent_payload_response_version: Option, /// When set, traces are exported via OTLP HTTP/JSON instead of the Datadog agent. otlp_config: Option, @@ -256,122 +233,65 @@ impl TraceExporter { TraceExporterBuilder::default() } - /// Return the existing runtime or create a new one and start all workers - fn runtime(&self) -> Result, TraceExporterError> { - let mut runtime_guard = self.runtime.lock_or_panic(); - match runtime_guard.as_ref() { - Some(runtime) => { - // Runtime already running - Ok(runtime.clone()) - } - None => { - let runtime = Arc::new(build_runtime()?); - *runtime_guard = Some(runtime.clone()); - #[cfg(not(target_arch = "wasm32"))] - self.start_all_workers(&runtime)?; - Ok(runtime) + /// Stop the background workers owned by this exporter. + /// + /// Only the workers spawned for this exporter are stopped. Workers from other components + /// sharing the same [`SharedRuntime`] are unaffected. + /// + /// # Errors + /// Returns [`SharedRuntimeError::ShutdownTimedOut`] if a timeout was given and elapsed before + /// all workers finished. + pub fn shutdown(self, timeout: Option) -> Result<(), TraceExporterError> { + let runtime = self.shared_runtime.clone(); + if let Some(timeout) = timeout { + match runtime + .block_on(async { tokio::time::timeout(timeout, self.shutdown_workers()).await }) + .map_err(TraceExporterError::Io)? + { + Ok(()) => Ok(()), + Err(_) => Err(TraceExporterError::Shutdown(ShutdownError::TimedOut( + timeout, + ))), } + } else { + runtime + .block_on(self.shutdown_workers()) + .map_err(TraceExporterError::Io)?; + Ok(()) } } - /// Manually start all workers - #[cfg(not(target_arch = "wasm32"))] - pub fn run_worker(&self) -> Result<(), TraceExporterError> { - self.runtime()?; - Ok(()) - } - - #[cfg(not(target_arch = "wasm32"))] - /// Start all workers with the given runtime - fn start_all_workers(&self, runtime: &Arc) -> Result<(), TraceExporterError> { - let mut workers = self.workers.lock_or_panic(); - - self.start_info_worker(&mut workers, runtime)?; - self.start_stats_worker(&mut workers, runtime)?; - self.start_telemetry_worker(&mut workers, runtime)?; - - Ok(()) - } + async fn shutdown_workers(self) { + #[cfg(not(target_arch = "wasm32"))] + { + let mut join_set = JoinSet::new(); - #[cfg(not(target_arch = "wasm32"))] - /// Start the info worker - fn start_info_worker( - &self, - workers: &mut TraceExporterWorkers, - runtime: &Arc, - ) -> Result<(), TraceExporterError> { - workers.info.start(runtime).map_err(|e| { - TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) - }) - } + // Extract the stats handle before moving other fields. + if let StatsComputationStatus::Enabled { worker_handle, .. } = + &**self.client_side_stats.load() + { + let handle = worker_handle.clone(); + join_set.spawn(async move { handle.stop().await }); + } - #[cfg(not(target_arch = "wasm32"))] - /// Start the stats worker if present - fn start_stats_worker( - &self, - workers: &mut TraceExporterWorkers, - runtime: &Arc, - ) -> Result<(), TraceExporterError> { - if let Some(stats_worker) = &mut workers.stats { - stats_worker.start(runtime).map_err(|e| { - TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) - })?; - } - Ok(()) - } + let info_fetcher = self.workers.info_fetcher; + join_set.spawn(async move { info_fetcher.stop().await }); - #[cfg(all(not(target_arch = "wasm32"), feature = "telemetry"))] - fn start_telemetry_worker( - &self, - workers: &mut TraceExporterWorkers, - runtime: &Arc, - ) -> Result<(), TraceExporterError> { - if let Some(telemetry_worker) = &mut workers.telemetry { - telemetry_worker.start(runtime).map_err(|e| { - TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) - })?; - if let Some(client) = &self.telemetry { - runtime.block_on(client.start()); + #[cfg(feature = "telemetry")] + if let Some(telemetry) = self.workers.telemetry { + join_set.spawn(async move { telemetry.stop().await }); } - } - Ok(()) - } - #[cfg(all(not(target_arch = "wasm32"), not(feature = "telemetry")))] - fn start_telemetry_worker( - &self, - _workers: &mut TraceExporterWorkers, - _runtime: &Arc, - ) -> Result<(), TraceExporterError> { - Ok(()) - } - - #[cfg(not(target_arch = "wasm32"))] - pub fn stop_worker(&self) { - let runtime = self.runtime.lock_or_panic().take(); - if let Some(ref rt) = runtime { - let mut workers = self.workers.lock_or_panic(); - rt.block_on(async { - let _ = workers.info.pause().await; - if let Some(stats_worker) = &mut workers.stats { - let _ = stats_worker.pause().await; - }; - #[cfg(feature = "telemetry")] - if let Some(telemetry_worker) = &mut workers.telemetry { - let _ = telemetry_worker.pause().await; - }; - }); - } - if let PausableWorker::Paused { worker } = &mut self.workers.lock_or_panic().info { - self.info_response_observer.manual_trigger(); - worker.drain(); + while let Some(result) = join_set.join_next().await { + if let Ok(Err(e)) = result { + error!("Worker failed to shutdown: {:?}", e); + } + } } - drop(runtime); - } - #[cfg(target_arch = "wasm32")] - pub fn stop_worker(&self) { - let _ = self.runtime.lock_or_panic().take(); + // On wasm32 workers are no-ops, nothing to stop. + #[cfg(target_arch = "wasm32")] + let _ = self; } /// Send msgpack serialized traces to the agent @@ -397,7 +317,8 @@ impl TraceExporter { Ok(res) } - /// Async version of [`Self::send`] for platforms that cannot use `block_on` (e.g. wasm). + /// **WARNING**: This method is experimental and should not be used for production. + /// Async version of [`Self::send`] for platforms that cannot use `block_on` (e.g. wasm) pub async fn send_async(&self, data: &[u8]) -> Result { self.check_agent_info(); @@ -433,58 +354,6 @@ impl TraceExporter { Ok(res) } - /// Safely shutdown the TraceExporter and all related tasks - #[cfg(not(target_arch = "wasm32"))] - pub fn shutdown(mut self, timeout: Option) -> Result<(), TraceExporterError> { - let mut builder = tokio::runtime::Builder::new_current_thread(); - builder.enable_all(); - let runtime = builder.build()?; - - if let Some(timeout) = timeout { - return match runtime - .block_on(async { tokio::time::timeout(timeout, self.shutdown_async()).await }) - { - Ok(()) => Ok(()), - Err(_e) => Err(TraceExporterError::Shutdown( - error::ShutdownError::TimedOut(timeout), - )), - }; - } - - runtime.block_on(self.shutdown_async()); - Ok(()) - } - - #[cfg(not(target_arch = "wasm32"))] - /// Future used inside `Self::shutdown`. - /// - /// This function should not take ownership of the trace exporter as it will cause the runtime - /// stored in the trace exporter to be dropped in a non-blocking context causing a panic. - async fn shutdown_async(&mut self) { - let stats_status = self.client_side_stats.load(); - if let StatsComputationStatus::Enabled { - cancellation_token, .. - } = stats_status.as_ref() - { - cancellation_token.cancel(); - - let stats_worker = self.workers.lock_or_panic().stats.take(); - - if let Some(stats_worker) = stats_worker { - let _ = stats_worker.join().await; - } - } - #[cfg(feature = "telemetry")] - if let Some(telemetry) = self.telemetry.take() { - telemetry.shutdown().await; - let telemetry_worker = self.workers.lock_or_panic().telemetry.take(); - - if let Some(telemetry_worker) = telemetry_worker { - let _ = telemetry_worker.join().await; - } - } - } - #[cfg(not(target_arch = "wasm32"))] /// Check if agent info state has changed fn has_agent_info_state_changed(&self, agent_info: &Arc) -> bool { @@ -506,13 +375,12 @@ impl TraceExporter { let ctx = stats::StatsContext { metadata: &self.metadata, endpoint_url: &self.endpoint.url, - runtime: &self.runtime, + shared_runtime: &self.shared_runtime, }; stats::handle_stats_disabled_by_agent( &ctx, &agent_info, &self.client_side_stats, - &self.workers, self.client.clone(), ); } @@ -522,14 +390,13 @@ impl TraceExporter { let ctx = stats::StatsContext { metadata: &self.metadata, endpoint_url: &self.endpoint.url, - runtime: &self.runtime, + shared_runtime: &self.shared_runtime, }; stats::handle_stats_enabled( &ctx, &agent_info, stats_concentrator, &self.client_side_stats, - &self.workers, ); } } @@ -603,8 +470,8 @@ impl TraceExporter { trace_chunks: Vec>>, ) -> Result { self.check_agent_info(); - self.runtime()? - .block_on(async { self.send_trace_chunks_inner(trace_chunks).await }) + self.shared_runtime + .block_on(async { self.send_trace_chunks_inner(trace_chunks).await })? } /// Send a list of trace chunks to the agent, asynchronously (or OTLP when configured). @@ -681,8 +548,8 @@ impl TraceExporter { None, ); - self.runtime()? - .block_on(async { self.send_trace_chunks_inner(traces).await }) + self.shared_runtime + .block_on(async { self.send_trace_chunks_inner(traces).await })? } /// Send traces payload to agent with retry and telemetry reporting @@ -963,7 +830,7 @@ impl TraceExporter { #[cfg(not(target_arch = "wasm32"))] /// Test only function to check if the stats computation is active and the worker is running pub fn is_stats_worker_active(&self) -> bool { - stats::is_stats_worker_active(&self.client_side_stats, &self.workers) + stats::is_stats_worker_active(&self.client_side_stats) } } @@ -993,8 +860,6 @@ mod tests { use libdd_trace_utils::span::v04::SpanBytes; use libdd_trace_utils::span::v05; use std::net; - use std::time::Duration; - use tokio::time::sleep; // v05 messagepack empty payload -> [[""], []] const V5_EMPTY: [u8; 4] = [0x92, 0x91, 0xA0, 0x90]; @@ -1637,15 +1502,7 @@ mod tests { traces_endpoint.assert_calls(1); while metrics_endpoint.calls() == 0 { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + std::thread::sleep(Duration::from_millis(100)); } metrics_endpoint.assert_calls(1); } @@ -1695,15 +1552,7 @@ mod tests { traces_endpoint.assert_calls(1); while metrics_endpoint.calls() == 0 { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + std::thread::sleep(Duration::from_millis(100)); } metrics_endpoint.assert_calls(1); } @@ -1764,15 +1613,7 @@ mod tests { traces_endpoint.assert_calls(1); while metrics_endpoint.calls() == 0 { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + std::thread::sleep(Duration::from_millis(100)); } metrics_endpoint.assert_calls(1); } @@ -1947,21 +1788,11 @@ mod tests { // Wait for the info fetcher to get the config while mock_info.calls() == 0 { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + std::thread::sleep(Duration::from_millis(100)); } let _ = exporter.send(data.as_ref()).unwrap(); - exporter.shutdown(None).unwrap(); - mock_traces.assert(); } @@ -2031,15 +1862,6 @@ mod tests { ); mock_otlp.assert(); } - - #[test] - #[cfg_attr(miri, ignore)] - fn stop_and_start_runtime() { - let builder = TraceExporter::::builder(); - let exporter = builder.build::().unwrap(); - exporter.stop_worker(); - exporter.run_worker().unwrap(); - } } #[cfg(test)] @@ -2050,8 +1872,6 @@ mod single_threaded_tests { use libdd_capabilities_impl::NativeCapabilities; use libdd_trace_utils::msgpack_encoder; use libdd_trace_utils::span::v04::SpanBytes; - use std::time::Duration; - use tokio::time::sleep; #[cfg_attr(miri, ignore)] #[test] @@ -2083,6 +1903,8 @@ mod single_threaded_tests { .body(r#"{"version":"1","client_drop_p0s":true,"endpoints":["/v0.4/traces","/v0.6/stats"]}"#); }); + let runtime = Arc::new(SharedRuntime::new().unwrap()); + let mut builder = TraceExporter::::builder(); builder .set_url(&server.url("/")) @@ -2094,6 +1916,7 @@ mod single_threaded_tests { .set_language_interpreter("v8") .set_input_format(TraceExporterInputFormat::V04) .set_output_format(TraceExporterOutputFormat::V04) + .set_shared_runtime(runtime.clone()) .enable_stats(Duration::from_secs(10)); let exporter = builder.build::().unwrap(); @@ -2106,15 +1929,7 @@ mod single_threaded_tests { // Wait for the info fetcher to get the config while agent_info::get_agent_info().is_none() { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + std::thread::sleep(Duration::from_millis(100)); } let result = exporter.send(data.as_ref()); @@ -2131,7 +1946,7 @@ mod single_threaded_tests { std::thread::sleep(Duration::from_millis(10)); } - exporter.shutdown(None).unwrap(); + runtime.shutdown(None).unwrap(); // Wait for the mock server to process the stats for _ in 0..1000 { @@ -2183,6 +1998,8 @@ mod single_threaded_tests { .body(r#"{"version":"1","client_drop_p0s":true,"endpoints":["/v0.4/traces","/v0.6/stats"]}"#); }); + let runtime = Arc::new(SharedRuntime::new().unwrap()); + let mut builder = TraceExporter::::builder(); builder .set_url(&server.url("/")) @@ -2194,6 +2011,7 @@ mod single_threaded_tests { .set_language_interpreter("v8") .set_input_format(TraceExporterInputFormat::V04) .set_output_format(TraceExporterOutputFormat::V04) + .set_shared_runtime(runtime.clone()) .enable_stats(Duration::from_secs(10)); let exporter = builder.build::().unwrap(); @@ -2211,15 +2029,7 @@ mod single_threaded_tests { // Wait for agent_info to be present so that sending a trace will trigger the stats worker // to start while agent_info::get_agent_info().is_none() { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + std::thread::sleep(Duration::from_millis(100)); } exporter.send(data.as_ref()).unwrap(); @@ -2234,7 +2044,7 @@ mod single_threaded_tests { std::thread::sleep(Duration::from_millis(10)); } - exporter + runtime .shutdown(Some(Duration::from_millis(5))) .unwrap_err(); // The shutdown should timeout diff --git a/libdd-data-pipeline/src/trace_exporter/stats.rs b/libdd-data-pipeline/src/trace_exporter/stats.rs index 4a3835a40b..7014b0e7a4 100644 --- a/libdd-data-pipeline/src/trace_exporter/stats.rs +++ b/libdd-data-pipeline/src/trace_exporter/stats.rs @@ -16,13 +16,11 @@ use libdd_capabilities::{HttpClientTrait, MaybeSend}; #[cfg(not(target_arch = "wasm32"))] use libdd_common::Endpoint; use libdd_common::MutexExt; +use libdd_shared_runtime::{SharedRuntime, WorkerHandle}; use libdd_trace_stats::span_concentrator::SpanConcentrator; use std::sync::{Arc, Mutex}; use std::time::Duration; #[cfg(not(target_arch = "wasm32"))] -use tokio::runtime::Runtime; -use tokio_util::sync::CancellationToken; -#[cfg(not(target_arch = "wasm32"))] use tracing::{debug, error}; #[cfg(not(target_arch = "wasm32"))] @@ -39,7 +37,7 @@ pub(crate) const STATS_ENDPOINT: &str = "/v0.6/stats"; pub(crate) struct StatsContext<'a> { pub metadata: &'a super::TracerMetadata, pub endpoint_url: &'a http::Uri, - pub runtime: &'a Arc>>>, + pub shared_runtime: &'a SharedRuntime, } #[derive(Debug)] @@ -54,7 +52,7 @@ pub(crate) enum StatsComputationStatus { /// Client-side stats is enabled Enabled { stats_concentrator: Arc>, - cancellation_token: CancellationToken, + worker_handle: WorkerHandle, }, } @@ -75,7 +73,6 @@ fn get_span_kinds_for_stats(agent_info: &Arc) -> Vec { pub(crate) fn start_stats_computation( ctx: &StatsContext, client_side_stats: &ArcSwap, - workers: &Arc>>, span_kinds: Vec, peer_tags: Vec, client: H, @@ -87,13 +84,10 @@ pub(crate) fn start_stats_computation>, - cancellation_token: &CancellationToken, - workers: &Arc>>, client_side_stats: &ArcSwap, client: H, ) -> anyhow::Result<()> { @@ -117,26 +109,17 @@ fn create_and_start_stats_worker( +pub(crate) fn stop_stats_computation( ctx: &StatsContext, client_side_stats: &ArcSwap, - workers: &Arc>>, ) { if let StatsComputationStatus::Enabled { stats_concentrator, - cancellation_token, + worker_handle, } = &**client_side_stats.load() { - let runtime_guard = ctx.runtime.lock_or_panic(); - if let Some(rt) = runtime_guard.as_ref() { - rt.block_on(async { - cancellation_token.cancel(); - }); - workers.lock_or_panic().stats = None; - let bucket_size = stats_concentrator.lock_or_panic().get_bucket_size(); - - client_side_stats.store(Arc::new(StatsComputationStatus::DisabledByAgent { - bucket_size, - })); + let bucket_size = stats_concentrator.lock_or_panic().get_bucket_size(); + client_side_stats.store(Arc::new(StatsComputationStatus::DisabledByAgent { + bucket_size, + })); + match ctx.shared_runtime.block_on(worker_handle.clone().stop()) { + Ok(Err(e)) => error!("Failed to stop stats worker: {e}"), + Err(e) => error!("Failed to stop stats worker: {e}"), + _ => {} } } } @@ -177,14 +156,12 @@ pub(crate) fn handle_stats_disabled_by_agent, client_side_stats: &ArcSwap, - workers: &Arc>>, client: H, ) { if agent_info.info.client_drop_p0s.is_some_and(|v| v) { let status = start_stats_computation( ctx, client_side_stats, - workers, get_span_kinds_for_stats(agent_info), agent_info.info.peer_tags.clone().unwrap_or_default(), client, @@ -200,19 +177,18 @@ pub(crate) fn handle_stats_disabled_by_agent( +pub(crate) fn handle_stats_enabled( ctx: &StatsContext, agent_info: &Arc, stats_concentrator: &Mutex, client_side_stats: &ArcSwap, - workers: &Arc>>, ) { if agent_info.info.client_drop_p0s.is_some_and(|v| v) { let mut concentrator = stats_concentrator.lock_or_panic(); concentrator.set_span_kinds(get_span_kinds_for_stats(agent_info)); concentrator.set_peer_tags(agent_info.info.peer_tags.clone().unwrap_or_default()); } else { - stop_stats_computation(ctx, client_side_stats, workers); + stop_stats_computation(ctx, client_side_stats); debug!("Client-side stats computation has been disabled by the agent") } } @@ -273,25 +249,9 @@ pub(crate) fn process_traces_for_stats( #[cfg(test)] #[cfg(not(target_arch = "wasm32"))] /// Test only function to check if the stats computation is active and the worker is running -pub(crate) fn is_stats_worker_active( - client_side_stats: &ArcSwap, - workers: &Arc>>, -) -> bool { - if !matches!( +pub(crate) fn is_stats_worker_active(client_side_stats: &ArcSwap) -> bool { + matches!( **client_side_stats.load(), StatsComputationStatus::Enabled { .. } - ) { - return false; - } - - if let Ok(workers) = workers.try_lock() { - if let Some(stats_worker) = &workers.stats { - return matches!( - stats_worker, - crate::pausable_worker::PausableWorker::Running { .. } - ); - } - } - - false + ) } diff --git a/libdd-data-pipeline/tests/test_fetch_info.rs b/libdd-data-pipeline/tests/test_fetch_info.rs index af5125741b..81ec7acea8 100644 --- a/libdd-data-pipeline/tests/test_fetch_info.rs +++ b/libdd-data-pipeline/tests/test_fetch_info.rs @@ -4,9 +4,10 @@ #[cfg(test)] mod tracing_integration_tests { use libdd_capabilities_impl::NativeCapabilities; - use libdd_common::{worker::Worker, Endpoint}; + use libdd_common::Endpoint; use libdd_data_pipeline::agent_info; use libdd_data_pipeline::agent_info::{fetch_info, AgentInfoFetcher}; + use libdd_shared_runtime::Worker; use libdd_trace_utils::test_utils::datadog_test_agent::DatadogTestAgent; use std::time::Duration; diff --git a/libdd-shared-runtime-ffi/Cargo.toml b/libdd-shared-runtime-ffi/Cargo.toml new file mode 100644 index 0000000000..d965ef9240 --- /dev/null +++ b/libdd-shared-runtime-ffi/Cargo.toml @@ -0,0 +1,26 @@ +# Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "libdd-shared-runtime-ffi" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +publish = false + +[lib] +crate-type = ["lib", "staticlib", "cdylib"] +bench = false + +[features] +default = ["cbindgen", "catch_panic"] +catch_panic = [] +cbindgen = ["build_common/cbindgen"] + +[build-dependencies] +build_common = { path = "../build-common" } + +[dependencies] +libdd-shared-runtime = { version = "1.0.0", path = "../libdd-shared-runtime" } +tracing = { version = "0.1", default-features = false } diff --git a/libdd-shared-runtime-ffi/build.rs b/libdd-shared-runtime-ffi/build.rs new file mode 100644 index 0000000000..5cefa15d31 --- /dev/null +++ b/libdd-shared-runtime-ffi/build.rs @@ -0,0 +1,11 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +extern crate build_common; + +use build_common::generate_and_configure_header; + +fn main() { + let header_name = "shared-runtime.h"; + generate_and_configure_header(header_name); +} diff --git a/libdd-shared-runtime-ffi/cbindgen.toml b/libdd-shared-runtime-ffi/cbindgen.toml new file mode 100644 index 0000000000..f294158074 --- /dev/null +++ b/libdd-shared-runtime-ffi/cbindgen.toml @@ -0,0 +1,28 @@ +# Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +# SPDX-License-Identifier: Apache-2.0 + +language = "C" +cpp_compat = true +tab_width = 2 +header = """// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 +""" +include_guard = "DDOG_SHARED_RUNTIME_H" + +[export] +prefix = "ddog_" +renaming_overrides_prefixing = true + +[export.mangle] +rename_types = "PascalCase" + +[enum] +prefix_with_name = true +rename_variants = "ScreamingSnakeCase" + +[fn] +must_use = "DDOG_CHECK_RETURN" + +[parse] +parse_deps = true +include = ["libdd-shared-runtime"] diff --git a/libdd-shared-runtime-ffi/src/lib.rs b/libdd-shared-runtime-ffi/src/lib.rs new file mode 100644 index 0000000000..3387a8d03c --- /dev/null +++ b/libdd-shared-runtime-ffi/src/lib.rs @@ -0,0 +1,37 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 +#![cfg_attr(not(test), deny(clippy::panic))] +#![cfg_attr(not(test), deny(clippy::unwrap_used))] +#![cfg_attr(not(test), deny(clippy::expect_used))] +#![cfg_attr(not(test), deny(clippy::todo))] +#![cfg_attr(not(test), deny(clippy::unimplemented))] + +mod shared_runtime; + +#[cfg(all(feature = "catch_panic", panic = "unwind"))] +macro_rules! catch_panic { + ($f:expr, $err:expr) => { + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| $f)) { + Ok(ret) => ret, + Err(info) => { + if let Some(s) = info.downcast_ref::() { + tracing::error!("panic: {}", s); + } else if let Some(s) = info.downcast_ref::<&str>() { + tracing::error!("panic: {}", s); + } else { + tracing::error!("panic: unable to retrieve panic context"); + } + $err + } + } + }; +} + +#[cfg(any(not(feature = "catch_panic"), panic = "abort"))] +macro_rules! catch_panic { + ($f:expr, $err:expr) => { + $f + }; +} + +pub(crate) use catch_panic; diff --git a/libdd-shared-runtime-ffi/src/shared_runtime.rs b/libdd-shared-runtime-ffi/src/shared_runtime.rs new file mode 100644 index 0000000000..7a9a0ac084 --- /dev/null +++ b/libdd-shared-runtime-ffi/src/shared_runtime.rs @@ -0,0 +1,336 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +use crate::catch_panic; +use libdd_shared_runtime::{SharedRuntime, SharedRuntimeError}; +use std::ffi::{c_char, CString}; +use std::ptr::NonNull; +use std::sync::Arc; + +/// Error codes for SharedRuntime FFI operations. +#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum SharedRuntimeErrorCode { + /// Invalid argument provided (e.g. null handle). + InvalidArgument, + /// The runtime is not available or in an invalid state. + RuntimeUnavailable, + /// Failed to acquire a lock on internal state. + LockFailed, + /// A worker operation failed. + WorkerError, + /// Failed to create the tokio runtime. + RuntimeCreation, + /// Shutdown timed out. + ShutdownTimedOut, + /// An unexpected panic occurred inside the FFI call. + #[cfg(feature = "catch_panic")] + Panic, +} + +/// Error returned by SharedRuntime FFI functions. +#[repr(C)] +pub struct SharedRuntimeFFIError { + pub code: SharedRuntimeErrorCode, + /// The error message is always defined when the error is returned by a ddog_shared_runtime + /// ffi. + pub msg: *mut c_char, +} + +impl SharedRuntimeFFIError { + fn new(code: SharedRuntimeErrorCode, msg: &str) -> Self { + Self { + code, + msg: CString::new(msg).unwrap_or_default().into_raw(), + } + } +} + +impl From for SharedRuntimeFFIError { + fn from(err: SharedRuntimeError) -> Self { + let code = match &err { + SharedRuntimeError::RuntimeUnavailable => SharedRuntimeErrorCode::RuntimeUnavailable, + SharedRuntimeError::LockFailed(_) => SharedRuntimeErrorCode::LockFailed, + SharedRuntimeError::WorkerError(_) => SharedRuntimeErrorCode::WorkerError, + SharedRuntimeError::RuntimeCreation(_) => SharedRuntimeErrorCode::RuntimeCreation, + SharedRuntimeError::ShutdownTimedOut(_) => SharedRuntimeErrorCode::ShutdownTimedOut, + }; + SharedRuntimeFFIError::new(code, &err.to_string()) + } +} + +impl Drop for SharedRuntimeFFIError { + fn drop(&mut self) { + if !self.msg.is_null() { + // SAFETY: `msg` is always produced by `CString::into_raw` in `new`. + unsafe { + drop(CString::from_raw(self.msg)); + self.msg = std::ptr::null_mut(); + } + } + } +} + +macro_rules! panic_error { + () => { + Some(Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::Panic, + "panic", + ))) + }; +} + +/// Frees a `SharedRuntimeFFIError`. After this call the pointer is invalid. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_error_free(error: Option>) { + catch_panic!(drop(error), ()) +} + +/// Create a new `SharedRuntime`. +/// +/// On success writes a raw handle into `*out_handle` and returns `None`. +/// On failure leaves `*out_handle` unchanged and returns an error. +/// +/// The caller owns the handle and must eventually pass it to +/// [`ddog_shared_runtime_free`] (or another consumer that takes ownership). +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_new( + out_handle: NonNull<*const SharedRuntime>, +) -> Option> { + catch_panic!( + match SharedRuntime::new() { + Ok(runtime) => { + out_handle.as_ptr().write(Arc::into_raw(Arc::new(runtime))); + None + } + Err(err) => Some(Box::new(SharedRuntimeFFIError::from(err))), + }, + panic_error!() + ) +} + +/// Free a handle, decrementing the `Arc` strong count. +/// +/// The underlying runtime may not be dropped if other components are still using it. +/// Use [`ddog_shared_runtime_shutdown`] to cleanly stop workers. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_free(handle: *const SharedRuntime) { + catch_panic!( + { + if !handle.is_null() { + // SAFETY: handle was produced by Arc::into_raw; this call takes ownership. + drop(Arc::from_raw(handle)); + } + }, + () + ) +} + +/// Must be called in the parent process before `fork()`. +/// +/// Pauses all workers so that no background threads are running during the +/// fork, preventing deadlocks in the child process. +/// +/// Returns an error if `handle` is null. +/// The handle must have been initialized with `ddog_shared_runtime_new`. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_before_fork( + handle: Option<&SharedRuntime>, +) -> Option> { + catch_panic!( + { + match handle { + Some(runtime) => { + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + runtime.before_fork(); + None + } + None => Some(Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "handle is null", + ))), + } + }, + panic_error!() + ) +} + +/// Must be called in the parent process after `fork()`. +/// +/// Restarts all workers that were paused by [`ddog_shared_runtime_before_fork`]. +/// +/// Returns `None` on success, or an error if workers could not be restarted. +/// The handle must have been initialized with `ddog_shared_runtime_new`. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_after_fork_parent( + handle: Option<&SharedRuntime>, +) -> Option> { + catch_panic!( + { + match handle { + Some(runtime) => { + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + match runtime.after_fork_parent() { + Ok(()) => None, + Err(err) => Some(Box::new(SharedRuntimeFFIError::from(err))), + } + } + None => Some(Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "handle is null", + ))), + } + }, + panic_error!() + ) +} + +/// Must be called in the child process after `fork()`. +/// +/// Creates a fresh tokio runtime and restarts all workers. The original +/// runtime cannot be safely reused after a fork. +/// +/// Returns `None` on success, or an error if the runtime could not be +/// reinitialized. +/// The handle must have been initialized with `ddog_shared_runtime_new`. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_after_fork_child( + handle: Option<&SharedRuntime>, +) -> Option> { + catch_panic!( + { + match handle { + Some(runtime) => { + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + match runtime.after_fork_child() { + Ok(()) => None, + Err(err) => Some(Box::new(SharedRuntimeFFIError::from(err))), + } + } + None => Some(Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "handle is null", + ))), + } + }, + panic_error!() + ) +} + +/// Shut down the `SharedRuntime`, stopping all workers. +/// +/// `timeout_ms` is the maximum time to wait for workers to stop, in +/// milliseconds. Pass `0` for no timeout. +/// +/// Returns `None` on success, or `SharedRuntimeErrorCode::ShutdownTimedOut` +/// if the timeout was reached. +/// The handle must have been initialized with `ddog_shared_runtime_new`. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_shutdown( + handle: Option<&SharedRuntime>, + timeout_ms: u64, +) -> Option> { + catch_panic!( + { + match handle { + Some(runtime) => { + let timeout = if timeout_ms > 0 { + Some(std::time::Duration::from_millis(timeout_ms)) + } else { + None + }; + + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + match runtime.shutdown(timeout) { + Ok(()) => None, + Err(err) => Some(Box::new(SharedRuntimeFFIError::from(err))), + } + } + None => Some(Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "handle is null", + ))), + } + }, + panic_error!() + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem::MaybeUninit; + + #[test] + fn test_new_and_free() { + unsafe { + let mut handle: MaybeUninit<*const SharedRuntime> = MaybeUninit::uninit(); + let err = ddog_shared_runtime_new(NonNull::new_unchecked(handle.as_mut_ptr())); + assert!(err.is_none()); + ddog_shared_runtime_free(handle.assume_init()); + } + } + + #[test] + fn test_before_after_fork_null() { + unsafe { + let err = ddog_shared_runtime_before_fork(None); + assert_eq!(err.unwrap().code, SharedRuntimeErrorCode::InvalidArgument); + + let err = ddog_shared_runtime_after_fork_parent(None); + assert_eq!(err.unwrap().code, SharedRuntimeErrorCode::InvalidArgument); + + let err = ddog_shared_runtime_after_fork_child(None); + assert_eq!(err.unwrap().code, SharedRuntimeErrorCode::InvalidArgument); + } + } + + #[test] + fn test_fork_lifecycle() { + unsafe { + let mut handle: MaybeUninit<*const SharedRuntime> = MaybeUninit::uninit(); + ddog_shared_runtime_new(NonNull::new_unchecked(handle.as_mut_ptr())); + let handle = handle.assume_init(); + + let err = ddog_shared_runtime_before_fork(std::mem::transmute::< + *const SharedRuntime, + Option<&SharedRuntime>, + >(handle)); + assert!(err.is_none(), "{:?}", err.map(|e| e.code)); + + let err = ddog_shared_runtime_after_fork_parent(std::mem::transmute::< + *const SharedRuntime, + Option<&SharedRuntime>, + >(handle)); + assert!(err.is_none(), "{:?}", err.map(|e| e.code)); + + ddog_shared_runtime_free(handle); + } + } + + #[test] + fn test_shutdown() { + unsafe { + let mut handle: MaybeUninit<*const SharedRuntime> = MaybeUninit::uninit(); + ddog_shared_runtime_new(NonNull::new_unchecked(handle.as_mut_ptr())); + let handle = handle.assume_init(); + + let err = ddog_shared_runtime_shutdown( + std::mem::transmute::<*const SharedRuntime, Option<&SharedRuntime>>(handle), + 0, + ); + assert!(err.is_none()); + + ddog_shared_runtime_free(handle); + } + } + + #[test] + fn test_error_free() { + let error = Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "test error", + )); + unsafe { ddog_shared_runtime_error_free(Some(error)) }; + } +} diff --git a/libdd-shared-runtime/Cargo.toml b/libdd-shared-runtime/Cargo.toml new file mode 100644 index 0000000000..612cfee8c9 --- /dev/null +++ b/libdd-shared-runtime/Cargo.toml @@ -0,0 +1,28 @@ +# Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "libdd-shared-runtime" +version = "1.0.0" +description = "Shared tokio runtime with fork-safe worker management for Datadog libraries" +homepage = "https://github.com/DataDog/libdatadog/tree/main/libdd-shared-runtime" +repository = "https://github.com/DataDog/libdatadog/tree/main/libdd-shared-runtime" +edition.workspace = true +rust-version.workspace = true +license.workspace = true + +[lib] +crate-type = ["lib"] +bench = false + +[dependencies] +async-trait = "0.1" +futures = { version = "0.3", default-features = false, features = ["alloc"] } +tokio = { version = "1.23", features = ["rt", "macros", "time"] } +tokio-util = "0.7.11" +tracing = { version = "0.1", default-features = false } +libdd-capabilities = { path = "../libdd-capabilities", version = "0.1.0" } +libdd-common = { version = "3.0.2", path = "../libdd-common", default-features = false } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio = { version = "1.23", features = ["rt-multi-thread"] } diff --git a/libdd-shared-runtime/src/lib.rs b/libdd-shared-runtime/src/lib.rs new file mode 100644 index 0000000000..e9dc0ee642 --- /dev/null +++ b/libdd-shared-runtime/src/lib.rs @@ -0,0 +1,24 @@ +// Copyright 2026-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 +#![cfg_attr(not(test), deny(clippy::panic))] +#![cfg_attr(not(test), deny(clippy::unwrap_used))] +#![cfg_attr(not(test), deny(clippy::expect_used))] +#![cfg_attr(not(test), deny(clippy::todo))] +#![cfg_attr(not(test), deny(clippy::unimplemented))] + +//! A shared tokio runtime for running background workers across multiple components. +//! +//! This crate provides [`SharedRuntime`], which owns a single tokio runtime and manages +//! [`PausableWorker`]s on it. Components such as the trace exporter can share one runtime +//! instead of each creating their own, reducing thread and resource overhead. +//! +//! [`SharedRuntime`] also provides fork-safety hooks (`before_fork`, `after_fork_parent`, +//! `after_fork_child`) that pause and restart workers around `fork()` calls, preventing +//! deadlocks in child processes. + +pub mod shared_runtime; +pub mod worker; + +// Top-level re-exports for convenience +pub use shared_runtime::{SharedRuntime, SharedRuntimeError, WorkerHandle, WorkerHandleError}; +pub use worker::Worker; diff --git a/libdd-shared-runtime/src/shared_runtime/mod.rs b/libdd-shared-runtime/src/shared_runtime/mod.rs new file mode 100644 index 0000000000..21213acd12 --- /dev/null +++ b/libdd-shared-runtime/src/shared_runtime/mod.rs @@ -0,0 +1,576 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +//! SharedRuntime for managing [`PausableWorker`]s across fork boundaries. +//! +//! This module provides a SharedRuntime that manages a tokio runtime and allows +//! spawning PausableWorkers on it. It also provides hooks for safely handling +//! fork operations by pausing workers before fork and restarting them appropriately +//! in parent and child processes. + +pub(crate) mod pausable_worker; + +use crate::worker::Worker; +use futures::stream::{FuturesUnordered, StreamExt}; +use libdd_common::MutexExt; +use pausable_worker::{PausableWorker, PausableWorkerError}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::{fmt, io}; +use tokio::runtime::{Builder, Runtime}; +use tracing::{debug, error}; + +type BoxedWorker = Box; + +#[derive(Debug)] +struct WorkerEntry { + id: u64, + worker: PausableWorker, +} + +/// Handle to a worker registered on a [`SharedRuntime`]. +/// +/// This handle can be used to stop the worker. +#[derive(Clone, Debug)] +pub struct WorkerHandle { + worker_id: u64, + workers: Arc>>, +} + +#[derive(Debug)] +pub enum WorkerHandleError { + AlreadyStopped, + WorkerError(PausableWorkerError), +} + +impl fmt::Display for WorkerHandleError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::AlreadyStopped => { + write!(f, "Worker has already been stopped") + } + Self::WorkerError(err) => write!(f, "Worker error: {}", err), + } + } +} + +impl std::error::Error for WorkerHandleError {} + +impl From for WorkerHandleError { + fn from(err: PausableWorkerError) -> Self { + Self::WorkerError(err) + } +} + +impl WorkerHandle { + /// Stop the worker and execute the shutdown logic. + /// + /// # Errors + /// Returns an error if the worker has already been stopped. + /// + /// # Cancel safety + /// This function is *NOT* cancel safe and shouldn't be called in [Worker::trigger]. + /// If cancelled, the stopped worker can end up in an invalid state if a fork occurs while + /// stopping. + pub async fn stop(self) -> Result<(), WorkerHandleError> { + let mut worker = { + let mut workers_lock = self.workers.lock_or_panic(); + let Some(position) = workers_lock + .iter() + .position(|entry| entry.id == self.worker_id) + else { + return Err(WorkerHandleError::AlreadyStopped); + }; + let WorkerEntry { worker, .. } = workers_lock.swap_remove(position); + worker + }; + worker.pause().await?; + worker.shutdown().await; + Ok(()) + } +} + +/// Errors that can occur when using SharedRuntime. +#[derive(Debug)] +pub enum SharedRuntimeError { + /// The runtime is not available or in an invalid state. + RuntimeUnavailable, + /// Failed to acquire a lock on internal state. + LockFailed(String), + /// A worker operation failed. + WorkerError(PausableWorkerError), + /// Failed to create the tokio runtime. + RuntimeCreation(io::Error), + /// Shutdown timed out. + ShutdownTimedOut(std::time::Duration), +} + +impl fmt::Display for SharedRuntimeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::RuntimeUnavailable => { + write!(f, "Runtime is not available or in an invalid state") + } + Self::LockFailed(msg) => write!(f, "Failed to acquire lock: {}", msg), + Self::WorkerError(err) => write!(f, "Worker error: {}", err), + Self::RuntimeCreation(err) => { + write!(f, "Failed to create runtime: {}", err) + } + Self::ShutdownTimedOut(duration) => { + write!(f, "Shutdown timed out after {:?}", duration) + } + } + } +} + +impl std::error::Error for SharedRuntimeError {} + +impl From for SharedRuntimeError { + fn from(err: PausableWorkerError) -> Self { + SharedRuntimeError::WorkerError(err) + } +} + +impl From for SharedRuntimeError { + fn from(err: io::Error) -> Self { + SharedRuntimeError::RuntimeCreation(err) + } +} + +/// A shared runtime that manages PausableWorkers and provides fork safety hooks. +/// +/// The SharedRuntime owns a tokio runtime and tracks PausableWorkers spawned on it. +/// It provides methods to safely pause workers before forking and restart them +/// after fork in both parent and child processes. +/// +/// # Mutex lock order +/// When locking both [Self::runtime] and [Self::workers], the mutex must be locked in the order of +/// the fields in the struct. When possible avoid holding both locks simultaneously. +#[derive(Debug)] +pub struct SharedRuntime { + runtime: Arc>>>, + workers: Arc>>, + next_worker_id: AtomicU64, +} + +/// Build a tokio runtime appropriate for the current platform. +/// +/// On wasm32, a single-threaded current-thread runtime is used since multi-threading +/// is not available. On all other platforms a multi-threaded runtime is used. +fn build_runtime() -> Result { + #[cfg(not(target_arch = "wasm32"))] + { + Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + } + #[cfg(target_arch = "wasm32")] + { + Builder::new_current_thread().enable_all().build() + } +} + +impl SharedRuntime { + /// Create a new SharedRuntime with a default tokio runtime. + /// + /// # Errors + /// Returns an error if the tokio runtime cannot be created. + pub fn new() -> Result { + debug!("Creating new SharedRuntime"); + let runtime = build_runtime()?; + + Ok(Self { + runtime: Arc::new(Mutex::new(Some(Arc::new(runtime)))), + workers: Arc::new(Mutex::new(Vec::new())), + next_worker_id: AtomicU64::new(1), + }) + } + + /// Spawn a PausableWorker on this runtime. + /// + /// The worker will be tracked by this SharedRuntime and will be paused/resumed + /// during fork operations. + /// + /// # Errors + /// Returns an error if the runtime is not available or the worker cannot be started. + pub fn spawn_worker( + &self, + worker: T, + ) -> Result { + let boxed_worker: BoxedWorker = Box::new(worker); + debug!(?boxed_worker, "Spawning worker on SharedRuntime"); + let mut pausable_worker = PausableWorker::new(boxed_worker); + + // Hold the workers lock while starting the worker to avoid a race with + // before_fork: without this, before_fork could run after the worker is started but + // before it's added to the list, not pausing the worker before the runtime is dropped. + let runtime = self.runtime.lock_or_panic().clone(); + let mut workers_guard = self.workers.lock_or_panic(); + + // If the runtime is not available, the worker will be started + // when the runtime is recreated (after_fork_parent/child). + if let Some(runtime) = runtime { + if let Err(e) = pausable_worker.start(&runtime) { + return Err(e.into()); + } + } + + let worker_id = self.next_worker_id.fetch_add(1, Ordering::Relaxed); + + workers_guard.push(WorkerEntry { + id: worker_id, + worker: pausable_worker, + }); + + Ok(WorkerHandle { + worker_id, + workers: self.workers.clone(), + }) + } + + /// Hook to be called before forking. + /// + /// This method pauses all workers and prepares the runtime for forking. + /// It ensures that no background tasks are running when the fork occurs, + /// preventing potential deadlocks in the child process. + /// + /// Worker errors are logged but do not cause the function to fail. + /// If the worker fails to pause it is dropped without calling shutdown. + #[cfg(not(target_arch = "wasm32"))] + pub fn before_fork(&self) { + debug!("before_fork: pausing all workers"); + if let Some(runtime) = self.runtime.lock_or_panic().take() { + let mut workers_lock = self.workers.lock_or_panic(); + runtime.block_on(async { + let futures: FuturesUnordered<_> = workers_lock + .iter_mut() + .map(|worker_entry| async { + if let Err(e) = worker_entry.worker.pause().await { + error!("Worker failed to pause before fork: {:?}", e); + } + }) + .collect(); + + futures.collect::<()>().await; + }); + } + } + + fn restart_runtime(&self) -> Result<(), SharedRuntimeError> { + let mut runtime_lock = self.runtime.lock_or_panic(); + if runtime_lock.is_none() { + *runtime_lock = Some(Arc::new(build_runtime()?)); + } + Ok(()) + } + + /// Hook to be called in the parent process after forking. + /// + /// This method restarts workers and resumes normal operation in the parent process. + /// The runtime may need to be recreated if it was shut down in before_fork. + /// + /// # Errors + /// Returns an error if workers cannot be restarted or the runtime cannot be recreated. + #[cfg(not(target_arch = "wasm32"))] + pub fn after_fork_parent(&self) -> Result<(), SharedRuntimeError> { + debug!("after_fork_parent: restarting runtime and workers"); + self.restart_runtime()?; + + let runtime_lock = self.runtime.lock_or_panic(); + let runtime = runtime_lock + .as_ref() + .ok_or(SharedRuntimeError::RuntimeUnavailable)? + .clone(); + drop(runtime_lock); + + let mut workers_lock = self.workers.lock_or_panic(); + + // Restart all workers + for worker_entry in workers_lock.iter_mut() { + worker_entry.worker.start(&runtime)?; + } + + Ok(()) + } + + /// Hook to be called in the child process after forking. + /// + /// This method reinitializes the runtime and workers in the child process. + /// A new runtime must be created since tokio runtimes cannot be safely forked. + /// Workers are reset and restarted to resume operations in the child. + /// + /// # Errors + /// Returns an error if the runtime cannot be reinitialized or workers cannot be started. + #[cfg(not(target_arch = "wasm32"))] + pub fn after_fork_child(&self) -> Result<(), SharedRuntimeError> { + debug!("after_fork_child: reinitializing runtime and workers"); + self.restart_runtime()?; + + let runtime_lock = self.runtime.lock_or_panic(); + let runtime = runtime_lock + .as_ref() + .ok_or(SharedRuntimeError::RuntimeUnavailable)? + .clone(); + drop(runtime_lock); + + let mut workers_lock = self.workers.lock_or_panic(); + + // Restart all workers in child process + for worker_entry in workers_lock.iter_mut() { + worker_entry.worker.reset(); + worker_entry.worker.start(&runtime)?; + } + + Ok(()) + } + + /// Run a future to completion on the shared runtime, blocking the current thread. + /// + /// If the runtime is not available (e.g. after calling before_fork), a temporary + /// single-threaded runtime is used. + /// + /// # Errors + /// Returns an error if it fails to create a fallback runtime. + pub fn block_on(&self, f: F) -> Result { + let runtime = match self.runtime.lock_or_panic().as_ref() { + None => Arc::new(Builder::new_current_thread().enable_all().build()?), + Some(runtime) => runtime.clone(), + }; + Ok(runtime.block_on(f)) + } + + /// Shutdown the runtime and all workers synchronously with optional timeout. + /// + /// Worker errors are logged but do not cause the function to fail. + /// + /// # Errors + /// Returns an error only if shutdown times out. + pub fn shutdown(&self, timeout: Option) -> Result<(), SharedRuntimeError> { + debug!(?timeout, "Shutting down SharedRuntime"); + match self.runtime.lock_or_panic().take() { + Some(runtime) => { + if let Some(timeout) = timeout { + match runtime.block_on(async { + tokio::time::timeout(timeout, self.shutdown_async()).await + }) { + Ok(()) => Ok(()), + Err(_) => Err(SharedRuntimeError::ShutdownTimedOut(timeout)), + } + } else { + runtime.block_on(self.shutdown_async()); + Ok(()) + } + } + None => Ok(()), // The runtime is not running so there's nothing to shutdown + } + } + + /// Shutdown all workers asynchronously. + /// + /// This should be called during application shutdown to cleanly stop all + /// background workers and the runtime. + /// + /// Worker errors are logged but do not cause the function to fail. + pub async fn shutdown_async(&self) { + debug!("Shutting down all workers asynchronously"); + let workers = { + let mut workers_lock = self.workers.lock_or_panic(); + std::mem::take(&mut *workers_lock) + }; + + let futures: FuturesUnordered<_> = workers + .into_iter() + .map(|mut worker_entry| async move { + if let Err(e) = worker_entry.worker.pause().await { + error!("Worker failed to shutdown: {:?}", e); + return; + } + worker_entry.worker.shutdown().await; + }) + .collect(); + + futures.collect::<()>().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use std::sync::mpsc::{channel, Receiver, Sender}; + use std::time::Duration; + use tokio::time::sleep; + + #[derive(Debug)] + struct TestWorker { + state: i32, + sender: Sender, + } + + fn make_test_worker() -> (TestWorker, Receiver) { + let (sender, receiver) = channel::(); + (TestWorker { state: 0, sender }, receiver) + } + + #[async_trait] + impl Worker for TestWorker { + async fn run(&mut self) { + let _ = self.sender.send(self.state); + self.state += 1; + } + + async fn trigger(&mut self) { + sleep(Duration::from_millis(100)).await; + } + + fn reset(&mut self) { + self.state = 0; + } + + async fn shutdown(&mut self) { + self.state = -1; + let _ = self.sender.send(self.state); + } + } + + #[test] + fn test_shared_runtime_creation() { + let shared_runtime = SharedRuntime::new(); + assert!(shared_runtime.is_ok()); + } + + #[test] + fn test_spawn_worker() { + let shared_runtime = SharedRuntime::new().unwrap(); + let (worker, receiver) = make_test_worker(); + + let result = shared_runtime.spawn_worker(worker); + assert!(result.is_ok()); + assert_eq!(shared_runtime.workers.lock_or_panic().len(), 1); + + // Verify the worker is actually running by receiving its first output + assert_eq!( + receiver + .recv_timeout(Duration::from_secs(1)) + .expect("worker did not run"), + 0 + ); + } + + #[test] + fn test_worker_handle_stop() { + let rt = tokio::runtime::Runtime::new().unwrap(); + let shared_runtime = SharedRuntime::new().unwrap(); + let (worker, receiver) = make_test_worker(); + + let handle = shared_runtime.spawn_worker(worker).unwrap(); + assert_eq!(shared_runtime.workers.lock_or_panic().len(), 1); + + // Wait for at least one run before stopping + receiver + .recv_timeout(Duration::from_secs(1)) + .expect("worker did not run"); + + rt.block_on(async { + assert!(handle.stop().await.is_ok()); + }); + + assert_eq!(shared_runtime.workers.lock_or_panic().len(), 0); + + // Drain all messages after stop — the last one must be the shutdown sentinel + let mut last = receiver + .recv_timeout(Duration::from_secs(1)) + .expect("shutdown did not send a value"); + while let Ok(v) = receiver.try_recv() { + last = v; + } + assert_eq!(last, -1); + } + + #[test] + fn test_before_and_after_fork_parent() { + let shared_runtime = SharedRuntime::new().unwrap(); + let (worker, receiver) = make_test_worker(); + + shared_runtime.spawn_worker(worker).unwrap(); + + // Let the worker run until state > 0 so that preservation is observable + let mut state_before_fork = 0; + while state_before_fork == 0 { + state_before_fork = receiver + .recv_timeout(Duration::from_secs(1)) + .expect("worker did not advance state before fork"); + } + + shared_runtime.before_fork(); + // Drain pre-fork buffered messages now that the worker is paused + while receiver.try_recv().is_ok() {} + + assert!(shared_runtime.after_fork_parent().is_ok()); + + // State must be preserved (not reset) after fork in the parent + let after_fork_value = receiver + .recv_timeout(Duration::from_secs(1)) + .expect("worker did not resume after fork"); + assert!( + after_fork_value > state_before_fork, + "after_fork_parent should preserve state: got {after_fork_value}, expected > {state_before_fork}" + ); + } + + #[test] + fn test_after_fork_child() { + let shared_runtime = SharedRuntime::new().unwrap(); + let (worker, receiver) = make_test_worker(); + + shared_runtime.spawn_worker(worker).unwrap(); + + // Let the worker run until state > 0 so that the reset is observable + let mut state_before_fork = 0; + while state_before_fork == 0 { + state_before_fork = receiver + .recv_timeout(Duration::from_secs(1)) + .expect("worker did not advance state before fork"); + } + + shared_runtime.before_fork(); + // Drain pre-fork buffered messages now that the worker is paused + while receiver.try_recv().is_ok() {} + + assert!(shared_runtime.after_fork_child().is_ok()); + + // State must be reset to 0 in the child + let after_fork_value = receiver + .recv_timeout(Duration::from_secs(1)) + .expect("worker did not resume after fork child"); + assert_eq!( + after_fork_value, 0, + "after_fork_child should reset state to 0, got {after_fork_value}" + ); + } + + #[test] + fn test_shutdown() { + let shared_runtime = SharedRuntime::new().unwrap(); + let (worker, receiver) = make_test_worker(); + + shared_runtime.spawn_worker(worker).unwrap(); + + // Wait for at least one run before shutting down + receiver + .recv_timeout(Duration::from_secs(1)) + .expect("worker did not run"); + + shared_runtime.shutdown(None).unwrap(); + + // Drain all messages after shutdown — the last one must be the shutdown sentinel + let mut last = receiver + .recv_timeout(Duration::from_secs(1)) + .expect("shutdown did not send a value"); + while let Ok(v) = receiver.try_recv() { + last = v; + } + assert_eq!(last, -1); + } +} diff --git a/libdd-shared-runtime/src/shared_runtime/pausable_worker.rs b/libdd-shared-runtime/src/shared_runtime/pausable_worker.rs new file mode 100644 index 0000000000..e3dcec8701 --- /dev/null +++ b/libdd-shared-runtime/src/shared_runtime/pausable_worker.rs @@ -0,0 +1,218 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +//! Defines a pausable worker to be able to stop background processes before forks + +use crate::worker::Worker; +use libdd_capabilities::MaybeSend; +use std::fmt::Display; +use tokio::{runtime::Runtime, select, task::JoinHandle}; +use tokio_util::sync::CancellationToken; +use tracing::debug; + +/// A pausable worker which can be paused and restarted on forks. +/// +/// Used to allow a [`super::Worker`] to be paused while saving its state when +/// dropping a tokio runtime to be able to restart with the same state on a new runtime. This is +/// used to stop all threads before a fork to avoid deadlocks in child. +#[derive(Debug)] +pub enum PausableWorker { + Running { + handle: JoinHandle, + stop_token: CancellationToken, + }, + Paused { + worker: T, + }, + InvalidState, +} + +#[derive(Debug)] +pub enum PausableWorkerError { + InvalidState, + TaskAborted, +} + +impl Display for PausableWorkerError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PausableWorkerError::InvalidState => { + write!(f, "Worker is in an invalid state and must be recreated.") + } + PausableWorkerError::TaskAborted => { + write!(f, "Worker task has been aborted and state has been lost.") + } + } + } +} + +impl core::error::Error for PausableWorkerError {} + +impl PausableWorker { + /// Create a new pausable worker from the given worker. + pub fn new(worker: T) -> Self { + Self::Paused { worker } + } + + /// Start the worker on the given runtime. + /// + /// The worker's main loop will be run on the runtime. + pub fn start(&mut self, rt: &Runtime) -> Result<(), PausableWorkerError> { + #[cfg(target_arch = "wasm32")] + return Ok(()); + #[cfg(not(target_arch = "wasm32"))] + match self { + PausableWorker::Running { .. } => Ok(()), + PausableWorker::Paused { worker } => { + debug!(?worker, "Starting pausable worker"); + let PausableWorker::Paused { mut worker } = + std::mem::replace(self, PausableWorker::InvalidState) + else { + // Unreachable + return Ok(()); + }; + + // Worker is temporarily in an invalid state, but since this block is failsafe it + // will be replaced by a valid state. + let stop_token = CancellationToken::new(); + let cloned_token = stop_token.clone(); + let handle = rt.spawn(async move { + // First iteration using initial_trigger + select! { + // Always check for cancellation first, to reduce time-to-pause in case + // the initial trigger is always ready. + biased; + + _ = cloned_token.cancelled() => { + return worker; + } + _ = worker.initial_trigger() => { + worker.run().await; + } + } + + // Regular iterations + loop { + select! { + // Always check for cancellation first, to reduce time-to-pause in case + // the trigger is always ready. + biased; + + _ = cloned_token.cancelled() => { + break; + } + _ = worker.trigger() => { + worker.run().await; + } + } + } + worker + }); + + *self = PausableWorker::Running { handle, stop_token }; + Ok(()) + } + PausableWorker::InvalidState => Err(PausableWorkerError::InvalidState), + } + } + + /// Pause the worker and wait for it to complete, storing its state for restart. + /// + /// # Errors + /// Fails if the worker handle has been aborted preventing the worker from being retrieved. + pub async fn pause(&mut self) -> Result<(), PausableWorkerError> { + match self { + PausableWorker::Running { .. } => { + debug!("Waiting for worker to pause"); + let PausableWorker::Running { handle, stop_token } = + std::mem::replace(self, PausableWorker::InvalidState) + else { + // Unreachable + return Ok(()); + }; + + if !stop_token.is_cancelled() { + stop_token.cancel(); + } + + if let Ok(mut worker) = handle.await { + debug!(?worker, "Worker paused successfully"); + worker.on_pause().await; + *self = PausableWorker::Paused { worker }; + Ok(()) + } else { + // The task has been aborted and the worker can't be retrieved. + *self = PausableWorker::InvalidState; + Err(PausableWorkerError::TaskAborted) + } + } + PausableWorker::Paused { .. } => Ok(()), + PausableWorker::InvalidState => Err(PausableWorkerError::InvalidState), + } + } + + /// Reset the worker state (e.g. in a fork child). + pub fn reset(&mut self) { + if let PausableWorker::Paused { worker } = self { + worker.reset(); + } + } + + /// Shutdown the worker. + pub async fn shutdown(&mut self) { + if let PausableWorker::Paused { worker } = self { + worker.shutdown().await; + } + } +} + +#[cfg(test)] +mod tests { + use async_trait::async_trait; + use tokio::{runtime::Builder, time::sleep}; + + use super::*; + use std::{ + sync::mpsc::{channel, Sender}, + time::Duration, + }; + + /// Test worker incrementing the state and sending it with the sender. + #[derive(Debug)] + struct TestWorker { + state: u32, + sender: Sender, + } + + #[async_trait] + impl Worker for TestWorker { + async fn run(&mut self) { + let _ = self.sender.send(self.state); + self.state += 1; + } + + async fn trigger(&mut self) { + sleep(Duration::from_millis(100)).await; + } + } + + #[test] + fn test_restart() { + let (sender, receiver) = channel::(); + let worker = TestWorker { state: 0, sender }; + let runtime = Builder::new_multi_thread().enable_time().build().unwrap(); + let mut pausable_worker = PausableWorker::new(worker); + + pausable_worker.start(&runtime).unwrap(); + + assert_eq!(receiver.recv().unwrap(), 0); + runtime.block_on(async { pausable_worker.pause().await.unwrap() }); + // Empty the message queue and get the last message + let mut next_message = 1; + for message in receiver.try_iter() { + next_message = message + 1; + } + pausable_worker.start(&runtime).unwrap(); + assert_eq!(receiver.recv().unwrap(), next_message); + } +} diff --git a/libdd-shared-runtime/src/worker.rs b/libdd-shared-runtime/src/worker.rs new file mode 100644 index 0000000000..9d76ab8374 --- /dev/null +++ b/libdd-shared-runtime/src/worker.rs @@ -0,0 +1,76 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +use async_trait::async_trait; +use libdd_capabilities::MaybeSend; + +/// A background worker meant to be spawned on a [`SharedRuntime`](crate::SharedRuntime). +/// +/// # Lifecycle +/// The worker's [`run`](Self::run) method is executed every time [`trigger`](Self::trigger) +/// returns. On startup [`initial_trigger`](Self::initial_trigger) is called before the first +/// [`run`](Self::run). +/// +/// # Cancellation safety +/// The `trigger` function can be interrupted at any yield point (`.await`ed call). The state of the +/// worker at this point will be saved and used to restart the worker. To be able to safely restart, +/// the worker must be in a valid state on every call to `.await` within the trigger function. +/// See [`tokio::select#cancellation-safety`] for more details. +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +pub trait Worker: std::fmt::Debug + MaybeSend { + /// Main worker function + /// + /// Code in this function must always use timeout on long-running await calls to avoid + /// blocking forks if an await call takes too long to complete. + async fn run(&mut self); + + /// Function called between each `run` to wait for the next run. + /// This function should be cancellation safe as it can be cancelled at any yield point. + async fn trigger(&mut self); + + /// Alternative trigger called on start to provide custom behavior. + /// Defaults to `trigger` behavior. + async fn initial_trigger(&mut self) { + self.trigger().await + } + + /// Reset the worker state. Called in the child after a fork to cleanup parent state. + fn reset(&mut self) {} + + /// Hook called after the worker has been paused (e.g. before a fork). + /// Default is a no-op. + async fn on_pause(&mut self) {} + + /// Hook called when the app is shutting down. Can be used to flush remaining data. + async fn shutdown(&mut self) {} +} + +// Blanket implementation for boxed trait objects +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +impl Worker for Box { + async fn run(&mut self) { + (**self).run().await + } + + async fn trigger(&mut self) { + (**self).trigger().await + } + + async fn initial_trigger(&mut self) { + (**self).initial_trigger().await + } + + fn reset(&mut self) { + (**self).reset() + } + + async fn on_pause(&mut self) { + (**self).on_pause().await + } + + async fn shutdown(&mut self) { + (**self).shutdown().await + } +} diff --git a/libdd-telemetry/Cargo.toml b/libdd-telemetry/Cargo.toml index 66b8cd3f48..eaf69baec2 100644 --- a/libdd-telemetry/Cargo.toml +++ b/libdd-telemetry/Cargo.toml @@ -18,6 +18,7 @@ https = ["libdd-common/https"] [dependencies] anyhow = { version = "1.0" } +async-trait = "0.1" base64 = "0.22" futures = { version = "0.3", default-features = false } http-body-util = "0.1" @@ -31,6 +32,7 @@ uuid = { version = "1.3", features = ["v4"] } hashbrown = "0.15" bytes = "1.4" libdd-common = { version = "3.0.2", path = "../libdd-common", default-features = false } +libdd-shared-runtime = { version = "1.0.0", path = "../libdd-shared-runtime" } libdd-ddsketch = { version = "1.0.1", path = "../libdd-ddsketch" } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] diff --git a/libdd-telemetry/src/worker/mod.rs b/libdd-telemetry/src/worker/mod.rs index 2bcbd0163b..a7a19d09fe 100644 --- a/libdd-telemetry/src/worker/mod.rs +++ b/libdd-telemetry/src/worker/mod.rs @@ -11,7 +11,9 @@ use crate::{ metrics::{ContextKey, MetricBuckets, MetricContexts}, }; -use libdd_common::{http_common, tag::Tag, worker::Worker}; +use async_trait::async_trait; +use libdd_common::{http_common, tag::Tag}; +use libdd_shared_runtime::Worker; use std::iter::Sum; use std::ops::Add; @@ -140,6 +142,7 @@ pub struct TelemetryWorker { metrics_flush_interval: Duration, deadlines: scheduler::Scheduler, data: TelemetryWorkerData, + next_action: Option, } impl Debug for TelemetryWorker { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -157,58 +160,68 @@ impl Debug for TelemetryWorker { } } +#[async_trait] impl Worker for TelemetryWorker { - // Runs a state machine that waits for actions, either from the worker's - // mailbox, or scheduled actions from the worker's deadline object. - async fn run(&mut self) { - debug!( - worker.flavor = ?self.flavor, - worker.runtime_id = %self.runtime_id, - "Starting telemetry worker" - ); - - loop { - if self.cancellation_token.is_cancelled() { - debug!( - worker.runtime_id = %self.runtime_id, - "Telemetry worker cancelled, shutting down" - ); - return; - } + async fn trigger(&mut self) { + if self.next_action.is_some() { + // An action is already available and hasn't been executed + return; + } + // Wait for the next action and store it + let action = self.recv_next_action().await; + self.next_action = Some(action); + } - let action = self.recv_next_action().await; + // Processes a single action from the state machine + async fn run(&mut self) { + // Take the action that was stored by trigger() + if let Some(action) = self.next_action.take() { debug!( worker.runtime_id = %self.runtime_id, action = ?action, "Received telemetry action" ); - let action_result = match self.flavor { + // When running as a [libdd_shared_runtime::Worker] Shutdown is handled by stopping the + // Worker from the handle and not by sending stop action + let _action_result = match self.flavor { TelemetryWorkerFlavor::Full => self.dispatch_action(action).await, TelemetryWorkerFlavor::MetricsLogs => { self.dispatch_metrics_logs_action(action).await } }; - - match action_result { - ControlFlow::Continue(()) => {} - ControlFlow::Break(()) => { - debug!( - worker.runtime_id = %self.runtime_id, - worker.restartable = self.config.restartable, - "Telemetry worker received break signal" - ); - if !self.config.restartable { - break; - } - } - }; } + } - debug!( - worker.runtime_id = %self.runtime_id, - "Telemetry worker stopped" - ); + /// Reset the worker state in the child process after a fork. + /// + /// Discards inherited pending telemetry state and dedupe history without sending anything, and + /// drains the mailbox so that actions queued before the fork are not processed by the + /// child. + fn reset(&mut self) { + // Drain all actions queued in the mailbox before the fork. + while self.mailbox.try_recv().is_ok() {} + + // Discard any action that was staged by the last trigger() call. + self.next_action = None; + + // Clear all unbuffered telemetry data; the child must not send pre-fork data. + self.data.logs = store::QueueHashMap::default(); + self.data.metric_buckets = MetricBuckets::default(); + self.data.dependencies.clear(); + self.data.integrations.clear(); + self.data.configurations.clear(); + self.data.endpoints.clear(); + } + + async fn shutdown(&mut self) { + let stop_action = TelemetryActions::Lifecycle(LifecycleAction::Stop); + let _action_result = match self.flavor { + TelemetryWorkerFlavor::Full => self.dispatch_action(stop_action).await, + TelemetryWorkerFlavor::MetricsLogs => { + self.dispatch_metrics_logs_action(stop_action).await + } + }; } } @@ -828,6 +841,59 @@ impl TelemetryWorker { metric_buckets: self.data.metric_buckets.stats(), } } + + // Runs a state machine that waits for actions, either from the worker's + // mailbox, or scheduled actions from the worker's deadline object. + async fn run_loop(mut self) { + debug!( + worker.flavor = ?self.flavor, + worker.runtime_id = %self.runtime_id, + "Starting telemetry worker" + ); + + loop { + if self.cancellation_token.is_cancelled() { + debug!( + worker.runtime_id = %self.runtime_id, + "Telemetry worker cancelled, shutting down" + ); + return; + } + + let action = self.recv_next_action().await; + debug!( + worker.runtime_id = %self.runtime_id, + action = ?action, + "Received telemetry action" + ); + + let action_result = match self.flavor { + TelemetryWorkerFlavor::Full => self.dispatch_action(action).await, + TelemetryWorkerFlavor::MetricsLogs => { + self.dispatch_metrics_logs_action(action).await + } + }; + + match action_result { + ControlFlow::Continue(()) => {} + ControlFlow::Break(()) => { + debug!( + worker.runtime_id = %self.runtime_id, + worker.restartable = self.config.restartable, + "Telemetry worker received break signal" + ); + if !self.config.restartable { + break; + } + } + }; + } + + debug!( + worker.runtime_id = %self.runtime_id, + "Telemetry worker stopped" + ); + } } #[derive(Debug)] @@ -867,8 +933,9 @@ pub struct TelemetryWorkerHandle { sender: mpsc::Sender, shutdown: Arc, cancellation_token: CancellationToken, - // Used to spawn cancellation tasks - runtime: runtime::Handle, + // Used to spawn cancellation tasks. Should be None when running as a SharedRuntime worker, + // since the runtime is not guaranteed to exist for the lifetime of the worker. + runtime: Option, contexts: MetricContexts, } @@ -926,12 +993,16 @@ impl TelemetryWorkerHandle { } fn cancel_requests_with_deadline(&self, deadline: time::Instant) { + let Some(runtime) = &self.runtime else { + tracing::error!("Cannot schedule cancellation deadline: no runtime handle available"); + return; + }; let token = self.cancellation_token.clone(); let f = async move { tokio::time::sleep_until(deadline.into()).await; token.cancel() }; - self.runtime.spawn(f); + runtime.spawn(f); } pub fn wait_for_shutdown_deadline(&self, deadline: time::Instant) { @@ -1095,10 +1166,15 @@ impl TelemetryWorkerBuilder { } } - /// Build the corresponding worker and it's handle. - /// The runtime handle is wrapped in the worker handle and should be the one used to run the - /// worker task. - pub fn build_worker(self, tokio_runtime: Handle) -> (TelemetryWorkerHandle, TelemetryWorker) { + /// Build the corresponding worker and its handle. + /// + /// The optional runtime handle is stored in the worker handle and should be the one used to run + /// the worker task cancellation deadlines. Pass `None` when the worker will be run via a + /// [`SharedRuntime`](libdd_shared_runtime::SharedRuntime). + pub fn build_worker( + self, + tokio_runtime: Option, + ) -> (TelemetryWorkerHandle, TelemetryWorker) { let (tx, mailbox) = mpsc::channel(5000); let shutdown = Arc::new(InnerTelemetryShutdown { is_shutdown: Mutex::new(false), @@ -1146,6 +1222,7 @@ impl TelemetryWorkerBuilder { ), ]), cancellation_token: token.clone(), + next_action: None, }; ( @@ -1154,6 +1231,7 @@ impl TelemetryWorkerBuilder { shutdown, cancellation_token: token, runtime: tokio_runtime, + contexts, }, worker, @@ -1165,9 +1243,9 @@ impl TelemetryWorkerBuilder { pub fn spawn(self) -> (TelemetryWorkerHandle, JoinHandle<()>) { let tokio_runtime = tokio::runtime::Handle::current(); - let (worker_handle, mut worker) = self.build_worker(tokio_runtime.clone()); + let (worker_handle, worker) = self.build_worker(Some(tokio_runtime.clone())); - let join_handle = tokio_runtime.spawn(async move { worker.run().await }); + let join_handle = tokio_runtime.spawn(async move { worker.run_loop().await }); (worker_handle, join_handle) } @@ -1177,10 +1255,10 @@ impl TelemetryWorkerBuilder { let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; - let (handle, mut worker) = self.build_worker(runtime.handle().clone()); + let (handle, worker) = self.build_worker(Some(runtime.handle().clone())); let notify_shutdown = handle.shutdown.clone(); std::thread::spawn(move || { - runtime.block_on(worker.run()); + runtime.block_on(worker.run_loop()); runtime.shutdown_background(); notify_shutdown.shutdown_finished(); }); @@ -1203,4 +1281,231 @@ mod tests { #[allow(clippy::redundant_closure)] let _ = |h: TelemetryWorkerHandle| is_sync(h); } + + mod reset { + use super::super::*; + use crate::data::{ + metrics::{MetricNamespace, MetricType}, + Configuration, ConfigurationOrigin, Dependency, Endpoint, Integration, Log, LogLevel, + }; + use libdd_shared_runtime::Worker; + + fn build_test_worker() -> (TelemetryWorkerHandle, TelemetryWorker) { + let builder = TelemetryWorkerBuilder::new( + "hostname".to_string(), + "test-service".to_string(), + "rust".to_string(), + "1.0.0".to_string(), + "1.0.0".to_string(), + ); + // build_worker requires a tokio Handle; tests using this must be #[tokio::test] + builder.build_worker(Some(tokio::runtime::Handle::current())) + } + + fn make_log(id: u64, message: &str) -> (LogIdentifier, Log) { + ( + LogIdentifier { identifier: id }, + Log { + message: message.to_string(), + level: LogLevel::Warn, + stack_trace: None, + count: 1, + tags: String::new(), + is_sensitive: false, + is_crash: false, + }, + ) + } + + /// After reset(), pending buffered telemetry and dedupe history is cleared. + #[tokio::test] + async fn test_reset_clears_buffered_data() { + let (handle, mut worker) = build_test_worker(); + + // Populate every data field that reset() should clear. + worker.data.dependencies.insert(Dependency { + name: "dep".to_string(), + version: None, + }); + worker.data.integrations.insert(Integration { + name: "integration".to_string(), + version: None, + enabled: true, + compatible: None, + auto_enabled: None, + }); + worker.data.configurations.insert(Configuration { + name: "cfg".to_string(), + value: "true".to_string(), + origin: ConfigurationOrigin::Code, + config_id: None, + seq_id: None, + }); + worker.data.endpoints.insert(Endpoint { + operation_name: "GET /health".to_string(), + resource_name: "/health".to_string(), + ..Default::default() + }); + let (id, log) = make_log(42, "msg"); + worker.data.logs.get_mut_or_insert(id, log); + + // Register a metric context and add a data point. + let key = handle.register_metric_context( + "test.metric".to_string(), + vec![], + MetricType::Count, + false, + MetricNamespace::Tracers, + ); + worker.data.metric_buckets.add_point(key, 1.0, vec![]); + + worker.reset(); + + let stats = worker.stats(); + assert_eq!( + stats.dependencies_stored, 0, + "dependency dedupe history should be cleared" + ); + assert_eq!( + stats.dependencies_unflushed, 0, + "dependency pending queue should be cleared" + ); + assert_eq!( + stats.integrations_stored, 0, + "integration dedupe history should be cleared" + ); + assert_eq!( + stats.integrations_unflushed, 0, + "integration pending queue should be cleared" + ); + assert_eq!( + stats.configurations_stored, 0, + "configuration dedupe history should be cleared" + ); + assert_eq!( + stats.configurations_unflushed, 0, + "configuration pending queue should be cleared" + ); + assert_eq!(stats.logs, 0, "logs should be cleared"); + assert_eq!( + stats.metric_buckets.buckets, 0, + "metric buckets should be cleared" + ); + assert_eq!( + stats.metric_buckets.series, 0, + "metric series should be cleared" + ); + assert!( + worker.data.endpoints.is_empty(), + "endpoints should be cleared" + ); + assert!(worker.next_action.is_none(), "next_action should be None"); + } + + /// After reset(), actions queued in the mailbox before the fork are discarded. + #[tokio::test] + async fn test_reset_drains_mailbox() { + let (handle, mut worker) = build_test_worker(); + + // Enqueue several actions that should be discarded. + handle + .try_send_msg(TelemetryActions::AddDependency(Dependency { + name: "dep".to_string(), + version: None, + })) + .unwrap(); + let (id, log) = make_log(1, "pre-fork log"); + handle + .try_send_msg(TelemetryActions::AddLog((id, log))) + .unwrap(); + + // Stage one action as if trigger() had already stored it. + worker.next_action = Some(TelemetryActions::Lifecycle(LifecycleAction::Start)); + + worker.reset(); + + // The mailbox must be empty and next_action cleared. + assert!( + worker.mailbox.try_recv().is_err(), + "mailbox should be empty" + ); + assert!(worker.next_action.is_none(), "next_action should be None"); + // None of the queued actions should have been applied to pending state. + let stats = worker.stats(); + assert_eq!( + stats.dependencies_stored, 0, + "queued AddDependency must not be applied" + ); + assert_eq!( + stats.dependencies_unflushed, 0, + "queued AddDependency must not be pending" + ); + assert_eq!(stats.logs, 0, "queued AddLog must be discarded"); + } + + /// After reset(), the worker accepts new telemetry and processes it normally. + #[tokio::test] + async fn test_worker_accepts_new_data_after_reset() { + let (handle, mut worker) = build_test_worker(); + worker.flavor = TelemetryWorkerFlavor::MetricsLogs; + + // Populate state before reset – this data must not survive. + let (id, log) = make_log(99, "pre-fork"); + worker.data.logs.get_mut_or_insert(id, log); + + worker.reset(); + + // Send a new log from the child side. + let (id2, log2) = make_log(1, "post-fork"); + handle + .try_send_msg(TelemetryActions::AddLog((id2, log2))) + .unwrap(); + + // Simulate one trigger() + run() cycle. + worker.trigger().await; + worker.run().await; + + let stats = worker.stats(); + // Only the new post-fork log should be buffered. + assert_eq!(stats.logs, 1, "only post-fork log should be present"); + } + + /// After reset(), lifecycle state needed to keep periodic flushing alive is preserved. + #[tokio::test] + async fn test_reset_preserves_started_and_deadlines() { + let (_handle, mut worker) = build_test_worker(); + + worker.data.started = true; + worker + .deadlines + .schedule_event(LifecycleAction::FlushMetricAggr) + .unwrap(); + worker + .deadlines + .schedule_event(LifecycleAction::FlushData) + .unwrap(); + + let deadlines_before = worker.deadlines.deadlines.clone(); + + worker.reset(); + + assert!(worker.data.started, "started flag should be preserved"); + assert_eq!( + worker.deadlines.deadlines.len(), + deadlines_before.len(), + "scheduled deadlines should be preserved" + ); + for ((_, actual), (_, expected)) in worker + .deadlines + .deadlines + .iter() + .zip(deadlines_before.iter()) + { + assert_eq!( + actual, expected, + "deadline kinds should be preserved across reset" + ); + } + } + } } diff --git a/libdd-telemetry/src/worker/store.rs b/libdd-telemetry/src/worker/store.rs index 3986941de1..e085c8a9b1 100644 --- a/libdd-telemetry/src/worker/store.rs +++ b/libdd-telemetry/src/worker/store.rs @@ -38,6 +38,13 @@ mod queuehashmap { self.items.is_empty() } + /// Clear the map, reusing existing allocations + pub fn clear(&mut self) { + self.table.clear(); + self.items.clear(); + self.popped = 0; + } + // Remove the oldest item in the queue and return it pub fn pop_front(&mut self) -> Option<(K, V)> { let (k, v) = self.items.pop_front()?; @@ -208,6 +215,12 @@ where pub fn len_stored(&self) -> usize { self.items.len() } + + /// Discard pending unflushed items and clear stored dedupe history. + pub fn clear(&mut self) { + self.unflushed.clear(); + self.items.clear(); + } } impl Extend for Store diff --git a/libdd-trace-utils/src/test_utils/mod.rs b/libdd-trace-utils/src/test_utils/mod.rs index da4d2b8cff..e71f65fdd1 100644 --- a/libdd-trace-utils/src/test_utils/mod.rs +++ b/libdd-trace-utils/src/test_utils/mod.rs @@ -433,6 +433,25 @@ pub async fn poll_for_mock_hit( mock_hit } +/// Poll for a mock to be hit at least `min_hits` times. +/// +/// Returns `true` as soon as the mock has been called at least `min_hits` times, +/// or `false` if `poll_attempts` is exhausted before that threshold is reached. +pub async fn poll_for_mock_hits( + mock: &mut Mock<'_>, + poll_attempts: i32, + sleep_interval_ms: u64, + min_hits: usize, +) -> bool { + for _ in 0..poll_attempts { + sleep(Duration::from_millis(sleep_interval_ms)).await; + if mock.calls_async().await >= min_hits { + return true; + } + } + false +} + /// Creates a `SendData` object with the specified size and target endpoint. /// /// This function is a test helper to create a `SendData` object. diff --git a/tools/docker/Dockerfile.build b/tools/docker/Dockerfile.build index 84026b1fe5..fe63849159 100644 --- a/tools/docker/Dockerfile.build +++ b/tools/docker/Dockerfile.build @@ -115,6 +115,8 @@ COPY "spawn_worker/Cargo.toml" "spawn_worker/" COPY "tests/spawn_from_lib/Cargo.toml" "tests/spawn_from_lib/" COPY "datadog-ipc/Cargo.toml" "datadog-ipc/" COPY "datadog-ipc-macros/Cargo.toml" "datadog-ipc-macros/" +COPY "libdd-shared-runtime/Cargo.toml" "libdd-shared-runtime/" +COPY "libdd-shared-runtime-ffi/Cargo.toml" "libdd-shared-runtime-ffi/" COPY "libdd-data-pipeline/Cargo.toml" "libdd-data-pipeline/" COPY "libdd-data-pipeline-ffi/Cargo.toml" "libdd-data-pipeline-ffi/" COPY "bin_tests/Cargo.toml" "bin_tests/"