diff --git a/.github/workflows/verify-package.yml b/.github/workflows/verify-package.yml index 085ac61d..10ad364a 100644 --- a/.github/workflows/verify-package.yml +++ b/.github/workflows/verify-package.yml @@ -52,7 +52,9 @@ jobs: libcurl4-openssl-dev \ pkg-config \ libsasl2-dev \ - protobuf-compiler + protobuf-compiler \ + musl-tools + sudo ln -sf /usr/bin/musl-gcc /usr/local/bin/x86_64-linux-musl-gcc - name: Cache Cargo uses: Swatinem/rust-cache@v2 diff --git a/Cargo.lock b/Cargo.lock index e174c43f..b8edca1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2215,6 +2215,7 @@ dependencies = [ "lru", "num_cpus", "parking_lot", + "parquet", "petgraph 0.7.1", "proctitle", "prost", @@ -2228,6 +2229,7 @@ dependencies = [ "serde_yaml", "sqlparser", "strum", + "tempfile", "thiserror 2.0.18", "tokio", "tokio-stream", diff --git a/Cargo.toml b/Cargo.toml index 87d4ea03..531601d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "sync", "tim serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9" serde_json = "1.0" -uuid = { version = "1.0", features = ["v4"] } +uuid = { version = "1.0", features = ["v4", "v7"] } log = "0.4" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -51,6 +51,7 @@ arrow = { version = "55", default-features = false } arrow-array = "55" arrow-ipc = "55" arrow-schema = { version = "55", features = ["serde"] } +parquet = "55" futures = "0.3" serde_json_path = "0.7" xxhash-rust = { version = "0.8", features = ["xxh3"] } @@ -78,3 +79,6 @@ governor = "0.8.0" default = ["incremental-cache", "python"] incremental-cache = ["wasmtime/incremental-cache"] python = [] + +[dev-dependencies] +tempfile = "3.27.0" diff --git a/Makefile b/Makefile index 4daf185b..e914b376 100644 --- a/Makefile +++ b/Makefile @@ -13,12 +13,50 @@ APP_NAME := function-stream VERSION := $(shell grep '^version' Cargo.toml | head -1 | awk -F '"' '{print $$2}') -ARCH := $(shell uname -m) -OS := $(shell uname -s | tr '[:upper:]' '[:lower:]') DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") +# 1. Auto-detect system environment & normalize architecture +RAW_ARCH := $(shell uname -m) +# Fix macOS M-series returning arm64 while Rust expects aarch64 +ifeq ($(RAW_ARCH), arm64) + ARCH := aarch64 +else ifeq ($(RAW_ARCH), amd64) + ARCH := x86_64 +else + ARCH := $(RAW_ARCH) +endif + +OS := $(shell uname -s | tr '[:upper:]' '[:lower:]') +OS_NAME := $(shell uname -s) + +# 2. Configure RUSTFLAGS and target triple per platform DIST_ROOT := dist -TARGET_DIR := target/release +ifeq ($(OS_NAME), Linux) + # Linux: static-link musl for a truly self-contained, zero-dependency binary + TRIPLE := $(ARCH)-unknown-linux-musl + STATIC_FLAGS := -C target-feature=+crt-static +else ifeq ($(OS_NAME), Darwin) + # macOS: strip symbols but keep dynamic linking (Apple system restriction) + TRIPLE := $(ARCH)-apple-darwin + STATIC_FLAGS := +else ifneq (,$(findstring MINGW,$(OS_NAME))$(findstring MSYS,$(OS_NAME))) + # Windows (Git Bash / MSYS2): static-link MSVC runtime + TRIPLE := $(ARCH)-pc-windows-msvc + STATIC_FLAGS := -C target-feature=+crt-static +else + # Fallback + TRIPLE := $(ARCH)-unknown-linux-gnu + STATIC_FLAGS := +endif + +# 3. Aggressive optimization flags +# opt-level=z : size-oriented, minimize binary footprint +# strip=symbols: remove debug symbol table at link time +# Note: panic=abort is intentionally omitted to preserve stack unwinding +# for better fault tolerance in the streaming runtime +OPTIMIZE_FLAGS := -C opt-level=z -C strip=symbols $(STATIC_FLAGS) + +TARGET_DIR := target/$(TRIPLE)/release PYTHON_ROOT := python WASM_SOURCE := $(PYTHON_ROOT)/functionstream-runtime/target/functionstream-python-runtime.wasm @@ -42,7 +80,7 @@ C_0 := \033[0m log = @printf "$(C_B)[-]$(C_0) %-15s %s\n" "$(1)" "$(2)" success = @printf "$(C_G)[✔]$(C_0) %s\n" "$(1)" -.PHONY: all help build build-lite dist dist-lite clean test env env-clean go-sdk-env go-sdk-build go-sdk-clean docker docker-run docker-push .check-env .build-wasm +.PHONY: all help build build-lite dist dist-lite clean test env env-clean go-sdk-env go-sdk-build go-sdk-clean docker docker-run docker-push .check-env .ensure-target .build-wasm all: build @@ -65,18 +103,42 @@ help: @echo "" @echo " Version: $(VERSION) | Arch: $(ARCH) | OS: $(OS)" -build: .check-env .build-wasm - $(call log,BUILD,Rust Full Features) - @cargo build --release --features python --quiet +# 4. Auto-install missing Rust target toolchain +.ensure-target: + @rustup target list --installed | grep -q "$(TRIPLE)" || \ + (printf "$(C_Y)[!] Auto-installing target toolchain for $(OS_NAME): $(TRIPLE)$(C_0)\n" && \ + rustup target add $(TRIPLE)) + +# 5. Build targets (depend on .ensure-target for automatic toolchain setup) +build: .check-env .ensure-target .build-wasm + $(call log,BUILD,Rust Full [$(OS_NAME) / $(TRIPLE)]) + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + --features python \ + --quiet $(call log,BUILD,CLI) - @cargo build --release -p function-stream-cli --quiet + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + -p function-stream-cli \ + --quiet $(call success,Target: $(TARGET_DIR)/$(APP_NAME) $(TARGET_DIR)/cli) -build-lite: .check-env - $(call log,BUILD,Rust Lite No Python) - @cargo build --release --no-default-features --features incremental-cache --quiet +build-lite: .check-env .ensure-target + $(call log,BUILD,Rust Lite [$(OS_NAME) / $(TRIPLE)]) + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + --no-default-features \ + --features incremental-cache \ + --quiet $(call log,BUILD,CLI for dist) - @cargo build --release -p function-stream-cli --quiet + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + -p function-stream-cli \ + --quiet $(call success,Target: $(TARGET_DIR)/$(APP_NAME) $(TARGET_DIR)/cli) .build-wasm: diff --git a/protocol/proto/storage.proto b/protocol/proto/storage.proto index d7caf7bc..20e14862 100644 --- a/protocol/proto/storage.proto +++ b/protocol/proto/storage.proto @@ -52,6 +52,14 @@ message StreamingTableDefinition { // Stored as opaque bytes to avoid coupling storage schema with runtime API protos. bytes fs_program_bytes = 3; string comment = 4; + + // User-specified checkpoint interval from WITH clause (e.g. 'checkpoint.interval' = '5000'). + // 0 or unset means use system default. + uint64 checkpoint_interval_ms = 5; + + // Last globally-committed checkpoint epoch. + // Updated by JobManager after all operators ACK. Used for crash recovery. + uint64 latest_checkpoint_epoch = 6; } // ============================================================================= diff --git a/src/coordinator/execution/executor.rs b/src/coordinator/execution/executor.rs index 0000d0cf..7dc3c0ff 100644 --- a/src/coordinator/execution/executor.rs +++ b/src/coordinator/execution/executor.rs @@ -322,25 +322,34 @@ impl PlanVisitor for Executor { let job_manager: Arc = Arc::clone(&self.job_manager); let job_id = plan.name.clone(); - let job_id = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(job_manager.submit_job(job_id, fs_program.clone())) - }) - .map_err(|e| ExecuteError::Internal(format!("Failed to submit streaming job: {e}")))?; + + let custom_interval: Option = plan + .with_options + .as_ref() + .and_then(|opts| opts.get("checkpoint.interval")) + .and_then(|v| v.parse().ok()); self.catalog_manager .persist_streaming_job( &plan.name, &fs_program, plan.comment.as_deref().unwrap_or(""), + custom_interval.unwrap_or(0), ) .map_err(|e| { - ExecuteError::Internal(format!( - "Streaming job '{}' submitted but persistence failed: {e}", - plan.name - )) + ExecuteError::Internal(format!("Streaming job persistence failed: {e}",)) })?; + let job_id = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(job_manager.submit_job( + job_id, + fs_program, + custom_interval, + None, + )) + }) + .map_err(|e| ExecuteError::Internal(format!("Failed to submit streaming job: {e}")))?; + info!( job_id = %job_id, table = %plan.name, diff --git a/src/coordinator/plan/logical_plan_visitor.rs b/src/coordinator/plan/logical_plan_visitor.rs index 6adc6420..d49d0314 100644 --- a/src/coordinator/plan/logical_plan_visitor.rs +++ b/src/coordinator/plan/logical_plan_visitor.rs @@ -168,10 +168,28 @@ impl LogicalPlanVisitor { let validated_program = self.validate_graph_topology(&final_logical_plan)?; + let streaming_with_options: Option> = + if with_options.is_empty() { + None + } else { + let map: std::collections::HashMap = with_options + .iter() + .filter_map(|opt| match opt { + SqlOption::KeyValue { key, value } => Some(( + key.value.clone(), + value.to_string().trim_matches('\'').to_string(), + )), + _ => None, + }) + .collect(); + if map.is_empty() { None } else { Some(map) } + }; + Ok(StreamingTable { name: sink_table_name, comment: comment.clone(), program: validated_program, + with_options: streaming_with_options, }) } diff --git a/src/coordinator/plan/streaming_table_plan.rs b/src/coordinator/plan/streaming_table_plan.rs index 512ec266..e155ba91 100644 --- a/src/coordinator/plan/streaming_table_plan.rs +++ b/src/coordinator/plan/streaming_table_plan.rs @@ -10,6 +10,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use super::{PlanNode, PlanVisitor, PlanVisitorContext, PlanVisitorResult}; use crate::sql::logical_node::logical::LogicalProgram; @@ -19,6 +21,7 @@ pub struct StreamingTable { pub name: String, pub comment: Option, pub program: LogicalProgram, + pub with_options: Option>, } impl PlanNode for StreamingTable { diff --git a/src/runtime/streaming/api/context.rs b/src/runtime/streaming/api/context.rs index f9dc805e..27babd56 100644 --- a/src/runtime/streaming/api/context.rs +++ b/src/runtime/streaming/api/context.rs @@ -10,6 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -19,6 +20,7 @@ use arrow_array::RecordBatch; use crate::runtime::streaming::memory::MemoryPool; use crate::runtime::streaming::network::endpoint::PhysicalSender; use crate::runtime::streaming::protocol::event::{StreamEvent, TrackedEvent}; +use crate::runtime::streaming::state::{IoManager, MemoryController}; #[derive(Debug, Clone)] pub struct TaskContextConfig { @@ -61,9 +63,22 @@ pub struct TaskContext { /// Subtask-level tunables. config: TaskContextConfig, + + /// Root directory for operator state persistence (LSM-Tree data/tombstone files). + pub state_dir: PathBuf, + + /// Shared memory controller for state engine back-pressure. + pub memory_controller: Arc, + + /// I/O thread pool handle for background spill/compaction. + pub io_manager: IoManager, + + /// Last globally-committed safe epoch for crash recovery. + safe_epoch: u64, } impl TaskContext { + #[allow(clippy::too_many_arguments)] pub fn new( job_id: String, pipeline_id: u32, @@ -71,6 +86,10 @@ impl TaskContext { parallelism: u32, downstream_senders: Vec, memory_pool: Arc, + memory_controller: Arc, + io_manager: IoManager, + state_dir: PathBuf, + safe_epoch: u64, ) -> Self { let task_name = format!( "Task-[{}]-Pipe[{}]-Sub[{}/{}]", @@ -87,9 +106,18 @@ impl TaskContext { memory_pool, current_watermark: None, config: TaskContextConfig::default(), + state_dir, + memory_controller, + io_manager, + safe_epoch, } } + #[inline] + pub fn latest_safe_epoch(&self) -> u64 { + self.safe_epoch + } + #[inline] pub fn config(&self) -> &TaskContextConfig { &self.config diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index 3082dc56..b0839e4a 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -10,13 +10,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; -use std::sync::{Arc, OnceLock, RwLock}; +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex, OnceLock, RwLock}; +use std::time::Duration; use anyhow::{Context, Result, anyhow, bail, ensure}; use tokio::sync::mpsc; +use tokio::task::JoinHandle as TokioJoinHandle; use tokio_stream::wrappers::ReceiverStream; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use protocol::function_stream_graph::{ChainedOperator, FsProgram}; @@ -31,7 +34,10 @@ use crate::runtime::streaming::job::models::{ }; use crate::runtime::streaming::memory::MemoryPool; use crate::runtime::streaming::network::endpoint::{BoxedEventStream, PhysicalSender}; -use crate::runtime::streaming::protocol::control::{ControlCommand, StopMode}; +use crate::runtime::streaming::protocol::control::{ControlCommand, JobMasterEvent, StopMode}; +use crate::runtime::streaming::protocol::event::CheckpointBarrier; +use crate::runtime::streaming::state::{IoManager, IoPool, MemoryController, NoopMetricsCollector}; +use crate::storage::stream_catalog::CatalogManager; #[derive(Debug, Clone)] pub struct StreamingJobSummary { @@ -57,12 +63,39 @@ pub struct StreamingJobDetail { pub program: FsProgram, } +#[derive(Debug, Clone)] +pub struct StateConfig { + pub max_background_spills: usize, + pub max_background_compactions: usize, + pub soft_limit_ratio: f64, + pub checkpoint_interval_ms: u64, +} + +impl Default for StateConfig { + fn default() -> Self { + Self { + max_background_spills: 4, + max_background_compactions: 2, + soft_limit_ratio: 0.7, + checkpoint_interval_ms: 10_000, + } + } +} + static GLOBAL_JOB_MANAGER: OnceLock> = OnceLock::new(); pub struct JobManager { active_jobs: Arc>>, operator_factory: Arc, memory_pool: Arc, + + #[allow(dead_code)] + memory_controller: Arc, + #[allow(dead_code)] + io_manager_client: IoManager, + io_pool: Mutex>, + state_base_dir: PathBuf, + state_config: StateConfig, } struct PreparedChain { @@ -85,17 +118,48 @@ impl PipelineRunner { } impl JobManager { - pub fn new(operator_factory: Arc, max_memory_bytes: usize) -> Self { - Self { + pub fn new( + operator_factory: Arc, + max_memory_bytes: usize, + state_base_dir: impl AsRef, + state_config: StateConfig, + ) -> Result { + let soft_limit_bytes = (max_memory_bytes as f64 * state_config.soft_limit_ratio) as usize; + let memory_controller = MemoryController::new(soft_limit_bytes, max_memory_bytes); + + let metrics = Arc::new(NoopMetricsCollector); + let (io_pool, io_manager_client) = IoPool::try_new( + state_config.max_background_spills, + state_config.max_background_compactions, + metrics, + ) + .context("Failed to initialize state engine I/O pool")?; + + Ok(Self { active_jobs: Arc::new(RwLock::new(HashMap::new())), operator_factory, memory_pool: MemoryPool::new(max_memory_bytes), - } + memory_controller, + io_manager_client, + io_pool: Mutex::new(Some(io_pool)), + state_base_dir: state_base_dir.as_ref().to_path_buf(), + state_config, + }) } - pub fn init(factory: Arc, memory_bytes: usize) -> Result<()> { + pub fn init( + factory: Arc, + memory_bytes: usize, + state_base_dir: PathBuf, + state_config: StateConfig, + ) -> Result<()> { GLOBAL_JOB_MANAGER - .set(Arc::new(Self::new(factory, memory_bytes))) + .set(Arc::new(Self::new( + factory, + memory_bytes, + state_base_dir, + state_config, + )?)) .map_err(|_| anyhow!("JobManager singleton already initialized")) } @@ -106,19 +170,44 @@ impl JobManager { .ok_or_else(|| anyhow!("JobManager not initialized. Call init() first.")) } - pub async fn submit_job(&self, job_id: String, program: FsProgram) -> Result { + pub fn shutdown(&self) { + if let Some(pool) = self.io_pool.lock().unwrap().take() { + pool.shutdown(); + } + } + + pub async fn submit_job( + &self, + job_id: String, + program: FsProgram, + custom_checkpoint_interval_ms: Option, + recovery_epoch: Option, + ) -> Result { let mut edge_manager = EdgeManager::build(&program.nodes, &program.edges); let mut pipelines = HashMap::with_capacity(program.nodes.len()); + let mut source_control_txs = Vec::new(); + let mut expected_pipeline_ids = HashSet::new(); + + let job_state_dir = self.state_base_dir.join(&job_id); + std::fs::create_dir_all(&job_state_dir).context("Failed to create job state dir")?; + + let (job_master_tx, job_master_rx) = mpsc::channel(256); + + let safe_epoch = recovery_epoch.unwrap_or(0); + for node in &program.nodes { let pipeline_id = node.node_index as u32; - let pipeline = self + let (pipeline, is_source) = self .build_and_spawn_pipeline( job_id.clone(), pipeline_id, &node.operators, &mut edge_manager, + &job_state_dir, + job_master_tx.clone(), + safe_epoch, ) .with_context(|| { format!( @@ -127,9 +216,25 @@ impl JobManager { ) })?; + if is_source { + source_control_txs.push(pipeline.control_tx.clone()); + } + expected_pipeline_ids.insert(pipeline_id); pipelines.insert(pipeline_id, pipeline); } + let interval_ms = + custom_checkpoint_interval_ms.unwrap_or(self.state_config.checkpoint_interval_ms); + + self.spawn_checkpoint_coordinator( + job_id.clone(), + source_control_txs, + job_master_rx, + expected_pipeline_ids, + interval_ms, + safe_epoch + 1, + ); + let graph = PhysicalExecutionGraph { job_id: job_id.clone(), program, @@ -143,7 +248,7 @@ impl JobManager { .map_err(|e| anyhow!("Active jobs lock poisoned: {}", e))?; jobs_guard.insert(job_id.clone(), graph); - info!(job_id = %job_id, "Job submitted successfully."); + info!(job_id = %job_id, interval_ms, recovery_epoch = safe_epoch, "Job submitted successfully."); Ok(job_id) } @@ -320,13 +425,17 @@ impl JobManager { .collect()) } + #[allow(clippy::too_many_arguments)] fn build_and_spawn_pipeline( &self, job_id: String, pipeline_id: u32, operators: &[ChainedOperator], edge_manager: &mut EdgeManager, - ) -> Result { + job_state_dir: &Path, + _job_master_tx: mpsc::Sender, + recovery_epoch: u64, + ) -> Result<(PhysicalPipeline, bool)> { let (raw_inboxes, raw_outboxes) = edge_manager.take_endpoints(pipeline_id).with_context(|| { format!( @@ -352,6 +461,8 @@ impl JobManager { ) })?; + let is_source = chain.source.is_some(); + ensure!( chain.source.is_some() || !physical_inboxes.is_empty(), "Topology Error: Pipeline '{}' contains no source and has no upstream inputs (Dead end).", @@ -375,6 +486,10 @@ impl JobManager { parallelism, physical_outboxes, Arc::clone(&self.memory_pool), + Arc::clone(&self.memory_controller), + self.io_manager_client.clone(), + job_state_dir.to_path_buf(), + recovery_epoch, ); let runner = if let Some(source) = chain.source { @@ -392,12 +507,13 @@ impl JobManager { .spawn_worker_thread(job_id, pipeline_id, runner, Arc::clone(&status)) .with_context(|| format!("Failed to spawn OS thread for pipeline {}", pipeline_id))?; - Ok(PhysicalPipeline { + let pipeline = PhysicalPipeline { pipeline_id, handle: Some(handle), status, control_tx, - }) + }; + Ok((pipeline, is_source)) } fn build_operator_chain(&self, operator_configs: &[ChainedOperator]) -> Result { @@ -509,4 +625,97 @@ impl JobManager { warn!(job_id = %job_id, pipeline_id = pipeline_id, "Pipeline failure detected. Job degraded."); } } + + // ======================================================================== + // Chandy-Lamport distributed snapshot barrier coordinator + // ======================================================================== + + fn spawn_checkpoint_coordinator( + &self, + job_id: String, + source_control_txs: Vec>, + mut job_master_rx: mpsc::Receiver, + expected_pipeline_ids: HashSet, + interval_ms: u64, + start_epoch: u64, + ) -> TokioJoinHandle<()> { + tokio::spawn(async move { + if interval_ms == 0 { + info!(job_id = %job_id, "Checkpoint disabled for this job"); + return; + } + + let mut interval = tokio::time::interval(Duration::from_millis(interval_ms)); + interval.tick().await; + + let mut current_epoch: u64 = start_epoch; + let mut pending_checkpoints: HashMap> = HashMap::new(); + + loop { + tokio::select! { + _ = interval.tick() => { + info!(job_id = %job_id, epoch = current_epoch, "Triggering global Checkpoint Barrier."); + pending_checkpoints.insert(current_epoch, expected_pipeline_ids.clone()); + + let barrier = CheckpointBarrier { + epoch: current_epoch as u32, + min_epoch: 0, + timestamp: std::time::SystemTime::now(), + then_stop: false, + }; + + for tx in &source_control_txs { + let cmd = ControlCommand::trigger_checkpoint(barrier); + if tx.send(cmd).await.is_err() { + debug!(job_id = %job_id, "Source disconnected. Shutting down coordinator."); + return; + } + } + current_epoch += 1; + } + + Some(event) = job_master_rx.recv() => { + match event { + JobMasterEvent::CheckpointAck { pipeline_id, epoch } => { + if let Some(pending_set) = pending_checkpoints.get_mut(&epoch) { + pending_set.remove(&pipeline_id); + + if pending_set.is_empty() { + info!( + job_id = %job_id, epoch = epoch, + "Checkpoint Epoch is GLOBALLY COMPLETED!" + ); + + if let Some(catalog) = CatalogManager::try_global() { + if let Err(e) = catalog.commit_job_checkpoint(&job_id, epoch) { + error!( + job_id = %job_id, epoch = epoch, + error = %e, + "Failed to commit checkpoint metadata to Catalog" + ); + } + } else { + warn!( + job_id = %job_id, epoch = epoch, + "CatalogManager not available, checkpoint not persisted globally" + ); + } + + pending_checkpoints.remove(&epoch); + } + } + } + JobMasterEvent::CheckpointDecline { pipeline_id, epoch, reason } => { + error!( + job_id = %job_id, epoch = epoch, pipeline_id = pipeline_id, + reason = %reason, "Checkpoint FAILED!" + ); + pending_checkpoints.remove(&epoch); + } + } + } + } + } + }) + } } diff --git a/src/runtime/streaming/job/mod.rs b/src/runtime/streaming/job/mod.rs index 02e0343c..59d5c61f 100644 --- a/src/runtime/streaming/job/mod.rs +++ b/src/runtime/streaming/job/mod.rs @@ -14,4 +14,4 @@ pub mod edge_manager; pub mod job_manager; pub mod models; -pub use job_manager::{JobManager, StreamingJobSummary}; +pub use job_manager::{JobManager, StateConfig, StreamingJobSummary}; diff --git a/src/runtime/streaming/mod.rs b/src/runtime/streaming/mod.rs index 7e0ba57a..0e4e6758 100644 --- a/src/runtime/streaming/mod.rs +++ b/src/runtime/streaming/mod.rs @@ -23,5 +23,6 @@ pub mod memory; pub mod network; pub mod operators; pub mod protocol; +pub mod state; pub use protocol::StreamOutput; diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index 625cdee5..efe0abbb 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -11,15 +11,17 @@ // limitations under the License. use crate::sql::common::constants::updating_state_field; -use anyhow::{Result, bail}; -use arrow::compute::max_array; +use anyhow::{Result, anyhow, bail}; +use arrow::compute::{concat_batches, max_array}; use arrow::row::{RowConverter, SortField}; use arrow_array::builder::{ BinaryBuilder, TimestampNanosecondBuilder, UInt32Builder, UInt64Builder, }; use arrow_array::cast::AsArray; use arrow_array::types::UInt64Type; -use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, StructArray}; +use arrow_array::{ + Array, ArrayRef, BinaryArray, BooleanArray, RecordBatch, StructArray, UInt32Array, UInt64Array, +}; use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaBuilder, TimeUnit}; use datafusion::common::{Result as DFResult, ScalarValue}; use datafusion::physical_expr::aggregate::AggregateFunctionExpr; @@ -36,7 +38,7 @@ use std::collections::HashSet; use std::sync::LazyLock; use std::time::{Duration, Instant, SystemTime}; use std::{collections::HashMap, mem, sync::Arc}; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; // ========================================================================= // ========================================================================= use crate::runtime::streaming::StreamOutput; @@ -44,6 +46,7 @@ use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::operators::{Key, UpdatingCache}; +use crate::runtime::streaming::state::OperatorStateStore; use crate::runtime::util::decode_aggregate; use crate::sql::common::{ CheckpointBarrier, FsSchema, TIMESTAMP_FIELD, UPDATING_META_FIELD, Watermark, to_nanos, @@ -213,10 +216,15 @@ pub struct IncrementalAggregatingFunc { ttl: Duration, key_converter: RowConverter, new_generation: u64, + + state_store: Option>, } static GLOBAL_KEY: LazyLock>> = LazyLock::new(|| Arc::new(Vec::new())); +const KEY_SLIDING_SNAPSHOT: &[u8] = &[0x01]; +const KEY_BATCH_SNAPSHOT: &[u8] = &[0x02]; + impl IncrementalAggregatingFunc { fn update_batch( &mut self, @@ -437,40 +445,38 @@ impl IncrementalAggregatingFunc { // ========================================================================= fn checkpoint_sliding(&mut self) -> DFResult>> { - if self.updated_keys.is_empty() { + let keys = self.accumulators.keys(); + if keys.is_empty() { return Ok(None); } let mut states = vec![vec![]; self.sliding_state_schema.schema.fields.len()]; let parser = self.key_converter.parser(); - let mut generation_builder = UInt64Builder::with_capacity(self.updated_keys.len()); - - let mut cols = self - .key_converter - .convert_rows(self.updated_keys.keys().map(|k| { - let (accumulators, generation) = - self.accumulators.get_mut_generation(k.0.as_ref()).unwrap(); - generation_builder.append_value(generation); - - for (state, agg) in accumulators.iter_mut().zip(self.aggregates.iter()) { - let IncrementalState::Sliding { expr, accumulator } = state else { - continue; - }; - let state = accumulator.state().unwrap_or_else(|_| { - let state = accumulator.state().unwrap(); - *accumulator = expr.create_sliding_accumulator().unwrap(); - let states: Vec<_> = - state.iter().map(|s| s.to_array()).try_collect().unwrap(); - accumulator.merge_batch(&states).unwrap(); - state - }); - - for (idx, v) in agg.state_cols.iter().zip(state.into_iter()) { - states[*idx].push(v); - } + let mut generation_builder = UInt64Builder::with_capacity(keys.len()); + + let mut cols = self.key_converter.convert_rows(keys.iter().map(|k| { + let (accumulators, generation) = + self.accumulators.get_mut_generation(k.0.as_ref()).unwrap(); + generation_builder.append_value(generation); + + for (state, agg) in accumulators.iter_mut().zip(self.aggregates.iter()) { + let IncrementalState::Sliding { expr, accumulator } = state else { + continue; + }; + let state = accumulator.state().unwrap_or_else(|_| { + let state = accumulator.state().unwrap(); + *accumulator = expr.create_sliding_accumulator().unwrap(); + let states: Vec<_> = state.iter().map(|s| s.to_array()).try_collect().unwrap(); + accumulator.merge_batch(&states).unwrap(); + state + }); + + for (idx, v) in agg.state_cols.iter().zip(state.into_iter()) { + states[*idx].push(v); } - parser.parse(k.0.as_ref()) - }))?; + } + parser.parse(k.0.as_ref()) + }))?; cols.extend( states @@ -482,7 +488,7 @@ impl IncrementalAggregatingFunc { let generations = generation_builder.finish(); self.new_generation = self .new_generation - .max(max_array::(&generations).unwrap()); + .max(max_array::(&generations).unwrap_or(0)); cols.push(Arc::new(generations)); Ok(Some(cols)) @@ -496,12 +502,22 @@ impl IncrementalAggregatingFunc { { return Ok(None); } - if self.updated_keys.is_empty() { + + let keys = self.accumulators.keys(); + + let mut size = 0; + for k in &keys { + for state in self.accumulators.get_mut(k.0.as_ref()).unwrap().iter_mut() { + if let IncrementalState::Batch { data, .. } = state { + size += data.len(); + } + } + } + if size == 0 { return Ok(None); } - let size = self.updated_keys.len(); - let mut rows = Vec::with_capacity(size); + let mut key_bytes_for_rows = Vec::with_capacity(size); let mut accumulator_builder = UInt32Builder::with_capacity(size); let mut args_row_builder = BinaryBuilder::with_capacity(size, size * 4); let mut count_builder = UInt64Builder::with_capacity(size); @@ -509,10 +525,8 @@ impl IncrementalAggregatingFunc { let mut generation_builder = UInt64Builder::with_capacity(size); let now = to_nanos(SystemTime::now()) as i64; - let parser = self.key_converter.parser(); - for k in self.updated_keys.keys() { - let row = parser.parse(&k.0); + for k in keys { for (i, state) in self .accumulators .get_mut(k.0.as_ref()) @@ -520,29 +534,27 @@ impl IncrementalAggregatingFunc { .iter_mut() .enumerate() { - let IncrementalState::Batch { - data, - changed_values, - .. - } = state - else { + let IncrementalState::Batch { data, .. } = state else { continue; }; - for vk in changed_values.iter() { - if let Some(count) = data.get(vk) { - accumulator_builder.append_value(i as u32); - args_row_builder.append_value(&*vk.0); - count_builder.append_value(count.count); - generation_builder.append_value(count.generation); - timestamp_builder.append_value(now); - rows.push(row.to_owned()) - } + for (vk, count_data) in data.iter() { + accumulator_builder.append_value(i as u32); + args_row_builder.append_value(&*vk.0); + count_builder.append_value(count_data.count); + generation_builder.append_value(count_data.generation); + timestamp_builder.append_value(now); + key_bytes_for_rows.push(k.0.clone()); } data.retain(|_, v| v.count > 0); } } + let parser = self.key_converter.parser(); + let rows: Vec<_> = key_bytes_for_rows + .iter() + .map(|kb| parser.parse(kb).to_owned()) + .collect(); let mut cols = self.key_converter.convert_rows(rows.into_iter())?; cols.push(Arc::new(accumulator_builder.finish())); cols.push(Arc::new(args_row_builder.finish())); @@ -552,7 +564,7 @@ impl IncrementalAggregatingFunc { let generations = generation_builder.finish(); self.new_generation = self .new_generation - .max(max_array::(&generations).unwrap()); + .max(max_array::(&generations).unwrap_or(0)); cols.push(Arc::new(generations)); Ok(Some(cols)) @@ -710,7 +722,147 @@ impl Operator for IncrementalAggregatingFunc { } async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Updating Aggregate recovering state from LSM-Tree..." + ); + + let mut sliding_batches = Vec::new(); + let mut batch_batches = Vec::new(); + + for key in active_keys { + if key == KEY_SLIDING_SNAPSHOT { + sliding_batches + .extend(store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?); + } else if key == KEY_BATCH_SNAPSHOT { + batch_batches + .extend(store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?); + } + } + + let num_keys = self + .input_schema + .routing_keys() + .map(|k| k.len()) + .unwrap_or(0); + let now = Instant::now(); + + // Restore sliding (reversible) accumulator state + if !sliding_batches.is_empty() { + let combined = concat_batches(&self.sliding_state_schema.schema, &sliding_batches)?; + let key_cols: Vec = combined.columns()[0..num_keys].to_vec(); + let aggregate_states: Vec> = self + .aggregates + .iter() + .map(|agg| { + agg.state_cols + .iter() + .map(|&idx| combined.column(idx).clone()) + .collect() + }) + .collect(); + let gen_col = combined + .column(combined.num_columns() - 1) + .as_any() + .downcast_ref::() + .expect("generation column must be UInt64Array"); + + let rows = self.key_converter.convert_columns(&key_cols)?; + for i in 0..combined.num_rows() { + let key = rows.row(i).as_ref().to_vec(); + let generation = gen_col.value(i); + self.restore_sliding(&key, now, i, &aggregate_states, generation)?; + } + info!( + rows = combined.num_rows(), + "Restored sliding accumulator state." + ); + } + + // Restore batch (non-reversible) detail dictionaries + if !batch_batches.is_empty() { + let combined = concat_batches(&self.batch_state_schema.schema, &batch_batches)?; + let key_cols: Vec = combined.columns()[0..num_keys].to_vec(); + + let acc_idx_col = combined + .column(num_keys) + .as_any() + .downcast_ref::() + .expect("accumulator index column must be UInt32Array"); + let args_col = combined + .column(num_keys + 1) + .as_any() + .downcast_ref::() + .expect("args_row column must be BinaryArray"); + let count_col = combined + .column(num_keys + 2) + .as_any() + .downcast_ref::() + .expect("count column must be UInt64Array"); + // column num_keys+3 is timestamp, skip + let gen_col = combined + .column(num_keys + 4) + .as_any() + .downcast_ref::() + .expect("generation column must be UInt64Array"); + + let rows = self.key_converter.convert_columns(&key_cols)?; + + for i in 0..combined.num_rows() { + let key = rows.row(i).as_ref().to_vec(); + let acc_idx = acc_idx_col.value(i) as usize; + let args_row = args_col.value(i).to_vec(); + let count = count_col.value(i); + let generation = gen_col.value(i); + + if !self.accumulators.contains_key(&key) { + self.accumulators.insert( + Arc::new(key.clone()), + now, + generation, + self.make_accumulators(), + ); + } + + if let Some(accs) = self.accumulators.get_mut(&key) + && let Some(IncrementalState::Batch { + data, + changed_values, + .. + }) = accs.get_mut(acc_idx) + { + let vk = Key(Arc::new(args_row.clone())); + data.insert(vk.clone(), BatchData { count, generation }); + changed_values.insert(vk); + } + } + info!(rows = combined.num_rows(), "Restored batch detail state."); + } + + info!( + groups = self.accumulators.keys().len(), + "Updating Aggregate successfully restored active groups." + ); + } + self.initialize(ctx).await?; + self.state_store = Some(store); Ok(()) } @@ -743,9 +895,52 @@ impl Operator for IncrementalAggregatingFunc { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + let store = self + .state_store + .clone() + .expect("State store not initialized"); + + // Tombstone previous epoch snapshots for disk space reclamation + store + .remove_batches(KEY_SLIDING_SNAPSHOT.to_vec()) + .map_err(|e| anyhow!("{e}"))?; + store + .remove_batches(KEY_BATCH_SNAPSHOT.to_vec()) + .map_err(|e| anyhow!("{e}"))?; + + // Full snapshot of sliding (reversible) accumulator state + if let Some(cols) = self.checkpoint_sliding()? { + let batch = RecordBatch::try_new(self.sliding_state_schema.schema.clone(), cols)?; + store + .put(KEY_SLIDING_SNAPSHOT.to_vec(), batch) + .await + .map_err(|e| anyhow!("{e}"))?; + } + + // Full snapshot of batch (non-reversible) detail state + if let Some(cols) = self.checkpoint_batch()? { + let batch = RecordBatch::try_new(self.batch_state_schema.schema.clone(), cols)?; + store + .put(KEY_BATCH_SNAPSHOT.to_vec(), batch) + .await + .map_err(|e| anyhow!("{e}"))?; + } + + // Flush to Parquet + store + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + + info!( + epoch = barrier.epoch, + "Updating Aggregate snapshotted successfully." + ); + + self.updated_keys.clear(); + Ok(()) } @@ -907,6 +1102,7 @@ impl IncrementalAggregatingConstructor { sliding_state_schema, batch_state_schema, new_generation: 0, + state_store: None, }) } } diff --git a/src/runtime/streaming/operators/grouping/updating_cache.rs b/src/runtime/streaming/operators/grouping/updating_cache.rs index 37f2ba04..34c732fc 100644 --- a/src/runtime/streaming/operators/grouping/updating_cache.rs +++ b/src/runtime/streaming/operators/grouping/updating_cache.rs @@ -64,6 +64,10 @@ impl Iterator for TTLIter<'_, T> { } impl UpdatingCache { + pub fn keys(&self) -> Vec { + self.map.keys().cloned().collect() + } + pub fn with_time_to_idle(ttl: Duration) -> Self { Self { map: HashMap::new(), diff --git a/src/runtime/streaming/operators/joins/join_instance.rs b/src/runtime/streaming/operators/joins/join_instance.rs index 75513542..bfb6c416 100644 --- a/src/runtime/streaming/operators/joins/join_instance.rs +++ b/src/runtime/streaming/operators/joins/join_instance.rs @@ -11,9 +11,8 @@ // limitations under the License. use anyhow::{Result, anyhow}; -use arrow::compute::{max, min, partition, sort_to_indices, take}; +use arrow::compute::{concat_batches, max, min, partition, sort_to_indices, take}; use arrow_array::{RecordBatch, TimestampNanosecondArray}; -use datafusion::execution::SendableRecordBatchStream; use datafusion::execution::context::SessionContext; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::ExecutionPlan; @@ -21,80 +20,79 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; use prost::Message; -use std::collections::BTreeMap; +use std::collections::BTreeSet; use std::sync::{Arc, RwLock}; -use std::time::SystemTime; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::warn; +use std::time::UNIX_EPOCH; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; -use crate::sql::common::constants::mem_exec_join_side; -use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos}; +use crate::runtime::streaming::state::OperatorStateStore; +use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; use protocol::function_stream_graph::JoinOperator; #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum JoinSide { - Left, - Right, + Left = 0, + Right = 1, } -impl JoinSide { - #[allow(dead_code)] - fn name(&self) -> &'static str { - match self { - JoinSide::Left => mem_exec_join_side::LEFT, - JoinSide::Right => mem_exec_join_side::RIGHT, - } - } -} +// ============================================================================ +// Lightweight state index: composite key [Side(1B)] + [Timestamp(8B BE)] +// ============================================================================ -struct JoinInstance { - left_tx: UnboundedSender, - right_tx: UnboundedSender, - result_stream: SendableRecordBatchStream, +struct InstantStateIndex { + side: JoinSide, + active_timestamps: BTreeSet, } -impl JoinInstance { - fn feed_data(&self, batch: RecordBatch, side: JoinSide) -> Result<()> { - match side { - JoinSide::Left => self - .left_tx - .send(batch) - .map_err(|e| anyhow!("Left send err: {}", e)), - JoinSide::Right => self - .right_tx - .send(batch) - .map_err(|e| anyhow!("Right send err: {}", e)), +impl InstantStateIndex { + fn new(side: JoinSide) -> Self { + Self { + side, + active_timestamps: BTreeSet::new(), } } - async fn close_and_drain(self) -> Result> { - drop(self.left_tx); - drop(self.right_tx); - - let mut outputs = Vec::new(); - let mut stream = self.result_stream; + fn build_key(side: JoinSide, ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(9); + key.push(side as u8); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key + } - while let Some(result_batch) = stream.next().await { - outputs.push(result_batch?); + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 9 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[1..]); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None } - - Ok(outputs) } } +// ============================================================================ +// InstantJoinOperator (persistent state refactor) +// ============================================================================ + pub struct InstantJoinOperator { left_input_schema: FsSchemaRef, right_input_schema: FsSchemaRef, - active_joins: BTreeMap, - left_receiver_hook: Arc>>>, - right_receiver_hook: Arc>>>, + left_schema: FsSchemaRef, + right_schema: FsSchemaRef, + + left_passer: Arc>>, + right_passer: Arc>>, join_exec_plan: Arc, + + left_state: InstantStateIndex, + right_state: InstantStateIndex, + state_store: Option>, } impl InstantJoinOperator { @@ -105,32 +103,26 @@ impl InstantJoinOperator { } } - fn get_or_create_join_instance(&mut self, time: SystemTime) -> Result<&mut JoinInstance> { - use std::collections::btree_map::Entry; + async fn compute_pair( + &mut self, + left: RecordBatch, + right: RecordBatch, + ) -> Result> { + self.left_passer.write().unwrap().replace(left); + self.right_passer.write().unwrap().replace(right); - if let Entry::Vacant(e) = self.active_joins.entry(time) { - let (left_tx, left_rx) = unbounded_channel(); - let (right_tx, right_rx) = unbounded_channel(); + self.join_exec_plan.reset().map_err(|e| anyhow!("{e}"))?; - *self.left_receiver_hook.write().unwrap() = Some(left_rx); - *self.right_receiver_hook.write().unwrap() = Some(right_rx); + let mut result_stream = self + .join_exec_plan + .execute(0, SessionContext::new().task_ctx()) + .map_err(|e| anyhow!("{e}"))?; - self.join_exec_plan.reset().map_err(|e| anyhow!("{e}"))?; - let result_stream = self - .join_exec_plan - .execute(0, SessionContext::new().task_ctx()) - .map_err(|e| anyhow!("{e}"))?; - - e.insert(JoinInstance { - left_tx, - right_tx, - result_stream, - }); + let mut outputs = Vec::new(); + while let Some(batch) = result_stream.next().await { + outputs.push(batch.map_err(|e| anyhow!("{e}"))?); } - - self.active_joins - .get_mut(&time) - .ok_or_else(|| anyhow!("join instance missing after insert")) + Ok(outputs) } async fn process_side_internal( @@ -142,6 +134,10 @@ impl InstantJoinOperator { if batch.num_rows() == 0 { return Ok(()); } + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); let time_column = batch .column(self.input_schema(side).timestamp_index) @@ -152,19 +148,28 @@ impl InstantJoinOperator { let min_timestamp = min(time_column).ok_or_else(|| anyhow!("empty timestamp column"))?; let max_timestamp = max(time_column).ok_or_else(|| anyhow!("empty timestamp column"))?; - if let Some(watermark) = ctx.current_watermark() - && watermark > from_nanos(min_timestamp as u128) - { - warn!("Dropped late batch from {:?} before watermark", side); - return Ok(()); + if let Some(watermark) = ctx.current_watermark() { + let watermark_nanos = watermark.duration_since(UNIX_EPOCH).unwrap().as_nanos() as i64; + if watermark_nanos > min_timestamp { + warn!("Dropped late batch from {:?} before watermark", side); + return Ok(()); + } } let unkeyed_batch = self.input_schema(side).unkeyed_batch(&batch)?; + let state_index = match side { + JoinSide::Left => &mut self.left_state, + JoinSide::Right => &mut self.right_state, + }; if max_timestamp == min_timestamp { - let time_key = from_nanos(max_timestamp as u128); - let join_instance = self.get_or_create_join_instance(time_key)?; - join_instance.feed_data(unkeyed_batch, side)?; + let ts_nanos = max_timestamp as u64; + let key = InstantStateIndex::build_key(side, ts_nanos); + store + .put(key, unkeyed_batch) + .await + .map_err(|e| anyhow!("{e}"))?; + state_index.active_timestamps.insert(ts_nanos); return Ok(()); } @@ -179,16 +184,21 @@ impl InstantJoinOperator { let typed_timestamps = sorted_timestamps .as_any() .downcast_ref::() - .ok_or_else(|| anyhow!("sorted timestamps downcast failed"))?; + .unwrap(); + let ranges = partition(std::slice::from_ref(&sorted_timestamps)) .unwrap() .ranges(); for range in ranges { let sub_batch = sorted_batch.slice(range.start, range.end - range.start); - let time_key = from_nanos(typed_timestamps.value(range.start) as u128); - let join_instance = self.get_or_create_join_instance(time_key)?; - join_instance.feed_data(sub_batch, side)?; + let ts_nanos = typed_timestamps.value(range.start) as u64; + let key = InstantStateIndex::build_key(side, ts_nanos); + store + .put(key, sub_batch) + .await + .map_err(|e| anyhow!("{e}"))?; + state_index.active_timestamps.insert(ts_nanos); } Ok(()) @@ -201,7 +211,39 @@ impl Operator for InstantJoinOperator { "InstantJoin" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + for key in active_keys { + if let Some(ts) = InstantStateIndex::extract_timestamp(&key) { + if key[0] == JoinSide::Left as u8 { + self.left_state.active_timestamps.insert(ts); + } else if key[0] == JoinSide::Right as u8 { + self.right_state.active_timestamps.insert(ts); + } + } + } + + info!( + pipeline_id = ctx.pipeline_id, + restored_left = self.left_state.active_timestamps.len(), + restored_right = self.right_state.active_timestamps.len(), + "Instant Join Operator recovered state." + ); + + self.state_store = Some(store); Ok(()) } @@ -228,24 +270,76 @@ impl Operator for InstantJoinOperator { let Watermark::EventTime(current_time) = watermark else { return Ok(vec![]); }; - let mut emit_outputs = Vec::new(); + let store = self.state_store.clone().unwrap(); + let cutoff_nanos = current_time.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; + + let mut all_active_ts = BTreeSet::new(); + all_active_ts.extend(self.left_state.active_timestamps.iter()); + all_active_ts.extend(self.right_state.active_timestamps.iter()); + + let expired_ts: Vec = all_active_ts + .into_iter() + .filter(|&ts| ts < cutoff_nanos) + .collect(); + + if expired_ts.is_empty() { + return Ok(vec![]); + } - let mut expired_times = Vec::new(); - for key in self.active_joins.keys() { - if *key < current_time { - expired_times.push(*key); + // Phase 1: Harvest — extract all expired timestamp data from LSM-Tree + let mut pending_pairs: Vec<(u64, RecordBatch, RecordBatch)> = + Vec::with_capacity(expired_ts.len()); + + for &ts in &expired_ts { + let left_key = InstantStateIndex::build_key(JoinSide::Left, ts); + let right_key = InstantStateIndex::build_key(JoinSide::Right, ts); + + let left_batches = store + .get_batches(&left_key) + .await + .map_err(|e| anyhow!("{e}"))?; + let right_batches = store + .get_batches(&right_key) + .await + .map_err(|e| anyhow!("{e}"))?; + + let left_input = if left_batches.is_empty() { + RecordBatch::new_empty(self.left_schema.schema.clone()) } else { - break; - } + concat_batches(&self.left_schema.schema, left_batches.iter())? + }; + let right_input = if right_batches.is_empty() { + RecordBatch::new_empty(self.right_schema.schema.clone()) + } else { + concat_batches(&self.right_schema.schema, right_batches.iter())? + }; + + pending_pairs.push((ts, left_input, right_input)); } - for time_key in expired_times { - if let Some(join_instance) = self.active_joins.remove(&time_key) { - let joined_batches = join_instance.close_and_drain().await?; - for batch in joined_batches { - emit_outputs.push(StreamOutput::Forward(batch)); - } + // Phase 2: Compute — all data extracted, no store reference held + let mut emit_outputs = Vec::new(); + + for (_, left_input, right_input) in pending_pairs { + if left_input.num_rows() == 0 && right_input.num_rows() == 0 { + continue; } + let results = self.compute_pair(left_input, right_input).await?; + for batch in results { + emit_outputs.push(StreamOutput::Forward(batch)); + } + } + + // Phase 3: Cleanup — tombstone LSM-Tree entries and update in-memory index + for ts in expired_ts { + let left_key = InstantStateIndex::build_key(JoinSide::Left, ts); + let right_key = InstantStateIndex::build_key(JoinSide::Right, ts); + store.remove_batches(left_key).map_err(|e| anyhow!("{e}"))?; + store + .remove_batches(right_key) + .map_err(|e| anyhow!("{e}"))?; + self.left_state.active_timestamps.remove(&ts); + self.right_state.active_timestamps.remove(&ts); } Ok(emit_outputs) @@ -253,13 +347,22 @@ impl Operator for InstantJoinOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .unwrap() + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } } +// ============================================================================ +// Constructor +// ============================================================================ + pub struct InstantJoinConstructor; impl InstantJoinConstructor { @@ -268,21 +371,23 @@ impl InstantJoinConstructor { config: JoinOperator, registry: Arc, ) -> anyhow::Result { - let join_physical_plan_node = PhysicalPlanNode::decode(&mut config.join_plan.as_slice())?; - let left_input_schema: Arc = Arc::new(config.left_schema.unwrap().try_into()?); let right_input_schema: Arc = Arc::new(config.right_schema.unwrap().try_into()?); - let left_receiver_hook = Arc::new(RwLock::new(None)); - let right_receiver_hook = Arc::new(RwLock::new(None)); + let left_schema = Arc::new(left_input_schema.schema_without_keys()?); + let right_schema = Arc::new(right_input_schema.schema_without_keys()?); + + let left_passer = Arc::new(RwLock::new(None)); + let right_passer = Arc::new(RwLock::new(None)); let codec = StreamingExtensionCodec { - context: StreamingDecodingContext::LockedJoinStream { - left: left_receiver_hook.clone(), - right: right_receiver_hook.clone(), + context: StreamingDecodingContext::LockedJoinPair { + left: left_passer.clone(), + right: right_passer.clone(), }, }; + let join_physical_plan_node = PhysicalPlanNode::decode(&mut config.join_plan.as_slice())?; let join_exec_plan = join_physical_plan_node.try_into_physical_plan( registry.as_ref(), &RuntimeEnvBuilder::new().build()?, @@ -292,10 +397,14 @@ impl InstantJoinConstructor { Ok(InstantJoinOperator { left_input_schema, right_input_schema, - active_joins: BTreeMap::new(), - left_receiver_hook, - right_receiver_hook, + left_schema, + right_schema, + left_passer, + right_passer, join_exec_plan, + left_state: InstantStateIndex::new(JoinSide::Left), + right_state: InstantStateIndex::new(JoinSide::Right), + state_store: None, }) } } diff --git a/src/runtime/streaming/operators/joins/join_with_expiration.rs b/src/runtime/streaming/operators/joins/join_with_expiration.rs index 60bbe7e3..4d579715 100644 --- a/src/runtime/streaming/operators/joins/join_with_expiration.rs +++ b/src/runtime/streaming/operators/joins/join_with_expiration.rs @@ -19,15 +19,16 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::{physical_plan::AsExecutionPlan, protobuf::PhysicalPlanNode}; use futures::StreamExt; use prost::Message; -use std::collections::VecDeque; +use std::collections::BTreeSet; use std::sync::{Arc, RwLock}; -use std::time::{Duration, SystemTime}; -use tracing::warn; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; @@ -35,49 +36,91 @@ use protocol::function_stream_graph::JoinOperator; #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum JoinSide { - Left, - Right, + Left = 0, + Right = 1, } // ============================================================================ +// Persistent state buffer: composite key [Side(1B)] + [Timestamp(8B BE)] // ============================================================================ -struct StateBuffer { - batches: VecDeque<(SystemTime, RecordBatch)>, +struct PersistentStateBuffer { + side: JoinSide, ttl: Duration, + active_timestamps: BTreeSet, } -impl StateBuffer { - fn new(ttl: Duration) -> Self { +impl PersistentStateBuffer { + fn new(side: JoinSide, ttl: Duration) -> Self { Self { - batches: VecDeque::new(), + side, ttl, + active_timestamps: BTreeSet::new(), } } - fn insert(&mut self, batch: RecordBatch, time: SystemTime) { - self.batches.push_back((time, batch)); + fn build_key(side: JoinSide, ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(9); + key.push(side as u8); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key } - fn expire(&mut self, current_time: SystemTime) { - let cutoff = current_time - .checked_sub(self.ttl) - .unwrap_or(SystemTime::UNIX_EPOCH); - while let Some((time, _)) = self.batches.front() { - if *time < cutoff { - self.batches.pop_front(); - } else { - break; - } + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 9 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[1..]); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None } } - fn get_all_batches(&self) -> Vec { - self.batches.iter().map(|(_, b)| b.clone()).collect() + async fn insert( + &mut self, + batch: RecordBatch, + time: SystemTime, + store: &Arc, + ) -> Result<()> { + let ts_nanos = time.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; + self.active_timestamps.insert(ts_nanos); + let key = Self::build_key(self.side, ts_nanos); + store.put(key, batch).await.map_err(|e| anyhow!("{e}")) + } + + fn expire(&mut self, current_time: SystemTime, store: &Arc) -> Result<()> { + let cutoff = current_time.checked_sub(self.ttl).unwrap_or(UNIX_EPOCH); + let cutoff_nanos = cutoff.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; + + let expired_ts: Vec = self + .active_timestamps + .iter() + .take_while(|&&ts| ts < cutoff_nanos) + .copied() + .collect(); + + for ts in expired_ts { + let key = Self::build_key(self.side, ts); + store.remove_batches(key).map_err(|e| anyhow!("{e}"))?; + self.active_timestamps.remove(&ts); + } + + Ok(()) + } + + async fn get_all_batches(&self, store: &Arc) -> Result> { + let mut all_batches = Vec::new(); + for &ts in &self.active_timestamps { + let key = Self::build_key(self.side, ts); + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + all_batches.extend(batches); + } + Ok(all_batches) } } // ============================================================================ +// JoinWithExpirationOperator // ============================================================================ pub struct JoinWithExpirationOperator { @@ -90,8 +133,9 @@ pub struct JoinWithExpirationOperator { right_passer: Arc>>, join_exec_plan: Arc, - left_state: StateBuffer, - right_state: StateBuffer, + left_state: PersistentStateBuffer, + right_state: PersistentStateBuffer, + state_store: Option>, } impl JoinWithExpirationOperator { @@ -133,18 +177,30 @@ impl JoinWithExpirationOperator { ctx: &mut TaskContext, ) -> Result> { let current_time = ctx.current_watermark().unwrap_or_else(SystemTime::now); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); - self.left_state.expire(current_time); - self.right_state.expire(current_time); + self.left_state.expire(current_time, store)?; + self.right_state.expire(current_time, store)?; match side { - JoinSide::Left => self.left_state.insert(batch.clone(), current_time), - JoinSide::Right => self.right_state.insert(batch.clone(), current_time), + JoinSide::Left => { + self.left_state + .insert(batch.clone(), current_time, store) + .await? + } + JoinSide::Right => { + self.right_state + .insert(batch.clone(), current_time, store) + .await? + } } let opposite_batches = match side { - JoinSide::Left => self.right_state.get_all_batches(), - JoinSide::Right => self.left_state.get_all_batches(), + JoinSide::Left => self.right_state.get_all_batches(store).await?, + JoinSide::Right => self.left_state.get_all_batches(store).await?, }; if opposite_batches.is_empty() { @@ -182,7 +238,39 @@ impl Operator for JoinWithExpirationOperator { "JoinWithExpiration" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + for key in active_keys { + if let Some(ts) = PersistentStateBuffer::extract_timestamp(&key) { + if key[0] == JoinSide::Left as u8 { + self.left_state.active_timestamps.insert(ts); + } else if key[0] == JoinSide::Right as u8 { + self.right_state.active_timestamps.insert(ts); + } + } + } + + info!( + pipeline_id = ctx.pipeline_id, + restored_left = self.left_state.active_timestamps.len(), + restored_right = self.right_state.active_timestamps.len(), + "Join Operator restored state from LSM-Tree." + ); + + self.state_store = Some(store); Ok(()) } @@ -210,9 +298,19 @@ impl Operator for JoinWithExpirationOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + + store + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + + info!(epoch = barrier.epoch, "Join Operator snapshotted state."); Ok(()) } @@ -222,6 +320,7 @@ impl Operator for JoinWithExpirationOperator { } // ============================================================================ +// Constructor // ============================================================================ pub struct JoinWithExpirationConstructor; @@ -273,8 +372,9 @@ impl JoinWithExpirationConstructor { left_passer, right_passer, join_exec_plan, - left_state: StateBuffer::new(ttl), - right_state: StateBuffer::new(ttl), + left_state: PersistentStateBuffer::new(JoinSide::Left, ttl), + right_state: PersistentStateBuffer::new(JoinSide::Right, ttl), + state_store: None, }) } } diff --git a/src/runtime/streaming/operators/windows/session_aggregating_window.rs b/src/runtime/streaming/operators/windows/session_aggregating_window.rs index 4293ea7c..15075964 100644 --- a/src/runtime/streaming/operators/windows/session_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/session_aggregating_window.rs @@ -30,15 +30,17 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; use prost::Message; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; +use tracing::info; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::converter::Converter; use crate::sql::common::{ CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos, to_nanos, @@ -170,6 +172,7 @@ impl ActiveSession { } } +#[derive(Clone)] struct SessionWindowResult { window_start: SystemTime, window_end: SystemTime, @@ -389,9 +392,39 @@ pub struct SessionWindowOperator { session_states: HashMap, KeySessionState>, pq_watermark_actions: BTreeMap>>, pq_start_times: BTreeMap>>, + + // LSM-Tree state engine and per-routing-key timestamp index + state_store: Option>, + pending_timestamps: HashMap, BTreeSet>, } impl SessionWindowOperator { + // State key: [RoutingKey bytes] + [8-byte big-endian timestamp] + fn build_state_key(routing_key: &[u8], ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(routing_key.len() + 8); + key.extend_from_slice(routing_key); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key + } + + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() >= 8 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[key.len() - 8..]); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None + } + } + + fn extract_routing_key(key: &[u8]) -> Vec { + if key.len() >= 8 { + key[..key.len() - 8].to_vec() + } else { + Vec::new() + } + } + fn filter_batch_by_time( &self, batch: RecordBatch, @@ -430,6 +463,7 @@ impl SessionWindowOperator { &mut self, sorted_batch: RecordBatch, watermark: Option, + is_recovery_replay: bool, ) -> Result<()> { let partition_ranges = if !self.config.input_schema_ref.has_routing_keys() { std::iter::once(0..sorted_batch.num_rows()).collect::>() @@ -470,6 +504,32 @@ impl SessionWindowOperator { .to_vec() }; + // Write-ahead persistence: skip during recovery replay to avoid duplicate writes + if !is_recovery_replay { + let ts_col = key_batch + .column(self.config.input_schema_ref.timestamp_index) + .as_any() + .downcast_ref::() + .unwrap(); + let ts_nanos = ts_col.value(0) as u64; + + let state_key = Self::build_state_key(&row_key, ts_nanos); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + + store + .put(state_key, key_batch.clone()) + .await + .map_err(|e| anyhow!("{e}"))?; + + self.pending_timestamps + .entry(row_key.clone()) + .or_default() + .insert(ts_nanos); + } + let state = self .session_states .entry(row_key.clone()) @@ -529,7 +589,10 @@ impl SessionWindowOperator { Ok(()) } - async fn evaluate_watermark(&mut self, watermark: SystemTime) -> Result> { + async fn evaluate_watermark_with_meta( + &mut self, + watermark: SystemTime, + ) -> Result, Vec)>> { let mut emit_results: Vec<(Vec, Vec)> = Vec::new(); loop { @@ -588,11 +651,7 @@ impl SessionWindowOperator { } } - if emit_results.is_empty() { - return Ok(vec![]); - } - - Ok(vec![self.format_to_arrow(emit_results)?]) + Ok(emit_results) } fn format_to_arrow( @@ -666,10 +725,68 @@ impl Operator for SessionWindowOperator { "SessionWindow" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery & event sourcing: rebuild in-memory sessions from LSM-Tree + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Session Operator recovering active state keys from LSM-Tree..." + ); + + let mut recovered_batches = Vec::new(); + + for key in active_keys { + if let Some(ts) = Self::extract_timestamp(&key) { + let row_key = Self::extract_routing_key(&key); + self.pending_timestamps + .entry(row_key) + .or_default() + .insert(ts); + } + + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + recovered_batches.extend(batches); + } + + // Temporal ordering is critical: replay must preserve watermark/session merge invariants + recovered_batches.sort_by_key(|b| { + b.column(self.config.input_schema_ref.timestamp_index) + .as_any() + .downcast_ref::() + .map(|ts| ts.value(0)) + .unwrap_or(0) + }); + + for batch in recovered_batches { + self.ingest_sorted_batch(batch, None, true).await?; + } + + info!( + pipeline_id = ctx.pipeline_id, + "Session Window Operator successfully replayed events and rebuilt in-memory sessions." + ); + } + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist raw data before in-memory ingestion async fn process_data( &mut self, _input_idx: usize, @@ -685,12 +802,13 @@ impl Operator for SessionWindowOperator { let sorted_batch = self.sort_batch(&filtered_batch)?; - self.ingest_sorted_batch(sorted_batch, watermark_time) + self.ingest_sorted_batch(sorted_batch, watermark_time, false) .await?; Ok(vec![]) } + // Watermark-driven session closure with precise LSM-Tree garbage collection async fn process_watermark( &mut self, watermark: Watermark, @@ -700,18 +818,56 @@ impl Operator for SessionWindowOperator { return Ok(vec![]); }; - let output_batches = self.evaluate_watermark(current_time).await?; - Ok(output_batches - .into_iter() - .map(StreamOutput::Forward) - .collect()) + let completed_sessions = self.evaluate_watermark_with_meta(current_time).await?; + if completed_sessions.is_empty() { + return Ok(vec![]); + } + + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + + // GC: tombstone expired raw data covered by closed sessions + for (row_key, session_results) in &completed_sessions { + if let Some(ts_set) = self.pending_timestamps.get_mut(row_key) { + for session_res in session_results { + let start_nanos = to_nanos(session_res.window_start) as u64; + let end_nanos = to_nanos(session_res.window_end - self.config.gap) as u64; + + let expired_ts: Vec = + ts_set.range(start_nanos..=end_nanos).copied().collect(); + + for ts in expired_ts { + let state_key = Self::build_state_key(row_key, ts); + store + .remove_batches(state_key) + .map_err(|e| anyhow!("{e}"))?; + ts_set.remove(&ts); + } + } + } + } + + let output_batch = self.format_to_arrow(completed_sessions)?; + Ok(vec![StreamOutput::Forward(output_batch)]) } async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + + info!( + epoch = barrier.epoch, + "Session Window Operator snapshotted state." + ); Ok(()) } @@ -797,6 +953,8 @@ impl SessionAggregatingWindowConstructor { pq_start_times: BTreeMap::new(), pq_watermark_actions: BTreeMap::new(), row_converter, + state_store: None, + pending_timestamps: HashMap::new(), }) } } diff --git a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs index 73ba4dc9..538e0dad 100644 --- a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs @@ -11,7 +11,7 @@ // limitations under the License. use anyhow::{Result, anyhow, bail}; -use arrow::compute::{partition, sort_to_indices, take}; +use arrow::compute::{concat_batches, partition, sort_to_indices, take}; use arrow_array::{Array, PrimitiveArray, RecordBatch, types::TimestampNanosecondType}; use arrow_schema::SchemaRef; use datafusion::common::ScalarValue; @@ -27,20 +27,49 @@ use datafusion_proto::{ }; use futures::StreamExt; use prost::Message; -use std::collections::{BTreeMap, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, VecDeque}; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; +use tracing::info; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark, from_nanos, to_nanos}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; use protocol::function_stream_graph::SlidingWindowAggregateOperator; // ============================================================================ +// Dual-layer state key: [StateType(1B)] + [Timestamp(8B BE)] +// STATE_TYPE_RAW = 0 (raw input data, pending partial aggregation) +// STATE_TYPE_PARTIAL = 1 (pre-aggregated pane results) +// ============================================================================ + +const STATE_TYPE_RAW: u8 = 0; +const STATE_TYPE_PARTIAL: u8 = 1; + +fn build_state_key(state_type: u8, ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(9); + key.push(state_type); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key +} + +fn parse_state_key(key: &[u8]) -> Option<(u8, u64)> { + if key.len() == 9 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[1..9]); + Some((key[0], u64::from_be_bytes(ts_bytes))) + } else { + None + } +} + +// ============================================================================ +// RecordBatchTier & TieredRecordBatchHolder // ============================================================================ #[derive(Default, Debug)] @@ -263,6 +292,11 @@ pub struct SlidingWindowOperator { active_bins: BTreeMap, tiered_record_batches: TieredRecordBatchHolder, + + // LSM-Tree state engine with dual-layer index + state_store: Option>, + pending_raw_bins: BTreeSet, + pending_partial_bins: BTreeSet, } impl SlidingWindowOperator { @@ -309,10 +343,77 @@ impl Operator for SlidingWindowOperator { "SlidingWindow" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery: restore dual-layer state (partial panes + raw active bins) + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + let mut raw_recovery_batches = Vec::new(); + + for key in active_keys { + if let Some((state_type, ts_nanos)) = parse_state_key(&key) { + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + if batches.is_empty() { + continue; + } + + if state_type == STATE_TYPE_PARTIAL { + let bin_start = from_nanos(ts_nanos as u128); + for b in batches { + self.tiered_record_batches.insert(b, bin_start)?; + } + self.pending_partial_bins.insert(ts_nanos); + } else if state_type == STATE_TYPE_RAW { + let schema = batches[0].schema(); + let combined = concat_batches(&schema, &batches)?; + raw_recovery_batches.push((ts_nanos, combined)); + } + } + } + + // Temporal ordering guarantees correct DataFusion session replay + raw_recovery_batches.sort_by_key(|(ts, _)| *ts); + + for (ts_nanos, batch) in raw_recovery_batches { + let bin_start = from_nanos(ts_nanos as u128); + let slot = self.active_bins.entry(bin_start).or_default(); + Self::ensure_bin_running( + slot, + self.partial_aggregation_plan.clone(), + &self.receiver_hook, + )?; + + slot.sender + .as_ref() + .unwrap() + .send(batch) + .map_err(|e| anyhow!("{e}"))?; + self.pending_raw_bins.insert(ts_nanos); + } + + info!( + pipeline_id = ctx.pipeline_id, + partial_bins = self.pending_partial_bins.len(), + raw_bins = self.pending_raw_bins.len(), + "Sliding Window Operator recovered state." + ); + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist raw data (Type 0) before in-memory computation async fn process_data( &mut self, _input_idx: usize, @@ -340,6 +441,10 @@ impl Operator for SlidingWindowOperator { let partition_ranges = partition(std::slice::from_ref(&sorted_bins))?.ranges(); let watermark = ctx.current_watermark(); + let store = self + .state_store + .clone() + .expect("State store not initialized"); for range in partition_ranges { let bin_start = from_nanos(typed_bin.value(range.start) as u128); @@ -351,8 +456,16 @@ impl Operator for SlidingWindowOperator { } let bin_batch = sorted.slice(range.start, range.end - range.start); - let slot = self.active_bins.entry(bin_start).or_default(); + let bin_start_nanos = to_nanos(bin_start) as u64; + + let key = build_state_key(STATE_TYPE_RAW, bin_start_nanos); + store + .put(key, bin_batch.clone()) + .await + .map_err(|e| anyhow!("{e}"))?; + self.pending_raw_bins.insert(bin_start_nanos); + let slot = self.active_bins.entry(bin_start).or_default(); Self::ensure_bin_running( slot, self.partial_aggregation_plan.clone(), @@ -371,6 +484,7 @@ impl Operator for SlidingWindowOperator { Ok(vec![]) } + // State morphing (Type 0 → Type 1) and dual-layer GC async fn process_watermark( &mut self, watermark: Watermark, @@ -380,6 +494,10 @@ impl Operator for SlidingWindowOperator { return Ok(vec![]); }; let watermark_bin = self.bin_start(current_time); + let store = self + .state_store + .clone() + .expect("State store not initialized"); let mut final_outputs = Vec::new(); @@ -398,12 +516,34 @@ impl Operator for SlidingWindowOperator { .remove(&bin_start) .ok_or_else(|| anyhow!("missing active bin"))?; let bin_end = bin_start + self.slide; + let bin_start_nanos = to_nanos(bin_start) as u64; + // Phase 1: drain partial aggregation from DataFusion bin.close_and_drain().await?; - for b in bin.finished_batches { - self.tiered_record_batches.insert(b, bin_start)?; + + // Phase 2: state morphing — persist partial result (Type 1), feed tiered holder + if !bin.finished_batches.is_empty() { + let schema = bin.finished_batches[0].schema(); + let combined_partial = concat_batches(&schema, &bin.finished_batches)?; + + let p_key = build_state_key(STATE_TYPE_PARTIAL, bin_start_nanos); + store + .put(p_key, combined_partial) + .await + .map_err(|e| anyhow!("{e}"))?; + self.pending_partial_bins.insert(bin_start_nanos); + + for b in bin.finished_batches { + self.tiered_record_batches.insert(b, bin_start)?; + } } + // Phase 3: tombstone raw data (Type 0) — no longer needed after partial is saved + let r_key = build_state_key(STATE_TYPE_RAW, bin_start_nanos); + store.remove_batches(r_key).map_err(|e| anyhow!("{e}"))?; + self.pending_raw_bins.remove(&bin_start_nanos); + + // Phase 4: compute final sliding window result let interval_start = bin_end - self.width; let interval_end = bin_end; @@ -436,8 +576,23 @@ impl Operator for SlidingWindowOperator { final_outputs.push(StreamOutput::Forward(batch?)); } - self.tiered_record_batches - .delete_before(bin_end + self.slide - self.width)?; + // Phase 5: GC expired partial bins (Type 1) that fall outside the window + let cutoff_time = bin_end + self.slide - self.width; + self.tiered_record_batches.delete_before(cutoff_time)?; + + let cutoff_nanos = to_nanos(cutoff_time) as u64; + let expired_partials: Vec = self + .pending_partial_bins + .iter() + .take_while(|&&ts| ts < cutoff_nanos) + .copied() + .collect(); + + for ts in expired_partials { + let p_key = build_state_key(STATE_TYPE_PARTIAL, ts); + store.remove_batches(p_key).map_err(|e| anyhow!("{e}"))?; + self.pending_partial_bins.remove(&ts); + } } Ok(final_outputs) @@ -445,9 +600,14 @@ impl Operator for SlidingWindowOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } @@ -531,6 +691,9 @@ impl SlidingAggregatingWindowConstructor { final_batches_passer, active_bins: BTreeMap::new(), tiered_record_batches: TieredRecordBatchHolder::new(vec![slide])?, + state_store: None, + pending_raw_bins: BTreeSet::new(), + pending_partial_bins: BTreeSet::new(), }) } } diff --git a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs index de576bf0..7bf3268d 100644 --- a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs @@ -27,17 +27,18 @@ use datafusion_proto::{ }; use futures::StreamExt; use prost::Message; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::mem; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::warn; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::time_utils::print_time; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark, from_nanos, to_nanos}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; @@ -94,9 +95,28 @@ pub struct TumblingWindowOperator { final_batches_passer: Arc>>, active_bins: BTreeMap, + + // LSM-Tree state engine and pending window timestamp index + state_store: Option>, + pending_bins: BTreeSet, } impl TumblingWindowOperator { + // State key: 8-byte big-endian bin_start_nanos + fn build_state_key(ts_nanos: u64) -> Vec { + ts_nanos.to_be_bytes().to_vec() + } + + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 8 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(key); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None + } + } + fn bin_start(&self, timestamp: SystemTime) -> SystemTime { if self.width == Duration::ZERO { return timestamp; @@ -141,10 +161,67 @@ impl Operator for TumblingWindowOperator { "TumblingWindow" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery: replay raw data from LSM-Tree into DataFusion sessions + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Tumbling Window Operator recovering active windows from LSM-Tree..." + ); + + for key in active_keys { + if let Some(ts_nanos) = Self::extract_timestamp(&key) { + let bin_start = from_nanos(ts_nanos as u128); + + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + if batches.is_empty() { + continue; + } + + let slot = self.active_bins.entry(bin_start).or_default(); + Self::ensure_bin_running( + slot, + self.partial_aggregation_plan.clone(), + &self.receiver_hook, + )?; + + let sender = slot.sender.as_ref().unwrap(); + for batch in batches { + sender + .send(batch) + .map_err(|e| anyhow!("recovery channel send: {e}"))?; + } + + self.pending_bins.insert(ts_nanos); + } + } + + info!( + pipeline_id = ctx.pipeline_id, + "Tumbling Window Operator successfully replayed events and rebuilt in-memory state." + ); + } + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist raw data before in-memory computation async fn process_data( &mut self, _input_idx: usize, @@ -171,6 +248,11 @@ impl Operator for TumblingWindowOperator { .ok_or_else(|| anyhow!("binning function must produce TimestampNanosecond"))?; let partition_ranges = partition(std::slice::from_ref(&sorted_bins))?.ranges(); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + for range in partition_ranges { let bin_start = from_nanos(typed_bin.value(range.start) as u128); @@ -186,8 +268,16 @@ impl Operator for TumblingWindowOperator { } let bin_batch = sorted.slice(range.start, range.end - range.start); - let slot = self.active_bins.entry(bin_start).or_default(); + let bin_start_nanos = to_nanos(bin_start) as u64; + + let state_key = Self::build_state_key(bin_start_nanos); + store + .put(state_key, bin_batch.clone()) + .await + .map_err(|e| anyhow!("{e}"))?; + self.pending_bins.insert(bin_start_nanos); + let slot = self.active_bins.entry(bin_start).or_default(); Self::ensure_bin_running( slot, self.partial_aggregation_plan.clone(), @@ -206,6 +296,7 @@ impl Operator for TumblingWindowOperator { Ok(vec![]) } + // Watermark-driven window closure with LSM-Tree GC async fn process_watermark( &mut self, watermark: Watermark, @@ -214,6 +305,10 @@ impl Operator for TumblingWindowOperator { let Watermark::EventTime(current_time) = watermark else { return Ok(vec![]); }; + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); let mut final_outputs = Vec::new(); @@ -227,10 +322,8 @@ impl Operator for TumblingWindowOperator { } for bin_start in expired_bins { - let mut bin = self - .active_bins - .remove(&bin_start) - .ok_or_else(|| anyhow!("missing tumbling bin"))?; + let mut bin = self.active_bins.remove(&bin_start).unwrap(); + let bin_start_nanos = to_nanos(bin_start) as u64; bin.close_and_drain().await?; let partial_batches = mem::take(&mut bin.finished_batches); @@ -271,6 +364,13 @@ impl Operator for TumblingWindowOperator { final_outputs.push(StreamOutput::Forward(batch?)); } } + + // Tombstone the raw data — window is fully closed + let state_key = Self::build_state_key(bin_start_nanos); + store + .remove_batches(state_key) + .map_err(|e| anyhow!("{e}"))?; + self.pending_bins.remove(&bin_start_nanos); } Ok(final_outputs) @@ -278,9 +378,14 @@ impl Operator for TumblingWindowOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } @@ -367,6 +472,8 @@ impl TumblingAggregateWindowConstructor { receiver_hook, final_batches_passer, active_bins: BTreeMap::new(), + state_store: None, + pending_bins: BTreeSet::new(), }) } } diff --git a/src/runtime/streaming/operators/windows/window_function.rs b/src/runtime/streaming/operators/windows/window_function.rs index 5e340fec..cf6a198d 100644 --- a/src/runtime/streaming/operators/windows/window_function.rs +++ b/src/runtime/streaming/operators/windows/window_function.rs @@ -13,7 +13,6 @@ use anyhow::{Result, anyhow}; use arrow::compute::{max, min}; use arrow_array::RecordBatch; -use datafusion::execution::SendableRecordBatchStream; use datafusion::execution::context::SessionContext; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::ExecutionPlan; @@ -21,57 +20,26 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; use prost::Message; -use std::collections::BTreeMap; +use std::collections::BTreeSet; use std::sync::{Arc, RwLock}; use std::time::SystemTime; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::warn; +use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel}; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::Operator; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::time_utils::print_time; -use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos}; +use crate::sql::common::{ + CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos, to_nanos, +}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; // ============================================================================ -// ============================================================================ - -struct ActiveWindowExec { - sender: Option>, - result_stream: Option, -} - -impl ActiveWindowExec { - fn new( - plan: Arc, - hook: &Arc>>>, - ) -> Result { - let (tx, rx) = unbounded_channel(); - *hook.write().unwrap() = Some(rx); - plan.reset()?; - let result_stream = plan.execute(0, SessionContext::new().task_ctx())?; - Ok(Self { - sender: Some(tx), - result_stream: Some(result_stream), - }) - } - - async fn close_and_drain(&mut self) -> Result> { - self.sender.take(); - let mut results = Vec::new(); - if let Some(mut stream) = self.result_stream.take() { - while let Some(batch) = stream.next().await { - results.push(batch?); - } - } - Ok(results) - } -} - -// ============================================================================ +// WindowFunctionOperator: LSM-Tree backed lazy-compute model // ============================================================================ pub struct WindowFunctionOperator { @@ -79,10 +47,28 @@ pub struct WindowFunctionOperator { input_schema_unkeyed: FsSchemaRef, window_exec_plan: Arc, receiver_hook: Arc>>>, - active_execs: BTreeMap, + + // LSM-Tree state engine and lightweight timestamp index + state_store: Option>, + pending_timestamps: BTreeSet, } impl WindowFunctionOperator { + // State key: 8-byte big-endian timestamp (nanos) + fn build_state_key(ts_nanos: u64) -> Vec { + ts_nanos.to_be_bytes().to_vec() + } + + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 8 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(key); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None + } + } + fn filter_and_split_batches( &self, batch: RecordBatch, @@ -137,18 +123,6 @@ impl WindowFunctionOperator { } Ok(batches) } - - fn get_or_create_exec(&mut self, timestamp: SystemTime) -> Result<&mut ActiveWindowExec> { - use std::collections::btree_map::Entry; - match self.active_execs.entry(timestamp) { - Entry::Vacant(v) => { - let new_exec = - ActiveWindowExec::new(self.window_exec_plan.clone(), &self.receiver_hook)?; - Ok(v.insert(new_exec)) - } - Entry::Occupied(o) => Ok(o.into_mut()), - } - } } #[async_trait] @@ -157,10 +131,47 @@ impl Operator for WindowFunctionOperator { "WindowFunction" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery: restore the lightweight timestamp index from LSM-Tree. + // Data stays on disk until process_watermark triggers on-demand compute. + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.memory_controller.clone(), + ctx.io_manager.clone(), + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Window Function Operator recovering active timestamps from LSM-Tree..." + ); + + for key in active_keys { + if let Some(ts_nanos) = Self::extract_timestamp(&key) { + self.pending_timestamps.insert(ts_nanos); + } + } + + info!( + pipeline_id = ctx.pipeline_id, + "Window Function Operator successfully rebuilt in-memory indices." + ); + } + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist data into LSM-Tree, defer computation to watermark async fn process_data( &mut self, _input_idx: usize, @@ -169,19 +180,27 @@ impl Operator for WindowFunctionOperator { ) -> Result> { let current_watermark = ctx.current_watermark(); let split_batches = self.filter_and_split_batches(batch, current_watermark)?; + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); for (sub_batch, timestamp) in split_batches { - let exec = self.get_or_create_exec(timestamp)?; - exec.sender - .as_ref() - .ok_or_else(|| anyhow!("window exec sender missing"))? - .send(sub_batch) - .map_err(|e| anyhow!("route batch to plan: {e}"))?; + let ts_nanos = to_nanos(timestamp) as u64; + let key = Self::build_state_key(ts_nanos); + + store + .put(key, sub_batch) + .await + .map_err(|e| anyhow!("{e}"))?; + + self.pending_timestamps.insert(ts_nanos); } Ok(vec![]) } + // On-demand compute & GC: pull data from LSM-Tree, run DataFusion, tombstone async fn process_watermark( &mut self, watermark: Watermark, @@ -190,27 +209,48 @@ impl Operator for WindowFunctionOperator { let Watermark::EventTime(current_time) = watermark else { return Ok(vec![]); }; + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + let current_nanos = to_nanos(current_time) as u64; + + let expired_ts: Vec = self + .pending_timestamps + .iter() + .take_while(|&&ts| ts < current_nanos) + .copied() + .collect(); let mut final_outputs = Vec::new(); - let mut expired_timestamps = Vec::new(); - for &k in self.active_execs.keys() { - if k < current_time { - expired_timestamps.push(k); - } else { - break; - } - } + for ts in expired_ts { + let key = Self::build_state_key(ts); + + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; - for ts in expired_timestamps { - let mut exec = self - .active_execs - .remove(&ts) - .ok_or_else(|| anyhow!("missing window exec"))?; - let result_batches = exec.close_and_drain().await?; - for batch in result_batches { - final_outputs.push(StreamOutput::Forward(batch)); + if !batches.is_empty() { + let (tx, rx) = unbounded_channel(); + *self.receiver_hook.write().unwrap() = Some(rx); + + self.window_exec_plan.reset()?; + let mut stream = self + .window_exec_plan + .execute(0, SessionContext::new().task_ctx())?; + + for batch in batches { + tx.send(batch) + .map_err(|e| anyhow!("Failed to send batch to execution plan: {e}"))?; + } + drop(tx); + + while let Some(res) = stream.next().await { + final_outputs.push(StreamOutput::Forward(res?)); + } } + + store.remove_batches(key).map_err(|e| anyhow!("{e}"))?; + self.pending_timestamps.remove(&ts); } Ok(final_outputs) @@ -218,9 +258,14 @@ impl Operator for WindowFunctionOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .snapshot_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; Ok(()) } @@ -275,7 +320,8 @@ impl WindowFunctionConstructor { input_schema_unkeyed, window_exec_plan, receiver_hook, - active_execs: BTreeMap::new(), + state_store: None, + pending_timestamps: BTreeSet::new(), }) } } diff --git a/src/runtime/streaming/protocol/control.rs b/src/runtime/streaming/protocol/control.rs index 3b23cb09..e87ccd3b 100644 --- a/src/runtime/streaming/protocol/control.rs +++ b/src/runtime/streaming/protocol/control.rs @@ -79,3 +79,16 @@ pub enum StopMode { pub fn control_channel(capacity: usize) -> (Sender, Receiver) { mpsc::channel(capacity) } + +#[derive(Debug, Clone)] +pub enum JobMasterEvent { + CheckpointAck { + pipeline_id: u32, + epoch: u64, + }, + CheckpointDecline { + pipeline_id: u32, + epoch: u64, + reason: String, + }, +} diff --git a/src/runtime/streaming/protocol/mod.rs b/src/runtime/streaming/protocol/mod.rs index e91e8d8c..28fd85a4 100644 --- a/src/runtime/streaming/protocol/mod.rs +++ b/src/runtime/streaming/protocol/mod.rs @@ -13,4 +13,6 @@ pub mod control; pub mod event; +#[allow(unused_imports)] +pub use control::{ControlCommand, JobMasterEvent, StopMode}; pub use event::{CheckpointBarrier, StreamOutput, Watermark}; diff --git a/src/runtime/streaming/state/error.rs b/src/runtime/streaming/state/error.rs new file mode 100644 index 00000000..10c3c7c5 --- /dev/null +++ b/src/runtime/streaming/state/error.rs @@ -0,0 +1,37 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +use crossbeam_channel::TrySendError; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum StateEngineError { + #[error("I/O error during state persistence: {0}")] + IoError(#[from] std::io::Error), + + #[error("Parquet serialization/deserialization failed: {0}")] + ParquetError(#[from] parquet::errors::ParquetError), + + #[error("Arrow computation failed: {0}")] + ArrowError(#[from] arrow::error::ArrowError), + + #[error("Memory hard limit exceeded and spill channel is full")] + MemoryBackpressureTimeout, + + #[error("Background I/O pool has been shut down or disconnected")] + IoPoolDisconnected, + + #[error("State metadata corrupted: {0}")] + Corruption(String), +} + +pub type Result = std::result::Result; + +impl From> for StateEngineError { + fn from(err: TrySendError) -> Self { + match err { + TrySendError::Full(_) => StateEngineError::MemoryBackpressureTimeout, + TrySendError::Disconnected(_) => StateEngineError::IoPoolDisconnected, + } + } +} diff --git a/src/runtime/streaming/state/io_manager.rs b/src/runtime/streaming/state/io_manager.rs new file mode 100644 index 00000000..9b37da1d --- /dev/null +++ b/src/runtime/streaming/state/io_manager.rs @@ -0,0 +1,151 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +#[allow(unused_imports)] +use super::error::StateEngineError; +use super::metrics::StateMetricsCollector; +use super::operator_state::{MemTable, OperatorStateStore, TombstoneMap}; +use crossbeam_channel::{Receiver, Sender, TrySendError, bounded}; +use std::panic::{AssertUnwindSafe, catch_unwind}; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; +use std::time::Instant; + +pub struct SpillJob { + pub store: Arc, + pub epoch: u64, + pub data: MemTable, + pub tombstone_snapshot: TombstoneMap, +} + +pub enum CompactJob { + Minor { store: Arc }, + Major { store: Arc }, +} + +pub struct IoPool { + spill_tx: Option>, + compact_tx: Option>, + worker_handles: Vec>, +} + +impl IoPool { + pub fn try_new( + spill_threads: usize, + compact_threads: usize, + metrics: Arc, + ) -> std::io::Result<(Self, IoManager)> { + let (spill_tx, spill_rx) = bounded::(1024); + let (compact_tx, compact_rx) = bounded::(256); + let mut worker_handles = Vec::with_capacity(spill_threads + compact_threads); + + for i in 0..spill_threads.max(1) { + let rx = spill_rx.clone(); + let m = metrics.clone(); + let handle = thread::Builder::new() + .name(format!("fs-spill-worker-{i}")) + .spawn(move || spill_worker_loop(rx, m))?; + worker_handles.push(handle); + } + + for i in 0..compact_threads.max(1) { + let rx = compact_rx.clone(); + let m = metrics.clone(); + let handle = thread::Builder::new() + .name(format!("fs-compact-worker-{i}")) + .spawn(move || compact_worker_loop(rx, m))?; + worker_handles.push(handle); + } + + let manager = IoManager { + spill_tx: spill_tx.clone(), + compact_tx: compact_tx.clone(), + }; + + Ok(( + Self { + spill_tx: Some(spill_tx), + compact_tx: Some(compact_tx), + worker_handles, + }, + manager, + )) + } + + pub fn shutdown(mut self) { + tracing::info!("Initiating graceful shutdown for IoPool..."); + self.spill_tx.take(); + self.compact_tx.take(); + for handle in self.worker_handles.drain(..) { + if let Err(e) = handle.join() { + tracing::error!("I/O Worker thread panicked during shutdown: {:?}", e); + } + } + tracing::info!("IoPool graceful shutdown completed."); + } +} + +#[derive(Clone)] +pub struct IoManager { + spill_tx: Sender, + compact_tx: Sender, +} + +impl IoManager { + pub fn try_send_spill(&self, job: SpillJob) -> Result<(), TrySendError> { + self.spill_tx.try_send(job) + } + pub fn try_send_compact(&self, job: CompactJob) -> Result<(), TrySendError> { + self.compact_tx.try_send(job) + } + pub fn pending_spills(&self) -> usize { + self.spill_tx.len() + } +} + +fn spill_worker_loop(rx: Receiver, metrics: Arc) { + while let Ok(job) = rx.recv() { + let op_id = job.store.operator_id; + let epoch = job.epoch; + let start = Instant::now(); + + let result = catch_unwind(AssertUnwindSafe(|| { + job.store + .execute_spill_sync(job.epoch, job.data, job.tombstone_snapshot, &metrics) + })); + + let duration_ms = start.elapsed().as_millis(); + metrics.record_spill_duration(op_id, duration_ms); + + match result { + Ok(Ok(())) => tracing::debug!(op_id, epoch, duration_ms, "Spill success"), + Ok(Err(e)) => tracing::error!(op_id, epoch, duration_ms, %e, "Spill I/O Error"), + Err(_) => tracing::error!(op_id, epoch, "CRITICAL: Spill thread PANICKED! Recovered."), + } + } +} + +fn compact_worker_loop(rx: Receiver, metrics: Arc) { + while let Ok(job) = rx.recv() { + let (store, is_major) = match job { + CompactJob::Minor { store } => (store, false), + CompactJob::Major { store } => (store, true), + }; + + let op_id = store.operator_id; + let start = Instant::now(); + + let result = catch_unwind(AssertUnwindSafe(|| { + store.execute_compact_sync(is_major, &metrics) + })); + + let duration_ms = start.elapsed().as_millis(); + metrics.record_compaction_duration(op_id, is_major, duration_ms); + + match result { + Ok(Ok(())) => tracing::info!(op_id, is_major, duration_ms, "Compaction success"), + Ok(Err(e)) => tracing::error!(op_id, is_major, duration_ms, %e, "Compaction I/O Error"), + Err(_) => tracing::error!(op_id, is_major, "CRITICAL: Compact thread PANICKED!"), + } + } +} diff --git a/src/runtime/streaming/state/metrics.rs b/src/runtime/streaming/state/metrics.rs new file mode 100644 index 00000000..c6d5ae4e --- /dev/null +++ b/src/runtime/streaming/state/metrics.rs @@ -0,0 +1,18 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +pub trait StateMetricsCollector: Send + Sync + 'static { + fn record_memory_usage(&self, operator_id: u32, bytes: usize); + fn record_spill_duration(&self, operator_id: u32, duration_ms: u128); + fn record_compaction_duration(&self, operator_id: u32, is_major: bool, duration_ms: u128); + fn inc_io_errors(&self, operator_id: u32); +} + +/// Default no-op implementation. +pub struct NoopMetricsCollector; +impl StateMetricsCollector for NoopMetricsCollector { + fn record_memory_usage(&self, _: u32, _: usize) {} + fn record_spill_duration(&self, _: u32, _: u128) {} + fn record_compaction_duration(&self, _: u32, _: bool, _: u128) {} + fn inc_io_errors(&self, _: u32) {} +} diff --git a/src/runtime/streaming/state/mod.rs b/src/runtime/streaming/state/mod.rs new file mode 100644 index 00000000..ae14ad62 --- /dev/null +++ b/src/runtime/streaming/state/mod.rs @@ -0,0 +1,25 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod error; +mod io_manager; +pub mod metrics; +mod operator_state; + +#[allow(unused_imports)] +pub use error::{Result, StateEngineError}; +#[allow(unused_imports)] +pub use io_manager::{CompactJob, IoManager, IoPool, SpillJob}; +#[allow(unused_imports)] +pub use metrics::{NoopMetricsCollector, StateMetricsCollector}; +#[allow(unused_imports)] +pub use operator_state::{MemoryController, OperatorStateStore}; diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs new file mode 100644 index 00000000..86cb3729 --- /dev/null +++ b/src/runtime/streaming/state/operator_state.rs @@ -0,0 +1,1003 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +use super::error::{Result, StateEngineError}; +use super::io_manager::{CompactJob, IoManager, SpillJob}; +use super::metrics::StateMetricsCollector; +use arrow_array::builder::{BinaryBuilder, BooleanBuilder, UInt64Builder}; +use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema}; +use crossbeam_channel::TrySendError; +use parking_lot::{Mutex, RwLock}; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use parquet::arrow::{ArrowWriter, ProjectionMask}; +use parquet::file::properties::WriterProperties; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::fs::{self, File}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; +use tokio::sync::Notify; +use uuid::Uuid; + +pub(crate) const PARTITION_KEY_COL: &str = "__fs_partition_key"; + +pub type PartitionKey = Vec; +pub type MemTable = HashMap>; +pub type TombstoneMap = HashMap; + +const TOMBSTONE_ENTRY_OVERHEAD: usize = 8 + 16; + +#[derive(Debug)] +pub struct MemoryController { + current_usage: AtomicUsize, + hard_limit: usize, + soft_limit: usize, +} + +impl MemoryController { + pub fn new(soft_limit: usize, hard_limit: usize) -> Arc { + Arc::new(Self { + current_usage: AtomicUsize::new(0), + hard_limit, + soft_limit, + }) + } + pub fn exceeds_hard_limit(&self, incoming: usize) -> bool { + self.current_usage.load(Ordering::Relaxed) + incoming > self.hard_limit + } + pub fn should_spill(&self) -> bool { + self.current_usage.load(Ordering::Relaxed) > self.soft_limit + } + pub fn record_inc(&self, bytes: usize) { + self.current_usage.fetch_add(bytes, Ordering::Relaxed); + } + pub fn record_dec(&self, bytes: usize) { + self.current_usage.fetch_sub(bytes, Ordering::Relaxed); + } + pub fn usage_bytes(&self) -> usize { + self.current_usage.load(Ordering::Relaxed) + } +} + +pub struct OperatorStateStore { + pub operator_id: u32, + current_epoch: AtomicU64, + + active_table: RwLock, + immutable_tables: Mutex>, + + data_files: RwLock>, + tombstone_files: RwLock>, + tombstones: RwLock, + + mem_ctrl: Arc, + io_manager: IoManager, + + data_dir: PathBuf, + tombstone_dir: PathBuf, + + spill_notify: Arc, + is_spilling: AtomicBool, + is_compacting: AtomicBool, +} + +impl OperatorStateStore { + pub fn new( + operator_id: u32, + base_dir: impl AsRef, + mem_ctrl: Arc, + io_manager: IoManager, + ) -> Result> { + let op_dir = base_dir.as_ref().join(format!("op_{operator_id}")); + let data_dir = op_dir.join("data"); + let tombstone_dir = op_dir.join("tombstones"); + + fs::create_dir_all(&data_dir).map_err(StateEngineError::IoError)?; + fs::create_dir_all(&tombstone_dir).map_err(StateEngineError::IoError)?; + + Ok(Arc::new(Self { + operator_id, + current_epoch: AtomicU64::new(1), + active_table: RwLock::new(HashMap::new()), + immutable_tables: Mutex::new(VecDeque::new()), + data_files: RwLock::new(Vec::new()), + tombstone_files: RwLock::new(Vec::new()), + tombstones: RwLock::new(HashMap::new()), + mem_ctrl, + io_manager, + data_dir, + tombstone_dir, + spill_notify: Arc::new(Notify::new()), + is_spilling: AtomicBool::new(false), + is_compacting: AtomicBool::new(false), + })) + } + + pub async fn put(self: &Arc, key: PartitionKey, batch: RecordBatch) -> Result<()> { + let size = batch.get_array_memory_size(); + while self.mem_ctrl.exceeds_hard_limit(size) { + self.trigger_spill(); + self.spill_notify.notified().await; + } + + self.mem_ctrl.record_inc(size); + self.active_table + .write() + .entry(key) + .or_default() + .push(batch); + + if self.mem_ctrl.should_spill() { + self.downgrade_active_table(self.current_epoch.load(Ordering::Acquire)); + self.trigger_spill(); + } + Ok(()) + } + + pub fn remove_batches(&self, key: PartitionKey) -> Result<()> { + let current_ep = self.current_epoch.load(Ordering::Acquire); + let tombstone_mem_size = key.len() + TOMBSTONE_ENTRY_OVERHEAD; + + { + let mut tb_guard = self.tombstones.write(); + if tb_guard.insert(key.clone(), current_ep).is_none() { + self.mem_ctrl.record_inc(tombstone_mem_size); + } + } + + if let Some(batches) = self.active_table.write().remove(&key) { + let released: usize = batches.iter().map(|b| b.get_array_memory_size()).sum(); + self.mem_ctrl.record_dec(released); + } + + let mut imm = self.immutable_tables.lock(); + for (_, table) in imm.iter_mut() { + if let Some(batches) = table.remove(&key) { + let released: usize = batches.iter().map(|b| b.get_array_memory_size()).sum(); + self.mem_ctrl.record_dec(released); + } + } + + Ok(()) + } + + pub fn snapshot_epoch(self: &Arc, epoch: u64) -> Result<()> { + self.downgrade_active_table(epoch); + self.trigger_spill(); + self.current_epoch + .store(epoch.saturating_add(1), Ordering::Release); + Ok(()) + } + + fn downgrade_active_table(&self, epoch: u64) { + let mut active_guard = self.active_table.write(); + if active_guard.is_empty() { + return; + } + let old_active = std::mem::take(&mut *active_guard); + self.immutable_tables.lock().push_back((epoch, old_active)); + } + + pub async fn get_batches(&self, key: &[u8]) -> Result> { + let deleted_epoch = self.tombstones.read().get(key).copied(); + let mut out = Vec::new(); + + if let Some(batches) = self.active_table.read().get(key) { + out.extend(batches.clone()); + } + + for (table_epoch, table) in self.immutable_tables.lock().iter().rev() { + if let Some(del_ep) = deleted_epoch + && *table_epoch <= del_ep + { + continue; + } + if let Some(batches) = table.get(key) { + out.extend(batches.clone()); + } + } + + let paths: Vec = self.data_files.read().clone(); + if paths.is_empty() { + return Ok(out); + } + + let pk = key.to_vec(); + let merged = tokio::task::spawn_blocking(move || -> Result> { + let mut acc = Vec::new(); + for path in paths { + let file_epoch = extract_epoch(&path); + if let Some(del_ep) = deleted_epoch + && file_epoch <= del_ep + { + continue; + } + + // Native Bloom Filter intercepts empty reads here + let file = File::open(&path).map_err(StateEngineError::IoError)?; + let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + for maybe in reader.by_ref() { + if let Some(filtered) = filter_and_strip_partition_key(&maybe?, &pk)? { + acc.push(filtered); + } + } + } + Ok(acc) + }) + .await + .map_err(|_| StateEngineError::Corruption("Tokio task panicked".into()))??; + + out.extend(merged); + Ok(out) + } + + fn trigger_spill(self: &Arc) { + if !self.is_spilling.swap(true, Ordering::SeqCst) { + let target = self.immutable_tables.lock().pop_front(); + let Some((epoch, data)) = target else { + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + return; + }; + + let tombstone_snapshot = self.tombstones.read().clone(); + let job = SpillJob { + store: self.clone(), + epoch, + data, + tombstone_snapshot, + }; + + match self.io_manager.try_send_spill(job) { + Ok(()) => {} + Err(TrySendError::Full(j)) | Err(TrySendError::Disconnected(j)) => { + self.immutable_tables.lock().push_front((j.epoch, j.data)); + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + } + } + } + } + + pub fn trigger_minor_compaction(self: &Arc) { + if !self.is_compacting.swap(true, Ordering::SeqCst) { + let _ = self.io_manager.try_send_compact(CompactJob::Minor { + store: self.clone(), + }); + } + } + + pub fn trigger_major_compaction(self: &Arc) { + if !self.is_compacting.swap(true, Ordering::SeqCst) { + let _ = self.io_manager.try_send_compact(CompactJob::Major { + store: self.clone(), + }); + } + } + + pub(crate) fn execute_spill_sync( + self: &Arc, + epoch: u64, + data: MemTable, + tombstones: TombstoneMap, + metrics: &Arc, + ) -> Result<()> { + let mut batches_to_write = Vec::new(); + let mut size_to_release: usize = 0; + let distinct_keys_count = data.len() as u64; + + for (key, batches) in data { + for batch in batches { + size_to_release += batch.get_array_memory_size(); + batches_to_write.push(inject_partition_key(&batch, &key)?); + } + } + + if !batches_to_write.is_empty() { + let path = self.data_dir.join(Self::generate_data_file_name(epoch)); + if let Err(e) = + write_parquet_with_bloom_atomic(&path, &batches_to_write, distinct_keys_count) + { + metrics.inc_io_errors(self.operator_id); + let restored = restore_memtable_from_injected_batches(batches_to_write)?; + self.immutable_tables.lock().push_front((epoch, restored)); + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + return Err(e); + } + self.data_files.write().push(path); + } + + if !tombstones.is_empty() { + let mut key_builder = BinaryBuilder::new(); + let mut epoch_builder = UInt64Builder::new(); + let tomb_ndv = tombstones.len() as u64; + + for (key, del_epoch) in tombstones.iter() { + key_builder.append_value(key); + epoch_builder.append_value(*del_epoch); + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("deleted_key", DataType::Binary, false), + Field::new("deleted_epoch", DataType::UInt64, false), + ])); + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(key_builder.finish()), + Arc::new(epoch_builder.finish()), + ], + )?; + + let path = self + .tombstone_dir + .join(Self::generate_tombstone_file_name(epoch)); + if let Err(e) = write_parquet_with_bloom_atomic(&path, &[batch], tomb_ndv) { + metrics.inc_io_errors(self.operator_id); + return Err(e); + } + self.tombstone_files.write().push(path); + } + + self.mem_ctrl.record_dec(size_to_release); + metrics.record_memory_usage(self.operator_id, self.mem_ctrl.usage_bytes()); + + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + + if !self.immutable_tables.lock().is_empty() { + self.trigger_spill(); + } + Ok(()) + } + + pub(crate) fn execute_compact_sync( + self: &Arc, + is_major: bool, + metrics: &Arc, + ) -> Result<()> { + let result = (|| -> Result<()> { + let files_to_merge = { + let df = self.data_files.read(); + if df.len() < 2 { + return Ok(()); + } + if is_major { + df.clone() + } else { + df.iter().take(2).cloned().collect() + } + }; + + let tombstone_snapshot = self.tombstones.read().clone(); + let compacted_watermark_epoch = files_to_merge + .iter() + .map(|p| extract_epoch(p)) + .max() + .unwrap_or(0); + let new_path = self + .data_dir + .join(Self::generate_data_file_name(compacted_watermark_epoch)); + + let mut all_batches = Vec::new(); + let mut estimated_rows = 0; + + for path in &files_to_merge { + let file_epoch = extract_epoch(path); + let file = File::open(path).map_err(StateEngineError::IoError)?; + let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + for batch in reader { + let b = batch?; + if let Some(filtered) = + filter_tombstones_from_batch(&b, &tombstone_snapshot, file_epoch)? + { + estimated_rows += filtered.num_rows() as u64; + all_batches.push(filtered); + } + } + } + + if !all_batches.is_empty() { + if let Err(e) = write_parquet_with_bloom_atomic( + &new_path, + &all_batches, + estimated_rows.max(100), + ) { + metrics.inc_io_errors(self.operator_id); + return Err(e); + } + let mut df = self.data_files.write(); + df.retain(|p| !files_to_merge.contains(p)); + df.push(new_path); + } else { + let mut df = self.data_files.write(); + df.retain(|p| !files_to_merge.contains(p)); + } + + for path in &files_to_merge { + let _ = fs::remove_file(path); + } + + // Watermark GC + { + let mut tg = self.tombstones.write(); + let mut memory_freed = 0; + + tg.retain(|key, deleted_epoch| { + if *deleted_epoch <= compacted_watermark_epoch { + memory_freed += key.len() + TOMBSTONE_ENTRY_OVERHEAD; + false + } else { + true + } + }); + + if memory_freed > 0 { + self.mem_ctrl.record_dec(memory_freed); + metrics.record_memory_usage(self.operator_id, self.mem_ctrl.usage_bytes()); + } + } + + { + let mut tf_guard = self.tombstone_files.write(); + tf_guard.retain(|p| { + if extract_epoch(p) <= compacted_watermark_epoch { + let _ = fs::remove_file(p); + return false; + } + true + }); + } + + Ok(()) + })(); + + self.is_compacting.store(false, Ordering::SeqCst); + result + } + + pub async fn restore_metadata(&self, safe_epoch: u64) -> Result> { + self.active_table.write().clear(); + self.immutable_tables + .lock() + .retain(|(e, _)| *e <= safe_epoch); + + let cleanup_future = |files: &mut Vec| { + files.retain(|path| { + if extract_epoch(path) > safe_epoch { + let _ = fs::remove_file(path); + false + } else { + true + } + }); + }; + cleanup_future(&mut self.data_files.write()); + cleanup_future(&mut self.tombstone_files.write()); + + let tomb_paths = self.tombstone_files.read().clone(); + let loaded_tombstones = tokio::task::spawn_blocking(move || -> Result { + let mut map = HashMap::new(); + for path in tomb_paths { + let file = File::open(&path).map_err(StateEngineError::IoError)?; + let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + for batch in reader { + let batch = batch?; + let key_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let ep_col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..key_col.len() { + let k = key_col.value(i).to_vec(); + let e = ep_col.value(i); + let current_max = map.get(&k).copied().unwrap_or(0); + if e > current_max { + map.insert(k, e); + } + } + } + } + Ok(map) + }) + .await + .map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; + + let mut total_tombstone_mem = 0; + for key in loaded_tombstones.keys() { + total_tombstone_mem += key.len() + TOMBSTONE_ENTRY_OVERHEAD; + } + self.mem_ctrl.record_inc(total_tombstone_mem); + *self.tombstones.write() = loaded_tombstones.clone(); + + let data_paths = self.data_files.read().clone(); + let active_keys = tokio::task::spawn_blocking(move || -> Result> { + let mut keys = HashSet::new(); + for path in data_paths { + let file_epoch = extract_epoch(&path); + let file = File::open(&path).map_err(StateEngineError::IoError)?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file)?; + let schema = builder.parquet_schema(); + let mask = ProjectionMask::leaves(schema, vec![schema.columns().len() - 1]); + let reader = builder.with_projection(mask).build()?; + + for batch in reader { + let batch = batch?; + let key_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..key_col.len() { + let k = key_col.value(i).to_vec(); + let is_active = match loaded_tombstones.get(&k) { + Some(del_ep) => *del_ep < file_epoch, + None => true, + }; + if is_active { + keys.insert(k); + } + } + } + } + Ok(keys) + }) + .await + .map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; + + self.current_epoch.store(safe_epoch + 1, Ordering::Release); + Ok(active_keys) + } + + // ======================================================================== + // UUID-based file name generators + // ======================================================================== + + fn generate_data_file_name(epoch: u64) -> String { + format!("data-epoch-{}_uuid-{}.parquet", epoch, Uuid::now_v7()) + } + + fn generate_tombstone_file_name(epoch: u64) -> String { + format!("tombstone-epoch-{}_uuid-{}.parquet", epoch, Uuid::now_v7()) + } +} + +// ============================================================================ +// Internal helper functions +// ============================================================================ + +fn write_parquet_with_bloom_atomic(path: &Path, batches: &[RecordBatch], ndv: u64) -> Result<()> { + if batches.is_empty() { + return Ok(()); + } + let tmp = path.with_extension("tmp"); + { + let file = File::create(&tmp).map_err(StateEngineError::IoError)?; + let props = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .set_bloom_filter_ndv(ndv) + .build(); + + let mut writer = ArrowWriter::try_new(&file, batches[0].schema(), Some(props))?; + for b in batches { + writer.write(b)?; + } + writer.close()?; + file.sync_all().map_err(StateEngineError::IoError)?; + } + fs::rename(&tmp, path).map_err(StateEngineError::IoError)?; + Ok(()) +} + +fn extract_epoch(path: &Path) -> u64 { + let name = path + .file_name() + .unwrap_or_default() + .to_str() + .unwrap_or_default(); + if let Some(start) = name.find("-epoch-") { + let after = &name[start + 7..]; + if let Some(end) = after.find("_uuid-") { + return after[..end].parse().unwrap_or(0); + } + } + 0 +} + +fn inject_partition_key(batch: &RecordBatch, key: &[u8]) -> Result { + let mut fields = batch.schema().fields().to_vec(); + fields.push(Arc::new(Field::new( + PARTITION_KEY_COL, + DataType::Binary, + false, + ))); + let schema = Arc::new(Schema::new(fields)); + let key_array = Arc::new(BinaryArray::from_iter_values(std::iter::repeat_n( + key, + batch.num_rows(), + ))); + let mut cols = batch.columns().to_vec(); + cols.push(key_array as Arc); + Ok(RecordBatch::try_new(schema, cols)?) +} + +fn filter_tombstones_from_batch( + batch: &RecordBatch, + tombstones: &TombstoneMap, + file_epoch: u64, +) -> Result> { + if tombstones.is_empty() { + return Ok(Some(batch.clone())); + } + let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { + return Ok(Some(batch.clone())); + }; + + let key_col = batch + .column(idx) + .as_any() + .downcast_ref::() + .unwrap(); + let mut mask_builder = BooleanBuilder::with_capacity(batch.num_rows()); + let mut has_valid = false; + + for i in 0..batch.num_rows() { + let key = key_col.value(i).to_vec(); + let keep = match tombstones.get(&key) { + Some(deleted_epoch) => *deleted_epoch < file_epoch, + None => true, + }; + mask_builder.append_value(keep); + if keep { + has_valid = true; + } + } + + if !has_valid { + return Ok(None); + } + let mask = mask_builder.finish(); + Ok(Some(arrow::compute::filter_record_batch(batch, &mask)?)) +} + +fn filter_and_strip_partition_key( + batch: &RecordBatch, + target_key: &[u8], +) -> Result> { + let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { + return Ok(Some(batch.clone())); + }; + let key_col = batch + .column(idx) + .as_any() + .downcast_ref::() + .unwrap(); + let mut mask_builder = BooleanBuilder::with_capacity(batch.num_rows()); + for i in 0..batch.num_rows() { + mask_builder.append_value(key_col.value(i) == target_key); + } + let mask = mask_builder.finish(); + let filtered = arrow::compute::filter_record_batch(batch, &mask)?; + if filtered.num_rows() == 0 { + return Ok(None); + } + let mut proj: Vec = (0..filtered.num_columns()).collect(); + proj.retain(|&i| i != idx); + Ok(Some(filtered.project(&proj)?)) +} + +fn restore_memtable_from_injected_batches(batches: Vec) -> Result { + let mut m = MemTable::new(); + for batch in batches { + let idx = batch.schema().index_of(PARTITION_KEY_COL).unwrap(); + let key_col = batch + .column(idx) + .as_any() + .downcast_ref::() + .unwrap(); + let pk = key_col.value(0).to_vec(); + let mut proj: Vec = (0..batch.num_columns()).collect(); + proj.retain(|&i| i != idx); + m.entry(pk).or_default().push(batch.project(&proj)?); + } + Ok(m) +} + +#[cfg(test)] +mod tests { + use super::super::io_manager::IoPool; + use super::super::metrics::NoopMetricsCollector; + use super::*; + use arrow_array::Int64Array; + use tempfile::TempDir; + + fn test_schema() -> Arc { + Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + false, + )])) + } + + fn make_batch(values: &[i64]) -> RecordBatch { + RecordBatch::try_new( + test_schema(), + vec![Arc::new(Int64Array::from(values.to_vec()))], + ) + .unwrap() + } + + fn setup() -> (TempDir, Arc, IoManager, IoPool) { + let tmp = TempDir::new().unwrap(); + let mem = MemoryController::new(1024 * 1024, 2 * 1024 * 1024); + let metrics: Arc = Arc::new(NoopMetricsCollector); + let (pool, mgr) = IoPool::try_new(1, 1, metrics).unwrap(); + (tmp, mem, mgr, pool) + } + + #[tokio::test] + async fn test_put_and_get() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"key-a".to_vec(); + let batch = make_batch(&[10, 20, 30]); + store.put(key.clone(), batch).await.unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert_eq!(result.len(), 1); + let col = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[10, 20, 30]); + } + + #[tokio::test] + async fn test_multiple_puts_same_key() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"key-x".to_vec(); + store.put(key.clone(), make_batch(&[1])).await.unwrap(); + store.put(key.clone(), make_batch(&[2])).await.unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert_eq!(result.len(), 2); + } + + #[tokio::test] + async fn test_get_nonexistent_key() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let result = store.get_batches(b"no-such-key").await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_remove_batches() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"key-del".to_vec(); + store.put(key.clone(), make_batch(&[42])).await.unwrap(); + + store.remove_batches(key.clone()).unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_remove_does_not_affect_other_keys() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let k1 = b"key-1".to_vec(); + let k2 = b"key-2".to_vec(); + store.put(k1.clone(), make_batch(&[1])).await.unwrap(); + store.put(k2.clone(), make_batch(&[2])).await.unwrap(); + + store.remove_batches(k1.clone()).unwrap(); + + assert!(store.get_batches(&k1).await.unwrap().is_empty()); + assert_eq!(store.get_batches(&k2).await.unwrap().len(), 1); + } + + #[tokio::test] + async fn test_snapshot_epoch_advances() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + store.put(b"k".to_vec(), make_batch(&[1])).await.unwrap(); + store.snapshot_epoch(5).unwrap(); + + assert_eq!(store.current_epoch.load(Ordering::Acquire), 6); + } + + #[tokio::test] + async fn test_data_survives_snapshot_via_spill() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"persist".to_vec(); + store.put(key.clone(), make_batch(&[99])).await.unwrap(); + store.snapshot_epoch(1).unwrap(); + + // snapshot_epoch triggers a spill; wait for the background worker to + // flush the data to disk so get_batches can read it from parquet files. + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + + let result = store.get_batches(&key).await.unwrap(); + assert!(!result.is_empty()); + let col = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[99]); + } + + #[tokio::test] + async fn test_tombstone_hides_immutable_data() { + let (tmp, mem, mgr, _pool) = setup(); + let store = OperatorStateStore::new(1, tmp.path(), mem, mgr).unwrap(); + + let key = b"will-die".to_vec(); + store.put(key.clone(), make_batch(&[7])).await.unwrap(); + + // Move to immutable via snapshot + store.snapshot_epoch(1).unwrap(); + + // Tombstone at epoch 2 (> immutable epoch 1) + store.current_epoch.store(2, Ordering::Release); + store.remove_batches(key.clone()).unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_memory_controller_tracking() { + let mem = MemoryController::new(1024, 2048); + assert_eq!(mem.usage_bytes(), 0); + + mem.record_inc(100); + assert_eq!(mem.usage_bytes(), 100); + + mem.record_dec(40); + assert_eq!(mem.usage_bytes(), 60); + + assert!(!mem.should_spill()); + mem.record_inc(1000); + assert!(mem.should_spill()); + } + + #[tokio::test] + async fn test_memory_controller_hard_limit() { + let mem = MemoryController::new(512, 1024); + assert!(!mem.exceeds_hard_limit(500)); + assert!(mem.exceeds_hard_limit(1025)); + + mem.record_inc(800); + assert!(mem.exceeds_hard_limit(300)); + assert!(!mem.exceeds_hard_limit(200)); + } + + #[test] + fn test_extract_epoch() { + let path = PathBuf::from("/tmp/data-epoch-42_uuid-abc123.parquet"); + assert_eq!(extract_epoch(&path), 42); + + let path2 = PathBuf::from("/tmp/tombstone-epoch-100_uuid-def456.parquet"); + assert_eq!(extract_epoch(&path2), 100); + + let path3 = PathBuf::from("/tmp/random-file.parquet"); + assert_eq!(extract_epoch(&path3), 0); + } + + #[test] + fn test_inject_and_strip_partition_key() { + let batch = make_batch(&[1, 2, 3]); + let key = b"pk-test"; + + let injected = inject_partition_key(&batch, key).unwrap(); + assert_eq!(injected.num_columns(), 2); + assert!(injected.schema().index_of(PARTITION_KEY_COL).is_ok()); + + let stripped = filter_and_strip_partition_key(&injected, key) + .unwrap() + .unwrap(); + assert_eq!(stripped.num_columns(), 1); + let col = stripped + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[1, 2, 3]); + } + + #[test] + fn test_filter_partition_key_mismatch() { + let batch = make_batch(&[1, 2]); + let injected = inject_partition_key(&batch, b"pk-a").unwrap(); + + let result = filter_and_strip_partition_key(&injected, b"pk-b").unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_restore_memtable_roundtrip() { + let batch1 = inject_partition_key(&make_batch(&[10]), b"k1").unwrap(); + let batch2 = inject_partition_key(&make_batch(&[20]), b"k2").unwrap(); + let batch3 = inject_partition_key(&make_batch(&[30]), b"k1").unwrap(); + + let restored = + restore_memtable_from_injected_batches(vec![batch1, batch2, batch3]).unwrap(); + + assert_eq!(restored.len(), 2); + assert_eq!(restored[b"k1".as_ref()].len(), 2); + assert_eq!(restored[b"k2".as_ref()].len(), 1); + } + + #[test] + fn test_write_and_read_parquet() { + let tmp = TempDir::new().unwrap(); + let path = tmp.path().join("test.parquet"); + + let batch = make_batch(&[100, 200, 300]); + write_parquet_with_bloom_atomic(&path, std::slice::from_ref(&batch), 1).unwrap(); + + let file = File::open(&path).unwrap(); + let reader = ParquetRecordBatchReaderBuilder::try_new(file) + .unwrap() + .build() + .unwrap(); + + let read_batches: Vec = reader.map(|r| r.unwrap()).collect(); + assert_eq!(read_batches.len(), 1); + let col = read_batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[100, 200, 300]); + } + + #[test] + fn test_filter_tombstones_from_batch() { + let batch = make_batch(&[1, 2, 3]); + let key = b"victim"; + let injected = inject_partition_key(&batch, key).unwrap(); + + let mut tombstones: TombstoneMap = HashMap::new(); + tombstones.insert(key.to_vec(), 10); + + // file_epoch <= tombstone epoch => fully filtered + let result = filter_tombstones_from_batch(&injected, &tombstones, 5).unwrap(); + assert!(result.is_none()); + + // file_epoch > tombstone epoch => data survives + let result = filter_tombstones_from_batch(&injected, &tombstones, 15).unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_write_empty_batches_is_noop() { + let tmp = TempDir::new().unwrap(); + let path = tmp.path().join("empty.parquet"); + + write_parquet_with_bloom_atomic(&path, &[], 0).unwrap(); + assert!(!path.exists()); + } +} diff --git a/src/server/initializer.rs b/src/server/initializer.rs index 785321b8..c1e11569 100644 --- a/src/server/initializer.rs +++ b/src/server/initializer.rs @@ -158,7 +158,7 @@ fn initialize_python_service(config: &GlobalConfig) -> Result<()> { fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { use crate::runtime::streaming::factory::OperatorFactory; use crate::runtime::streaming::factory::Registry; - use crate::runtime::streaming::job::JobManager; + use crate::runtime::streaming::job::{JobManager, StateConfig}; use std::sync::Arc; let registry = Arc::new(Registry::new()); @@ -168,7 +168,11 @@ fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { .max_memory_bytes .unwrap_or(256 * 1024 * 1024); - JobManager::init(factory, max_memory_bytes).context("JobManager service failed to start")?; + let state_base_dir = std::env::temp_dir().join("function-stream").join("state"); + let state_config = StateConfig::default(); + + JobManager::init(factory, max_memory_bytes, state_base_dir, state_config) + .context("JobManager service failed to start")?; Ok(()) } diff --git a/src/sql/analysis/aggregate_rewriter.rs b/src/sql/analysis/aggregate_rewriter.rs index d7be0db8..ddcb0294 100644 --- a/src/sql/analysis/aggregate_rewriter.rs +++ b/src/sql/analysis/aggregate_rewriter.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use crate::sql::analysis::streaming_window_analzer::StreamingWindowAnalzer; use crate::sql::logical_node::aggregate::StreamWindowAggregateNode; use crate::sql::logical_node::key_calculation::{KeyExtractionNode, KeyExtractionStrategy}; +use crate::sql::logical_node::updating_aggregate::ContinuousAggregateNode; use crate::sql::schema::StreamSchemaProvider; use crate::sql::types::{ QualifiedField, TIMESTAMP_FIELD, WindowBehavior, WindowType, build_df_schema_with_metadata, @@ -70,10 +71,10 @@ impl TreeNodeRewriter for AggregateRewriter<'_> { }) .collect(); - // 3. Dispatch to Updating Aggregate if no windowing is detected. + // 3. Dispatch to ContinuousAggregateNode (UpdatingAggregate) if no windowing is detected. let input_window = StreamingWindowAnalzer::get_window(&agg.input)?; if window_exprs.is_empty() && input_window.is_none() { - return self.rewrite_as_updating_aggregate( + return self.rewrite_as_continuous_updating_aggregate( agg.input, key_fields, agg.group_expr, @@ -174,9 +175,9 @@ impl<'a> AggregateRewriter<'a> { })) } - /// [Strategy] Rewrites standard GROUP BY into a non-windowed updating aggregate. + /// [Strategy] Rewrites standard GROUP BY into a ContinuousAggregateNode with retraction semantics. /// Injected max(_timestamp) ensures the streaming pulse (Watermark) continues to propagate. - fn rewrite_as_updating_aggregate( + fn rewrite_as_continuous_updating_aggregate( &self, input: Arc, key_fields: Vec, @@ -184,6 +185,7 @@ impl<'a> AggregateRewriter<'a> { mut aggr_expr: Vec, schema: Arc, ) -> Result> { + let key_count = key_fields.len(); let keyed_input = self.build_keyed_input(input, &group_expr, &key_fields)?; // Ensure the updating stream maintains time awareness. @@ -207,14 +209,23 @@ impl<'a> AggregateRewriter<'a> { schema.metadata().clone(), )?); - let aggregate = Aggregate::try_new_with_schema( + let base_aggregate = Aggregate::try_new_with_schema( Arc::new(keyed_input), group_expr, aggr_expr, output_schema, )?; - Ok(Transformed::yes(LogicalPlan::Aggregate(aggregate))) + let continuous_node = ContinuousAggregateNode::try_new( + LogicalPlan::Aggregate(base_aggregate), + (0..key_count).collect(), + None, + self.schema_provider.planning_options.ttl, + )?; + + Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(continuous_node), + }))) } /// [Strategy] Reconciles window definitions between the input stream and the current GROUP BY. @@ -232,24 +243,16 @@ impl<'a> AggregateRewriter<'a> { let has_group_window = !window_expr_info.is_empty(); match (input_window, has_group_window) { - // Re-aggregation or subquery with an existing window. (Some(i_win), true) => { let (idx, g_win) = window_expr_info.pop().unwrap(); if i_win != g_win { - return plan_err!( - "Inconsistent windowing: input is {:?}, but group by is {:?}", - i_win, - g_win - ); + return plan_err!("Inconsistent windowing detected"); } if let Some(field) = visitor.fields.iter().next() { group_expr[idx] = Expr::Column(field.qualified_column()); Ok(WindowBehavior::InData) } else { - if matches!(i_win, WindowType::Session { .. }) { - return plan_err!("Nested session windows are not supported"); - } group_expr.remove(idx); Ok(WindowBehavior::FromOperator { window: i_win, @@ -259,7 +262,6 @@ impl<'a> AggregateRewriter<'a> { }) } } - // First-time windowing defined in this aggregate. (None, true) => { let (idx, g_win) = window_expr_info.pop().unwrap(); group_expr.remove(idx); @@ -270,9 +272,8 @@ impl<'a> AggregateRewriter<'a> { is_nested: false, }) } - // Passthrough: input is already windowed, no new window in group by. (Some(_), false) => Ok(WindowBehavior::InData), - _ => unreachable!("Dispatched to non-windowed path previously"), + _ => unreachable!("Handled by updating path"), } } } diff --git a/src/storage/stream_catalog/manager.rs b/src/storage/stream_catalog/manager.rs index 3804a95a..471e3cd9 100644 --- a/src/storage/stream_catalog/manager.rs +++ b/src/storage/stream_catalog/manager.rs @@ -17,7 +17,7 @@ use datafusion::common::{Result as DFResult, internal_err, plan_err}; use prost::Message; use protocol::function_stream_graph::FsProgram; use protocol::storage::{self as pb, table_definition}; -use tracing::{info, warn}; +use tracing::{debug, info, warn}; use unicase::UniCase; use crate::sql::common::constants::sql_field; @@ -88,6 +88,7 @@ impl CatalogManager { table_name: &str, fs_program: &FsProgram, comment: &str, + checkpoint_interval_ms: u64, ) -> DFResult<()> { let program_bytes = fs_program.encode_to_vec(); let def = pb::StreamingTableDefinition { @@ -95,11 +96,13 @@ impl CatalogManager { created_at_millis: chrono::Utc::now().timestamp_millis(), fs_program_bytes: program_bytes, comment: comment.to_string(), + checkpoint_interval_ms, + latest_checkpoint_epoch: 0, }; let payload = def.encode_to_vec(); let key = Self::build_streaming_job_key(table_name); self.store.put(&key, payload)?; - info!(table = %table_name, "Streaming job definition persisted"); + info!(table = %table_name, interval_ms = checkpoint_interval_ms, "Streaming job definition persisted"); Ok(()) } @@ -110,7 +113,38 @@ impl CatalogManager { Ok(()) } - pub fn load_streaming_job_definitions(&self) -> DFResult> { + /// Persist the globally-completed checkpoint epoch after all operators ACK. + /// Only advances forward; stale epochs are silently ignored. + pub fn commit_job_checkpoint(&self, table_name: &str, epoch: u64) -> DFResult<()> { + let key = Self::build_streaming_job_key(table_name); + + let current_payload = self.store.get(&key)?.ok_or_else(|| { + datafusion::common::DataFusionError::Plan(format!( + "Cannot commit checkpoint: Streaming job '{}' not found in catalog", + table_name + )) + })?; + + let mut def = + pb::StreamingTableDefinition::decode(current_payload.as_slice()).map_err(|e| { + datafusion::common::DataFusionError::Execution(format!( + "Protobuf decode error: {}", + e + )) + })?; + + if epoch > def.latest_checkpoint_epoch { + def.latest_checkpoint_epoch = epoch; + let new_payload = def.encode_to_vec(); + self.store.put(&key, new_payload)?; + debug!(table = %table_name, epoch = epoch, "Checkpoint metadata committed to Catalog"); + } + + Ok(()) + } + + /// Returns (table_name, program, checkpoint_interval_ms, latest_checkpoint_epoch). + pub fn load_streaming_job_definitions(&self) -> DFResult> { let records = self.store.scan_prefix(STREAMING_JOB_KEY_PREFIX)?; let mut out = Vec::with_capacity(records.len()); for (key, payload) in records { @@ -136,7 +170,12 @@ impl CatalogManager { continue; } }; - out.push((def.table_name, program)); + out.push(( + def.table_name, + program, + def.checkpoint_interval_ms, + def.latest_checkpoint_epoch, + )); } Ok(out) } @@ -522,12 +561,28 @@ pub fn restore_streaming_jobs_from_store() { let mut restored = 0usize; let mut failed = 0usize; - for (table_name, fs_program) in definitions { + for (table_name, fs_program, interval_ms, latest_epoch) in definitions { let jm = job_manager.clone(); let name = table_name.clone(); - match rt.block_on(jm.submit_job(name.clone(), fs_program)) { + + let custom_interval = if interval_ms > 0 { + Some(interval_ms) + } else { + None + }; + let recovery_epoch = if latest_epoch > 0 { + Some(latest_epoch) + } else { + None + }; + + match rt.block_on(jm.submit_job(name.clone(), fs_program, custom_interval, recovery_epoch)) + { Ok(job_id) => { - info!(table = %table_name, job_id = %job_id, "Streaming job restored"); + info!( + table = %table_name, job_id = %job_id, + epoch = latest_epoch, "Streaming job restored" + ); restored += 1; } Err(e) => {