diff --git a/Cargo.lock b/Cargo.lock index bb673af67..a841408c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,7 +156,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -167,7 +167,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -2069,7 +2069,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -2279,9 +2279,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "fusillade" -version = "15.1.1" +version = "16.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8df645c11fe9c557f76e51d1fa435a167f2a9c271584b109a575248daeb39f6c" +checksum = "98c322e8248970d339960fb800ccc313bc69cd41c828e94ac0256765eb52f505" dependencies = [ "anyhow", "async-trait", @@ -3681,7 +3681,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -5109,7 +5109,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -5180,7 +5180,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -5643,7 +5643,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -6026,7 +6026,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -6942,7 +6942,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] diff --git a/dwctl/Cargo.toml b/dwctl/Cargo.toml index 6e2f0d9b7..16c3f4768 100644 --- a/dwctl/Cargo.toml +++ b/dwctl/Cargo.toml @@ -19,7 +19,7 @@ embedded-db = ["dep:postgresql_embedded"] [dependencies] axum = { version = "0.8", features = ["multipart"] } -fusillade = { version = "15.1.1" } +fusillade = { version = "16.0.0" } tokio = { version = "1.0", features = ["full"] } tokio-stream = { version = "0.1", features = ["sync"] } tokio-util = "0.7" diff --git a/dwctl/src/api/handlers/batches.rs b/dwctl/src/api/handlers/batches.rs index 1cf50f153..4ffcde120 100644 --- a/dwctl/src/api/handlers/batches.rs +++ b/dwctl/src/api/handlers/batches.rs @@ -472,7 +472,7 @@ pub async fn create_batch( } } - let windows = vec![(req.completion_window.clone(), parse_window_to_seconds(&req.completion_window))]; + let windows = vec![(req.completion_window.clone(), None, parse_window_to_seconds(&req.completion_window))]; let states = vec!["pending".to_string(), "claimed".to_string(), "processing".to_string()]; let model_throughputs = batch_model_info.throughputs; @@ -689,7 +689,7 @@ async fn reserve_capacity_for_batch( file_model_counts: &HashMap, model_throughputs: &HashMap, model_ids_by_alias: &HashMap, - windows: &[(String, i64)], + windows: &[(String, Option, i64)], states: &[String], model_filter: &[String], relaxation_factor: f32, @@ -733,7 +733,7 @@ async fn reserve_capacity_for_batch( // Fetch pending counts AFTER locks to avoid stale snapshots let pending_counts = state .request_manager - .get_pending_request_counts_by_model_and_completion_window(windows, states, model_filter, true) + .get_pending_request_counts_by_model_and_window(windows, states, model_filter, true) .await .map_err(|e| Error::Internal { operation: format!("get pending counts: {}", e), @@ -2796,7 +2796,7 @@ mod tests { let model_throughputs = HashMap::from([(alias.clone(), 1000.0_f32)]); let model_ids_by_alias = HashMap::from([(alias.clone(), model_id)]); - let windows = vec![("24h".to_string(), super::parse_window_to_seconds("24h"))]; + let windows = vec![("1h".to_string(), None, super::parse_window_to_seconds("24h"))]; let states = vec!["pending".to_string(), "claimed".to_string(), "processing".to_string()]; let model_filter = vec![alias.clone()]; @@ -2858,7 +2858,7 @@ mod tests { let model_throughputs = HashMap::from([(alias.clone(), 0.0_f32)]); let model_ids_by_alias = HashMap::from([(alias.clone(), model_id)]); - let windows = vec![("1h".to_string(), super::parse_window_to_seconds("1h"))]; + let windows = vec![("1h".to_string(), None, super::parse_window_to_seconds("1h"))]; let states = vec!["pending".to_string(), "claimed".to_string(), "processing".to_string()]; let model_filter = vec![alias.clone()]; @@ -3219,7 +3219,7 @@ mod tests { let file_model_counts = HashMap::from([(alias.clone(), 5_i64)]); let model_throughputs = HashMap::from([(alias.clone(), 0.001_f32)]); let model_ids_by_alias = HashMap::from([(alias.clone(), model_id)]); - let windows = vec![("1h".to_string(), super::parse_window_to_seconds("1h"))]; + let windows = vec![("1h".to_string(), None, super::parse_window_to_seconds("1h"))]; let states = vec!["pending".to_string(), "claimed".to_string(), "processing".to_string()]; let model_filter = vec![alias.clone()]; @@ -3286,7 +3286,7 @@ mod tests { let model_filter = vec![alias.clone()]; // 1h window — strict (factor defaults to 1.0), 5 > 3.6, rejected - let windows_1h = vec![("1h".to_string(), super::parse_window_to_seconds("1h"))]; + let windows_1h = vec![("1h".to_string(), None, super::parse_window_to_seconds("1h"))]; let err = super::reserve_capacity_for_batch( &state, "1h", @@ -3303,7 +3303,7 @@ mod tests { assert!(matches!(err, Error::TooManyRequests { .. }), "1h should be rejected — not relaxed"); // 24h window — factor=10.0, effective capacity = 86400 * 0.001 * 10 = 864, accepted - let windows_24h = vec![("24h".to_string(), super::parse_window_to_seconds("24h"))]; + let windows_24h = vec![("1h".to_string(), None, super::parse_window_to_seconds("24h"))]; let reservation_ids = super::reserve_capacity_for_batch( &state, "24h", diff --git a/dwctl/src/api/handlers/queue.rs b/dwctl/src/api/handlers/queue.rs index c807b2955..5267d12ad 100644 --- a/dwctl/src/api/handlers/queue.rs +++ b/dwctl/src/api/handlers/queue.rs @@ -2,8 +2,12 @@ //! //! Endpoints for querying queue depth and pending request metrics from fusillade. -use axum::{extract::State, response::Json}; +use axum::{ + extract::{Query, State}, + response::Json, +}; use fusillade::Storage; +use serde::Deserialize; use sqlx_pool_router::PoolProvider; use std::collections::HashMap; @@ -15,9 +19,72 @@ use crate::{ errors::Error, }; -/// Nested map of pending request counts: model -> completion_window -> count +/// Strict duration parser for `/demand` window entries. +/// +/// Unlike [`parse_window_to_seconds`] — which is forgiving on purpose for +/// the batch API (zero/negative/malformed input defaults to 24h) — this +/// returns `None` for anything malformed so the handler can reject the +/// request with 400. Zero is accepted; it's a meaningful lower bound +/// (`0s:1h` = "strictly future 0..1h"). +fn parse_demand_duration(raw: &str) -> Option { + let s = raw.trim(); + let (digits, mult): (&str, i64) = if let Some(d) = s.strip_suffix('h') { + (d, 3600) + } else if let Some(d) = s.strip_suffix('m') { + (d, 60) + } else if let Some(d) = s.strip_suffix('s') { + (d, 1) + } else { + return None; + }; + let n: i64 = digits.parse().ok()?; + if n < 0 { + return None; + } + n.checked_mul(mult) +} + +/// Nested map of pending request counts: model -> window_label -> count type PendingCountsByModelAndWindow = HashMap>; +/// Query parameters for the demand endpoint. +#[derive(Debug, Deserialize)] +pub struct DemandQuery { + /// Comma-separated windows, each either `` (shorthand for + /// `0s:`) or `:`. Both `start` and `end` are offsets + /// from `now` and accept the same `` form as batch + /// completion-window strings (`h`, `m`, `s`). Required. + pub window: String, +} + +/// Parse one entry from the `window=` query list. +/// +/// Returns `Ok(None)` for an empty (skipped) entry, `Ok(Some(...))` for a +/// valid entry, or `Err` for malformed input. Shorthand `` returns +/// `start = None` (no lower bound, including overdue — matches the legacy +/// `<= now + N` behaviour of `/pending-request-counts`). Explicit +/// `:` returns `start = Some(...)` and enforces the lower bound +/// strictly. The label is the caller's raw input so scouter can send +/// `window=1h,24h` and still match `"1h"` / `"24h"` keys on the response. +fn parse_demand_window(raw: &str) -> Result, i64)>, String> { + let trimmed = raw.trim(); + if trimmed.is_empty() { + return Ok(None); + } + let (start_secs, end_secs) = match trimmed.split_once(':') { + Some((start, end)) => { + let s = parse_demand_duration(start).ok_or_else(|| format!("malformed window start in {:?}", trimmed))?; + let e = parse_demand_duration(end).ok_or_else(|| format!("malformed window end in {:?}", trimmed))?; + (Some(s), e) + } + None => { + let e = parse_demand_duration(trimmed).ok_or_else(|| format!("malformed window {:?}", trimmed))?; + (None, e) + } + }; + Ok(Some((trimmed.to_string(), start_secs, end_secs))) +} + /// Get pending, claimed, and processing request counts grouped by model and completion window /// /// Returns a nested map showing how many pending requests are queued for each @@ -48,14 +115,14 @@ pub async fn get_pending_request_counts( .batches .allowed_completion_windows .iter() - .map(|window| (window.clone(), parse_window_to_seconds(window))) + .map(|window| (window.clone(), None, parse_window_to_seconds(window))) .collect::>(); let states = vec!["pending".to_string(), "claimed".to_string(), "processing".to_string()]; // Include claimed and processing to get a more complete picture of queue depth let model_filter: Vec = Vec::new(); let counts = state .request_manager - .get_pending_request_counts_by_model_and_completion_window(&windows, &states, &model_filter, false) + .get_pending_request_counts_by_model_and_window(&windows, &states, &model_filter, false) .await .map_err(|e| Error::Internal { operation: format!("get pending request counts: {}", e), @@ -64,6 +131,77 @@ pub async fn get_pending_request_counts( Ok(Json(counts)) } +/// Get pending request demand bucketed by deadline window. +/// +/// Returns, per model, counts of pending/claimed/processing requests whose +/// deadline (`submitted_at + completion_window`) falls within each +/// caller-specified window. Each window is either `` (shorthand for +/// `0s:`, matching the legacy "due within N" semantic) or +/// `:` for a disjoint range. Both bounds are offsets from +/// `now`. +/// +/// Windows can overlap or be disjoint — the caller chooses. This endpoint +/// is deliberately decoupled from `config.batches.allowed_completion_windows` +/// so replica-allocation consumers can pick the lookahead shape they care +/// about independently of whatever completion-window SLAs the batch API +/// exposes to users. +/// +/// Excludes the same categories as `/pending-request-counts`: escalated +/// requests, requests without a template_id, and requests in batches being +/// cancelled. +#[utoipa::path( + get, + path = "/admin/api/v1/monitoring/demand", + params( + ( + "window" = String, + Query, + description = "Comma-separated windows, e.g. `1h,24h` (cumulative) or `0s:1h,1h:24h` (disjoint)", + ), + ), + responses( + (status = 200, description = "Pending request counts by model and window", body = HashMap>), + (status = 400, description = "Missing or malformed window parameter"), + (status = 500, description = "Internal server error"), + ), + tag = "monitoring", +)] +#[tracing::instrument(skip_all)] +pub async fn get_demand( + State(state): State>, + Query(params): Query, + _: RequiresPermission, +) -> Result, Error> { + let windows: Vec<(String, Option, i64)> = params + .window + .split(',') + .map(parse_demand_window) + .collect::, _>>() + .map_err(|message| Error::BadRequest { message })? + .into_iter() + .flatten() + .collect(); + + if windows.is_empty() { + return Err(Error::BadRequest { + message: "window query parameter must list at least one window (e.g. `window=1h,24h` or `window=0s:1h,1h:24h`)".to_string(), + }); + } + + let states = vec!["pending".to_string(), "claimed".to_string(), "processing".to_string()]; + let model_filter: Vec = Vec::new(); + + let counts = state + .request_manager + .get_pending_request_counts_by_model_and_window(&windows, &states, &model_filter, false) + .await + .map_err(|e| Error::Internal { + operation: format!("get demand by window: {}", e), + })?; + + Ok(Json(counts)) +} + #[cfg(test)] mod tests { use super::*; @@ -113,4 +251,100 @@ mod tests { // Should be empty when no requests exist assert_eq!(counts.len(), 0, "Should have no pending requests"); } + + #[sqlx::test] + async fn test_demand_requires_system_permission(pool: sqlx::PgPool) { + let (server, _bg): (TestServer, _) = create_test_app(pool.clone(), false).await; + + let standard_user = create_test_user(&pool, Role::StandardUser).await; + let response = server + .get("/admin/api/v1/monitoring/demand?window=1h,24h") + .add_header(&add_auth_headers(&standard_user)[0].0, &add_auth_headers(&standard_user)[0].1) + .add_header(&add_auth_headers(&standard_user)[1].0, &add_auth_headers(&standard_user)[1].1) + .await; + response.assert_status(axum::http::StatusCode::FORBIDDEN); + + let platform_manager = create_test_user(&pool, Role::PlatformManager).await; + let response = server + .get("/admin/api/v1/monitoring/demand?window=1h,24h") + .add_header(&add_auth_headers(&platform_manager)[0].0, &add_auth_headers(&platform_manager)[0].1) + .add_header(&add_auth_headers(&platform_manager)[1].0, &add_auth_headers(&platform_manager)[1].1) + .await; + response.assert_status_ok(); + } + + #[sqlx::test] + async fn test_demand_rejects_missing_window(pool: PgPool) { + let (server, _bg): (TestServer, _) = create_test_app(pool.clone(), false).await; + let admin = create_test_admin_user(&pool, Role::PlatformManager).await; + + let response = server + .get("/admin/api/v1/monitoring/demand") + .add_header(&add_auth_headers(&admin)[0].0, &add_auth_headers(&admin)[0].1) + .add_header(&add_auth_headers(&admin)[1].0, &add_auth_headers(&admin)[1].1) + .await; + response.assert_status(axum::http::StatusCode::BAD_REQUEST); + } + + #[sqlx::test] + async fn test_demand_rejects_empty_window(pool: PgPool) { + let (server, _bg): (TestServer, _) = create_test_app(pool.clone(), false).await; + let admin = create_test_admin_user(&pool, Role::PlatformManager).await; + + let response = server + .get("/admin/api/v1/monitoring/demand?window=") + .add_header(&add_auth_headers(&admin)[0].0, &add_auth_headers(&admin)[0].1) + .add_header(&add_auth_headers(&admin)[1].0, &add_auth_headers(&admin)[1].1) + .await; + response.assert_status(axum::http::StatusCode::BAD_REQUEST); + } + + #[sqlx::test] + async fn test_demand_accepts_arbitrary_windows(pool: PgPool) { + // Caller-supplied windows don't need to match + // config.batches.allowed_completion_windows — the point of this + // endpoint is to decouple the two. Mixing cumulative (`2h`) and + // disjoint (`1h:72h`) shapes should work in the same request. + let (server, _bg): (TestServer, _) = create_test_app(pool.clone(), false).await; + let admin = create_test_admin_user(&pool, Role::PlatformManager).await; + + let response = server + .get("/admin/api/v1/monitoring/demand?window=15m,2h,1h:72h") + .add_header(&add_auth_headers(&admin)[0].0, &add_auth_headers(&admin)[0].1) + .add_header(&add_auth_headers(&admin)[1].0, &add_auth_headers(&admin)[1].1) + .await; + response.assert_status_ok(); + let counts: HashMap> = response.json(); + assert_eq!(counts.len(), 0, "no pending requests exist in a clean database"); + } + + #[sqlx::test] + async fn test_demand_accepts_zero_start(pool: PgPool) { + // `0s:1h` must parse `0s` as zero seconds (not coerce to 24h like + // the lenient batch-window parser does). Regression guard. + let (server, _bg): (TestServer, _) = create_test_app(pool.clone(), false).await; + let admin = create_test_admin_user(&pool, Role::PlatformManager).await; + + let response = server + .get("/admin/api/v1/monitoring/demand?window=0s:1h") + .add_header(&add_auth_headers(&admin)[0].0, &add_auth_headers(&admin)[0].1) + .add_header(&add_auth_headers(&admin)[1].0, &add_auth_headers(&admin)[1].1) + .await; + response.assert_status_ok(); + } + + #[sqlx::test] + async fn test_demand_rejects_malformed_window(pool: PgPool) { + let (server, _bg): (TestServer, _) = create_test_app(pool.clone(), false).await; + let admin = create_test_admin_user(&pool, Role::PlatformManager).await; + + for bad in ["window=foo", "window=1x", "window=1h,bad", "window=-1h:1h"] { + let response = server + .get(&format!("/admin/api/v1/monitoring/demand?{}", bad)) + .add_header(&add_auth_headers(&admin)[0].0, &add_auth_headers(&admin)[0].1) + .add_header(&add_auth_headers(&admin)[1].0, &add_auth_headers(&admin)[1].1) + .await; + response.assert_status(axum::http::StatusCode::BAD_REQUEST); + } + } } diff --git a/dwctl/src/lib.rs b/dwctl/src/lib.rs index 19a407378..ccb7e9779 100644 --- a/dwctl/src/lib.rs +++ b/dwctl/src/lib.rs @@ -1245,6 +1245,7 @@ pub async fn build_router( "/monitoring/pending-request-counts", get(api::handlers::queue::get_pending_request_counts), ) + .route("/monitoring/demand", get(api::handlers::queue::get_demand)) // Tool sources CRUD .route("/tool-sources", get(api::handlers::tool_sources::list_tool_sources)) .route("/tool-sources", post(api::handlers::tool_sources::create_tool_source)) diff --git a/dwctl/src/openapi/admin.rs b/dwctl/src/openapi/admin.rs index 4e1df36fc..3269f01e1 100644 --- a/dwctl/src/openapi/admin.rs +++ b/dwctl/src/openapi/admin.rs @@ -119,6 +119,7 @@ impl Modify for AdminSecurityAddon { api::handlers::requests::aggregate_requests, api::handlers::requests::aggregate_by_user, api::handlers::queue::get_pending_request_counts, + api::handlers::queue::get_demand, ), components( schemas(