diff --git a/AGENTS.md b/AGENTS.md index 4b37b050..fcc41930 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -18,6 +18,7 @@ Agent-assisted contributions are welcome, but should be **supervised** and **rev - **Primary server binary**: `skit` (crate: `streamkit-server`). - **Dev task runner**: `just` (see `justfile`). - **Docs**: Astro + Starlight in `docs/` (sidebar in `docs/astro.config.mjs`). +- **UI tooling**: Bun-first. Use `bun install`, `bunx` (or `bun run` scripts) for UI work—avoid npm/pnpm. ## Workflow expectations diff --git a/Cargo.lock b/Cargo.lock index b06c8d1a..934b85f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4763,6 +4763,7 @@ dependencies = [ "tokio-stream", "tokio-test", "tokio-tungstenite 0.28.0", + "tokio-util 0.7.18", "toml 0.9.11+spec-1.1.0", "tower", "tower-http", diff --git a/apps/skit-cli/src/client.rs b/apps/skit-cli/src/client.rs index 3428a6e9..d775c19a 100644 --- a/apps/skit-cli/src/client.rs +++ b/apps/skit-cli/src/client.rs @@ -16,6 +16,14 @@ use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; use tracing::{debug, error, info}; use url::Url; +/// Represents one multipart input file for oneshot execution. +#[derive(Debug, Clone)] +pub struct InputFile { + pub field: String, + pub path: String, + pub content_type: Option, +} + fn http_base_url(server_url: &str) -> Result> { let mut url = Url::parse(server_url)?; match url.scheme() { @@ -167,12 +175,12 @@ fn parse_batch_operations( #[allow(clippy::cognitive_complexity)] pub async fn process_oneshot( pipeline_path: &str, - input_path: &str, + inputs: &[InputFile], output_path: &str, server_url: &str, ) -> Result<(), Box> { let client = reqwest::Client::new(); - process_oneshot_with_client(&client, pipeline_path, input_path, output_path, server_url).await + process_oneshot_with_client(&client, pipeline_path, inputs, output_path, server_url).await } /// Process a pipeline using a remote server in oneshot mode with a caller-provided HTTP client. @@ -192,13 +200,17 @@ pub async fn process_oneshot( pub async fn process_oneshot_with_client( client: &reqwest::Client, pipeline_path: &str, - input_path: &str, + inputs: &[InputFile], output_path: &str, server_url: &str, ) -> Result<(), Box> { + if inputs.is_empty() { + return Err("At least one input file is required".into()); + } + info!( pipeline = %pipeline_path, - input = %input_path, + inputs = inputs.len(), output = %output_path, server = %server_url, "Starting oneshot pipeline processing" @@ -208,31 +220,41 @@ pub async fn process_oneshot_with_client( if !Path::new(pipeline_path).exists() { return Err(format!("Pipeline file not found: {pipeline_path}").into()); } - if !Path::new(input_path).exists() { - return Err(format!("Input file not found: {input_path}").into()); + for input in inputs { + if !Path::new(&input.path).exists() { + return Err(format!("Input file not found: {}", input.path).into()); + } } // Read pipeline configuration debug!("Reading pipeline configuration from {pipeline_path}"); let pipeline_content = fs::read_to_string(pipeline_path).await?; - // Read input media file - debug!("Reading input media file from {input_path}"); - let media_data = fs::read(input_path).await?; - - // Extract filename for the multipart form - let input_filename = Path::new(input_path) - .file_name() - .and_then(|name| name.to_str()) - .unwrap_or("input") - .to_string(); - // Create multipart form - let media_len = media_data.len(); - debug!("Creating multipart form with {media_len} bytes of media data"); - let form = multipart::Form::new() - .text("config", pipeline_content) - .part("media", multipart::Part::bytes(media_data).file_name(input_filename)); + let mut form = multipart::Form::new().text("config", pipeline_content); + for input in inputs { + debug!("Reading input media file from {}", input.path); + let media_data = fs::read(&input.path).await?; + let media_len = media_data.len(); + + let input_filename = Path::new(&input.path) + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or("input") + .to_string(); + + debug!( + "Adding multipart field '{}' with {} bytes (file: {})", + input.field, media_len, input_filename + ); + + let mut part = multipart::Part::bytes(media_data).file_name(input_filename); + if let Some(ct) = &input.content_type { + part = part.mime_str(ct)?; + } + + form = form.part(input.field.clone(), part); + } // Send request to server let url = http_base_url(server_url)?.join("/api/v1/process")?; diff --git a/apps/skit-cli/src/lib.rs b/apps/skit-cli/src/lib.rs index a6725a5d..e8a556ac 100644 --- a/apps/skit-cli/src/lib.rs +++ b/apps/skit-cli/src/lib.rs @@ -17,7 +17,7 @@ pub use client::{ destroy_session, get_config, get_permissions, get_pipeline, get_sample, list_audio_assets, list_node_schemas, list_packet_schemas, list_plugins, list_samples_dynamic, list_samples_oneshot, list_sessions, process_oneshot, save_sample, tune_node, - upload_audio_asset, upload_plugin, watch_events, + upload_audio_asset, upload_plugin, watch_events, InputFile, }; pub use load_test::run_load_test; diff --git a/apps/skit-cli/src/load_test/workers.rs b/apps/skit-cli/src/load_test/workers.rs index d68c0fc9..7c90c589 100644 --- a/apps/skit-cli/src/load_test/workers.rs +++ b/apps/skit-cli/src/load_test/workers.rs @@ -44,7 +44,11 @@ pub async fn oneshot_worker( let result = process_oneshot_with_client( &client, pipeline_path, - input_path, + &[crate::client::InputFile { + field: "media".to_string(), + path: input_path.clone(), + content_type: None, + }], output_path, &config.server.url, ) diff --git a/apps/skit-cli/src/main.rs b/apps/skit-cli/src/main.rs index ca28891b..df372625 100644 --- a/apps/skit-cli/src/main.rs +++ b/apps/skit-cli/src/main.rs @@ -2,7 +2,8 @@ // // SPDX-License-Identifier: MPL-2.0 -use clap::{Parser, Subcommand}; +use clap::{ArgAction, Parser, Subcommand}; +use streamkit_client::InputFile; use tracing::{error, info}; #[derive(Parser, Debug)] @@ -12,6 +13,22 @@ struct Cli { command: Commands, } +#[derive(Debug, Clone)] +struct FieldPath { + field: String, + path: String, +} + +fn parse_field_path(s: &str) -> Result { + let mut parts = s.splitn(2, '='); + let field = parts.next().unwrap_or("").trim(); + let path = parts.next().unwrap_or("").trim(); + if field.is_empty() || path.is_empty() { + return Err("expected form name=path".to_string()); + } + Ok(FieldPath { field: field.to_string(), path: path.to_string() }) +} + #[derive(Subcommand, Debug)] enum Commands { /// Process a pipeline using a remote server (oneshot mode) @@ -19,8 +36,11 @@ enum Commands { OneShot { /// Path to the pipeline YAML file pipeline: String, - /// Input media file path + /// Primary input media file path (multipart field defaults to 'media') input: String, + /// Additional input fields in the form name=path (repeatable) + #[arg(long = "input", value_parser = parse_field_path, action = ArgAction::Append)] + extra_input: Vec, /// Output file path output: String, /// Server URL (default: http://127.0.0.1:4545) @@ -329,11 +349,17 @@ async fn main() { let cli = Cli::parse(); match cli.command { - Commands::OneShot { pipeline, input, output, server } => { + Commands::OneShot { pipeline, input, extra_input, output, server } => { info!("Starting StreamKit client - oneshot processing"); + let mut inputs = Vec::new(); + inputs.push(InputFile { field: "media".to_string(), path: input, content_type: None }); + for extra in extra_input { + inputs.push(InputFile { field: extra.field, path: extra.path, content_type: None }); + } + if let Err(e) = - streamkit_client::process_oneshot(&pipeline, &input, &output, &server).await + streamkit_client::process_oneshot(&pipeline, &inputs, &output, &server).await { // Error already logged via tracing above error!(error = %e, "Failed to process oneshot pipeline"); diff --git a/apps/skit-cli/src/shell.rs b/apps/skit-cli/src/shell.rs index 9242440a..103bb0b5 100644 --- a/apps/skit-cli/src/shell.rs +++ b/apps/skit-cli/src/shell.rs @@ -623,7 +623,12 @@ impl Shell { // Use the existing process_oneshot function from client.rs // This makes a multipart HTTP POST to /api/v1/process - crate::client::process_oneshot(pipeline_path, input_path, output_path, &http_url).await?; + let inputs = vec![crate::client::InputFile { + field: "media".to_string(), + path: input_path.to_string(), + content_type: None, + }]; + crate::client::process_oneshot(pipeline_path, &inputs, output_path, &http_url).await?; println!("✅ Oneshot processing completed successfully"); diff --git a/apps/skit/Cargo.toml b/apps/skit/Cargo.toml index 3789b505..3d48b198 100644 --- a/apps/skit/Cargo.toml +++ b/apps/skit/Cargo.toml @@ -48,6 +48,7 @@ anyhow = "1.0" # For HTTP server axum = { version = "0.8", features = ["multipart", "ws"] } tokio = { workspace = true, features = ["full"] } +tokio-util = { workspace = true } tower = "0.5.3" tower-http = { version = "0.6", features = ["cors", "trace", "fs", "set-header"] } tokio-stream = "0.1.18" diff --git a/apps/skit/src/bin/gen-docs-reference.rs b/apps/skit/src/bin/gen-docs-reference.rs index 67b233a8..60b6a557 100644 --- a/apps/skit/src/bin/gen-docs-reference.rs +++ b/apps/skit/src/bin/gen-docs-reference.rs @@ -116,7 +116,39 @@ fn add_synthetic_oneshot_nodes(defs: &mut Vec) { Receives binary data from the HTTP request body." .to_string(), ), - param_schema: serde_json::json!({}), + param_schema: serde_json::json!({ + "type": "object", + "additionalProperties": false, + "properties": { + "field": { + "type": "string", + "description": "Multipart field name to bind to this input. Defaults to 'media' when only one http_input node exists; otherwise defaults to the node id." + }, + "fields": { + "type": "array", + "description": "Optional list of multipart fields for this node. When set, the node exposes one output pin per entry (pin name matches the field name). Entries may be strings or objects with { name, required }.", + "items": { + "oneOf": [ + { "type": "string" }, + { + "type": "object", + "additionalProperties": false, + "properties": { + "name": { "type": "string" }, + "required": { "type": "boolean", "default": true } + }, + "required": ["name"] + } + ] + } + }, + "required": { + "type": "boolean", + "description": "If true (default), the request must include this field.", + "default": true + } + } + }), inputs: vec![], outputs: vec![OutputPin { name: "out".to_string(), diff --git a/apps/skit/src/server.rs b/apps/skit/src/server.rs index 8201b841..27d3c6da 100644 --- a/apps/skit/src/server.rs +++ b/apps/skit/src/server.rs @@ -18,13 +18,14 @@ use bytes::Bytes; use multer as raw_multer; use opentelemetry::{global, KeyValue}; use rust_embed::RustEmbed; +use std::collections::{HashMap, HashSet}; use std::convert::Infallible; use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::sync::OnceLock; use std::task::{Context as TaskContext, Poll}; -use std::time::Instant; +use std::time::{Duration, Instant}; use tower::limit::ConcurrencyLimitLayer; use tower::ServiceBuilder; use tower_http::{ @@ -44,12 +45,13 @@ use streamkit_api::Pipeline; use streamkit_api::{ApiPipeline, Event as ApiEvent, EventPayload, MessageType}; use streamkit_core::control::EngineControlMessage; use streamkit_core::error::StreamKitError; -use streamkit_engine::{Engine, OneshotEngineConfig}; +use streamkit_engine::{Engine, OneshotEngineConfig, OneshotInput}; use crate::session::SessionManager; use crate::config::Config; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; use anyhow::Error as AnyhowError; use futures::{Stream, StreamExt}; @@ -497,7 +499,39 @@ async fn list_node_definitions_handler( Receives binary data from the HTTP request body." .to_string(), ), - param_schema: serde_json::json!({}), + param_schema: serde_json::json!({ + "type": "object", + "additionalProperties": false, + "properties": { + "field": { + "type": "string", + "description": "Multipart field name to bind to this input. Defaults to 'media' when only one http_input node exists; otherwise defaults to the node id." + }, + "fields": { + "type": "array", + "description": "Optional list of multipart fields for this node. When set, the node exposes one output pin per entry (pin name matches the field name). Entries may be strings or objects with { name, required }.", + "items": { + "oneOf": [ + { "type": "string" }, + { + "type": "object", + "additionalProperties": false, + "properties": { + "name": { "type": "string" }, + "required": { "type": "boolean", "default": true } + }, + "required": ["name"] + } + ] + } + }, + "required": { + "type": "boolean", + "description": "If true (default), the request must include this field.", + "default": true + } + } + }), inputs: vec![], outputs: vec![OutputPin { name: "out".to_string(), @@ -1283,12 +1317,12 @@ async fn get_pipeline_handler( Ok(Json(api_pipeline)) } -/// Result of parsing multipart request with config and optional media stream -struct MultipartParseResult { - user_pipeline: UserPipeline, - media_stream: MediaStream, - media_content_type: Option, - has_media: bool, +/// Binding between a multipart field and an http_input node. +struct HttpInputBinding { + node_id: String, + field_name: String, + output_pin: String, + required: bool, } /// Extract content-type header and multipart boundary from request headers. @@ -1316,7 +1350,7 @@ async fn parse_config_field( let first_name = first_field.name().map(std::string::ToString::to_string).unwrap_or_default(); if first_name != "config" { return Err(AppError::BadRequest( - "Multipart fields must be ordered: 'config' first, then 'media'".to_string(), + "Multipart fields must be ordered: 'config' first".to_string(), )); } @@ -1327,72 +1361,223 @@ async fn parse_config_field( serde_saphyr::from_slice(&config_bytes).map_err(Into::into) } -/// Parse the multipart request and extract config and media stream. -/// Returns the parsed pipeline config, media stream (possibly empty), content type, and whether media was provided. +/// Build http_input bindings from the pipeline definition. /// -/// This needs to stay relatively monolithic because it combines multipart streaming -/// with a spawned task, and the `Multipart<'_>` lifetime makes further extraction awkward. -#[allow(clippy::cognitive_complexity)] -async fn parse_multipart_request( - req: axum::extract::Request, -) -> Result { - let headers = req.headers().clone(); - let boundary = extract_multipart_boundary(&headers)?; - let body_stream = req.into_body().into_data_stream(); - let mut multipart = raw_multer::Multipart::new(body_stream, boundary); - - // Parse the config field - let user_pipeline = parse_config_field(&mut multipart).await?; +/// Defaults: +/// - Single http_input: field name defaults to "media" +/// - Multiple http_input: field names default to the node id +fn determine_http_input_bindings( + pipeline_def: &Pipeline, +) -> Result, AppError> { + // Record which output pins the pipeline references for each http_input node + let mut pins_used: HashMap> = HashMap::new(); + for conn in &pipeline_def.connections { + if let Some(node_def) = pipeline_def.nodes.get(&conn.from_node) { + if node_def.kind == "streamkit::http_input" { + pins_used.entry(conn.from_node.clone()).or_default().insert(conn.from_pin.clone()); + } + } + } - // Setup channels for streaming media field - let (media_tx, media_rx) = tokio::sync::mpsc::channel::>(16); - let (ct_tx, ct_rx) = tokio::sync::oneshot::channel::>(); - let (has_media_tx, has_media_rx) = tokio::sync::oneshot::channel::(); + let http_inputs: Vec<(&String, &streamkit_api::Node)> = pipeline_def + .nodes + .iter() + .filter(|(_, node)| node.kind == "streamkit::http_input") + .collect(); - // Spawn producer task to stream media field - // Note: Must remain inline due to Multipart<'r> lifetime constraints - tokio::spawn(async move { - while let Ok(next) = multipart.next_field().await { - if let Some(mut field) = next { - let fname = field.name().map(std::string::ToString::to_string).unwrap_or_default(); - if fname == "media" { - let _ = has_media_tx.send(true); - let _ = ct_tx.send(field.content_type().map(std::string::ToString::to_string)); - stream_media_field_chunks(&mut field, &media_tx).await; - drop(media_tx); - break; + let default_field = if http_inputs.len() == 1 { Some("media".to_string()) } else { None }; + let mut seen_fields: HashSet = HashSet::new(); + let mut bindings = Vec::new(); + + for (node_id, node_def) in http_inputs { + let mut node_bindings: Vec = Vec::new(); + let mut single_field: Option = None; + let mut single_required = true; + let mut has_fields_param = false; + let mut has_single_field_param = false; + + if let Some(params) = &node_def.params { + if let Some(fields_val) = params.get("fields") { + has_fields_param = true; + let fields = fields_val.as_array().ok_or_else(|| { + AppError::BadRequest( + "streamkit::http_input.params.fields must be an array of strings or objects" + .to_string(), + ) + })?; + + for entry in fields { + let (name, required) = match entry { + serde_json::Value::String(s) => (s.clone(), true), + serde_json::Value::Object(map) => { + let Some(name_val) = map.get("name") else { + return Err(AppError::BadRequest( + "fields entries must include 'name'".to_string(), + )); + }; + let name = name_val + .as_str() + .ok_or_else(|| { + AppError::BadRequest("fields.name must be a string".to_string()) + })? + .trim() + .to_string(); + if name.is_empty() { + return Err(AppError::BadRequest( + "fields.name must not be empty".to_string(), + )); + } + let required = map + .get("required") + .and_then(serde_json::Value::as_bool) + .unwrap_or(true); + (name, required) + }, + _ => { + return Err(AppError::BadRequest( + "fields entries must be strings or objects".to_string(), + )) + }, + }; + + node_bindings.push(HttpInputBinding { + node_id: node_id.clone(), + field_name: name, + output_pin: String::new(), + required, + }); + } + } else if let Some(field_val) = params.get("field").and_then(serde_json::Value::as_str) + { + has_single_field_param = true; + let trimmed = field_val.trim(); + if !trimmed.is_empty() { + single_field = Some(trimmed.to_string()); + } + if let Some(req_val) = params.get("required").and_then(serde_json::Value::as_bool) { + single_required = req_val; } - tracing::warn!("Ignoring unknown multipart field: {}", fname); - } else { - let _ = has_media_tx.send(false); - tracing::debug!("No media field found in multipart"); - break; } } - }); - // Wait to see if media field exists (timeout prevents hanging on slow/broken clients). - // If this times out, fail the request instead of guessing "no media". - let has_media = tokio::time::timeout(std::time::Duration::from_secs(5), has_media_rx) - .await - .map_err(|_| { - AppError::BadRequest("Timed out waiting for multipart media field".to_string()) - })? - .unwrap_or(false); + if has_fields_param && has_single_field_param { + return Err(AppError::BadRequest( + "streamkit::http_input: use either 'field' or 'fields', not both".to_string(), + )); + } + + if has_fields_param && node_bindings.is_empty() { + return Err(AppError::BadRequest( + "streamkit::http_input.params.fields must include at least one field".to_string(), + )); + } + + if node_bindings.is_empty() { + let field_name = + single_field.or_else(|| default_field.clone()).unwrap_or_else(|| node_id.clone()); + node_bindings.push(HttpInputBinding { + node_id: node_id.clone(), + field_name, + output_pin: String::new(), + required: single_required, + }); + } + + // Back-compat: allow implicit 'media' only when no fields array is provided. + if !has_fields_param + && default_field.as_deref() == Some("media") + && !node_bindings.iter().any(|b| b.field_name == "media") + { + node_bindings.push(HttpInputBinding { + node_id: node_id.clone(), + field_name: "media".to_string(), + output_pin: String::new(), + required: false, + }); + } - let media_stream: MediaStream = Box::new(ReceiverStream::new(media_rx).map(|x| x)); - let media_content_type: Option = ct_rx.await.ok().flatten(); + // Decide pin names based on referenced connections. Keep field names for multi-field mode, + // but allow legacy 'out' default when only one pin is referenced (steps format). + let used_pins = pins_used.get(node_id.as_str()).cloned().unwrap_or_default(); + for binding in &mut node_bindings { + let pin_name = if used_pins.contains(&binding.field_name) { + binding.field_name.clone() + } else if used_pins.len() == 1 && !has_fields_param { + // Legacy steps pipelines reference 'out' + used_pins.iter().next().cloned().unwrap_or_else(|| binding.field_name.clone()) + } else { + binding.field_name.clone() + }; + binding.output_pin = pin_name; + } + + for binding in node_bindings { + if !seen_fields.insert(binding.field_name.clone()) { + return Err(AppError::BadRequest(format!( + "Duplicate multipart field name '{field_name}' across http_input nodes", + field_name = binding.field_name + ))); + } + bindings.push(binding); + } + } - Ok(MultipartParseResult { user_pipeline, media_stream, media_content_type, has_media }) + Ok(bindings) } /// Stream all chunks from a media field through the provided channel. async fn stream_media_field_chunks( field: &mut raw_multer::Field<'_>, media_tx: &tokio::sync::mpsc::Sender>, + cancellation_token: Option<&CancellationToken>, ) { let mut chunk_count: usize = 0; let mut total_bytes: usize = 0; + + if let Some(token) = cancellation_token { + loop { + tokio::select! { + () = token.cancelled() => { + tracing::info!( + "Stopped streaming media early after {} chunks ({} bytes) due to cancellation", + chunk_count, + total_bytes + ); + break; + } + chunk_result = field.chunk() => { + match chunk_result { + Ok(Some(chunk)) => { + chunk_count += 1; + total_bytes += chunk.len(); + if media_tx.send(Ok(chunk)).await.is_err() { + tracing::debug!( + "Media consumer dropped after {} chunks ({} bytes)", + chunk_count, + total_bytes + ); + break; + } + }, + Ok(None) => { + tracing::info!( + "Finished streaming media after {} chunks ({} bytes)", + chunk_count, + total_bytes + ); + break; + }, + Err(e) => { + let _ = media_tx.send(Err(axum::Error::new(e))).await; + break; + }, + } + } + } + } + return; + } + loop { match field.chunk().await { Ok(Some(chunk)) => { @@ -1423,44 +1608,88 @@ async fn stream_media_field_chunks( } } -/// Validate that the pipeline has the required nodes based on whether media was provided. +/// Route multipart fields into pre-created channels based on expected names. +async fn route_multipart_fields( + mut multipart: raw_multer::Multipart<'_>, + mut field_senders: HashMap>>, + required_fields: HashSet, + mut required_seen_tx: Option>, + parse_done_tx: tokio::sync::oneshot::Sender>, + cancellation_token: CancellationToken, +) { + let mut seen_required: HashSet = HashSet::new(); + + let result = async { + while let Some(mut field) = multipart + .next_field() + .await + .map_err(|e| AppError::BadRequest(format!("Multipart error: {e}")))? + { + let fname = field.name().map(std::string::ToString::to_string).unwrap_or_default(); + if fname.is_empty() { + continue; + } + + let Some(sender) = field_senders.remove(&fname) else { + let expected = if field_senders.is_empty() { + "none".to_string() + } else { + field_senders.keys().cloned().collect::>().join(", ") + }; + return Err(AppError::BadRequest(format!( + "Unexpected multipart field '{fname}'. Expected: {expected}" + ))); + }; + + if required_fields.contains(&fname) { + seen_required.insert(fname.clone()); + if seen_required.len() == required_fields.len() { + if let Some(tx) = required_seen_tx.take() { + let _ = tx.send(()); + } + } + } + + stream_media_field_chunks(&mut field, &sender, Some(&cancellation_token)).await; + } + + if !required_fields.is_empty() && seen_required.len() < required_fields.len() { + let missing: Vec<_> = required_fields.difference(&seen_required).cloned().collect(); + return Err(AppError::BadRequest(format!( + "Missing required multipart field(s): {}", + missing.join(", ") + ))); + } + + Ok(()) + } + .await; + + drop(field_senders); + + if let Some(tx) = required_seen_tx.take() { + let _ = tx.send(()); + } + + let _ = parse_done_tx.send(result); +} + +/// Validate that the pipeline has the required nodes for oneshot processing. /// Returns (has_http_input, has_file_read, has_http_output) for logging purposes. -fn validate_pipeline_nodes( - pipeline_def: &Pipeline, - has_media: bool, -) -> Result<(bool, bool, bool), AppError> { +fn validate_pipeline_nodes(pipeline_def: &Pipeline) -> Result<(bool, bool, bool), AppError> { let has_http_input = pipeline_def.nodes.values().any(|node| node.kind == "streamkit::http_input"); let has_http_output = pipeline_def.nodes.values().any(|node| node.kind == "streamkit::http_output"); let has_file_read = pipeline_def.nodes.values().any(|node| node.kind == "core::file_reader"); - // Validate entry point based on whether media was provided - if has_media { - // HTTP streaming mode: require http_input - if !has_http_input { - return Err(AppError::BadRequest( - "Pipeline must contain one 'streamkit::http_input' node when media is provided" - .to_string(), - )); - } - } else { - // File-based mode: require file_read, disallow http_input - if has_http_input { - return Err(AppError::BadRequest( - "Pipeline cannot contain 'streamkit::http_input' node when no media is provided" - .to_string(), - )); - } - if !has_file_read { - return Err(AppError::BadRequest( - "Pipeline must contain at least one 'core::file_reader' node when no media is provided" - .to_string(), - )); - } + if !has_http_input && !has_file_read { + return Err(AppError::BadRequest( + "Pipeline must contain at least one 'streamkit::http_input' or 'core::file_reader' node for oneshot processing" + .to_string(), + )); } - // Always require http_output for response streaming if !has_http_output { return Err(AppError::BadRequest( "Pipeline must contain one 'streamkit::http_output' node for oneshot processing" @@ -1706,17 +1935,21 @@ async fn process_oneshot_pipeline_handler( )); } - // Parse multipart request to get config and media stream - let parse_result = parse_multipart_request(req).await?; + // Parse multipart: read boundary + config first + let boundary = extract_multipart_boundary(req.headers())?; + let body_stream = req.into_body().into_data_stream(); + let mut multipart = raw_multer::Multipart::new(body_stream, boundary); + let user_pipeline = parse_config_field(&mut multipart).await?; // Compile pipeline definition tracing::debug!("Compiling user pipeline definition"); - let pipeline_def: Pipeline = compile(parse_result.user_pipeline)?; + let pipeline_def: Pipeline = compile(user_pipeline)?; tracing::debug!("Pipeline compilation completed"); + let input_bindings = determine_http_input_bindings(&pipeline_def)?; + // Validate pipeline structure - let (has_http_input, has_file_read, has_http_output) = - validate_pipeline_nodes(&pipeline_def, parse_result.has_media)?; + let (has_http_input, has_file_read, has_http_output) = validate_pipeline_nodes(&pipeline_def)?; // Enforce allowed node/plugin kinds for oneshot execution. // @@ -1742,23 +1975,81 @@ async fn process_oneshot_pipeline_handler( } } - // Validate file paths in file-based mode - if !parse_result.has_media { - validate_file_reader_paths(&pipeline_def, &app_state.config.security)?; - } - + // Validate file/script paths + validate_file_reader_paths(&pipeline_def, &app_state.config.security)?; validate_file_writer_paths(&pipeline_def, &app_state.config.security)?; validate_script_paths(&pipeline_def, &app_state.config.security)?; tracing::info!( "Pipeline validation passed: mode={}, has_http_input={}, has_file_read={}, has_http_output={}", - if parse_result.has_media { "http-streaming" } else { "file-based" }, + if has_http_input { "http-streaming" } else { "file-based" }, has_http_input, has_file_read, has_http_output ); tracing::info!(role = %role_name, "Executing oneshot pipeline for role"); + // Prepare multipart routing + let cancel_token = CancellationToken::new(); + let mut field_senders: HashMap>> = + HashMap::new(); + let mut engine_inputs = Vec::new(); + let mut required_fields: HashSet = HashSet::new(); + + let io_capacity = app_state + .config + .engine + .oneshot + .io_channel_capacity + .unwrap_or(streamkit_engine::constants::DEFAULT_ONESHOT_IO_CAPACITY); + + for binding in &input_bindings { + let (tx, rx) = tokio::sync::mpsc::channel::>(io_capacity); + if binding.required { + required_fields.insert(binding.field_name.clone()); + } + field_senders.insert(binding.field_name.clone(), tx); + + let media_stream: MediaStream = Box::new(ReceiverStream::new(rx).map(|x| x)); + engine_inputs.push(OneshotInput { + node_id: binding.node_id.clone(), + output_pin: binding.output_pin.clone(), + stream: media_stream, + content_type: None, + field_name: binding.field_name.clone(), + required: binding.required, + cancellation_token: Some(cancel_token.clone()), + }); + } + + let (required_seen_tx, required_seen_rx) = tokio::sync::oneshot::channel(); + let mut required_seen_tx = Some(required_seen_tx); + if required_fields.is_empty() { + if let Some(tx) = required_seen_tx.take() { + let _ = tx.send(()); + } + } + let (parse_done_tx, parse_done_rx) = tokio::sync::oneshot::channel(); + + // Spawn multipart routing task + let routing_task = tokio::spawn(route_multipart_fields( + multipart, + field_senders, + required_fields.clone(), + required_seen_tx, + parse_done_tx, + cancel_token.clone(), + )); + + // Wait for required fields to appear (prevents hanging on missing uploads) + tokio::time::timeout(Duration::from_secs(5), required_seen_rx) + .await + .map_err(|_| { + cancel_token.cancel(); + AppError::BadRequest("Timed out waiting for required multipart fields".to_string()) + })? + .map_err(|_| AppError::BadRequest("Failed to observe multipart state".into()))?; + // Execute oneshot pipeline tracing::info!("Starting oneshot pipeline execution"); let oneshot_start_time = Instant::now(); @@ -1791,10 +2082,9 @@ async fn process_oneshot_pipeline_handler( .engine .run_oneshot_pipeline( pipeline_def, - parse_result.media_stream, - parse_result.media_content_type, - parse_result.has_media, + engine_inputs, Some(oneshot_config), + Some(cancel_token.clone()), ) .await { @@ -1805,10 +2095,27 @@ async fn process_oneshot_pipeline_handler( Err(e) => { let labels = [KeyValue::new("status", "error")]; oneshot_duration_histogram.record(oneshot_start_time.elapsed().as_secs_f64(), &labels); + cancel_token.cancel(); return Err(e.into()); }, }; + // Ensure multipart routing finished cleanly + match parse_done_rx.await { + Ok(Ok(())) => {}, + Ok(Err(err)) => { + let labels = [KeyValue::new("status", "error")]; + oneshot_duration_histogram.record(oneshot_start_time.elapsed().as_secs_f64(), &labels); + cancel_token.cancel(); + return Err(err); + }, + Err(e) => { + cancel_token.cancel(); + return Err(AppError::BadRequest(format!("Multipart routing task aborted: {e}"))); + }, + } + let _ = routing_task.await; + // Build and return streaming response Ok(build_streaming_response(pipeline_result, oneshot_start_time, oneshot_duration_histogram)) } diff --git a/apps/skit/src/websocket_handlers.rs b/apps/skit/src/websocket_handlers.rs index 237a9f65..e9b76f45 100644 --- a/apps/skit/src/websocket_handlers.rs +++ b/apps/skit/src/websocket_handlers.rs @@ -314,7 +314,39 @@ fn handle_list_nodes(app_state: &AppState, perms: &Permissions) -> ResponsePaylo Receives binary data from the HTTP request body." .to_string(), ), - param_schema: serde_json::json!({}), + param_schema: serde_json::json!({ + "type": "object", + "additionalProperties": false, + "properties": { + "field": { + "type": "string", + "description": "Multipart field name to bind to this input. Defaults to 'media' when only one http_input node exists; otherwise defaults to the node id." + }, + "fields": { + "type": "array", + "description": "Optional list of multipart fields for this node. When set, the node exposes one output pin per entry (pin name matches the field name). Entries may be strings or objects with { name, required }.", + "items": { + "oneOf": [ + { "type": "string" }, + { + "type": "object", + "additionalProperties": false, + "properties": { + "name": { "type": "string" }, + "required": { "type": "boolean", "default": true } + }, + "required": ["name"] + } + ] + } + }, + "required": { + "type": "boolean", + "description": "If true (default), the request must include this field.", + "default": true + } + } + }), inputs: vec![], outputs: vec![OutputPin { name: "out".to_string(), diff --git a/crates/api/src/yaml.rs b/crates/api/src/yaml.rs index e1306763..6b194854 100644 --- a/crates/api/src/yaml.rs +++ b/crates/api/src/yaml.rs @@ -51,6 +51,16 @@ impl NeedsDependency { } } + /// Returns (node, from_pin) where from_pin is parsed from "node.pin" syntax if present. + fn node_and_pin(&self) -> (&str, Option<&str>) { + let label = self.node(); + if let Some((node, pin)) = label.split_once('.') { + (node, Some(pin)) + } else { + (label, None) + } + } + fn mode(&self) -> ConnectionMode { match self { Self::Simple(_) => ConnectionMode::default(), @@ -215,8 +225,8 @@ fn detect_cycles(user_nodes: &IndexMap) -> Result<(), String> let dependencies: Vec<&str> = match &node_def.needs { Needs::None => vec![], - Needs::Single(dep) => vec![dep.node()], - Needs::Multiple(deps) => deps.iter().map(NeedsDependency::node).collect(), + Needs::Single(dep) => vec![dep.node_and_pin().0], + Needs::Multiple(deps) => deps.iter().map(|d| d.node_and_pin().0).collect(), }; for dep_name in dependencies { @@ -274,7 +284,7 @@ fn compile_dag( }; for (idx, dep) in dependencies.iter().enumerate() { - let dep_name = dep.node(); + let (dep_name, from_pin) = dep.node_and_pin(); // Validate that the referenced node exists if !user_nodes.contains_key(dep_name) { @@ -289,7 +299,7 @@ fn compile_dag( connections.push(Connection { from_node: dep_name.to_string(), - from_pin: "out".to_string(), + from_pin: from_pin.unwrap_or("out").to_string(), to_node: node_name.clone(), to_pin, mode: dep.mode(), diff --git a/crates/engine/src/lib.rs b/crates/engine/src/lib.rs index 3d1e1844..1a96d045 100644 --- a/crates/engine/src/lib.rs +++ b/crates/engine/src/lib.rs @@ -35,7 +35,7 @@ mod dynamic_pin_distributor; pub use dynamic_config::DynamicEngineConfig; #[cfg(feature = "dynamic")] pub use dynamic_handle::DynamicEngineHandle; -pub use oneshot::{OneshotEngineConfig, OneshotPipelineResult}; +pub use oneshot::{OneshotEngineConfig, OneshotInput, OneshotPipelineResult}; // Import constants and types (within dynamic module) #[cfg(feature = "dynamic")] diff --git a/crates/engine/src/oneshot.rs b/crates/engine/src/oneshot.rs index cb9b40d1..6e612807 100644 --- a/crates/engine/src/oneshot.rs +++ b/crates/engine/src/oneshot.rs @@ -36,6 +36,7 @@ use streamkit_core::control::NodeControlMessage; use streamkit_core::error::StreamKitError; use streamkit_core::node::ProcessorNode; use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; /// Configuration for oneshot pipeline execution. #[derive(Debug, Clone)] @@ -64,12 +65,30 @@ pub struct OneshotPipelineResult { pub content_type: String, } +/// Binding between a multipart field and an `streamkit::http_input` node. +pub struct OneshotInput { + /// Node id of the `streamkit::http_input` instance to feed. + pub node_id: String, + /// Output pin name to send this stream on (typically matches the multipart field). + pub output_pin: String, + /// Incoming byte stream for this node. + pub stream: S, + /// Optional request content type associated with this stream. + pub content_type: Option, + /// Multipart field name (for logging/debugging). + pub field_name: String, + /// Whether the pipeline marked this input as required. + pub required: bool, + /// Cancellation token to stop reading if the pipeline is cancelled. + pub cancellation_token: Option, +} + impl Engine { /// Runs a pipeline as a self-contained, one-shot task from a streaming input. /// /// Supports two modes: - /// - HTTP streaming mode (`has_http_input=true`): Uses http_input node with media stream - /// - File-based mode (`has_http_input=false`): Uses file_read nodes reading from disk + /// - HTTP streaming mode (`inputs` non-empty): Uses http_input nodes with media streams + /// - File-based mode (`inputs` empty): Uses file_read nodes reading from disk /// /// # Errors /// @@ -86,10 +105,9 @@ impl Engine { pub async fn run_oneshot_pipeline( &self, definition: Pipeline, - mut input_stream: S, - input_content_type: Option, - has_http_input: bool, + inputs: Vec>, config: Option, + cancellation_token: Option, ) -> Result where S: Stream> + Send + Unpin + 'static, @@ -103,7 +121,6 @@ impl Engine { definition.connections.len() ); - // expect is documented in #[doc] Panics section above #[allow(clippy::expect_used)] let registry = { let guard = self @@ -113,15 +130,15 @@ impl Engine { guard.clone() }; - // --- 1. Find the special namespaced input and output nodes --- - let mut input_node_id: Option = None; + // --- 1. Identify key nodes --- let mut output_node_id: Option = None; let mut source_node_ids: Vec = Vec::new(); + let mut http_input_nodes: Vec = Vec::new(); for (name, def) in &definition.nodes { tracing::debug!("Found node '{}' of type '{}'", name, def.kind); if def.kind == "streamkit::http_input" { - input_node_id = Some(name.clone()); + http_input_nodes.push(name.clone()); } if def.kind == "streamkit::http_output" { output_node_id = Some(name.clone()); @@ -131,26 +148,23 @@ impl Engine { } } - // Validate based on mode + let has_http_input = !http_input_nodes.is_empty(); + if has_http_input { - // HTTP streaming mode: require http_input - if input_node_id.is_none() { - tracing::error!("Pipeline validation failed: missing streamkit::http_input node"); + if inputs.is_empty() { + tracing::error!( + "Pipeline validation failed: no input streams provided for http_input nodes" + ); return Err(StreamKitError::Configuration( - "Pipeline must contain one 'streamkit::http_input' node.".to_string(), + "Input streams are required for 'streamkit::http_input' nodes.".to_string(), )); } tracing::info!( - "HTTP streaming mode: input='{}', output='{}'", - // Safe unwrap: just validated input_node_id.is_some() above - { - #[allow(clippy::unwrap_used)] - input_node_id.as_ref().unwrap() - }, - output_node_id.as_ref().map_or("unknown", |s| s.as_str()) + "HTTP streaming mode: {} http_input node(s), output='{}'", + http_input_nodes.len(), + output_node_id.as_deref().unwrap_or("unknown") ); } else { - // File-based mode: ensure we have source nodes if source_node_ids.is_empty() { tracing::error!("Pipeline validation failed: no file_reader nodes found"); return Err(StreamKitError::Configuration( @@ -158,10 +172,19 @@ impl Engine { .to_string(), )); } + if !inputs.is_empty() { + tracing::error!( + "Pipeline validation failed: streams provided but no http_input nodes present" + ); + return Err(StreamKitError::Configuration( + "Multipart streams were provided but the pipeline has no 'streamkit::http_input' nodes." + .to_string(), + )); + } tracing::info!( "File-based mode: {} source node(s), output='{}'", source_node_ids.len(), - output_node_id.as_ref().map_or("unknown", |s| s.as_str()) + output_node_id.as_deref().unwrap_or("unknown") ); } @@ -172,12 +195,107 @@ impl Engine { ) })?; - // --- 2. Create channels for the I/O streams and cancellation token --- - let (input_stream_tx, input_stream_rx) = mpsc::channel(config.io_channel_capacity); + // --- 2. I/O channels and cancellation token --- let (output_stream_tx, output_stream_rx) = mpsc::channel(config.io_channel_capacity); - let cancellation_token = tokio_util::sync::CancellationToken::new(); + let cancellation_token = cancellation_token.unwrap_or_default(); tracing::debug!("Created I/O stream channels and cancellation token"); + // --- 2.5. Bind http_input streams --- + let mut nodes: HashMap> = HashMap::new(); + let mut provided_inputs: HashMap>> = HashMap::new(); + let mut first_input_content_type: Option = None; + + for input in inputs { + provided_inputs.entry(input.node_id.clone()).or_default().push(input); + } + + if has_http_input { + for node_id in &http_input_nodes { + let Some(bound_inputs) = provided_inputs.remove(node_id) else { + tracing::error!( + "Pipeline validation failed: no stream provided for http_input node '{}'", + node_id + ); + return Err(StreamKitError::Configuration(format!( + "No stream provided for http_input node '{node_id}'" + ))); + }; + + let mut per_pin_receivers: Vec<(String, mpsc::Receiver, Option)> = + Vec::new(); + + for input in bound_inputs { + if first_input_content_type.is_none() { + first_input_content_type.clone_from(&input.content_type); + } + + let (tx, rx) = mpsc::channel(config.io_channel_capacity); + per_pin_receivers.push(( + input.output_pin.clone(), + rx, + input.content_type.clone(), + )); + + let node_name = node_id.clone(); + let mut stream = input.stream; + let input_stream_tx = tx; + let input_pump_token = + input.cancellation_token.unwrap_or_else(|| cancellation_token.clone()); + let output_pin = input.output_pin.clone(); + + tokio::spawn(async move { + use futures::StreamExt; + let mut chunk_count = 0usize; + tracing::debug!( + "Input stream pump starting for node '{}', pin '{}'", + node_name, + output_pin + ); + loop { + tokio::select! { + () = input_pump_token.cancelled() => { + tracing::info!("Input stream pump for '{}.{}' cancelled after {} chunks", node_name, output_pin, chunk_count); + break; + } + chunk_result = stream.next() => { + match chunk_result { + Some(Ok(chunk)) => { + chunk_count += 1; + if input_stream_tx.send(chunk).await.is_err() { + tracing::warn!("Input node '{}.{}' closed before stream ended.", node_name, output_pin); + break; + } + } + Some(Err(e)) => { + tracing::error!("Error reading from input stream for '{}.{}': {}", node_name, output_pin, e); + break; + } + None => { + tracing::info!("Input stream pump for '{}.{}' finished after {} chunks", node_name, output_pin, chunk_count); + break; + } + } + } + } + } + }); + } + + tracing::debug!("Creating special input node '{}'", node_id); + let input_node = streamkit_nodes::core::bytes_input::BytesInputNode::with_streams( + per_pin_receivers, + ); + nodes.insert(node_id.clone(), Box::new(input_node)); + } + + if !provided_inputs.is_empty() { + let extras = provided_inputs.keys().cloned().collect::>().join(", "); + return Err(StreamKitError::Configuration(format!( + "Unexpected input streams provided for unknown http_input nodes: {extras}" + ))); + } + } + // --- 3. Validate that http_output is connected --- let final_node_id = definition .connections @@ -212,23 +330,7 @@ impl Engine { }; // --- 4. Instantiate all nodes for the pipeline --- - let mut nodes: HashMap> = HashMap::new(); - - // Manually create the special input node (only in HTTP streaming mode) - if has_http_input { - // Safe unwrap: validated input_node_id.is_some() when has_http_input is true - #[allow(clippy::unwrap_used)] - let input_id = input_node_id.as_ref().unwrap(); - tracing::debug!("Creating special input node '{}'", input_id); - let input_node = Box::new(streamkit_nodes::core::bytes_input::BytesInputNode::new( - input_stream_rx, - input_content_type.clone(), - )); - nodes.insert(input_id.clone(), input_node); - } - tracing::debug!("Creating special output node '{}'", output_node_id); - // Get output node definition - this should exist since output_node_id was found in pipeline let output_node_def = definition.nodes.get(&output_node_id).ok_or_else(|| { StreamKitError::Configuration(format!( "Output node '{output_node_id}' not found in pipeline definition" @@ -238,17 +340,14 @@ impl Engine { output_stream_tx, output_node_def.params.as_ref(), )?; - // Capture the configured content type before moving the node let configured_content_type = output_node.configured_content_type(); nodes.insert(output_node_id.clone(), Box::new(output_node)); - // Create the final node for insertion into the pipeline tracing::debug!("Adding final node '{}' to pipeline", final_node_id); let final_node_instance = registry.create_node(&final_node_def.kind, final_node_def.params.as_ref())?; nodes.insert(final_node_id.clone(), final_node_instance); - // Create all other standard processing nodes from the main registry. for (name, def) in &definition.nodes { if !nodes.contains_key(name) { tracing::debug!("Creating node '{}' of type '{}'", name, def.kind); @@ -268,16 +367,14 @@ impl Engine { tracing::info!("Created {} nodes total", nodes.len()); - // --- 5. Use the shared helper to wire up and spawn the graph --- + // --- 5. Wire and spawn --- tracing::info!("Wiring up and spawning pipeline graph"); let node_kinds: HashMap = definition.nodes.iter().map(|(name, def)| (name.clone(), def.kind.clone())).collect(); - // Shared audio buffer pool for hot paths (e.g., Opus decode). let audio_pool = self.audio_pool.clone(); - // Oneshot pipelines don't track state, so pass None for state_tx let live_nodes = graph_builder::wire_and_spawn_graph( nodes, &definition.connections, @@ -291,9 +388,7 @@ impl Engine { .await?; tracing::info!("Pipeline graph successfully spawned"); - // --- 5.5. Send Start signals to file_reader nodes --- - // Note: file_reader nodes need Start signals even in HTTP streaming mode - // (e.g., for mixing scenarios where you have both http_input and file_reader) + // --- 5.5. Start file readers (if any) --- if !source_node_ids.is_empty() { tracing::info!( "Sending Start signals to {} file_reader node(s)", @@ -315,62 +410,21 @@ impl Engine { } } - // --- 6. Spawn a task to pump the input stream into the graph (HTTP streaming mode only) --- - if has_http_input { - tracing::debug!("Starting input stream pump task"); - let input_pump_token = cancellation_token.clone(); - tokio::spawn(async move { - use futures::StreamExt; - let mut chunk_count = 0; - tracing::debug!("Input stream pump starting to read from stream"); - loop { - tokio::select! { - // Use () instead of _ for unit type to be explicit - () = input_pump_token.cancelled() => { - tracing::info!("Input stream pump cancelled after {} chunks", chunk_count); - break; - } - chunk_result = input_stream.next() => { - match chunk_result { - Some(Ok(chunk)) => { - chunk_count += 1; - if input_stream_tx.send(chunk).await.is_err() { - tracing::warn!("Input node closed before stream ended."); - break; - } - } - Some(Err(e)) => { - tracing::error!("Error reading from input stream: {}", e); - break; - } - None => { - tracing::info!("Input stream pump finished after {} chunks", chunk_count); - break; - } - } - } - } - } - }); - } - // --- 7. Determine content-type for the response --- tracing::debug!( "Content type sources - configured: {:?}, static: {:?}, input: {:?}", configured_content_type, static_content_type, - input_content_type + first_input_content_type ); - // Priority: configured (from http_output params) > static (final node) > input > default let content_type = configured_content_type .or(static_content_type) - .or(input_content_type) + .or(first_input_content_type) .unwrap_or_else(|| "application/octet-stream".to_string()); tracing::info!("Using content type for response: '{}'", content_type); - // --- 8. Return the result struct --- Ok(OneshotPipelineResult { data_stream: output_stream_rx, content_type }) } } diff --git a/crates/nodes/src/core/bytes_input.rs b/crates/nodes/src/core/bytes_input.rs index 68511c16..5694524b 100644 --- a/crates/nodes/src/core/bytes_input.rs +++ b/crates/nodes/src/core/bytes_input.rs @@ -16,6 +16,11 @@ use tokio::sync::mpsc; /// and sends them out as `Packet::Binary` packets. This node is special-cased /// by the stateless runner to represent the HTTP request body. pub struct BytesInputNode { + streams: Vec, +} + +struct BytesInputStream { + pin: String, stream_rx: mpsc::Receiver, content_type: Option, } @@ -23,8 +28,21 @@ pub struct BytesInputNode { impl BytesInputNode { /// Creates a new BytesInputNode directly with a channel receiver. /// This is a safe, compile-time checked way to provide the input stream. - pub const fn new(stream_rx: mpsc::Receiver, content_type: Option) -> Self { - Self { stream_rx, content_type } + pub fn new( + pin: impl Into, + stream_rx: mpsc::Receiver, + content_type: Option, + ) -> Self { + Self { streams: vec![BytesInputStream { pin: pin.into(), stream_rx, content_type }] } + } + + /// Creates a BytesInputNode with multiple output pins/streams. + pub fn with_streams(streams: Vec<(String, mpsc::Receiver, Option)>) -> Self { + let streams = streams + .into_iter() + .map(|(pin, stream_rx, content_type)| BytesInputStream { pin, stream_rx, content_type }) + .collect(); + Self { streams } } } @@ -36,89 +54,78 @@ impl ProcessorNode for BytesInputNode { } fn output_pins(&self) -> Vec { - vec![OutputPin { - name: "out".to_string(), - // This node produces generic binary data, but we use Any - // to allow flexible connections (e.g., Binary → Text conversion) - produces_type: PacketType::Any, - cardinality: PinCardinality::Broadcast, - }] + self.streams + .iter() + .map(|stream| OutputPin { + name: stream.pin.clone(), + // This node produces generic binary data, but we use Any + // to allow flexible connections (e.g., Binary → Text conversion) + produces_type: PacketType::Any, + cardinality: PinCardinality::Broadcast, + }) + .collect() } - async fn run(mut self: Box, mut context: NodeContext) -> Result<(), StreamKitError> { + async fn run(mut self: Box, context: NodeContext) -> Result<(), StreamKitError> { let node_name = context.output_sender.node_name().to_string(); state_helpers::emit_initializing(&context.state_tx, &node_name); tracing::info!("BytesInputNode starting"); state_helpers::emit_running(&context.state_tx, &node_name); - let mut chunk_count = 0; - let mut reason = "completed".to_string(); + let mut handles = Vec::new(); - // This node's main loop reads from the stream receiver provided at creation. - // If a cancellation token is provided, we'll also listen for cancellation. - if let Some(token) = &context.cancellation_token { - loop { - tokio::select! { - () = token.cancelled() => { - reason = "cancelled".to_string(); - tracing::info!("BytesInputNode cancelled after {} chunks.", chunk_count); - break; - } - chunk = self.stream_rx.recv() => { - match chunk { - Some(chunk) => { - chunk_count += 1; - if context - .output_sender - .send( - "out", - Packet::Binary { - data: chunk, - content_type: self.content_type.clone().map(Cow::Owned), - metadata: None, - }, - ) - .await - .is_err() - { - tracing::debug!("Output channel closed, stopping node"); - break; - } + for mut stream in self.streams { + let mut sender = context.output_sender.clone(); + let state_tx = context.state_tx.clone(); + let node = node_name.clone(); + let cancel = context.cancellation_token.clone(); + handles.push(tokio::spawn(async move { + let mut chunk_count = 0usize; + let mut reason = "completed".to_string(); + loop { + tokio::select! { + () = async { + if let Some(token) = &cancel { + token.cancelled().await; } - None => { - // Stream finished normally - break; + } => { + reason = "cancelled".to_string(); + tracing::info!("BytesInputNode '{}' stream '{}' cancelled after {} chunks.", node, stream.pin, chunk_count); + break; + } + chunk = stream.stream_rx.recv() => { + match chunk { + Some(chunk) => { + chunk_count += 1; + if sender + .send( + &stream.pin, + Packet::Binary { + data: chunk, + content_type: stream.content_type.clone().map(Cow::Owned), + metadata: None, + }, + ) + .await + .is_err() + { + tracing::debug!("Output channel for pin '{}' closed, stopping stream", stream.pin); + break; + } + } + None => break, } } } } - } - } else { - // No cancellation token, use simpler loop - while let Some(chunk) = self.stream_rx.recv().await { - chunk_count += 1; - if context - .output_sender - .send( - "out", - Packet::Binary { - data: chunk, - content_type: self.content_type.clone().map(Cow::Owned), - metadata: None, - }, - ) - .await - .is_err() - { - tracing::debug!("Output channel closed, stopping node"); - break; - } - } + state_helpers::emit_stopped(&state_tx, &node, reason); + tracing::info!("BytesInputNode '{}' stream '{}' finished after {} chunks.", node, stream.pin, chunk_count); + })); + } + + for handle in handles { + let _ = handle.await; } - // The loop exits when the sender is dropped, which happens when the - // upstream (e.g., the HTTP request body stream) has finished. - state_helpers::emit_stopped(&context.state_tx, &node_name, reason); - tracing::info!("BytesInputNode finished sending stream after {} chunks.", chunk_count); Ok(()) } } diff --git a/docs/src/content/docs/guides/creating-pipelines.md b/docs/src/content/docs/guides/creating-pipelines.md index 107c9043..20cc1dec 100644 --- a/docs/src/content/docs/guides/creating-pipelines.md +++ b/docs/src/content/docs/guides/creating-pipelines.md @@ -215,12 +215,12 @@ curl http://localhost:4545/api/v1/sessions//pipeline Use `POST /api/v1/process` with multipart fields: - `config` (YAML, required; must be the first field) -- `media` (optional) +- Upload fields for media (optional): names must match `streamkit::http_input` nodes. Default is `media` when a single `http_input` exists with no params; otherwise use the node id or `params.field`. If `params.fields` is set, only the listed fields are accepted and the legacy `media` field is disabled. Oneshot validation rules: -- If `media` is present: the pipeline must contain `streamkit::http_input` -- If `media` is absent: the pipeline must contain `core::file_reader` and must not contain `streamkit::http_input` +- If uploads are present: the pipeline must contain `streamkit::http_input` (field names must match) +- If uploads are absent: the pipeline must contain `core::file_reader` and must not contain `streamkit::http_input` - Always: the pipeline must contain `streamkit::http_output` ```bash diff --git a/docs/src/content/docs/reference/http-api.md b/docs/src/content/docs/reference/http-api.md index b8eb9896..2d1ee4b6 100644 --- a/docs/src/content/docs/reference/http-api.md +++ b/docs/src/content/docs/reference/http-api.md @@ -70,11 +70,29 @@ Destroy a session: `POST /api/v1/process` accepts multipart: - `config`: pipeline YAML (required; must be the first field) -- `media`: optional binary media payload +- One or more media fields: names must match `streamkit::http_input` nodes **Max body size**: Configurable via `[server].max_body_size` (default: 100 MB). -If `media` is provided, the pipeline must include `streamkit::http_input` to receive it. If no media is needed, `streamkit::http_input` can still be used as a trigger (with empty body) or the pipeline can rely solely on `core::file_reader`. Both nodes can be used together (e.g., mixing uploaded audio with a local file). In all cases, `streamkit::http_output` is required. +If one or more media fields are provided, the pipeline must include `streamkit::http_input` nodes to receive them. Each `http_input` can declare: + +- `field`: single field name (default `media` when only one http_input exists, otherwise the node id) +- `required`: whether the field must be present (default `true`) +- `fields`: list of field entries (string or `{ name, required }`), which exposes one output pin per entry so each upload can be routed independently. When `fields` is set, only the listed fields are accepted; the legacy `media` field is disabled. `field` and `fields` are mutually exclusive. + +Unexpected fields cause a `400`, and missing required fields time out. + +Example (dual upload mixing sample, real assets + paced playback): + +```bash +curl --no-buffer \ + -F config=@samples/pipelines/oneshot/dual_upload_mixing.yml \ + -F track_a=@samples/audio/system/speech_2m.opus \ + -F "track_b=@samples/audio/system/THE LADY IS A TRAMP.opus" \ + http://127.0.0.1:4545/api/v1/process | ffplay -nodisp -autoexit -f webm -i - +``` + +If no uploads are needed, `streamkit::http_input` can still be used as a trigger (with empty body) or the pipeline can rely solely on `core::file_reader`. Both nodes can be used together (e.g., mixing uploaded audio with a local file). In all cases, `streamkit::http_output` is required. > [!NOTE] > `streamkit::http_input` and `streamkit::http_output` are **oneshot-only marker nodes**. They are available in schema discovery, but they cannot be used in dynamic sessions. diff --git a/docs/src/content/docs/reference/nodes/streamkit-http-input.md b/docs/src/content/docs/reference/nodes/streamkit-http-input.md index 5c3c72b1..5b8b2433 100644 --- a/docs/src/content/docs/reference/nodes/streamkit-http-input.md +++ b/docs/src/content/docs/reference/nodes/streamkit-http-input.md @@ -18,10 +18,15 @@ Synthetic input node for oneshot HTTP pipelines. Receives binary data from the H No inputs. ### Outputs -- `out` produces `Binary` (broadcast) +- Single-field mode: one `Binary` pin named after `field` (defaults to `media` when a single `http_input` exists). +- Multi-field mode: one `Binary` pin per `fields` entry. Pin names match the field names and **no legacy `media` pin is added**. ## Parameters -No parameters. +- `field` (`string`, optional) — Multipart field name to bind to this input. Defaults to `media` when there is only one `http_input` node; otherwise defaults to the node id. +- `fields` (`array`, optional) — List of multipart fields for this node. Each entry can be a string or `{ name, required }`. When set, only these fields are accepted and the legacy `media` field is disabled. `field` and `fields` are mutually exclusive. +- `required` (`boolean`, default: `true`) — When `true`, the request must include this field. Ignored when `fields` is provided (use per-entry `required` instead). + +When `fields` is provided, this node exposes multiple output pins, one per field. Pin names match the field names, allowing you to wire each uploaded stream independently. The legacy `media` pin is not added in this mode.
diff --git a/samples/pipelines/oneshot/dual_upload_mixing.yml b/samples/pipelines/oneshot/dual_upload_mixing.yml new file mode 100644 index 00000000..d02534be --- /dev/null +++ b/samples/pipelines/oneshot/dual_upload_mixing.yml @@ -0,0 +1,91 @@ +# +# skit:input_asset_tags=speech + +name: Dual Upload Mixer +description: Mix two uploaded Ogg/Opus tracks and return Opus/WebM +mode: oneshot +nodes: + # ============================================================ + # INPUTS: Two uploaded audio tracks + # ============================================================ + uploads: + kind: streamkit::http_input + params: + fields: + - name: track_a + - name: track_b + + # ============================================================ + # UPLOAD A: Demux, pace, decode, apply gain + # ============================================================ + upload_a_demuxer: + kind: containers::ogg::demuxer + needs: uploads.track_a + + upload_a_pacer: + kind: core::pacer + params: + buffer_size: 16 + speed: 1 + needs: upload_a_demuxer + + upload_a_decoder: + kind: audio::opus::decoder + needs: upload_a_pacer + + upload_a_gain: + kind: audio::gain + params: + gain: 1.0 + needs: upload_a_decoder + + # ============================================================ + # UPLOAD B: Demux, pace, decode, apply gain + # ============================================================ + upload_b_demuxer: + kind: containers::ogg::demuxer + needs: uploads.track_b + + upload_b_pacer: + kind: core::pacer + params: + buffer_size: 16 + speed: 1 + needs: upload_b_demuxer + + upload_b_decoder: + kind: audio::opus::decoder + needs: upload_b_pacer + + upload_b_gain: + kind: audio::gain + params: + gain: 0.15 + needs: upload_b_decoder + + # ============================================================ + # OUTPUT: Mix streams and encode to WebM + # ============================================================ + mixer: + kind: audio::mixer + params: + num_inputs: 2 + needs: + - upload_a_gain + - upload_b_gain + + opus_encoder: + kind: audio::opus::encoder + needs: mixer + + webm_muxer: + kind: containers::webm::muxer + params: + channels: 1 + chunk_size: 65536 + sample_rate: 48000 + needs: opus_encoder + + http_output: + kind: streamkit::http_output + needs: webm_muxer diff --git a/ui/src/components/converter/AssetSelector.tsx b/ui/src/components/converter/AssetSelector.tsx index 423143e4..c0bd677d 100644 --- a/ui/src/components/converter/AssetSelector.tsx +++ b/ui/src/components/converter/AssetSelector.tsx @@ -94,6 +94,7 @@ interface AssetSelectorProps { selectedAssetId: string; onAssetSelect: (assetId: string) => void; isLoading?: boolean; + groupId?: string; } export const AssetSelector: React.FC = ({ @@ -101,6 +102,7 @@ export const AssetSelector: React.FC = ({ selectedAssetId, onAssetSelect, isLoading = false, + groupId = 'asset', }) => { if (isLoading) { return ( @@ -134,11 +136,12 @@ export const AssetSelector: React.FC = ({ value={selectedAssetId} onValueChange={onAssetSelect} aria-label="Audio asset selection" + name={`asset-selector-${groupId}`} > {assets.map((asset) => ( - - + + diff --git a/ui/src/hooks/usePipeline.ts b/ui/src/hooks/usePipeline.ts index d2085538..2f5994d2 100644 --- a/ui/src/hooks/usePipeline.ts +++ b/ui/src/hooks/usePipeline.ts @@ -106,8 +106,17 @@ function buildPipelineForYaml( .map((e): NeedsDependency | null => { const label = idToLabelMap.get(e.source); if (!label) return null; + const sourceNode = idToNode.get(e.source); + const sourceOutputs = (sourceNode?.data.outputs || []) as Array<{ name: string }>; + const defaultOutput = sourceOutputs[0]?.name; + const sourceHandle = e.sourceHandle || defaultOutput; + const annotatePin = + sourceOutputs.length > 1 || + (sourceHandle && defaultOutput && sourceHandle !== defaultOutput); + + const needsLabel = sourceHandle && annotatePin ? `${label}.${sourceHandle}` : label; const mode = (e.data as { mode?: ConnectionMode } | undefined)?.mode; - return mode === 'best_effort' ? { node: label, mode } : label; + return mode === 'best_effort' ? { node: needsLabel, mode } : needsLabel; }) .filter((v): v is NeedsDependency => v !== null); diff --git a/ui/src/services/converter.test.ts b/ui/src/services/converter.test.ts index 95cadfa0..3babf109 100644 --- a/ui/src/services/converter.test.ts +++ b/ui/src/services/converter.test.ts @@ -38,6 +38,7 @@ vi.mock('./base', () => ({ describe('converter service', () => { const MOCK_YAML = 'steps:\n - id: test\n kind: core::passthrough'; const MOCK_FILE = new File(['test content'], 'test.ogg', { type: 'audio/ogg' }); + const MOCK_UPLOAD = [{ field: 'media', file: MOCK_FILE }]; let originalMediaSource: unknown; @@ -69,7 +70,7 @@ describe('converter service', () => { body: mockBody, } as Response); - const result = await convertFile(MOCK_YAML, MOCK_FILE, 'playback'); + const result = await convertFile(MOCK_YAML, MOCK_UPLOAD, 'playback'); expect(result.success).toBe(true); expect(result.useStreaming).toBe(true); @@ -97,7 +98,7 @@ describe('converter service', () => { } as unknown as Response); const abortController = new AbortController(); - const result = await convertFile(MOCK_YAML, MOCK_FILE, 'playback', abortController.signal); + const result = await convertFile(MOCK_YAML, MOCK_UPLOAD, 'playback', abortController.signal); expect(result.success).toBe(true); expect(result.responseStream).toBeDefined(); @@ -126,7 +127,7 @@ describe('converter service', () => { body: mockBody, } as Response); - const result = await convertFile(MOCK_YAML, MOCK_FILE, 'playback'); + const result = await convertFile(MOCK_YAML, MOCK_UPLOAD, 'playback'); expect(result.success).toBe(true); expect(result.useStreaming).toBe(true); @@ -155,7 +156,7 @@ describe('converter service', () => { } as unknown as Response); const abortController = new AbortController(); - const result = await convertFile(MOCK_YAML, MOCK_FILE, 'playback', abortController.signal); + const result = await convertFile(MOCK_YAML, MOCK_UPLOAD, 'playback', abortController.signal); expect(result.responseStream).toBeDefined(); @@ -178,7 +179,7 @@ describe('converter service', () => { blob: vi.fn().mockResolvedValue(mockBlob), } as never); - const result = await convertFile(MOCK_YAML, MOCK_FILE, 'playback'); + const result = await convertFile(MOCK_YAML, MOCK_UPLOAD, 'playback'); expect(result.success).toBe(true); expect(result.useStreaming).toBe(false); @@ -212,7 +213,7 @@ describe('converter service', () => { blob: vi.fn().mockResolvedValue(mockBlob), } as never); - const result = await convertFile(MOCK_YAML, MOCK_FILE, 'download'); + const result = await convertFile(MOCK_YAML, MOCK_UPLOAD, 'download'); expect(result.success).toBe(true); expect(mockLink.href).toBe('blob:download-url'); @@ -244,7 +245,7 @@ describe('converter service', () => { } as never); const file = new File(['content'], 'my-audio.ogg', { type: 'audio/ogg' }); - await convertFile(MOCK_YAML, file, 'download'); + await convertFile(MOCK_YAML, [{ field: 'media', file }], 'download'); expect(mockLink.download).toBe('my-audio_converted.wav'); }); @@ -262,7 +263,7 @@ describe('converter service', () => { }); }); - const resultPromise = convertFile(MOCK_YAML, MOCK_FILE, 'playback', abortController.signal); + const resultPromise = convertFile(MOCK_YAML, MOCK_UPLOAD, 'playback', abortController.signal); abortController.abort(); const result = await resultPromise; @@ -280,7 +281,7 @@ describe('converter service', () => { blob: vi.fn().mockResolvedValue(new Blob()), } as never); - await convertFile(MOCK_YAML, MOCK_FILE, 'download', abortController.signal); + await convertFile(MOCK_YAML, MOCK_UPLOAD, 'download', abortController.signal); expect(fetch).toHaveBeenCalledWith( 'http://localhost:4545/api/v1/process', @@ -300,7 +301,7 @@ describe('converter service', () => { text: vi.fn().mockResolvedValue('Invalid pipeline configuration'), } as never); - const result = await convertFile(MOCK_YAML, MOCK_FILE, 'playback'); + const result = await convertFile(MOCK_YAML, MOCK_UPLOAD, 'playback'); expect(result.success).toBe(false); expect(result.error).toContain('Bad Request'); @@ -310,7 +311,7 @@ describe('converter service', () => { it('should handle network errors', async () => { (fetch as ReturnType).mockRejectedValue(new Error('Network error')); - const result = await convertFile(MOCK_YAML, MOCK_FILE, 'playback'); + const result = await convertFile(MOCK_YAML, MOCK_UPLOAD, 'playback'); expect(result.success).toBe(false); expect(result.error).toContain('Network error'); @@ -319,7 +320,7 @@ describe('converter service', () => { it('should handle unknown errors', async () => { (fetch as ReturnType).mockRejectedValue('Unknown error'); - const result = await convertFile(MOCK_YAML, MOCK_FILE, 'playback'); + const result = await convertFile(MOCK_YAML, MOCK_UPLOAD, 'playback'); expect(result.success).toBe(false); expect(result.error).toBe('Unknown error occurred'); @@ -334,7 +335,7 @@ describe('converter service', () => { blob: vi.fn().mockResolvedValue(new Blob()), } as never); - await convertFile(MOCK_YAML, MOCK_FILE, 'download'); + await convertFile(MOCK_YAML, MOCK_UPLOAD, 'download'); expect(fetch).toHaveBeenCalledWith( 'http://localhost:4545/api/v1/process', diff --git a/ui/src/services/converter.ts b/ui/src/services/converter.ts index b319bc96..1756e477 100644 --- a/ui/src/services/converter.ts +++ b/ui/src/services/converter.ts @@ -183,9 +183,11 @@ async function handleDownload( * @param signal - Optional AbortSignal to cancel the request * @returns A promise that resolves when the conversion is complete */ +export type UploadField = { field: string; file: File }; + export async function convertFile( pipelineYaml: string, - mediaFile: File | null, + uploads: UploadField[] | null, mode: OutputMode = 'download', signal?: AbortSignal, options?: ConvertFileOptions @@ -195,15 +197,17 @@ export async function convertFile( const formData = new FormData(); formData.append('config', new Blob([pipelineYaml], { type: 'text/yaml' })); - // Only append media file if provided (not needed for asset-based pipelines) - if (mediaFile) { - formData.append('media', mediaFile); + // Append uploads (multi-field allowed) + const files = uploads ?? []; + for (const upload of files) { + formData.append(upload.field, upload.file); } // Determine the API URL logger.info('Starting conversion:', { - fileName: mediaFile?.name || '(asset-based)', - fileSize: mediaFile?.size || 0, + uploads: files.length, + fileNames: files.map((f) => f.file.name), + fileSizes: files.map((f) => f.file.size), pipelineLength: pipelineYaml.length, }); @@ -244,7 +248,9 @@ export async function convertFile( } // Handle download mode - return handleDownload(response, contentType, mediaFile); + // Use first upload to infer output naming when possible + const primaryFile = files[0]?.file ?? null; + return handleDownload(response, contentType, primaryFile); } catch (error) { logger.error('Conversion error:', error); return { diff --git a/ui/src/utils/yamlPipeline.test.ts b/ui/src/utils/yamlPipeline.test.ts index 03583732..f5bd4281 100644 --- a/ui/src/utils/yamlPipeline.test.ts +++ b/ui/src/utils/yamlPipeline.test.ts @@ -157,4 +157,65 @@ nodes: expect(result.edges).toHaveLength(1); expect((result.edges[0]?.data as { mode?: string } | undefined)?.mode).toBe('best_effort'); }); + + it('parses needs that target specific output pins (e.g., http_input fields)', () => { + setPacketTypeRegistry([ + { + id: 'Binary', + label: 'Binary', + color: '#555555', + display_template: null, + compatibility: { kind: 'any' }, + }, + ]); + + const yaml = ` +mode: oneshot +nodes: + uploads: + kind: streamkit::http_input + params: + fields: + - track_a + - track_b + + upload_a_demuxer: + kind: containers::ogg::demuxer + needs: uploads.track_a + + upload_b_demuxer: + kind: containers::ogg::demuxer + needs: uploads.track_b +`; + + const nodeDefinitions: NodeDefinition[] = [ + { + kind: 'streamkit::http_input', + param_schema: {}, + inputs: [], + outputs: [{ name: 'media', produces_type: 'Binary', cardinality: 'Broadcast' }], + categories: [], + bidirectional: false, + }, + makeSinkNodeDef('containers::ogg::demuxer', ['Binary']), + ]; + + let nextId = 1; + const result = parseYamlToPipeline( + yaml, + nodeDefinitions, + () => {}, + () => {}, + () => `id_${nextId++}`, + () => { + nextId = 1; + } + ); + + expect(result.error).toBeUndefined(); + expect(result.edges).toHaveLength(2); + const handles = result.edges.map((e) => e.sourceHandle); + expect(handles).toContain('track_a'); + expect(handles).toContain('track_b'); + }); }); diff --git a/ui/src/utils/yamlPipeline.ts b/ui/src/utils/yamlPipeline.ts index b4c49559..d7acdefe 100644 --- a/ui/src/utils/yamlPipeline.ts +++ b/ui/src/utils/yamlPipeline.ts @@ -42,6 +42,26 @@ type EditorNodeData = { type NeedsDependency = string | { node: string; mode?: ConnectionMode }; +type ParsedDependency = { + sourceLabel: string; + sourcePin?: string; + mode?: ConnectionMode; +}; + +/** + * Parse a needs dependency into node label + optional source pin + mode. + * Supports "node.pin" shorthand as well as { node, mode } objects. + */ +function parseDependency(dep: NeedsDependency): ParsedDependency { + const mode = typeof dep === 'string' ? undefined : dep.mode; + const raw = typeof dep === 'string' ? dep : dep.node; + + const [sourceLabel, ...rest] = raw.split('.'); + const sourcePin = rest.length > 0 ? rest.join('.') : undefined; + + return { sourceLabel, sourcePin, mode }; +} + type ImportedNodeConfig = { kind: string; params?: Record; @@ -164,6 +184,45 @@ function expandDynamicInputs(nodeDef: NodeDefinition, config: ImportedNodeConfig return nodeInputs; } +/** + * Derive output pins for http_input based on params.fields/field. + * - When fields are provided, create one pin per entry (string or { name }). + * - Ensure a 'media' pin always exists for backward compatibility. + */ +function deriveHttpInputOutputs( + params?: Record +): Array<{ name: string; produces_type: PacketType; cardinality: 'Broadcast' }> { + const normalizeFieldName = (value: unknown): string | null => { + if (typeof value === 'string' && value.trim()) { + return value.trim(); + } + if (value && typeof value === 'object' && 'name' in (value as Record)) { + const name = String((value as Record).name || '').trim(); + return name || null; + } + return null; + }; + + const fields = params?.fields; + const outputs: string[] = Array.isArray(fields) + ? fields + .map((entry) => normalizeFieldName(entry)) + .filter((name): name is string => Boolean(name)) + : []; + + if (outputs.length === 0) { + const single = normalizeFieldName(params?.field); + outputs.push(single ?? 'media'); + } + + const unique = Array.from(new Set(outputs)); + return unique.map((name) => ({ + name, + produces_type: 'Binary' as PacketType, + cardinality: 'Broadcast' as const, + })); +} + /** * Creates ReactFlow nodes from pipeline nodes */ @@ -187,6 +246,10 @@ function createNodesFromPipeline( labelToIdMap.set(label, newId); const nodeInputs = expandDynamicInputs(nodeDef, config); + const nodeOutputs = + config.kind === 'streamkit::http_input' + ? deriveHttpInputOutputs(config.params) + : nodeDef.outputs || []; const newNode: Node = { id: newId, @@ -202,7 +265,7 @@ function createNodesFromPipeline( params: (config.params as Record) || {}, paramSchema: nodeDef.param_schema, inputs: nodeInputs, - outputs: nodeDef.outputs || [], + outputs: nodeOutputs, definition: { bidirectional: nodeDef.bidirectional }, nodeDefinition: nodeDef, onParamChange: handleParamChange, @@ -223,7 +286,8 @@ function createNodesFromPipeline( function validateConnectionCompatibility( sourceNode: Node, targetNode: Node, - needsIndex: number, + sourceHandleName: string, + targetHandleName: string, sourceLabel: string, targetLabel: string, nodes: Node[], @@ -238,12 +302,15 @@ function validateConnectionCompatibility( accepts_types: PacketType[]; }>; - // Get default output pin (first output) - const sourceOutput = sourceOutputs[0]; - // For nodes with multiple needs, use the input pin corresponding to the needs index - const targetInput = targetInputs[needsIndex] || targetInputs[0]; + const sourceOutput = sourceOutputs.find((o) => o.name === sourceHandleName) ?? sourceOutputs[0]; + const targetInput = targetInputs.find((i) => i.name === targetHandleName) ?? targetInputs[0]; - if (!sourceOutput || !targetInput) return; + if (!sourceOutput) { + throw new Error( + `Node "${targetLabel}" references non-existent pin "${sourceHandleName}" on "${sourceLabel}".` + ); + } + if (!targetInput) return; // Resolve the output type (handles Passthrough inference and param-dependent nodes like resampler) const resolvedSourceType = resolveOutputType( @@ -273,7 +340,8 @@ function createEdgeForConnection( targetNode: Node, sourceId: string, targetId: string, - needsIndex: number, + sourceHandleName: string, + targetHandleName: string, mode: ConnectionMode | undefined, nodeByLabel: Map>, newEdges: Edge[] @@ -287,19 +355,19 @@ function createEdgeForConnection( accepts_types: PacketType[]; }>; - const sourceOutput = sourceOutputs[0]; - const targetInput = targetInputs[needsIndex] || targetInputs[0]; + const sourceOutput = sourceOutputs.find((o) => o.name === sourceHandleName) ?? sourceOutputs[0]; + const targetInput = targetInputs.find((i) => i.name === targetHandleName) ?? targetInputs[0]; if (!sourceOutput || !targetInput) return; // Get pin names - const sourceHandleName = sourceOutput.name; - const targetHandleName = targetInput.name; + const sourceHandle = sourceOutput.name; + const targetHandle = targetInput.name; // Resolve the output type for this edge (handles Passthrough inference) const resolvedType = resolveOutputType( sourceNode, - sourceHandleName, + sourceHandle, Array.from(nodeByLabel.values()), newEdges ); @@ -309,8 +377,8 @@ function createEdgeForConnection( id: `${sourceId}-${targetId}-${newEdges.length}`, source: sourceId, target: targetId, - sourceHandle: sourceHandleName, - targetHandle: targetHandleName, + sourceHandle: sourceHandle, + targetHandle: targetHandle, data: { resolvedType, ...(mode === 'best_effort' ? { mode } : {}), @@ -338,8 +406,7 @@ function createEdgesFromPipeline( const needs: NeedsDependency[] = Array.isArray(config.needs) ? config.needs : [config.needs]; needs.forEach((dep: NeedsDependency, needsIndex: number) => { - const sourceLabel = typeof dep === 'string' ? dep : dep.node; - const mode: ConnectionMode | undefined = typeof dep === 'string' ? undefined : dep.mode; + const { sourceLabel, sourcePin, mode } = parseDependency(dep); const sourceId = labelToIdMap.get(sourceLabel); @@ -353,11 +420,29 @@ function createEdgesFromPipeline( const sourceNode = nodeByLabel.get(sourceLabel); if (!sourceNode) return; + const sourceOutputs = (sourceNode.data.outputs || []) as Array<{ + name: string; + produces_type: PacketType; + }>; + const targetInputs = (targetNode.data.inputs || []) as Array<{ + name: string; + accepts_types: PacketType[]; + }>; + + const sourceHandleName = sourcePin ?? sourceOutputs[0]?.name; + const targetHandleName = (targetInputs[needsIndex] || targetInputs[0])?.name; + + if (!sourceHandleName) { + throw new Error(`Node "${sourceLabel}" has no output pins to connect to "${label}".`); + } + if (!targetHandleName) return; + // Validate connection compatibility validateConnectionCompatibility( sourceNode, targetNode, - needsIndex, + sourceHandleName, + targetHandleName, sourceLabel, label, Array.from(nodeByLabel.values()), @@ -370,7 +455,8 @@ function createEdgesFromPipeline( targetNode, sourceId, targetId, - needsIndex, + sourceHandleName, + targetHandleName, mode, nodeByLabel, newEdges diff --git a/ui/src/views/ConvertView.tsx b/ui/src/views/ConvertView.tsx index db818ea6..a9a285f6 100644 --- a/ui/src/views/ConvertView.tsx +++ b/ui/src/views/ConvertView.tsx @@ -3,7 +3,10 @@ // SPDX-License-Identifier: MPL-2.0 import styled from '@emotion/styled'; -import React, { useEffect, useRef, useCallback, useState, useMemo } from 'react'; +import { load as loadYaml } from 'js-yaml'; +import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; + +/* eslint-disable max-depth */ import { AssetSelector } from '@/components/converter/AssetSelector'; import { ConversionProgress } from '@/components/converter/ConversionProgress'; @@ -17,13 +20,98 @@ import { MSEAudioPlayer } from '@/components/MSEAudioPlayer'; import { RadioGroupRoot, RadioWithLabel } from '@/components/ui/RadioGroup'; import { useConvertViewState } from '@/hooks/useConvertViewState'; import { useAudioAssets } from '@/services/assets'; -import { convertFile, type OutputMode, getExtensionFromContentType } from '@/services/converter'; +import { + convertFile, + getExtensionFromContentType, + type OutputMode, + type UploadField, +} from '@/services/converter'; import { listSamples } from '@/services/samples'; -import { useSchemaStore, ensureSchemasLoaded } from '@/stores/schemaStore'; +import { ensureSchemasLoaded, useSchemaStore } from '@/stores/schemaStore'; import { viewsLogger } from '@/utils/logger'; import { orderSamplePipelinesSystemFirst } from '@/utils/samplePipelineOrdering'; import { injectFileReadNode } from '@/utils/yamlPipeline'; +type HttpInputField = { name: string; required: boolean }; + +const normalizeHttpInputField = (entry: unknown): HttpInputField | null => { + if (typeof entry === 'string' && entry.trim()) { + return { name: entry.trim(), required: true }; + } + if (entry && typeof entry === 'object' && 'name' in (entry as Record)) { + const name = String((entry as Record).name ?? '').trim(); + if (!name) return null; + const required = (entry as Record).required; + return { name, required: typeof required === 'boolean' ? required : true }; + } + return null; +}; + +const isRecord = (value: unknown): value is Record => + Boolean(value) && typeof value === 'object' && !Array.isArray(value); + +const extractFieldsFromNode = ( + label: string, + node: Record, + defaultField: string | null +): HttpInputField[] => { + const params = isRecord(node.params) ? (node.params as Record) : {}; + const fieldsVal = params.fields; + + if (Array.isArray(fieldsVal)) { + return fieldsVal + .map((entry) => normalizeHttpInputField(entry)) + .filter((f): f is HttpInputField => Boolean(f)); + } + + const fieldVal = typeof params.field === 'string' ? params.field.trim() : ''; + if (fieldVal) { + const required = typeof params.required === 'boolean' ? (params.required as boolean) : true; + return [{ name: fieldVal, required }]; + } + + const fallback = defaultField ?? label; + return fallback ? [{ name: fallback, required: defaultField ? false : true }] : []; +}; + +const deriveHttpInputFields = ( + yaml: string +): { fields: HttpInputField[]; hasHttpInput: boolean } => { + try { + const parsed = loadYaml(yaml) as { nodes?: unknown; steps?: unknown } | null; + if (!parsed || typeof parsed !== 'object') return { fields: [], hasHttpInput: false }; + + if (isRecord(parsed.nodes)) { + const httpEntries = Object.entries(parsed.nodes).filter( + ([, node]) => isRecord(node) && node.kind === 'streamkit::http_input' + ); + if (httpEntries.length === 0) return { fields: [], hasHttpInput: false }; + + const defaultField = httpEntries.length === 1 ? 'media' : null; + const fields = httpEntries.flatMap(([label, node]) => + extractFieldsFromNode(label, node as Record, defaultField) + ); + + const unique = new Map(); + fields.forEach((f) => unique.set(f.name, f)); + return { fields: Array.from(unique.values()), hasHttpInput: true }; + } + + if (Array.isArray(parsed.steps)) { + const hasHttpInput = parsed.steps.some( + (s) => isRecord(s) && typeof s.kind === 'string' && s.kind === 'streamkit::http_input' + ); + if (hasHttpInput) { + return { fields: [{ name: 'media', required: true }], hasHttpInput: true }; + } + } + + return { fields: [], hasHttpInput: false }; + } catch { + return { fields: [], hasHttpInput: false }; + } +}; + const ViewContainer = styled.div` height: 100%; display: flex; @@ -432,10 +520,12 @@ const generateCliCommand = ( templateId: string, isNoInput: boolean, isTTS: boolean, + fields: HttpInputField[], serverUrl: string = 'http://127.0.0.1:4545' ): string => { // Convert template ID to file path (e.g., "oneshot/speech_to_text" -> "samples/pipelines/oneshot/speech_to_text.yml") const configPath = `samples/pipelines/${templateId}.yml`; + const activeFields = fields.length > 0 ? fields : [{ name: 'media', required: true }]; if (isNoInput) { // No input needed - send empty media field @@ -453,6 +543,29 @@ const generateCliCommand = ( ${serverUrl}/api/v1/process -o - | ffplay -f webm -i -`; } + // Multi-upload pipelines + if (activeFields.length > 1) { + // Provide real assets for known dual-upload sample + if ( + templateId.endsWith('oneshot/dual_upload_mixing') && + activeFields.some((f) => f.name === 'track_a') && + activeFields.some((f) => f.name === 'track_b') + ) { + return `curl --no-buffer \\ + -F config=@${configPath} \\ + -F track_a=@samples/audio/system/speech_2m.opus \\ + -F "track_b=@samples/audio/system/THE LADY IS A TRAMP.opus" \\ + ${serverUrl}/api/v1/process | ffplay -nodisp -autoexit -f webm -i -`; + } + + const fieldLines = activeFields.map((f) => ` -F ${f.name}=@your-${f.name}.ogg \\`).join('\n'); + + return `curl --no-buffer \\ + -F config=@${configPath} \\ +${fieldLines} + ${serverUrl}/api/v1/process -o - | ffplay -f webm -i -`; + } + // Standard audio input pipeline return `curl --no-buffer \\ -F config=@${configPath} \\ @@ -519,6 +632,9 @@ const ConvertView: React.FC = () => { setShowTechnicalDetails, } = useConvertViewState(); + const [httpInputFields, setHttpInputFields] = useState([]); + const [hasHttpInput, setHasHttpInput] = useState(false); + const [fieldUploads, setFieldUploads] = useState>({}); // State for CLI command copy button const [cliCopied, setCliCopied] = useState(false); const [msePlaybackError, setMsePlaybackError] = useState(null); @@ -527,8 +643,13 @@ const ConvertView: React.FC = () => { // Generate CLI command based on current template and pipeline type const cliCommand = useMemo(() => { if (!selectedTemplateId) return ''; - return generateCliCommand(selectedTemplateId, isNoInputPipeline, isTTSPipeline); - }, [selectedTemplateId, isNoInputPipeline, isTTSPipeline]); + return generateCliCommand( + selectedTemplateId, + isNoInputPipeline, + isTTSPipeline, + httpInputFields + ); + }, [selectedTemplateId, isNoInputPipeline, isTTSPipeline, httpInputFields]); // Handler for copying CLI command to clipboard const handleCopyCliCommand = useCallback(async () => { @@ -651,16 +772,22 @@ const ConvertView: React.FC = () => { return false; }; - // Filter assets based on pipeline's expected format + // Filter assets based on pipeline's expected format; for multi-field uploads, allow all assets so fields can mix const filteredAssets = React.useMemo(() => { - if (!pipelineYaml || inputMode !== 'asset') { + if (!pipelineYaml) { return audioAssets; } const expectedFormats = detectExpectedFormats(pipelineYaml); const inputAssetTags = detectInputAssetTags(pipelineYaml); - // If no specific format detected, show all assets + // Multi-field pipelines: only filter by format (avoid tag-based narrowing so users can mix content) + if (httpInputFields.length > 1) { + if (!expectedFormats) return audioAssets; + return audioAssets.filter((asset) => expectedFormats.includes(asset.format.toLowerCase())); + } + + // Single-field pipelines: apply both format and tag filters if present if (!expectedFormats && !inputAssetTags) { viewsLogger.debug('No specific format required, showing all assets'); return audioAssets; @@ -682,7 +809,7 @@ const ConvertView: React.FC = () => { viewsLogger.debug('Filtered to', tagFiltered.length, 'compatible assets'); return tagFiltered; - }, [audioAssets, pipelineYaml, inputMode]); + }, [audioAssets, pipelineYaml, httpInputFields.length]); // Clear selected asset if it's no longer in the filtered list useEffect(() => { @@ -692,6 +819,20 @@ const ConvertView: React.FC = () => { } }, [filteredAssets, selectedAssetId, setSelectedAssetId]); + // Track http_input fields for multi-upload pipelines + useEffect(() => { + const { fields, hasHttpInput: hasHttp } = deriveHttpInputFields(pipelineYaml); + setHasHttpInput(hasHttp); + setHttpInputFields(fields); + setFieldUploads((prev) => { + const next: Record = {}; + fields.forEach((f) => { + next[f.name] = prev[f.name] ?? null; + }); + return next; + }); + }, [pipelineYaml]); + // Watch for pipeline YAML changes and update transcription/TTS detection useEffect(() => { const isTranscription = checkIfTranscriptionPipeline(pipelineYaml); @@ -797,37 +938,61 @@ const ConvertView: React.FC = () => { } }; - // Helper: Prepare input file based on pipeline type and mode - const prepareInputFile = useCallback((): File | null => { + const prepareUploads = useCallback(async (): Promise => { + const fields = + httpInputFields.length > 0 ? httpInputFields : [{ name: 'media', required: true }]; + if (isNoInputPipeline) { - // No input needed - create empty file as placeholder for http_input + // No input needed - create empty placeholder for the first field const blob = new Blob([''], { type: 'application/octet-stream' }); - return new File([blob], 'empty', { type: 'application/octet-stream' }); + const file = new File([blob], 'empty', { type: 'application/octet-stream' }); + return [{ field: fields[0].name, file }]; } if (isTTSPipeline) { - // For TTS pipelines, convert text to a File object if (!textInput.trim()) { return null; } const blob = new Blob([textInput], { type: 'text/plain' }); - return new File([blob], 'input.txt', { type: 'text/plain' }); + const file = new File([blob], 'input.txt', { type: 'text/plain' }); + return [{ field: fields[0].name, file }]; } if (inputMode === 'upload') { + if (fields.length > 1) { + const uploads: UploadField[] = []; + for (const field of fields) { + const file = fieldUploads[field.name]; + if (!file) { + if (field.required) return null; + continue; + } + uploads.push({ field: field.name, file }); + } + return uploads; + } + if (!selectedFile) { return null; } - return selectedFile; + return [{ field: fields[0].name, file: selectedFile }]; } - // Asset mode - ensure asset is selected - if (!selectedAssetId) { - return null; + if (!selectedAssetId || !hasHttpInput) { + return []; } - // YAML is already modified by useEffect, just use it directly - return null; - }, [inputMode, isNoInputPipeline, isTTSPipeline, selectedAssetId, selectedFile, textInput]); + return []; + }, [ + fieldUploads, + httpInputFields, + inputMode, + isNoInputPipeline, + isTTSPipeline, + hasHttpInput, + selectedAssetId, + selectedFile, + textInput, + ]); // Helper: Clean up previous conversion state const cleanupPreviousState = useCallback(() => { @@ -913,8 +1078,8 @@ const ConvertView: React.FC = () => { // eslint-disable-next-line max-statements -- Intentionally co-locates conversion state + error/cancel handling. const handleConvert = async () => { // Determine the input source - const fileToConvert = prepareInputFile(); - if (fileToConvert === null && !selectedAssetId) { + const uploads = await prepareUploads(); + if (uploads === null) { return; // Validation failed } @@ -929,7 +1094,10 @@ const ConvertView: React.FC = () => { setConversionMessage(''); try { - const result = await convertFile(pipelineYaml, fileToConvert, outputMode, controller.signal); + const webmPlayback = outputMode === 'playback' ? 'auto' : 'blob'; + const result = await convertFile(pipelineYaml, uploads, outputMode, controller.signal, { + webmPlayback, + }); if (result.success) { handleConversionSuccess(result); @@ -1084,8 +1252,8 @@ const ConvertView: React.FC = () => { if (mseFallbackLoading) return; // Determine the input source - const fileToConvert = prepareInputFile(); - if (fileToConvert === null && !selectedAssetId) { + const uploads = await prepareUploads(); + if (uploads === null) { return; } @@ -1108,7 +1276,7 @@ const ConvertView: React.FC = () => { setMseFallbackLoading(true); try { - const result = await convertFile(pipelineYaml, fileToConvert, 'playback', controller.signal, { + const result = await convertFile(pipelineYaml, uploads, 'playback', controller.signal, { webmPlayback: 'blob', }); @@ -1141,20 +1309,28 @@ const ConvertView: React.FC = () => { handleConversionSuccess, mseFallbackLoading, pipelineYaml, - prepareInputFile, - selectedAssetId, + prepareUploads, setAbortController, setConversionMessage, setConversionStatus, ]); + const uploadFields = + httpInputFields.length > 0 ? httpInputFields : [{ name: 'media', required: true }]; + const isMultiUpload = uploadFields.length > 1; + const handleDownloadAudio = () => { if (!audioUrl) return; let outputFileName = 'converted_audio'; - if (inputMode === 'upload' && selectedFile) { - const originalName = selectedFile.name; + const primaryUpload = + isMultiUpload && inputMode === 'upload' + ? (uploadFields.map((f) => fieldUploads[f.name]).find((f): f is File => Boolean(f)) ?? null) + : selectedFile; + + if (inputMode === 'upload' && primaryUpload) { + const originalName = primaryUpload.name; const baseName = originalName.includes('.') ? originalName.substring(0, originalName.lastIndexOf('.')) : originalName; @@ -1194,8 +1370,26 @@ const ConvertView: React.FC = () => { ? true // No input needed for these pipelines : isTTSPipeline ? textInput.trim() !== '' - : (inputMode === 'upload' && selectedFile !== null) || - (inputMode === 'asset' && selectedAssetId !== '')); + : (() => { + if (!hasHttpInput) { + // Pipelines without http_input rely solely on YAML/file_reader; allow run + return true; + } + + if (isMultiUpload) { + return uploadFields.every((f) => (f.required ? !!fieldUploads[f.name] : true)); + } + + if (inputMode === 'upload') { + return selectedFile !== null; + } + + if (inputMode === 'asset') { + return selectedAssetId !== ''; + } + + return false; + })()); return ( @@ -1297,23 +1491,49 @@ const ConvertView: React.FC = () => { ) : ( <> - - handleInputModeChange('upload')} - > - Upload File - - handleInputModeChange('asset')} - > - Select Existing Asset - - - - {inputMode === 'upload' ? ( - + {!isMultiUpload && ( + + handleInputModeChange('upload')} + > + Upload File + + handleInputModeChange('asset')} + disabled={isMultiUpload} + > + Select Existing Asset + + + )} + + {inputMode === 'upload' || isMultiUpload ? ( + isMultiUpload ? ( +
+

+ This pipeline expects multiple uploads. For each field, choose an upload + or pick an existing asset. +

+ {uploadFields.map((field) => ( +
+ + {field.name} + {!field.required ? ' (optional)' : ''} + + + setFieldUploads((prev) => ({ ...prev, [field.name]: file })) + } + /> +
+ ))} +
+ ) : ( + + ) ) : (