diff --git a/CHANGELOG.md b/CHANGELOG.md index eb52f0697..d6f9882e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- **Home Assistant MQTT bridge** — opt-in publisher for body-tracking, zone, and vital-sign events. Set `MQTT_URL` (or pass `--mqtt-url`) to enable. Publishes presence, posture, activity, vitals alerts, lying-still alerts, and node online/offline events to `ruview/events/{type}` topics. Includes Home Assistant MQTT discovery configs for person-count and activity sensors. Zero impact on existing deployments when not configured. See `docs/integrations/home-assistant.md`. +- **`event_stream` module** — async broadcast event bus (`EventBus`) and state-change tracker (`StateTracker`) for detecting presence, posture, activity, vitals, and node lifecycle transitions. - **`nvsim` crate — deterministic NV-diamond magnetometer pipeline simulator** (ADR-089) — New standalone leaf crate at `v2/crates/nvsim` modeling a forward-only magnetic sensing path: scene → source synthesis (Biot–Savart, dipole, diff --git a/v2/Cargo.lock b/v2/Cargo.lock index 2425594e1..c03fa5d57 100644 --- a/v2/Cargo.lock +++ b/v2/Cargo.lock @@ -231,6 +231,18 @@ dependencies = [ "wait-timeout", ] +[[package]] +name = "async-compression" +version = "0.4.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e79b3f8a79cccc2898f31920fc69f304859b3bd567490f75ebf51ae1c792a9ac" +dependencies = [ + "compression-codecs", + "compression-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -318,7 +330,7 @@ dependencies = [ "sync_wrapper 1.0.2", "tokio", "tokio-tungstenite", - "tower", + "tower 0.5.3", "tower-layer", "tower-service", "tracing", @@ -871,6 +883,23 @@ dependencies = [ "memchr", ] +[[package]] +name = "compression-codecs" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce2548391e9c1929c21bf6aa2680af86fe4c1b33e6cea9ac1cfeec0bd11218cf" +dependencies = [ + "compression-core", + "flate2", + "memchr", +] + +[[package]] +name = "compression-core" +version = "0.4.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc14f565cf027a105f7a44ccf9e5b424348421a1d8952a8fc9d499d313107789" + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -2371,6 +2400,16 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hdrhistogram" +version = "7.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" +dependencies = [ + "byteorder", + "num-traits", +] + [[package]] name = "heapless" version = "0.6.1" @@ -2612,7 +2651,7 @@ dependencies = [ "hyper 0.14.32", "rustls 0.21.12", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", ] [[package]] @@ -3594,10 +3633,10 @@ dependencies = [ "libc", "log", "openssl", - "openssl-probe", + "openssl-probe 0.2.1", "openssl-sys", "schannel", - "security-framework", + "security-framework 3.7.0", "security-framework-sys", "tempfile", ] @@ -3892,13 +3931,35 @@ name = "nvsim" version = "0.3.0" dependencies = [ "approx 0.5.1", + "criterion", + "js-sys", "rand 0.8.5", "rand_chacha 0.3.1", "serde", + "serde-wasm-bindgen", "serde_json", "sha2", "thiserror 1.0.69", "tracing", + "wasm-bindgen", +] + +[[package]] +name = "nvsim-server" +version = "0.3.0" +dependencies = [ + "axum", + "clap", + "futures-util", + "nvsim", + "serde", + "serde_json", + "thiserror 1.0.69", + "tokio", + "tower 0.4.13", + "tower-http 0.5.2", + "tracing", + "tracing-subscriber", ] [[package]] @@ -4106,6 +4167,12 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + [[package]] name = "openssl-probe" version = "0.2.1" @@ -4487,6 +4554,26 @@ dependencies = [ "siphasher 1.0.2", ] +[[package]] +name = "pin-project" +version = "1.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "pin-project-lite" version = "0.2.17" @@ -5232,14 +5319,14 @@ dependencies = [ "percent-encoding", "pin-project-lite", "rustls 0.21.12", - "rustls-pemfile", + "rustls-pemfile 1.0.4", "serde", "serde_json", "serde_urlencoded", "sync_wrapper 0.1.2", "system-configuration", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", "tower-service", "url", "wasm-bindgen", @@ -5278,7 +5365,7 @@ dependencies = [ "sync_wrapper 1.0.2", "tokio", "tokio-native-tls", - "tower", + "tower 0.5.3", "tower-http 0.6.8", "tower-service", "url", @@ -5311,7 +5398,7 @@ dependencies = [ "sync_wrapper 1.0.2", "tokio", "tokio-util", - "tower", + "tower 0.5.3", "tower-http 0.6.8", "tower-service", "url", @@ -5466,6 +5553,24 @@ dependencies = [ "smallvec", ] +[[package]] +name = "rumqttc" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1568e15fab2d546f940ed3a21f48bbbd1c494c90c99c4481339364a497f94a9" +dependencies = [ + "bytes", + "flume", + "futures-util", + "log", + "rustls-native-certs 0.7.3", + "rustls-pemfile 2.2.0", + "rustls-webpki 0.102.8", + "thiserror 1.0.69", + "tokio", + "tokio-rustls 0.25.0", +] + [[package]] name = "rustc-hash" version = "2.1.1" @@ -5548,16 +5653,29 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" +dependencies = [ + "openssl-probe 0.1.6", + "rustls-pemfile 2.2.0", + "rustls-pki-types", + "schannel", + "security-framework 2.11.1", +] + [[package]] name = "rustls-native-certs" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" dependencies = [ - "openssl-probe", + "openssl-probe 0.2.1", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.7.0", ] [[package]] @@ -5569,6 +5687,15 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.14.0" @@ -5591,10 +5718,10 @@ dependencies = [ "log", "once_cell", "rustls 0.23.37", - "rustls-native-certs", + "rustls-native-certs 0.8.3", "rustls-platform-verifier-android", "rustls-webpki 0.103.9", - "security-framework", + "security-framework 3.7.0", "security-framework-sys", "webpki-root-certs", "windows-sys 0.61.2", @@ -5918,6 +6045,19 @@ dependencies = [ "untrusted", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.11.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + [[package]] name = "security-framework" version = "3.7.0" @@ -7201,6 +7341,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" +dependencies = [ + "rustls 0.22.4", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-serial" version = "5.4.5" @@ -7379,6 +7530,27 @@ dependencies = [ "zip 0.6.6", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "hdrhistogram", + "indexmap 1.9.3", + "pin-project", + "pin-project-lite", + "rand 0.8.5", + "slab", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower" version = "0.5.3" @@ -7401,8 +7573,10 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ + "async-compression", "bitflags 2.11.0", "bytes", + "futures-core", "futures-util", "http 1.4.0", "http-body 1.0.1", @@ -7433,7 +7607,7 @@ dependencies = [ "http-body 1.0.1", "iri-string", "pin-project-lite", - "tower", + "tower 0.5.3", "tower-layer", "tower-service", ] @@ -8415,6 +8589,7 @@ dependencies = [ "chrono", "clap", "futures-util", + "rumqttc", "ruvector-mincut", "serde", "serde_json", diff --git a/v2/Cargo.toml b/v2/Cargo.toml index 113859cd1..8705589d6 100644 --- a/v2/Cargo.toml +++ b/v2/Cargo.toml @@ -73,6 +73,9 @@ hyper = { version = "1.1", features = ["full"] } sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "sqlite", "uuid", "chrono", "json"] } redis = { version = "0.24", features = ["tokio-comp", "connection-manager"] } +# MQTT +rumqttc = "0.24" + # Configuration config = "0.14" dotenvy = "0.15" diff --git a/v2/crates/wifi-densepose-sensing-server/Cargo.toml b/v2/crates/wifi-densepose-sensing-server/Cargo.toml index 0647e8e9d..f8a7322ea 100644 --- a/v2/crates/wifi-densepose-sensing-server/Cargo.toml +++ b/v2/crates/wifi-densepose-sensing-server/Cargo.toml @@ -50,5 +50,8 @@ wifi-densepose-wifiscan = { version = "0.3.0", path = "../wifi-densepose-wifisca # build without vcpkg/openblas (issue #366, #415). wifi-densepose-signal = { version = "0.3.0", path = "../wifi-densepose-signal", default-features = false } +# MQTT bridge for Home Assistant integration +rumqttc = { workspace = true } + [dev-dependencies] tempfile = "3.10" diff --git a/v2/crates/wifi-densepose-sensing-server/src/cli.rs b/v2/crates/wifi-densepose-sensing-server/src/cli.rs index 5fdad82bd..a8eb03770 100644 --- a/v2/crates/wifi-densepose-sensing-server/src/cli.rs +++ b/v2/crates/wifi-densepose-sensing-server/src/cli.rs @@ -102,4 +102,8 @@ pub struct Args { /// Start field model calibration on boot (empty room required) #[arg(long)] pub calibrate: bool, + + /// MQTT broker URL for Home Assistant integration (e.g. mqtt://broker.local:1883) + #[arg(long, env = "SENSING_MQTT_URL")] + pub mqtt_url: Option, } diff --git a/v2/crates/wifi-densepose-sensing-server/src/event_stream.rs b/v2/crates/wifi-densepose-sensing-server/src/event_stream.rs new file mode 100644 index 000000000..ab2a35823 --- /dev/null +++ b/v2/crates/wifi-densepose-sensing-server/src/event_stream.rs @@ -0,0 +1,288 @@ +//! Event stream for home automation integration. +//! +//! Tracks state changes in the sensing pipeline and emits structured events +//! via an async broadcast channel. Used by the MQTT bridge and any future +//! webhook / SSE integrations. + +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; +use std::time::Instant; +use serde::{Deserialize, Serialize}; +use tokio::sync::{broadcast, RwLock}; + +/// Event types emitted by the sensing pipeline. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum EventData { + PresenceChanged { body_count: usize, previous: usize }, + PostureChanged { body_id: usize, from: String, to: String }, + ActivityChanged { from: String, to: String }, + VitalsAlert { metric: String, value: f64, threshold: f64 }, + LyingStillAlert { body_id: usize, duration_minutes: f64 }, + NodeOnline { node_id: u8 }, + NodeOffline { node_id: u8 }, +} + +/// A timestamped event. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Event { + pub id: u64, + pub timestamp: f64, + #[serde(flatten)] + pub data: EventData, +} + +/// Tracks previous state to detect changes. +pub struct StateTracker { + prev_body_count: usize, + prev_postures: HashMap, + prev_activity: String, + lying_since: HashMap, + lying_alerts_fired: HashMap>, + prev_online_nodes: std::collections::HashSet, + prev_hr: f64, + prev_br: f64, +} + +impl StateTracker { + pub fn new() -> Self { + Self { + prev_body_count: 0, + prev_postures: HashMap::new(), + prev_activity: "absent".to_string(), + lying_since: HashMap::new(), + lying_alerts_fired: HashMap::new(), + prev_online_nodes: std::collections::HashSet::new(), + prev_hr: 0.0, + prev_br: 0.0, + } + } + + /// Compare current state to previous, emit events for any changes. + pub fn diff( + &mut self, + body_count: usize, + clusters: &[(usize, String)], + activity: &str, + online_nodes: &std::collections::HashSet, + hr_bpm: f64, + br_bpm: f64, + ) -> Vec { + let mut events = Vec::new(); + + // Presence changed + if body_count != self.prev_body_count { + events.push(EventData::PresenceChanged { + body_count, + previous: self.prev_body_count, + }); + self.prev_body_count = body_count; + } + + // Posture changed per body + let mut current_postures = HashMap::new(); + for (id, pose) in clusters { + current_postures.insert(*id, pose.clone()); + if let Some(prev) = self.prev_postures.get(id) { + if prev != pose { + events.push(EventData::PostureChanged { + body_id: *id, + from: prev.clone(), + to: pose.clone(), + }); + } + } + // Track lying duration + if pose == "lying" { + self.lying_since.entry(*id).or_insert_with(Instant::now); + if let Some(since) = self.lying_since.get(id) { + let minutes = since.elapsed().as_secs_f64() / 60.0; + let thresholds = [5u64, 15, 30, 60]; + let fired = self.lying_alerts_fired.entry(*id).or_default(); + for &t in &thresholds { + if minutes >= t as f64 && !fired.contains(&t) { + events.push(EventData::LyingStillAlert { + body_id: *id, + duration_minutes: minutes, + }); + fired.push(t); + } + } + } + } else { + self.lying_since.remove(id); + self.lying_alerts_fired.remove(id); + } + } + self.prev_postures = current_postures; + + // Activity changed + if activity != self.prev_activity { + events.push(EventData::ActivityChanged { + from: self.prev_activity.clone(), + to: activity.to_string(), + }); + self.prev_activity = activity.to_string(); + } + + // Node online/offline + for &nid in online_nodes { + if !self.prev_online_nodes.contains(&nid) { + events.push(EventData::NodeOnline { node_id: nid }); + } + } + for &nid in &self.prev_online_nodes { + if !online_nodes.contains(&nid) { + events.push(EventData::NodeOffline { node_id: nid }); + } + } + self.prev_online_nodes = online_nodes.clone(); + + // Vitals alerts (HR outside 40-120, BR outside 8-25 when previously normal) + if hr_bpm > 1.0 { + if hr_bpm > 120.0 && self.prev_hr <= 120.0 { + events.push(EventData::VitalsAlert { + metric: "heart_rate_high".to_string(), + value: hr_bpm, + threshold: 120.0, + }); + } + if hr_bpm < 40.0 && self.prev_hr >= 40.0 { + events.push(EventData::VitalsAlert { + metric: "heart_rate_low".to_string(), + value: hr_bpm, + threshold: 40.0, + }); + } + self.prev_hr = hr_bpm; + } + if br_bpm > 1.0 { + if br_bpm > 25.0 && self.prev_br <= 25.0 { + events.push(EventData::VitalsAlert { + metric: "breathing_rate_high".to_string(), + value: br_bpm, + threshold: 25.0, + }); + } + if br_bpm < 8.0 && self.prev_br >= 8.0 { + events.push(EventData::VitalsAlert { + metric: "breathing_rate_low".to_string(), + value: br_bpm, + threshold: 8.0, + }); + } + self.prev_br = br_bpm; + } + + events + } +} + +/// Central event bus for the sensing server. +pub struct EventBus { + tx: broadcast::Sender, + recent: Arc>>, + next_id: Arc, +} + +const RECENT_CAPACITY: usize = 100; +const BUS_CAPACITY: usize = 1000; + +impl EventBus { + pub fn new() -> Self { + let (tx, _) = broadcast::channel(BUS_CAPACITY); + Self { + tx, + recent: Arc::new(RwLock::new(VecDeque::with_capacity(RECENT_CAPACITY))), + next_id: Arc::new(std::sync::atomic::AtomicU64::new(1)), + } + } + + /// Publish an event. Called from the sensing pipeline. + pub async fn publish(&self, data: EventData) { + let id = self.next_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let event = Event { + id, + timestamp: chrono::Utc::now().timestamp_millis() as f64 / 1000.0, + data, + }; + let _ = self.tx.send(event.clone()); + let mut recent = self.recent.write().await; + recent.push_back(event); + if recent.len() > RECENT_CAPACITY { + recent.pop_front(); + } + } + + /// Subscribe to the event stream. + pub fn subscribe(&self) -> broadcast::Receiver { + self.tx.subscribe() + } + + /// Get recent events (for polling clients). + pub async fn recent_events(&self) -> Vec { + self.recent.read().await.iter().cloned().collect() + } + + /// Get recent events since a given ID. + pub async fn events_since(&self, since_id: u64) -> Vec { + self.recent.read().await.iter().filter(|e| e.id > since_id).cloned().collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_event_bus_publish_subscribe() { + let bus = EventBus::new(); + let mut rx = bus.subscribe(); + + bus.publish(EventData::PresenceChanged { body_count: 2, previous: 0 }).await; + bus.publish(EventData::NodeOnline { node_id: 1 }).await; + bus.publish(EventData::ActivityChanged { from: "absent".into(), to: "active".into() }).await; + + let e1 = rx.recv().await.unwrap(); + assert_eq!(e1.id, 1); + let e2 = rx.recv().await.unwrap(); + assert_eq!(e2.id, 2); + let e3 = rx.recv().await.unwrap(); + assert_eq!(e3.id, 3); + } + + #[tokio::test] + async fn test_event_bus_recent() { + let bus = EventBus::new(); + bus.publish(EventData::NodeOnline { node_id: 5 }).await; + bus.publish(EventData::NodeOffline { node_id: 3 }).await; + + let recent = bus.recent_events().await; + assert_eq!(recent.len(), 2); + assert_eq!(recent[0].id, 1); + assert_eq!(recent[1].id, 2); + } + + #[test] + fn test_state_tracker_no_change() { + let mut tracker = StateTracker::new(); + let nodes = std::collections::HashSet::new(); + let events = tracker.diff(0, &[], "absent", &nodes, 0.0, 0.0); + assert!(events.is_empty(), "no state change should produce no events"); + } + + #[test] + fn test_state_tracker_presence_change() { + let mut tracker = StateTracker::new(); + let nodes = std::collections::HashSet::new(); + let events = tracker.diff(3, &[], "absent", &nodes, 0.0, 0.0); + assert_eq!(events.len(), 1); + match &events[0] { + EventData::PresenceChanged { body_count, previous } => { + assert_eq!(*body_count, 3); + assert_eq!(*previous, 0); + } + _ => panic!("expected PresenceChanged"), + } + } +} diff --git a/v2/crates/wifi-densepose-sensing-server/src/main.rs b/v2/crates/wifi-densepose-sensing-server/src/main.rs index a8b207e47..607605c36 100644 --- a/v2/crates/wifi-densepose-sensing-server/src/main.rs +++ b/v2/crates/wifi-densepose-sensing-server/src/main.rs @@ -12,7 +12,9 @@ mod adaptive_classifier; pub mod cli; pub mod csi; +mod event_stream; mod field_bridge; +mod mqtt_bridge; mod multistatic_bridge; pub mod pose; mod rvf_container; @@ -166,6 +168,10 @@ struct Args { /// Start field model calibration on boot (empty room required) #[arg(long)] calibrate: bool, + + /// MQTT broker URL for Home Assistant integration (e.g. mqtt://broker.local:1883) + #[arg(long, env = "SENSING_MQTT_URL")] + mqtt_url: Option, } // ── Data types ─────────────────────────────────────────────────────────────── @@ -4857,6 +4863,19 @@ async fn main() { } } + // MQTT bridge for Home Assistant (opt-in via --mqtt-url / SENSING_MQTT_URL) + let event_bus = Arc::new(event_stream::EventBus::new()); + if let Some(ref mqtt_url) = args.mqtt_url { + let mqtt_config = mqtt_bridge::MqttConfig { + broker_url: mqtt_url.clone(), + ..Default::default() + }; + mqtt_bridge::start_mqtt_bridge(mqtt_config, event_bus.clone()); + info!("MQTT bridge enabled → {mqtt_url}"); + } else { + info!("MQTT bridge disabled (no --mqtt-url / SENSING_MQTT_URL set)"); + } + // ADR-050: Parse bind address once, use for all listeners let bind_ip: std::net::IpAddr = args.bind_addr.parse() .expect("Invalid --bind-addr (use 127.0.0.1 or 0.0.0.0)"); diff --git a/v2/crates/wifi-densepose-sensing-server/src/mqtt_bridge.rs b/v2/crates/wifi-densepose-sensing-server/src/mqtt_bridge.rs new file mode 100644 index 000000000..53c090ad0 --- /dev/null +++ b/v2/crates/wifi-densepose-sensing-server/src/mqtt_bridge.rs @@ -0,0 +1,212 @@ +//! MQTT bridge for Home Assistant integration. +//! +//! Subscribes to the event bus and publishes events to MQTT. +//! Topics: +//! ruview/events/{event_type} -- event JSON +//! homeassistant/sensor/ruview_*/config -- HA MQTT discovery +//! +//! Opt-in via `--mqtt-url` (or `SENSING_MQTT_URL` env var). +//! If no broker is available, logs a warning and retries automatically. + +use std::sync::Arc; +use std::time::Duration; + +use rumqttc::{AsyncClient, MqttOptions, QoS}; +use tracing::{info, warn, debug}; + +use crate::event_stream::{EventBus, EventData, Event}; + +/// MQTT bridge configuration. +#[derive(Debug, Clone)] +pub struct MqttConfig { + pub broker_url: String, + pub client_id: String, + pub topic_prefix: String, + pub ha_discovery: bool, +} + +impl Default for MqttConfig { + fn default() -> Self { + Self { + broker_url: "mqtt://localhost:1883".to_string(), + client_id: "ruview-sensing".to_string(), + topic_prefix: "ruview".to_string(), + ha_discovery: true, + } + } +} + +/// Parse a `mqtt://host:port` URL into `(host, port)`. +fn parse_mqtt_url(url: &str) -> (String, u16) { + let stripped = url.strip_prefix("mqtt://").unwrap_or(url); + let parts: Vec<&str> = stripped.split(':').collect(); + let host = parts.first().unwrap_or(&"localhost").to_string(); + let port = parts.get(1).and_then(|p| p.parse().ok()).unwrap_or(1883); + (host, port) +} + +/// Start the MQTT bridge as a background task. +/// +/// Returns immediately. If the broker is unreachable the bridge logs a +/// warning and retries every 30 s — it never crashes the server. +pub fn start_mqtt_bridge( + config: MqttConfig, + event_bus: Arc, +) { + tokio::spawn(async move { + let (host, port) = parse_mqtt_url(&config.broker_url); + + let mut mqttoptions = MqttOptions::new(&config.client_id, &host, port); + mqttoptions.set_keep_alive(Duration::from_secs(30)); + mqttoptions.set_clean_session(true); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 100); + + info!("MQTT bridge connecting to {}:{}", host, port); + + // Subscribe to event bus + let mut rx = event_bus.subscribe(); + let prefix = config.topic_prefix.clone(); + + // Publish HA discovery configs if enabled + if config.ha_discovery { + let discovery_client = client.clone(); + let discovery_prefix = prefix.clone(); + tokio::spawn(async move { + // Wait for connection before publishing discovery + tokio::time::sleep(Duration::from_secs(5)).await; + publish_ha_discovery(&discovery_client, &discovery_prefix).await; + }); + } + + // Event forwarding loop + let event_client = client.clone(); + let event_prefix = prefix.clone(); + tokio::spawn(async move { + loop { + match rx.recv().await { + Ok(event) => { + publish_event(&event_client, &event_prefix, &event).await; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + debug!("MQTT bridge lagged by {n} events"); + continue; + } + Err(_) => { + warn!("MQTT bridge: event bus closed"); + break; + } + } + } + }); + + // Drive the MQTT event loop with auto-reconnect + loop { + match eventloop.poll().await { + Ok(_notification) => {} + Err(e) => { + warn!("MQTT connection error: {e} — retrying in 30s"); + tokio::time::sleep(Duration::from_secs(30)).await; + } + } + } + }); +} + +/// Publish a single event to MQTT. +async fn publish_event(client: &AsyncClient, prefix: &str, event: &Event) { + let event_type = match &event.data { + EventData::PresenceChanged { .. } => "presence_changed", + EventData::PostureChanged { .. } => "posture_changed", + EventData::ActivityChanged { .. } => "activity_changed", + EventData::VitalsAlert { .. } => "vitals_alert", + EventData::LyingStillAlert { .. } => "lying_still_alert", + EventData::NodeOnline { .. } => "node_online", + EventData::NodeOffline { .. } => "node_offline", + }; + + let topic = format!("{prefix}/events/{event_type}"); + if let Ok(payload) = serde_json::to_string(event) { + if let Err(e) = client.publish(&topic, QoS::AtMostOnce, false, payload).await { + debug!("MQTT publish failed: {e}"); + } + } +} + +/// Publish Home Assistant MQTT discovery messages. +async fn publish_ha_discovery(client: &AsyncClient, prefix: &str) { + // Person count sensor + let config = serde_json::json!({ + "name": "RuView Person Count", + "unique_id": "ruview_person_count", + "state_topic": format!("{prefix}/events/presence_changed"), + "value_template": "{{ value_json.data.body_count }}", + "device": { + "identifiers": ["ruview_sensing"], + "name": "RuView Sensing", + "manufacturer": "RuView", + "model": "WiFi CSI Sensing Server" + }, + "icon": "mdi:account-group" + }); + + let topic = "homeassistant/sensor/ruview_person_count/config"; + if let Ok(payload) = serde_json::to_string(&config) { + let _ = client.publish(topic, QoS::AtLeastOnce, true, payload).await; + } + + // Activity sensor + let activity_config = serde_json::json!({ + "name": "RuView Activity", + "unique_id": "ruview_activity", + "state_topic": format!("{prefix}/events/activity_changed"), + "value_template": "{{ value_json.data.to }}", + "device": { + "identifiers": ["ruview_sensing"], + "name": "RuView Sensing", + }, + "icon": "mdi:run" + }); + + let topic = "homeassistant/sensor/ruview_activity/config"; + if let Ok(payload) = serde_json::to_string(&activity_config) { + let _ = client.publish(topic, QoS::AtLeastOnce, true, payload).await; + } + + info!("MQTT HA discovery configs published"); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_mqtt_url() { + let (host, port) = parse_mqtt_url("mqtt://test.example.com:1883"); + assert_eq!(host, "test.example.com"); + assert_eq!(port, 1883); + } + + #[test] + fn test_parse_mqtt_url_default_port() { + let (host, port) = parse_mqtt_url("mqtt://broker.local"); + assert_eq!(host, "broker.local"); + assert_eq!(port, 1883); + } + + #[test] + fn test_parse_mqtt_url_no_scheme() { + let (host, port) = parse_mqtt_url("10.0.0.5:1884"); + assert_eq!(host, "10.0.0.5"); + assert_eq!(port, 1884); + } + + #[test] + fn test_default_config() { + let config = MqttConfig::default(); + assert_eq!(config.broker_url, "mqtt://localhost:1883"); + assert_eq!(config.client_id, "ruview-sensing"); + assert_eq!(config.topic_prefix, "ruview"); + assert!(config.ha_discovery); + } +}