diff --git a/Cargo.lock b/Cargo.lock index 6e20797..985802e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,6 +136,15 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.1" @@ -225,6 +234,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "croner" version = "2.2.0" @@ -234,6 +252,16 @@ dependencies = [ "chrono", ] +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "csv" version = "1.4.0" @@ -355,6 +383,16 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -571,6 +609,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getopts" version = "0.2.24" @@ -1365,6 +1413,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -1463,6 +1520,36 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + [[package]] name = "rc-box" version = "1.3.0" @@ -1653,15 +1740,18 @@ dependencies = [ "anyhow", "async-trait", "axum", + "base64", "chrono", "futures", "futures-util", "pulldown-cmark", + "rand", "reqwest", "rmcp", "rusqlite", "serde", "serde_json", + "sha2", "sqlite-vec", "teloxide", "tempfile", @@ -1930,6 +2020,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2489,6 +2590,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicase" version = "2.9.0" @@ -2561,6 +2668,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "want" version = "0.3.1" @@ -3122,6 +3235,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zerofrom" version = "0.1.6" diff --git a/Cargo.toml b/Cargo.toml index 35f4b93..42e3250 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,5 +58,10 @@ sqlite-vec = "0.1" # Setup wizard web server (used only by src/bin/setup.rs) axum = "0.8" +# OAuth 2.0 / PKCE helpers (used only by src/bin/setup.rs) +rand = "0.8" +sha2 = "0.10" +base64 = "0.22" + [dev-dependencies] tempfile = "3" diff --git a/README.md b/README.md index 5b617eb..b086364 100644 --- a/README.md +++ b/README.md @@ -261,11 +261,11 @@ skills/ - [x] RAG query rewriting (disambiguates follow-up questions before vector search) - [x] Nightly conversation summarization (LLM-based cron job) - [x] Verbose tool UI (`/verbose` command — live tool call progress in Telegram) +- [x] Google integration tools (Calendar, Email, Drive) ### Planned - [ ] Image upload support -- [ ] Google integration tools (Calendar, Email, Drive) - [ ] Event trigger framework (e.g., on email receive) - [ ] WhatsApp support - [ ] Webhook mode (in addition to polling) diff --git a/config.example.toml b/config.example.toml index 8a0c605..7efcb30 100644 --- a/config.example.toml +++ b/config.example.toml @@ -148,6 +148,30 @@ directory = "skills" # These servers are reached over HTTPS and do not require a local command. # Use `url` instead of `command`; optionally set `auth_token` for Bearer auth. +# Example: Notion MCP server (official HTTP MCP — no Node.js required) +# Docs: https://developers.notion.com/guides/mcp/mcp +# +# Recommended: use the setup wizard (cargo run --bin setup) to obtain an OAuth +# access token via the built-in OAuth 2.0 flow. The wizard also stores the +# refresh token and expiry so the bot can automatically renew the connection. +# +# The bot refreshes the access token automatically when it is within 5 minutes +# of expiry and writes the new credentials back to this file. Manual setup: +# 1. Create a Notion integration at https://www.notion.so/my-integrations +# and note your client_id and client_secret. +# 2. Complete the OAuth flow to obtain an access_token and refresh_token. +# 3. Fill in the fields below. +# +# [[mcp_servers]] +# name = "notion" +# url = "https://mcp.notion.com/mcp" +# auth_token = "your-notion-oauth-access-token" +# refresh_token = "your-notion-oauth-refresh-token" +# token_expires_at = 1234567890 # Unix timestamp; set by the wizard +# token_endpoint = "https://api.notion.com/v1/oauth/token" +# oauth_client_id = "your-notion-client-id" +# oauth_client_secret = "your-notion-client-secret" # omit for public clients + # Example: Exa AI web search (https://mcp.exa.ai) # Get your API key at https://dashboard.exa.ai/api-keys # @@ -161,3 +185,11 @@ directory = "skills" # [[mcp_servers]] # name = "exa" # url = "https://mcp.exa.ai/mcp?exaApiKey=your-exa-api-key" + +# ── Self-Learning (optional; defaults apply if section omitted) ───────────── +# [learning] +# user_model_path = "memory/USER.md" # Honcho-style user model file +# skill_extraction_enabled = true # Auto-generate skills from tool-heavy tasks +# skill_extraction_threshold = 5 # Min tool calls to trigger extraction +# user_model_update_interval = 10 # Update user model every N user messages +# user_model_cron = "0 0 3 * * SUN" # Weekly user model refresh (6-field cron) diff --git a/setup/index.html b/setup/index.html index 2f2b829..aae4f2e 100644 --- a/setup/index.html +++ b/setup/index.html @@ -164,6 +164,36 @@ transition: background 0.15s, color 0.15s; } .btn-guide:hover { background: #f6851b; color: #fff; opacity: 1; } + /* OAuth connect button + row */ + .oauth-connect-row { + display: flex; + align-items: center; + gap: 0.75rem; + margin-left: 1.75rem; + margin-top: 0.55rem; + flex-wrap: wrap; + } + .btn-oauth-connect { + display: inline-flex; + align-items: center; + gap: 0.35rem; + padding: 0.4rem 1rem; + border-radius: 6px; + border: none; + background: #2b6cb0; + color: #fff; + font-size: 0.85rem; + font-weight: 600; + cursor: pointer; + transition: background 0.15s, opacity 0.15s; + } + .btn-oauth-connect:hover { background: #2c5282; } + .btn-oauth-connect:disabled { opacity: 0.55; cursor: not-allowed; } + .oauth-connect-status { + font-size: 0.82rem; + font-weight: 500; + color: #718096; + } /* Modal overlay */ .modal-overlay { display: none; @@ -486,6 +516,58 @@

Meta Threads — Access Token Setup

+ + + " + .into(), + ), + } + }; // lock is released here + + // Exchange the authorization code for an access token. + let redir = redirect_uri(); + let mut token_params = vec![ + ("grant_type", "authorization_code".to_owned()), + ("code", params.code.clone()), + ("redirect_uri", redir), + ("client_id", client_id), + ("code_verifier", code_verifier), + ]; + if let Some(secret) = client_secret { + token_params.push(("client_secret", secret)); + } + + let token_result = st + .http_client + .post(&token_endpoint) + .form(&token_params) + .send() + .await; + + match token_result { + Ok(resp) if resp.status().is_success() => match resp.json::().await { + Ok(tok) => { + // Re-acquire the lock only to write back the access token. + if let Some(session) = st.oauth_sessions.lock().await.get_mut(¶ms.state) { + session.access_token = Some(tok.access_token); + session.refresh_token = tok.refresh_token; + session.expires_in = tok.expires_in; + } + Html(format!( + "Authorized\ +

\ + ✅ {server_name} authorization successful! You can close this window.

\ + \ + " + )) + } + Err(e) => Html(format!( + "

Failed to parse token response: {e}

\ + " + )), + }, + Ok(resp) => { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + Html(format!( + "

Token exchange failed ({status}): {body}

\ + " + )) + } + Err(e) => Html(format!( + "

Token request error: {e}

\ + " + )), + } +} + +/// GET /api/oauth/token?state= +/// +/// Polling endpoint — returns `{ ready: true, token }` once the callback has +/// completed, or `{ ready: false }` while still waiting. +async fn oauth_token_poll( + State(st): State, + Query(params): Query, +) -> Result, StatusCode> { + let sessions = st.oauth_sessions.lock().await; + let session = sessions.get(¶ms.state).ok_or(StatusCode::NOT_FOUND)?; + Ok(Json(OAuthTokenPollResponse { + ready: session.access_token.is_some(), + token: session.access_token.clone(), + refresh_token: session.refresh_token.clone(), + expires_in: session.expires_in, + token_endpoint: Some(session.token_endpoint.clone()), + oauth_client_id: Some(session.client_id.clone()), + oauth_client_secret: session.client_secret.clone(), + })) +} + // ── Config formatting ────────────────────────────────────────────────────────── struct ConfigParams<'a> { @@ -309,33 +698,37 @@ async fn main() -> Result<()> { } // ── Web mode ────────────────────────────────────────────────────────────── - let port: u16 = 8719; let config_path = project_root.join("config.toml"); let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); let state = AppState { config_path, shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))), + oauth_sessions: Arc::new(Mutex::new(HashMap::new())), + http_client: reqwest::Client::new(), }; let app = Router::new() .route("/", get(serve_index)) .route("/api/load-config", get(load_config)) .route("/api/save-config", post(save_config)) + .route("/api/oauth/start", get(oauth_start)) + .route("/oauth/callback", get(oauth_callback)) + .route("/api/oauth/token", get(oauth_token_poll)) .with_state(state); - let addr = format!("127.0.0.1:{port}"); + let addr = format!("127.0.0.1:{SETUP_PORT}"); let listener = tokio::net::TcpListener::bind(&addr) .await .with_context(|| format!("Failed to bind to {addr}"))?; - println!("RustFox setup wizard → http://localhost:{port}"); + println!("RustFox setup wizard → http://localhost:{SETUP_PORT}"); println!("Press Ctrl-C to exit without saving.\n"); // Open the browser after a short delay. tokio::spawn(async move { tokio::time::sleep(tokio::time::Duration::from_millis(400)).await; - let url = format!("http://localhost:{port}"); + let url = format!("http://localhost:{SETUP_PORT}"); // Try xdg-open (Linux), then open (macOS) — ignore errors. let _ = std::process::Command::new("xdg-open").arg(&url).status(); let _ = std::process::Command::new("open").arg(&url).status(); @@ -621,4 +1014,32 @@ allowed_directory = "/tmp" assert!(out.contains("[skills]")); assert!(out.contains(r#"directory = "skills""#)); } + + #[test] + fn test_pkce_verifier_length() { + let v = pkce_verifier(); + // 32 bytes → 43 base64url chars (no padding) + assert_eq!(v.len(), 43); + assert!(v + .chars() + .all(|c| c.is_alphanumeric() || c == '-' || c == '_')); + } + + #[test] + fn test_pkce_challenge_is_base64url() { + let verifier = pkce_verifier(); + let challenge = pkce_challenge(&verifier); + // SHA-256 → 32 bytes → 43 base64url chars + assert_eq!(challenge.len(), 43); + assert!(challenge + .chars() + .all(|c| c.is_alphanumeric() || c == '-' || c == '_')); + } + + #[test] + fn test_random_state_is_32_hex_chars() { + let s = random_state(); + assert_eq!(s.len(), 32); + assert!(s.chars().all(|c| c.is_ascii_hexdigit())); + } } diff --git a/src/config.rs b/src/config.rs index 5dc39fe..c794d65 100644 --- a/src/config.rs +++ b/src/config.rs @@ -22,6 +22,8 @@ pub struct Config { pub embedding: Option, #[serde(default)] pub langsmith: Option, + #[serde(default = "default_learning_config")] + pub learning: LearningConfig, } #[derive(Debug, Deserialize, Clone)] @@ -80,6 +82,24 @@ pub struct McpServerConfig { /// Used with `url`; ignored for stdio servers. #[serde(default)] pub auth_token: Option, + /// OAuth 2.0 refresh token for long-lived connections. + /// When set, the bot will automatically exchange this for a new `auth_token` + /// before the current one expires and persist the updated token to `config.toml`. + #[serde(default)] + pub refresh_token: Option, + /// Unix timestamp (seconds since epoch) at which the current `auth_token` + /// expires. Derived from the `expires_in` field of the token response. + #[serde(default)] + pub token_expires_at: Option, + /// OAuth 2.0 token endpoint used for refresh-token exchanges. + #[serde(default)] + pub token_endpoint: Option, + /// OAuth 2.0 client ID used when authenticating refresh-token requests. + #[serde(default)] + pub oauth_client_id: Option, + /// OAuth 2.0 client secret (if applicable) used alongside `oauth_client_id`. + #[serde(default)] + pub oauth_client_secret: Option, } #[derive(Debug, Deserialize, Clone)] @@ -138,6 +158,25 @@ pub struct LangSmithConfig { pub base_url: String, } +#[derive(Debug, Deserialize, Clone)] +pub struct LearningConfig { + /// Path to the user model file (Honcho-style USER.md). + #[serde(default = "default_user_model_path")] + pub user_model_path: PathBuf, + /// Whether post-task skill extraction is enabled. + #[serde(default = "default_true")] + pub skill_extraction_enabled: bool, + /// Minimum tool calls to trigger skill extraction (default 5). + #[serde(default = "default_skill_extraction_threshold")] + pub skill_extraction_threshold: u32, + /// Message count between user model updates (default 10). + #[serde(default = "default_user_model_update_interval")] + pub user_model_update_interval: usize, + /// Cron expression for weekly user model update (default: Sunday 3am). + #[serde(default = "default_user_model_cron")] + pub user_model_cron: String, +} + fn default_model() -> String { "moonshotai/kimi-k2.5".to_string() } @@ -264,6 +303,36 @@ fn default_langsmith_base_url() -> String { "https://api.smith.langchain.com".to_string() } +fn default_user_model_path() -> PathBuf { + PathBuf::from("memory/USER.md") +} + +fn default_true() -> bool { + true +} + +fn default_skill_extraction_threshold() -> u32 { + 5 +} + +fn default_user_model_update_interval() -> usize { + 10 +} + +fn default_user_model_cron() -> String { + "0 0 3 * * SUN".to_string() +} + +fn default_learning_config() -> LearningConfig { + LearningConfig { + user_model_path: default_user_model_path(), + skill_extraction_enabled: true, + skill_extraction_threshold: default_skill_extraction_threshold(), + user_model_update_interval: default_user_model_update_interval(), + user_model_cron: default_user_model_cron(), + } +} + impl Config { /// Location string from [general], injected into the system prompt. pub fn user_location(&self) -> Option<&str> { diff --git a/src/learning.rs b/src/learning.rs new file mode 100644 index 0000000..ddd63c1 --- /dev/null +++ b/src/learning.rs @@ -0,0 +1,741 @@ +use anyhow::{Context, Result}; +use std::path::Path; +use tracing::{info, warn}; + +use crate::llm::{ChatMessage, LlmClient}; +use crate::skills::loader::load_skills_from_dir; +use crate::skills::SkillRegistry; + +/// Minimum number of conversation messages needed before updating the user model. +const MIN_MESSAGES_FOR_USER_MODEL: usize = 3; + +// ─── Feature 1: Post-task Skill Extractor ─────────────────────────────────── + +/// Analyze a completed agentic loop and, if the conversation used enough tool +/// calls and contains a reusable pattern, auto-generate a new skill in `skills/`. +/// +/// This function performs the extraction work when awaited. Callers that want +/// it to run in the background should spawn it explicitly with `tokio::spawn`. +pub async fn post_task_skill_extractor( + llm: &LlmClient, + skills_dir: &Path, + skills: &tokio::sync::RwLock, + messages: &[ChatMessage], + tool_call_count: u32, +) { + info!( + "Skill extractor triggered: {} tool calls in conversation", + tool_call_count + ); + + match extract_skill_from_conversation(llm, skills_dir, skills, messages).await { + Ok(Some(name)) => info!("Auto-generated skill: {}", name), + Ok(None) => info!("Skill extractor: no reusable pattern found"), + Err(e) => warn!("Skill extractor failed: {:#}", e), + } +} + +/// Ask the LLM to analyze the conversation and decide whether a reusable skill +/// should be created. Returns `Some(skill_name)` if one was written, `None` if +/// the LLM decided the task was too ad-hoc. +async fn extract_skill_from_conversation( + llm: &LlmClient, + skills_dir: &Path, + skills: &tokio::sync::RwLock, + messages: &[ChatMessage], +) -> Result> { + // Build a condensed transcript of the conversation for the LLM. + let transcript = build_transcript(messages); + if transcript.is_empty() { + return Ok(None); + } + + // Collect existing skill names to avoid duplicates. + let existing_names: Vec = { + let reg = skills.read().await; + reg.list().iter().map(|s| s.name.clone()).collect() + }; + + let analysis_prompt = format!( + "You are a skill-extraction engine for an AI assistant called RustFox.\n\ + \n\ + Analyze the following conversation transcript and decide if it contains a \ + **reusable, multi-step workflow** that should be saved as a new skill.\n\ + \n\ + Rules:\n\ + - Only create a skill if the pattern would help the assistant handle SIMILAR \ + future requests more efficiently.\n\ + - Do NOT create skills for: one-off questions, trivial lookups, simple math, \ + greetings, or tasks that are too specific to generalize.\n\ + - The skill name must be lowercase letters, numbers, and hyphens only (1-64 chars).\n\ + - These skills already exist (do NOT duplicate): {existing}\n\ + \n\ + If a skill SHOULD be created, respond with EXACTLY this format (no extra text):\n\ + ```\n\ + SKILL_NAME: \n\ + SKILL_DESCRIPTION: \n\ + SKILL_TAGS: \n\ + SKILL_BODY:\n\ + \n\ + ```\n\ + \n\ + If NO skill should be created, respond with exactly: NO_SKILL\n\ + \n\ + TRANSCRIPT:\n{transcript}", + existing = existing_names.join(", "), + transcript = transcript, + ); + + let analysis_messages = vec![ChatMessage { + role: "user".to_string(), + content: Some(analysis_prompt), + tool_calls: None, + tool_call_id: None, + }]; + + let response = llm.chat(&analysis_messages, &[]).await?; + let raw = response.content.unwrap_or_default(); + + // Strip outer code fences if the model wrapped its entire response. + let content = strip_code_fences(&raw); + + if content.trim() == "NO_SKILL" || !content.contains("SKILL_NAME:") { + return Ok(None); + } + + // Parse the structured response. + let name = extract_line_value(&content, "SKILL_NAME:") + .context("Missing SKILL_NAME in LLM response")?; + let description = extract_line_value(&content, "SKILL_DESCRIPTION:") + .unwrap_or_else(|| "Auto-generated skill".to_string()); + let tags_raw = extract_line_value(&content, "SKILL_TAGS:").unwrap_or_default(); + let tags: Vec = tags_raw + .split(',') + .map(|t| t.trim().to_string()) + .filter(|t| !t.is_empty()) + .collect(); + + let body = extract_body_after(&content, "SKILL_BODY:") + .unwrap_or_else(|| "# Auto-generated skill\n\nNo instructions extracted.".to_string()); + + // Validate the name. + if name.is_empty() + || name.len() > 64 + || !name + .chars() + .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-') + { + warn!("Skill extractor: invalid name '{}', skipping", name); + return Ok(None); + } + + // Skip if a skill with this name already exists. + if existing_names.iter().any(|n| n == &name) { + info!("Skill extractor: skill '{}' already exists, skipping", name); + return Ok(None); + } + + // Build SKILL.md content with frontmatter. + // YAML-single-quote description and tag values so characters like `:`, `#`, + // `]`, or commas don't break the simple frontmatter parser. + fn yaml_single_quoted(value: &str) -> String { + format!("'{}'", value.replace('\'', "''")) + } + let tags_yaml = if tags.is_empty() { + "[]".to_string() + } else { + format!( + "[{}]", + tags.iter() + .map(|tag| yaml_single_quoted(tag)) + .collect::>() + .join(", ") + ) + }; + let skill_content = format!( + "---\nname: {name}\ndescription: {description}\ntags: {tags}\n---\n\n{body}\n", + name = name, + description = yaml_single_quoted(&description), + tags = tags_yaml, + body = body, + ); + + // Write the skill file. + let skill_dir = skills_dir.join(&name); + tokio::fs::create_dir_all(&skill_dir) + .await + .with_context(|| format!("Failed to create skill directory: {}", skill_dir.display()))?; + + let skill_path = skill_dir.join("SKILL.md"); + tokio::fs::write(&skill_path, &skill_content) + .await + .with_context(|| format!("Failed to write skill file: {}", skill_path.display()))?; + + info!("Written auto-generated skill: {}", skill_path.display()); + + // Hot-reload skills. + match load_skills_from_dir(skills_dir).await { + Ok(new_registry) => { + let count = new_registry.len(); + let mut reg = skills.write().await; + *reg = new_registry; + info!("Skills reloaded after extraction: {} skill(s)", count); + } + Err(e) => warn!("Failed to reload skills after extraction: {:#}", e), + } + + Ok(Some(name)) +} + +// ─── Feature 2: Skill Self-Patch ──────────────────────────────────────────── + +/// Safely patch an existing skill's SKILL.md by appending or replacing a section. +/// Creates a `.bak` backup before modifying. Returns a status message. +pub async fn self_patch_skill( + skills_dir: &Path, + skill_name: &str, + patch_content: &str, + skills: &tokio::sync::RwLock, +) -> Result { + // Validate skill_name to prevent path traversal (e.g. "../secret"). + // Only allow safe directory-name characters. + if skill_name.is_empty() + || skill_name.contains('/') + || skill_name.contains('\\') + || skill_name.contains("..") + || skill_name.starts_with('.') + { + anyhow::bail!("Invalid skill name: '{}'", skill_name); + } + + let skill_dir = skills_dir.join(skill_name); + + // Canonicalize to verify the resolved path stays inside skills_dir. + // We canonicalize skills_dir first (it must exist), then check the prefix. + let canonical_skills_dir = tokio::fs::canonicalize(skills_dir) + .await + .with_context(|| format!("Cannot canonicalize skills dir: {}", skills_dir.display()))?; + let canonical_skill_dir = tokio::fs::canonicalize(&skill_dir).await.ok(); + if let Some(ref resolved) = canonical_skill_dir { + if !resolved.starts_with(&canonical_skills_dir) { + anyhow::bail!("Skill path '{}' escapes the skills directory", skill_name); + } + } + + let skill_path = skill_dir.join("SKILL.md"); + + if !skill_path.exists() { + anyhow::bail!("Skill '{}' does not exist", skill_name); + } + + // Read current content. + let current = tokio::fs::read_to_string(&skill_path) + .await + .with_context(|| format!("Failed to read {}", skill_path.display()))?; + + // Create backup. + let backup_path = skill_path.with_extension("md.bak"); + tokio::fs::write(&backup_path, ¤t) + .await + .with_context(|| format!("Failed to create backup: {}", backup_path.display()))?; + + // Apply patch: if the patch contains frontmatter (starts with ---), replace + // the entire file. Otherwise append to the existing content. + let new_content = if patch_content.starts_with("---") { + patch_content.to_string() + } else { + format!("{}\n\n{}", current.trim_end(), patch_content) + }; + + // Validate that the result still contains frontmatter. + if !new_content.starts_with("---") { + anyhow::bail!( + "Patched content would lose YAML frontmatter — aborting. \ + Backup preserved at {}", + backup_path.display() + ); + } + + tokio::fs::write(&skill_path, &new_content) + .await + .with_context(|| format!("Failed to write patched skill: {}", skill_path.display()))?; + + info!( + "Skill '{}' patched ({} → {} bytes)", + skill_name, + current.len(), + new_content.len() + ); + + // Hot-reload skills. + match load_skills_from_dir(skills_dir).await { + Ok(new_registry) => { + let count = new_registry.len(); + let mut reg = skills.write().await; + *reg = new_registry; + info!("Skills reloaded after patch: {} skill(s)", count); + } + Err(e) => warn!("Failed to reload skills after patch: {:#}", e), + } + + Ok(format!( + "Skill '{}' patched successfully. Backup at {}", + skill_name, + backup_path.display() + )) +} + +// ─── Feature 3: User Model ───────────────────────────────────────────────── + +/// Default content for a new USER.md file. +const DEFAULT_USER_MODEL: &str = "\ +--- +name: user-model +description: Long-term user preferences and context learned across sessions. +tags: [user, preferences, context] +--- + +# User Model + + + +user_name: ~ +language: [en] +communication_style: [] +preferences: [] +interests: [] +context: [] +"; + +/// Read the user model file, or return empty string if it doesn't exist. +pub async fn read_user_model(user_model_path: &Path) -> String { + match tokio::fs::read_to_string(user_model_path).await { + Ok(content) => content, + Err(e) => { + // Only warn if the file exists but can't be read (permission errors, etc.) + if user_model_path.exists() { + warn!( + "User model file exists but could not be read ({}): {}", + user_model_path.display(), + e + ); + } + String::new() + } + } +} + +/// Update the user model by summarizing recent conversations through the LLM. +pub async fn update_user_model( + llm: &LlmClient, + memory: &crate::memory::MemoryStore, + user_model_path: &Path, +) { + match update_user_model_inner(llm, memory, user_model_path).await { + Ok(true) => info!("User model updated: {}", user_model_path.display()), + Ok(false) => info!("User model: not enough data to update"), + Err(e) => warn!("User model update failed: {:#}", e), + } +} + +async fn update_user_model_inner( + llm: &LlmClient, + memory: &crate::memory::MemoryStore, + user_model_path: &Path, +) -> Result { + // Load recent conversation messages for context. + let recent = memory + .search_messages("user preferences interests communication", 20) + .await + .unwrap_or_default(); + + if recent.len() < MIN_MESSAGES_FOR_USER_MODEL { + return Ok(false); // Not enough data yet + } + + let conversation_snippets: String = recent + .iter() + .filter_map(|m| m.content.as_ref().map(|c| format!("[{}]: {}", m.role, c))) + .collect::>() + .join("\n"); + + // Read existing model. + let existing = if user_model_path.exists() { + tokio::fs::read_to_string(user_model_path) + .await + .unwrap_or_default() + } else { + DEFAULT_USER_MODEL.to_string() + }; + + let prompt = format!( + "You maintain a concise user profile for an AI assistant.\n\ + \n\ + Current user model:\n```\n{existing}\n```\n\ + \n\ + Recent conversation excerpts:\n```\n{snippets}\n```\n\ + \n\ + Update the user model based on the conversations. Rules:\n\ + - Keep the YAML frontmatter exactly as-is (name, description, tags)\n\ + - Update fields: user_name, language, communication_style, preferences, \ + interests, context\n\ + - Be concise — max 500 words total\n\ + - Only add information the user explicitly stated or strongly implied\n\ + - Do not remove existing valid entries — merge new info\n\ + - Output the COMPLETE updated file (frontmatter + body), nothing else", + existing = existing, + snippets = conversation_snippets, + ); + + let messages = vec![ChatMessage { + role: "user".to_string(), + content: Some(prompt), + tool_calls: None, + tool_call_id: None, + }]; + + let response = llm.chat(&messages, &[]).await?; + let new_content = response.content.unwrap_or_default(); + + // Strict validation: must start with `---` and contain a closing `---` + // delimiter so we don't write malformed or injection-bearing content into + // USER.md (which is later injected into the system prompt). + if !has_valid_frontmatter(&new_content) || new_content.trim().is_empty() { + warn!("User model update returned invalid content, skipping"); + return Ok(false); + } + + // Ensure parent directory exists. + if let Some(parent) = user_model_path.parent() { + tokio::fs::create_dir_all(parent).await.ok(); + } + + tokio::fs::write(user_model_path, &new_content) + .await + .with_context(|| format!("Failed to write user model: {}", user_model_path.display()))?; + + Ok(true) +} + +// ─── Feature 4: Self-Update ───────────────────────────────────────────────── + +/// Run `git fetch`, `git checkout `, `git pull`, and `cargo build --release` +/// in the project root directory. Returns a multi-line status log. +/// +/// Does NOT restart the process — the caller should arrange restart after a +/// successful build (e.g. via `std::process::Command` or systemd). +pub async fn self_update(branch: &str, project_root: &Path) -> Result { + let mut log = String::new(); + + // Step 0: Check for uncommitted changes. + let status_output = run_git_command(project_root, &["status", "--porcelain"]).await?; + if !status_output.trim().is_empty() { + // Stash changes to avoid losing them. + let stash_result = run_git_command( + project_root, + &["stash", "push", "-m", "rustfox-auto-stash-before-update"], + ) + .await?; + log.push_str(&format!( + "⚠ Stashed uncommitted changes: {}\n", + stash_result.trim() + )); + } + + // Step 1: git fetch --all + log.push_str("→ git fetch --all\n"); + let fetch = run_git_command(project_root, &["fetch", "--all"]).await?; + log.push_str(&format!(" {}\n", fetch.trim())); + + // Step 2: git checkout + log.push_str(&format!("→ git checkout {}\n", branch)); + let checkout = run_git_command(project_root, &["checkout", branch]).await?; + log.push_str(&format!(" {}\n", checkout.trim())); + + // Step 3: git pull origin + log.push_str(&format!("→ git pull origin {}\n", branch)); + let pull = run_git_command(project_root, &["pull", "origin", branch]).await?; + log.push_str(&format!(" {}\n", pull.trim())); + + // Step 4: cargo build --release + log.push_str("→ cargo build --release\n"); + let build = run_cargo_build(project_root).await?; + log.push_str(&format!(" {}\n", build.trim())); + + log.push_str("✅ Build successful. Restart to activate the new version."); + + Ok(log) +} + +/// Run a git command in the given directory and return combined stdout+stderr. +async fn run_git_command(dir: &Path, args: &[&str]) -> Result { + let output = tokio::process::Command::new("git") + .args(args) + .current_dir(dir) + .output() + .await + .with_context(|| format!("Failed to execute: git {}", args.join(" ")))?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined = format!("{}{}", stdout, stderr); + + if !output.status.success() { + anyhow::bail!( + "git {} failed (exit {}): {}", + args.join(" "), + output.status.code().unwrap_or(-1), + combined.trim() + ); + } + + Ok(combined) +} + +/// Run `cargo build --release` in the given directory. +async fn run_cargo_build(dir: &Path) -> Result { + let output = tokio::process::Command::new("cargo") + .args(["build", "--release"]) + .current_dir(dir) + .output() + .await + .context("Failed to execute cargo build --release")?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined = format!("{}{}", stdout, stderr); + + if !output.status.success() { + anyhow::bail!( + "cargo build --release failed (exit {}): {}", + output.status.code().unwrap_or(-1), + combined.trim() + ); + } + + Ok(combined) +} + +// ─── Helpers ──────────────────────────────────────────────────────────────── + +/// Build a condensed transcript from conversation messages for analysis. +fn build_transcript(messages: &[ChatMessage]) -> String { + messages + .iter() + .filter(|m| m.role == "user" || m.role == "assistant" || m.role == "tool") + .filter_map(|m| { + m.content + .as_ref() + .map(|c| format!("[{}]: {}", m.role, truncate(c, 500))) + }) + .collect::>() + .join("\n") +} + +/// Truncate a string to at most `max_len` characters. +fn truncate(s: &str, max_len: usize) -> String { + let char_count = s.chars().count(); + if char_count <= max_len { + s.to_string() + } else { + let truncated: String = s.chars().take(max_len).collect(); + format!("{}…", truncated) + } +} + +/// Extract the value after a `KEY:` prefix on a single line. +fn extract_line_value(text: &str, key: &str) -> Option { + for line in text.lines() { + let trimmed = line.trim(); + if let Some(rest) = trimmed.strip_prefix(key) { + let value = rest.trim().to_string(); + if !value.is_empty() { + return Some(value); + } + } + } + None +} + +/// Extract everything after a `KEY:` line as the body text. +fn extract_body_after(text: &str, key: &str) -> Option { + if let Some(pos) = text.find(key) { + let after = &text[pos + key.len()..]; + let body = after.trim().to_string(); + if !body.is_empty() { + return Some(body); + } + } + None +} + +/// Strip outer code fences (```` ``` ```` or ```` ```rust ```` etc.) from a string. +/// If the entire content is wrapped in a single fenced block, return just the +/// inner text; otherwise return the trimmed original. +fn strip_code_fences(s: &str) -> String { + let trimmed = s.trim(); + if let Some(after_open) = trimmed.strip_prefix("```") { + // Skip optional language hint on the opening fence line. + let after_hint = + after_open.trim_start_matches(|c: char| c.is_alphanumeric() || c == '_' || c == '-'); + // The rest must start with a newline for this to be a real fence. + if let Some(inner) = after_hint.strip_prefix('\n') { + if let Some(close_pos) = inner.rfind("```") { + return inner[..close_pos].trim().to_string(); + } + } + } + trimmed.to_string() +} + +/// Return `true` if `content` has a valid YAML frontmatter block, i.e. it +/// starts with `---` and contains a second `---` delimiter on its own line. +fn has_valid_frontmatter(content: &str) -> bool { + let trimmed = content.trim(); + let mut lines = trimmed.lines(); + // The very first line must be exactly "---". + if lines.next() != Some("---") { + return false; + } + // Look for a closing "---" on its own line. + for line in lines { + if line.trim() == "---" { + return true; + } + } + false +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + #[test] + fn test_extract_line_value() { + let text = "SKILL_NAME: my-new-skill\nSKILL_DESCRIPTION: Does things\n"; + assert_eq!( + extract_line_value(text, "SKILL_NAME:"), + Some("my-new-skill".to_string()) + ); + assert_eq!( + extract_line_value(text, "SKILL_DESCRIPTION:"), + Some("Does things".to_string()) + ); + assert_eq!(extract_line_value(text, "MISSING:"), None); + } + + #[test] + fn test_extract_body_after() { + let text = "SKILL_BODY:\n# My Skill\n\nDo things step by step."; + let body = extract_body_after(text, "SKILL_BODY:").unwrap(); + assert!(body.starts_with("# My Skill")); + assert!(body.contains("step by step")); + } + + #[test] + fn test_truncate_short_string() { + assert_eq!(truncate("hello", 10), "hello"); + } + + #[test] + fn test_truncate_long_string() { + let result = truncate("hello world", 5); + assert!(result.starts_with("hello")); + assert!(result.ends_with('…')); + } + + #[test] + fn test_build_transcript_filters_roles() { + let messages = vec![ + ChatMessage { + role: "system".to_string(), + content: Some("System prompt".to_string()), + tool_calls: None, + tool_call_id: None, + }, + ChatMessage { + role: "user".to_string(), + content: Some("Hello".to_string()), + tool_calls: None, + tool_call_id: None, + }, + ChatMessage { + role: "assistant".to_string(), + content: Some("Hi there".to_string()), + tool_calls: None, + tool_call_id: None, + }, + ]; + let transcript = build_transcript(&messages); + assert!(!transcript.contains("System prompt")); + assert!(transcript.contains("[user]: Hello")); + assert!(transcript.contains("[assistant]: Hi there")); + } + + #[tokio::test] + async fn test_read_user_model_nonexistent() { + let dir = tempdir().unwrap(); + let path = dir.path().join("nonexistent.md"); + let content = read_user_model(&path).await; + assert!(content.is_empty()); + } + + #[tokio::test] + async fn test_read_user_model_existing() { + let dir = tempdir().unwrap(); + let path = dir.path().join("USER.md"); + tokio::fs::write(&path, "# Test Model\nuser_name: Alice") + .await + .unwrap(); + let content = read_user_model(&path).await; + assert!(content.contains("Alice")); + } + + #[tokio::test] + async fn test_self_patch_skill_creates_backup() { + let dir = tempdir().unwrap(); + let skill_dir = dir.path().join("test-skill"); + tokio::fs::create_dir_all(&skill_dir).await.unwrap(); + + let skill_path = skill_dir.join("SKILL.md"); + tokio::fs::write( + &skill_path, + "---\nname: test-skill\ndescription: Test\ntags: []\n---\n\n# Original", + ) + .await + .unwrap(); + + let registry = tokio::sync::RwLock::new(SkillRegistry::new()); + + let result = self_patch_skill( + dir.path(), + "test-skill", + "\n## New Section\nAdded.", + ®istry, + ) + .await + .unwrap(); + + assert!(result.contains("patched successfully")); + + // Verify backup exists. + let backup = skill_dir.join("SKILL.md.bak"); + assert!(backup.exists()); + let backup_content = tokio::fs::read_to_string(&backup).await.unwrap(); + assert!(backup_content.contains("# Original")); + + // Verify patch was applied. + let patched = tokio::fs::read_to_string(&skill_path).await.unwrap(); + assert!(patched.contains("# Original")); + assert!(patched.contains("## New Section")); + } + + #[tokio::test] + async fn test_self_patch_nonexistent_skill_fails() { + let dir = tempdir().unwrap(); + let registry = tokio::sync::RwLock::new(SkillRegistry::new()); + let result = self_patch_skill(dir.path(), "no-such-skill", "patch", ®istry).await; + assert!(result.is_err()); + } +} diff --git a/src/main.rs b/src/main.rs index 8a96440..eb193c5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod agent; mod config; mod langsmith; +mod learning; mod llm; mod mcp; mod memory; @@ -80,9 +81,19 @@ async fn main() -> Result<()> { .context("Failed to initialize memory store")?; info!(" Database: {}", config.memory.database_path.display()); - // Initialize MCP connections + // Refresh any expiring OAuth tokens before connecting to MCP servers + let http_client = reqwest::Client::new(); + let mut mcp_server_configs = config.mcp_servers.clone(); + let refreshed = + crate::mcp::refresh_expiring_tokens(&mut mcp_server_configs, &config_path, &http_client) + .await; + if refreshed > 0 { + info!(" Refreshed {refreshed} expiring MCP OAuth token(s) at startup"); + } + + // Initialize MCP connections (using possibly-refreshed configs) let mut mcp_manager = McpManager::new(); - mcp_manager.connect_all(&config.mcp_servers).await; + mcp_manager.connect_all(&mcp_server_configs).await; // Load skills from markdown files let skills = load_skills_from_dir(&config.skills.directory).await?; @@ -165,6 +176,34 @@ async fn main() -> Result<()> { } }); + // Spawn background OAuth token refresh task: checks every 30 minutes. + // `cfgs` is kept across ticks so that updated token_expires_at values + // are remembered and a freshly-rotated refresh token isn't re-used. + { + let mut cfgs = mcp_server_configs.clone(); + let refresh_config_path = config_path.clone(); + let refresh_http_client = http_client.clone(); + tokio::spawn(async move { + // 30-minute interval — tokens expiring within 5 min are always caught + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30 * 60)); + interval.tick().await; // skip first immediate tick + loop { + interval.tick().await; + let refreshed = crate::mcp::refresh_expiring_tokens( + &mut cfgs, + &refresh_config_path, + &refresh_http_client, + ) + .await; + if refreshed > 0 { + tracing::info!( + "Background token refresh: updated {refreshed} MCP OAuth token(s)" + ); + } + } + }); + } + // Register built-in background tasks and start scheduler register_builtin_tasks( &scheduler, @@ -172,6 +211,8 @@ async fn main() -> Result<()> { crate::llm::LlmClient::new(config.openrouter.clone()), config.memory.summarize_cron.clone(), config.memory.summarize_threshold, + config.learning.user_model_cron.clone(), + config.learning.user_model_path.clone(), ) .await?; scheduler.start().await?; diff --git a/src/mcp.rs b/src/mcp.rs index 94da412..41ebf9e 100644 --- a/src/mcp.rs +++ b/src/mcp.rs @@ -8,14 +8,223 @@ use rmcp::{ }, ServiceExt, }; +use serde::Deserialize; use serde_json::Value; use std::collections::HashMap; +use std::path::Path; use tokio::process::Command; -use tracing::{error, info}; +use tracing::{debug, error, info, warn}; use crate::config::McpServerConfig; use crate::llm::{FunctionDefinition, ToolDefinition}; +// ── OAuth token refresh ──────────────────────────────────────────────────────── + +/// Response from the token endpoint when refreshing an access token. +#[derive(Deserialize)] +struct TokenRefreshResponse { + access_token: String, + /// Authorization servers rotate the refresh token on each use. + #[serde(default)] + refresh_token: Option, + /// Lifetime of the new access token in seconds. + #[serde(default)] + expires_in: Option, +} + +/// Returns true when the access token for an HTTP MCP server has expired or +/// will expire within the next 5 minutes. +pub fn token_needs_refresh(config: &McpServerConfig) -> bool { + if config.refresh_token.is_none() || config.token_endpoint.is_none() { + return false; + } + match config.token_expires_at { + None => false, + Some(expires_at) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64; + // Refresh if expiry is within the next 5 minutes (300 seconds) + expires_at - now <= 300 + } + } +} + +/// Exchange a refresh token for a new access token. +/// +/// Uses HTTP Basic Auth with `oauth_client_id` : `oauth_client_secret` when a +/// client ID is present. Returns the updated token fields. +pub async fn refresh_oauth_token( + config: &McpServerConfig, + http_client: &reqwest::Client, +) -> Result<(String, Option, Option)> { + let refresh_token = config + .refresh_token + .as_deref() + .context("No refresh_token in config")?; + let token_endpoint = config + .token_endpoint + .as_deref() + .context("No token_endpoint in config")?; + + let params = [ + ("grant_type", "refresh_token"), + ("refresh_token", refresh_token), + ]; + + let mut request = http_client.post(token_endpoint).form(¶ms); + + // Add HTTP Basic Auth when a client_id is available. + if let Some(client_id) = &config.oauth_client_id { + request = request.basic_auth(client_id, config.oauth_client_secret.as_deref()); + } + + let resp = request + .send() + .await + .context("Token refresh request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!( + "Token refresh failed ({status}) for '{}': {body}", + config.name + ); + } + + let tok: TokenRefreshResponse = resp + .json() + .await + .context("Failed to parse token refresh response")?; + + let new_expires_at = tok.expires_in.map(|secs| { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64; + now + secs as i64 + }); + + Ok((tok.access_token, tok.refresh_token, new_expires_at)) +} + +/// Rewrite the `auth_token`, `refresh_token`, and `token_expires_at` fields for +/// a named `[[mcp_servers]]` entry inside `config.toml`, preserving all other +/// content. +/// +/// Strategy: parse the TOML into a `toml::Value`, update the matching server +/// entry, and serialise back. Comments are lost on round-trip, but the +/// functional data is fully preserved. This is acceptable for a +/// machine-updated file. +pub async fn update_config_tokens( + config_path: &Path, + server_name: &str, + auth_token: &str, + new_refresh_token: Option<&str>, + new_expires_at: Option, +) -> Result<()> { + let content = tokio::fs::read_to_string(config_path) + .await + .with_context(|| format!("Failed to read {}", config_path.display()))?; + + let mut doc: toml::Value = content + .parse() + .with_context(|| format!("Failed to parse TOML from {}", config_path.display()))?; + + if let Some(servers) = doc.get_mut("mcp_servers").and_then(|v| v.as_array_mut()) { + for server in servers.iter_mut() { + if server + .get("name") + .and_then(|v| v.as_str()) + .map(|n| n == server_name) + .unwrap_or(false) + { + if let toml::Value::Table(table) = server { + table.insert( + "auth_token".to_string(), + toml::Value::String(auth_token.to_string()), + ); + if let Some(rt) = new_refresh_token { + table.insert( + "refresh_token".to_string(), + toml::Value::String(rt.to_string()), + ); + } + if let Some(ea) = new_expires_at { + table.insert("token_expires_at".to_string(), toml::Value::Integer(ea)); + } + } + break; + } + } + } + + let new_content = toml::to_string_pretty(&doc).context("Failed to serialise updated config")?; + tokio::fs::write(config_path, new_content) + .await + .with_context(|| format!("Failed to write {}", config_path.display()))?; + + debug!("Persisted refreshed token for MCP server '{server_name}'"); + Ok(()) +} + +/// Refresh tokens for every HTTP MCP server that is near expiry, writing the +/// new credentials back to `config_path`. Returns the number of servers that +/// were refreshed. +pub async fn refresh_expiring_tokens( + configs: &mut [McpServerConfig], + config_path: &Path, + http_client: &reqwest::Client, +) -> usize { + let mut refreshed = 0usize; + for cfg in configs.iter_mut() { + if !token_needs_refresh(cfg) { + continue; + } + info!( + "Access token for MCP server '{}' is expiring; refreshing...", + cfg.name + ); + match refresh_oauth_token(cfg, http_client).await { + Ok((new_token, new_rt, new_exp)) => { + // Persist to disk first; only update in-memory state on success to + // avoid a situation where the runtime uses a token that was never saved. + match update_config_tokens( + config_path, + &cfg.name, + &new_token, + new_rt.as_deref(), + new_exp, + ) + .await + { + Ok(()) => { + cfg.auth_token = Some(new_token); + if new_rt.is_some() { + cfg.refresh_token = new_rt; + } + cfg.token_expires_at = new_exp; + refreshed += 1; + info!("Token refreshed successfully for MCP server '{}'", cfg.name); + } + Err(e) => { + warn!( + "Failed to persist refreshed token for '{}': {e:#}", + cfg.name + ); + } + } + } + Err(e) => { + warn!("Token refresh failed for MCP server '{}': {e:#}", cfg.name); + } + } + } + refreshed +} + /// Represents a connected MCP server with its tools pub struct McpConnection { pub name: String, @@ -53,8 +262,25 @@ impl McpManager { info!("Connecting to HTTP MCP server '{}': {}", config.name, url); - let transport_config = StreamableHttpClientTransportConfig::with_uri(url.to_string()) - .auth_header(config.auth_token.clone().unwrap_or_default()); + let mut transport_config = StreamableHttpClientTransportConfig::with_uri(url.to_string()); + + // Only set the auth header when a non-empty token is provided. + // Using unwrap_or_default() would pass an empty string, causing + // reqwest to send "Authorization: Bearer " (empty token) which + // remote servers (e.g. Notion) reject with 401 invalid_token. + match &config.auth_token { + Some(token) if !token.is_empty() => { + transport_config = transport_config.auth_header(token.clone()); + } + None => { + tracing::debug!( + "HTTP MCP server '{}' has no auth_token configured; \ + requests will be sent without an Authorization header", + config.name + ); + } + _ => {} + } let transport = StreamableHttpClientTransport::from_config(transport_config); @@ -249,3 +475,85 @@ impl McpManager { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::McpServerConfig; + use std::collections::HashMap; + + fn base_config() -> McpServerConfig { + McpServerConfig { + name: "test".to_string(), + command: None, + args: vec![], + env: HashMap::new(), + url: Some("https://example.com/mcp".to_string()), + auth_token: Some("tok".to_string()), + refresh_token: Some("rt".to_string()), + token_expires_at: None, + token_endpoint: Some("https://example.com/token".to_string()), + oauth_client_id: None, + oauth_client_secret: None, + } + } + + #[test] + fn test_no_refresh_without_refresh_token() { + let mut cfg = base_config(); + cfg.refresh_token = None; + assert!(!token_needs_refresh(&cfg)); + } + + #[test] + fn test_no_refresh_without_token_endpoint() { + let mut cfg = base_config(); + cfg.token_endpoint = None; + assert!(!token_needs_refresh(&cfg)); + } + + #[test] + fn test_no_refresh_when_no_expiry() { + let cfg = base_config(); // token_expires_at == None + assert!(!token_needs_refresh(&cfg)); + } + + #[test] + fn test_needs_refresh_when_expired() { + let mut cfg = base_config(); + // Set expiry 1 hour in the past + let past = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64 + - 3600; + cfg.token_expires_at = Some(past); + assert!(token_needs_refresh(&cfg)); + } + + #[test] + fn test_needs_refresh_when_expiring_within_5_min() { + let mut cfg = base_config(); + // Set expiry 60 seconds from now (within 5-minute window) + let soon = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64 + + 60; + cfg.token_expires_at = Some(soon); + assert!(token_needs_refresh(&cfg)); + } + + #[test] + fn test_no_refresh_when_expiry_far_future() { + let mut cfg = base_config(); + // Set expiry 1 hour from now + let far = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64 + + 3600; + cfg.token_expires_at = Some(far); + assert!(!token_needs_refresh(&cfg)); + } +} diff --git a/src/platform/telegram.rs b/src/platform/telegram.rs index e01062f..0bb088a 100644 --- a/src/platform/telegram.rs +++ b/src/platform/telegram.rs @@ -301,18 +301,29 @@ async fn handle_message(bot: Bot, msg: Message, agent: Arc) -> ResponseRe while let Some(token) = rx.recv().await { buffer.push_str(&token); - // When buffer exceeds split threshold, send a NEW message and reset + // When buffer exceeds split threshold, finalize the current message + // and reset so subsequent tokens start a new message. + // + // Previous logic sent the full buffer as a NEW message, then cleared + // the buffer. This caused the new message to visually shrink on the + // next edit (which only contained the small post-split tokens). + // + // Fix: edit/send the current message with its accumulated content + // (finalizing it), then clear the buffer AND current_msg_id so the + // next batch of tokens creates a fresh message. if buffer.len() > TELEGRAM_STREAM_SPLIT { - match stream_bot.send_message(stream_chat_id, &buffer).await { - Ok(new_msg) => { - current_msg_id = Some(new_msg.id); - buffer.clear(); - } - Err(e) => { - tracing::error!(error = %e, "stream_handle: send_message failed at split"); - break; + if let Some(msg_id) = current_msg_id { + if let Err(e) = stream_bot + .edit_message_text(stream_chat_id, msg_id, &buffer) + .await + { + tracing::warn!(error = %e, "stream_handle: edit failed at split"); } + } else if let Err(e) = stream_bot.send_message(stream_chat_id, &buffer).await { + tracing::warn!(error = %e, "stream_handle: send failed at split"); } + buffer.clear(); + current_msg_id = None; last_action = Instant::now(); continue; } diff --git a/src/platform/tool_notifier.rs b/src/platform/tool_notifier.rs index fa062ed..fe74e7f 100644 --- a/src/platform/tool_notifier.rs +++ b/src/platform/tool_notifier.rs @@ -55,6 +55,7 @@ pub fn friendly_tool_name(name: &str) -> String { ("github", "🐙"), ("sqlite", "🗄️"), ("threads", "🧵"), + ("notion", "📝"), ("fetch", "🌐"), ("git", "📦"), ]; diff --git a/src/scheduler/tasks.rs b/src/scheduler/tasks.rs index 62d2200..b0ed0b4 100644 --- a/src/scheduler/tasks.rs +++ b/src/scheduler/tasks.rs @@ -10,6 +10,8 @@ pub async fn register_builtin_tasks( llm: crate::llm::LlmClient, summarize_cron: String, summarize_threshold: usize, + user_model_cron: String, + user_model_path: std::path::PathBuf, ) -> anyhow::Result<()> { // Heartbeat — log that the bot is alive every hour scheduler @@ -41,5 +43,22 @@ pub async fn register_builtin_tasks( .await?; } + // Weekly user model update + { + let memory_clone = _memory.clone(); + let llm_clone = llm.clone(); + let model_path = user_model_path; + scheduler + .add_cron_job(&user_model_cron, "weekly-user-model-update", move || { + let store = memory_clone.clone(); + let llm = llm_clone.clone(); + let path = model_path.clone(); + Box::pin(async move { + crate::learning::update_user_model(&llm, &store, &path).await; + }) + }) + .await?; + } + Ok(()) } diff --git a/src/tools.rs b/src/tools.rs index 08add2b..1297e31 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -182,6 +182,70 @@ pub fn builtin_tool_definitions() -> Vec { }), }, }, + ToolDefinition { + tool_type: "function".to_string(), + function: FunctionDefinition { + name: "try_new_tech".to_string(), + description: "Run a sandboxed experiment with a new technology or approach. Creates a temp project in sandbox/experiments/, writes the code, runs cargo check/test (Rust) or node (JS), and returns results. On success, may auto-generate a skill.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "technology": { + "type": "string", + "description": "Name/description of the technology being tested (e.g. 'serde flatten', 'tokio select')" + }, + "experiment_code": { + "type": "string", + "description": "The source code for the experiment" + }, + "language": { + "type": "string", + "enum": ["rust", "javascript"], + "description": "Programming language for the experiment (default: rust)" + } + }, + "required": ["technology", "experiment_code"] + }), + }, + }, + ToolDefinition { + tool_type: "function".to_string(), + function: FunctionDefinition { + name: "self_update_to_branch".to_string(), + description: "Update the bot to a specific git branch and rebuild. Runs: git fetch, git checkout , git pull, cargo build --release. Use for development without manual SSH. The bot should be restarted after a successful build.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "branch": { + "type": "string", + "description": "Git branch to update to (default: 'main')" + } + }, + "required": [] + }), + }, + }, + ToolDefinition { + tool_type: "function".to_string(), + function: FunctionDefinition { + name: "patch_skill".to_string(), + description: "Patch an existing skill's SKILL.md by appending content or replacing it entirely. Creates a .bak backup. If the patch starts with '---' (frontmatter), it replaces the whole file; otherwise it appends to the existing content. Reloads skills after patching.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "skill_name": { + "type": "string", + "description": "Name of the skill to patch (e.g. 'code-interpreter')" + }, + "patch_content": { + "type": "string", + "description": "Content to append (or full replacement if it starts with ---)" + } + }, + "required": ["skill_name", "patch_content"] + }), + }, + }, ] }