From cecd714c595c279ed732b564bc3a177eebe876d2 Mon Sep 17 00:00:00 2001 From: dahai9 Date: Thu, 26 Mar 2026 17:01:22 +0800 Subject: [PATCH 1/9] refactor: simplify state management with strum and optimize README --- Cargo.toml | 17 ++- README.md | 147 +++++++----------------- examples/basic.rs | 14 +-- examples/pocketflow-rs-rag/Cargo.toml | 1 + examples/pocketflow-rs-rag/src/state.rs | 30 +---- examples/text2sql/src/flow.rs | 22 +--- src/flow.rs | 19 +-- src/node.rs | 33 +++--- 8 files changed, 84 insertions(+), 199 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2ca1114..536e696 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,10 +15,10 @@ name = "basic" path = "examples/basic.rs" [workspace] -members = [ - "examples/pocketflow-rs-rag", - "examples/text2sql" -] +members = ["examples/pocketflow-rs-rag", "examples/text2sql"] + +[workspace.dependencies] +strum = "0.26" [dependencies] anyhow = "1.0" @@ -29,9 +29,10 @@ serde_json = "1.0" thiserror = "1.0" tracing = "0.1" rand = "0.8" -openai_api_rust = { version = "0.1.9", optional = true} +openai_api_rust = { version = "0.1.9", optional = true } regex = "1.11.1" -qdrant-client = {version = "1.14.0", optional = true} +strum = { version = "0.26", features = ["derive"] } +qdrant-client = { version = "1.14.0", optional = true } reqwest = { version = "0.12", features = ["json"], optional = true } [features] @@ -39,6 +40,4 @@ openai = ["dep:openai_api_rust"] websearch = ["dep:reqwest"] qdrant = ["dep:qdrant-client"] debug = [] -default = [ - "openai", -] \ No newline at end of file +default = ["openai"] diff --git a/README.md b/README.md index 2cb1c76..c352579 100644 --- a/README.md +++ b/README.md @@ -4,34 +4,39 @@ A Rust implementation of [PocketFlow](https://github.com/The-Pocket/PocketFlow), a minimalist flow-based programming framework. -📋 [Get started quickly with our template →](#template) +🚀 [Get started quickly with our template →](#template) -## Features +## ✨ Features -- Type-safe state transitions using enums -- Macro-based flow construction -- Async node execution and post-processing -- Batch flow support -- Custom state management -- Extensible node system +- 🦀 **Type-safe:** State transitions using Rust enums +- 🏗️ **Macro-based:** Flow construction using `build_flow!` and `build_batch_flow!` +- ⚡ **Async first:** Non-blocking node execution and post-processing +- 📦 **Batch support:** High-performance processing of multiple contexts +- 🧩 **Extensible:** Custom state management and node systems +- 🛠️ **Utility-rich:** Optional integrations for OpenAI, Qdrant, and web search -## Quick Start +## 🚀 Quick Start ### 0. Setup -```bash -cargo add pocketflow_rs +```toml +[dependencies] +pocketflow_rs = "0.1.0" +strum = { version = "0.26", features = ["derive"] } ``` ### 1. Define Custom States ```rust use pocketflow_rs::ProcessState; +use strum::Display; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Default, Display)] +#[strum(serialize_all = "snake_case")] pub enum MyState { Success, Failure, + #[default] Default, } @@ -39,19 +44,6 @@ impl ProcessState for MyState { fn is_default(&self) -> bool { matches!(self, MyState::Default) } - fn to_condition(&self) -> String { - match self { - MyState::Success => "success".to_string(), - MyState::Failure => "failure".to_string(), - MyState::Default => "default".to_string(), - } - } -} - -impl Default for MyState { - fn default() -> Self { - MyState::Default - } } ``` @@ -79,12 +71,12 @@ impl Node for MyNode { result: &Result, ) -> Result> { // Your post-processing logic here - Ok(ProcessResult::new(MyState::Success, "success".to_string())) + Ok(ProcessResult::new(MyState::Success, "success")) } } ``` -### 3. Build Flows +### 3. Build & Run Flows ```rust use pocketflow_rs::{build_flow, Context}; @@ -104,7 +96,11 @@ let context = Context::new(); let result = flow.run(context).await?; ``` -### 4. Batch Processing +## 🏗️ Advanced Usage + +### Batch Processing + +Build high-throughput flows for parallel processing: ```rust use pocketflow_rs::build_batch_flow; @@ -119,95 +115,36 @@ let batch_flow = build_batch_flow!( ); let contexts = vec![Context::new(); 10]; -batch_flow.run_batch(contexts).await?; +let results = batch_flow.run_batch(contexts).await?; ``` -## Advanced Usage - -### Custom State Management - -Define your own states to control flow transitions: +## 🛠️ Available Features -```rust -#[derive(Debug, Clone, PartialEq)] -pub enum WorkflowState { - Initialized, - Processing, - Completed, - Error, - Default, -} - -impl ProcessState for WorkflowState { - fn is_default(&self) -> bool { - matches!(self, WorkflowState::Default) - } - fn to_condition(&self) -> String { - match self { - WorkflowState::Initialized => "initialized".to_string(), - WorkflowState::Processing => "processing".to_string(), - WorkflowState::Completed => "completed".to_string(), - WorkflowState::Error => "error".to_string(), - WorkflowState::Default => "default".to_string(), - } - } -} -``` +Customize `pocketflow_rs` by enabling the features you need in your `Cargo.toml`: -### Complex Flow Construction - -Build complex workflows with multiple nodes and state transitions: - -```rust -let flow = build_flow!( - start: ("start", node1), - nodes: [ - ("process", node2), - ("validate", node3), - ("complete", node4) - ], - edges: [ - ("start", "process", WorkflowState::Initialized), - ("process", "validate", WorkflowState::Processing), - ("validate", "process", WorkflowState::Error), - ("validate", "complete", WorkflowState::Completed) - ] -); -``` - -## Available Features - -The following features are available: (feature for [utility_function](https://the-pocket.github.io/PocketFlow/utility_function/)) - -- `openai` (default): Enable OpenAI API integration for LLM capabilities -- `websearch`: Enable web search functionality using Google Custom Search API -- `qdrant`: Enable vector database integration using Qdrant -- `debug`: Enable additional debug logging and information - -To use specific features, add them to your `Cargo.toml`: +| Feature | Description | +|---------|-------------| +| `openai` (default) | OpenAI API integration for LLM capabilities | +| `websearch` | Google Custom Search API integration | +| `qdrant` | Vector database integration using Qdrant | +| `debug` | Enhanced logging and visualization tools | +Example: ```toml -[dependencies] -pocketflow_rs = { version = "0.1.0", features = ["openai", "websearch"] } -``` - -Or use them in the command line: - -```bash -cargo add pocketflow_rs --features "openai websearch" +pocketflow_rs = { version = "0.1.0", features = ["openai", "qdrant"] } ``` -## Examples +## 📂 Examples -Check out the `examples/` directory for more detailed examples: +Check out the `examples/` directory for detailed implementations: -- basic.rs: Basic flow with custom states -- text2sql: Text-to-SQL workflow example -- [pocketflow-rs-rag](./examples/pocketflow-rs-rag/README.md): Retrieval-Augmented Generation (RAG) workflow example +- 🟢 [**basic.rs**](./examples/basic.rs): Basic flow with custom states +- 🗃️ [**text2sql**](./examples/text2sql/): Text-to-SQL workflow using OpenAI +- 🔍 [**pocketflow-rs-rag**](./examples/pocketflow-rs-rag/): Retrieval-Augmented Generation (RAG) system -## Template +## 📋 Template -Fork the [PocketFlow-Template-Rust](https://github.com/The-Pocket/PocketFlow-Template-Rust) repository and use it as a template for your own project. +Don't start from scratch! Use the [PocketFlow-Template-Rust](https://github.com/The-Pocket/PocketFlow-Template-Rust) to kickstart your project. ## License diff --git a/examples/basic.rs b/examples/basic.rs index 421763c..2973cb2 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -2,8 +2,9 @@ use anyhow::Result; use pocketflow_rs::{Context, Node, ProcessResult, ProcessState, build_flow}; use rand::Rng; use serde_json::Value; - -#[derive(Debug, Clone, PartialEq, Default)] +use strum::Display; +#[derive(Debug, Clone, PartialEq, Default, Display)] +#[strum(serialize_all = "snake_case")] enum NumberState { Small, Medium, @@ -16,15 +17,6 @@ impl ProcessState for NumberState { fn is_default(&self) -> bool { matches!(self, NumberState::Default) } - - fn to_condition(&self) -> String { - match self { - NumberState::Small => "small".to_string(), - NumberState::Medium => "medium".to_string(), - NumberState::Large => "large".to_string(), - NumberState::Default => "default".to_string(), - } - } } // A simple node that prints a message diff --git a/examples/pocketflow-rs-rag/Cargo.toml b/examples/pocketflow-rs-rag/Cargo.toml index ad3117d..eab0970 100644 --- a/examples/pocketflow-rs-rag/Cargo.toml +++ b/examples/pocketflow-rs-rag/Cargo.toml @@ -18,6 +18,7 @@ reqwest = { version = "0.12.15", features = ["json"] } uuid = { version = "1.16.0", features = ["v4"] } qdrant-client = "1.14.0" termimad = "0.31.3" +strum ={ workspace = true } [dev-dependencies] tempfile = "3.8" diff --git a/examples/pocketflow-rs-rag/src/state.rs b/examples/pocketflow-rs-rag/src/state.rs index a266e5f..71f27f4 100644 --- a/examples/pocketflow-rs-rag/src/state.rs +++ b/examples/pocketflow-rs-rag/src/state.rs @@ -1,6 +1,8 @@ use pocketflow_rs::ProcessState; +use strum::Display; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Display)] +#[strum(serialize_all = "snake_case")] pub enum RagState { // Offline states FileLoadedError, @@ -29,32 +31,6 @@ impl ProcessState for RagState { fn is_default(&self) -> bool { matches!(self, RagState::Default) } - - fn to_condition(&self) -> String { - match self { - // Offline states - RagState::FileLoadedError => "file_loaded_error".to_string(), - RagState::DocumentsLoaded => "documents_loaded".to_string(), - RagState::DocumentsChunked => "documents_chunked".to_string(), - RagState::ChunksEmbedded => "chunks_embedded".to_string(), - RagState::IndexCreated => "index_created".to_string(), - // Offline error states - RagState::DocumentLoadError => "document_load_error".to_string(), - RagState::ChunkingError => "chunking_error".to_string(), - RagState::EmbeddingError => "embedding_error".to_string(), - RagState::IndexCreationError => "index_creation_error".to_string(), - // Online states - RagState::QueryEmbedded => "query_embedded".to_string(), - RagState::DocumentsRetrieved => "documents_retrieved".to_string(), - RagState::AnswerGenerated => "answer_generated".to_string(), - // Online error states - RagState::QueryEmbeddingError => "query_embedding_error".to_string(), - RagState::RetrievalError => "retrieval_error".to_string(), - RagState::GenerationError => "generation_error".to_string(), - RagState::Default => "default".to_string(), - RagState::QueryRewriteError => "query_rewrite_error".to_string(), - } - } } impl Default for RagState { diff --git a/examples/text2sql/src/flow.rs b/examples/text2sql/src/flow.rs index 9eb651f..b8b40ae 100644 --- a/examples/text2sql/src/flow.rs +++ b/examples/text2sql/src/flow.rs @@ -1,3 +1,5 @@ +use std::default; + use anyhow::{Context as AnyhowContext, Result}; use async_trait::async_trait; use chrono::NaiveDate; @@ -7,13 +9,16 @@ use openai_api_rust::chat::*; use openai_api_rust::*; use pocketflow_rs::{Context, Node, ProcessResult, ProcessState}; use serde_json::{Value, json}; +use strum::Display; use tracing::{error, info}; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Display)] +#[strum(serialize_all = "snake_case")] pub enum SqlExecutorState { SchemaRetrieved, SqlGenerated, SqlExecuted, + #[default] Default, } @@ -21,21 +26,6 @@ impl ProcessState for SqlExecutorState { fn is_default(&self) -> bool { matches!(self, SqlExecutorState::Default) } - - fn to_condition(&self) -> String { - match self { - SqlExecutorState::SchemaRetrieved => "schema_retrieved".to_string(), - SqlExecutorState::SqlGenerated => "sql_generated".to_string(), - SqlExecutorState::SqlExecuted => "sql_executed".to_string(), - SqlExecutorState::Default => "default".to_string(), - } - } -} - -impl Default for SqlExecutorState { - fn default() -> Self { - SqlExecutorState::Default - } } #[derive(Debug, thiserror::Error)] diff --git a/src/flow.rs b/src/flow.rs index 63f0aa2..bec5465 100644 --- a/src/flow.rs +++ b/src/flow.rs @@ -159,13 +159,13 @@ macro_rules! build_flow { )* // Handle edges appropriately $( - build_flow!(@edge g, $edge); + build_flow!(@edge_process g, $edge); )* g }}; - (@edge $g:expr, ($from:expr, $to:expr, $condition:expr)) => { + (@edge_process $g:expr, ($from:expr, $to:expr, $condition:expr)) => { $g.add_edge($from, $to, $condition); }; } @@ -204,7 +204,7 @@ macro_rules! build_batch_flow { )* // Handle edges appropriately $( - build_flow!(@edge g.flow, $edge); + build_flow!(@edge_process g.flow, $edge); )* g }}; @@ -216,10 +216,11 @@ mod tests { use crate::node::{Node, ProcessResult, ProcessState}; use async_trait::async_trait; use serde_json::json; + use strum::Display; - #[derive(Debug, Clone, PartialEq)] + #[derive(Debug, Clone, PartialEq, Default, Display)] + #[strum(serialize_all = "snake_case")] #[allow(dead_code)] - #[derive(Default)] enum CustomState { Success, Failure, @@ -231,14 +232,6 @@ mod tests { fn is_default(&self) -> bool { matches!(self, CustomState::Default) } - - fn to_condition(&self) -> String { - match self { - CustomState::Success => "success".to_string(), - CustomState::Failure => "failure".to_string(), - CustomState::Default => "default".to_string(), - } - } } struct TestNode { diff --git a/src/node.rs b/src/node.rs index 72974ea..d14b49d 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,15 +1,20 @@ use crate::{Params, context::Context}; use anyhow::Result; use async_trait::async_trait; -use std::collections::HashMap; -use std::sync::Arc; +use strum::Display; +// use std::collections::HashMap; +// use std::sync::Arc; -pub trait ProcessState: Send + Sync { +pub trait ProcessState: Send + Sync + std::fmt::Display { fn is_default(&self) -> bool; - fn to_condition(&self) -> String; + + fn to_condition(&self) -> String { + self.to_string() + } } -#[derive(Debug, Clone, PartialEq, Default)] +#[derive(Debug, Clone, PartialEq, Default, Display)] +#[strum(serialize_all = "snake_case")] pub enum BaseState { Success, Failure, @@ -21,14 +26,6 @@ impl ProcessState for BaseState { fn is_default(&self) -> bool { matches!(self, BaseState::Default) } - - fn to_condition(&self) -> String { - match self { - BaseState::Success => "success".to_string(), - BaseState::Failure => "failure".to_string(), - BaseState::Default => "default".to_string(), - } - } } #[derive(Debug, Clone, PartialEq)] @@ -87,20 +84,20 @@ pub trait BaseNodeTrait: Node {} #[allow(dead_code)] pub struct BaseNode { params: Params, - next_nodes: HashMap>, + // next_nodes: HashMap>, } impl BaseNode { pub fn new(params: Params) -> Self { Self { params, - next_nodes: HashMap::new(), + // next_nodes: HashMap::new(), } } - pub fn add_next(&mut self, action: String, node: Arc) { - self.next_nodes.insert(action, node); - } + // pub fn add_next(&mut self, action: String, node: Arc) { + // self.next_nodes.insert(action, node); + // } } #[async_trait] From e1a99f715216788af0aa245eeafdfa25b726b030 Mon Sep 17 00:00:00 2001 From: dahai9 Date: Fri, 27 Mar 2026 11:40:58 +0800 Subject: [PATCH 2/9] docs: rewrite AGENTS.md with Rust-focused agentic coding guidance --- AGENTS.md | 215 +++++++++++++++++++++++++++++++++++++++++++++++++ docs/design.md | 78 ++++++++++++++++++ 2 files changed, 293 insertions(+) create mode 100644 AGENTS.md create mode 100644 docs/design.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..e73695e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,215 @@ +--- +layout: default +title: "Agentic Coding" +--- + +# Agentic Coding: Humans Design, Agents code! + +> If you are an AI agent building apps with PocketFlow-Rust, read this guide carefully. Start small, design at a high level first (`docs/design.md`), then implement and verify. +{: .warning } + +## Agentic Coding Steps + +Agentic coding should be a collaboration between human system design and AI implementation. + +| Steps | Human | AI | Comment | +|:--|:--:|:--:|:--| +| 1. Requirements | ★★★ High | ★☆☆ Low | Humans define the problem and success criteria. | +| 2. Flow | ★★☆ Medium | ★★☆ Medium | Humans define the orchestration; AI fills in details. | +| 3. Utilities | ★★☆ Medium | ★★☆ Medium | Humans provide external APIs; AI helps implement wrappers. | +| 4. Data | ★☆☆ Low | ★★★ High | AI proposes the schema; humans verify it matches the app. | +| 5. Node | ★☆☆ Low | ★★★ High | AI designs nodes around the flow and shared store. | +| 6. Implementation | ★☆☆ Low | ★★★ High | AI implements the flow and nodes from the design. | +| 7. Optimization | ★★☆ Medium | ★★☆ Medium | Iterate on prompts, data shape, and flow structure. | +| 8. Reliability | ★☆☆ Low | ★★★ High | Add validation, retries, logging, and tests. | + +1. **Requirements**: Clarify the user problem, not just the feature list. + - Good for: repetitive tasks, structured transformations, workflow automation, RAG, agent loops. + - Not good for: vague goals without measurable outputs or unstable business decisions. + - Keep it user-centric and small at first. + +2. **Flow Design**: Define the graph at a high level. + - Pick a pattern if it fits: [Agent](./design_pattern/agent.md), [Workflow](./design_pattern/workflow.md), [RAG](./design_pattern/rag.md), [Map Reduce](./design_pattern/mapreduce.md). + - For each node, write a one-line purpose and its next action conditions. + - Draw the flow in mermaid. + - Use the Rust API as source of truth: `Node`, `Flow`, `BatchFlow`, `ProcessState`. + - Example: + ```mermaid + flowchart LR + start[Load Input] --> process[Process] + process --> finish[Finish] + ``` + - If you cannot describe the flow manually, do not automate it yet. + +3. **Utilities**: Identify required external I/O helpers. + - Think of utilities as the body of the agent: file I/O, web requests, LLM calls, DB access, embeddings. + - Keep LLM tasks inside nodes or utilities, but do not confuse them with orchestration. + - Put reusable wrappers in `src/utils/*.rs` and add a small test when practical. + - Prefer returning `anyhow::Result` and keep wrappers narrow and deterministic. + - Example utility shape (real wrapper from `src/utils/llm_wrapper.rs`): + ```rust + use async_trait::async_trait; + + #[async_trait] + pub trait LLMWrapper { + async fn generate(&self, prompt: &str) -> anyhow::Result; + } + + pub struct OpenAIClient { + api_key: String, + model: String, + endpoint: String, + } + + impl OpenAIClient { + pub fn new(api_key: String, model: String, endpoint: String) -> Self { + Self { api_key, model, endpoint } + } + } + + #[async_trait] + impl LLMWrapper for OpenAIClient { + async fn generate(&self, prompt: &str) -> anyhow::Result { + // Use openai_api_rust or reqwest to call the API + todo!() + } + } + ``` + +4. **Data Design**: Design the shared store before coding the nodes. + - The shared store is `Context`. + - Use `Context::set()` / `Context::get()` for shared data. + - Use `metadata` for auxiliary data that should not be treated as primary results. + - Keep keys simple and avoid redundancy. + - Example: + ```rust + use pocketflow_rs::Context; + use serde_json::json; + + let mut context = Context::new(); + context.set("input", json!("hello")); + context.set_metadata("source", json!("user")); + ``` + +5. **Node Design**: Plan each node’s role and state transitions. + - `prepare(&mut context)`: optional, read from `Context` and prepare inputs. + - `execute(&context)`: do compute or remote calls; keep it idempotent when possible. + - `post_process(&mut context, &result)`: write outputs back to `Context` and return `ProcessResult`. + - Define a custom `State` enum implementing `ProcessState` for branching; use `BaseState` when simple. + - Example node shape (from `examples/basic.rs`): + ```rust + use anyhow::Result; + use async_trait::async_trait; + use pocketflow_rs::{Context, Node, ProcessResult, ProcessState}; + use serde_json::Value; + use strum::Display; + + #[derive(Debug, Clone, PartialEq, Default, Display)] + #[strum(serialize_all = "snake_case")] + enum MyState { + #[default] + Default, + Success, + } + + impl ProcessState for MyState { + fn is_default(&self) -> bool { + matches!(self, MyState::Default) + } + } + + struct MyNode; + + #[async_trait] + impl Node for MyNode { + type State = MyState; + + async fn execute(&self, context: &Context) -> Result { + let input = context.get("input").cloned().unwrap_or(Value::Null); + Ok(input) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + match result { + Ok(value) => { + context.set("output", value.clone()); + Ok(ProcessResult::new(MyState::Success, "done".to_string())) + } + Err(e) => { + context.set("error", Value::String(e.to_string())); + Ok(ProcessResult::new(MyState::Default, e.to_string())) + } + } + } + } + ``` + +6. **Implementation**: Build the initial nodes and flows. + - Keep the first pass simple. + - Use `build_flow!` and `build_batch_flow!` instead of hand-wiring infrastructure. + - Example flow assembly: + ```rust + use pocketflow_rs::{build_flow, Flow, BaseState}; + + pub fn create_flow() -> Flow { + build_flow!( + start: ("get_input", GetInputNode), + nodes: [ + ("process", ProcessNode), + ("output", OutputNode) + ], + edges: [ + ("get_input", "process", BaseState::Default), + ("process", "output", BaseState::Default) + ] + ) + } + ``` + - Add logging via `tracing` where it helps debugging. + - Prefer small, composable nodes over large monoliths. + +7. **Optimization**: Improve after the first working version. + - Refine the flow when the bottleneck is logic or structure. + - Refine prompts and context when the bottleneck is model behavior. + - Refine utilities when the bottleneck is I/O or integration. + +8. **Reliability**: Make failures visible and recoverable. + - Validate results in `execute` or `post_process`. + - Use the framework’s retry behavior where available. + - Add tests for utility wrappers and critical node transitions. + - Log failures and important decisions. + +## Example Rust Project Layout + +``` +my_project/ +├── Cargo.toml +├── src/ +│ ├── main.rs +│ ├── flow.rs +│ ├── nodes.rs +│ └── utils/ +│ ├── mod.rs +│ ├── llm_wrapper.rs +│ └── web_search.rs +└── docs/ + └── design.md +``` + +- **`Cargo.toml`**: add `pocketflow_rs`, `serde_json`, `anyhow`, `async-trait`, `tokio`, and `strum` as dependencies. +- **`docs/design.md`**: keep it high-level and Rust-oriented; do not copy Python pseudocode. +- **`src/utils/`**: one file per reusable integration is a good default. +- **`src/nodes.rs`**: node definitions should stay focused and readable. +- **`src/flow.rs`**: assemble the flow graph and state transitions. +- **State enums**: use `strum::Display` with `#[strum(serialize_all = "snake_case")]` for automatic state-to-string conversion. + +## Before Finishing + +- Check names, types, and examples against `src/lib.rs`, `src/node.rs`, and `src/flow.rs`. +- Remove any Python-only syntax or nonexistent APIs. +- Keep examples runnable or close to runnable against the current Rust API. +- Ensure state enums implement `ProcessState` and use `strum::Display` for edge matching. diff --git a/docs/design.md b/docs/design.md new file mode 100644 index 0000000..8838469 --- /dev/null +++ b/docs/design.md @@ -0,0 +1,78 @@ +# Design Doc: Your Project Name + +> Please DON'T remove notes for AI + +## Requirements + +> Notes for AI: Keep it simple and clear. +> If the requirements are abstract, write concrete user stories + + +## Flow Design + +> Notes for AI: +> 1. Consider the design patterns of agent, map-reduce, rag, and workflow. Apply them if they fit. +> 2. Present a concise, high-level description of the workflow. + +### Applicable Design Pattern: + +1. Map the file summary into chunks, then reduce these chunks into a final summary. +2. Agentic file finder + - *Context*: The entire summary of the file + - *Action*: Find the file + +### Flow high-level Design: + +1. **First Node**: This node is for ... +2. **Second Node**: This node is for ... +3. **Third Node**: This node is for ... + +```mermaid +flowchart TD + firstNode[First Node] --> secondNode[Second Node] + secondNode --> thirdNode[Third Node] +``` +## Utility Functions + +> Notes for AI: +> 1. Understand the utility function definition thoroughly by reviewing the doc. +> 2. Include only the necessary utility functions, based on nodes in the flow. + +1. **Call LLM** (`utils/call_llm.py`) + - *Input*: prompt (str) + - *Output*: response (str) + - Generally used by most nodes for LLM tasks + +2. **Embedding** (`utils/get_embedding.py`) + - *Input*: str + - *Output*: a vector of 3072 floats + - Used by the second node to embed text + +## Node Design + +### Shared Store + +> Notes for AI: Try to minimize data redundancy + +The shared store structure is organized as follows: + +```python +shared = { + "key": "value" +} +``` + +### Node Steps + +> Notes for AI: Carefully decide whether to use Batch/Async Node/Flow. + +1. First Node + - *Purpose*: Provide a short explanation of the node’s function + - *Type*: Decide between Regular, Batch, or Async + - *Steps*: + - *prep*: Read "key" from the shared store + - *exec*: Call the utility function + - *post*: Write "key" to the shared store + +2. Second Node + ... From 9b534023856c35d43a32687c5db547f31c19df08 Mon Sep 17 00:00:00 2001 From: dahai9 Date: Fri, 27 Mar 2026 13:40:39 +0800 Subject: [PATCH 3/9] chore: initial pi-mono refactor working implementation before compaction --- .gitignore | 4 +- Cargo.toml | 11 +- src/bin/pi.rs | 382 +++++++++++++++++++++++++++++++++++ src/utils/mod.rs | 3 + src/utils/pi_llm.rs | 50 +++++ src/utils/session_manager.rs | 68 +++++++ src/utils/tools/mod.rs | 47 +++++ xxx.md | 70 +++++++ 8 files changed, 632 insertions(+), 3 deletions(-) create mode 100644 src/bin/pi.rs create mode 100644 src/utils/pi_llm.rs create mode 100644 src/utils/session_manager.rs create mode 100644 src/utils/tools/mod.rs create mode 100644 xxx.md diff --git a/.gitignore b/.gitignore index ab099f4..4f9030e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ target .env Cargo.lock -.vscode/ \ No newline at end of file +.vscode/ +.pi/ +test_dir/ diff --git a/Cargo.toml b/Cargo.toml index 536e696..e259317 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,10 @@ license = "MIT" name = "pocketflow_rs" path = "src/lib.rs" +[[bin]] +name = "pi" +path = "src/bin/pi.rs" + [[example]] name = "basic" path = "examples/basic.rs" @@ -33,11 +37,14 @@ openai_api_rust = { version = "0.1.9", optional = true } regex = "1.11.1" strum = { version = "0.26", features = ["derive"] } qdrant-client = { version = "1.14.0", optional = true } -reqwest = { version = "0.12", features = ["json"], optional = true } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +clap = { version = "4.4", features = ["derive"] } +directories = "5.0" +uuid = { version = "1.8", features = ["v4", "fast-rng"] } [features] openai = ["dep:openai_api_rust"] -websearch = ["dep:reqwest"] +websearch = [] qdrant = ["dep:qdrant-client"] debug = [] default = ["openai"] diff --git a/src/bin/pi.rs b/src/bin/pi.rs new file mode 100644 index 0000000..ba708d1 --- /dev/null +++ b/src/bin/pi.rs @@ -0,0 +1,382 @@ +use anyhow::Result; +use clap::Parser; +use pocketflow_rs::{build_flow, Context, Flow, Node, ProcessResult, ProcessState}; +use serde_json::{json, Value}; +use std::io::{self, Write}; +use std::sync::Arc; +use strum::Display; +use uuid::Uuid; +use pocketflow_rs::utils::pi_llm::PiLLM; +use pocketflow_rs::utils::session_manager::{AgentMessage, SessionManager}; +use pocketflow_rs::utils::tools::{execute_bash, read_file, write_file}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short, long)] + interactive: bool, + + #[arg(short, long, default_value = "openai")] + provider: String, + + #[arg(short, long, default_value = "gpt-4o")] + model: String, +} + +#[derive(Debug, Clone, PartialEq, Default, Display)] +#[strum(serialize_all = "snake_case")] +enum PiState { + #[default] + Default, + CallLLM, + ExecuteTool, + WaitForInput, + Finished, +} + +impl ProcessState for PiState { + fn is_default(&self) -> bool { + matches!(self, PiState::Default) + } +} + +// Global shared components between nodes +struct AppContext { + llm: PiLLM, + session_manager: SessionManager, +} + +struct InputNode { + app: Arc, +} + +#[async_trait::async_trait] +impl Node for InputNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + print!("> "); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let text = input.trim().to_string(); + + if text == "exit" || text == "quit" { + return Ok(json!({ "command": "exit" })); + } + + let id = Uuid::new_v4().to_string(); + + let msg = AgentMessage { + id: id.clone(), + parent_id: None, + role: "user".to_string(), + content: text, + name: None, + tool_calls: None, + tool_call_id: None, + }; + + // Persist immediately + self.app.session_manager.append_message(&msg)?; + + Ok(json!({ "message": msg })) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = result.as_ref().unwrap(); + if res.get("command").and_then(|v| v.as_str()) == Some("exit") { + return Ok(ProcessResult::new(PiState::Finished, "finished".to_string())); + } + + let msg_val = res.get("message").unwrap(); + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); + messages.as_array_mut().unwrap().push(msg_val.clone()); + context.set("messages", messages); + + Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + } +} + +struct LLMReasoningNode { + app: Arc, +} + +#[async_trait::async_trait] +impl Node for LLMReasoningNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let messages = context.get("messages").unwrap_or(&json!([])).clone(); + + let tools = json!([ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read the contents of a file", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string" } + }, + "required": ["path"] + } + } + }, + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write contents to a file", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "content": { "type": "string" } + }, + "required": ["path", "content"] + } + } + }, + { + "type": "function", + "function": { + "name": "bash", + "description": "Execute a bash command", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" } + }, + "required": ["command"] + } + } + } + ]); + + let mut openai_messages = Vec::new(); + // Convert AgentMessage to format expected by OpenAI + if let Some(arr) = messages.as_array() { + for m in arr { + let mut mapped = json!({ + "role": m["role"].as_str().unwrap(), + "content": m["content"].as_str().unwrap() + }); + if let Some(calls) = m.get("tool_calls") { + if !calls.is_null() { + mapped.as_object_mut().unwrap().insert("tool_calls".to_string(), calls.clone()); + } + } + if let Some(tid) = m.get("tool_call_id") { + if !tid.is_null() { + mapped.as_object_mut().unwrap().insert("tool_call_id".to_string(), tid.clone()); + } + } + openai_messages.push(mapped); + } + } + + let response = self.app.llm.chat_completion(openai_messages, tools).await?; + Ok(response) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = match result { + Ok(v) => v, + Err(e) => { + println!("\n[LLM Error]: {}\n", e); + return Ok(ProcessResult::new(PiState::WaitForInput, "error".to_string())); + } + }; + + // Ensure choice 0 exists + if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { + if let Some(choice) = choices.first() { + let msg = choice.get("message").unwrap(); + let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or(""); + let tool_calls = msg.get("tool_calls"); + + let agent_msg = AgentMessage { + id: Uuid::new_v4().to_string(), + parent_id: None, + role: "assistant".to_string(), + content: content.to_string(), + name: None, + tool_calls: tool_calls.cloned(), + tool_call_id: None, + }; + + // Persist + self.app.session_manager.append_message(&agent_msg)?; + + // Print + if !content.is_empty() { + println!("\nAssistant: {}\n", content); + } + + // Update context + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); + messages.as_array_mut().unwrap().push(serde_json::to_value(&agent_msg)?); + context.set("messages", messages); + + if let Some(tc) = tool_calls { + if !tc.is_null() && tc.as_array().map_or(false, |a| !a.is_empty()) { + return Ok(ProcessResult::new(PiState::ExecuteTool, "execute_tool".to_string())); + } + } + } + } + + Ok(ProcessResult::new(PiState::WaitForInput, "wait_for_input".to_string())) + } +} + +struct ToolExecutionNode { + app: Arc, +} + +#[async_trait::async_trait] +impl Node for ToolExecutionNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let messages = context.get("messages").unwrap().as_array().unwrap(); + let last_msg = messages.last().unwrap(); + + let mut tool_results = Vec::new(); + + if let Some(tool_calls) = last_msg.get("tool_calls").and_then(|tc| tc.as_array()) { + for call in tool_calls { + let id = call["id"].as_str().unwrap().to_string(); + let func = &call["function"]; + let name = func["name"].as_str().unwrap(); + let args_str = func["arguments"].as_str().unwrap(); + let args: Value = serde_json::from_str(args_str)?; + + println!("Executing tool: {} with args: {}", name, args_str); + + let output = match name { + "read_file" => { + let path = args["path"].as_str().unwrap(); + read_file(path) + } + "write_file" => { + let path = args["path"].as_str().unwrap(); + let content = args["content"].as_str().unwrap(); + write_file(path, content) + } + "bash" => { + let command = args["command"].as_str().unwrap(); + execute_bash(command, ".") + } + _ => format!("Unknown tool: {}", name), + }; + + let agent_msg = AgentMessage { + id: Uuid::new_v4().to_string(), + parent_id: None, + role: "tool".to_string(), + content: output, + name: Some(name.to_string()), + tool_calls: None, + tool_call_id: Some(id), + }; + + tool_results.push(agent_msg); + } + } + + Ok(serde_json::to_value(tool_results)?) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let tool_results: Vec = serde_json::from_value(result.as_ref().unwrap().clone())?; + + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); + + for msg in tool_results { + self.app.session_manager.append_message(&msg)?; + messages.as_array_mut().unwrap().push(serde_json::to_value(&msg)?); + } + + context.set("messages", messages); + + Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + + // Setup directory and SessionManager + let cwd = std::env::current_dir()?; + let session_manager = SessionManager::new(&cwd); + + // Load API Key + let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()); + let mut endpoint = std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string()); + if !endpoint.ends_with("/chat/completions") { + endpoint = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + } + + let llm = PiLLM::new(api_key, args.model, endpoint); + + let app_context = Arc::new(AppContext { + llm, + session_manager, + }); + + let input_node = InputNode { app: app_context.clone() }; + let llm_node = LLMReasoningNode { app: app_context.clone() }; + let tool_node = ToolExecutionNode { app: app_context.clone() }; + + let flow = build_flow!( + start: ("input", input_node), + nodes: [ + ("llm", llm_node), + ("tool", tool_node) + ], + edges: [ + ("input", "llm", PiState::CallLLM), + ("llm", "tool", PiState::ExecuteTool), + ("llm", "input", PiState::WaitForInput), + ("tool", "llm", PiState::CallLLM) + // Implicit default stop for PiState::Finished + ] + ); + + let mut context = Context::new(); + + // Load history + let history = app_context.session_manager.load_history(None)?; + if !history.is_empty() { + println!("Loaded {} messages from history.", history.len()); + let val = serde_json::to_value(history)?; + context.set("messages", val); + } else { + context.set("messages", json!([])); + } + + println!("pi agent started. Type 'exit' to quit."); + + match flow.run(context).await { + Ok(_) => println!("Agent shutdown."), + Err(e) => eprintln!("Error running flow: {}", e), + } + + Ok(()) +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 82e3df8..f66d11b 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -4,3 +4,6 @@ pub mod text_chunking; pub mod vector_db; pub mod viz_debug; pub mod web_search; +pub mod pi_llm; +pub mod session_manager; +pub mod tools; diff --git a/src/utils/pi_llm.rs b/src/utils/pi_llm.rs new file mode 100644 index 0000000..a0b8e14 --- /dev/null +++ b/src/utils/pi_llm.rs @@ -0,0 +1,50 @@ +use anyhow::Result; +use reqwest::Client; +use serde_json::{json, Value}; +use tracing::info; + +pub struct PiLLM { + client: Client, + api_key: String, + model: String, + endpoint: String, +} + +impl PiLLM { + pub fn new(api_key: String, model: String, endpoint: String) -> Self { + Self { + client: Client::new(), + api_key, + model, + endpoint, + } + } + + pub async fn chat_completion(&self, messages: Vec, tools: Value) -> Result { + info!("Sending LLM request to {}", self.endpoint); + + let mut body = json!({ + "model": self.model, + "messages": messages, + }); + + if !tools.is_null() && tools.as_array().map(|a| !a.is_empty()).unwrap_or(false) { + body.as_object_mut().unwrap().insert("tools".to_string(), tools); + } + + let res = self.client.post(&self.endpoint) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + let status = res.status(); + let response_json: Value = res.json().await?; + if !status.is_success() { + return Err(anyhow::anyhow!("API request failed with status {}: {}", status, response_json)); + } + + Ok(response_json) + } +} diff --git a/src/utils/session_manager.rs b/src/utils/session_manager.rs new file mode 100644 index 0000000..222ccff --- /dev/null +++ b/src/utils/session_manager.rs @@ -0,0 +1,68 @@ +use serde::{Deserialize, Serialize}; +use std::fs::{File, OpenOptions}; +use std::io::{BufRead, BufReader, Write}; +use std::path::{Path, PathBuf}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentMessage { + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_id: Option, + pub role: String, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +pub struct SessionManager { + log_path: PathBuf, +} + +impl SessionManager { + pub fn new(workspace: &Path) -> Self { + let mut log_path = workspace.to_path_buf(); + log_path.push(".pi"); + log_path.push("logs"); + if !log_path.exists() { + std::fs::create_dir_all(&log_path).unwrap_or_default(); + } + log_path.push("log.jsonl"); + Self { log_path } + } + + pub fn append_message(&self, message: &AgentMessage) -> anyhow::Result<()> { + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.log_path)?; + let json = serde_json::to_string(message)?; + writeln!(file, "{}", json)?; + Ok(()) + } + + pub fn load_history(&self, head_id: Option<&str>) -> anyhow::Result> { + if !self.log_path.exists() { + return Ok(Vec::new()); + } + + let file = File::open(&self.log_path)?; + let reader = BufReader::new(file); + + let mut messages = Vec::new(); + for line in reader.lines() { + if let Ok(l) = line { + if let Ok(msg) = serde_json::from_str::(&l) { + messages.push(msg); + } + } + } + + // Simple linear history for now. + // In a full implementation, we would rebuild the tree using parent_id up to head_id. + Ok(messages) + } +} diff --git a/src/utils/tools/mod.rs b/src/utils/tools/mod.rs new file mode 100644 index 0000000..95d29ee --- /dev/null +++ b/src/utils/tools/mod.rs @@ -0,0 +1,47 @@ +use anyhow::Result; +use std::fs; +use std::path::Path; +use std::process::Command; + +pub fn read_file(path: &str) -> String { + match fs::read_to_string(path) { + Ok(content) => content, + Err(e) => format!("Error reading file: {}", e), + } +} + +pub fn write_file(path: &str, content: &str) -> String { + let p = Path::new(path); + if let Some(parent) = p.parent() { + let _ = fs::create_dir_all(parent); + } + match fs::write(p, content) { + Ok(_) => format!("Successfully wrote to {}", path), + Err(e) => format!("Error writing file: {}", e), + } +} + +pub fn execute_bash(command: &str, cwd: &str) -> String { + let output = Command::new("bash") + .arg("-c") + .arg(command) + .current_dir(cwd) + .output(); + + match output { + Ok(out) => { + let mut result = String::from_utf8_lossy(&out.stdout).to_string(); + let stderr = String::from_utf8_lossy(&out.stderr).to_string(); + if !stderr.is_empty() { + result.push_str("\n--- STDERR ---\n"); + result.push_str(&stderr); + } + if result.is_empty() { + "Command executed successfully with no output.".to_string() + } else { + result + } + } + Err(e) => format!("Error executing command: {}", e), + } +} diff --git a/xxx.md b/xxx.md new file mode 100644 index 0000000..eb3a010 --- /dev/null +++ b/xxx.md @@ -0,0 +1,70 @@ +# 计划: 实现 pi-mono 会话压缩 (Compaction) 与 TOML 配置支持 + +## Context (背景与目标) +目前系统已打通核心交互循环和会话持久化。随着对话的增加,每次请求附带的历史上下文会不断变长,最终超出 LLM 模型的上下文窗口(Context Window)。 +我们需要引入会话压缩(Compaction)机制,并在 `pi` 启动时加载一个 `config.toml` 文件,用于分类配置是否启用自动压缩、每种模型的窗口大小约束、以及对应的大模型服务提供商 (Provider) 设定。 + +由于安全设定约束,**在计划模式下,系统严格禁止直接执行代码提交(`git commit`)等修改系统状态的操作。** 我将在获得您对本计划的批准进入执行模式后,**第一步便为您执行代码提交**,然后再进行后续的代码修改。 + +## Proposed Changes (架构与组件设计) + +### 1. 配置管理设计 (`src/config.rs`) +通过引入 `toml` 和 `serde` 读取全局或局部配置文件(如 `config.toml`)。设计配置类结构如下: + +```rust +#[derive(Debug, Deserialize)] +pub struct AppConfig { + pub general: GeneralConfig, + pub providers: HashMap, + pub models: HashMap, +} + +#[derive(Debug, Deserialize)] +pub struct GeneralConfig { + pub auto_compact: bool, // 是否开启自动压缩 +} + +#[derive(Debug, Deserialize)] +pub struct ProviderConfig { + pub api_base: String, + pub api_key_env: String, // 指定从哪个环境变量读取 Key,提高安全性 +} + +#[derive(Debug, Deserialize)] +pub struct ModelConfig { + pub provider: String, // 关联的 Provider 名称 + pub context_window: usize, // 模型的绝对最大窗口 (按 Token 数,粗略可用字数/4近似) + pub compact_threshold: usize, // 触发压缩的阈值 (如超过 80% 则压缩) +} +``` + +### 2. Append-Only 会话日志的压缩实现 (`src/utils/session_manager.rs`) +保留 `pi-mono` 的 Append-Only JSONL 特性,不直接修改旧日志文件。我们在 `AgentMessage` 中新增一个特殊字段: +- `clears_history: Option` + +当发生压缩时,系统将先前的对话交由 LLM 生成摘要,并写入一条 `role: "system"`, `content: "Previous conversation summary: ..."` 且附带 `clears_history: true` 的新记录。 +在重启应用并调用 `load_history` 恢复列表时,如果读到 `clears_history == Some(true)` 的消息,就将内存中的 `messages` 清空(或保留最初的系统设定),只保留这条 Summary 继续往下构建,完成完美的持久化无损截断。 + +### 3. PiLLM 动态路由支持 (`src/utils/pi_llm.rs`) +修改 `PiLLM` 内部逻辑: +不再硬编码读取 `OPENAI_API_KEY`,而是根据当前的选项,从 `AppConfig` 查找到对应的 `ModelConfig`,接着查找到 `ProviderConfig`,使用对应的 `api_base` 和对应的环境变量(`api_key_env`)进行鉴权与请求分发。 + +### 4. 压缩节点注入 (`src/bin/pi.rs`) +- 在 `LLMReasoningNode` 的 `execute` 开始前加入 Token 容量估算逻辑(如按字符串长度评估或引入 tiktoken 库)。 +- 若超过 `model.compact_threshold` 且 `auto_compact` 为 true,即刻在此流程内部或单独生成一个阻塞调用,请求模型生成摘要("Please summarize the history conversation concisely...")。 +- 成功获得摘要后,实例化一段 `clears_history=true` 的 `AgentMessage` 调用 `append_message` 持久化,重置 Context 中的 `messages`,接着再处理用户当下真实的发问。 + +### 5. `Cargo.toml` 依赖更新 +添加 `toml` crates 支持。 + +## Verification Plan (验证计划) +### Automated Tests +1. 编写对 `AppConfig` TOML 解析的单元测试。 +2. 针对包含 `clears_history: true` 的模拟 `.jsonl` 文件编写 `SessionManager::load_history` 测试,断言数组应当短路清空,仅保留之后的有效长度。 + +### Manual Verification +1. **代码提交流程验证**:执行模式开启后的第一件事就是运行 `git commit -am "chore: initial working implementation before compaction"`,证明我们遵守了指令。 +2. 启动代理程序加载携带多 `models` 和 `providers` 的 `config.toml`。 +3. 把某配置大模型的 `compact_threshold` 改为非常小的值(如 `20`)。 +4. 对话几次,让长度超过该极小阈值,观察控制台是否输出了 "[Auto Compacting History...]" 提示。 +5. 通过查看底层 `log.jsonl`,证实末尾产生了一条标志性的含有 `clears_history: true` 的 Json 行。使用 `cargo run` 重启程序,核实历史条数(`Loaded X messages`)是否明显变少(因为老信息已被抛弃截断)。 From a26b65187a8f5b9cee53c3ea05366af92738a7ec Mon Sep 17 00:00:00 2001 From: dahai9 Date: Fri, 27 Mar 2026 14:43:40 +0800 Subject: [PATCH 4/9] feat: add enhanced config system with paths and defaults for pocketflow-rs --- Cargo.toml | 1 + src/bin/pi.rs | 160 +++++++++++++++++++++++++++++++++-- src/utils/config.rs | 102 ++++++++++++++++++++++ src/utils/mod.rs | 1 + src/utils/session_manager.rs | 7 +- 5 files changed, 263 insertions(+), 8 deletions(-) create mode 100644 src/utils/config.rs diff --git a/Cargo.toml b/Cargo.toml index e259317..fec9591 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "rus clap = { version = "4.4", features = ["derive"] } directories = "5.0" uuid = { version = "1.8", features = ["v4", "fast-rng"] } +toml = "0.8" [features] openai = ["dep:openai_api_rust"] diff --git a/src/bin/pi.rs b/src/bin/pi.rs index ba708d1..aa55592 100644 --- a/src/bin/pi.rs +++ b/src/bin/pi.rs @@ -9,6 +9,7 @@ use uuid::Uuid; use pocketflow_rs::utils::pi_llm::PiLLM; use pocketflow_rs::utils::session_manager::{AgentMessage, SessionManager}; use pocketflow_rs::utils::tools::{execute_bash, read_file, write_file}; +use pocketflow_rs::utils::config::AppConfig; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -28,6 +29,8 @@ struct Args { enum PiState { #[default] Default, + CheckSize, + DoCompact, CallLLM, ExecuteTool, WaitForInput, @@ -44,6 +47,8 @@ impl ProcessState for PiState { struct AppContext { llm: PiLLM, session_manager: SessionManager, + config: AppConfig, + model_name: String, } struct InputNode { @@ -75,6 +80,7 @@ impl Node for InputNode { name: None, tool_calls: None, tool_call_id: None, + clears_history: None, }; // Persist immediately @@ -98,7 +104,7 @@ impl Node for InputNode { messages.as_array_mut().unwrap().push(msg_val.clone()); context.set("messages", messages); - Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + Ok(ProcessResult::new(PiState::CheckSize, "check_size".to_string())) } } @@ -106,6 +112,122 @@ struct LLMReasoningNode { app: Arc, } +struct CheckSizeNode { + app: Arc, +} + +#[async_trait::async_trait] +impl Node for CheckSizeNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let messages = context.get("messages").unwrap_or(&json!([])).clone(); + let msgs_str = serde_json::to_string(&messages).unwrap_or_default(); + let estimated_tokens = msgs_str.len() / 4; + + // Handle gracefully if config doesn't perfectly match model + let default_compact_thresh = 100000; + let threshold = self.app.config.models.get(&self.app.model_name) + .map(|m| m.compact_threshold) + .unwrap_or(default_compact_thresh); + + if self.app.config.general.auto_compact && estimated_tokens > threshold && messages.as_array().map(|a| a.len() > 3).unwrap_or(false) { + println!("\n[Auto Compacting History (est {} tokens > {})]...", estimated_tokens, threshold); + Ok(json!({ "needs_compact": true, "history_str": msgs_str })) + } else { + Ok(json!({ "needs_compact": false })) + } + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = result.as_ref().unwrap(); + if res.get("needs_compact").and_then(|v| v.as_bool()) == Some(true) { + context.set("history_to_compact", res.get("history_str").unwrap().clone()); + Ok(ProcessResult::new(PiState::DoCompact, "do_compact".to_string())) + } else { + Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + } + } +} + +struct DoCompactNode { + app: Arc, +} + +#[async_trait::async_trait] +impl Node for DoCompactNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let history_str = context.get("history_to_compact").unwrap().as_str().unwrap(); + + let summary_prompt = json!({ + "role": "user", + "content": format!("Summarize the entire conversation history concisely, retaining all tool outcomes and important context so it can be used to replace the history entirely:\n{}", history_str) + }); + + println!("Sending compaction request to LLM..."); + let mut retries = 0; + let max_retries = 3; + loop { + match self.app.llm.chat_completion(vec![summary_prompt.clone()], Value::Null).await { + Ok(summary_res) => return Ok(summary_res), + Err(e) => { + retries += 1; + if retries > max_retries { + return Err(e); + } + println!("[Compaction Failed]: {}. Retrying ({}/{}) in 2 seconds...", e, retries, max_retries); + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + } + } + } + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = match result { + Ok(v) => v, + Err(e) => { + println!("[Compaction Failed]: {}", e); + return Ok(ProcessResult::new(PiState::CallLLM, "call_llm_fallback".to_string())); + } + }; + + if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { + if let Some(choice) = choices.first() { + let summary_text = choice["message"]["content"].as_str().unwrap_or(""); + + let compact_msg = AgentMessage { + id: Uuid::new_v4().to_string(), + parent_id: None, + role: "system".to_string(), + content: format!("Previous conversation summary:\n{}", summary_text), + name: None, + tool_calls: None, + tool_call_id: None, + clears_history: Some(true), + }; + + self.app.session_manager.append_message(&compact_msg)?; + + let messages = json!([compact_msg]); + context.set("messages", messages); + println!("History compressed successfully."); + } + } + + Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + } +} + #[async_trait::async_trait] impl Node for LLMReasoningNode { type State = PiState; @@ -213,6 +335,7 @@ impl Node for LLMReasoningNode { name: None, tool_calls: tool_calls.cloned(), tool_call_id: None, + clears_history: None, }; // Persist @@ -289,6 +412,7 @@ impl Node for ToolExecutionNode { name: Some(name.to_string()), tool_calls: None, tool_call_id: Some(id), + clears_history: None, }; tool_results.push(agent_msg); @@ -314,7 +438,7 @@ impl Node for ToolExecutionNode { context.set("messages", messages); - Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + Ok(ProcessResult::new(PiState::CheckSize, "check_size".to_string())) } } @@ -325,36 +449,58 @@ async fn main() -> Result<()> { // Setup directory and SessionManager let cwd = std::env::current_dir()?; let session_manager = SessionManager::new(&cwd); + + // Load config + let config = AppConfig::load(&cwd)?; // Load API Key - let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()); - let mut endpoint = std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string()); + let (api_key, mut endpoint) = if let Some(model_conf) = config.models.get(&args.model) { + if let Some(provider_conf) = config.providers.get(&model_conf.provider) { + let key = std::env::var(&provider_conf.api_key_env).unwrap_or_else(|_| "dummy_key".to_string()); + (key, provider_conf.api_base.clone()) + } else { + (std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()), "https://api.openai.com/v1".to_string()) + } + } else { + println!("Warning: Model '{}' not found in config, falling back to openai env vars.", args.model); + (std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()), std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string())) + }; + if !endpoint.ends_with("/chat/completions") { endpoint = format!("{}/chat/completions", endpoint.trim_end_matches('/')); } - let llm = PiLLM::new(api_key, args.model, endpoint); + let llm = PiLLM::new(api_key, args.model.clone(), endpoint); let app_context = Arc::new(AppContext { llm, session_manager, + config, + model_name: args.model.clone(), }); let input_node = InputNode { app: app_context.clone() }; + let check_size_node = CheckSizeNode { app: app_context.clone() }; + let compact_node = DoCompactNode { app: app_context.clone() }; let llm_node = LLMReasoningNode { app: app_context.clone() }; let tool_node = ToolExecutionNode { app: app_context.clone() }; let flow = build_flow!( start: ("input", input_node), nodes: [ + ("check_size", check_size_node), + ("do_compact", compact_node), ("llm", llm_node), ("tool", tool_node) ], edges: [ - ("input", "llm", PiState::CallLLM), + ("input", "check_size", PiState::CheckSize), + ("check_size", "do_compact", PiState::DoCompact), + ("check_size", "llm", PiState::CallLLM), + ("do_compact", "llm", PiState::CallLLM), ("llm", "tool", PiState::ExecuteTool), ("llm", "input", PiState::WaitForInput), - ("tool", "llm", PiState::CallLLM) + ("tool", "check_size", PiState::CheckSize) // Implicit default stop for PiState::Finished ] ); diff --git a/src/utils/config.rs b/src/utils/config.rs new file mode 100644 index 0000000..7d1c1e4 --- /dev/null +++ b/src/utils/config.rs @@ -0,0 +1,102 @@ +use anyhow::{Context, Result}; +use serde::Deserialize; +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +#[derive(Debug, Deserialize, Clone)] +pub struct AppConfig { + #[serde(default)] + pub general: GeneralConfig, + pub providers: HashMap, + pub models: HashMap, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct GeneralConfig { + #[serde(default = "default_auto_compact")] + pub auto_compact: bool, +} + +impl Default for GeneralConfig { + fn default() -> Self { + Self { auto_compact: default_auto_compact() } + } +} + +fn default_auto_compact() -> bool { + true +} + +#[derive(Debug, Deserialize, Clone)] +pub struct ProviderConfig { + pub api_base: String, + pub api_key_env: String, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct ModelConfig { + pub provider: String, + pub context_window: usize, + pub compact_threshold: usize, +} + +impl AppConfig { + pub fn load>(workspace: P) -> Result { + let workspace = workspace.as_ref(); + let mut config_dir = workspace.to_path_buf(); + config_dir.push(".pi"); + + if !config_dir.exists() { + fs::create_dir_all(&config_dir)?; + } + + let mut config_path = config_dir.clone(); + config_path.push("config.toml"); + + if !config_path.exists() { + let default_toml = r#"[general] +auto_compact = true + +[providers.openai] +api_base = "https://api.openai.com/v1" +api_key_env = "OPENAI_API_KEY" + +[providers.cerebras] +api_base = "http://127.0.0.1:4000/v1" +api_key_env = "OPENAI_API_KEY" + +[models."gpt-4o"] +provider = "openai" +context_window = 128000 +compact_threshold = 100000 + +[models."cerebras/qwen-3-235b-a22b-instruct-2507"] +provider = "cerebras" +context_window = 8192 +compact_threshold = 6000 +"#; + fs::write(&config_path, default_toml.trim())?; + } + + let content = fs::read_to_string(&config_path) + .with_context(|| format!("Failed to read config file at {{:?}}", config_path))?; + + let config: AppConfig = toml::from_str(&content) + .with_context(|| "Failed to parse config.toml")?; + + Ok(config) + } + + pub fn config_dir>(workspace: P) -> std::path::PathBuf { + let mut path = workspace.as_ref().to_path_buf(); + path.push(".pi"); + path + } + + pub fn logs_dir>(workspace: P) -> std::path::PathBuf { + let mut path = Self::config_dir(workspace); + path.push("logs"); + path + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index f66d11b..7449038 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -7,3 +7,4 @@ pub mod web_search; pub mod pi_llm; pub mod session_manager; pub mod tools; +pub mod config; diff --git a/src/utils/session_manager.rs b/src/utils/session_manager.rs index 222ccff..830367a 100644 --- a/src/utils/session_manager.rs +++ b/src/utils/session_manager.rs @@ -16,6 +16,8 @@ pub struct AgentMessage { pub tool_calls: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub clears_history: Option, } pub struct SessionManager { @@ -44,7 +46,7 @@ impl SessionManager { Ok(()) } - pub fn load_history(&self, head_id: Option<&str>) -> anyhow::Result> { + pub fn load_history(&self, _head_id: Option<&str>) -> anyhow::Result> { if !self.log_path.exists() { return Ok(Vec::new()); } @@ -56,6 +58,9 @@ impl SessionManager { for line in reader.lines() { if let Ok(l) = line { if let Ok(msg) = serde_json::from_str::(&l) { + if msg.clears_history == Some(true) { + messages.clear(); + } messages.push(msg); } } From 0f651dcbd501f967e2dc69781c58ed2e3c5aa610 Mon Sep 17 00:00:00 2001 From: dahai9 Date: Fri, 27 Mar 2026 15:00:52 +0800 Subject: [PATCH 5/9] docs: add CONFIG_GUIDE.md for configuration explanation --- CONFIG_GUIDE.md | 72 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 CONFIG_GUIDE.md diff --git a/CONFIG_GUIDE.md b/CONFIG_GUIDE.md new file mode 100644 index 0000000..8ae9d67 --- /dev/null +++ b/CONFIG_GUIDE.md @@ -0,0 +1,72 @@ +# PocketFlow-RS 配置指南 + +本项目使用 `config.toml` 管理模型、API 提供商和通用设置。配置文件位于工作区下的 `./.pi/config.toml`。 + +## 文件位置 +- **默认路径**:`./.pi/config.toml` +- **日志路径**:`./.pi/logs/log.jsonl` + +> 提示:当你首次运行 `pi` 命令时,若配置不存在,会自动创建默认配置。 + +## 配置结构 + +```toml +[general] +auto_compact = true + +[providers.openai] +api_base = "https://api.openai.com/v1" +api_key_env = "OPENAI_API_KEY" + +[providers.cerebras] +api_base = "http://127.0.0.1:4000/v1" +api_key_env = "OPENAI_API_KEY" + +[models."gpt-4o"] +provider = "openai" +context_window = 128000 +compact_threshold = 100000 + +[models."cerebras/qwen-3-235b-a22b-instruct-2507"] +provider = "cerebras" +context_window = 8192 +compact_threshold = 6000 +``` + +### 字段说明 + +| 区域 | 字段 | 说明 | +|------|------|------| +| `general` | `auto_compact` | 是否开启自动历史压缩(摘要旧消息以节省上下文) | +| `providers` | `api_base` | 模型提供商的 API 地址 | +| `providers` | `api_key_env` | 读取 API 密钥的环境变量名 | +| `models` | `provider` | 该模型使用的提供商(需匹配上方 `providers` 中的键) | +| `models` | `context_window` | 模型最大上下文长度(token 数) | +| `models` | `compact_threshold` | 超过此长度时触发自动压缩 | + +## 自定义配置示例 + +你可以添加本地大模型支持,例如: + +```toml +[providers.localhost] +api_base = "http://127.0.0.1:8080/v1" +api_key_env = "LOCAL_API_KEY" + +[models."qwen:14b"] +provider = "localhost" +context_window = 32768 +compact_threshold = 24000 +``` + +然后运行: +```bash +pi --model qwen:14b --provider localhost +``` + +--- + +📌 提示:确保相应 `api_key_env` 环境变量已设置,例如: +```bash +export OPENAI_API_KEY=sk-... +``` From cd3f8a002edddb2207b0fe8bec6c2d737af4f686 Mon Sep 17 00:00:00 2001 From: dahai9 Date: Sun, 29 Mar 2026 09:38:08 +0800 Subject: [PATCH 6/9] feat: implement SubFlow Node and Mermaid visualization --- examples/basic.rs | 6 +++ src/bin/pi.rs | 10 +++-- src/flow.rs | 98 ++++++++++++++++++++++++++++++++++++++++++ src/node.rs | 61 +++++++++++++++++++++++++- src/utils/config.rs | 2 +- src/utils/tools/mod.rs | 2 +- 6 files changed, 172 insertions(+), 7 deletions(-) diff --git a/examples/basic.rs b/examples/basic.rs index 2973cb2..15052ba 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -157,6 +157,12 @@ async fn main() -> std::result::Result<(), Box> { // Create context let context = Context::new(); + // Export flow visualization + std::fs::create_dir_all("test_dir")?; + let mermaid = flow.to_mermaid(); + std::fs::write("test_dir/basic_flow.mmd", mermaid)?; + println!("Saved flow visualization to test_dir/basic_flow.mmd"); + // Run the flow println!("Starting flow execution..."); flow.run(context).await?; diff --git a/src/bin/pi.rs b/src/bin/pi.rs index aa55592..3b3aa98 100644 --- a/src/bin/pi.rs +++ b/src/bin/pi.rs @@ -1,6 +1,6 @@ use anyhow::Result; use clap::Parser; -use pocketflow_rs::{build_flow, Context, Flow, Node, ProcessResult, ProcessState}; +use pocketflow_rs::{build_flow, Context, Node, ProcessResult, ProcessState}; use serde_json::{json, Value}; use std::io::{self, Write}; use std::sync::Arc; @@ -59,7 +59,7 @@ struct InputNode { impl Node for InputNode { type State = PiState; - async fn execute(&self, context: &Context) -> Result { + async fn execute(&self, _context: &Context) -> Result { print!("> "); io::stdout().flush()?; let mut input = String::new(); @@ -516,7 +516,11 @@ async fn main() -> Result<()> { } else { context.set("messages", json!([])); } - + // Export flow visualization + std::fs::create_dir_all("test_dir")?; + let mermaid = flow.to_mermaid(); + std::fs::write("test_dir/pi_flow.mmd", mermaid)?; + println!("Saved flow visualization to test_dir/pi_flow.mmd"); println!("pi agent started. Type 'exit' to quit."); match flow.run(context).await { diff --git a/src/flow.rs b/src/flow.rs index bec5465..a04432f 100644 --- a/src/flow.rs +++ b/src/flow.rs @@ -92,6 +92,31 @@ impl Flow { Ok(context.get("result").unwrap_or(&Value::Null).clone()) } + + pub fn to_mermaid(&self) -> String { + let mut mermaid = String::from("flowchart TD\n"); + + // 声明所有节点 + let mut nodes: Vec<_> = self.nodes.keys().collect(); + nodes.sort(); // 确保输出稳定 + for node_name in nodes { + mermaid.push_str(&format!(" {}[{}]\n", node_name, node_name)); + } + + // 声明所有连线 + let mut edge_keys: Vec<_> = self.edges.keys().collect(); + edge_keys.sort(); // 确保输出稳定 + for from in edge_keys { + let edges = &self.edges[from]; + let mut sorted_edges = edges.clone(); + sorted_edges.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1))); + for (to, condition) in sorted_edges { + mermaid.push_str(&format!(" {} -->|{}| {}\n", from, condition, to)); + } + } + + mermaid + } } #[allow(dead_code)] @@ -125,6 +150,10 @@ impl BatchFlow { info!("Batch flow execution completed"); Ok(()) } + + pub fn to_mermaid(&self) -> String { + self.flow.to_mermaid() + } } #[macro_export] @@ -217,6 +246,7 @@ mod tests { use async_trait::async_trait; use serde_json::json; use strum::Display; + use crate::node::SubFlowNode; #[derive(Debug, Clone, PartialEq, Default, Display)] #[strum(serialize_all = "snake_case")] @@ -358,4 +388,72 @@ mod tests { let result = flow3.run(context).await.unwrap(); assert_eq!(result, json!({"data": "test2"})); } + + #[tokio::test] + async fn test_subflow_node() { + // 1. Create subflow + let sub_node = TestNode::new(json!({"sub_result": "from_subflow"}), CustomState::Success); + let sub_flow = build_flow!( + start: ("sub_start", sub_node) + ); + + // 2. Create SubFlowNode as a node in parent flow + let subflow_node = SubFlowNode::::new( + sub_flow, + |ctx: &Context| { + // Example: Inherit everything + ctx.clone() + }, + |parent_ctx: &mut Context, result: &Result| { + // Example: Map result to parent context + match result { + Ok(val) => { + parent_ctx.set("result", val.clone()); + Ok(ProcessResult::new(CustomState::Success, "subflow ok".to_string())) + } + Err(e) => { + parent_ctx.set("error", json!(e.to_string())); + Ok(ProcessResult::new(CustomState::Failure, e.to_string())) + } + } + }, + ); + + // 3. Create parent flow + let parent_flow = build_flow!( + start: ("run_subflow", subflow_node) + ); + + let context = Context::new(); + let result: serde_json::Value = parent_flow.run(context).await.unwrap(); + + assert_eq!(result, json!({"sub_result": "from_subflow"})); + } + + #[test] + fn test_flow_to_mermaid() { + let node1 = TestNode::new(json!({"data": "test1"}), CustomState::Success); + let node2 = TestNode::new(json!({"data": "test2"}), CustomState::Default); + let end_node = TestNode::new(json!({"final_result": "finished"}), CustomState::Default); + + let flow = build_flow!( + start: ("start", node1), + nodes: [("next", node2), ("end", end_node)], + edges: [ + ("start", "next", CustomState::Success), + ("next", "end", CustomState::Default) + ] + ); + + let mermaid = flow.to_mermaid(); + let expected = "\ +flowchart TD + end[end] + next[next] + start[start] + next -->|default| end + start -->|success| next +"; + assert_eq!(mermaid, expected); + } } diff --git a/src/node.rs b/src/node.rs index d14b49d..bc00dc1 100644 --- a/src/node.rs +++ b/src/node.rs @@ -2,8 +2,8 @@ use crate::{Params, context::Context}; use anyhow::Result; use async_trait::async_trait; use strum::Display; -// use std::collections::HashMap; -// use std::sync::Arc; +use std::sync::Arc; +use crate::flow::Flow; pub trait ProcessState: Send + Sync + std::fmt::Display { fn is_default(&self) -> bool; @@ -138,3 +138,60 @@ impl Node for BatchNode { } impl BaseNodeTrait for BatchNode {} + +pub struct SubFlowNode +where + SubState: ProcessState + Default, + ParentState: ProcessState + Default, +{ + pub sub_flow: Arc>, + pub context_builder: Box Context + Send + Sync>, + pub result_mapper: Box< + dyn Fn(&mut Context, &Result) -> Result> + + Send + + Sync, + >, +} + +impl SubFlowNode +where + SubState: ProcessState + Default, + ParentState: ProcessState + Default, +{ + pub fn new( + sub_flow: Flow, + context_builder: impl Fn(&Context) -> Context + Send + Sync + 'static, + result_mapper: impl Fn(&mut Context, &Result) -> Result> + + Send + + Sync + + 'static, + ) -> Self { + Self { + sub_flow: Arc::new(sub_flow), + context_builder: Box::new(context_builder), + result_mapper: Box::new(result_mapper), + } + } +} + +#[async_trait] +impl Node for SubFlowNode +where + SubState: ProcessState + Default, + ParentState: ProcessState + Default, +{ + type State = ParentState; + + async fn execute(&self, context: &Context) -> Result { + let sub_context = (self.context_builder)(context); + self.sub_flow.run(sub_context).await + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + (self.result_mapper)(context, result) + } +} diff --git a/src/utils/config.rs b/src/utils/config.rs index 7d1c1e4..ef032a1 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -80,7 +80,7 @@ compact_threshold = 6000 } let content = fs::read_to_string(&config_path) - .with_context(|| format!("Failed to read config file at {{:?}}", config_path))?; + .with_context(|| format!("Failed to read config file at {:?}", config_path))?; let config: AppConfig = toml::from_str(&content) .with_context(|| "Failed to parse config.toml")?; diff --git a/src/utils/tools/mod.rs b/src/utils/tools/mod.rs index 95d29ee..74d8e03 100644 --- a/src/utils/tools/mod.rs +++ b/src/utils/tools/mod.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +// use anyhow::Result; use std::fs; use std::path::Path; use std::process::Command; From 1137e5abee7057aaf8ee1b297d7696b89059f5e8 Mon Sep 17 00:00:00 2001 From: dahai9 Date: Sun, 29 Mar 2026 09:52:20 +0800 Subject: [PATCH 7/9] refactor: move pi binary to examples/pi as a workspace member --- Cargo.toml | 5 +- examples/pi/CONFIG_GUIDE.md | 72 ++++++++++++++++++++++++ examples/pi/Cargo.toml | 14 +++++ src/bin/pi.rs => examples/pi/src/main.rs | 0 4 files changed, 87 insertions(+), 4 deletions(-) create mode 100644 examples/pi/CONFIG_GUIDE.md create mode 100644 examples/pi/Cargo.toml rename src/bin/pi.rs => examples/pi/src/main.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index fec9591..0ac36f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,16 +10,13 @@ license = "MIT" name = "pocketflow_rs" path = "src/lib.rs" -[[bin]] -name = "pi" -path = "src/bin/pi.rs" [[example]] name = "basic" path = "examples/basic.rs" [workspace] -members = ["examples/pocketflow-rs-rag", "examples/text2sql"] +members = ["examples/pocketflow-rs-rag", "examples/text2sql", "examples/pi"] [workspace.dependencies] strum = "0.26" diff --git a/examples/pi/CONFIG_GUIDE.md b/examples/pi/CONFIG_GUIDE.md new file mode 100644 index 0000000..8ae9d67 --- /dev/null +++ b/examples/pi/CONFIG_GUIDE.md @@ -0,0 +1,72 @@ +# PocketFlow-RS 配置指南 + +本项目使用 `config.toml` 管理模型、API 提供商和通用设置。配置文件位于工作区下的 `./.pi/config.toml`。 + +## 文件位置 +- **默认路径**:`./.pi/config.toml` +- **日志路径**:`./.pi/logs/log.jsonl` + +> 提示:当你首次运行 `pi` 命令时,若配置不存在,会自动创建默认配置。 + +## 配置结构 + +```toml +[general] +auto_compact = true + +[providers.openai] +api_base = "https://api.openai.com/v1" +api_key_env = "OPENAI_API_KEY" + +[providers.cerebras] +api_base = "http://127.0.0.1:4000/v1" +api_key_env = "OPENAI_API_KEY" + +[models."gpt-4o"] +provider = "openai" +context_window = 128000 +compact_threshold = 100000 + +[models."cerebras/qwen-3-235b-a22b-instruct-2507"] +provider = "cerebras" +context_window = 8192 +compact_threshold = 6000 +``` + +### 字段说明 + +| 区域 | 字段 | 说明 | +|------|------|------| +| `general` | `auto_compact` | 是否开启自动历史压缩(摘要旧消息以节省上下文) | +| `providers` | `api_base` | 模型提供商的 API 地址 | +| `providers` | `api_key_env` | 读取 API 密钥的环境变量名 | +| `models` | `provider` | 该模型使用的提供商(需匹配上方 `providers` 中的键) | +| `models` | `context_window` | 模型最大上下文长度(token 数) | +| `models` | `compact_threshold` | 超过此长度时触发自动压缩 | + +## 自定义配置示例 + +你可以添加本地大模型支持,例如: + +```toml +[providers.localhost] +api_base = "http://127.0.0.1:8080/v1" +api_key_env = "LOCAL_API_KEY" + +[models."qwen:14b"] +provider = "localhost" +context_window = 32768 +compact_threshold = 24000 +``` + +然后运行: +```bash +pi --model qwen:14b --provider localhost +``` + +--- + +📌 提示:确保相应 `api_key_env` 环境变量已设置,例如: +```bash +export OPENAI_API_KEY=sk-... +``` diff --git a/examples/pi/Cargo.toml b/examples/pi/Cargo.toml new file mode 100644 index 0000000..87e65f1 --- /dev/null +++ b/examples/pi/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "pi" +version = "0.1.0" +edition = "2024" + +[dependencies] +pocketflow_rs = { path = "../../", features = ["openai"] } +anyhow = "1.0" +tokio = { version = "1.0", features = ["full"] } +serde_json = "1.0" +clap = { version = "4.4", features = ["derive"] } +uuid = { version = "1.8", features = ["v4", "fast-rng"] } +strum = { version = "0.26", features = ["derive"] } +async-trait = "0.1" diff --git a/src/bin/pi.rs b/examples/pi/src/main.rs similarity index 100% rename from src/bin/pi.rs rename to examples/pi/src/main.rs From 92f5a9e4869cd7c7e9c903da400f23aca65b88cc Mon Sep 17 00:00:00 2001 From: dahai9 Date: Sun, 29 Mar 2026 10:13:02 +0800 Subject: [PATCH 8/9] feat: implement SubFlow Node, Mermaid visualization, and modularize pi agent example --- CONFIG_GUIDE.md | 72 --- examples/pi/Cargo.toml | 4 + examples/pi/src/lib.rs | 7 + examples/pi/src/main.rs | 433 +----------------- examples/pi/src/nodes/check_size.rs | 47 ++ examples/pi/src/nodes/do_compact.rs | 81 ++++ examples/pi/src/nodes/input.rs | 65 +++ examples/pi/src/nodes/llm_reasoning.rs | 146 ++++++ examples/pi/src/nodes/mod.rs | 11 + examples/pi/src/nodes/tool_execution.rs | 86 ++++ examples/pi/src/state.rs | 32 ++ {src => examples/pi/src}/utils/config.rs | 0 examples/pi/src/utils/mod.rs | 8 + {src => examples/pi/src}/utils/pi_llm.rs | 0 .../pi/src}/utils/session_manager.rs | 0 {src => examples/pi/src}/utils/tools/mod.rs | 0 src/utils/mod.rs | 4 - xxx.md | 70 --- 18 files changed, 491 insertions(+), 575 deletions(-) delete mode 100644 CONFIG_GUIDE.md create mode 100644 examples/pi/src/lib.rs create mode 100644 examples/pi/src/nodes/check_size.rs create mode 100644 examples/pi/src/nodes/do_compact.rs create mode 100644 examples/pi/src/nodes/input.rs create mode 100644 examples/pi/src/nodes/llm_reasoning.rs create mode 100644 examples/pi/src/nodes/mod.rs create mode 100644 examples/pi/src/nodes/tool_execution.rs create mode 100644 examples/pi/src/state.rs rename {src => examples/pi/src}/utils/config.rs (100%) create mode 100644 examples/pi/src/utils/mod.rs rename {src => examples/pi/src}/utils/pi_llm.rs (100%) rename {src => examples/pi/src}/utils/session_manager.rs (100%) rename {src => examples/pi/src}/utils/tools/mod.rs (100%) delete mode 100644 xxx.md diff --git a/CONFIG_GUIDE.md b/CONFIG_GUIDE.md deleted file mode 100644 index 8ae9d67..0000000 --- a/CONFIG_GUIDE.md +++ /dev/null @@ -1,72 +0,0 @@ -# PocketFlow-RS 配置指南 - -本项目使用 `config.toml` 管理模型、API 提供商和通用设置。配置文件位于工作区下的 `./.pi/config.toml`。 - -## 文件位置 -- **默认路径**:`./.pi/config.toml` -- **日志路径**:`./.pi/logs/log.jsonl` - -> 提示:当你首次运行 `pi` 命令时,若配置不存在,会自动创建默认配置。 - -## 配置结构 - -```toml -[general] -auto_compact = true - -[providers.openai] -api_base = "https://api.openai.com/v1" -api_key_env = "OPENAI_API_KEY" - -[providers.cerebras] -api_base = "http://127.0.0.1:4000/v1" -api_key_env = "OPENAI_API_KEY" - -[models."gpt-4o"] -provider = "openai" -context_window = 128000 -compact_threshold = 100000 - -[models."cerebras/qwen-3-235b-a22b-instruct-2507"] -provider = "cerebras" -context_window = 8192 -compact_threshold = 6000 -``` - -### 字段说明 - -| 区域 | 字段 | 说明 | -|------|------|------| -| `general` | `auto_compact` | 是否开启自动历史压缩(摘要旧消息以节省上下文) | -| `providers` | `api_base` | 模型提供商的 API 地址 | -| `providers` | `api_key_env` | 读取 API 密钥的环境变量名 | -| `models` | `provider` | 该模型使用的提供商(需匹配上方 `providers` 中的键) | -| `models` | `context_window` | 模型最大上下文长度(token 数) | -| `models` | `compact_threshold` | 超过此长度时触发自动压缩 | - -## 自定义配置示例 - -你可以添加本地大模型支持,例如: - -```toml -[providers.localhost] -api_base = "http://127.0.0.1:8080/v1" -api_key_env = "LOCAL_API_KEY" - -[models."qwen:14b"] -provider = "localhost" -context_window = 32768 -compact_threshold = 24000 -``` - -然后运行: -```bash -pi --model qwen:14b --provider localhost -``` - ---- - -📌 提示:确保相应 `api_key_env` 环境变量已设置,例如: -```bash -export OPENAI_API_KEY=sk-... -``` diff --git a/examples/pi/Cargo.toml b/examples/pi/Cargo.toml index 87e65f1..4b26e56 100644 --- a/examples/pi/Cargo.toml +++ b/examples/pi/Cargo.toml @@ -12,3 +12,7 @@ clap = { version = "4.4", features = ["derive"] } uuid = { version = "1.8", features = ["v4", "fast-rng"] } strum = { version = "0.26", features = ["derive"] } async-trait = "0.1" +serde = { version = "1.0", features = ["derive"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +toml = "0.8" +tracing = "0.1" diff --git a/examples/pi/src/lib.rs b/examples/pi/src/lib.rs new file mode 100644 index 0000000..c891e4c --- /dev/null +++ b/examples/pi/src/lib.rs @@ -0,0 +1,7 @@ +pub mod nodes; +pub mod state; +pub mod utils; + +pub use nodes::*; +pub use state::*; +pub use utils::*; diff --git a/examples/pi/src/main.rs b/examples/pi/src/main.rs index 3b3aa98..0db1b84 100644 --- a/examples/pi/src/main.rs +++ b/examples/pi/src/main.rs @@ -1,15 +1,9 @@ use anyhow::Result; use clap::Parser; -use pocketflow_rs::{build_flow, Context, Node, ProcessResult, ProcessState}; -use serde_json::{json, Value}; -use std::io::{self, Write}; +use pocketflow_rs::{build_flow, Context}; +use serde_json::json; use std::sync::Arc; -use strum::Display; -use uuid::Uuid; -use pocketflow_rs::utils::pi_llm::PiLLM; -use pocketflow_rs::utils::session_manager::{AgentMessage, SessionManager}; -use pocketflow_rs::utils::tools::{execute_bash, read_file, write_file}; -use pocketflow_rs::utils::config::AppConfig; +use pi::{InputNode, CheckSizeNode, DoCompactNode, LLMReasoningNode, ToolExecutionNode, AppContext, PiState, PiLLM, SessionManager, AppConfig}; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -24,424 +18,6 @@ struct Args { model: String, } -#[derive(Debug, Clone, PartialEq, Default, Display)] -#[strum(serialize_all = "snake_case")] -enum PiState { - #[default] - Default, - CheckSize, - DoCompact, - CallLLM, - ExecuteTool, - WaitForInput, - Finished, -} - -impl ProcessState for PiState { - fn is_default(&self) -> bool { - matches!(self, PiState::Default) - } -} - -// Global shared components between nodes -struct AppContext { - llm: PiLLM, - session_manager: SessionManager, - config: AppConfig, - model_name: String, -} - -struct InputNode { - app: Arc, -} - -#[async_trait::async_trait] -impl Node for InputNode { - type State = PiState; - - async fn execute(&self, _context: &Context) -> Result { - print!("> "); - io::stdout().flush()?; - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - let text = input.trim().to_string(); - - if text == "exit" || text == "quit" { - return Ok(json!({ "command": "exit" })); - } - - let id = Uuid::new_v4().to_string(); - - let msg = AgentMessage { - id: id.clone(), - parent_id: None, - role: "user".to_string(), - content: text, - name: None, - tool_calls: None, - tool_call_id: None, - clears_history: None, - }; - - // Persist immediately - self.app.session_manager.append_message(&msg)?; - - Ok(json!({ "message": msg })) - } - - async fn post_process( - &self, - context: &mut Context, - result: &Result, - ) -> Result> { - let res = result.as_ref().unwrap(); - if res.get("command").and_then(|v| v.as_str()) == Some("exit") { - return Ok(ProcessResult::new(PiState::Finished, "finished".to_string())); - } - - let msg_val = res.get("message").unwrap(); - let mut messages = context.get("messages").cloned().unwrap_or(json!([])); - messages.as_array_mut().unwrap().push(msg_val.clone()); - context.set("messages", messages); - - Ok(ProcessResult::new(PiState::CheckSize, "check_size".to_string())) - } -} - -struct LLMReasoningNode { - app: Arc, -} - -struct CheckSizeNode { - app: Arc, -} - -#[async_trait::async_trait] -impl Node for CheckSizeNode { - type State = PiState; - - async fn execute(&self, context: &Context) -> Result { - let messages = context.get("messages").unwrap_or(&json!([])).clone(); - let msgs_str = serde_json::to_string(&messages).unwrap_or_default(); - let estimated_tokens = msgs_str.len() / 4; - - // Handle gracefully if config doesn't perfectly match model - let default_compact_thresh = 100000; - let threshold = self.app.config.models.get(&self.app.model_name) - .map(|m| m.compact_threshold) - .unwrap_or(default_compact_thresh); - - if self.app.config.general.auto_compact && estimated_tokens > threshold && messages.as_array().map(|a| a.len() > 3).unwrap_or(false) { - println!("\n[Auto Compacting History (est {} tokens > {})]...", estimated_tokens, threshold); - Ok(json!({ "needs_compact": true, "history_str": msgs_str })) - } else { - Ok(json!({ "needs_compact": false })) - } - } - - async fn post_process( - &self, - context: &mut Context, - result: &Result, - ) -> Result> { - let res = result.as_ref().unwrap(); - if res.get("needs_compact").and_then(|v| v.as_bool()) == Some(true) { - context.set("history_to_compact", res.get("history_str").unwrap().clone()); - Ok(ProcessResult::new(PiState::DoCompact, "do_compact".to_string())) - } else { - Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) - } - } -} - -struct DoCompactNode { - app: Arc, -} - -#[async_trait::async_trait] -impl Node for DoCompactNode { - type State = PiState; - - async fn execute(&self, context: &Context) -> Result { - let history_str = context.get("history_to_compact").unwrap().as_str().unwrap(); - - let summary_prompt = json!({ - "role": "user", - "content": format!("Summarize the entire conversation history concisely, retaining all tool outcomes and important context so it can be used to replace the history entirely:\n{}", history_str) - }); - - println!("Sending compaction request to LLM..."); - let mut retries = 0; - let max_retries = 3; - loop { - match self.app.llm.chat_completion(vec![summary_prompt.clone()], Value::Null).await { - Ok(summary_res) => return Ok(summary_res), - Err(e) => { - retries += 1; - if retries > max_retries { - return Err(e); - } - println!("[Compaction Failed]: {}. Retrying ({}/{}) in 2 seconds...", e, retries, max_retries); - tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; - } - } - } - } - - async fn post_process( - &self, - context: &mut Context, - result: &Result, - ) -> Result> { - let res = match result { - Ok(v) => v, - Err(e) => { - println!("[Compaction Failed]: {}", e); - return Ok(ProcessResult::new(PiState::CallLLM, "call_llm_fallback".to_string())); - } - }; - - if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { - if let Some(choice) = choices.first() { - let summary_text = choice["message"]["content"].as_str().unwrap_or(""); - - let compact_msg = AgentMessage { - id: Uuid::new_v4().to_string(), - parent_id: None, - role: "system".to_string(), - content: format!("Previous conversation summary:\n{}", summary_text), - name: None, - tool_calls: None, - tool_call_id: None, - clears_history: Some(true), - }; - - self.app.session_manager.append_message(&compact_msg)?; - - let messages = json!([compact_msg]); - context.set("messages", messages); - println!("History compressed successfully."); - } - } - - Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) - } -} - -#[async_trait::async_trait] -impl Node for LLMReasoningNode { - type State = PiState; - - async fn execute(&self, context: &Context) -> Result { - let messages = context.get("messages").unwrap_or(&json!([])).clone(); - - let tools = json!([ - { - "type": "function", - "function": { - "name": "read_file", - "description": "Read the contents of a file", - "parameters": { - "type": "object", - "properties": { - "path": { "type": "string" } - }, - "required": ["path"] - } - } - }, - { - "type": "function", - "function": { - "name": "write_file", - "description": "Write contents to a file", - "parameters": { - "type": "object", - "properties": { - "path": { "type": "string" }, - "content": { "type": "string" } - }, - "required": ["path", "content"] - } - } - }, - { - "type": "function", - "function": { - "name": "bash", - "description": "Execute a bash command", - "parameters": { - "type": "object", - "properties": { - "command": { "type": "string" } - }, - "required": ["command"] - } - } - } - ]); - - let mut openai_messages = Vec::new(); - // Convert AgentMessage to format expected by OpenAI - if let Some(arr) = messages.as_array() { - for m in arr { - let mut mapped = json!({ - "role": m["role"].as_str().unwrap(), - "content": m["content"].as_str().unwrap() - }); - if let Some(calls) = m.get("tool_calls") { - if !calls.is_null() { - mapped.as_object_mut().unwrap().insert("tool_calls".to_string(), calls.clone()); - } - } - if let Some(tid) = m.get("tool_call_id") { - if !tid.is_null() { - mapped.as_object_mut().unwrap().insert("tool_call_id".to_string(), tid.clone()); - } - } - openai_messages.push(mapped); - } - } - - let response = self.app.llm.chat_completion(openai_messages, tools).await?; - Ok(response) - } - - async fn post_process( - &self, - context: &mut Context, - result: &Result, - ) -> Result> { - let res = match result { - Ok(v) => v, - Err(e) => { - println!("\n[LLM Error]: {}\n", e); - return Ok(ProcessResult::new(PiState::WaitForInput, "error".to_string())); - } - }; - - // Ensure choice 0 exists - if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { - if let Some(choice) = choices.first() { - let msg = choice.get("message").unwrap(); - let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or(""); - let tool_calls = msg.get("tool_calls"); - - let agent_msg = AgentMessage { - id: Uuid::new_v4().to_string(), - parent_id: None, - role: "assistant".to_string(), - content: content.to_string(), - name: None, - tool_calls: tool_calls.cloned(), - tool_call_id: None, - clears_history: None, - }; - - // Persist - self.app.session_manager.append_message(&agent_msg)?; - - // Print - if !content.is_empty() { - println!("\nAssistant: {}\n", content); - } - - // Update context - let mut messages = context.get("messages").cloned().unwrap_or(json!([])); - messages.as_array_mut().unwrap().push(serde_json::to_value(&agent_msg)?); - context.set("messages", messages); - - if let Some(tc) = tool_calls { - if !tc.is_null() && tc.as_array().map_or(false, |a| !a.is_empty()) { - return Ok(ProcessResult::new(PiState::ExecuteTool, "execute_tool".to_string())); - } - } - } - } - - Ok(ProcessResult::new(PiState::WaitForInput, "wait_for_input".to_string())) - } -} - -struct ToolExecutionNode { - app: Arc, -} - -#[async_trait::async_trait] -impl Node for ToolExecutionNode { - type State = PiState; - - async fn execute(&self, context: &Context) -> Result { - let messages = context.get("messages").unwrap().as_array().unwrap(); - let last_msg = messages.last().unwrap(); - - let mut tool_results = Vec::new(); - - if let Some(tool_calls) = last_msg.get("tool_calls").and_then(|tc| tc.as_array()) { - for call in tool_calls { - let id = call["id"].as_str().unwrap().to_string(); - let func = &call["function"]; - let name = func["name"].as_str().unwrap(); - let args_str = func["arguments"].as_str().unwrap(); - let args: Value = serde_json::from_str(args_str)?; - - println!("Executing tool: {} with args: {}", name, args_str); - - let output = match name { - "read_file" => { - let path = args["path"].as_str().unwrap(); - read_file(path) - } - "write_file" => { - let path = args["path"].as_str().unwrap(); - let content = args["content"].as_str().unwrap(); - write_file(path, content) - } - "bash" => { - let command = args["command"].as_str().unwrap(); - execute_bash(command, ".") - } - _ => format!("Unknown tool: {}", name), - }; - - let agent_msg = AgentMessage { - id: Uuid::new_v4().to_string(), - parent_id: None, - role: "tool".to_string(), - content: output, - name: Some(name.to_string()), - tool_calls: None, - tool_call_id: Some(id), - clears_history: None, - }; - - tool_results.push(agent_msg); - } - } - - Ok(serde_json::to_value(tool_results)?) - } - - async fn post_process( - &self, - context: &mut Context, - result: &Result, - ) -> Result> { - let tool_results: Vec = serde_json::from_value(result.as_ref().unwrap().clone())?; - - let mut messages = context.get("messages").cloned().unwrap_or(json!([])); - - for msg in tool_results { - self.app.session_manager.append_message(&msg)?; - messages.as_array_mut().unwrap().push(serde_json::to_value(&msg)?); - } - - context.set("messages", messages); - - Ok(ProcessResult::new(PiState::CheckSize, "check_size".to_string())) - } -} - #[tokio::main] async fn main() -> Result<()> { let args = Args::parse(); @@ -454,7 +30,7 @@ async fn main() -> Result<()> { let config = AppConfig::load(&cwd)?; // Load API Key - let (api_key, mut endpoint) = if let Some(model_conf) = config.models.get(&args.model) { + let (api_key, mut endpoint): (String, String) = if let Some(model_conf) = config.models.get(&args.model) { if let Some(provider_conf) = config.providers.get(&model_conf.provider) { let key = std::env::var(&provider_conf.api_key_env).unwrap_or_else(|_| "dummy_key".to_string()); (key, provider_conf.api_base.clone()) @@ -501,7 +77,6 @@ async fn main() -> Result<()> { ("llm", "tool", PiState::ExecuteTool), ("llm", "input", PiState::WaitForInput), ("tool", "check_size", PiState::CheckSize) - // Implicit default stop for PiState::Finished ] ); diff --git a/examples/pi/src/nodes/check_size.rs b/examples/pi/src/nodes/check_size.rs new file mode 100644 index 0000000..f31f92f --- /dev/null +++ b/examples/pi/src/nodes/check_size.rs @@ -0,0 +1,47 @@ +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{json, Value}; +use std::sync::Arc; +use crate::state::{AppContext, PiState}; + +pub struct CheckSizeNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for CheckSizeNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let messages = context.get("messages").unwrap_or(&json!([])).clone(); + let msgs_str = serde_json::to_string(&messages).unwrap_or_default(); + let estimated_tokens = msgs_str.len() / 4; + + // Handle gracefully if config doesn't perfectly match model + let default_compact_thresh = 100000; + let threshold = self.app.config.models.get(&self.app.model_name) + .map(|m| m.compact_threshold) + .unwrap_or(default_compact_thresh); + + if self.app.config.general.auto_compact && estimated_tokens > threshold && messages.as_array().map(|a| a.len() > 3).unwrap_or(false) { + println!("\n[Auto Compacting History (est {} tokens > {})]...", estimated_tokens, threshold); + Ok(json!({ "needs_compact": true, "history_str": msgs_str })) + } else { + Ok(json!({ "needs_compact": false })) + } + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = result.as_ref().unwrap(); + if res.get("needs_compact").and_then(|v| v.as_bool()) == Some(true) { + context.set("history_to_compact", res.get("history_str").unwrap().clone()); + Ok(ProcessResult::new(PiState::DoCompact, "do_compact".to_string())) + } else { + Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + } + } +} diff --git a/examples/pi/src/nodes/do_compact.rs b/examples/pi/src/nodes/do_compact.rs new file mode 100644 index 0000000..2446ad7 --- /dev/null +++ b/examples/pi/src/nodes/do_compact.rs @@ -0,0 +1,81 @@ +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{json, Value}; +use std::sync::Arc; +use uuid::Uuid; +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; + +pub struct DoCompactNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for DoCompactNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let history_str = context.get("history_to_compact").unwrap().as_str().unwrap(); + + let summary_prompt = json!({ + "role": "user", + "content": format!("Summarize the entire conversation history concisely, retaining all tool outcomes and important context so it can be used to replace the history entirely:\n{}", history_str) + }); + + println!("Sending compaction request to LLM..."); + let mut retries = 0; + let max_retries = 3; + loop { + match self.app.llm.chat_completion(vec![summary_prompt.clone()], Value::Null).await { + Ok(summary_res) => return Ok(summary_res), + Err(e) => { + retries += 1; + if retries > max_retries { + return Err(e); + } + println!("[Compaction Failed]: {}. Retrying ({}/{}) in 2 seconds...", e, retries, max_retries); + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + } + } + } + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = match result { + Ok(v) => v, + Err(e) => { + println!("[Compaction Failed]: {}", e); + return Ok(ProcessResult::new(PiState::CallLLM, "call_llm_fallback".to_string())); + } + }; + + if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { + if let Some(choice) = choices.first() { + let summary_text = choice["message"]["content"].as_str().unwrap_or(""); + + let compact_msg = AgentMessage { + id: Uuid::new_v4().to_string(), + parent_id: None, + role: "system".to_string(), + content: format!("Previous conversation summary:\n{}", summary_text), + name: None, + tool_calls: None, + tool_call_id: None, + clears_history: Some(true), + }; + + self.app.session_manager.append_message(&compact_msg)?; + + let messages = json!([compact_msg]); + context.set("messages", messages); + println!("History compressed successfully."); + } + } + + Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + } +} diff --git a/examples/pi/src/nodes/input.rs b/examples/pi/src/nodes/input.rs new file mode 100644 index 0000000..79a005c --- /dev/null +++ b/examples/pi/src/nodes/input.rs @@ -0,0 +1,65 @@ +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{json, Value}; +use std::io::{self, Write}; +use std::sync::Arc; +use uuid::Uuid; +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; + +pub struct InputNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for InputNode { + type State = PiState; + + async fn execute(&self, _context: &Context) -> Result { + print!("> "); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let text = input.trim().to_string(); + + if text == "exit" || text == "quit" { + return Ok(json!({ "command": "exit" })); + } + + let id = Uuid::new_v4().to_string(); + + let msg = AgentMessage { + id: id.clone(), + parent_id: None, + role: "user".to_string(), + content: text, + name: None, + tool_calls: None, + tool_call_id: None, + clears_history: None, + }; + + // Persist immediately + self.app.session_manager.append_message(&msg)?; + + Ok(json!({ "message": msg })) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = result.as_ref().unwrap(); + if res.get("command").and_then(|v| v.as_str()) == Some("exit") { + return Ok(ProcessResult::new(PiState::Finished, "finished".to_string())); + } + + let msg_val = res.get("message").unwrap(); + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); + messages.as_array_mut().unwrap().push(msg_val.clone()); + context.set("messages", messages); + + Ok(ProcessResult::new(PiState::CheckSize, "check_size".to_string())) + } +} diff --git a/examples/pi/src/nodes/llm_reasoning.rs b/examples/pi/src/nodes/llm_reasoning.rs new file mode 100644 index 0000000..a871aca --- /dev/null +++ b/examples/pi/src/nodes/llm_reasoning.rs @@ -0,0 +1,146 @@ +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{json, Value}; +use std::sync::Arc; +use uuid::Uuid; +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; + +pub struct LLMReasoningNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for LLMReasoningNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let messages = context.get("messages").unwrap_or(&json!([])).clone(); + + let tools = json!([ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read the contents of a file", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string" } + }, + "required": ["path"] + } + } + }, + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write contents to a file", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "content": { "type": "string" } + }, + "required": ["path", "content"] + } + } + }, + { + "type": "function", + "function": { + "name": "bash", + "description": "Execute a bash command", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" } + }, + "required": ["command"] + } + } + } + ]); + + let mut openai_messages = Vec::new(); + // Convert AgentMessage to format expected by OpenAI + if let Some(arr) = messages.as_array() { + for m in arr { + let mut mapped = json!({ + "role": m["role"].as_str().unwrap(), + "content": m["content"].as_str().unwrap() + }); + if let Some(calls) = m.get("tool_calls") { + if !calls.is_null() { + mapped.as_object_mut().unwrap().insert("tool_calls".to_string(), calls.clone()); + } + } + if let Some(tid) = m.get("tool_call_id") { + if !tid.is_null() { + mapped.as_object_mut().unwrap().insert("tool_call_id".to_string(), tid.clone()); + } + } + openai_messages.push(mapped); + } + } + + let response = self.app.llm.chat_completion(openai_messages, tools).await?; + Ok(response) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = match result { + Ok(v) => v, + Err(e) => { + println!("\n[LLM Error]: {}\n", e); + return Ok(ProcessResult::new(PiState::WaitForInput, "error".to_string())); + } + }; + + // Ensure choice 0 exists + if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { + if let Some(choice) = choices.first() { + let msg = choice.get("message").unwrap(); + let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or(""); + let tool_calls = msg.get("tool_calls"); + + let agent_msg = AgentMessage { + id: Uuid::new_v4().to_string(), + parent_id: None, + role: "assistant".to_string(), + content: content.to_string(), + name: None, + tool_calls: tool_calls.cloned(), + tool_call_id: None, + clears_history: None, + }; + + // Persist + self.app.session_manager.append_message(&agent_msg)?; + + // Print + if !content.is_empty() { + println!("\nAssistant: {}\n", content); + } + + // Update context + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); + messages.as_array_mut().unwrap().push(serde_json::to_value(&agent_msg)?); + context.set("messages", messages); + + if let Some(tc) = tool_calls { + if !tc.is_null() && tc.as_array().map_or(false, |a| !a.is_empty()) { + return Ok(ProcessResult::new(PiState::ExecuteTool, "execute_tool".to_string())); + } + } + } + } + + Ok(ProcessResult::new(PiState::WaitForInput, "wait_for_input".to_string())) + } +} diff --git a/examples/pi/src/nodes/mod.rs b/examples/pi/src/nodes/mod.rs new file mode 100644 index 0000000..7105536 --- /dev/null +++ b/examples/pi/src/nodes/mod.rs @@ -0,0 +1,11 @@ +mod input; +mod check_size; +mod do_compact; +mod llm_reasoning; +mod tool_execution; + +pub use input::InputNode; +pub use check_size::CheckSizeNode; +pub use do_compact::DoCompactNode; +pub use llm_reasoning::LLMReasoningNode; +pub use tool_execution::ToolExecutionNode; diff --git a/examples/pi/src/nodes/tool_execution.rs b/examples/pi/src/nodes/tool_execution.rs new file mode 100644 index 0000000..7b145a6 --- /dev/null +++ b/examples/pi/src/nodes/tool_execution.rs @@ -0,0 +1,86 @@ +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{json, Value}; +use std::sync::Arc; +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; +use crate::utils::tools::{execute_bash, read_file, write_file}; + +pub struct ToolExecutionNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for ToolExecutionNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let messages = context.get("messages").unwrap().as_array().unwrap(); + let last_msg = messages.last().unwrap(); + + let mut tool_results = Vec::new(); + + if let Some(tool_calls) = last_msg.get("tool_calls").and_then(|tc| tc.as_array()) { + for call in tool_calls { + let id = call["id"].as_str().unwrap().to_string(); + let func = &call["function"]; + let name = func["name"].as_str().unwrap(); + let args_str = func["arguments"].as_str().unwrap(); + let args: Value = serde_json::from_str(args_str)?; + + println!("Executing tool: {} with args: {}", name, args_str); + + let output = match name { + "read_file" => { + let path = args["path"].as_str().unwrap(); + read_file(path) + } + "write_file" => { + let path = args["path"].as_str().unwrap(); + let content = args["content"].as_str().unwrap(); + write_file(path, content) + } + "bash" => { + let command = args["command"].as_str().unwrap(); + execute_bash(command, ".") + } + _ => format!("Unknown tool: {}", name), + }; + + let agent_msg = AgentMessage { + id: uuid::Uuid::new_v4().to_string(), + parent_id: None, + role: "tool".to_string(), + content: output, + name: Some(name.to_string()), + tool_calls: None, + tool_call_id: Some(id), + clears_history: None, + }; + + tool_results.push(agent_msg); + } + } + + Ok(serde_json::to_value(tool_results)?) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let tool_results: Vec = serde_json::from_value(result.as_ref().unwrap().clone())?; + + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); + + for msg in tool_results { + self.app.session_manager.append_message(&msg)?; + messages.as_array_mut().unwrap().push(serde_json::to_value(&msg)?); + } + + context.set("messages", messages); + + Ok(ProcessResult::new(PiState::CheckSize, "check_size".to_string())) + } +} diff --git a/examples/pi/src/state.rs b/examples/pi/src/state.rs new file mode 100644 index 0000000..be2645e --- /dev/null +++ b/examples/pi/src/state.rs @@ -0,0 +1,32 @@ +use pocketflow_rs::ProcessState; +use strum::Display; +use crate::utils::pi_llm::PiLLM; +use crate::utils::session_manager::SessionManager; +use crate::utils::config::AppConfig; + +#[derive(Debug, Clone, PartialEq, Default, Display)] +#[strum(serialize_all = "snake_case")] +pub enum PiState { + #[default] + Default, + CheckSize, + DoCompact, + CallLLM, + ExecuteTool, + WaitForInput, + Finished, +} + +impl ProcessState for PiState { + fn is_default(&self) -> bool { + matches!(self, PiState::Default) + } +} + +// Global shared components between nodes +pub struct AppContext { + pub llm: PiLLM, + pub session_manager: SessionManager, + pub config: AppConfig, + pub model_name: String, +} diff --git a/src/utils/config.rs b/examples/pi/src/utils/config.rs similarity index 100% rename from src/utils/config.rs rename to examples/pi/src/utils/config.rs diff --git a/examples/pi/src/utils/mod.rs b/examples/pi/src/utils/mod.rs new file mode 100644 index 0000000..3174eaf --- /dev/null +++ b/examples/pi/src/utils/mod.rs @@ -0,0 +1,8 @@ +pub mod pi_llm; +pub mod session_manager; +pub mod tools; +pub mod config; + +pub use pi_llm::PiLLM; +pub use session_manager::SessionManager; +pub use config::AppConfig; diff --git a/src/utils/pi_llm.rs b/examples/pi/src/utils/pi_llm.rs similarity index 100% rename from src/utils/pi_llm.rs rename to examples/pi/src/utils/pi_llm.rs diff --git a/src/utils/session_manager.rs b/examples/pi/src/utils/session_manager.rs similarity index 100% rename from src/utils/session_manager.rs rename to examples/pi/src/utils/session_manager.rs diff --git a/src/utils/tools/mod.rs b/examples/pi/src/utils/tools/mod.rs similarity index 100% rename from src/utils/tools/mod.rs rename to examples/pi/src/utils/tools/mod.rs diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 7449038..82e3df8 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -4,7 +4,3 @@ pub mod text_chunking; pub mod vector_db; pub mod viz_debug; pub mod web_search; -pub mod pi_llm; -pub mod session_manager; -pub mod tools; -pub mod config; diff --git a/xxx.md b/xxx.md deleted file mode 100644 index eb3a010..0000000 --- a/xxx.md +++ /dev/null @@ -1,70 +0,0 @@ -# 计划: 实现 pi-mono 会话压缩 (Compaction) 与 TOML 配置支持 - -## Context (背景与目标) -目前系统已打通核心交互循环和会话持久化。随着对话的增加,每次请求附带的历史上下文会不断变长,最终超出 LLM 模型的上下文窗口(Context Window)。 -我们需要引入会话压缩(Compaction)机制,并在 `pi` 启动时加载一个 `config.toml` 文件,用于分类配置是否启用自动压缩、每种模型的窗口大小约束、以及对应的大模型服务提供商 (Provider) 设定。 - -由于安全设定约束,**在计划模式下,系统严格禁止直接执行代码提交(`git commit`)等修改系统状态的操作。** 我将在获得您对本计划的批准进入执行模式后,**第一步便为您执行代码提交**,然后再进行后续的代码修改。 - -## Proposed Changes (架构与组件设计) - -### 1. 配置管理设计 (`src/config.rs`) -通过引入 `toml` 和 `serde` 读取全局或局部配置文件(如 `config.toml`)。设计配置类结构如下: - -```rust -#[derive(Debug, Deserialize)] -pub struct AppConfig { - pub general: GeneralConfig, - pub providers: HashMap, - pub models: HashMap, -} - -#[derive(Debug, Deserialize)] -pub struct GeneralConfig { - pub auto_compact: bool, // 是否开启自动压缩 -} - -#[derive(Debug, Deserialize)] -pub struct ProviderConfig { - pub api_base: String, - pub api_key_env: String, // 指定从哪个环境变量读取 Key,提高安全性 -} - -#[derive(Debug, Deserialize)] -pub struct ModelConfig { - pub provider: String, // 关联的 Provider 名称 - pub context_window: usize, // 模型的绝对最大窗口 (按 Token 数,粗略可用字数/4近似) - pub compact_threshold: usize, // 触发压缩的阈值 (如超过 80% 则压缩) -} -``` - -### 2. Append-Only 会话日志的压缩实现 (`src/utils/session_manager.rs`) -保留 `pi-mono` 的 Append-Only JSONL 特性,不直接修改旧日志文件。我们在 `AgentMessage` 中新增一个特殊字段: -- `clears_history: Option` - -当发生压缩时,系统将先前的对话交由 LLM 生成摘要,并写入一条 `role: "system"`, `content: "Previous conversation summary: ..."` 且附带 `clears_history: true` 的新记录。 -在重启应用并调用 `load_history` 恢复列表时,如果读到 `clears_history == Some(true)` 的消息,就将内存中的 `messages` 清空(或保留最初的系统设定),只保留这条 Summary 继续往下构建,完成完美的持久化无损截断。 - -### 3. PiLLM 动态路由支持 (`src/utils/pi_llm.rs`) -修改 `PiLLM` 内部逻辑: -不再硬编码读取 `OPENAI_API_KEY`,而是根据当前的选项,从 `AppConfig` 查找到对应的 `ModelConfig`,接着查找到 `ProviderConfig`,使用对应的 `api_base` 和对应的环境变量(`api_key_env`)进行鉴权与请求分发。 - -### 4. 压缩节点注入 (`src/bin/pi.rs`) -- 在 `LLMReasoningNode` 的 `execute` 开始前加入 Token 容量估算逻辑(如按字符串长度评估或引入 tiktoken 库)。 -- 若超过 `model.compact_threshold` 且 `auto_compact` 为 true,即刻在此流程内部或单独生成一个阻塞调用,请求模型生成摘要("Please summarize the history conversation concisely...")。 -- 成功获得摘要后,实例化一段 `clears_history=true` 的 `AgentMessage` 调用 `append_message` 持久化,重置 Context 中的 `messages`,接着再处理用户当下真实的发问。 - -### 5. `Cargo.toml` 依赖更新 -添加 `toml` crates 支持。 - -## Verification Plan (验证计划) -### Automated Tests -1. 编写对 `AppConfig` TOML 解析的单元测试。 -2. 针对包含 `clears_history: true` 的模拟 `.jsonl` 文件编写 `SessionManager::load_history` 测试,断言数组应当短路清空,仅保留之后的有效长度。 - -### Manual Verification -1. **代码提交流程验证**:执行模式开启后的第一件事就是运行 `git commit -am "chore: initial working implementation before compaction"`,证明我们遵守了指令。 -2. 启动代理程序加载携带多 `models` 和 `providers` 的 `config.toml`。 -3. 把某配置大模型的 `compact_threshold` 改为非常小的值(如 `20`)。 -4. 对话几次,让长度超过该极小阈值,观察控制台是否输出了 "[Auto Compacting History...]" 提示。 -5. 通过查看底层 `log.jsonl`,证实末尾产生了一条标志性的含有 `clears_history: true` 的 Json 行。使用 `cargo run` 重启程序,核实历史条数(`Loaded X messages`)是否明显变少(因为老信息已被抛弃截断)。 From 5b7e873869e07cbe7ec138aceee22985b2630678 Mon Sep 17 00:00:00 2001 From: dahai9 Date: Sun, 29 Mar 2026 10:22:17 +0800 Subject: [PATCH 9/9] docs: update README with SubFlow and Mermaid visualization; prepare for PR --- README.md | 35 +++++++----- examples/pi/src/main.rs | 69 ++++++++++++++++-------- examples/pi/src/nodes/check_size.rs | 34 ++++++++---- examples/pi/src/nodes/do_compact.rs | 35 +++++++----- examples/pi/src/nodes/input.rs | 18 ++++--- examples/pi/src/nodes/llm_reasoning.rs | 40 ++++++++++---- examples/pi/src/nodes/mod.rs | 4 +- examples/pi/src/nodes/tool_execution.rs | 31 ++++++----- examples/pi/src/state.rs | 6 +-- examples/pi/src/utils/config.rs | 16 +++--- examples/pi/src/utils/mod.rs | 4 +- examples/pi/src/utils/pi_llm.rs | 18 +++++-- examples/pi/src/utils/session_manager.rs | 4 +- examples/pi/src/utils/tools/mod.rs | 2 +- src/flow.rs | 13 +++-- src/node.rs | 16 +++--- 16 files changed, 228 insertions(+), 117 deletions(-) diff --git a/README.md b/README.md index c352579..79875b7 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ A Rust implementation of [PocketFlow](https://github.com/The-Pocket/PocketFlow), - ⚡ **Async first:** Non-blocking node execution and post-processing - 📦 **Batch support:** High-performance processing of multiple contexts - 🧩 **Extensible:** Custom state management and node systems +- 🌳 **Composable:** Support for `SubFlow` nodes to build recursive workflows +- 📊 **Flow Visualization:** Generate Mermaid TD diagrams from flow graphs - 🛠️ **Utility-rich:** Optional integrations for OpenAI, Qdrant, and web search ## 🚀 Quick Start @@ -44,6 +46,8 @@ impl ProcessState for MyState { fn is_default(&self) -> bool { matches!(self, MyState::Default) } + + // to_condition() is automatically implemented via Display (strum) } ``` @@ -98,26 +102,32 @@ let result = flow.run(context).await?; ## 🏗️ Advanced Usage -### Batch Processing +### Composition with SubFlow -Build high-throughput flows for parallel processing: +You can nest a `Flow` inside another `Flow` using `SubFlowNode`. This allows for modular and recursive workflow design. ```rust -use pocketflow_rs::build_batch_flow; +use pocketflow_rs::{SubFlowNode, Flow}; -let batch_flow = build_batch_flow!( - start: ("start", node1), - nodes: [("next", node2)], - edges: [ - ("start", "next", MyState::Success) - ], - batch_size: 10 +let sub_flow = create_sub_flow(); +let sub_flow_node = SubFlowNode::new( + sub_flow, + |parent_context| parent_context.clone(), // context_builder + |parent_context, result| { /* map sub-flow result back to parent state */ } ); +``` + +### Flow Visualization + +Generate Mermaid TD flowchart strings automatically to visualize your complex workflows. -let contexts = vec![Context::new(); 10]; -let results = batch_flow.run_batch(contexts).await?; +```rust +let mermaid = flow.to_mermaid(); +println!("{}", mermaid); ``` +Check out `test_dir/pi_flow.mmd` after running the `pi` agent for a real-world example. + ## 🛠️ Available Features Customize `pocketflow_rs` by enabling the features you need in your `Cargo.toml`: @@ -139,6 +149,7 @@ pocketflow_rs = { version = "0.1.0", features = ["openai", "qdrant"] } Check out the `examples/` directory for detailed implementations: - 🟢 [**basic.rs**](./examples/basic.rs): Basic flow with custom states +- 🤖 [**pi**](./examples/pi/): Modular interactive coding agent with tool-use and history compaction - 🗃️ [**text2sql**](./examples/text2sql/): Text-to-SQL workflow using OpenAI - 🔍 [**pocketflow-rs-rag**](./examples/pocketflow-rs-rag/): Retrieval-Augmented Generation (RAG) system diff --git a/examples/pi/src/main.rs b/examples/pi/src/main.rs index 0db1b84..d7f218f 100644 --- a/examples/pi/src/main.rs +++ b/examples/pi/src/main.rs @@ -1,9 +1,12 @@ use anyhow::Result; use clap::Parser; -use pocketflow_rs::{build_flow, Context}; +use pi::{ + AppConfig, AppContext, CheckSizeNode, DoCompactNode, InputNode, LLMReasoningNode, PiLLM, + PiState, SessionManager, ToolExecutionNode, +}; +use pocketflow_rs::{Context, build_flow}; use serde_json::json; use std::sync::Arc; -use pi::{InputNode, CheckSizeNode, DoCompactNode, LLMReasoningNode, ToolExecutionNode, AppContext, PiState, PiLLM, SessionManager, AppConfig}; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -21,26 +24,38 @@ struct Args { #[tokio::main] async fn main() -> Result<()> { let args = Args::parse(); - + // Setup directory and SessionManager let cwd = std::env::current_dir()?; let session_manager = SessionManager::new(&cwd); - + // Load config let config = AppConfig::load(&cwd)?; // Load API Key - let (api_key, mut endpoint): (String, String) = if let Some(model_conf) = config.models.get(&args.model) { - if let Some(provider_conf) = config.providers.get(&model_conf.provider) { - let key = std::env::var(&provider_conf.api_key_env).unwrap_or_else(|_| "dummy_key".to_string()); - (key, provider_conf.api_base.clone()) + let (api_key, mut endpoint): (String, String) = + if let Some(model_conf) = config.models.get(&args.model) { + if let Some(provider_conf) = config.providers.get(&model_conf.provider) { + let key = std::env::var(&provider_conf.api_key_env) + .unwrap_or_else(|_| "dummy_key".to_string()); + (key, provider_conf.api_base.clone()) + } else { + ( + std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()), + "https://api.openai.com/v1".to_string(), + ) + } } else { - (std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()), "https://api.openai.com/v1".to_string()) - } - } else { - println!("Warning: Model '{}' not found in config, falling back to openai env vars.", args.model); - (std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()), std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string())) - }; + println!( + "Warning: Model '{}' not found in config, falling back to openai env vars.", + args.model + ); + ( + std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()), + std::env::var("OPENAI_BASE_URL") + .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()), + ) + }; if !endpoint.ends_with("/chat/completions") { endpoint = format!("{}/chat/completions", endpoint.trim_end_matches('/')); @@ -55,11 +70,21 @@ async fn main() -> Result<()> { model_name: args.model.clone(), }); - let input_node = InputNode { app: app_context.clone() }; - let check_size_node = CheckSizeNode { app: app_context.clone() }; - let compact_node = DoCompactNode { app: app_context.clone() }; - let llm_node = LLMReasoningNode { app: app_context.clone() }; - let tool_node = ToolExecutionNode { app: app_context.clone() }; + let input_node = InputNode { + app: app_context.clone(), + }; + let check_size_node = CheckSizeNode { + app: app_context.clone(), + }; + let compact_node = DoCompactNode { + app: app_context.clone(), + }; + let llm_node = LLMReasoningNode { + app: app_context.clone(), + }; + let tool_node = ToolExecutionNode { + app: app_context.clone(), + }; let flow = build_flow!( start: ("input", input_node), @@ -81,7 +106,7 @@ async fn main() -> Result<()> { ); let mut context = Context::new(); - + // Load history let history = app_context.session_manager.load_history(None)?; if !history.is_empty() { @@ -91,13 +116,13 @@ async fn main() -> Result<()> { } else { context.set("messages", json!([])); } - // Export flow visualization + // Export flow visualization std::fs::create_dir_all("test_dir")?; let mermaid = flow.to_mermaid(); std::fs::write("test_dir/pi_flow.mmd", mermaid)?; println!("Saved flow visualization to test_dir/pi_flow.mmd"); println!("pi agent started. Type 'exit' to quit."); - + match flow.run(context).await { Ok(_) => println!("Agent shutdown."), Err(e) => eprintln!("Error running flow: {}", e), diff --git a/examples/pi/src/nodes/check_size.rs b/examples/pi/src/nodes/check_size.rs index f31f92f..17edb0e 100644 --- a/examples/pi/src/nodes/check_size.rs +++ b/examples/pi/src/nodes/check_size.rs @@ -1,8 +1,8 @@ +use crate::state::{AppContext, PiState}; use anyhow::Result; use pocketflow_rs::{Context, Node, ProcessResult}; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::sync::Arc; -use crate::state::{AppContext, PiState}; pub struct CheckSizeNode { pub app: Arc, @@ -16,15 +16,25 @@ impl Node for CheckSizeNode { let messages = context.get("messages").unwrap_or(&json!([])).clone(); let msgs_str = serde_json::to_string(&messages).unwrap_or_default(); let estimated_tokens = msgs_str.len() / 4; - + // Handle gracefully if config doesn't perfectly match model let default_compact_thresh = 100000; - let threshold = self.app.config.models.get(&self.app.model_name) + let threshold = self + .app + .config + .models + .get(&self.app.model_name) .map(|m| m.compact_threshold) .unwrap_or(default_compact_thresh); - - if self.app.config.general.auto_compact && estimated_tokens > threshold && messages.as_array().map(|a| a.len() > 3).unwrap_or(false) { - println!("\n[Auto Compacting History (est {} tokens > {})]...", estimated_tokens, threshold); + + if self.app.config.general.auto_compact + && estimated_tokens > threshold + && messages.as_array().map(|a| a.len() > 3).unwrap_or(false) + { + println!( + "\n[Auto Compacting History (est {} tokens > {})]...", + estimated_tokens, threshold + ); Ok(json!({ "needs_compact": true, "history_str": msgs_str })) } else { Ok(json!({ "needs_compact": false })) @@ -38,8 +48,14 @@ impl Node for CheckSizeNode { ) -> Result> { let res = result.as_ref().unwrap(); if res.get("needs_compact").and_then(|v| v.as_bool()) == Some(true) { - context.set("history_to_compact", res.get("history_str").unwrap().clone()); - Ok(ProcessResult::new(PiState::DoCompact, "do_compact".to_string())) + context.set( + "history_to_compact", + res.get("history_str").unwrap().clone(), + ); + Ok(ProcessResult::new( + PiState::DoCompact, + "do_compact".to_string(), + )) } else { Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) } diff --git a/examples/pi/src/nodes/do_compact.rs b/examples/pi/src/nodes/do_compact.rs index 2446ad7..6e89284 100644 --- a/examples/pi/src/nodes/do_compact.rs +++ b/examples/pi/src/nodes/do_compact.rs @@ -1,10 +1,10 @@ +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; use anyhow::Result; use pocketflow_rs::{Context, Node, ProcessResult}; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::sync::Arc; use uuid::Uuid; -use crate::state::{AppContext, PiState}; -use crate::utils::session_manager::AgentMessage; pub struct DoCompactNode { pub app: Arc, @@ -16,24 +16,32 @@ impl Node for DoCompactNode { async fn execute(&self, context: &Context) -> Result { let history_str = context.get("history_to_compact").unwrap().as_str().unwrap(); - + let summary_prompt = json!({ "role": "user", "content": format!("Summarize the entire conversation history concisely, retaining all tool outcomes and important context so it can be used to replace the history entirely:\n{}", history_str) }); - + println!("Sending compaction request to LLM..."); let mut retries = 0; let max_retries = 3; loop { - match self.app.llm.chat_completion(vec![summary_prompt.clone()], Value::Null).await { + match self + .app + .llm + .chat_completion(vec![summary_prompt.clone()], Value::Null) + .await + { Ok(summary_res) => return Ok(summary_res), Err(e) => { retries += 1; if retries > max_retries { return Err(e); } - println!("[Compaction Failed]: {}. Retrying ({}/{}) in 2 seconds...", e, retries, max_retries); + println!( + "[Compaction Failed]: {}. Retrying ({}/{}) in 2 seconds...", + e, retries, max_retries + ); tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; } } @@ -49,14 +57,17 @@ impl Node for DoCompactNode { Ok(v) => v, Err(e) => { println!("[Compaction Failed]: {}", e); - return Ok(ProcessResult::new(PiState::CallLLM, "call_llm_fallback".to_string())); + return Ok(ProcessResult::new( + PiState::CallLLM, + "call_llm_fallback".to_string(), + )); } }; if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { if let Some(choice) = choices.first() { let summary_text = choice["message"]["content"].as_str().unwrap_or(""); - + let compact_msg = AgentMessage { id: Uuid::new_v4().to_string(), parent_id: None, @@ -67,15 +78,15 @@ impl Node for DoCompactNode { tool_call_id: None, clears_history: Some(true), }; - + self.app.session_manager.append_message(&compact_msg)?; - + let messages = json!([compact_msg]); context.set("messages", messages); println!("History compressed successfully."); } } - + Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) } } diff --git a/examples/pi/src/nodes/input.rs b/examples/pi/src/nodes/input.rs index 79a005c..bf76cf0 100644 --- a/examples/pi/src/nodes/input.rs +++ b/examples/pi/src/nodes/input.rs @@ -1,11 +1,11 @@ +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; use anyhow::Result; use pocketflow_rs::{Context, Node, ProcessResult}; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::io::{self, Write}; use std::sync::Arc; use uuid::Uuid; -use crate::state::{AppContext, PiState}; -use crate::utils::session_manager::AgentMessage; pub struct InputNode { pub app: Arc, @@ -27,7 +27,7 @@ impl Node for InputNode { } let id = Uuid::new_v4().to_string(); - + let msg = AgentMessage { id: id.clone(), parent_id: None, @@ -52,7 +52,10 @@ impl Node for InputNode { ) -> Result> { let res = result.as_ref().unwrap(); if res.get("command").and_then(|v| v.as_str()) == Some("exit") { - return Ok(ProcessResult::new(PiState::Finished, "finished".to_string())); + return Ok(ProcessResult::new( + PiState::Finished, + "finished".to_string(), + )); } let msg_val = res.get("message").unwrap(); @@ -60,6 +63,9 @@ impl Node for InputNode { messages.as_array_mut().unwrap().push(msg_val.clone()); context.set("messages", messages); - Ok(ProcessResult::new(PiState::CheckSize, "check_size".to_string())) + Ok(ProcessResult::new( + PiState::CheckSize, + "check_size".to_string(), + )) } } diff --git a/examples/pi/src/nodes/llm_reasoning.rs b/examples/pi/src/nodes/llm_reasoning.rs index a871aca..25c3e69 100644 --- a/examples/pi/src/nodes/llm_reasoning.rs +++ b/examples/pi/src/nodes/llm_reasoning.rs @@ -1,10 +1,10 @@ +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; use anyhow::Result; use pocketflow_rs::{Context, Node, ProcessResult}; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::sync::Arc; use uuid::Uuid; -use crate::state::{AppContext, PiState}; -use crate::utils::session_manager::AgentMessage; pub struct LLMReasoningNode { pub app: Arc, @@ -16,7 +16,7 @@ impl Node for LLMReasoningNode { async fn execute(&self, context: &Context) -> Result { let messages = context.get("messages").unwrap_or(&json!([])).clone(); - + let tools = json!([ { "type": "function", @@ -73,12 +73,18 @@ impl Node for LLMReasoningNode { }); if let Some(calls) = m.get("tool_calls") { if !calls.is_null() { - mapped.as_object_mut().unwrap().insert("tool_calls".to_string(), calls.clone()); + mapped + .as_object_mut() + .unwrap() + .insert("tool_calls".to_string(), calls.clone()); } } if let Some(tid) = m.get("tool_call_id") { if !tid.is_null() { - mapped.as_object_mut().unwrap().insert("tool_call_id".to_string(), tid.clone()); + mapped + .as_object_mut() + .unwrap() + .insert("tool_call_id".to_string(), tid.clone()); } } openai_messages.push(mapped); @@ -98,10 +104,13 @@ impl Node for LLMReasoningNode { Ok(v) => v, Err(e) => { println!("\n[LLM Error]: {}\n", e); - return Ok(ProcessResult::new(PiState::WaitForInput, "error".to_string())); + return Ok(ProcessResult::new( + PiState::WaitForInput, + "error".to_string(), + )); } }; - + // Ensure choice 0 exists if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { if let Some(choice) = choices.first() { @@ -130,17 +139,26 @@ impl Node for LLMReasoningNode { // Update context let mut messages = context.get("messages").cloned().unwrap_or(json!([])); - messages.as_array_mut().unwrap().push(serde_json::to_value(&agent_msg)?); + messages + .as_array_mut() + .unwrap() + .push(serde_json::to_value(&agent_msg)?); context.set("messages", messages); if let Some(tc) = tool_calls { if !tc.is_null() && tc.as_array().map_or(false, |a| !a.is_empty()) { - return Ok(ProcessResult::new(PiState::ExecuteTool, "execute_tool".to_string())); + return Ok(ProcessResult::new( + PiState::ExecuteTool, + "execute_tool".to_string(), + )); } } } } - Ok(ProcessResult::new(PiState::WaitForInput, "wait_for_input".to_string())) + Ok(ProcessResult::new( + PiState::WaitForInput, + "wait_for_input".to_string(), + )) } } diff --git a/examples/pi/src/nodes/mod.rs b/examples/pi/src/nodes/mod.rs index 7105536..230f3d5 100644 --- a/examples/pi/src/nodes/mod.rs +++ b/examples/pi/src/nodes/mod.rs @@ -1,11 +1,11 @@ -mod input; mod check_size; mod do_compact; +mod input; mod llm_reasoning; mod tool_execution; -pub use input::InputNode; pub use check_size::CheckSizeNode; pub use do_compact::DoCompactNode; +pub use input::InputNode; pub use llm_reasoning::LLMReasoningNode; pub use tool_execution::ToolExecutionNode; diff --git a/examples/pi/src/nodes/tool_execution.rs b/examples/pi/src/nodes/tool_execution.rs index 7b145a6..d828ace 100644 --- a/examples/pi/src/nodes/tool_execution.rs +++ b/examples/pi/src/nodes/tool_execution.rs @@ -1,10 +1,10 @@ -use anyhow::Result; -use pocketflow_rs::{Context, Node, ProcessResult}; -use serde_json::{json, Value}; -use std::sync::Arc; use crate::state::{AppContext, PiState}; use crate::utils::session_manager::AgentMessage; use crate::utils::tools::{execute_bash, read_file, write_file}; +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{Value, json}; +use std::sync::Arc; pub struct ToolExecutionNode { pub app: Arc, @@ -17,7 +17,7 @@ impl Node for ToolExecutionNode { async fn execute(&self, context: &Context) -> Result { let messages = context.get("messages").unwrap().as_array().unwrap(); let last_msg = messages.last().unwrap(); - + let mut tool_results = Vec::new(); if let Some(tool_calls) = last_msg.get("tool_calls").and_then(|tc| tc.as_array()) { @@ -29,7 +29,7 @@ impl Node for ToolExecutionNode { let args: Value = serde_json::from_str(args_str)?; println!("Executing tool: {} with args: {}", name, args_str); - + let output = match name { "read_file" => { let path = args["path"].as_str().unwrap(); @@ -70,17 +70,24 @@ impl Node for ToolExecutionNode { context: &mut Context, result: &Result, ) -> Result> { - let tool_results: Vec = serde_json::from_value(result.as_ref().unwrap().clone())?; - + let tool_results: Vec = + serde_json::from_value(result.as_ref().unwrap().clone())?; + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); - + for msg in tool_results { self.app.session_manager.append_message(&msg)?; - messages.as_array_mut().unwrap().push(serde_json::to_value(&msg)?); + messages + .as_array_mut() + .unwrap() + .push(serde_json::to_value(&msg)?); } - + context.set("messages", messages); - Ok(ProcessResult::new(PiState::CheckSize, "check_size".to_string())) + Ok(ProcessResult::new( + PiState::CheckSize, + "check_size".to_string(), + )) } } diff --git a/examples/pi/src/state.rs b/examples/pi/src/state.rs index be2645e..a0875fb 100644 --- a/examples/pi/src/state.rs +++ b/examples/pi/src/state.rs @@ -1,8 +1,8 @@ -use pocketflow_rs::ProcessState; -use strum::Display; +use crate::utils::config::AppConfig; use crate::utils::pi_llm::PiLLM; use crate::utils::session_manager::SessionManager; -use crate::utils::config::AppConfig; +use pocketflow_rs::ProcessState; +use strum::Display; #[derive(Debug, Clone, PartialEq, Default, Display)] #[strum(serialize_all = "snake_case")] diff --git a/examples/pi/src/utils/config.rs b/examples/pi/src/utils/config.rs index ef032a1..44b5903 100644 --- a/examples/pi/src/utils/config.rs +++ b/examples/pi/src/utils/config.rs @@ -20,7 +20,9 @@ pub struct GeneralConfig { impl Default for GeneralConfig { fn default() -> Self { - Self { auto_compact: default_auto_compact() } + Self { + auto_compact: default_auto_compact(), + } } } @@ -46,11 +48,11 @@ impl AppConfig { let workspace = workspace.as_ref(); let mut config_dir = workspace.to_path_buf(); config_dir.push(".pi"); - + if !config_dir.exists() { fs::create_dir_all(&config_dir)?; } - + let mut config_path = config_dir.clone(); config_path.push("config.toml"); @@ -81,10 +83,10 @@ compact_threshold = 6000 let content = fs::read_to_string(&config_path) .with_context(|| format!("Failed to read config file at {:?}", config_path))?; - - let config: AppConfig = toml::from_str(&content) - .with_context(|| "Failed to parse config.toml")?; - + + let config: AppConfig = + toml::from_str(&content).with_context(|| "Failed to parse config.toml")?; + Ok(config) } diff --git a/examples/pi/src/utils/mod.rs b/examples/pi/src/utils/mod.rs index 3174eaf..a6f209e 100644 --- a/examples/pi/src/utils/mod.rs +++ b/examples/pi/src/utils/mod.rs @@ -1,8 +1,8 @@ +pub mod config; pub mod pi_llm; pub mod session_manager; pub mod tools; -pub mod config; +pub use config::AppConfig; pub use pi_llm::PiLLM; pub use session_manager::SessionManager; -pub use config::AppConfig; diff --git a/examples/pi/src/utils/pi_llm.rs b/examples/pi/src/utils/pi_llm.rs index a0b8e14..a5b4dbd 100644 --- a/examples/pi/src/utils/pi_llm.rs +++ b/examples/pi/src/utils/pi_llm.rs @@ -1,6 +1,6 @@ use anyhow::Result; use reqwest::Client; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use tracing::info; pub struct PiLLM { @@ -22,17 +22,21 @@ impl PiLLM { pub async fn chat_completion(&self, messages: Vec, tools: Value) -> Result { info!("Sending LLM request to {}", self.endpoint); - + let mut body = json!({ "model": self.model, "messages": messages, }); if !tools.is_null() && tools.as_array().map(|a| !a.is_empty()).unwrap_or(false) { - body.as_object_mut().unwrap().insert("tools".to_string(), tools); + body.as_object_mut() + .unwrap() + .insert("tools".to_string(), tools); } - let res = self.client.post(&self.endpoint) + let res = self + .client + .post(&self.endpoint) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json") .json(&body) @@ -42,7 +46,11 @@ impl PiLLM { let status = res.status(); let response_json: Value = res.json().await?; if !status.is_success() { - return Err(anyhow::anyhow!("API request failed with status {}: {}", status, response_json)); + return Err(anyhow::anyhow!( + "API request failed with status {}: {}", + status, + response_json + )); } Ok(response_json) diff --git a/examples/pi/src/utils/session_manager.rs b/examples/pi/src/utils/session_manager.rs index 830367a..91f750d 100644 --- a/examples/pi/src/utils/session_manager.rs +++ b/examples/pi/src/utils/session_manager.rs @@ -53,7 +53,7 @@ impl SessionManager { let file = File::open(&self.log_path)?; let reader = BufReader::new(file); - + let mut messages = Vec::new(); for line in reader.lines() { if let Ok(l) = line { @@ -65,7 +65,7 @@ impl SessionManager { } } } - + // Simple linear history for now. // In a full implementation, we would rebuild the tree using parent_id up to head_id. Ok(messages) diff --git a/examples/pi/src/utils/tools/mod.rs b/examples/pi/src/utils/tools/mod.rs index 74d8e03..2ab1b9b 100644 --- a/examples/pi/src/utils/tools/mod.rs +++ b/examples/pi/src/utils/tools/mod.rs @@ -27,7 +27,7 @@ pub fn execute_bash(command: &str, cwd: &str) -> String { .arg(command) .current_dir(cwd) .output(); - + match output { Ok(out) => { let mut result = String::from_utf8_lossy(&out.stdout).to_string(); diff --git a/src/flow.rs b/src/flow.rs index a04432f..e4fe7fc 100644 --- a/src/flow.rs +++ b/src/flow.rs @@ -95,7 +95,7 @@ impl Flow { pub fn to_mermaid(&self) -> String { let mut mermaid = String::from("flowchart TD\n"); - + // 声明所有节点 let mut nodes: Vec<_> = self.nodes.keys().collect(); nodes.sort(); // 确保输出稳定 @@ -114,7 +114,7 @@ impl Flow { mermaid.push_str(&format!(" {} -->|{}| {}\n", from, condition, to)); } } - + mermaid } } @@ -242,11 +242,11 @@ macro_rules! build_batch_flow { #[cfg(test)] mod tests { use super::*; + use crate::node::SubFlowNode; use crate::node::{Node, ProcessResult, ProcessState}; use async_trait::async_trait; use serde_json::json; use strum::Display; - use crate::node::SubFlowNode; #[derive(Debug, Clone, PartialEq, Default, Display)] #[strum(serialize_all = "snake_case")] @@ -409,7 +409,10 @@ mod tests { match result { Ok(val) => { parent_ctx.set("result", val.clone()); - Ok(ProcessResult::new(CustomState::Success, "subflow ok".to_string())) + Ok(ProcessResult::new( + CustomState::Success, + "subflow ok".to_string(), + )) } Err(e) => { parent_ctx.set("error", json!(e.to_string())); @@ -435,7 +438,7 @@ mod tests { let node1 = TestNode::new(json!({"data": "test1"}), CustomState::Success); let node2 = TestNode::new(json!({"data": "test2"}), CustomState::Default); let end_node = TestNode::new(json!({"final_result": "finished"}), CustomState::Default); - + let flow = build_flow!( start: ("start", node1), nodes: [("next", node2), ("end", end_node)], diff --git a/src/node.rs b/src/node.rs index bc00dc1..b077568 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,9 +1,9 @@ +use crate::flow::Flow; use crate::{Params, context::Context}; use anyhow::Result; use async_trait::async_trait; -use strum::Display; use std::sync::Arc; -use crate::flow::Flow; +use strum::Display; pub trait ProcessState: Send + Sync + std::fmt::Display { fn is_default(&self) -> bool; @@ -146,6 +146,7 @@ where { pub sub_flow: Arc>, pub context_builder: Box Context + Send + Sync>, + #[allow(clippy::type_complexity)] pub result_mapper: Box< dyn Fn(&mut Context, &Result) -> Result> + Send @@ -161,10 +162,13 @@ where pub fn new( sub_flow: Flow, context_builder: impl Fn(&Context) -> Context + Send + Sync + 'static, - result_mapper: impl Fn(&mut Context, &Result) -> Result> - + Send - + Sync - + 'static, + result_mapper: impl Fn( + &mut Context, + &Result, + ) -> Result> + + Send + + Sync + + 'static, ) -> Self { Self { sub_flow: Arc::new(sub_flow),