diff --git a/.gitignore b/.gitignore index ab099f4..4f9030e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ target .env Cargo.lock -.vscode/ \ No newline at end of file +.vscode/ +.pi/ +test_dir/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..e73695e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,215 @@ +--- +layout: default +title: "Agentic Coding" +--- + +# Agentic Coding: Humans Design, Agents code! + +> If you are an AI agent building apps with PocketFlow-Rust, read this guide carefully. Start small, design at a high level first (`docs/design.md`), then implement and verify. +{: .warning } + +## Agentic Coding Steps + +Agentic coding should be a collaboration between human system design and AI implementation. + +| Steps | Human | AI | Comment | +|:--|:--:|:--:|:--| +| 1. Requirements | ★★★ High | ★☆☆ Low | Humans define the problem and success criteria. | +| 2. Flow | ★★☆ Medium | ★★☆ Medium | Humans define the orchestration; AI fills in details. | +| 3. Utilities | ★★☆ Medium | ★★☆ Medium | Humans provide external APIs; AI helps implement wrappers. | +| 4. Data | ★☆☆ Low | ★★★ High | AI proposes the schema; humans verify it matches the app. | +| 5. Node | ★☆☆ Low | ★★★ High | AI designs nodes around the flow and shared store. | +| 6. Implementation | ★☆☆ Low | ★★★ High | AI implements the flow and nodes from the design. | +| 7. Optimization | ★★☆ Medium | ★★☆ Medium | Iterate on prompts, data shape, and flow structure. | +| 8. Reliability | ★☆☆ Low | ★★★ High | Add validation, retries, logging, and tests. | + +1. **Requirements**: Clarify the user problem, not just the feature list. + - Good for: repetitive tasks, structured transformations, workflow automation, RAG, agent loops. + - Not good for: vague goals without measurable outputs or unstable business decisions. + - Keep it user-centric and small at first. + +2. **Flow Design**: Define the graph at a high level. + - Pick a pattern if it fits: [Agent](./design_pattern/agent.md), [Workflow](./design_pattern/workflow.md), [RAG](./design_pattern/rag.md), [Map Reduce](./design_pattern/mapreduce.md). + - For each node, write a one-line purpose and its next action conditions. + - Draw the flow in mermaid. + - Use the Rust API as source of truth: `Node`, `Flow`, `BatchFlow`, `ProcessState`. + - Example: + ```mermaid + flowchart LR + start[Load Input] --> process[Process] + process --> finish[Finish] + ``` + - If you cannot describe the flow manually, do not automate it yet. + +3. **Utilities**: Identify required external I/O helpers. + - Think of utilities as the body of the agent: file I/O, web requests, LLM calls, DB access, embeddings. + - Keep LLM tasks inside nodes or utilities, but do not confuse them with orchestration. + - Put reusable wrappers in `src/utils/*.rs` and add a small test when practical. + - Prefer returning `anyhow::Result` and keep wrappers narrow and deterministic. + - Example utility shape (real wrapper from `src/utils/llm_wrapper.rs`): + ```rust + use async_trait::async_trait; + + #[async_trait] + pub trait LLMWrapper { + async fn generate(&self, prompt: &str) -> anyhow::Result; + } + + pub struct OpenAIClient { + api_key: String, + model: String, + endpoint: String, + } + + impl OpenAIClient { + pub fn new(api_key: String, model: String, endpoint: String) -> Self { + Self { api_key, model, endpoint } + } + } + + #[async_trait] + impl LLMWrapper for OpenAIClient { + async fn generate(&self, prompt: &str) -> anyhow::Result { + // Use openai_api_rust or reqwest to call the API + todo!() + } + } + ``` + +4. **Data Design**: Design the shared store before coding the nodes. + - The shared store is `Context`. + - Use `Context::set()` / `Context::get()` for shared data. + - Use `metadata` for auxiliary data that should not be treated as primary results. + - Keep keys simple and avoid redundancy. + - Example: + ```rust + use pocketflow_rs::Context; + use serde_json::json; + + let mut context = Context::new(); + context.set("input", json!("hello")); + context.set_metadata("source", json!("user")); + ``` + +5. **Node Design**: Plan each node’s role and state transitions. + - `prepare(&mut context)`: optional, read from `Context` and prepare inputs. + - `execute(&context)`: do compute or remote calls; keep it idempotent when possible. + - `post_process(&mut context, &result)`: write outputs back to `Context` and return `ProcessResult`. + - Define a custom `State` enum implementing `ProcessState` for branching; use `BaseState` when simple. + - Example node shape (from `examples/basic.rs`): + ```rust + use anyhow::Result; + use async_trait::async_trait; + use pocketflow_rs::{Context, Node, ProcessResult, ProcessState}; + use serde_json::Value; + use strum::Display; + + #[derive(Debug, Clone, PartialEq, Default, Display)] + #[strum(serialize_all = "snake_case")] + enum MyState { + #[default] + Default, + Success, + } + + impl ProcessState for MyState { + fn is_default(&self) -> bool { + matches!(self, MyState::Default) + } + } + + struct MyNode; + + #[async_trait] + impl Node for MyNode { + type State = MyState; + + async fn execute(&self, context: &Context) -> Result { + let input = context.get("input").cloned().unwrap_or(Value::Null); + Ok(input) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + match result { + Ok(value) => { + context.set("output", value.clone()); + Ok(ProcessResult::new(MyState::Success, "done".to_string())) + } + Err(e) => { + context.set("error", Value::String(e.to_string())); + Ok(ProcessResult::new(MyState::Default, e.to_string())) + } + } + } + } + ``` + +6. **Implementation**: Build the initial nodes and flows. + - Keep the first pass simple. + - Use `build_flow!` and `build_batch_flow!` instead of hand-wiring infrastructure. + - Example flow assembly: + ```rust + use pocketflow_rs::{build_flow, Flow, BaseState}; + + pub fn create_flow() -> Flow { + build_flow!( + start: ("get_input", GetInputNode), + nodes: [ + ("process", ProcessNode), + ("output", OutputNode) + ], + edges: [ + ("get_input", "process", BaseState::Default), + ("process", "output", BaseState::Default) + ] + ) + } + ``` + - Add logging via `tracing` where it helps debugging. + - Prefer small, composable nodes over large monoliths. + +7. **Optimization**: Improve after the first working version. + - Refine the flow when the bottleneck is logic or structure. + - Refine prompts and context when the bottleneck is model behavior. + - Refine utilities when the bottleneck is I/O or integration. + +8. **Reliability**: Make failures visible and recoverable. + - Validate results in `execute` or `post_process`. + - Use the framework’s retry behavior where available. + - Add tests for utility wrappers and critical node transitions. + - Log failures and important decisions. + +## Example Rust Project Layout + +``` +my_project/ +├── Cargo.toml +├── src/ +│ ├── main.rs +│ ├── flow.rs +│ ├── nodes.rs +│ └── utils/ +│ ├── mod.rs +│ ├── llm_wrapper.rs +│ └── web_search.rs +└── docs/ + └── design.md +``` + +- **`Cargo.toml`**: add `pocketflow_rs`, `serde_json`, `anyhow`, `async-trait`, `tokio`, and `strum` as dependencies. +- **`docs/design.md`**: keep it high-level and Rust-oriented; do not copy Python pseudocode. +- **`src/utils/`**: one file per reusable integration is a good default. +- **`src/nodes.rs`**: node definitions should stay focused and readable. +- **`src/flow.rs`**: assemble the flow graph and state transitions. +- **State enums**: use `strum::Display` with `#[strum(serialize_all = "snake_case")]` for automatic state-to-string conversion. + +## Before Finishing + +- Check names, types, and examples against `src/lib.rs`, `src/node.rs`, and `src/flow.rs`. +- Remove any Python-only syntax or nonexistent APIs. +- Keep examples runnable or close to runnable against the current Rust API. +- Ensure state enums implement `ProcessState` and use `strum::Display` for edge matching. diff --git a/Cargo.toml b/Cargo.toml index 2ca1114..0ac36f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,15 +10,16 @@ license = "MIT" name = "pocketflow_rs" path = "src/lib.rs" + [[example]] name = "basic" path = "examples/basic.rs" [workspace] -members = [ - "examples/pocketflow-rs-rag", - "examples/text2sql" -] +members = ["examples/pocketflow-rs-rag", "examples/text2sql", "examples/pi"] + +[workspace.dependencies] +strum = "0.26" [dependencies] anyhow = "1.0" @@ -29,16 +30,19 @@ serde_json = "1.0" thiserror = "1.0" tracing = "0.1" rand = "0.8" -openai_api_rust = { version = "0.1.9", optional = true} +openai_api_rust = { version = "0.1.9", optional = true } regex = "1.11.1" -qdrant-client = {version = "1.14.0", optional = true} -reqwest = { version = "0.12", features = ["json"], optional = true } +strum = { version = "0.26", features = ["derive"] } +qdrant-client = { version = "1.14.0", optional = true } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +clap = { version = "4.4", features = ["derive"] } +directories = "5.0" +uuid = { version = "1.8", features = ["v4", "fast-rng"] } +toml = "0.8" [features] openai = ["dep:openai_api_rust"] -websearch = ["dep:reqwest"] +websearch = [] qdrant = ["dep:qdrant-client"] debug = [] -default = [ - "openai", -] \ No newline at end of file +default = ["openai"] diff --git a/README.md b/README.md index 2cb1c76..79875b7 100644 --- a/README.md +++ b/README.md @@ -4,34 +4,41 @@ A Rust implementation of [PocketFlow](https://github.com/The-Pocket/PocketFlow), a minimalist flow-based programming framework. -📋 [Get started quickly with our template →](#template) +🚀 [Get started quickly with our template →](#template) -## Features +## ✨ Features -- Type-safe state transitions using enums -- Macro-based flow construction -- Async node execution and post-processing -- Batch flow support -- Custom state management -- Extensible node system +- 🦀 **Type-safe:** State transitions using Rust enums +- 🏗️ **Macro-based:** Flow construction using `build_flow!` and `build_batch_flow!` +- ⚡ **Async first:** Non-blocking node execution and post-processing +- 📦 **Batch support:** High-performance processing of multiple contexts +- 🧩 **Extensible:** Custom state management and node systems +- 🌳 **Composable:** Support for `SubFlow` nodes to build recursive workflows +- 📊 **Flow Visualization:** Generate Mermaid TD diagrams from flow graphs +- 🛠️ **Utility-rich:** Optional integrations for OpenAI, Qdrant, and web search -## Quick Start +## 🚀 Quick Start ### 0. Setup -```bash -cargo add pocketflow_rs +```toml +[dependencies] +pocketflow_rs = "0.1.0" +strum = { version = "0.26", features = ["derive"] } ``` ### 1. Define Custom States ```rust use pocketflow_rs::ProcessState; +use strum::Display; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Default, Display)] +#[strum(serialize_all = "snake_case")] pub enum MyState { Success, Failure, + #[default] Default, } @@ -39,19 +46,8 @@ impl ProcessState for MyState { fn is_default(&self) -> bool { matches!(self, MyState::Default) } - fn to_condition(&self) -> String { - match self { - MyState::Success => "success".to_string(), - MyState::Failure => "failure".to_string(), - MyState::Default => "default".to_string(), - } - } -} -impl Default for MyState { - fn default() -> Self { - MyState::Default - } + // to_condition() is automatically implemented via Display (strum) } ``` @@ -79,12 +75,12 @@ impl Node for MyNode { result: &Result, ) -> Result> { // Your post-processing logic here - Ok(ProcessResult::new(MyState::Success, "success".to_string())) + Ok(ProcessResult::new(MyState::Success, "success")) } } ``` -### 3. Build Flows +### 3. Build & Run Flows ```rust use pocketflow_rs::{build_flow, Context}; @@ -104,110 +100,62 @@ let context = Context::new(); let result = flow.run(context).await?; ``` -### 4. Batch Processing - -```rust -use pocketflow_rs::build_batch_flow; - -let batch_flow = build_batch_flow!( - start: ("start", node1), - nodes: [("next", node2)], - edges: [ - ("start", "next", MyState::Success) - ], - batch_size: 10 -); - -let contexts = vec![Context::new(); 10]; -batch_flow.run_batch(contexts).await?; -``` - -## Advanced Usage +## 🏗️ Advanced Usage -### Custom State Management +### Composition with SubFlow -Define your own states to control flow transitions: +You can nest a `Flow` inside another `Flow` using `SubFlowNode`. This allows for modular and recursive workflow design. ```rust -#[derive(Debug, Clone, PartialEq)] -pub enum WorkflowState { - Initialized, - Processing, - Completed, - Error, - Default, -} +use pocketflow_rs::{SubFlowNode, Flow}; -impl ProcessState for WorkflowState { - fn is_default(&self) -> bool { - matches!(self, WorkflowState::Default) - } - fn to_condition(&self) -> String { - match self { - WorkflowState::Initialized => "initialized".to_string(), - WorkflowState::Processing => "processing".to_string(), - WorkflowState::Completed => "completed".to_string(), - WorkflowState::Error => "error".to_string(), - WorkflowState::Default => "default".to_string(), - } - } -} +let sub_flow = create_sub_flow(); +let sub_flow_node = SubFlowNode::new( + sub_flow, + |parent_context| parent_context.clone(), // context_builder + |parent_context, result| { /* map sub-flow result back to parent state */ } +); ``` -### Complex Flow Construction +### Flow Visualization -Build complex workflows with multiple nodes and state transitions: +Generate Mermaid TD flowchart strings automatically to visualize your complex workflows. ```rust -let flow = build_flow!( - start: ("start", node1), - nodes: [ - ("process", node2), - ("validate", node3), - ("complete", node4) - ], - edges: [ - ("start", "process", WorkflowState::Initialized), - ("process", "validate", WorkflowState::Processing), - ("validate", "process", WorkflowState::Error), - ("validate", "complete", WorkflowState::Completed) - ] -); +let mermaid = flow.to_mermaid(); +println!("{}", mermaid); ``` -## Available Features +Check out `test_dir/pi_flow.mmd` after running the `pi` agent for a real-world example. -The following features are available: (feature for [utility_function](https://the-pocket.github.io/PocketFlow/utility_function/)) +## 🛠️ Available Features -- `openai` (default): Enable OpenAI API integration for LLM capabilities -- `websearch`: Enable web search functionality using Google Custom Search API -- `qdrant`: Enable vector database integration using Qdrant -- `debug`: Enable additional debug logging and information +Customize `pocketflow_rs` by enabling the features you need in your `Cargo.toml`: -To use specific features, add them to your `Cargo.toml`: +| Feature | Description | +|---------|-------------| +| `openai` (default) | OpenAI API integration for LLM capabilities | +| `websearch` | Google Custom Search API integration | +| `qdrant` | Vector database integration using Qdrant | +| `debug` | Enhanced logging and visualization tools | +Example: ```toml -[dependencies] -pocketflow_rs = { version = "0.1.0", features = ["openai", "websearch"] } -``` - -Or use them in the command line: - -```bash -cargo add pocketflow_rs --features "openai websearch" +pocketflow_rs = { version = "0.1.0", features = ["openai", "qdrant"] } ``` -## Examples +## 📂 Examples -Check out the `examples/` directory for more detailed examples: +Check out the `examples/` directory for detailed implementations: -- basic.rs: Basic flow with custom states -- text2sql: Text-to-SQL workflow example -- [pocketflow-rs-rag](./examples/pocketflow-rs-rag/README.md): Retrieval-Augmented Generation (RAG) workflow example +- 🟢 [**basic.rs**](./examples/basic.rs): Basic flow with custom states +- 🤖 [**pi**](./examples/pi/): Modular interactive coding agent with tool-use and history compaction +- 🗃️ [**text2sql**](./examples/text2sql/): Text-to-SQL workflow using OpenAI +- 🔍 [**pocketflow-rs-rag**](./examples/pocketflow-rs-rag/): Retrieval-Augmented Generation (RAG) system -## Template +## 📋 Template -Fork the [PocketFlow-Template-Rust](https://github.com/The-Pocket/PocketFlow-Template-Rust) repository and use it as a template for your own project. +Don't start from scratch! Use the [PocketFlow-Template-Rust](https://github.com/The-Pocket/PocketFlow-Template-Rust) to kickstart your project. ## License diff --git a/docs/design.md b/docs/design.md new file mode 100644 index 0000000..8838469 --- /dev/null +++ b/docs/design.md @@ -0,0 +1,78 @@ +# Design Doc: Your Project Name + +> Please DON'T remove notes for AI + +## Requirements + +> Notes for AI: Keep it simple and clear. +> If the requirements are abstract, write concrete user stories + + +## Flow Design + +> Notes for AI: +> 1. Consider the design patterns of agent, map-reduce, rag, and workflow. Apply them if they fit. +> 2. Present a concise, high-level description of the workflow. + +### Applicable Design Pattern: + +1. Map the file summary into chunks, then reduce these chunks into a final summary. +2. Agentic file finder + - *Context*: The entire summary of the file + - *Action*: Find the file + +### Flow high-level Design: + +1. **First Node**: This node is for ... +2. **Second Node**: This node is for ... +3. **Third Node**: This node is for ... + +```mermaid +flowchart TD + firstNode[First Node] --> secondNode[Second Node] + secondNode --> thirdNode[Third Node] +``` +## Utility Functions + +> Notes for AI: +> 1. Understand the utility function definition thoroughly by reviewing the doc. +> 2. Include only the necessary utility functions, based on nodes in the flow. + +1. **Call LLM** (`utils/call_llm.py`) + - *Input*: prompt (str) + - *Output*: response (str) + - Generally used by most nodes for LLM tasks + +2. **Embedding** (`utils/get_embedding.py`) + - *Input*: str + - *Output*: a vector of 3072 floats + - Used by the second node to embed text + +## Node Design + +### Shared Store + +> Notes for AI: Try to minimize data redundancy + +The shared store structure is organized as follows: + +```python +shared = { + "key": "value" +} +``` + +### Node Steps + +> Notes for AI: Carefully decide whether to use Batch/Async Node/Flow. + +1. First Node + - *Purpose*: Provide a short explanation of the node’s function + - *Type*: Decide between Regular, Batch, or Async + - *Steps*: + - *prep*: Read "key" from the shared store + - *exec*: Call the utility function + - *post*: Write "key" to the shared store + +2. Second Node + ... diff --git a/examples/basic.rs b/examples/basic.rs index 421763c..15052ba 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -2,8 +2,9 @@ use anyhow::Result; use pocketflow_rs::{Context, Node, ProcessResult, ProcessState, build_flow}; use rand::Rng; use serde_json::Value; - -#[derive(Debug, Clone, PartialEq, Default)] +use strum::Display; +#[derive(Debug, Clone, PartialEq, Default, Display)] +#[strum(serialize_all = "snake_case")] enum NumberState { Small, Medium, @@ -16,15 +17,6 @@ impl ProcessState for NumberState { fn is_default(&self) -> bool { matches!(self, NumberState::Default) } - - fn to_condition(&self) -> String { - match self { - NumberState::Small => "small".to_string(), - NumberState::Medium => "medium".to_string(), - NumberState::Large => "large".to_string(), - NumberState::Default => "default".to_string(), - } - } } // A simple node that prints a message @@ -165,6 +157,12 @@ async fn main() -> std::result::Result<(), Box> { // Create context let context = Context::new(); + // Export flow visualization + std::fs::create_dir_all("test_dir")?; + let mermaid = flow.to_mermaid(); + std::fs::write("test_dir/basic_flow.mmd", mermaid)?; + println!("Saved flow visualization to test_dir/basic_flow.mmd"); + // Run the flow println!("Starting flow execution..."); flow.run(context).await?; diff --git a/examples/pi/CONFIG_GUIDE.md b/examples/pi/CONFIG_GUIDE.md new file mode 100644 index 0000000..8ae9d67 --- /dev/null +++ b/examples/pi/CONFIG_GUIDE.md @@ -0,0 +1,72 @@ +# PocketFlow-RS 配置指南 + +本项目使用 `config.toml` 管理模型、API 提供商和通用设置。配置文件位于工作区下的 `./.pi/config.toml`。 + +## 文件位置 +- **默认路径**:`./.pi/config.toml` +- **日志路径**:`./.pi/logs/log.jsonl` + +> 提示:当你首次运行 `pi` 命令时,若配置不存在,会自动创建默认配置。 + +## 配置结构 + +```toml +[general] +auto_compact = true + +[providers.openai] +api_base = "https://api.openai.com/v1" +api_key_env = "OPENAI_API_KEY" + +[providers.cerebras] +api_base = "http://127.0.0.1:4000/v1" +api_key_env = "OPENAI_API_KEY" + +[models."gpt-4o"] +provider = "openai" +context_window = 128000 +compact_threshold = 100000 + +[models."cerebras/qwen-3-235b-a22b-instruct-2507"] +provider = "cerebras" +context_window = 8192 +compact_threshold = 6000 +``` + +### 字段说明 + +| 区域 | 字段 | 说明 | +|------|------|------| +| `general` | `auto_compact` | 是否开启自动历史压缩(摘要旧消息以节省上下文) | +| `providers` | `api_base` | 模型提供商的 API 地址 | +| `providers` | `api_key_env` | 读取 API 密钥的环境变量名 | +| `models` | `provider` | 该模型使用的提供商(需匹配上方 `providers` 中的键) | +| `models` | `context_window` | 模型最大上下文长度(token 数) | +| `models` | `compact_threshold` | 超过此长度时触发自动压缩 | + +## 自定义配置示例 + +你可以添加本地大模型支持,例如: + +```toml +[providers.localhost] +api_base = "http://127.0.0.1:8080/v1" +api_key_env = "LOCAL_API_KEY" + +[models."qwen:14b"] +provider = "localhost" +context_window = 32768 +compact_threshold = 24000 +``` + +然后运行: +```bash +pi --model qwen:14b --provider localhost +``` + +--- + +📌 提示:确保相应 `api_key_env` 环境变量已设置,例如: +```bash +export OPENAI_API_KEY=sk-... +``` diff --git a/examples/pi/Cargo.toml b/examples/pi/Cargo.toml new file mode 100644 index 0000000..4b26e56 --- /dev/null +++ b/examples/pi/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "pi" +version = "0.1.0" +edition = "2024" + +[dependencies] +pocketflow_rs = { path = "../../", features = ["openai"] } +anyhow = "1.0" +tokio = { version = "1.0", features = ["full"] } +serde_json = "1.0" +clap = { version = "4.4", features = ["derive"] } +uuid = { version = "1.8", features = ["v4", "fast-rng"] } +strum = { version = "0.26", features = ["derive"] } +async-trait = "0.1" +serde = { version = "1.0", features = ["derive"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +toml = "0.8" +tracing = "0.1" diff --git a/examples/pi/src/lib.rs b/examples/pi/src/lib.rs new file mode 100644 index 0000000..c891e4c --- /dev/null +++ b/examples/pi/src/lib.rs @@ -0,0 +1,7 @@ +pub mod nodes; +pub mod state; +pub mod utils; + +pub use nodes::*; +pub use state::*; +pub use utils::*; diff --git a/examples/pi/src/main.rs b/examples/pi/src/main.rs new file mode 100644 index 0000000..d7f218f --- /dev/null +++ b/examples/pi/src/main.rs @@ -0,0 +1,132 @@ +use anyhow::Result; +use clap::Parser; +use pi::{ + AppConfig, AppContext, CheckSizeNode, DoCompactNode, InputNode, LLMReasoningNode, PiLLM, + PiState, SessionManager, ToolExecutionNode, +}; +use pocketflow_rs::{Context, build_flow}; +use serde_json::json; +use std::sync::Arc; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(short, long)] + interactive: bool, + + #[arg(short, long, default_value = "openai")] + provider: String, + + #[arg(short, long, default_value = "gpt-4o")] + model: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + + // Setup directory and SessionManager + let cwd = std::env::current_dir()?; + let session_manager = SessionManager::new(&cwd); + + // Load config + let config = AppConfig::load(&cwd)?; + + // Load API Key + let (api_key, mut endpoint): (String, String) = + if let Some(model_conf) = config.models.get(&args.model) { + if let Some(provider_conf) = config.providers.get(&model_conf.provider) { + let key = std::env::var(&provider_conf.api_key_env) + .unwrap_or_else(|_| "dummy_key".to_string()); + (key, provider_conf.api_base.clone()) + } else { + ( + std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()), + "https://api.openai.com/v1".to_string(), + ) + } + } else { + println!( + "Warning: Model '{}' not found in config, falling back to openai env vars.", + args.model + ); + ( + std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "dummy_key".to_string()), + std::env::var("OPENAI_BASE_URL") + .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()), + ) + }; + + if !endpoint.ends_with("/chat/completions") { + endpoint = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + } + + let llm = PiLLM::new(api_key, args.model.clone(), endpoint); + + let app_context = Arc::new(AppContext { + llm, + session_manager, + config, + model_name: args.model.clone(), + }); + + let input_node = InputNode { + app: app_context.clone(), + }; + let check_size_node = CheckSizeNode { + app: app_context.clone(), + }; + let compact_node = DoCompactNode { + app: app_context.clone(), + }; + let llm_node = LLMReasoningNode { + app: app_context.clone(), + }; + let tool_node = ToolExecutionNode { + app: app_context.clone(), + }; + + let flow = build_flow!( + start: ("input", input_node), + nodes: [ + ("check_size", check_size_node), + ("do_compact", compact_node), + ("llm", llm_node), + ("tool", tool_node) + ], + edges: [ + ("input", "check_size", PiState::CheckSize), + ("check_size", "do_compact", PiState::DoCompact), + ("check_size", "llm", PiState::CallLLM), + ("do_compact", "llm", PiState::CallLLM), + ("llm", "tool", PiState::ExecuteTool), + ("llm", "input", PiState::WaitForInput), + ("tool", "check_size", PiState::CheckSize) + ] + ); + + let mut context = Context::new(); + + // Load history + let history = app_context.session_manager.load_history(None)?; + if !history.is_empty() { + println!("Loaded {} messages from history.", history.len()); + let val = serde_json::to_value(history)?; + context.set("messages", val); + } else { + context.set("messages", json!([])); + } + // Export flow visualization + std::fs::create_dir_all("test_dir")?; + let mermaid = flow.to_mermaid(); + std::fs::write("test_dir/pi_flow.mmd", mermaid)?; + println!("Saved flow visualization to test_dir/pi_flow.mmd"); + println!("pi agent started. Type 'exit' to quit."); + + match flow.run(context).await { + Ok(_) => println!("Agent shutdown."), + Err(e) => eprintln!("Error running flow: {}", e), + } + + Ok(()) +} diff --git a/examples/pi/src/nodes/check_size.rs b/examples/pi/src/nodes/check_size.rs new file mode 100644 index 0000000..17edb0e --- /dev/null +++ b/examples/pi/src/nodes/check_size.rs @@ -0,0 +1,63 @@ +use crate::state::{AppContext, PiState}; +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{Value, json}; +use std::sync::Arc; + +pub struct CheckSizeNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for CheckSizeNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let messages = context.get("messages").unwrap_or(&json!([])).clone(); + let msgs_str = serde_json::to_string(&messages).unwrap_or_default(); + let estimated_tokens = msgs_str.len() / 4; + + // Handle gracefully if config doesn't perfectly match model + let default_compact_thresh = 100000; + let threshold = self + .app + .config + .models + .get(&self.app.model_name) + .map(|m| m.compact_threshold) + .unwrap_or(default_compact_thresh); + + if self.app.config.general.auto_compact + && estimated_tokens > threshold + && messages.as_array().map(|a| a.len() > 3).unwrap_or(false) + { + println!( + "\n[Auto Compacting History (est {} tokens > {})]...", + estimated_tokens, threshold + ); + Ok(json!({ "needs_compact": true, "history_str": msgs_str })) + } else { + Ok(json!({ "needs_compact": false })) + } + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = result.as_ref().unwrap(); + if res.get("needs_compact").and_then(|v| v.as_bool()) == Some(true) { + context.set( + "history_to_compact", + res.get("history_str").unwrap().clone(), + ); + Ok(ProcessResult::new( + PiState::DoCompact, + "do_compact".to_string(), + )) + } else { + Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + } + } +} diff --git a/examples/pi/src/nodes/do_compact.rs b/examples/pi/src/nodes/do_compact.rs new file mode 100644 index 0000000..6e89284 --- /dev/null +++ b/examples/pi/src/nodes/do_compact.rs @@ -0,0 +1,92 @@ +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{Value, json}; +use std::sync::Arc; +use uuid::Uuid; + +pub struct DoCompactNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for DoCompactNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let history_str = context.get("history_to_compact").unwrap().as_str().unwrap(); + + let summary_prompt = json!({ + "role": "user", + "content": format!("Summarize the entire conversation history concisely, retaining all tool outcomes and important context so it can be used to replace the history entirely:\n{}", history_str) + }); + + println!("Sending compaction request to LLM..."); + let mut retries = 0; + let max_retries = 3; + loop { + match self + .app + .llm + .chat_completion(vec![summary_prompt.clone()], Value::Null) + .await + { + Ok(summary_res) => return Ok(summary_res), + Err(e) => { + retries += 1; + if retries > max_retries { + return Err(e); + } + println!( + "[Compaction Failed]: {}. Retrying ({}/{}) in 2 seconds...", + e, retries, max_retries + ); + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + } + } + } + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = match result { + Ok(v) => v, + Err(e) => { + println!("[Compaction Failed]: {}", e); + return Ok(ProcessResult::new( + PiState::CallLLM, + "call_llm_fallback".to_string(), + )); + } + }; + + if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { + if let Some(choice) = choices.first() { + let summary_text = choice["message"]["content"].as_str().unwrap_or(""); + + let compact_msg = AgentMessage { + id: Uuid::new_v4().to_string(), + parent_id: None, + role: "system".to_string(), + content: format!("Previous conversation summary:\n{}", summary_text), + name: None, + tool_calls: None, + tool_call_id: None, + clears_history: Some(true), + }; + + self.app.session_manager.append_message(&compact_msg)?; + + let messages = json!([compact_msg]); + context.set("messages", messages); + println!("History compressed successfully."); + } + } + + Ok(ProcessResult::new(PiState::CallLLM, "call_llm".to_string())) + } +} diff --git a/examples/pi/src/nodes/input.rs b/examples/pi/src/nodes/input.rs new file mode 100644 index 0000000..bf76cf0 --- /dev/null +++ b/examples/pi/src/nodes/input.rs @@ -0,0 +1,71 @@ +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{Value, json}; +use std::io::{self, Write}; +use std::sync::Arc; +use uuid::Uuid; + +pub struct InputNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for InputNode { + type State = PiState; + + async fn execute(&self, _context: &Context) -> Result { + print!("> "); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let text = input.trim().to_string(); + + if text == "exit" || text == "quit" { + return Ok(json!({ "command": "exit" })); + } + + let id = Uuid::new_v4().to_string(); + + let msg = AgentMessage { + id: id.clone(), + parent_id: None, + role: "user".to_string(), + content: text, + name: None, + tool_calls: None, + tool_call_id: None, + clears_history: None, + }; + + // Persist immediately + self.app.session_manager.append_message(&msg)?; + + Ok(json!({ "message": msg })) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = result.as_ref().unwrap(); + if res.get("command").and_then(|v| v.as_str()) == Some("exit") { + return Ok(ProcessResult::new( + PiState::Finished, + "finished".to_string(), + )); + } + + let msg_val = res.get("message").unwrap(); + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); + messages.as_array_mut().unwrap().push(msg_val.clone()); + context.set("messages", messages); + + Ok(ProcessResult::new( + PiState::CheckSize, + "check_size".to_string(), + )) + } +} diff --git a/examples/pi/src/nodes/llm_reasoning.rs b/examples/pi/src/nodes/llm_reasoning.rs new file mode 100644 index 0000000..25c3e69 --- /dev/null +++ b/examples/pi/src/nodes/llm_reasoning.rs @@ -0,0 +1,164 @@ +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{Value, json}; +use std::sync::Arc; +use uuid::Uuid; + +pub struct LLMReasoningNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for LLMReasoningNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let messages = context.get("messages").unwrap_or(&json!([])).clone(); + + let tools = json!([ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read the contents of a file", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string" } + }, + "required": ["path"] + } + } + }, + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write contents to a file", + "parameters": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "content": { "type": "string" } + }, + "required": ["path", "content"] + } + } + }, + { + "type": "function", + "function": { + "name": "bash", + "description": "Execute a bash command", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" } + }, + "required": ["command"] + } + } + } + ]); + + let mut openai_messages = Vec::new(); + // Convert AgentMessage to format expected by OpenAI + if let Some(arr) = messages.as_array() { + for m in arr { + let mut mapped = json!({ + "role": m["role"].as_str().unwrap(), + "content": m["content"].as_str().unwrap() + }); + if let Some(calls) = m.get("tool_calls") { + if !calls.is_null() { + mapped + .as_object_mut() + .unwrap() + .insert("tool_calls".to_string(), calls.clone()); + } + } + if let Some(tid) = m.get("tool_call_id") { + if !tid.is_null() { + mapped + .as_object_mut() + .unwrap() + .insert("tool_call_id".to_string(), tid.clone()); + } + } + openai_messages.push(mapped); + } + } + + let response = self.app.llm.chat_completion(openai_messages, tools).await?; + Ok(response) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let res = match result { + Ok(v) => v, + Err(e) => { + println!("\n[LLM Error]: {}\n", e); + return Ok(ProcessResult::new( + PiState::WaitForInput, + "error".to_string(), + )); + } + }; + + // Ensure choice 0 exists + if let Some(choices) = res.get("choices").and_then(|c| c.as_array()) { + if let Some(choice) = choices.first() { + let msg = choice.get("message").unwrap(); + let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or(""); + let tool_calls = msg.get("tool_calls"); + + let agent_msg = AgentMessage { + id: Uuid::new_v4().to_string(), + parent_id: None, + role: "assistant".to_string(), + content: content.to_string(), + name: None, + tool_calls: tool_calls.cloned(), + tool_call_id: None, + clears_history: None, + }; + + // Persist + self.app.session_manager.append_message(&agent_msg)?; + + // Print + if !content.is_empty() { + println!("\nAssistant: {}\n", content); + } + + // Update context + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); + messages + .as_array_mut() + .unwrap() + .push(serde_json::to_value(&agent_msg)?); + context.set("messages", messages); + + if let Some(tc) = tool_calls { + if !tc.is_null() && tc.as_array().map_or(false, |a| !a.is_empty()) { + return Ok(ProcessResult::new( + PiState::ExecuteTool, + "execute_tool".to_string(), + )); + } + } + } + } + + Ok(ProcessResult::new( + PiState::WaitForInput, + "wait_for_input".to_string(), + )) + } +} diff --git a/examples/pi/src/nodes/mod.rs b/examples/pi/src/nodes/mod.rs new file mode 100644 index 0000000..230f3d5 --- /dev/null +++ b/examples/pi/src/nodes/mod.rs @@ -0,0 +1,11 @@ +mod check_size; +mod do_compact; +mod input; +mod llm_reasoning; +mod tool_execution; + +pub use check_size::CheckSizeNode; +pub use do_compact::DoCompactNode; +pub use input::InputNode; +pub use llm_reasoning::LLMReasoningNode; +pub use tool_execution::ToolExecutionNode; diff --git a/examples/pi/src/nodes/tool_execution.rs b/examples/pi/src/nodes/tool_execution.rs new file mode 100644 index 0000000..d828ace --- /dev/null +++ b/examples/pi/src/nodes/tool_execution.rs @@ -0,0 +1,93 @@ +use crate::state::{AppContext, PiState}; +use crate::utils::session_manager::AgentMessage; +use crate::utils::tools::{execute_bash, read_file, write_file}; +use anyhow::Result; +use pocketflow_rs::{Context, Node, ProcessResult}; +use serde_json::{Value, json}; +use std::sync::Arc; + +pub struct ToolExecutionNode { + pub app: Arc, +} + +#[async_trait::async_trait] +impl Node for ToolExecutionNode { + type State = PiState; + + async fn execute(&self, context: &Context) -> Result { + let messages = context.get("messages").unwrap().as_array().unwrap(); + let last_msg = messages.last().unwrap(); + + let mut tool_results = Vec::new(); + + if let Some(tool_calls) = last_msg.get("tool_calls").and_then(|tc| tc.as_array()) { + for call in tool_calls { + let id = call["id"].as_str().unwrap().to_string(); + let func = &call["function"]; + let name = func["name"].as_str().unwrap(); + let args_str = func["arguments"].as_str().unwrap(); + let args: Value = serde_json::from_str(args_str)?; + + println!("Executing tool: {} with args: {}", name, args_str); + + let output = match name { + "read_file" => { + let path = args["path"].as_str().unwrap(); + read_file(path) + } + "write_file" => { + let path = args["path"].as_str().unwrap(); + let content = args["content"].as_str().unwrap(); + write_file(path, content) + } + "bash" => { + let command = args["command"].as_str().unwrap(); + execute_bash(command, ".") + } + _ => format!("Unknown tool: {}", name), + }; + + let agent_msg = AgentMessage { + id: uuid::Uuid::new_v4().to_string(), + parent_id: None, + role: "tool".to_string(), + content: output, + name: Some(name.to_string()), + tool_calls: None, + tool_call_id: Some(id), + clears_history: None, + }; + + tool_results.push(agent_msg); + } + } + + Ok(serde_json::to_value(tool_results)?) + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + let tool_results: Vec = + serde_json::from_value(result.as_ref().unwrap().clone())?; + + let mut messages = context.get("messages").cloned().unwrap_or(json!([])); + + for msg in tool_results { + self.app.session_manager.append_message(&msg)?; + messages + .as_array_mut() + .unwrap() + .push(serde_json::to_value(&msg)?); + } + + context.set("messages", messages); + + Ok(ProcessResult::new( + PiState::CheckSize, + "check_size".to_string(), + )) + } +} diff --git a/examples/pi/src/state.rs b/examples/pi/src/state.rs new file mode 100644 index 0000000..a0875fb --- /dev/null +++ b/examples/pi/src/state.rs @@ -0,0 +1,32 @@ +use crate::utils::config::AppConfig; +use crate::utils::pi_llm::PiLLM; +use crate::utils::session_manager::SessionManager; +use pocketflow_rs::ProcessState; +use strum::Display; + +#[derive(Debug, Clone, PartialEq, Default, Display)] +#[strum(serialize_all = "snake_case")] +pub enum PiState { + #[default] + Default, + CheckSize, + DoCompact, + CallLLM, + ExecuteTool, + WaitForInput, + Finished, +} + +impl ProcessState for PiState { + fn is_default(&self) -> bool { + matches!(self, PiState::Default) + } +} + +// Global shared components between nodes +pub struct AppContext { + pub llm: PiLLM, + pub session_manager: SessionManager, + pub config: AppConfig, + pub model_name: String, +} diff --git a/examples/pi/src/utils/config.rs b/examples/pi/src/utils/config.rs new file mode 100644 index 0000000..44b5903 --- /dev/null +++ b/examples/pi/src/utils/config.rs @@ -0,0 +1,104 @@ +use anyhow::{Context, Result}; +use serde::Deserialize; +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +#[derive(Debug, Deserialize, Clone)] +pub struct AppConfig { + #[serde(default)] + pub general: GeneralConfig, + pub providers: HashMap, + pub models: HashMap, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct GeneralConfig { + #[serde(default = "default_auto_compact")] + pub auto_compact: bool, +} + +impl Default for GeneralConfig { + fn default() -> Self { + Self { + auto_compact: default_auto_compact(), + } + } +} + +fn default_auto_compact() -> bool { + true +} + +#[derive(Debug, Deserialize, Clone)] +pub struct ProviderConfig { + pub api_base: String, + pub api_key_env: String, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct ModelConfig { + pub provider: String, + pub context_window: usize, + pub compact_threshold: usize, +} + +impl AppConfig { + pub fn load>(workspace: P) -> Result { + let workspace = workspace.as_ref(); + let mut config_dir = workspace.to_path_buf(); + config_dir.push(".pi"); + + if !config_dir.exists() { + fs::create_dir_all(&config_dir)?; + } + + let mut config_path = config_dir.clone(); + config_path.push("config.toml"); + + if !config_path.exists() { + let default_toml = r#"[general] +auto_compact = true + +[providers.openai] +api_base = "https://api.openai.com/v1" +api_key_env = "OPENAI_API_KEY" + +[providers.cerebras] +api_base = "http://127.0.0.1:4000/v1" +api_key_env = "OPENAI_API_KEY" + +[models."gpt-4o"] +provider = "openai" +context_window = 128000 +compact_threshold = 100000 + +[models."cerebras/qwen-3-235b-a22b-instruct-2507"] +provider = "cerebras" +context_window = 8192 +compact_threshold = 6000 +"#; + fs::write(&config_path, default_toml.trim())?; + } + + let content = fs::read_to_string(&config_path) + .with_context(|| format!("Failed to read config file at {:?}", config_path))?; + + let config: AppConfig = + toml::from_str(&content).with_context(|| "Failed to parse config.toml")?; + + Ok(config) + } + + pub fn config_dir>(workspace: P) -> std::path::PathBuf { + let mut path = workspace.as_ref().to_path_buf(); + path.push(".pi"); + path + } + + pub fn logs_dir>(workspace: P) -> std::path::PathBuf { + let mut path = Self::config_dir(workspace); + path.push("logs"); + path + } +} diff --git a/examples/pi/src/utils/mod.rs b/examples/pi/src/utils/mod.rs new file mode 100644 index 0000000..a6f209e --- /dev/null +++ b/examples/pi/src/utils/mod.rs @@ -0,0 +1,8 @@ +pub mod config; +pub mod pi_llm; +pub mod session_manager; +pub mod tools; + +pub use config::AppConfig; +pub use pi_llm::PiLLM; +pub use session_manager::SessionManager; diff --git a/examples/pi/src/utils/pi_llm.rs b/examples/pi/src/utils/pi_llm.rs new file mode 100644 index 0000000..a5b4dbd --- /dev/null +++ b/examples/pi/src/utils/pi_llm.rs @@ -0,0 +1,58 @@ +use anyhow::Result; +use reqwest::Client; +use serde_json::{Value, json}; +use tracing::info; + +pub struct PiLLM { + client: Client, + api_key: String, + model: String, + endpoint: String, +} + +impl PiLLM { + pub fn new(api_key: String, model: String, endpoint: String) -> Self { + Self { + client: Client::new(), + api_key, + model, + endpoint, + } + } + + pub async fn chat_completion(&self, messages: Vec, tools: Value) -> Result { + info!("Sending LLM request to {}", self.endpoint); + + let mut body = json!({ + "model": self.model, + "messages": messages, + }); + + if !tools.is_null() && tools.as_array().map(|a| !a.is_empty()).unwrap_or(false) { + body.as_object_mut() + .unwrap() + .insert("tools".to_string(), tools); + } + + let res = self + .client + .post(&self.endpoint) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + let status = res.status(); + let response_json: Value = res.json().await?; + if !status.is_success() { + return Err(anyhow::anyhow!( + "API request failed with status {}: {}", + status, + response_json + )); + } + + Ok(response_json) + } +} diff --git a/examples/pi/src/utils/session_manager.rs b/examples/pi/src/utils/session_manager.rs new file mode 100644 index 0000000..91f750d --- /dev/null +++ b/examples/pi/src/utils/session_manager.rs @@ -0,0 +1,73 @@ +use serde::{Deserialize, Serialize}; +use std::fs::{File, OpenOptions}; +use std::io::{BufRead, BufReader, Write}; +use std::path::{Path, PathBuf}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentMessage { + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_id: Option, + pub role: String, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub clears_history: Option, +} + +pub struct SessionManager { + log_path: PathBuf, +} + +impl SessionManager { + pub fn new(workspace: &Path) -> Self { + let mut log_path = workspace.to_path_buf(); + log_path.push(".pi"); + log_path.push("logs"); + if !log_path.exists() { + std::fs::create_dir_all(&log_path).unwrap_or_default(); + } + log_path.push("log.jsonl"); + Self { log_path } + } + + pub fn append_message(&self, message: &AgentMessage) -> anyhow::Result<()> { + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.log_path)?; + let json = serde_json::to_string(message)?; + writeln!(file, "{}", json)?; + Ok(()) + } + + pub fn load_history(&self, _head_id: Option<&str>) -> anyhow::Result> { + if !self.log_path.exists() { + return Ok(Vec::new()); + } + + let file = File::open(&self.log_path)?; + let reader = BufReader::new(file); + + let mut messages = Vec::new(); + for line in reader.lines() { + if let Ok(l) = line { + if let Ok(msg) = serde_json::from_str::(&l) { + if msg.clears_history == Some(true) { + messages.clear(); + } + messages.push(msg); + } + } + } + + // Simple linear history for now. + // In a full implementation, we would rebuild the tree using parent_id up to head_id. + Ok(messages) + } +} diff --git a/examples/pi/src/utils/tools/mod.rs b/examples/pi/src/utils/tools/mod.rs new file mode 100644 index 0000000..2ab1b9b --- /dev/null +++ b/examples/pi/src/utils/tools/mod.rs @@ -0,0 +1,47 @@ +// use anyhow::Result; +use std::fs; +use std::path::Path; +use std::process::Command; + +pub fn read_file(path: &str) -> String { + match fs::read_to_string(path) { + Ok(content) => content, + Err(e) => format!("Error reading file: {}", e), + } +} + +pub fn write_file(path: &str, content: &str) -> String { + let p = Path::new(path); + if let Some(parent) = p.parent() { + let _ = fs::create_dir_all(parent); + } + match fs::write(p, content) { + Ok(_) => format!("Successfully wrote to {}", path), + Err(e) => format!("Error writing file: {}", e), + } +} + +pub fn execute_bash(command: &str, cwd: &str) -> String { + let output = Command::new("bash") + .arg("-c") + .arg(command) + .current_dir(cwd) + .output(); + + match output { + Ok(out) => { + let mut result = String::from_utf8_lossy(&out.stdout).to_string(); + let stderr = String::from_utf8_lossy(&out.stderr).to_string(); + if !stderr.is_empty() { + result.push_str("\n--- STDERR ---\n"); + result.push_str(&stderr); + } + if result.is_empty() { + "Command executed successfully with no output.".to_string() + } else { + result + } + } + Err(e) => format!("Error executing command: {}", e), + } +} diff --git a/examples/pocketflow-rs-rag/Cargo.toml b/examples/pocketflow-rs-rag/Cargo.toml index ad3117d..eab0970 100644 --- a/examples/pocketflow-rs-rag/Cargo.toml +++ b/examples/pocketflow-rs-rag/Cargo.toml @@ -18,6 +18,7 @@ reqwest = { version = "0.12.15", features = ["json"] } uuid = { version = "1.16.0", features = ["v4"] } qdrant-client = "1.14.0" termimad = "0.31.3" +strum ={ workspace = true } [dev-dependencies] tempfile = "3.8" diff --git a/examples/pocketflow-rs-rag/src/state.rs b/examples/pocketflow-rs-rag/src/state.rs index a266e5f..71f27f4 100644 --- a/examples/pocketflow-rs-rag/src/state.rs +++ b/examples/pocketflow-rs-rag/src/state.rs @@ -1,6 +1,8 @@ use pocketflow_rs::ProcessState; +use strum::Display; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Display)] +#[strum(serialize_all = "snake_case")] pub enum RagState { // Offline states FileLoadedError, @@ -29,32 +31,6 @@ impl ProcessState for RagState { fn is_default(&self) -> bool { matches!(self, RagState::Default) } - - fn to_condition(&self) -> String { - match self { - // Offline states - RagState::FileLoadedError => "file_loaded_error".to_string(), - RagState::DocumentsLoaded => "documents_loaded".to_string(), - RagState::DocumentsChunked => "documents_chunked".to_string(), - RagState::ChunksEmbedded => "chunks_embedded".to_string(), - RagState::IndexCreated => "index_created".to_string(), - // Offline error states - RagState::DocumentLoadError => "document_load_error".to_string(), - RagState::ChunkingError => "chunking_error".to_string(), - RagState::EmbeddingError => "embedding_error".to_string(), - RagState::IndexCreationError => "index_creation_error".to_string(), - // Online states - RagState::QueryEmbedded => "query_embedded".to_string(), - RagState::DocumentsRetrieved => "documents_retrieved".to_string(), - RagState::AnswerGenerated => "answer_generated".to_string(), - // Online error states - RagState::QueryEmbeddingError => "query_embedding_error".to_string(), - RagState::RetrievalError => "retrieval_error".to_string(), - RagState::GenerationError => "generation_error".to_string(), - RagState::Default => "default".to_string(), - RagState::QueryRewriteError => "query_rewrite_error".to_string(), - } - } } impl Default for RagState { diff --git a/examples/text2sql/src/flow.rs b/examples/text2sql/src/flow.rs index 9eb651f..b8b40ae 100644 --- a/examples/text2sql/src/flow.rs +++ b/examples/text2sql/src/flow.rs @@ -1,3 +1,5 @@ +use std::default; + use anyhow::{Context as AnyhowContext, Result}; use async_trait::async_trait; use chrono::NaiveDate; @@ -7,13 +9,16 @@ use openai_api_rust::chat::*; use openai_api_rust::*; use pocketflow_rs::{Context, Node, ProcessResult, ProcessState}; use serde_json::{Value, json}; +use strum::Display; use tracing::{error, info}; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Display)] +#[strum(serialize_all = "snake_case")] pub enum SqlExecutorState { SchemaRetrieved, SqlGenerated, SqlExecuted, + #[default] Default, } @@ -21,21 +26,6 @@ impl ProcessState for SqlExecutorState { fn is_default(&self) -> bool { matches!(self, SqlExecutorState::Default) } - - fn to_condition(&self) -> String { - match self { - SqlExecutorState::SchemaRetrieved => "schema_retrieved".to_string(), - SqlExecutorState::SqlGenerated => "sql_generated".to_string(), - SqlExecutorState::SqlExecuted => "sql_executed".to_string(), - SqlExecutorState::Default => "default".to_string(), - } - } -} - -impl Default for SqlExecutorState { - fn default() -> Self { - SqlExecutorState::Default - } } #[derive(Debug, thiserror::Error)] diff --git a/src/flow.rs b/src/flow.rs index 63f0aa2..e4fe7fc 100644 --- a/src/flow.rs +++ b/src/flow.rs @@ -92,6 +92,31 @@ impl Flow { Ok(context.get("result").unwrap_or(&Value::Null).clone()) } + + pub fn to_mermaid(&self) -> String { + let mut mermaid = String::from("flowchart TD\n"); + + // 声明所有节点 + let mut nodes: Vec<_> = self.nodes.keys().collect(); + nodes.sort(); // 确保输出稳定 + for node_name in nodes { + mermaid.push_str(&format!(" {}[{}]\n", node_name, node_name)); + } + + // 声明所有连线 + let mut edge_keys: Vec<_> = self.edges.keys().collect(); + edge_keys.sort(); // 确保输出稳定 + for from in edge_keys { + let edges = &self.edges[from]; + let mut sorted_edges = edges.clone(); + sorted_edges.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1))); + for (to, condition) in sorted_edges { + mermaid.push_str(&format!(" {} -->|{}| {}\n", from, condition, to)); + } + } + + mermaid + } } #[allow(dead_code)] @@ -125,6 +150,10 @@ impl BatchFlow { info!("Batch flow execution completed"); Ok(()) } + + pub fn to_mermaid(&self) -> String { + self.flow.to_mermaid() + } } #[macro_export] @@ -159,13 +188,13 @@ macro_rules! build_flow { )* // Handle edges appropriately $( - build_flow!(@edge g, $edge); + build_flow!(@edge_process g, $edge); )* g }}; - (@edge $g:expr, ($from:expr, $to:expr, $condition:expr)) => { + (@edge_process $g:expr, ($from:expr, $to:expr, $condition:expr)) => { $g.add_edge($from, $to, $condition); }; } @@ -204,7 +233,7 @@ macro_rules! build_batch_flow { )* // Handle edges appropriately $( - build_flow!(@edge g.flow, $edge); + build_flow!(@edge_process g.flow, $edge); )* g }}; @@ -213,13 +242,15 @@ macro_rules! build_batch_flow { #[cfg(test)] mod tests { use super::*; + use crate::node::SubFlowNode; use crate::node::{Node, ProcessResult, ProcessState}; use async_trait::async_trait; use serde_json::json; + use strum::Display; - #[derive(Debug, Clone, PartialEq)] + #[derive(Debug, Clone, PartialEq, Default, Display)] + #[strum(serialize_all = "snake_case")] #[allow(dead_code)] - #[derive(Default)] enum CustomState { Success, Failure, @@ -231,14 +262,6 @@ mod tests { fn is_default(&self) -> bool { matches!(self, CustomState::Default) } - - fn to_condition(&self) -> String { - match self { - CustomState::Success => "success".to_string(), - CustomState::Failure => "failure".to_string(), - CustomState::Default => "default".to_string(), - } - } } struct TestNode { @@ -365,4 +388,75 @@ mod tests { let result = flow3.run(context).await.unwrap(); assert_eq!(result, json!({"data": "test2"})); } + + #[tokio::test] + async fn test_subflow_node() { + // 1. Create subflow + let sub_node = TestNode::new(json!({"sub_result": "from_subflow"}), CustomState::Success); + let sub_flow = build_flow!( + start: ("sub_start", sub_node) + ); + + // 2. Create SubFlowNode as a node in parent flow + let subflow_node = SubFlowNode::::new( + sub_flow, + |ctx: &Context| { + // Example: Inherit everything + ctx.clone() + }, + |parent_ctx: &mut Context, result: &Result| { + // Example: Map result to parent context + match result { + Ok(val) => { + parent_ctx.set("result", val.clone()); + Ok(ProcessResult::new( + CustomState::Success, + "subflow ok".to_string(), + )) + } + Err(e) => { + parent_ctx.set("error", json!(e.to_string())); + Ok(ProcessResult::new(CustomState::Failure, e.to_string())) + } + } + }, + ); + + // 3. Create parent flow + let parent_flow = build_flow!( + start: ("run_subflow", subflow_node) + ); + + let context = Context::new(); + let result: serde_json::Value = parent_flow.run(context).await.unwrap(); + + assert_eq!(result, json!({"sub_result": "from_subflow"})); + } + + #[test] + fn test_flow_to_mermaid() { + let node1 = TestNode::new(json!({"data": "test1"}), CustomState::Success); + let node2 = TestNode::new(json!({"data": "test2"}), CustomState::Default); + let end_node = TestNode::new(json!({"final_result": "finished"}), CustomState::Default); + + let flow = build_flow!( + start: ("start", node1), + nodes: [("next", node2), ("end", end_node)], + edges: [ + ("start", "next", CustomState::Success), + ("next", "end", CustomState::Default) + ] + ); + + let mermaid = flow.to_mermaid(); + let expected = "\ +flowchart TD + end[end] + next[next] + start[start] + next -->|default| end + start -->|success| next +"; + assert_eq!(mermaid, expected); + } } diff --git a/src/node.rs b/src/node.rs index 72974ea..b077568 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,15 +1,20 @@ +use crate::flow::Flow; use crate::{Params, context::Context}; use anyhow::Result; use async_trait::async_trait; -use std::collections::HashMap; use std::sync::Arc; +use strum::Display; -pub trait ProcessState: Send + Sync { +pub trait ProcessState: Send + Sync + std::fmt::Display { fn is_default(&self) -> bool; - fn to_condition(&self) -> String; + + fn to_condition(&self) -> String { + self.to_string() + } } -#[derive(Debug, Clone, PartialEq, Default)] +#[derive(Debug, Clone, PartialEq, Default, Display)] +#[strum(serialize_all = "snake_case")] pub enum BaseState { Success, Failure, @@ -21,14 +26,6 @@ impl ProcessState for BaseState { fn is_default(&self) -> bool { matches!(self, BaseState::Default) } - - fn to_condition(&self) -> String { - match self { - BaseState::Success => "success".to_string(), - BaseState::Failure => "failure".to_string(), - BaseState::Default => "default".to_string(), - } - } } #[derive(Debug, Clone, PartialEq)] @@ -87,20 +84,20 @@ pub trait BaseNodeTrait: Node {} #[allow(dead_code)] pub struct BaseNode { params: Params, - next_nodes: HashMap>, + // next_nodes: HashMap>, } impl BaseNode { pub fn new(params: Params) -> Self { Self { params, - next_nodes: HashMap::new(), + // next_nodes: HashMap::new(), } } - pub fn add_next(&mut self, action: String, node: Arc) { - self.next_nodes.insert(action, node); - } + // pub fn add_next(&mut self, action: String, node: Arc) { + // self.next_nodes.insert(action, node); + // } } #[async_trait] @@ -141,3 +138,64 @@ impl Node for BatchNode { } impl BaseNodeTrait for BatchNode {} + +pub struct SubFlowNode +where + SubState: ProcessState + Default, + ParentState: ProcessState + Default, +{ + pub sub_flow: Arc>, + pub context_builder: Box Context + Send + Sync>, + #[allow(clippy::type_complexity)] + pub result_mapper: Box< + dyn Fn(&mut Context, &Result) -> Result> + + Send + + Sync, + >, +} + +impl SubFlowNode +where + SubState: ProcessState + Default, + ParentState: ProcessState + Default, +{ + pub fn new( + sub_flow: Flow, + context_builder: impl Fn(&Context) -> Context + Send + Sync + 'static, + result_mapper: impl Fn( + &mut Context, + &Result, + ) -> Result> + + Send + + Sync + + 'static, + ) -> Self { + Self { + sub_flow: Arc::new(sub_flow), + context_builder: Box::new(context_builder), + result_mapper: Box::new(result_mapper), + } + } +} + +#[async_trait] +impl Node for SubFlowNode +where + SubState: ProcessState + Default, + ParentState: ProcessState + Default, +{ + type State = ParentState; + + async fn execute(&self, context: &Context) -> Result { + let sub_context = (self.context_builder)(context); + self.sub_flow.run(sub_context).await + } + + async fn post_process( + &self, + context: &mut Context, + result: &Result, + ) -> Result> { + (self.result_mapper)(context, result) + } +}