diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index d305a1e1082..6165aee5a97 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -235,8 +235,6 @@ pub(crate) struct PreviousTurnSettings { use crate::exec_policy::ExecPolicyUpdateError; use crate::feedback_tags; -use crate::file_watcher::FileWatcher; -use crate::file_watcher::FileWatcherEvent; use crate::git_info::get_git_repo_root; use crate::guardian::GuardianReviewSessionManager; use crate::hook_runtime::PendingInputHookDisposition; @@ -322,6 +320,8 @@ use crate::skills::injection::ToolMentionKind; use crate::skills::injection::app_id_from_path; use crate::skills::injection::tool_kind_for_path; use crate::skills::resolve_skill_dependencies_for_turn; +use crate::skills_watcher::SkillsWatcher; +use crate::skills_watcher::SkillsWatcherEvent; use crate::state::ActiveTurn; use crate::state::SessionServices; use crate::state::SessionState; @@ -403,7 +403,7 @@ pub(crate) struct CodexSpawnArgs { pub(crate) skills_manager: Arc, pub(crate) plugins_manager: Arc, pub(crate) mcp_manager: Arc, - pub(crate) file_watcher: Arc, + pub(crate) skills_watcher: Arc, pub(crate) conversation_history: InitialHistory, pub(crate) session_source: SessionSource, pub(crate) agent_control: AgentControl, @@ -456,7 +456,7 @@ impl Codex { skills_manager, plugins_manager, mcp_manager, - file_watcher, + skills_watcher, conversation_history, session_source, agent_control, @@ -644,7 +644,7 @@ impl Codex { skills_manager, plugins_manager, mcp_manager.clone(), - file_watcher, + skills_watcher, agent_control, ) .await @@ -1297,13 +1297,13 @@ impl Session { self.out_of_band_elicitation_paused.send_replace(paused); } - fn start_file_watcher_listener(self: &Arc) { - let mut rx = self.services.file_watcher.subscribe(); + fn start_skills_watcher_listener(self: &Arc) { + let mut rx = self.services.skills_watcher.subscribe(); let weak_sess = Arc::downgrade(self); tokio::spawn(async move { loop { match rx.recv().await { - Ok(FileWatcherEvent::SkillsChanged { .. }) => { + Ok(SkillsWatcherEvent::SkillsChanged { .. }) => { let Some(sess) = weak_sess.upgrade() else { break; }; @@ -1439,7 +1439,7 @@ impl Session { skills_manager: Arc, plugins_manager: Arc, mcp_manager: Arc, - file_watcher: Arc, + skills_watcher: Arc, agent_control: AgentControl, ) -> anyhow::Result> { debug!( @@ -1858,7 +1858,7 @@ impl Session { skills_manager, plugins_manager: Arc::clone(&plugins_manager), mcp_manager: Arc::clone(&mcp_manager), - file_watcher, + skills_watcher, agent_control, network_proxy, network_approval: Arc::clone(&network_approval), @@ -1937,7 +1937,7 @@ impl Session { } // Start the watcher after SessionConfigured so it cannot emit earlier events. - sess.start_file_watcher_listener(); + sess.start_skills_watcher_listener(); // Construct sandbox_state before MCP startup so it can be sent to each // MCP server immediately after it becomes ready (avoiding blocking). let sandbox_state = SandboxState { diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index e560cd9c7f1..a81970756b8 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -80,7 +80,7 @@ pub(crate) async fn run_codex_thread_interactive( skills_manager: Arc::clone(&parent_session.services.skills_manager), plugins_manager: Arc::clone(&parent_session.services.plugins_manager), mcp_manager: Arc::clone(&parent_session.services.mcp_manager), - file_watcher: Arc::clone(&parent_session.services.file_watcher), + skills_watcher: Arc::clone(&parent_session.services.skills_watcher), conversation_history: initial_history.unwrap_or(InitialHistory::New), session_source: SessionSource::SubAgent(subagent_source), agent_control: parent_session.services.agent_control.clone(), diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index 5d6769e0eac..1d08b6cfc76 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -2494,7 +2494,7 @@ async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() { skills_manager, plugins_manager, mcp_manager, - Arc::new(FileWatcher::noop()), + Arc::new(SkillsWatcher::noop()), AgentControl::default(), ) .await; @@ -2593,7 +2593,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { .expect("create environment"), ); - let file_watcher = Arc::new(FileWatcher::noop()); + let skills_watcher = Arc::new(SkillsWatcher::noop()); let services = SessionServices { mcp_connection_manager: Arc::new(RwLock::new( McpConnectionManager::new_mcp_connection_manager_for_tests( @@ -2627,7 +2627,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { skills_manager, plugins_manager, mcp_manager, - file_watcher, + skills_watcher, agent_control, network_proxy: None, network_approval: Arc::clone(&network_approval), @@ -3428,7 +3428,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( .expect("create environment"), ); - let file_watcher = Arc::new(FileWatcher::noop()); + let skills_watcher = Arc::new(SkillsWatcher::noop()); let services = SessionServices { mcp_connection_manager: Arc::new(RwLock::new( McpConnectionManager::new_mcp_connection_manager_for_tests( @@ -3462,7 +3462,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( skills_manager, plugins_manager, mcp_manager, - file_watcher, + skills_watcher, agent_control, network_proxy: None, network_approval: Arc::clone(&network_approval), diff --git a/codex-rs/core/src/codex_tests_guardian.rs b/codex-rs/core/src/codex_tests_guardian.rs index af0fccc9ac2..e9050024ebe 100644 --- a/codex-rs/core/src/codex_tests_guardian.rs +++ b/codex-rs/core/src/codex_tests_guardian.rs @@ -435,7 +435,7 @@ async fn guardian_subagent_does_not_inherit_parent_exec_policy_rules() { true, )); let mcp_manager = Arc::new(McpManager::new(Arc::clone(&plugins_manager))); - let file_watcher = Arc::new(FileWatcher::noop()); + let skills_watcher = Arc::new(SkillsWatcher::noop()); let CodexSpawnOk { codex, .. } = Codex::spawn(CodexSpawnArgs { config, @@ -444,7 +444,7 @@ async fn guardian_subagent_does_not_inherit_parent_exec_policy_rules() { skills_manager, plugins_manager, mcp_manager, - file_watcher, + skills_watcher, conversation_history: InitialHistory::New, session_source: SessionSource::SubAgent(SubAgentSource::Other( GUARDIAN_REVIEWER_NAME.to_string(), diff --git a/codex-rs/core/src/file_watcher.rs b/codex-rs/core/src/file_watcher.rs index 7b2cbd76b0f..f8f0e4b11fe 100644 --- a/codex-rs/core/src/file_watcher.rs +++ b/codex-rs/core/src/file_watcher.rs @@ -1,13 +1,15 @@ -//! Watches skill roots for changes and broadcasts coarse-grained -//! `FileWatcherEvent`s that higher-level components react to on the next turn. +//! Watches subscribed files or directories and routes coarse-grained change +//! notifications to the subscribers that own matching watched paths. +use std::collections::BTreeSet; use std::collections::HashMap; -use std::collections::HashSet; use std::path::Path; use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; use std::sync::RwLock; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use std::time::Duration; use notify::Event; @@ -16,96 +18,258 @@ use notify::RecommendedWatcher; use notify::RecursiveMode; use notify::Watcher; use tokio::runtime::Handle; -use tokio::sync::broadcast; +use tokio::sync::Mutex as AsyncMutex; +use tokio::sync::Notify; use tokio::sync::mpsc; use tokio::time::Instant; use tokio::time::sleep_until; use tracing::warn; -use crate::config::Config; -use crate::skills::SkillsManager; - #[derive(Debug, Clone, PartialEq, Eq)] -pub enum FileWatcherEvent { - SkillsChanged { paths: Vec }, +/// Coalesced file change notification for a subscriber. +pub struct FileWatcherEvent { + /// Changed paths delivered in sorted order with duplicates removed. + pub paths: Vec, +} + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +/// Path subscription registered by a [`FileWatcherSubscriber`]. +pub struct WatchPath { + /// Root path to watch. + pub path: PathBuf, + /// Whether events below `path` should match recursively. + pub recursive: bool, } +type SubscriberId = u64; + +#[derive(Default)] struct WatchState { - skills_root_ref_counts: HashMap, + next_subscriber_id: SubscriberId, + path_ref_counts: HashMap, + subscribers: HashMap, } -struct FileWatcherInner { - watcher: RecommendedWatcher, - watched_paths: HashMap, +struct SubscriberState { + watched_paths: HashMap, + tx: WatchSender, +} + +/// Receives coalesced change notifications for a single subscriber. +pub struct Receiver { + inner: Arc, +} + +struct WatchSender { + inner: Arc, } -const WATCHER_THROTTLE_INTERVAL: Duration = Duration::from_secs(10); +struct ReceiverInner { + changed_paths: AsyncMutex>, + notify: Notify, + sender_count: AtomicUsize, +} -/// Coalesces bursts of paths and emits at most once per interval. -struct ThrottledPaths { - pending: HashSet, - next_allowed_at: Instant, +impl Receiver { + /// Waits for the next batch of changed paths, or returns `None` once the + /// corresponding subscriber has been removed and no more events can arrive. + pub async fn recv(&mut self) -> Option { + loop { + let notified = self.inner.notify.notified(); + { + let mut changed_paths = self.inner.changed_paths.lock().await; + if !changed_paths.is_empty() { + return Some(FileWatcherEvent { + paths: std::mem::take(&mut *changed_paths).into_iter().collect(), + }); + } + if self.inner.sender_count.load(Ordering::Acquire) == 0 { + return None; + } + } + notified.await; + } + } } -impl ThrottledPaths { - fn new(now: Instant) -> Self { +impl WatchSender { + async fn add_changed_paths(&self, paths: &[PathBuf]) { + if paths.is_empty() { + return; + } + + let mut changed_paths = self.inner.changed_paths.lock().await; + let previous_len = changed_paths.len(); + changed_paths.extend(paths.iter().cloned()); + if changed_paths.len() != previous_len { + self.inner.notify.notify_one(); + } + } +} + +impl Clone for WatchSender { + fn clone(&self) -> Self { + self.inner.sender_count.fetch_add(1, Ordering::Relaxed); Self { - pending: HashSet::new(), - next_allowed_at: now, + inner: Arc::clone(&self.inner), } } +} - fn add(&mut self, paths: Vec) { - self.pending.extend(paths); +impl Drop for WatchSender { + fn drop(&mut self) { + if self.inner.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 { + self.inner.notify.notify_waiters(); + } } +} + +fn watch_channel() -> (WatchSender, Receiver) { + let inner = Arc::new(ReceiverInner { + changed_paths: AsyncMutex::new(BTreeSet::new()), + notify: Notify::new(), + sender_count: AtomicUsize::new(1), + }); + ( + WatchSender { + inner: Arc::clone(&inner), + }, + Receiver { inner }, + ) +} - fn next_deadline(&self, now: Instant) -> Option { - (!self.pending.is_empty() && now < self.next_allowed_at).then_some(self.next_allowed_at) +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +struct PathWatchCounts { + non_recursive: usize, + recursive: usize, +} + +impl PathWatchCounts { + fn increment(&mut self, recursive: bool, amount: usize) { + if recursive { + self.recursive += amount; + } else { + self.non_recursive += amount; + } } - fn take_ready(&mut self, now: Instant) -> Option> { - if self.pending.is_empty() || now < self.next_allowed_at { - return None; + fn decrement(&mut self, recursive: bool, amount: usize) { + if recursive { + self.recursive = self.recursive.saturating_sub(amount); + } else { + self.non_recursive = self.non_recursive.saturating_sub(amount); } - Some(self.take_with_next_allowed(now)) } - fn take_pending(&mut self, now: Instant) -> Option> { - if self.pending.is_empty() { - return None; + fn effective_mode(self) -> Option { + if self.recursive > 0 { + Some(RecursiveMode::Recursive) + } else if self.non_recursive > 0 { + Some(RecursiveMode::NonRecursive) + } else { + None } - Some(self.take_with_next_allowed(now)) } - fn take_with_next_allowed(&mut self, now: Instant) -> Vec { - let mut paths: Vec = self.pending.drain().collect(); - paths.sort_unstable_by(|a, b| a.as_os_str().cmp(b.as_os_str())); - self.next_allowed_at = now + WATCHER_THROTTLE_INTERVAL; - paths + fn is_empty(self) -> bool { + self.non_recursive == 0 && self.recursive == 0 } } -pub(crate) struct FileWatcher { - inner: Option>, - state: Arc>, - tx: broadcast::Sender, +struct FileWatcherInner { + watcher: RecommendedWatcher, + watched_paths: HashMap, +} + +/// Coalesces bursts of watch notifications and emits at most once per interval. +pub struct ThrottledWatchReceiver { + rx: Receiver, + interval: Duration, + next_allowed: Option, +} + +impl ThrottledWatchReceiver { + /// Creates a throttling wrapper around a raw watcher [`Receiver`]. + pub fn new(rx: Receiver, interval: Duration) -> Self { + Self { + rx, + interval, + next_allowed: None, + } + } + + /// Receives the next event, enforcing the configured minimum delay after + /// the previous emission. + pub async fn recv(&mut self) -> Option { + if let Some(next_allowed) = self.next_allowed { + sleep_until(next_allowed).await; + } + + let event = self.rx.recv().await; + if event.is_some() { + self.next_allowed = Some(Instant::now() + self.interval); + } + event + } +} + +/// Handle used to register watched paths for one logical consumer. +pub struct FileWatcherSubscriber { + id: SubscriberId, + file_watcher: Arc, +} + +impl FileWatcherSubscriber { + /// Registers the provided paths for this subscriber and returns an RAII + /// guard that unregisters them on drop. + pub fn register_paths(&self, watched_paths: Vec) -> WatchRegistration { + let watched_paths = dedupe_watched_paths(watched_paths); + self.file_watcher.register_paths(self.id, &watched_paths); + + WatchRegistration { + file_watcher: Arc::downgrade(&self.file_watcher), + subscriber_id: self.id, + watched_paths, + } + } + + #[cfg(test)] + pub(crate) fn register_path(&self, path: PathBuf, recursive: bool) -> WatchRegistration { + self.register_paths(vec![WatchPath { path, recursive }]) + } } -pub(crate) struct WatchRegistration { +impl Drop for FileWatcherSubscriber { + fn drop(&mut self) { + self.file_watcher.remove_subscriber(self.id); + } +} + +/// RAII guard for a set of active path registrations. +pub struct WatchRegistration { file_watcher: std::sync::Weak, - roots: Vec, + subscriber_id: SubscriberId, + watched_paths: Vec, } impl Drop for WatchRegistration { fn drop(&mut self) { if let Some(file_watcher) = self.file_watcher.upgrade() { - file_watcher.unregister_roots(&self.roots); + file_watcher.unregister_paths(self.subscriber_id, &self.watched_paths); } } } +/// Multi-subscriber file watcher built on top of `notify`. +pub struct FileWatcher { + inner: Option>, + state: Arc>, +} + impl FileWatcher { - pub(crate) fn new(_codex_home: PathBuf) -> notify::Result { + /// Creates a live filesystem watcher and starts its background event loop + /// on the current Tokio runtime. + pub fn new() -> notify::Result { let (raw_tx, raw_rx) = mpsc::unbounded_channel(); let raw_tx_clone = raw_tx; let watcher = notify::recommended_watcher(move |res| { @@ -115,109 +279,101 @@ impl FileWatcher { watcher, watched_paths: HashMap::new(), }; - let (tx, _) = broadcast::channel(128); - let state = Arc::new(RwLock::new(WatchState { - skills_root_ref_counts: HashMap::new(), - })); + let state = Arc::new(RwLock::new(WatchState::default())); let file_watcher = Self { inner: Some(Mutex::new(inner)), - state: Arc::clone(&state), - tx: tx.clone(), + state, }; - file_watcher.spawn_event_loop(raw_rx, state, tx); + file_watcher.spawn_event_loop(raw_rx); Ok(file_watcher) } - pub(crate) fn noop() -> Self { - let (tx, _) = broadcast::channel(1); + /// Creates an inert watcher that only supports test-driven synthetic + /// notifications. + pub fn noop() -> Self { Self { inner: None, - state: Arc::new(RwLock::new(WatchState { - skills_root_ref_counts: HashMap::new(), - })), - tx, + state: Arc::new(RwLock::new(WatchState::default())), } } - pub(crate) fn subscribe(&self) -> broadcast::Receiver { - self.tx.subscribe() + /// Adds a new subscriber and returns both its registration handle and its + /// dedicated event receiver. + pub fn add_subscriber(self: &Arc) -> (FileWatcherSubscriber, Receiver) { + let (tx, rx) = watch_channel(); + let mut state = self + .state + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let subscriber_id = state.next_subscriber_id; + state.next_subscriber_id += 1; + state.subscribers.insert( + subscriber_id, + SubscriberState { + watched_paths: HashMap::new(), + tx, + }, + ); + + let subscriber = FileWatcherSubscriber { + id: subscriber_id, + file_watcher: self.clone(), + }; + (subscriber, rx) } - pub(crate) fn register_config( - self: &Arc, - config: &Config, - skills_manager: &SkillsManager, - ) -> WatchRegistration { - let deduped_roots: HashSet = skills_manager - .skill_roots_for_config(config) - .into_iter() - .map(|root| root.path) - .collect(); - let mut registered_roots: Vec = deduped_roots.into_iter().collect(); - registered_roots.sort_unstable_by(|a, b| a.as_os_str().cmp(b.as_os_str())); - for root in ®istered_roots { - self.register_skills_root(root.clone()); - } + fn register_paths(&self, subscriber_id: SubscriberId, watched_paths: &[WatchPath]) { + let mut state = self + .state + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut inner_guard: Option> = None; - WatchRegistration { - file_watcher: Arc::downgrade(self), - roots: registered_roots, + for watched_path in watched_paths { + { + let Some(subscriber) = state.subscribers.get_mut(&subscriber_id) else { + return; + }; + *subscriber + .watched_paths + .entry(watched_path.clone()) + .or_default() += 1; + } + + let counts = state + .path_ref_counts + .entry(watched_path.path.clone()) + .or_default(); + let previous_mode = counts.effective_mode(); + counts.increment(watched_path.recursive, /*amount*/ 1); + let next_mode = counts.effective_mode(); + if previous_mode != next_mode { + self.reconfigure_watch(&watched_path.path, next_mode, &mut inner_guard); + } } } // Bridge `notify`'s callback-based events into the Tokio runtime and - // broadcast coarse-grained change signals to subscribers. - fn spawn_event_loop( - &self, - mut raw_rx: mpsc::UnboundedReceiver>, - state: Arc>, - tx: broadcast::Sender, - ) { + // notify the matching subscribers. + fn spawn_event_loop(&self, mut raw_rx: mpsc::UnboundedReceiver>) { if let Ok(handle) = Handle::try_current() { + let state = Arc::clone(&self.state); handle.spawn(async move { - let now = Instant::now(); - let mut skills = ThrottledPaths::new(now); - loop { - let now = Instant::now(); - let next_deadline = skills.next_deadline(now); - let timer_deadline = next_deadline - .unwrap_or_else(|| now + Duration::from_secs(60 * 60 * 24 * 365)); - let timer = sleep_until(timer_deadline); - tokio::pin!(timer); - - tokio::select! { - res = raw_rx.recv() => { - match res { - Some(Ok(event)) => { - let skills_paths = classify_event(&event, &state); - let now = Instant::now(); - skills.add(skills_paths); - - if let Some(paths) = skills.take_ready(now) { - let _ = tx.send(FileWatcherEvent::SkillsChanged { paths }); - } - } - Some(Err(err)) => { - warn!("file watcher error: {err}"); - } - None => { - // Flush any pending changes before shutdown so subscribers - // see the latest state. - let now = Instant::now(); - if let Some(paths) = skills.take_pending(now) { - let _ = tx.send(FileWatcherEvent::SkillsChanged { paths }); - } - break; - } + match raw_rx.recv().await { + Some(Ok(event)) => { + if !is_mutating_event(&event) { + continue; } - } - _ = &mut timer => { - let now = Instant::now(); - if let Some(paths) = skills.take_ready(now) { - let _ = tx.send(FileWatcherEvent::SkillsChanged { paths }); + if event.paths.is_empty() { + continue; } + Self::notify_subscribers(&state, &event.paths).await; } + Some(Err(err)) => { + warn!("file watcher error: {err}"); + } + None => break, } } }); @@ -226,127 +382,195 @@ impl FileWatcher { } } - fn register_skills_root(&self, root: PathBuf) { - let mut state = self - .state - .write() - .unwrap_or_else(std::sync::PoisonError::into_inner); - let count = state - .skills_root_ref_counts - .entry(root.clone()) - .or_insert(0); - *count += 1; - if *count == 1 { - self.watch_path(root, RecursiveMode::Recursive); - } - } - - fn unregister_roots(&self, roots: &[PathBuf]) { + fn unregister_paths(&self, subscriber_id: SubscriberId, watched_paths: &[WatchPath]) { let mut state = self .state .write() .unwrap_or_else(std::sync::PoisonError::into_inner); let mut inner_guard: Option> = None; - for root in roots { - let mut should_unwatch = false; - if let Some(count) = state.skills_root_ref_counts.get_mut(root) { - if *count > 1 { - *count -= 1; - } else { - state.skills_root_ref_counts.remove(root); - should_unwatch = true; + for watched_path in watched_paths { + { + let Some(subscriber) = state.subscribers.get_mut(&subscriber_id) else { + return; + }; + let Some(subscriber_count) = subscriber.watched_paths.get_mut(watched_path) else { + continue; + }; + *subscriber_count = subscriber_count.saturating_sub(1); + if *subscriber_count == 0 { + subscriber.watched_paths.remove(watched_path); } } - - if !should_unwatch { - continue; - } - let Some(inner) = &self.inner else { + let Some(counts) = state.path_ref_counts.get_mut(&watched_path.path) else { continue; }; - if inner_guard.is_none() { - let guard = inner - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - inner_guard = Some(guard); + let previous_mode = counts.effective_mode(); + counts.decrement(watched_path.recursive, /*amount*/ 1); + let next_mode = counts.effective_mode(); + if counts.is_empty() { + state.path_ref_counts.remove(&watched_path.path); } + if previous_mode != next_mode { + self.reconfigure_watch(&watched_path.path, next_mode, &mut inner_guard); + } + } + } - let Some(guard) = inner_guard.as_mut() else { + fn remove_subscriber(&self, subscriber_id: SubscriberId) { + let mut state = self + .state + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let Some(subscriber) = state.subscribers.remove(&subscriber_id) else { + return; + }; + + let mut inner_guard: Option> = None; + for (watched_path, count) in subscriber.watched_paths { + let Some(path_counts) = state.path_ref_counts.get_mut(&watched_path.path) else { continue; }; - if guard.watched_paths.remove(root).is_none() { - continue; + let previous_mode = path_counts.effective_mode(); + path_counts.decrement(watched_path.recursive, count); + let next_mode = path_counts.effective_mode(); + if path_counts.is_empty() { + state.path_ref_counts.remove(&watched_path.path); } - if let Err(err) = guard.watcher.unwatch(root) { - warn!("failed to unwatch {}: {err}", root.display()); + if previous_mode != next_mode { + self.reconfigure_watch(&watched_path.path, next_mode, &mut inner_guard); } } } - fn watch_path(&self, path: PathBuf, mode: RecursiveMode) { + fn reconfigure_watch<'a>( + &'a self, + path: &Path, + next_mode: Option, + inner_guard: &mut Option>, + ) { let Some(inner) = &self.inner else { return; }; - if !path.exists() { + if inner_guard.is_none() { + let guard = inner + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *inner_guard = Some(guard); + } + let Some(guard) = inner_guard.as_mut() else { + return; + }; + + let existing_mode = guard.watched_paths.get(path).copied(); + if existing_mode == next_mode { return; } - let watch_path = path; - let mut guard = inner - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - if let Some(existing) = guard.watched_paths.get(&watch_path) { - if *existing == RecursiveMode::Recursive || *existing == mode { - return; - } - if let Err(err) = guard.watcher.unwatch(&watch_path) { - warn!("failed to unwatch {}: {err}", watch_path.display()); + + if existing_mode.is_some() { + if let Err(err) = guard.watcher.unwatch(path) { + warn!("failed to unwatch {}: {err}", path.display()); } + guard.watched_paths.remove(path); } - if let Err(err) = guard.watcher.watch(&watch_path, mode) { - warn!("failed to watch {}: {err}", watch_path.display()); + + let Some(next_mode) = next_mode else { + return; + }; + if !path.exists() { return; } - guard.watched_paths.insert(watch_path, mode); - } -} -fn classify_event(event: &Event, state: &RwLock) -> Vec { - if !matches!( - event.kind, - EventKind::Create(_) | EventKind::Modify(_) | EventKind::Remove(_) - ) { - return Vec::new(); + if let Err(err) = guard.watcher.watch(path, next_mode) { + warn!("failed to watch {}: {err}", path.display()); + return; + } + guard.watched_paths.insert(path.to_path_buf(), next_mode); } - let mut skills_paths = Vec::new(); - let skills_roots = match state.read() { - Ok(state) => state - .skills_root_ref_counts - .keys() - .cloned() - .collect::>(), - Err(err) => { - let state = err.into_inner(); + async fn notify_subscribers(state: &RwLock, event_paths: &[PathBuf]) { + let subscribers_to_notify: Vec<(WatchSender, Vec)> = { + let state = state + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); state - .skills_root_ref_counts - .keys() - .cloned() - .collect::>() - } - }; + .subscribers + .values() + .filter_map(|subscriber| { + let changed_paths: Vec = event_paths + .iter() + .filter(|event_path| { + subscriber.watched_paths.keys().any(|watched_path| { + watch_path_matches_event(watched_path, event_path) + }) + }) + .cloned() + .collect(); + (!changed_paths.is_empty()).then_some((subscriber.tx.clone(), changed_paths)) + }) + .collect() + }; - for path in &event.paths { - if is_skills_path(path, &skills_roots) { - skills_paths.push(path.clone()); + for (subscriber, changed_paths) in subscribers_to_notify { + subscriber.add_changed_paths(&changed_paths).await; } } - skills_paths + #[cfg(test)] + pub(crate) async fn send_paths_for_test(&self, paths: Vec) { + Self::notify_subscribers(&self.state, &paths).await; + } + + #[cfg(test)] + pub(crate) fn spawn_event_loop_for_test( + &self, + raw_rx: mpsc::UnboundedReceiver>, + ) { + self.spawn_event_loop(raw_rx); + } + + #[cfg(test)] + pub(crate) fn watch_counts_for_test(&self, path: &Path) -> Option<(usize, usize)> { + let state = self + .state + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + state + .path_ref_counts + .get(path) + .map(|counts| (counts.non_recursive, counts.recursive)) + } } -fn is_skills_path(path: &Path, roots: &HashSet) -> bool { - roots.iter().any(|root| path.starts_with(root)) +fn is_mutating_event(event: &Event) -> bool { + matches!( + event.kind, + EventKind::Create(_) | EventKind::Modify(_) | EventKind::Remove(_) + ) +} + +fn dedupe_watched_paths(mut watched_paths: Vec) -> Vec { + watched_paths.sort_unstable_by(|a, b| { + a.path + .as_os_str() + .cmp(b.path.as_os_str()) + .then(a.recursive.cmp(&b.recursive)) + }); + watched_paths.dedup(); + watched_paths +} + +fn watch_path_matches_event(watched_path: &WatchPath, event_path: &Path) -> bool { + if event_path == watched_path.path { + return true; + } + if watched_path.path.starts_with(event_path) { + return true; + } + if !event_path.starts_with(&watched_path.path) { + return false; + } + watched_path.recursive || event_path.parent() == Some(watched_path.path.as_path()) } #[cfg(test)] diff --git a/codex-rs/core/src/file_watcher_tests.rs b/codex-rs/core/src/file_watcher_tests.rs index 995e7f7cea4..53f6391f576 100644 --- a/codex-rs/core/src/file_watcher_tests.rs +++ b/codex-rs/core/src/file_watcher_tests.rs @@ -1,13 +1,13 @@ use super::*; -use notify::EventKind; use notify::event::AccessKind; use notify::event::AccessMode; use notify::event::CreateKind; use notify::event::ModifyKind; -use notify::event::RemoveKind; use pretty_assertions::assert_eq; use tokio::time::timeout; +const TEST_THROTTLE_INTERVAL: Duration = Duration::from_millis(50); + fn path(name: &str) -> PathBuf { PathBuf::from(name) } @@ -20,147 +20,202 @@ fn notify_event(kind: EventKind, paths: Vec) -> Event { event } -#[test] -fn throttles_and_coalesces_within_interval() { - let start = Instant::now(); - let mut throttled = ThrottledPaths::new(start); +#[tokio::test] +async fn throttled_receiver_coalesces_within_interval() { + let (tx, rx) = watch_channel(); + let mut throttled = ThrottledWatchReceiver::new(rx, TEST_THROTTLE_INTERVAL); - throttled.add(vec![path("a")]); - let first = throttled.take_ready(start).expect("first emit"); - assert_eq!(first, vec![path("a")]); + tx.add_changed_paths(&[path("a")]).await; + let first = timeout(Duration::from_secs(1), throttled.recv()) + .await + .expect("first emit timeout"); + assert_eq!( + first, + Some(FileWatcherEvent { + paths: vec![path("a")], + }) + ); - throttled.add(vec![path("b"), path("c")]); - assert_eq!(throttled.take_ready(start), None); + tx.add_changed_paths(&[path("b"), path("c")]).await; + let blocked = timeout(TEST_THROTTLE_INTERVAL / 2, throttled.recv()).await; + assert_eq!(blocked.is_err(), true); - let second = throttled - .take_ready(start + WATCHER_THROTTLE_INTERVAL) - .expect("coalesced emit"); - assert_eq!(second, vec![path("b"), path("c")]); + let second = timeout(TEST_THROTTLE_INTERVAL * 2, throttled.recv()) + .await + .expect("second emit timeout"); + assert_eq!( + second, + Some(FileWatcherEvent { + paths: vec![path("b"), path("c")], + }) + ); } -#[test] -fn flushes_pending_on_shutdown() { - let start = Instant::now(); - let mut throttled = ThrottledPaths::new(start); - - throttled.add(vec![path("a")]); - let _ = throttled.take_ready(start).expect("first emit"); +#[tokio::test] +async fn throttled_receiver_flushes_pending_on_shutdown() { + let (tx, rx) = watch_channel(); + let mut throttled = ThrottledWatchReceiver::new(rx, TEST_THROTTLE_INTERVAL); - throttled.add(vec![path("b")]); - assert_eq!(throttled.take_ready(start), None); + tx.add_changed_paths(&[path("a")]).await; + let first = timeout(Duration::from_secs(1), throttled.recv()) + .await + .expect("first emit timeout"); + assert_eq!( + first, + Some(FileWatcherEvent { + paths: vec![path("a")], + }) + ); - let flushed = throttled - .take_pending(start) - .expect("shutdown flush emits pending paths"); - assert_eq!(flushed, vec![path("b")]); -} + tx.add_changed_paths(&[path("b")]).await; + drop(tx); -#[test] -fn classify_event_filters_to_skills_roots() { - let root = path("/tmp/skills"); - let state = RwLock::new(WatchState { - skills_root_ref_counts: HashMap::from([(root.clone(), 1)]), - }); - let event = notify_event( - EventKind::Create(CreateKind::Any), - vec![ - root.join("demo/SKILL.md"), - path("/tmp/other/not-a-skill.txt"), - ], + let second = timeout(Duration::from_secs(1), throttled.recv()) + .await + .expect("shutdown flush timeout"); + assert_eq!( + second, + Some(FileWatcherEvent { + paths: vec![path("b")], + }) ); - let classified = classify_event(&event, &state); - assert_eq!(classified, vec![root.join("demo/SKILL.md")]); + let closed = timeout(Duration::from_secs(1), throttled.recv()) + .await + .expect("closed recv timeout"); + assert_eq!(closed, None); } #[test] -fn classify_event_supports_multiple_roots_without_prefix_false_positives() { - let root_a = path("/tmp/skills"); - let root_b = path("/tmp/workspace/.codex/skills"); - let state = RwLock::new(WatchState { - skills_root_ref_counts: HashMap::from([(root_a.clone(), 1), (root_b.clone(), 1)]), - }); - let event = notify_event( - EventKind::Modify(ModifyKind::Any), - vec![ - root_a.join("alpha/SKILL.md"), - path("/tmp/skills-extra/not-under-skills.txt"), - root_b.join("beta/SKILL.md"), - ], +fn is_mutating_event_filters_non_mutating_event_kinds() { + assert_eq!( + is_mutating_event(¬ify_event( + EventKind::Create(CreateKind::Any), + vec![path("/tmp/created")] + )), + true + ); + assert_eq!( + is_mutating_event(¬ify_event( + EventKind::Modify(ModifyKind::Any), + vec![path("/tmp/modified")] + )), + true ); - - let classified = classify_event(&event, &state); assert_eq!( - classified, - vec![root_a.join("alpha/SKILL.md"), root_b.join("beta/SKILL.md")] + is_mutating_event(¬ify_event( + EventKind::Access(AccessKind::Open(AccessMode::Any)), + vec![path("/tmp/accessed")] + )), + false ); } #[test] -fn classify_event_ignores_non_mutating_event_kinds() { - let root = path("/tmp/skills"); - let state = RwLock::new(WatchState { - skills_root_ref_counts: HashMap::from([(root.clone(), 1)]), - }); - let path = root.join("demo/SKILL.md"); +fn register_dedupes_by_path_and_scope() { + let watcher = Arc::new(FileWatcher::noop()); + let (subscriber, _rx) = watcher.add_subscriber(); + let _first = subscriber.register_path(path("/tmp/skills"), false); + let _second = subscriber.register_path(path("/tmp/skills"), false); + let _third = subscriber.register_path(path("/tmp/skills"), true); + let _fourth = subscriber.register_path(path("/tmp/other-skills"), true); - let access_event = notify_event( - EventKind::Access(AccessKind::Open(AccessMode::Any)), - vec![path.clone()], + assert_eq!( + watcher.watch_counts_for_test(&path("/tmp/skills")), + Some((2, 1)) + ); + assert_eq!( + watcher.watch_counts_for_test(&path("/tmp/other-skills")), + Some((0, 1)) ); - assert_eq!(classify_event(&access_event, &state), Vec::::new()); - - let any_event = notify_event(EventKind::Any, vec![path.clone()]); - assert_eq!(classify_event(&any_event, &state), Vec::::new()); - - let other_event = notify_event(EventKind::Other, vec![path]); - assert_eq!(classify_event(&other_event, &state), Vec::::new()); } #[test] -fn register_skills_root_dedupes_state_entries() { - let watcher = FileWatcher::noop(); - let root = path("/tmp/skills"); - watcher.register_skills_root(root.clone()); - watcher.register_skills_root(root); - watcher.register_skills_root(path("/tmp/other-skills")); - - let state = watcher.state.read().expect("state lock"); - assert_eq!(state.skills_root_ref_counts.len(), 2); +fn watch_registration_drop_unregisters_paths() { + let watcher = Arc::new(FileWatcher::noop()); + let (subscriber, _rx) = watcher.add_subscriber(); + let registration = subscriber.register_path(path("/tmp/skills"), true); + + drop(registration); + + assert_eq!(watcher.watch_counts_for_test(&path("/tmp/skills")), None); } #[test] -fn watch_registration_drop_unregisters_roots() { +fn subscriber_drop_unregisters_paths() { let watcher = Arc::new(FileWatcher::noop()); - let root = path("/tmp/skills"); - watcher.register_skills_root(root.clone()); - let registration = WatchRegistration { - file_watcher: Arc::downgrade(&watcher), - roots: vec![root], + let registration = { + let (subscriber, _rx) = watcher.add_subscriber(); + subscriber.register_path(path("/tmp/skills"), true) }; + assert_eq!(watcher.watch_counts_for_test(&path("/tmp/skills")), None); drop(registration); +} + +#[tokio::test] +async fn receiver_closes_when_subscriber_drops() { + let watcher = Arc::new(FileWatcher::noop()); + let (subscriber, mut rx) = watcher.add_subscriber(); + + drop(subscriber); + + let closed = timeout(Duration::from_secs(1), rx.recv()) + .await + .expect("closed recv timeout"); + assert_eq!(closed, None); +} + +#[test] +fn recursive_registration_downgrades_to_non_recursive_after_drop() { + let temp_dir = tempfile::tempdir().expect("temp dir"); + let root = temp_dir.path().join("watched-dir"); + std::fs::create_dir(&root).expect("create root"); + + let watcher = Arc::new(FileWatcher::new().expect("watcher")); + let (subscriber, _rx) = watcher.add_subscriber(); + let non_recursive = subscriber.register_path(root.clone(), false); + let recursive = subscriber.register_path(root.clone(), true); + + { + let inner = watcher.inner.as_ref().expect("watcher inner"); + let inner = inner.lock().expect("inner lock"); + assert_eq!( + inner.watched_paths.get(&root), + Some(&RecursiveMode::Recursive) + ); + } + + drop(recursive); + + { + let inner = watcher.inner.as_ref().expect("watcher inner"); + let inner = inner.lock().expect("inner lock"); + assert_eq!( + inner.watched_paths.get(&root), + Some(&RecursiveMode::NonRecursive) + ); + } - let state = watcher.state.read().expect("state lock"); - assert_eq!(state.skills_root_ref_counts.len(), 0); + drop(non_recursive); } #[test] fn unregister_holds_state_lock_until_unwatch_finishes() { let temp_dir = tempfile::tempdir().expect("temp dir"); - let root = temp_dir.path().join("skills"); + let root = temp_dir.path().join("watched-dir"); std::fs::create_dir(&root).expect("create root"); - let watcher = Arc::new(FileWatcher::new(temp_dir.path().to_path_buf()).expect("watcher")); - watcher.register_skills_root(root.clone()); + let watcher = Arc::new(FileWatcher::new().expect("watcher")); + let (unregister_subscriber, _unregister_rx) = watcher.add_subscriber(); + let (register_subscriber, _register_rx) = watcher.add_subscriber(); + let registration = unregister_subscriber.register_path(root.clone(), true); let inner = watcher.inner.as_ref().expect("watcher inner"); let inner_guard = inner.lock().expect("inner lock"); - let unregister_watcher = Arc::clone(&watcher); - let unregister_root = root.clone(); let unregister_thread = std::thread::spawn(move || { - unregister_watcher.unregister_roots(&[unregister_root]); + drop(registration); }); let state_lock_observed = (0..100).any(|_| { @@ -172,75 +227,128 @@ fn unregister_holds_state_lock_until_unwatch_finishes() { }); assert_eq!(state_lock_observed, true); - let register_watcher = Arc::clone(&watcher); let register_root = root.clone(); let register_thread = std::thread::spawn(move || { - register_watcher.register_skills_root(register_root); + let registration = register_subscriber.register_path(register_root, false); + (register_subscriber, registration) }); drop(inner_guard); unregister_thread.join().expect("unregister join"); - register_thread.join().expect("register join"); + let (register_subscriber, non_recursive) = register_thread.join().expect("register join"); - let state = watcher.state.read().expect("state lock"); - assert_eq!(state.skills_root_ref_counts.get(&root), Some(&1)); - drop(state); + assert_eq!(watcher.watch_counts_for_test(&root), Some((1, 0))); let inner = watcher.inner.as_ref().expect("watcher inner"); let inner = inner.lock().expect("inner lock"); assert_eq!( inner.watched_paths.get(&root), - Some(&RecursiveMode::Recursive) + Some(&RecursiveMode::NonRecursive) ); + drop(inner); + + drop(non_recursive); + drop(register_subscriber); } #[tokio::test] -async fn spawn_event_loop_flushes_pending_changes_on_shutdown() { - let watcher = FileWatcher::noop(); - let root = path("/tmp/skills"); - { - let mut state = watcher.state.write().expect("state lock"); - state.skills_root_ref_counts.insert(root.clone(), 1); - } +async fn matching_subscribers_are_notified() { + let watcher = Arc::new(FileWatcher::noop()); + let (skills_subscriber, skills_rx) = watcher.add_subscriber(); + let (plugins_subscriber, plugins_rx) = watcher.add_subscriber(); + let _skills = skills_subscriber.register_path(path("/tmp/skills"), true); + let _plugins = plugins_subscriber.register_path(path("/tmp/plugins"), true); + let mut skills_rx = ThrottledWatchReceiver::new(skills_rx, TEST_THROTTLE_INTERVAL); + let mut plugins_rx = ThrottledWatchReceiver::new(plugins_rx, TEST_THROTTLE_INTERVAL); + + watcher + .send_paths_for_test(vec![path("/tmp/skills/rust/SKILL.md")]) + .await; + + let skills_event = timeout(Duration::from_secs(1), skills_rx.recv()) + .await + .expect("skills change timeout") + .expect("skills change"); + assert_eq!( + skills_event, + FileWatcherEvent { + paths: vec![path("/tmp/skills/rust/SKILL.md")], + } + ); - let (raw_tx, raw_rx) = mpsc::unbounded_channel(); - let (tx, mut rx) = broadcast::channel(8); - watcher.spawn_event_loop(raw_rx, Arc::clone(&watcher.state), tx); + let plugins_event = timeout(TEST_THROTTLE_INTERVAL, plugins_rx.recv()).await; + assert_eq!(plugins_event.is_err(), true); +} - raw_tx - .send(Ok(notify_event( - EventKind::Create(CreateKind::File), - vec![root.join("a/SKILL.md")], - ))) - .expect("send first event"); - let first = timeout(Duration::from_secs(2), rx.recv()) +#[tokio::test] +async fn non_recursive_watch_ignores_grandchildren() { + let watcher = Arc::new(FileWatcher::noop()); + let (subscriber, rx) = watcher.add_subscriber(); + let _registration = subscriber.register_path(path("/tmp/skills"), false); + let mut rx = ThrottledWatchReceiver::new(rx, TEST_THROTTLE_INTERVAL); + + watcher + .send_paths_for_test(vec![path("/tmp/skills/nested/SKILL.md")]) + .await; + + let event = timeout(TEST_THROTTLE_INTERVAL, rx.recv()).await; + assert_eq!(event.is_err(), true); +} + +#[tokio::test] +async fn ancestor_events_notify_child_watches() { + let watcher = Arc::new(FileWatcher::noop()); + let (subscriber, rx) = watcher.add_subscriber(); + let _registration = subscriber.register_path(path("/tmp/skills/rust/SKILL.md"), false); + let mut rx = ThrottledWatchReceiver::new(rx, TEST_THROTTLE_INTERVAL); + + watcher.send_paths_for_test(vec![path("/tmp/skills")]).await; + + let event = timeout(Duration::from_secs(1), rx.recv()) .await - .expect("first watcher event") - .expect("broadcast recv first"); + .expect("ancestor event timeout") + .expect("ancestor event"); assert_eq!( - first, - FileWatcherEvent::SkillsChanged { - paths: vec![root.join("a/SKILL.md")] + event, + FileWatcherEvent { + paths: vec![path("/tmp/skills")], } ); +} + +#[tokio::test] +async fn spawn_event_loop_filters_non_mutating_events() { + let watcher = Arc::new(FileWatcher::noop()); + let (subscriber, rx) = watcher.add_subscriber(); + let _registration = subscriber.register_path(path("/tmp/skills"), true); + let mut rx = ThrottledWatchReceiver::new(rx, TEST_THROTTLE_INTERVAL); + let (raw_tx, raw_rx) = mpsc::unbounded_channel(); + watcher.spawn_event_loop_for_test(raw_rx); raw_tx .send(Ok(notify_event( - EventKind::Remove(RemoveKind::File), - vec![root.join("b/SKILL.md")], + EventKind::Access(AccessKind::Open(AccessMode::Any)), + vec![path("/tmp/skills/SKILL.md")], ))) - .expect("send second event"); - drop(raw_tx); + .expect("send access event"); + let blocked = timeout(TEST_THROTTLE_INTERVAL, rx.recv()).await; + assert_eq!(blocked.is_err(), true); - let second = timeout(Duration::from_secs(2), rx.recv()) + raw_tx + .send(Ok(notify_event( + EventKind::Create(CreateKind::File), + vec![path("/tmp/skills/SKILL.md")], + ))) + .expect("send create event"); + let event = timeout(Duration::from_secs(1), rx.recv()) .await - .expect("second watcher event") - .expect("broadcast recv second"); + .expect("create event timeout") + .expect("create event"); assert_eq!( - second, - FileWatcherEvent::SkillsChanged { - paths: vec![root.join("b/SKILL.md")] + event, + FileWatcherEvent { + paths: vec![path("/tmp/skills/SKILL.md")], } ); } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 1c920805021..d3ffdd5657d 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -39,7 +39,7 @@ pub mod exec; pub mod exec_env; mod exec_policy; pub mod external_agent_config; -mod file_watcher; +pub mod file_watcher; mod flags; pub mod git_info; mod guardian; @@ -72,6 +72,7 @@ pub mod sandboxing; mod session_prefix; mod session_startup_prewarm; mod shell_detect; +mod skills_watcher; mod stream_events_utils; pub mod test_support; mod text_encoding; @@ -173,6 +174,7 @@ pub use client_common::Prompt; pub use client_common::REVIEW_PROMPT; pub use client_common::ResponseEvent; pub use client_common::ResponseStream; +pub use codex_sandboxing::get_platform_sandbox; pub use compact::content_items_to_text; pub use event_mapping::parse_turn_item; pub use exec_policy::ExecPolicyError; diff --git a/codex-rs/core/src/skills_watcher.rs b/codex-rs/core/src/skills_watcher.rs new file mode 100644 index 00000000000..fc0b86d90de --- /dev/null +++ b/codex-rs/core/src/skills_watcher.rs @@ -0,0 +1,116 @@ +//! Skills-specific watcher built on top of the generic [`FileWatcher`]. + +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use tokio::runtime::Handle; +use tokio::sync::broadcast; +use tracing::warn; + +use crate::config::Config; +use crate::file_watcher::FileWatcher; +use crate::file_watcher::FileWatcherSubscriber; +use crate::file_watcher::Receiver; +use crate::file_watcher::ThrottledWatchReceiver; +use crate::file_watcher::WatchPath; +use crate::file_watcher::WatchRegistration; +use crate::skills::SkillsManager; + +#[cfg(not(test))] +const WATCHER_THROTTLE_INTERVAL: Duration = Duration::from_secs(10); +#[cfg(test)] +const WATCHER_THROTTLE_INTERVAL: Duration = Duration::from_millis(50); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SkillsWatcherEvent { + SkillsChanged { paths: Vec }, +} + +pub(crate) struct SkillsWatcher { + subscriber: FileWatcherSubscriber, + tx: broadcast::Sender, +} + +impl SkillsWatcher { + pub(crate) fn new(file_watcher: &Arc) -> Self { + let (subscriber, rx) = file_watcher.add_subscriber(); + let (tx, _) = broadcast::channel(128); + let skills_watcher = Self { + subscriber, + tx: tx.clone(), + }; + Self::spawn_event_loop(rx, tx); + skills_watcher + } + + pub(crate) fn noop() -> Self { + Self::new(&Arc::new(FileWatcher::noop())) + } + + pub(crate) fn subscribe(&self) -> broadcast::Receiver { + self.tx.subscribe() + } + + pub(crate) fn register_config( + &self, + config: &Config, + skills_manager: &SkillsManager, + ) -> WatchRegistration { + let roots = skills_manager + .skill_roots_for_config(config) + .into_iter() + .map(|root| WatchPath { + path: root.path, + recursive: true, + }) + .collect(); + self.subscriber.register_paths(roots) + } + + fn spawn_event_loop(rx: Receiver, tx: broadcast::Sender) { + let mut rx = ThrottledWatchReceiver::new(rx, WATCHER_THROTTLE_INTERVAL); + if let Ok(handle) = Handle::try_current() { + handle.spawn(async move { + while let Some(event) = rx.recv().await { + let _ = tx.send(SkillsWatcherEvent::SkillsChanged { paths: event.paths }); + } + }); + } else { + warn!("skills watcher listener skipped: no Tokio runtime available"); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use tokio::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn forwards_file_watcher_events() { + let file_watcher = Arc::new(FileWatcher::noop()); + let skills_watcher = SkillsWatcher::new(&file_watcher); + let mut rx = skills_watcher.subscribe(); + let _registration = skills_watcher + .subscriber + .register_path(PathBuf::from("/tmp/skill"), true); + + file_watcher + .send_paths_for_test(vec![PathBuf::from("/tmp/skill/SKILL.md")]) + .await; + + let event = timeout(Duration::from_secs(2), rx.recv()) + .await + .expect("skills watcher event") + .expect("broadcast recv"); + assert_eq!( + event, + SkillsWatcherEvent::SkillsChanged { + paths: vec![PathBuf::from("/tmp/skill/SKILL.md")], + } + ); + } +} diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index ceab67f1c76..f12f7ef9711 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -8,12 +8,12 @@ use crate::analytics_client::AnalyticsEventsClient; use crate::client::ModelClient; use crate::config::StartedNetworkProxy; use crate::exec_policy::ExecPolicyManager; -use crate::file_watcher::FileWatcher; use crate::mcp::McpManager; use crate::mcp_connection_manager::McpConnectionManager; use crate::models_manager::manager::ModelsManager; use crate::plugins::PluginsManager; use crate::skills::SkillsManager; +use crate::skills_watcher::SkillsWatcher; use crate::state_db::StateDbHandle; use crate::tools::code_mode::CodeModeService; use crate::tools::network_approval::NetworkApprovalService; @@ -54,7 +54,7 @@ pub(crate) struct SessionServices { pub(crate) skills_manager: Arc, pub(crate) plugins_manager: Arc, pub(crate) mcp_manager: Arc, - pub(crate) file_watcher: Arc, + pub(crate) skills_watcher: Arc, pub(crate) agent_control: AgentControl, pub(crate) network_proxy: Option, pub(crate) network_approval: Arc, diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index 642def55590..6ce55c62298 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -12,7 +12,6 @@ use crate::config::Config; use crate::error::CodexErr; use crate::error::Result as CodexResult; use crate::file_watcher::FileWatcher; -use crate::file_watcher::FileWatcherEvent; use crate::mcp::McpManager; use crate::models_manager::collaboration_mode_presets::CollaborationModesConfig; use crate::models_manager::manager::ModelsManager; @@ -24,6 +23,8 @@ use crate::rollout::RolloutRecorder; use crate::rollout::truncation; use crate::shell_snapshot::ShellSnapshot; use crate::skills::SkillsManager; +use crate::skills_watcher::SkillsWatcher; +use crate::skills_watcher::SkillsWatcherEvent; use crate::tasks::interrupted_turn_history_marker; use codex_app_server_protocol::ThreadHistoryBuilder; use codex_app_server_protocol::TurnStatus; @@ -83,32 +84,33 @@ impl Drop for TempCodexHomeGuard { } } -fn build_file_watcher(codex_home: PathBuf, skills_manager: Arc) -> Arc { +fn build_skills_watcher(skills_manager: Arc) -> Arc { if should_use_test_thread_manager_behavior() && let Ok(handle) = Handle::try_current() && handle.runtime_flavor() == RuntimeFlavor::CurrentThread { // The real watcher spins background tasks that can starve the // current-thread test runtime and cause event waits to time out. - warn!("using noop file watcher under current-thread test runtime"); - return Arc::new(FileWatcher::noop()); + warn!("using noop skills watcher under current-thread test runtime"); + return Arc::new(SkillsWatcher::noop()); } - let file_watcher = match FileWatcher::new(codex_home) { + let file_watcher = match FileWatcher::new() { Ok(file_watcher) => Arc::new(file_watcher), Err(err) => { warn!("failed to initialize file watcher: {err}"); Arc::new(FileWatcher::noop()) } }; + let skills_watcher = Arc::new(SkillsWatcher::new(&file_watcher)); - let mut rx = file_watcher.subscribe(); + let mut rx = skills_watcher.subscribe(); let skills_manager = Arc::clone(&skills_manager); if let Ok(handle) = Handle::try_current() { handle.spawn(async move { loop { match rx.recv().await { - Ok(FileWatcherEvent::SkillsChanged { .. }) => { + Ok(SkillsWatcherEvent::SkillsChanged { .. }) => { skills_manager.clear_cache(); } Err(broadcast::error::RecvError::Closed) => break, @@ -117,10 +119,10 @@ fn build_file_watcher(codex_home: PathBuf, skills_manager: Arc) - } }); } else { - warn!("file watcher listener skipped: no Tokio runtime available"); + warn!("skills watcher listener skipped: no Tokio runtime available"); } - file_watcher + skills_watcher } /// Represents a newly created Codex thread (formerly called a conversation), including the first event @@ -201,7 +203,7 @@ pub(crate) struct ThreadManagerState { skills_manager: Arc, plugins_manager: Arc, mcp_manager: Arc, - file_watcher: Arc, + skills_watcher: Arc, session_source: SessionSource, // Captures submitted ops for testing purpose when test mode is enabled. ops_log: Option, @@ -233,7 +235,7 @@ impl ThreadManager { config.bundled_skills_enabled(), restriction_product, )); - let file_watcher = build_file_watcher(codex_home.clone(), Arc::clone(&skills_manager)); + let skills_watcher = build_skills_watcher(Arc::clone(&skills_manager)); Self { state: Arc::new(ThreadManagerState { threads: Arc::new(RwLock::new(HashMap::new())), @@ -248,7 +250,7 @@ impl ThreadManager { skills_manager, plugins_manager, mcp_manager, - file_watcher, + skills_watcher, auth_manager, session_source, ops_log: should_use_test_thread_manager_behavior() @@ -299,7 +301,7 @@ impl ThreadManager { /*bundled_skills_enabled*/ true, restriction_product, )); - let file_watcher = build_file_watcher(codex_home.clone(), Arc::clone(&skills_manager)); + let skills_watcher = build_skills_watcher(Arc::clone(&skills_manager)); Self { state: Arc::new(ThreadManagerState { threads: Arc::new(RwLock::new(HashMap::new())), @@ -312,7 +314,7 @@ impl ThreadManager { skills_manager, plugins_manager, mcp_manager, - file_watcher, + skills_watcher, auth_manager, session_source: SessionSource::Exec, ops_log: should_use_test_thread_manager_behavior() @@ -342,10 +344,6 @@ impl ThreadManager { self.state.mcp_manager.clone() } - pub fn subscribe_file_watcher(&self) -> broadcast::Receiver { - self.state.file_watcher.subscribe() - } - pub fn get_models_manager(&self) -> Arc { self.state.models_manager.clone() } @@ -838,7 +836,7 @@ impl ThreadManagerState { user_shell_override: Option, ) -> CodexResult { let watch_registration = self - .file_watcher + .skills_watcher .register_config(&config, self.skills_manager.as_ref()); let CodexSpawnOk { codex, thread_id, .. @@ -849,7 +847,7 @@ impl ThreadManagerState { skills_manager: Arc::clone(&self.skills_manager), plugins_manager: Arc::clone(&self.plugins_manager), mcp_manager: Arc::clone(&self.mcp_manager), - file_watcher: Arc::clone(&self.file_watcher), + skills_watcher: Arc::clone(&self.skills_watcher), conversation_history: initial_history, session_source, agent_control,