From adbadc29750087201362e73bf497ed377d775140 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Fri, 29 Nov 2024 16:08:23 -0700 Subject: [PATCH 01/73] feat: `ResourceRecord` --- crates/learner/src/database/mod.rs | 12 +++--- crates/learner/src/database/record.rs | 58 +++++++++++++++++++++++++++ crates/learner/src/lib.rs | 2 + 3 files changed, 67 insertions(+), 5 deletions(-) create mode 100644 crates/learner/src/database/record.rs diff --git a/crates/learner/src/database/mod.rs b/crates/learner/src/database/mod.rs index 77bd3e8..3ccdb42 100644 --- a/crates/learner/src/database/mod.rs +++ b/crates/learner/src/database/mod.rs @@ -57,7 +57,7 @@ use tokio_rusqlite::Connection; use super::*; mod instruction; -// pub mod models; +pub mod record; #[cfg(test)] mod tests; pub use self::instruction::{ @@ -133,10 +133,12 @@ impl Database { // Initialize schema conn .call(|conn| { - Ok(conn.execute_batch(include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/migrations/init.sql" - )))?) + Ok( + conn.execute_batch(include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/migrations/v1.sql" + )))?, + ) }) .await?; diff --git a/crates/learner/src/database/record.rs b/crates/learner/src/database/record.rs new file mode 100644 index 0000000..89fe63a --- /dev/null +++ b/crates/learner/src/database/record.rs @@ -0,0 +1,58 @@ +use super::*; + +/// A complete view of a resource with all associated data +#[derive(Debug)] +pub struct ResourceRecord { + pub resource: R, + pub state: ResourceState, + pub tags: Vec, + pub storage: Option, + pub retrieval: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ReadStatus { + Unread, + Reading { + progress: f32, + // last_read: DateTime, // Track when reading sessions occur + // total_time: Duration, // Accumulate reading time + }, + Completed { + finished_at: DateTime, + // times_referenced: u32, // Track how often it's been revisited + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResourceState { + pub read_status: ReadStatus, + pub starred: bool, + pub rating: Option, + pub last_accessed: Option>, + pub notes: Option, + pub citation_key: Option, + // pub importance: Option, // Different from rating - how crucial is this? + pub tags_updated_at: Option>, // Track tag management +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetrievalData { + pub source: Option, + pub source_identifier: Option, + pub urls: BTreeMap, + pub doi: Option, + pub last_checked: Option>, // When we last verified URLs + pub access_type: Option, // "open", "subscription", "institutional" + pub verified: bool, // Whether we've confirmed this data +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StorageData { + pub files: BTreeMap, + pub original_filenames: BTreeMap, + pub added_at: BTreeMap>, + pub file_sizes: BTreeMap, // Track file sizes + pub checksums: BTreeMap, // For integrity checking + pub last_verified: DateTime, // When we last checked files exist +} diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index 21d81f7..723f1ee 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -148,6 +148,7 @@ #![feature(str_from_utf16_endian)] use std::{ + collections::BTreeMap, fmt::Display, path::{Path, PathBuf}, }; @@ -174,6 +175,7 @@ pub mod resource; use crate::{ database::*, error::*, + prelude::*, resource::{Author, Paper}, retriever::*, }; From b00b385e65cbd5dcb6db4836606d9d790a52364a Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Fri, 29 Nov 2024 16:08:28 -0700 Subject: [PATCH 02/73] Create v1.sql --- crates/learner/migrations/v1.sql | 63 ++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 crates/learner/migrations/v1.sql diff --git a/crates/learner/migrations/v1.sql b/crates/learner/migrations/v1.sql new file mode 100644 index 0000000..f65418d --- /dev/null +++ b/crates/learner/migrations/v1.sql @@ -0,0 +1,63 @@ +PRAGMA foreign_keys = ON; + +-- Version tracking +CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER PRIMARY KEY NOT NULL, + applied_at TEXT NOT NULL DEFAULT (datetime('now')) +) STRICT; + +-- Core resource storage +CREATE TABLE IF NOT EXISTS resources ( + id INTEGER PRIMARY KEY, + type TEXT NOT NULL, -- Resource type identifier (e.g., "paper", "book") + title TEXT, -- Denormalized for common queries + metadata JSON NOT NULL, -- Complete resource data + searchable_text TEXT, -- For full-text search + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +) STRICT; + +-- Resource state tracking +CREATE TABLE IF NOT EXISTS resource_states ( + resource_id INTEGER PRIMARY KEY, + read_status TEXT NOT NULL DEFAULT 'unread', -- 'unread', 'reading', 'completed' + rating INTEGER CHECK (rating BETWEEN 1 AND 5), -- Optional 1-5 rating + starred BOOLEAN NOT NULL DEFAULT 0, + last_accessed TEXT, + updated_at TEXT NOT NULL DEFAULT (datetime('now')), + FOREIGN KEY (resource_id) REFERENCES resources(id) ON DELETE CASCADE +) STRICT; + +-- Tag management +CREATE TABLE IF NOT EXISTS tags ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +) STRICT; + +CREATE TABLE IF NOT EXISTS resource_tags ( + resource_id INTEGER NOT NULL, + tag_id INTEGER NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + PRIMARY KEY (resource_id, tag_id), + FOREIGN KEY (resource_id) REFERENCES resources(id) ON DELETE CASCADE, + FOREIGN KEY (tag_id) REFERENCES tags(id) ON DELETE CASCADE +) STRICT; + +-- Full-text search +CREATE VIRTUAL TABLE IF NOT EXISTS resources_fts USING fts5( + title, + searchable_text, + content=resources, + content_rowid=id, + tokenize='unicode61 remove_diacritics 1' +); + +-- FTS triggers +CREATE TRIGGER resources_ai AFTER INSERT ON resources BEGIN + INSERT INTO resources_fts(rowid, title, searchable_text) + VALUES (new.id, new.title, new.searchable_text); +END; + +-- Set initial version +INSERT INTO schema_version (version) VALUES (1); \ No newline at end of file From beafd293cab00f114413778f7760e6636fd2acff Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Fri, 29 Nov 2024 16:08:42 -0700 Subject: [PATCH 03/73] refactor: `Resource` trait --- crates/learner/src/resource/mod.rs | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index 3f6f5c7..cfda9fe 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -38,7 +38,7 @@ //! # } //! ``` -use serde_json::{Map, Value}; +use serde_json::Value; use super::*; @@ -92,8 +92,16 @@ pub trait Resource: Serialize + for<'de> Deserialize<'de> { /// /// Returns [`LearnerError::InvalidResource`] if the resource cannot be serialized /// to a JSON object. - fn fields(&self) -> Result> { - serde_json::to_value(self)?.as_object().cloned().ok_or_else(|| LearnerError::InvalidResource) + fn fields(&self) -> Result> { + let mut output = BTreeMap::new(); + let map = serde_json::to_value(self)? + .as_object() + .cloned() + .ok_or_else(|| LearnerError::InvalidResource)?; + map.into_iter().for_each(|(k, v)| { + let _v = output.insert(k, v).unwrap(); + }); + Ok(output) } } @@ -110,10 +118,12 @@ pub trait Resource: Serialize + for<'de> Deserialize<'de> { /// # Examples /// /// ```rust +/// use std::collections::BTreeMap; +/// /// use learner::resource::ResourceConfig; -/// use serde_json::{json, Map}; +/// use serde_json::json; /// -/// let mut fields = Map::new(); +/// let mut fields = BTreeMap::new(); /// fields.insert("title".into(), json!("Understanding Type Systems")); /// fields.insert("university".into(), json!("Tech University")); /// @@ -124,13 +134,13 @@ pub struct ResourceConfig { /// The type identifier for this resource configuration pub type_name: String, /// Map of field names to their values - pub fields: Map, + pub fields: BTreeMap, } impl Resource for ResourceConfig { fn resource_type(&self) -> String { self.type_name.clone() } - fn fields(&self) -> Result> { Ok(self.fields.clone()) } + fn fields(&self) -> Result> { Ok(self.fields.clone()) } } #[cfg(test)] @@ -142,7 +152,7 @@ mod tests { #[test] fn test_thesis_resource() -> Result<()> { // Create a thesis resource - let mut fields = Map::new(); + let mut fields = BTreeMap::new(); fields.insert("title".into(), json!("Understanding Quantum Computing Effects")); fields.insert("author".into(), json!(["Alice Researcher", "Bob Scientist"])); fields.insert("university".into(), json!("Tech University")); From 38c1045b05a04d2d03bcc21e90e2aceb5954215c Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Fri, 29 Nov 2024 16:14:40 -0700 Subject: [PATCH 04/73] revert: back to `init.sql` for now --- crates/learner/src/database/mod.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/crates/learner/src/database/mod.rs b/crates/learner/src/database/mod.rs index 3ccdb42..d1aa1d6 100644 --- a/crates/learner/src/database/mod.rs +++ b/crates/learner/src/database/mod.rs @@ -133,12 +133,10 @@ impl Database { // Initialize schema conn .call(|conn| { - Ok( - conn.execute_batch(include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/migrations/v1.sql" - )))?, - ) + Ok(conn.execute_batch(include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/migrations/init.sql" + )))?) }) .await?; From 4f7b1f1f59210db3fde2e1d8338baea0df36632e Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Fri, 29 Nov 2024 16:22:06 -0700 Subject: [PATCH 05/73] fix: doctest --- crates/learner/src/resource/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index cfda9fe..b8475c6 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -13,6 +13,8 @@ //! # Examples //! //! ```rust,no_run +//! use std::collections::BTreeMap; +//! //! use learner::{ //! resource::{Paper, Resource, ResourceConfig}, //! Learner, @@ -29,7 +31,7 @@ //! println!("Paper type: {}", paper.resource_type()); //! //! // Or create a custom resource type at runtime -//! let mut fields = serde_json::Map::new(); +//! let mut fields = BTreeMap::new(); //! fields.insert("title".into(), json!("My Thesis")); //! fields.insert("university".into(), json!("Tech University")); //! From d0bfe1ee292bd805b6beedb84253e4afbe13f9da Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Fri, 29 Nov 2024 17:14:19 -0700 Subject: [PATCH 06/73] refactor: `retriever` module --- crates/learner/config/retrievers/arxiv.toml | 1 + crates/learner/config/retrievers/doi.toml | 1 + crates/learner/config/retrievers/iacr.toml | 1 + crates/learner/src/resource/mod.rs | 9 + crates/learner/src/retriever/config.rs | 105 +++++++ crates/learner/src/retriever/mod.rs | 269 +----------------- .../src/retriever/{ => response}/json.rs | 0 crates/learner/src/retriever/response/mod.rs | 141 +++++++++ .../src/retriever/{ => response}/xml.rs | 0 .../tests/workflows/build_retriever.rs | 4 +- 10 files changed, 267 insertions(+), 264 deletions(-) create mode 100644 crates/learner/src/retriever/config.rs rename crates/learner/src/retriever/{ => response}/json.rs (100%) create mode 100644 crates/learner/src/retriever/response/mod.rs rename crates/learner/src/retriever/{ => response}/xml.rs (100%) diff --git a/crates/learner/config/retrievers/arxiv.toml b/crates/learner/config/retrievers/arxiv.toml index b47dbbb..aa09b6a 100644 --- a/crates/learner/config/retrievers/arxiv.toml +++ b/crates/learner/config/retrievers/arxiv.toml @@ -2,6 +2,7 @@ base_url = "http://export.arxiv.org" endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}&max_results=1" name = "arxiv" pattern = "(?:^|https?://arxiv\\.org/(?:abs|pdf)/)(\\d{4}\\.\\d{4,5}|[a-zA-Z-]+(?:\\.[A-Z]{2})?/\\d{7})(?:\\.pdf)?$" +resource_type = "paper" source = "arxiv" [response_format] diff --git a/crates/learner/config/retrievers/doi.toml b/crates/learner/config/retrievers/doi.toml index f363850..130af72 100644 --- a/crates/learner/config/retrievers/doi.toml +++ b/crates/learner/config/retrievers/doi.toml @@ -2,6 +2,7 @@ base_url = "https://api.crossref.org/works" endpoint_template = "https://api.crossref.org/works/{identifier}" name = "doi" pattern = "(?:^|https?://doi\\.org/)(10\\.\\d{4,9}/[-._;()/:\\w]+)$" +resource_type = "paper" source = "doi" [response_format] diff --git a/crates/learner/config/retrievers/iacr.toml b/crates/learner/config/retrievers/iacr.toml index bbc5eb1..2ad208e 100644 --- a/crates/learner/config/retrievers/iacr.toml +++ b/crates/learner/config/retrievers/iacr.toml @@ -2,6 +2,7 @@ base_url = "https://eprint.iacr.org" endpoint_template = "https://eprint.iacr.org/oai?verb=GetRecord&identifier=oai:eprint.iacr.org:{identifier}&metadataPrefix=oai_dc" name = "iacr" pattern = "(?:^|https?://eprint\\.iacr\\.org/)(\\d{4}/\\d+)(?:\\.pdf)?$" +resource_type = "paper" source = "iacr" [response_format] diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index b8475c6..c163a8a 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -50,6 +50,15 @@ mod shared; pub use paper::*; pub use shared::*; +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ResourceType { + Paper, + Book, + // We can add more built-in types as needed + Custom(PathBuf), // For user-defined resource types via config +} + /// Core trait that defines the behavior of a resource in the system. /// /// This trait provides a common interface for all resource types, whether they are diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs new file mode 100644 index 0000000..6a95ff8 --- /dev/null +++ b/crates/learner/src/retriever/config.rs @@ -0,0 +1,105 @@ +use response::ResponseFormat; + +use super::*; + +/// Configuration for a specific paper source retriever. +/// +/// This struct defines how to interact with a particular paper source's API, +/// including URL patterns, authentication, and response parsing rules. +/// +/// # Examples +/// +/// Example TOML configuration: +/// +/// ```toml +/// name = "arxiv" +/// base_url = "http://export.arxiv.org/api/query" +/// pattern = "^\\d{4}\\.\\d{4,5}$" +/// source = "arxiv" +/// endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}" +/// +/// [response_format] +/// type = "xml" +/// strip_namespaces = true +/// +/// [response_format.field_maps] +/// title = { path = "entry/title" } +/// abstract = { path = "entry/summary" } +/// publication_date = { path = "entry/published" } +/// authors = { path = "entry/author/name" } +/// ``` +#[derive(Debug, Clone, Deserialize)] +pub struct RetrieverConfig { + /// Name of this retriever configuration + pub name: String, + /// The type of resource this retriever should yield + pub resource_type: ResourceType, + /// Base URL for API requests + pub base_url: String, + /// Regex pattern for matching and extracting paper identifiers + #[serde(deserialize_with = "deserialize_regex")] + pub pattern: Regex, + /// Source identifier for papers from this retriever + pub source: String, + /// Template for constructing API endpoint URLs + pub endpoint_template: String, + /// Format and parsing configuration for API responses + pub response_format: ResponseFormat, + /// Optional HTTP headers for API requests + #[serde(default)] + pub headers: HashMap, +} + +impl RetrieverConfig { + /// Extracts the canonical identifier from an input string. + /// + /// Uses the configured regex pattern to extract the standardized + /// identifier from various input formats (URLs, DOIs, etc.). + /// + /// # Arguments + /// + /// * `input` - Input string containing a paper identifier + /// + /// # Returns + /// + /// Returns a Result containing either: + /// - The extracted identifier as a string slice + /// - A LearnerError if the input doesn't match the pattern + pub fn extract_identifier<'a>(&self, input: &'a str) -> Result<&'a str> { + self + .pattern + .captures(input) + .and_then(|cap| cap.get(1)) + .map(|m| m.as_str()) + .ok_or(LearnerError::InvalidIdentifier) + } + + pub async fn retrieve_paper(&self, input: &str) -> Result { + let identifier = self.extract_identifier(input)?; + let url = self.endpoint_template.replace("{identifier}", identifier); + + debug!("Fetching from {} via: {}", self.name, url); + + let client = reqwest::Client::new(); + let mut request = client.get(&url); + + // Add any configured headers + for (key, value) in &self.headers { + request = request.header(key, value); + } + + let response = request.send().await?; + let data = response.bytes().await?; + + trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); + + let response_processor = match &self.response_format { + ResponseFormat::Xml(config) => config as &dyn ResponseProcessor, + ResponseFormat::Json(config) => config as &dyn ResponseProcessor, + }; + let mut paper = response_processor.process_response(&data).await?; + paper.source = self.source.clone(); + paper.source_identifier = identifier.to_string(); + Ok(paper) + } +} diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 11c0726..f3cf5e7 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -67,9 +67,13 @@ use std::collections::HashMap; use super::*; +use crate::resource::ResourceType; -pub mod json; -pub mod xml; +mod config; +mod response; + +pub use config::*; +pub use response::*; /// Main entry point for paper retrieval operations. /// @@ -122,189 +126,6 @@ impl Retriever { pub fn is_empty(&self) -> bool { self.configs.is_empty() } } -/// Configuration for a specific paper source retriever. -/// -/// This struct defines how to interact with a particular paper source's API, -/// including URL patterns, authentication, and response parsing rules. -/// -/// # Examples -/// -/// Example TOML configuration: -/// -/// ```toml -/// name = "arxiv" -/// base_url = "http://export.arxiv.org/api/query" -/// pattern = "^\\d{4}\\.\\d{4,5}$" -/// source = "arxiv" -/// endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}" -/// -/// [response_format] -/// type = "xml" -/// strip_namespaces = true -/// -/// [response_format.field_maps] -/// title = { path = "entry/title" } -/// abstract = { path = "entry/summary" } -/// publication_date = { path = "entry/published" } -/// authors = { path = "entry/author/name" } -/// ``` -#[derive(Debug, Clone, Deserialize)] -pub struct RetrieverConfig { - /// Name of this retriever configuration - pub name: String, - /// Base URL for API requests - pub base_url: String, - /// Regex pattern for matching and extracting paper identifiers - #[serde(deserialize_with = "deserialize_regex")] - pub pattern: Regex, - /// Source identifier for papers from this retriever - pub source: String, - /// Template for constructing API endpoint URLs - pub endpoint_template: String, - /// Format and parsing configuration for API responses - pub response_format: ResponseFormat, - /// Optional HTTP headers for API requests - #[serde(default)] - pub headers: HashMap, -} - -/// Available response format handlers. -/// -/// Specifies how to parse and extract paper metadata from API responses -/// in different formats. -/// -/// # Examples -/// -/// XML configuration: -/// ```toml -/// [response_format] -/// type = "xml" -/// strip_namespaces = true -/// -/// [response_format.field_maps] -/// title = { path = "entry/title" } -/// ``` -/// -/// JSON configuration: -/// ```toml -/// [response_format] -/// type = "json" -/// -/// [response_format.field_maps] -/// title = { path = "message/title/0" } -/// ``` -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "type")] -pub enum ResponseFormat { - /// XML response parser configuration - #[serde(rename = "xml")] - Xml(xml::XmlConfig), - /// JSON response parser configuration - #[serde(rename = "json")] - Json(json::JsonConfig), -} - -/// Field mapping configuration. -/// -/// Defines how to extract and transform specific fields from API responses. -/// -/// # Examples -/// -/// ```toml -/// [field_maps.title] -/// path = "entry/title" -/// transform = { type = "replace", pattern = "\\s+", replacement = " " } -/// ``` -#[derive(Debug, Clone, Deserialize)] -pub struct FieldMap { - /// Path to field in response (e.g., JSON path or XPath) - pub path: String, - /// Optional transformation to apply to extracted value - #[serde(default)] - pub transform: Option, -} - -/// Available field value transformations. -/// -/// Transformations that can be applied to extracted field values -/// before constructing the final Paper object. -/// -/// # Examples -/// -/// ```toml -/// # Clean up whitespace -/// transform = { type = "replace", pattern = "\\s+", replacement = " " } -/// -/// # Convert date format -/// transform = { type = "date", from_format = "%Y-%m-%d", to_format = "%Y-%m-%dT00:00:00Z" } -/// -/// # Construct full URL -/// transform = { type = "url", base = "https://example.com/", suffix = ".pdf" } -/// ``` -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "type")] -pub enum Transform { - /// Replace text using regex pattern - Replace { - /// Regular expression pattern to match - pattern: String, - /// Text to replace matched patterns with - replacement: String, - }, - /// Convert between date formats - Date { - /// Source date format string using chrono syntax (e.g., "%Y-%m-%d") - from_format: String, - /// Target date format string using chrono syntax (e.g., "%Y-%m-%dT%H:%M:%SZ") - to_format: String, - }, - /// Construct URL from parts - Url { - /// Base URL template, may contain {value} placeholder - base: String, - /// Optional suffix to append to the URL (e.g., ".pdf") - suffix: Option, - }, -} - -/// Trait for processing API responses into Paper objects. -/// -/// Implementors of this trait handle the conversion of raw API response data -/// into structured Paper metadata. The trait is implemented separately for -/// different response formats (XML, JSON) to provide a unified interface for -/// paper retrieval. -/// -/// # Examples -/// -/// ```no_run -/// # use learner::{retriever::ResponseProcessor, resource::Paper}; -/// # use learner::error::LearnerError; -/// struct CustomProcessor; -/// -/// #[async_trait::async_trait] -/// impl ResponseProcessor for CustomProcessor { -/// async fn process_response(&self, data: &[u8]) -> Result { -/// // Parse response data and construct Paper -/// todo!() -/// } -/// } -/// ``` -#[async_trait] -pub trait ResponseProcessor: Send + Sync { - /// Process raw response data into a Paper object. - /// - /// # Arguments - /// - /// * `data` - Raw bytes from the API response - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - A fully populated Paper object - /// - A LearnerError if parsing fails - async fn process_response(&self, data: &[u8]) -> Result; -} - impl Retriever { /// Creates a new empty retriever with no configurations. /// @@ -577,84 +398,6 @@ impl Retriever { } } -impl RetrieverConfig { - /// Extracts the canonical identifier from an input string. - /// - /// Uses the configured regex pattern to extract the standardized - /// identifier from various input formats (URLs, DOIs, etc.). - /// - /// # Arguments - /// - /// * `input` - Input string containing a paper identifier - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - The extracted identifier as a string slice - /// - A LearnerError if the input doesn't match the pattern - pub fn extract_identifier<'a>(&self, input: &'a str) -> Result<&'a str> { - self - .pattern - .captures(input) - .and_then(|cap| cap.get(1)) - .map(|m| m.as_str()) - .ok_or(LearnerError::InvalidIdentifier) - } - - /// Retrieves a paper using this configuration. - /// - /// This method: - /// 1. Extracts the canonical identifier - /// 2. Constructs the API URL - /// 3. Makes the HTTP request - /// 4. Processes the response - /// - /// # Arguments - /// - /// * `input` - Paper identifier or URL - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - The retrieved Paper object - /// - A LearnerError if any step fails - /// - /// # Errors - /// - /// This method will return an error if: - /// - The identifier cannot be extracted - /// - The HTTP request fails - /// - The response cannot be parsed - pub async fn retrieve_paper(&self, input: &str) -> Result { - let identifier = self.extract_identifier(input)?; - let url = self.endpoint_template.replace("{identifier}", identifier); - - debug!("Fetching from {} via: {}", self.name, url); - - let client = reqwest::Client::new(); - let mut request = client.get(&url); - - // Add any configured headers - for (key, value) in &self.headers { - request = request.header(key, value); - } - - let response = request.send().await?; - let data = response.bytes().await?; - - trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); - - let response_processor = match &self.response_format { - ResponseFormat::Xml(config) => config as &dyn ResponseProcessor, - ResponseFormat::Json(config) => config as &dyn ResponseProcessor, - }; - let mut paper = response_processor.process_response(&data).await?; - paper.source = self.source.clone(); - paper.source_identifier = identifier.to_string(); - Ok(paper) - } -} - /// Custom deserializer for converting string patterns into Regex objects. /// /// Used with serde's derive functionality to automatically deserialize diff --git a/crates/learner/src/retriever/json.rs b/crates/learner/src/retriever/response/json.rs similarity index 100% rename from crates/learner/src/retriever/json.rs rename to crates/learner/src/retriever/response/json.rs diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs new file mode 100644 index 0000000..73f938d --- /dev/null +++ b/crates/learner/src/retriever/response/mod.rs @@ -0,0 +1,141 @@ +use super::*; + +pub mod json; +pub mod xml; + +/// Available response format handlers. +/// +/// Specifies how to parse and extract paper metadata from API responses +/// in different formats. +/// +/// # Examples +/// +/// XML configuration: +/// ```toml +/// [response_format] +/// type = "xml" +/// strip_namespaces = true +/// +/// [response_format.field_maps] +/// title = { path = "entry/title" } +/// ``` +/// +/// JSON configuration: +/// ```toml +/// [response_format] +/// type = "json" +/// +/// [response_format.field_maps] +/// title = { path = "message/title/0" } +/// ``` +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type")] +pub enum ResponseFormat { + /// XML response parser configuration + #[serde(rename = "xml")] + Xml(xml::XmlConfig), + /// JSON response parser configuration + #[serde(rename = "json")] + Json(json::JsonConfig), +} + +/// Field mapping configuration. +/// +/// Defines how to extract and transform specific fields from API responses. +/// +/// # Examples +/// +/// ```toml +/// [field_maps.title] +/// path = "entry/title" +/// transform = { type = "replace", pattern = "\\s+", replacement = " " } +/// ``` +#[derive(Debug, Clone, Deserialize)] +pub struct FieldMap { + /// Path to field in response (e.g., JSON path or XPath) + pub path: String, + /// Optional transformation to apply to extracted value + #[serde(default)] + pub transform: Option, +} + +/// Available field value transformations. +/// +/// Transformations that can be applied to extracted field values +/// before constructing the final Paper object. +/// +/// # Examples +/// +/// ```toml +/// # Clean up whitespace +/// transform = { type = "replace", pattern = "\\s+", replacement = " " } +/// +/// # Convert date format +/// transform = { type = "date", from_format = "%Y-%m-%d", to_format = "%Y-%m-%dT00:00:00Z" } +/// +/// # Construct full URL +/// transform = { type = "url", base = "https://example.com/", suffix = ".pdf" } +/// ``` +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type")] +pub enum Transform { + /// Replace text using regex pattern + Replace { + /// Regular expression pattern to match + pattern: String, + /// Text to replace matched patterns with + replacement: String, + }, + /// Convert between date formats + Date { + /// Source date format string using chrono syntax (e.g., "%Y-%m-%d") + from_format: String, + /// Target date format string using chrono syntax (e.g., "%Y-%m-%dT%H:%M:%SZ") + to_format: String, + }, + /// Construct URL from parts + Url { + /// Base URL template, may contain {value} placeholder + base: String, + /// Optional suffix to append to the URL (e.g., ".pdf") + suffix: Option, + }, +} + +/// Trait for processing API responses into Paper objects. +/// +/// Implementors of this trait handle the conversion of raw API response data +/// into structured Paper metadata. The trait is implemented separately for +/// different response formats (XML, JSON) to provide a unified interface for +/// paper retrieval. +/// +/// # Examples +/// +/// ```no_run +/// # use learner::{retriever::ResponseProcessor, resource::Paper}; +/// # use learner::error::LearnerError; +/// struct CustomProcessor; +/// +/// #[async_trait::async_trait] +/// impl ResponseProcessor for CustomProcessor { +/// async fn process_response(&self, data: &[u8]) -> Result { +/// // Parse response data and construct Paper +/// todo!() +/// } +/// } +/// ``` +#[async_trait] +pub trait ResponseProcessor: Send + Sync { + /// Process raw response data into a Paper object. + /// + /// # Arguments + /// + /// * `data` - Raw bytes from the API response + /// + /// # Returns + /// + /// Returns a Result containing either: + /// - A fully populated Paper object + /// - A LearnerError if parsing fails + async fn process_response(&self, data: &[u8]) -> Result; +} diff --git a/crates/learner/src/retriever/xml.rs b/crates/learner/src/retriever/response/xml.rs similarity index 100% rename from crates/learner/src/retriever/xml.rs rename to crates/learner/src/retriever/response/xml.rs diff --git a/crates/learner/tests/workflows/build_retriever.rs b/crates/learner/tests/workflows/build_retriever.rs index 29d6995..8477eb4 100644 --- a/crates/learner/tests/workflows/build_retriever.rs +++ b/crates/learner/tests/workflows/build_retriever.rs @@ -1,6 +1,8 @@ use std::fs::read_to_string; -use learner::retriever::{ResponseFormat, RetrieverConfig, Transform}; +use learner::retriever::{ResponseFormat, Transform}; + +use super::*; #[test] fn test_arxiv_config_deserialization() { From f29b9c035fd6bf8b92161be9336639750da5c8d9 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 07:14:18 -0700 Subject: [PATCH 07/73] feat: improved `ResourceConfig` --- crates/learner/config/resources/book.toml | 72 +++ crates/learner/config/resources/paper.toml | 73 +++ crates/learner/config/resources/thesis.toml | 80 +++- crates/learner/src/database/record.rs | 6 +- crates/learner/src/error.rs | 15 +- crates/learner/src/lib.rs | 3 +- crates/learner/src/resource/mod.rs | 502 +++++++++++++------- crates/learner/src/resource/paper.rs | 4 - crates/learner/src/retriever/config.rs | 33 +- crates/learner/src/retriever/mod.rs | 1 - 10 files changed, 576 insertions(+), 213 deletions(-) create mode 100644 crates/learner/config/resources/book.toml create mode 100644 crates/learner/config/resources/paper.toml diff --git a/crates/learner/config/resources/book.toml b/crates/learner/config/resources/book.toml new file mode 100644 index 0000000..1607b12 --- /dev/null +++ b/crates/learner/config/resources/book.toml @@ -0,0 +1,72 @@ +type_name = "book" + +description = "A published book, including textbooks, monographs, and edited volumes" + +[[fields]] +description = "The book's full title" +field_type = "string" +name = "title" +required = true +validation = { min_length = 1, max_length = 500 } + +[[fields]] +description = "The book's authors" +field_type = "array" +name = "authors" +required = false + +[[fields]] +description = "The book's editors, if any" +field_type = "array" +name = "editors" +required = false + +[[fields]] +description = "International Standard Book Number" +field_type = "string" +name = "isbn" +required = false +validation = { pattern = '''^(?:978|979)?[- ]?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]$''' } + +[[fields]] +description = "The book's publisher" +field_type = "string" +name = "publisher" +required = false + +[[fields]] +description = "When the book was published" +field_type = "datetime" +name = "publication_date" +required = false + +[[fields]] +description = "The edition number or description" +field_type = "string" +name = "edition" +required = false + +[[fields]] +description = "Total number of pages" +field_type = "integer" +name = "pages" +required = false + +[[fields]] +default = "en" +description = "The book's primary language" +field_type = "string" +name = "language" +required = false + +[[fields]] +description = "Subject categories or classifications" +field_type = "array" +name = "subjects" +required = false + +[[fields]] +description = "Book summary or description" +field_type = "string" +name = "summary" +required = false diff --git a/crates/learner/config/resources/paper.toml b/crates/learner/config/resources/paper.toml new file mode 100644 index 0000000..d971af1 --- /dev/null +++ b/crates/learner/config/resources/paper.toml @@ -0,0 +1,73 @@ +type_name = "paper" + +description = "A scholarly paper or article published in an academic context" + +[[fields]] +description = "The full title of the paper" +field_type = "string" +name = "title" +required = true +validation = { min_length = 1, max_length = 500 } + +[[fields]] +description = "The paper's authors with their affiliations" +field_type = "array" +name = "authors" +required = true +validation = { min_items = 1 } + +[[fields]] +description = "The paper's abstract or summary" +field_type = "string" +name = "abstract" +required = false + +[[fields]] +description = "When the paper was published or last updated" +field_type = "datetime" +name = "publication_date" +required = false + +[[fields]] +description = "Digital Object Identifier" +field_type = "string" +name = "doi" +required = false +validation = { pattern = "^10\\.\\d{4,9}/[-._;()/:a-zA-Z0-9]+$" } + +[[fields]] +description = "Keywords or subject areas" +field_type = "array" +name = "keywords" +required = false + +[[fields]] +description = "Journal where the paper was published" +field_type = "string" +name = "journal" +required = false + +[[fields]] +description = "Journal volume number" +field_type = "string" +name = "volume" +required = false + +[[fields]] +description = "Journal issue number" +field_type = "string" +name = "issue" +required = false + +[[fields]] +description = "Page range in the journal" +field_type = "string" +name = "pages" +required = false + +[[fields]] +default = true +description = "Whether the paper underwent peer review" +field_type = "boolean" +name = "peer_reviewed" +required = false diff --git a/crates/learner/config/resources/thesis.toml b/crates/learner/config/resources/thesis.toml index 131a58d..fe1a457 100644 --- a/crates/learner/config/resources/thesis.toml +++ b/crates/learner/config/resources/thesis.toml @@ -1,11 +1,73 @@ type_name = "thesis" -[fields] -abstract = { type = "string", required = false } -author = { type = "array", items = "string", required = true } -committee = { type = "array", items = "string", required = false } -defense_date = { type = "string", format = "date-time", required = false } -department = { type = "string", required = false, default = "Computer Science" } -keywords = { type = "array", items = "string", required = false } -title = { type = "string", required = true } -university = { type = "string", required = true } +description = "A master's thesis or doctoral dissertation" + +[[fields]] +description = "The full title of the thesis" +field_type = "string" +name = "title" +required = true +validation = { min_length = 1, max_length = 500 } + +[[fields]] +description = "The thesis author" +field_type = "string" +name = "author" +required = true + +[[fields]] +description = "Thesis abstract or summary" +field_type = "string" +name = "abstract" +required = false + +[[fields]] +description = "The degree type (e.g., PhD, MSc)" +field_type = "string" +name = "degree" +required = true +validation = { enum_values = ["PhD", "DPhil", "MSc", "MA", "MPhil", "MEng"] } + +[[fields]] +description = "The degree-granting institution" +field_type = "string" +name = "institution" +required = true + +[[fields]] +description = "The academic department" +field_type = "string" +name = "department" +required = false + +[[fields]] +description = "When the degree was awarded" +field_type = "datetime" +name = "completion_date" +required = true + +[[fields]] +description = "Thesis advisors or supervisors" +field_type = "array" +name = "advisors" +required = true +validation = { min_items = 1 } + +[[fields]] +description = "Committee members beyond advisors" +field_type = "array" +name = "committee" +required = false + +[[fields]] +description = "Keywords or subject areas" +field_type = "array" +name = "keywords" +required = false + +[[fields]] +description = "Digital Object Identifier if available" +field_type = "string" +name = "doi" +required = false +validation = { pattern = "^10\\.\\d{4,9}/[-._;()/:a-zA-Z0-9]+$" } diff --git a/crates/learner/src/database/record.rs b/crates/learner/src/database/record.rs index 89fe63a..a39285d 100644 --- a/crates/learner/src/database/record.rs +++ b/crates/learner/src/database/record.rs @@ -1,9 +1,11 @@ +use resource::ResourceConfig; + use super::*; /// A complete view of a resource with all associated data #[derive(Debug)] -pub struct ResourceRecord { - pub resource: R, +pub struct ResourceRecord { + pub resource: ResourceConfig, pub state: ResourceState, pub tags: Vec, pub storage: Option, diff --git a/crates/learner/src/error.rs b/crates/learner/src/error.rs index 0aa3867..c4a2bb9 100644 --- a/crates/learner/src/error.rs +++ b/crates/learner/src/error.rs @@ -251,17 +251,6 @@ pub enum LearnerError { #[error(transparent)] SerdeJson(#[from] serde_json::Error), - /// Indicates a resource failed to serialize into a valid structure. - /// - /// This error occurs when attempting to serialize a resource type - /// into JSON and the result is not a simple object structure. This - /// typically happens when: - /// - The resource type contains complex nested structures - /// - The resource serializes to a JSON array instead of an object - /// - The resource serializes to a primitive value - /// - /// The error helps ensure that resources maintain a flat, searchable - /// structure that can be properly stored and queried in the database. - #[error("A resource must serialize into a flat Rust struct or JSON object.")] - InvalidResource, + #[error("Failed to be a valid resource due to: {0}")] + InvalidResource(String), } diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index 723f1ee..fd0e67d 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -217,8 +217,7 @@ pub const IACR_CONFIG: &str = include_str!("../config/retrievers/iacr.toml"); /// ``` pub mod prelude { pub use crate::{ - database::DatabaseInstruction, error::LearnerError, resource::Resource, - retriever::ResponseProcessor, + database::DatabaseInstruction, error::LearnerError, retriever::ResponseProcessor, }; } diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index c163a8a..a17f037 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -1,46 +1,4 @@ -//! Resource abstraction and configuration for the learner library. -//! -//! This module provides the core abstractions for working with different types of academic -//! and research resources. It defines: -//! -//! - A [`Resource`] trait that all resource types must implement -//! - A flexible [`ResourceConfig`] for runtime-configured resource types -//! - Common utility types and functions for resource management -//! -//! The design allows for both statically defined resource types (like papers and books) -//! and dynamically configured resources that can be defined through configuration files. -//! -//! # Examples -//! -//! ```rust,no_run -//! use std::collections::BTreeMap; -//! -//! use learner::{ -//! resource::{Paper, Resource, ResourceConfig}, -//! Learner, -//! }; -//! use serde_json::json; -//! -//! # async fn example() -> Result<(), Box> { -//! // Using a built-in resource type -//! let learner = Learner::builder().build().await?; -//! let paper = learner.retriever.get_paper("2301.07041").await?; -//! -//! // Access resource fields -//! let fields = paper.fields()?; -//! println!("Paper type: {}", paper.resource_type()); -//! -//! // Or create a custom resource type at runtime -//! let mut fields = BTreeMap::new(); -//! fields.insert("title".into(), json!("My Thesis")); -//! fields.insert("university".into(), json!("Tech University")); -//! -//! let thesis = ResourceConfig { type_name: "thesis".to_string(), fields }; -//! # Ok(()) -//! # } -//! ``` - -use serde_json::Value; +use std::{collections::HashSet, str::FromStr}; use super::*; @@ -49,167 +7,349 @@ mod shared; pub use paper::*; pub use shared::*; +use toml::Value; #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum ResourceType { - Paper, - Book, - // We can add more built-in types as needed - Custom(PathBuf), // For user-defined resource types via config +pub struct ResourceConfig { + /// The type identifier for this resource + pub type_name: String, + /// Optional description of this resource type + #[serde(default)] + pub description: Option, + /// Field definitions with optional metadata + #[serde(default)] + pub fields: Vec, } -/// Core trait that defines the behavior of a resource in the system. -/// -/// This trait provides a common interface for all resource types, whether they are -/// statically defined (like [`Paper`]) or dynamically configured through [`ResourceConfig`]. -/// It requires that implementing types can be serialized and deserialized, which enables -/// persistent storage and retrieval. -/// -/// The trait provides two key capabilities: -/// - Identification of the resource type -/// - Access to the resource's fields in a uniform way -/// -/// # Examples -/// -/// ```rust -/// # use serde::{Serialize, Deserialize}; -/// # use learner::resource::Resource; -/// #[derive(Serialize, Deserialize)] -/// struct Book { -/// title: String, -/// author: String, -/// isbn: String, -/// } -/// -/// impl Resource for Book { -/// fn resource_type(&self) -> String { "book".to_string() } -/// } -/// ``` -pub trait Resource: Serialize + for<'de> Deserialize<'de> { - /// Returns the type identifier for this resource. - /// - /// This identifier is used to distinguish between different types of resources - /// in the system. For example, "paper", "book", or "thesis". - fn resource_type(&self) -> String; - - /// Returns a map of field names to their values for this resource. - /// - /// This method provides a uniform way to access a resource's fields regardless - /// of the concrete type. The default implementation uses serde to serialize - /// the resource to JSON and extract its fields. - /// - /// # Errors - /// - /// Returns [`LearnerError::InvalidResource`] if the resource cannot be serialized - /// to a JSON object. - fn fields(&self) -> Result> { - let mut output = BTreeMap::new(); - let map = serde_json::to_value(self)? - .as_object() - .cloned() - .ok_or_else(|| LearnerError::InvalidResource)?; - map.into_iter().for_each(|(k, v)| { - let _v = output.insert(k, v).unwrap(); - }); - Ok(output) - } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FieldDefinition { + /// Name of the field + pub name: String, + /// Type of the field (should be a TOML Value) + pub field_type: String, + /// Whether this field must be present + #[serde(default)] + pub required: bool, + /// Human-readable description + #[serde(default)] + pub description: Option, + /// Default value if field is absent + #[serde(default)] + pub default: Option, + /// Optional validation rules + #[serde(default)] + pub validation: Option, } -/// A dynamically configured resource type. -/// -/// This struct enables the creation of new resource types at runtime through -/// configuration files. It provides a flexible way to extend the system without -/// requiring code changes. -/// -/// The type consists of: -/// - A type identifier string -/// - A map of field names to their values -/// -/// # Examples -/// -/// ```rust -/// use std::collections::BTreeMap; -/// -/// use learner::resource::ResourceConfig; -/// use serde_json::json; -/// -/// let mut fields = BTreeMap::new(); -/// fields.insert("title".into(), json!("Understanding Type Systems")); -/// fields.insert("university".into(), json!("Tech University")); -/// -/// let thesis = ResourceConfig { type_name: "thesis".to_string(), fields }; -/// ``` #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ResourceConfig { - /// The type identifier for this resource configuration - pub type_name: String, - /// Map of field names to their values - pub fields: BTreeMap, +pub struct ValidationRules { + // String validations + pub pattern: Option, // Regex pattern to match + pub min_length: Option, // Minimum string length + pub max_length: Option, // Maximum string length + + // Numeric validations + pub minimum: Option, // Minimum value + pub maximum: Option, // Maximum value + pub multiple_of: Option, // Must be multiple of this value + + // Array validations + pub min_items: Option, // Minimum array length + pub max_items: Option, // Maximum array length + pub unique_items: Option, // Whether items must be unique + + // General validations + pub enum_values: Option>, // List of allowed values +} + +impl ResourceConfig { + /// Validates a set of values against this resource configuration + pub fn validate(&self, values: &toml::value::Table) -> Result { + // Check required fields + for field in &self.fields { + if field.required { + if !values.contains_key(&field.name) { + return Err(LearnerError::InvalidResource(format!( + "Missing required field: {}", + field.name + ))); + } + } + } + + // Validate each provided field + for (name, value) in values { + if let Some(field) = self.fields.iter().find(|f| f.name == *name) { + // Validate field value against its definition + self.validate_field(field, value)?; + } + } + + Ok(true) + } + + /// Validates a single field value against its definition + fn validate_field(&self, field: &FieldDefinition, value: &toml::Value) -> Result<()> { + // First validate that the provided value matches the declared type + match (field.field_type.as_str(), value) { + // String validation - handles both basic type checking and string-specific rules + ("string", toml::Value::String(v)) => { + if let Some(rules) = &field.validation { + // Length constraints + if let Some(min_length) = rules.min_length { + if v.len() < min_length { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be at least {} characters", + field.name, min_length + ))); + } + } + if let Some(max_length) = rules.max_length { + if v.len() > max_length { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' cannot exceed {} characters", + field.name, max_length + ))); + } + } + + // Pattern matching via regex + if let Some(pattern) = &rules.pattern { + dbg!(&pattern); + let re = Regex::new(pattern) + .map_err(|_| LearnerError::InvalidResource("Invalid regex pattern".into()))?; + if !re.is_match(v) { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must match pattern: {}", + field.name, pattern + ))); + } + } + + // Enumerated values check + if let Some(allowed) = &rules.enum_values { + if !allowed.contains(v) { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be one of: {:?}", + field.name, allowed + ))); + } + } + } + Ok(()) + }, + + // Numeric validations - handle both integer and float values + ("integer", toml::Value::Integer(v)) => { + if let Some(rules) = &field.validation { + validate_numeric(field, *v as f64, rules)?; + } + Ok(()) + }, + + ("float", toml::Value::Float(v)) => { + if let Some(rules) = &field.validation { + validate_numeric(field, *v, rules)?; + } + Ok(()) + }, + + // Array validation - handles array-specific rules + ("array", toml::Value::Array(v)) => { + if let Some(rules) = &field.validation { + if let Some(min_items) = rules.min_items { + if v.len() < min_items { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must have at least {} items", + field.name, min_items + ))); + } + } + + if let Some(max_items) = rules.max_items { + if v.len() > max_items { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' cannot exceed {} items", + field.name, max_items + ))); + } + } + + if rules.unique_items == Some(true) { + let mut seen = HashSet::new(); + for item in v { + let item_str = toml::to_string(item).map_err(|_| { + LearnerError::InvalidResource("Failed to serialize array item".into()) + })?; + if !seen.insert(item_str) { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' contains duplicate items", + field.name + ))); + } + } + } + } + Ok(()) + }, + + // Simple type validations - just ensure type matches + ("boolean", toml::Value::Boolean(_)) => Ok(()), + ("datetime", toml::Value::Datetime(_)) => Ok(()), + ("table", toml::Value::Table(_)) => Ok(()), + + // Type mismatch - provide a clear error message + _ => Err(LearnerError::InvalidResource(format!( + "Field '{}' expected type '{}' but got '{}'", + field.name, + field.field_type, + match value { + toml::Value::String(_) => "string", + toml::Value::Integer(_) => "integer", + toml::Value::Float(_) => "float", + toml::Value::Boolean(_) => "boolean", + toml::Value::Datetime(_) => "datetime", + toml::Value::Array(_) => "array", + toml::Value::Table(_) => "table", + } + ))), + } + } } -impl Resource for ResourceConfig { - fn resource_type(&self) -> String { self.type_name.clone() } +fn validate_numeric(field: &FieldDefinition, value: f64, rules: &ValidationRules) -> Result<()> { + if let Some(min) = rules.minimum { + if value < min { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be at least {}", + field.name, min + ))); + } + } + + if let Some(max) = rules.maximum { + if value > max { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' cannot exceed {}", + field.name, max + ))); + } + } + + if let Some(multiple) = rules.multiple_of { + let ratio = value / multiple; + if (ratio - ratio.round()).abs() > f64::EPSILON { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be a multiple of {}", + field.name, multiple + ))); + } + } + + Ok(()) +} +// Convert from chrono DateTime to TOML Datetime +pub fn chrono_to_toml_datetime(dt: DateTime) -> toml::value::Datetime { + // TOML datetime is stored as seconds since Unix epoch + toml::value::Datetime::from_str(&dt.to_rfc3339()).unwrap() +} - fn fields(&self) -> Result> { Ok(self.fields.clone()) } +// Convert from TOML Datetime to chrono DateTime +pub fn toml_to_chrono_datetime(dt: toml::value::Datetime) -> DateTime { + // Create DateTime from Unix timestamp + DateTime::parse_from_rfc3339(&dt.to_string()).unwrap().to_utc() } #[cfg(test)] mod tests { - use serde_json::json; + use chrono::{TimeZone, Utc}; use super::*; #[test] - fn test_thesis_resource() -> Result<()> { - // Create a thesis resource - let mut fields = BTreeMap::new(); - fields.insert("title".into(), json!("Understanding Quantum Computing Effects")); - fields.insert("author".into(), json!(["Alice Researcher", "Bob Scientist"])); - fields.insert("university".into(), json!("Tech University")); - fields.insert("department".into(), json!("Computer Science")); - fields.insert("defense_date".into(), json!("2024-06-15T14:00:00Z")); - fields.insert( - "committee".into(), - json!(["Prof. Carol Chair", "Dr. David Member", "Dr. Eve External"]), - ); - fields - .insert("keywords".into(), json!(["quantum computing", "decoherence", "error correction"])); - - let thesis = ResourceConfig { type_name: "thesis".to_string(), fields }; - - // Test resource_type - assert_eq!(thesis.resource_type(), "thesis"); - - // Test fields method - let fields = thesis.fields()?; - - // Verify we can access specific fields with proper types - assert!(fields.get("title").unwrap().is_string()); - assert!(fields.get("author").unwrap().as_array().unwrap().len() == 2); - - // Test JSON serialization/deserialization roundtrip - let serialized = serde_json::to_string(&thesis)?; - let deserialized: ResourceConfig = serde_json::from_str(&serialized)?; - assert_eq!(thesis.fields.get("title"), deserialized.fields.get("title")); + fn test_paper_configuration() -> Result<()> { + // Load the paper configuration + let config = include_str!("../../config/resources/paper.toml"); + let config: ResourceConfig = toml::from_str(config)?; + + let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); + + // Create a valid paper + let paper_values = toml::value::Table::from_iter([ + ("title".into(), toml::Value::String("Understanding Quantum Computing".into())), + ( + "authors".into(), + toml::Value::Array(vec![toml::Value::Table(toml::value::Table::from_iter([ + ("name".into(), toml::Value::String("Alice Researcher".into())), + ("affiliation".into(), toml::Value::String("Tech University".into())), + ]))]), + ), + ("publication_date".into(), toml::Value::Datetime(date)), + ("doi".into(), toml::Value::String("10.1234/example.123".into())), + ]); + + // Validate the paper + assert!(config.validate(&paper_values)?); + + // Test required field validation + let invalid_paper = toml::value::Table::from_iter([ + ("authors".into(), toml::Value::Array(vec![])), // Missing title + ]); + assert!(config.validate(&invalid_paper).is_err()); Ok(()) } #[test] - fn test_thesis_from_toml() -> Result<()> { - let toml_str = include_str!("../../config/resources/thesis.toml"); - let config: ResourceConfig = toml::from_str(toml_str)?; - dbg!(&config); - - assert_eq!(config.resource_type(), "thesis"); - - // Test that we can access the field definitions - let fields = config.fields()?; - dbg!(&fields); - assert!(fields.contains_key("title")); - assert!(fields.contains_key("author")); - assert!(fields.contains_key("university")); + fn test_book_configuration() -> Result<()> { + let config = include_str!("../../config/resources/book.toml"); + let config: ResourceConfig = toml::from_str(config)?; + + let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); + + let book_values = toml::value::Table::from_iter([ + ("title".into(), toml::Value::String("Advanced Quantum Computing".into())), + ( + "authors".into(), + toml::Value::Array(vec![ + toml::Value::String("Alice Writer".into()), + toml::Value::String("Bob Author".into()), + ]), + ), + ("isbn".into(), toml::Value::String("978-0-12-345678-9".into())), + ("publisher".into(), toml::Value::String("Academic Press".into())), + ("publication_date".into(), toml::Value::Datetime(date)), + ]); + + assert!(config.validate(&book_values)?); + Ok(()) + } + + #[test] + fn test_thesis_configuration() -> Result<()> { + let config = include_str!("../../config/resources/thesis.toml"); + let config: ResourceConfig = toml::from_str(config)?; + + let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); + + let thesis_values = toml::value::Table::from_iter([ + ("title".into(), toml::Value::String("Novel Approaches to Quantum Error Correction".into())), + ("author".into(), toml::Value::String("Alice Researcher".into())), + ("degree".into(), toml::Value::String("PhD".into())), + ("institution".into(), toml::Value::String("Tech University".into())), + ("completion_date".into(), toml::Value::Datetime(date)), + ( + "advisors".into(), + toml::Value::Array(vec![toml::Value::String("Prof. Bob Supervisor".into())]), + ), + ]); + + assert!(config.validate(&thesis_values)?); + + // Test degree enum validation + let mut invalid_thesis = thesis_values.clone(); + invalid_thesis.insert("degree".into(), toml::Value::String("InvalidDegree".into())); + assert!(config.validate(&invalid_thesis).is_err()); Ok(()) } diff --git a/crates/learner/src/resource/paper.rs b/crates/learner/src/resource/paper.rs index 2b04d98..744cc6f 100644 --- a/crates/learner/src/resource/paper.rs +++ b/crates/learner/src/resource/paper.rs @@ -188,7 +188,3 @@ impl Paper { PathBuf::from(format!("{}.pdf", formatted_title)) } } - -impl Resource for Paper { - fn resource_type(&self) -> String { "paper".to_string() } -} diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 6a95ff8..6f9f048 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,3 +1,4 @@ +use resource::ResourceConfig; use response::ResponseFormat; use super::*; @@ -33,7 +34,7 @@ pub struct RetrieverConfig { /// Name of this retriever configuration pub name: String, /// The type of resource this retriever should yield - pub resource_type: ResourceType, + pub resource: ResourceConfig, /// Base URL for API requests pub base_url: String, /// Regex pattern for matching and extracting paper identifiers @@ -102,4 +103,34 @@ impl RetrieverConfig { paper.source_identifier = identifier.to_string(); Ok(paper) } + + pub async fn retrieve_resource(&self, input: &str) -> Result { + let identifier = self.extract_identifier(input)?; + + // Send request and get response + let url = self.endpoint_template.replace("{identifier}", identifier); + debug!("Fetching from {} via: {}", self.name, url); + + let client = reqwest::Client::new(); + let mut request = client.get(&url); + + // Add any configured headers + for (key, value) in &self.headers { + request = request.header(key, value); + } + + let response = request.send().await?; + let data = response.bytes().await?; + trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); + + // Process the response into a generic Value first + let response_processor = match &self.response_format { + ResponseFormat::Xml(config) => config as &dyn ResponseProcessor, + ResponseFormat::Json(config) => config as &dyn ResponseProcessor, + }; + + todo!(); + + // Ok(resource) + } } diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index f3cf5e7..34f30f4 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -67,7 +67,6 @@ use std::collections::HashMap; use super::*; -use crate::resource::ResourceType; mod config; mod response; From b86f074e57a710bb38cba7d91b39bb1a0d48503f Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 07:28:57 -0700 Subject: [PATCH 08/73] fix: tests --- crates/learner/config/retrievers/arxiv.toml | 2 +- crates/learner/config/retrievers/doi.toml | 2 +- crates/learner/config/retrievers/iacr.toml | 2 +- crates/learner/src/retriever/config.rs | 20 ++++++++++++++++++++ 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/crates/learner/config/retrievers/arxiv.toml b/crates/learner/config/retrievers/arxiv.toml index aa09b6a..9d474e5 100644 --- a/crates/learner/config/retrievers/arxiv.toml +++ b/crates/learner/config/retrievers/arxiv.toml @@ -2,7 +2,7 @@ base_url = "http://export.arxiv.org" endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}&max_results=1" name = "arxiv" pattern = "(?:^|https?://arxiv\\.org/(?:abs|pdf)/)(\\d{4}\\.\\d{4,5}|[a-zA-Z-]+(?:\\.[A-Z]{2})?/\\d{7})(?:\\.pdf)?$" -resource_type = "paper" +resource = "config/resources/paper.toml" source = "arxiv" [response_format] diff --git a/crates/learner/config/retrievers/doi.toml b/crates/learner/config/retrievers/doi.toml index 130af72..4970498 100644 --- a/crates/learner/config/retrievers/doi.toml +++ b/crates/learner/config/retrievers/doi.toml @@ -2,7 +2,7 @@ base_url = "https://api.crossref.org/works" endpoint_template = "https://api.crossref.org/works/{identifier}" name = "doi" pattern = "(?:^|https?://doi\\.org/)(10\\.\\d{4,9}/[-._;()/:\\w]+)$" -resource_type = "paper" +resource = "config/resources/paper.toml" source = "doi" [response_format] diff --git a/crates/learner/config/retrievers/iacr.toml b/crates/learner/config/retrievers/iacr.toml index 2ad208e..9dac909 100644 --- a/crates/learner/config/retrievers/iacr.toml +++ b/crates/learner/config/retrievers/iacr.toml @@ -2,7 +2,7 @@ base_url = "https://eprint.iacr.org" endpoint_template = "https://eprint.iacr.org/oai?verb=GetRecord&identifier=oai:eprint.iacr.org:{identifier}&metadataPrefix=oai_dc" name = "iacr" pattern = "(?:^|https?://eprint\\.iacr\\.org/)(\\d{4}/\\d+)(?:\\.pdf)?$" -resource_type = "paper" +resource = "config/resources/paper.toml" source = "iacr" [response_format] diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 6f9f048..cd7e17a 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -34,6 +34,7 @@ pub struct RetrieverConfig { /// Name of this retriever configuration pub name: String, /// The type of resource this retriever should yield + #[serde(deserialize_with = "load_resource_config")] pub resource: ResourceConfig, /// Base URL for API requests pub base_url: String, @@ -134,3 +135,22 @@ impl RetrieverConfig { // Ok(resource) } } + +fn load_resource_config<'de, D>(deserializer: D) -> std::result::Result +where D: serde::Deserializer<'de> { + #[derive(Deserialize)] + #[serde(untagged)] + enum ResourceConfigRef { + Inline(ResourceConfig), + Path(PathBuf), + } + + let config_ref = ResourceConfigRef::deserialize(deserializer)?; + match config_ref { + ResourceConfigRef::Inline(config) => Ok(config), + ResourceConfigRef::Path(path) => { + let content = std::fs::read_to_string(&path).map_err(serde::de::Error::custom)?; + toml::from_str(&content).map_err(serde::de::Error::custom) + }, + } +} From e9f2251332eea4c6342757c2998dc0df78eb2fbf Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 07:34:20 -0700 Subject: [PATCH 09/73] cleanup some ci --- .release-plz.toml | 2 +- crates/learner/src/retriever/config.rs | 24 ++++++++++++++++++++++++ crates/sdk/Cargo.toml | 12 +++++++++--- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/.release-plz.toml b/.release-plz.toml index 6154bc2..27f2376 100644 --- a/.release-plz.toml +++ b/.release-plz.toml @@ -13,4 +13,4 @@ publish = true [[package]] name = "learner-sdk" -publish = false +publish = true diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index cd7e17a..0575083 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -76,6 +76,30 @@ impl RetrieverConfig { .ok_or(LearnerError::InvalidIdentifier) } + /// Retrieves a paper using this configuration. + /// + /// This method: + /// 1. Extracts the canonical identifier + /// 2. Constructs the API URL + /// 3. Makes the HTTP request + /// 4. Processes the response + /// + /// # Arguments + /// + /// * `input` - Paper identifier or URL + /// + /// # Returns + /// + /// Returns a Result containing either: + /// - The retrieved Paper object + /// - A LearnerError if any step fails + /// + /// # Errors + /// + /// This method will return an error if: + /// - The identifier cannot be extracted + /// - The HTTP request fails + /// - The response cannot be parsed pub async fn retrieve_paper(&self, input: &str) -> Result { let identifier = self.extract_identifier(input)?; let url = self.endpoint_template.replace("{identifier}", identifier); diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 0315d1d..0888278 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -1,7 +1,13 @@ [package] -edition = "2021" -name = "learner-sdk" -version = "0.1.0" +authors.workspace = true +description = "A simple SDK for making things to learn stuff" +edition.workspace = true +keywords.workspace = true +license.workspace = true +name = "learner-sdk" +readme.workspace = true +repository.workspace = true +version = "0.1.0" [dependencies] clap = { workspace = true } From eed91069f487bcfa006acb961badbb714084689f Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 07:45:43 -0700 Subject: [PATCH 10/73] WIP: improving SDK --- Cargo.lock | 2 + crates/sdk/Cargo.toml | 2 + crates/sdk/src/validate.rs | 220 +++++++++++++++++++++++++++---------- 3 files changed, 168 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 741c94b..1f588f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1197,6 +1197,8 @@ version = "0.1.0" dependencies = [ "clap", "learner", + "regex", + "reqwest", "tempfile", "tokio", "toml", diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 0888278..398c03c 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -12,6 +12,8 @@ version = "0.1.0" [dependencies] clap = { workspace = true } learner = { workspace = true } +regex = { workspace = true } +reqwest = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true } toml = { workspace = true } diff --git a/crates/sdk/src/validate.rs b/crates/sdk/src/validate.rs index 5d4f80e..2b34e9b 100644 --- a/crates/sdk/src/validate.rs +++ b/crates/sdk/src/validate.rs @@ -8,10 +8,14 @@ use learner::{ use super::*; pub fn validate_resource(path: &PathBuf) { + info!("Validating resource configuration from: {}", path.display()); + + // Read and parse the configuration let config_str = match read_to_string(path) { Ok(str) => str, Err(e) => { - error!("Failed to read config to string due to: {e:?}"); + error!("Failed to read config file: {}", e); + error!("Please ensure the file exists and has proper permissions"); return; }, }; @@ -19,22 +23,102 @@ pub fn validate_resource(path: &PathBuf) { let resource: ResourceConfig = match toml::from_str(&config_str) { Ok(config) => config, Err(e) => { - error!("Failed to parse config to string due to: {e:?}"); + error!("Failed to parse TOML configuration: {}", e); + error!("Common issues:"); + error!("- Missing or malformed fields"); + error!("- Incorrect data types"); + error!("- TOML syntax errors"); return; }, }; - info!("Resource type: {}", resource.type_name); + info!("Found resource type: {}", resource.type_name); + if let Some(desc) = &resource.description { + info!("Description: {}", desc); + } + + // Validate field definitions + info!("Validating {} field definitions...", resource.fields.len()); + for field in &resource.fields { + // Check field type validity + match field.field_type.as_str() { + "string" | "integer" | "float" | "boolean" | "datetime" | "array" | "table" => { + info!("Field '{}' ({}):", field.name, field.field_type); + if let Some(desc) = &field.description { + info!(" Description: {}", desc); + } + info!(" Required: {}", field.required); + + // Validate default values match declared type + if let Some(default) = &field.default { + match (field.field_type.as_str(), default) { + ("string", toml::Value::String(_)) + | ("integer", toml::Value::Integer(_)) + | ("float", toml::Value::Float(_)) + | ("boolean", toml::Value::Boolean(_)) + | ("datetime", toml::Value::Datetime(_)) + | ("array", toml::Value::Array(_)) + | ("table", toml::Value::Table(_)) => { + info!(" Default value: valid"); + }, + _ => { + error!(" Default value type doesn't match field type!"); + error!(" Expected {}, got {}", field.field_type, default.type_str()); + }, + } + } + + // Validate validation rules + if let Some(rules) = &field.validation { + info!(" Validation rules:"); + match field.field_type.as_str() { + "string" => { + if let Some(pattern) = &rules.pattern { + match regex::Regex::new(pattern) { + Ok(_) => info!(" - Valid regex pattern"), + Err(e) => error!(" - Invalid regex pattern: {}", e), + } + } + if let Some(min) = rules.min_length { + info!(" - Minimum length: {}", min); + } + if let Some(max) = rules.max_length { + info!(" - Maximum length: {}", max); + } + }, + "array" => { + if let Some(min) = rules.min_items { + info!(" - Minimum items: {}", min); + } + if let Some(max) = rules.max_items { + info!(" - Maximum items: {}", max); + } + if rules.unique_items == Some(true) { + info!(" - Items must be unique"); + } + }, + _ => {}, + } + } + }, + invalid_type => { + error!("Field '{}' has invalid type: {}", field.name, invalid_type); + error!("Valid types are: string, integer, float, boolean, datetime, array, table"); + }, + } + } - // Check all required fields are present - debug!("All config fields are:\n{:#?}", resource.fields()); + info!("Resource configuration validation complete!"); } pub async fn validate_retriever(path: &PathBuf, input: &Option) { + info!("Validating retriever configuration from: {}", path.display()); + let config_str = match read_to_string(path) { Ok(str) => str, Err(e) => { - error!("Failed to read config to string due to: {e:?}"); + error!("Failed to read config file: {}", e); + error!("Please ensure the file exists and has proper permissions"); return; }, }; @@ -42,73 +126,97 @@ pub async fn validate_retriever(path: &PathBuf, input: &Option) { let retriever: RetrieverConfig = match toml::from_str(&config_str) { Ok(config) => config, Err(e) => { - error!("Failed to parse config to string due to: {e:?}"); + error!("Failed to parse TOML configuration: {}", e); + error!("Common issues:"); + error!("- Missing required fields"); + error!("- Incorrect field types"); + error!("- Malformed URLs or patterns"); return; }, }; + // Validate basic configuration + info!("Validating retriever '{}'", retriever.name); + + // Check URL validity + if let Err(e) = reqwest::Url::parse(&retriever.base_url) { + error!("Invalid base URL: {}", e); + return; + } + + // Validate endpoint template + if !retriever.endpoint_template.contains("{identifier}") { + error!("Endpoint template must contain {{identifier}} placeholder"); + return; + } + + // Check response format configuration match &retriever.response_format { ResponseFormat::Xml(config) => { - debug!("Retriever is configured for: XML\n{config:#?}") + info!("XML configuration:"); + info!("- Namespace stripping: {}", config.strip_namespaces); + info!("Validating field mappings:"); + for (field, map) in &config.field_maps { + info!("- {}: {}", field, map.path); + if let Some(transform) = &map.transform { + info!(" with transformation: {:?}", transform); + } + } }, ResponseFormat::Json(config) => { - debug!("Retriever is configured for: JSON\n{config:#?}") + info!("JSON configuration:"); + info!("Validating field mappings:"); + for (field, map) in &config.field_maps { + info!("- {}: {}", field, map.path); + if let Some(transform) = &map.transform { + info!(" with transformation: {:?}", transform); + } + } }, } + // Test pattern matching if input provided if let Some(input) = input { - info!("Attempting to match against pattern..."); + info!("Testing identifier pattern matching..."); match retriever.extract_identifier(input) { - Ok(identifier) => info!("Retriever extracted input into: {identifier}"), - Err(e) => { - error!("Retriever failed to extract input due to: {e:?}"); - return; - }, - } + Ok(identifier) => { + info!("✓ Successfully extracted identifier: {}", identifier); - info!("Attempting to fetch paper using retriever..."); - let paper = match retriever.retrieve_paper(input).await { - Ok(paper) => { - info!("Paper retrieved!\n{paper:#?}"); - paper - }, - Err(e) => { - error!("Retriever failed to retriever paper due to: {e:?}"); - return; - }, - }; - - if paper.pdf_url.is_some() { - info!("Attempting to download associated pdf"); - let tempdir = tempfile::tempdir().unwrap(); - match paper.download_pdf(tempdir.path()).await { - Ok(filename) => { - let pdf_filepath = tempdir.path().join(filename); - if pdf_filepath.exists() { - let bytes = std::fs::read(path).unwrap(); - if bytes.is_empty() { - error!("PDF download was empty."); + // Try fetching + info!("Testing retrieval..."); + match retriever.retrieve_paper(input).await { + Ok(paper) => { + info!("✓ Successfully retrieved paper:"); + info!("Title: {}", paper.title); + info!("Authors: {:?}", paper.authors); + + // Test PDF download if available + if let Some(url) = &paper.pdf_url { + info!("Testing PDF download from: {}", url); + let tempdir = tempfile::tempdir().unwrap(); + match paper.download_pdf(tempdir.path()).await { + Ok(filename) => { + let pdf_path = tempdir.path().join(filename); + if pdf_path.exists() { + let metadata = std::fs::metadata(&pdf_path).unwrap(); + info!("✓ PDF downloaded successfully ({} bytes)", metadata.len()); + } else { + error!("PDF download failed - file not created"); + } + }, + Err(e) => error!("PDF download failed: {}", e), + } } else { - info!("Non-empty PDF downloaded successfully."); + warn!("No PDF URL available"); } - } else { - error!("PDF path did not end up getting written.") - } - }, - Err(e) => { - error!("PDF was unable to be downloaded due to: {e:?}") - }, - } - } else { - warn!( - "PDF URL was not determined. Please check your configuration against the server response." - ); + }, + Err(e) => error!("Retrieval failed: {}", e), + } + }, + Err(e) => error!("Pattern matching failed: {}", e), } } else { - warn!( - "No input string provided to further debug your `RetrieverConfig`. If you want to test \ - identifier pattern matching and online fetching, please pass in an input string with an \ - additional input, e.g., `2301.07041`." - ); + info!("No test input provided - skipping retrieval tests"); + info!("To test retrieval, provide an identifier like: 2301.07041"); } } From 3d314fce322c9cf97bddf6ec23f3964aafbd422d Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 08:31:44 -0700 Subject: [PATCH 11/73] WIP: fixing tests/adjusting `Learner` --- crates/learner/config/resources/book.toml | 2 +- crates/learner/config/resources/paper.toml | 2 +- crates/learner/config/resources/thesis.toml | 2 +- crates/learner/config/retrievers/arxiv.toml | 2 +- crates/learner/src/configuration.rs | 44 ++++++ crates/learner/src/lib.rs | 53 ++++++- crates/learner/src/resource/mod.rs | 23 ++- crates/learner/src/retriever/config.rs | 7 +- crates/learner/src/retriever/mod.rs | 131 ++---------------- crates/learner/tests/lib.rs | 1 + .../tests/workflows/build_retriever.rs | 6 +- crates/sdk/src/validate.rs | 2 +- 12 files changed, 139 insertions(+), 136 deletions(-) create mode 100644 crates/learner/src/configuration.rs diff --git a/crates/learner/config/resources/book.toml b/crates/learner/config/resources/book.toml index 1607b12..8b14ebb 100644 --- a/crates/learner/config/resources/book.toml +++ b/crates/learner/config/resources/book.toml @@ -1,4 +1,4 @@ -type_name = "book" +name = "book" description = "A published book, including textbooks, monographs, and edited volumes" diff --git a/crates/learner/config/resources/paper.toml b/crates/learner/config/resources/paper.toml index d971af1..bf95c4d 100644 --- a/crates/learner/config/resources/paper.toml +++ b/crates/learner/config/resources/paper.toml @@ -1,4 +1,4 @@ -type_name = "paper" +name = "paper" description = "A scholarly paper or article published in an academic context" diff --git a/crates/learner/config/resources/thesis.toml b/crates/learner/config/resources/thesis.toml index fe1a457..37b5653 100644 --- a/crates/learner/config/resources/thesis.toml +++ b/crates/learner/config/resources/thesis.toml @@ -1,4 +1,4 @@ -type_name = "thesis" +name = "thesis" description = "A master's thesis or doctoral dissertation" diff --git a/crates/learner/config/retrievers/arxiv.toml b/crates/learner/config/retrievers/arxiv.toml index 9d474e5..f68f629 100644 --- a/crates/learner/config/retrievers/arxiv.toml +++ b/crates/learner/config/retrievers/arxiv.toml @@ -2,7 +2,7 @@ base_url = "http://export.arxiv.org" endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}&max_results=1" name = "arxiv" pattern = "(?:^|https?://arxiv\\.org/(?:abs|pdf)/)(\\d{4}\\.\\d{4,5}|[a-zA-Z-]+(?:\\.[A-Z]{2})?/\\d{7})(?:\\.pdf)?$" -resource = "config/resources/paper.toml" +resource = "paper.toml" source = "arxiv" [response_format] diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs new file mode 100644 index 0000000..e22b920 --- /dev/null +++ b/crates/learner/src/configuration.rs @@ -0,0 +1,44 @@ +use super::*; + +pub trait Identifiable { + fn name(&self) -> String; +} + +pub trait Configurable: Sized { + type Config: Identifiable + for<'de> Deserialize<'de>; + fn insert(&mut self, config_name: String, config: Self::Config); + + fn with_config(mut self, config: Self::Config) { self.insert(config.name(), config); } + + fn with_config_str(mut self, toml_str: &str) -> Result { + let config: Self::Config = toml::from_str(toml_str)?; + self.insert(config.name(), config); + Ok(self) + } + + fn with_config_file(self, path: impl AsRef) -> Result { + let content = std::fs::read_to_string(path)?; + self.with_config_str(&content) + } + + fn with_config_dir(self, dir: impl AsRef) -> Result { + let dir = dir.as_ref(); + dbg!(&dir); + if !dir.is_dir() { + return Err(LearnerError::Path(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Config directory not found", + ))); + } + + let mut configurable = self; + for entry in std::fs::read_dir(dir)? { + let entry = entry?; + let path = entry.path(); + if path.extension().is_some_and(|ext| ext == "toml") { + configurable = configurable.with_config_file(path)?; + } + } + Ok(configurable) + } +} diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index fd0e67d..fc0791c 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -158,6 +158,7 @@ use chrono::{DateTime, Utc}; use lazy_static::lazy_static; use regex::Regex; use reqwest::Url; +use resource::{ResourceConfig, Resources}; use serde::{Deserialize, Serialize}; use tracing::{debug, trace, warn}; #[cfg(test)] @@ -166,6 +167,7 @@ use {tempfile::tempdir, tracing_test::traced_test}; pub mod database; pub mod retriever; +pub mod configuration; pub mod error; pub mod format; pub mod llm; @@ -187,6 +189,13 @@ pub const DOI_CONFIG: &str = include_str!("../config/retrievers/doi.toml"); /// IACR default configuration pub const IACR_CONFIG: &str = include_str!("../config/retrievers/iacr.toml"); +/// Paper default configuration +pub const PAPER_CONFIG: &str = include_str!("../config/resources/paper.toml"); +/// Book default configuration +pub const BOOK_CONFIG: &str = include_str!("../config/resources/book.toml"); +/// Thesis default configuration +pub const THESIS_CONFIG: &str = include_str!("../config/resources/thesis.toml"); + /// Common traits and types for ergonomic imports. /// /// This module provides a convenient way to import frequently used traits @@ -217,7 +226,10 @@ pub const IACR_CONFIG: &str = include_str!("../config/retrievers/iacr.toml"); /// ``` pub mod prelude { pub use crate::{ - database::DatabaseInstruction, error::LearnerError, retriever::ResponseProcessor, + configuration::{Configurable, Identifiable}, + database::DatabaseInstruction, + error::LearnerError, + retriever::ResponseProcessor, }; } @@ -253,6 +265,10 @@ pub struct Config { /// The path to load retriever configs from. #[serde(default = "Config::default_retrievers_path")] pub retrievers_path: PathBuf, + + /// The path to load retriever configs from. + #[serde(default = "Config::default_resources_path")] + pub resources_path: PathBuf, } // TODO: We should really let the database storage path be set prior to opening. We need a slightly @@ -284,6 +300,8 @@ pub struct Learner { pub database: Database, /// Paper retrieval system pub retriever: Retriever, + /// Resources to use + pub resources: Resources, } /// Builder for creating configured Learner instances. @@ -341,6 +359,14 @@ impl Config { Self::default_path().unwrap_or_else(|_| PathBuf::from(".")).join("retrievers") } + /// Returns the default path for resource configuration files. + /// + /// The path is constructed as `{config_dir}/retrievers` where + /// config_dir is determined by [`default_path()`](Config::default_path). + pub fn default_resources_path() -> PathBuf { + Self::default_path().unwrap_or_else(|_| PathBuf::from(".")).join("resources") + } + /// Loads existing configuration or creates new with defaults. /// /// Looks for configuration file at the default path. If not found, @@ -405,6 +431,14 @@ impl Config { let config = Self::default(); config.save()?; + // Write example resource configs + let resources_dir = &config.resources_path; + std::fs::create_dir_all(resources_dir)?; + + std::fs::write(resources_dir.join("paper.toml"), PAPER_CONFIG)?; + std::fs::write(resources_dir.join("book.toml"), BOOK_CONFIG)?; + std::fs::write(resources_dir.join("thesis.toml"), THESIS_CONFIG)?; + // Write example retriever configs let retrievers_dir = &config.retrievers_path; std::fs::create_dir_all(retrievers_dir)?; @@ -444,6 +478,16 @@ impl Config { self } + /// Sets the path for retriever configuration files. + /// + /// # Arguments + /// + /// * `retrievers_path` - Directory where retriever TOML configs are stored + pub fn with_resources_path(mut self, resources_path: &Path) -> Self { + self.resources_path = resources_path.to_path_buf(); + self + } + /// Sets the path for paper document storage. /// /// # Arguments @@ -461,6 +505,7 @@ impl Default for Config { database_path: Database::default_path(), storage_path: Database::default_storage_path(), retrievers_path: Self::default_retrievers_path(), + resources_path: Self::default_resources_path(), } } } @@ -537,6 +582,7 @@ impl LearnerBuilder { }; // Ensure paths exist + std::fs::create_dir_all(&config.resources_path)?; std::fs::create_dir_all(&config.retrievers_path)?; if let Some(parent) = config.database_path.parent() { std::fs::create_dir_all(parent)?; @@ -547,8 +593,9 @@ impl LearnerBuilder { database.set_storage_path(&config.storage_path).await?; let retriever = Retriever::new().with_config_dir(&config.retrievers_path)?; + let resources = Resources::new().with_config_dir(&config.resources_path)?; - Ok(Learner { config, database, retriever }) + Ok(Learner { config, database, retriever, resources }) } } @@ -713,11 +760,13 @@ mod tests { let storage_dir = tempdir().unwrap(); let config = Config::default() .with_database_path(&database_dir.path().join("learner.db")) + .with_resources_path(&config_dir.path().join("config/resources/")) .with_retrievers_path(&config_dir.path().join("config/retrievers/")) .with_storage_path(storage_dir.path()); let learner = Learner::builder().with_path(config_dir.path()).with_config(config).build().await.unwrap(); + assert_eq!(learner.config.resources_path, config_dir.path().join("config/resources/")); assert_eq!(learner.config.retrievers_path, config_dir.path().join("config/retrievers/")); assert_eq!(learner.config.database_path, database_dir.path().join("learner.db")); assert_eq!(learner.database.get_storage_path().await.unwrap(), storage_dir.path()); diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index a17f037..bc50481 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -9,10 +9,27 @@ pub use paper::*; pub use shared::*; use toml::Value; +#[derive(Debug, Clone, Default)] +pub struct Resources { + resource_configs: BTreeMap, +} + +impl Resources { + pub fn new() -> Self { Self::default() } +} + +impl Configurable for Resources { + type Config = ResourceConfig; + + fn insert(&mut self, config_name: String, config: Self::Config) { + self.resource_configs.insert(config_name, config); + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ResourceConfig { /// The type identifier for this resource - pub type_name: String, + pub name: String, /// Optional description of this resource type #[serde(default)] pub description: Option, @@ -21,6 +38,10 @@ pub struct ResourceConfig { pub fields: Vec, } +impl Identifiable for ResourceConfig { + fn name(&self) -> String { self.name.clone() } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FieldDefinition { /// Name of the field diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 0575083..cb8c7d9 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,6 +1,3 @@ -use resource::ResourceConfig; -use response::ResponseFormat; - use super::*; /// Configuration for a specific paper source retriever. @@ -52,6 +49,10 @@ pub struct RetrieverConfig { pub headers: HashMap, } +impl Identifiable for RetrieverConfig { + fn name(&self) -> String { self.name.clone() } +} + impl RetrieverConfig { /// Extracts the canonical identifier from an input string. /// diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 34f30f4..0fbf8d9 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -101,6 +101,14 @@ pub struct Retriever { configs: HashMap, } +impl Configurable for Retriever { + type Config = RetrieverConfig; + + fn insert(&mut self, config_name: String, config: Self::Config) { + self.configs.insert(config_name, config); + } +} + impl Retriever { /// Checks whether the retreivers map is empty. /// @@ -137,129 +145,6 @@ impl Retriever { /// ``` pub fn new() -> Self { Self::default() } - /// Adds a retriever configuration to this instance. - /// - /// This method configures support for a new paper source using the provided - /// configuration. Multiple configurations can be added to support different sources. - /// - /// # Arguments - /// - /// * `config` - Configuration for the paper source - /// - /// # Examples - /// - /// ```no_run - /// # use learner::retriever::{Retriever, RetrieverConfig}; - /// # fn example(config: RetrieverConfig) { - /// let retriever = Retriever::new().with_config(config); - /// # } - /// ``` - pub fn with_config(mut self, config: RetrieverConfig) { - self.configs.insert(config.name.clone(), config); - } - - /// Adds a retriever configuration from a TOML string. - /// - /// Parses the provided TOML string into a RetrieverConfig and adds it - /// to this instance. - /// - /// # Arguments - /// - /// * `toml_str` - TOML configuration string - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - The updated Retriever instance - /// - A LearnerError if parsing fails - /// - /// # Examples - /// - /// ```no_run - /// # use learner::retriever::Retriever; - /// let toml = r#" - /// name = "arxiv" - /// base_url = "http://export.arxiv.org/api/query" - /// pattern = "^\\d{4}\\.\\d{4,5}$" - /// source = "arxiv" - /// endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}" - /// "#; - /// - /// let retriever = Retriever::new().with_config_str(toml)?; - /// # Ok::<(), Box>(()) - /// ``` - pub fn with_config_str(mut self, toml_str: &str) -> Result { - let config: RetrieverConfig = toml::from_str(toml_str)?; - self.configs.insert(config.name.clone(), config); - Ok(self) - } - - /// Adds a retriever configuration from a TOML file. - /// - /// # Arguments - /// - /// * `path` - Path to TOML configuration file - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - The updated Retriever instance - /// - A LearnerError if reading or parsing fails - /// - /// # Examples - /// - /// ```no_run - /// # use learner::retriever::Retriever; - /// let retriever = Retriever::new().with_config_file("config/arxiv.toml")?; - /// # Ok::<(), Box>(()) - /// ``` - pub fn with_config_file(self, path: impl AsRef) -> Result { - let content = std::fs::read_to_string(path)?; - self.with_config_str(&content) - } - - /// Adds multiple configurations from a directory of TOML files. - /// - /// This method loads all .toml files from the specified directory and - /// adds them as configurations. - /// - /// # Arguments - /// - /// * `dir` - Path to directory containing TOML configuration files - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - The updated Retriever instance - /// - A LearnerError if directory access or parsing fails - /// - /// # Examples - /// - /// ```no_run - /// # use learner::retriever::Retriever; - /// let retriever = Retriever::new().with_config_dir("config/")?; - /// # Ok::<(), Box>(()) - /// ``` - pub fn with_config_dir(self, dir: impl AsRef) -> Result { - let dir = dir.as_ref(); - if !dir.is_dir() { - return Err(LearnerError::Path(std::io::Error::new( - std::io::ErrorKind::NotFound, - "Config directory not found", - ))); - } - - let mut retriever = self; - for entry in std::fs::read_dir(dir)? { - let entry = entry?; - let path = entry.path(); - if path.extension().is_some_and(|ext| ext == "toml") { - retriever = retriever.with_config_file(path)?; - } - } - Ok(retriever) - } - /// Attempts to retrieve a paper using any matching configuration. /// /// This method tries to match the input against all configured retrievers diff --git a/crates/learner/tests/lib.rs b/crates/learner/tests/lib.rs index cec2da9..3e73d55 100644 --- a/crates/learner/tests/lib.rs +++ b/crates/learner/tests/lib.rs @@ -28,6 +28,7 @@ pub async fn create_test_learner() -> (Learner, TempDir, TempDir, TempDir) { let config = Config::default() .with_database_path(&database_dir.path().join("learner.db")) .with_retrievers_path(Path::new("config/retrievers/")) + .with_resources_path(Path::new("config/resources/")) .with_storage_path(storage_dir.path()); let learner = Learner::builder().with_path(config_dir.path()).with_config(config).build().await.unwrap(); diff --git a/crates/learner/tests/workflows/build_retriever.rs b/crates/learner/tests/workflows/build_retriever.rs index 8477eb4..71c61cb 100644 --- a/crates/learner/tests/workflows/build_retriever.rs +++ b/crates/learner/tests/workflows/build_retriever.rs @@ -4,11 +4,13 @@ use learner::retriever::{ResponseFormat, Transform}; use super::*; -#[test] -fn test_arxiv_config_deserialization() { +#[tokio::test] +async fn test_arxiv_config_deserialization() { let config_str = read_to_string("config/retrievers/arxiv.toml").expect("Failed to read config file"); + // let (learner, _config_dir, _database_dir, _storage_dir) = create_test_learner().await; + let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); // Verify basic fields diff --git a/crates/sdk/src/validate.rs b/crates/sdk/src/validate.rs index 2b34e9b..7ebc8f4 100644 --- a/crates/sdk/src/validate.rs +++ b/crates/sdk/src/validate.rs @@ -32,7 +32,7 @@ pub fn validate_resource(path: &PathBuf) { }, }; - info!("Found resource type: {}", resource.type_name); + info!("Found resource type: {}", resource.name); if let Some(desc) = &resource.description { info!("Description: {}", desc); } From 894c3e10f3ff8d25bc321b5aea5d7108fd79df31 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 09:24:57 -0700 Subject: [PATCH 12/73] refactor + fix tests --- crates/learner/config/retrievers/arxiv.toml | 2 +- crates/learner/src/configuration.rs | 9 +- crates/learner/src/lib.rs | 4 +- crates/learner/src/resource/mod.rs | 32 +-- crates/learner/src/retriever/mod.rs | 237 ++++++++++++++++-- crates/learner/tests/lib.rs | 1 - .../tests/workflows/build_retriever.rs | 188 -------------- crates/learner/tests/workflows/mod.rs | 1 - 8 files changed, 234 insertions(+), 240 deletions(-) delete mode 100644 crates/learner/tests/workflows/build_retriever.rs diff --git a/crates/learner/config/retrievers/arxiv.toml b/crates/learner/config/retrievers/arxiv.toml index f68f629..9d474e5 100644 --- a/crates/learner/config/retrievers/arxiv.toml +++ b/crates/learner/config/retrievers/arxiv.toml @@ -2,7 +2,7 @@ base_url = "http://export.arxiv.org" endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}&max_results=1" name = "arxiv" pattern = "(?:^|https?://arxiv\\.org/(?:abs|pdf)/)(\\d{4}\\.\\d{4,5}|[a-zA-Z-]+(?:\\.[A-Z]{2})?/\\d{7})(?:\\.pdf)?$" -resource = "paper.toml" +resource = "config/resources/paper.toml" source = "arxiv" [response_format] diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index e22b920..43fddec 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -1,3 +1,5 @@ +use std::ops::{Index, IndexMut}; + use super::*; pub trait Identifiable { @@ -6,13 +8,13 @@ pub trait Identifiable { pub trait Configurable: Sized { type Config: Identifiable + for<'de> Deserialize<'de>; - fn insert(&mut self, config_name: String, config: Self::Config); + fn as_map(&mut self) -> &mut BTreeMap; - fn with_config(mut self, config: Self::Config) { self.insert(config.name(), config); } + fn with_config(mut self, config: Self::Config) { self.as_map().insert(config.name(), config); } fn with_config_str(mut self, toml_str: &str) -> Result { let config: Self::Config = toml::from_str(toml_str)?; - self.insert(config.name(), config); + self.as_map().insert(config.name(), config); Ok(self) } @@ -23,7 +25,6 @@ pub trait Configurable: Sized { fn with_config_dir(self, dir: impl AsRef) -> Result { let dir = dir.as_ref(); - dbg!(&dir); if !dir.is_dir() { return Err(LearnerError::Path(std::io::Error::new( std::io::ErrorKind::NotFound, diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index fc0791c..6fa978e 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -299,7 +299,7 @@ pub struct Learner { /// Database connection and operations pub database: Database, /// Paper retrieval system - pub retriever: Retriever, + pub retriever: Retrievers, /// Resources to use pub resources: Resources, } @@ -592,7 +592,7 @@ impl LearnerBuilder { let database = Database::open(&config.database_path).await?; database.set_storage_path(&config.storage_path).await?; - let retriever = Retriever::new().with_config_dir(&config.retrievers_path)?; + let retriever = Retrievers::new().with_config_dir(&config.retrievers_path)?; let resources = Resources::new().with_config_dir(&config.resources_path)?; Ok(Learner { config, database, retriever, resources }) diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index bc50481..2b266fd 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -11,7 +11,7 @@ use toml::Value; #[derive(Debug, Clone, Default)] pub struct Resources { - resource_configs: BTreeMap, + configs: BTreeMap, } impl Resources { @@ -21,9 +21,7 @@ impl Resources { impl Configurable for Resources { type Config = ResourceConfig; - fn insert(&mut self, config_name: String, config: Self::Config) { - self.resource_configs.insert(config_name, config); - } + fn as_map(&mut self) -> &mut BTreeMap { &mut self.configs } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -283,15 +281,14 @@ pub fn toml_to_chrono_datetime(dt: toml::value::Datetime) -> DateTime { #[cfg(test)] mod tests { - use chrono::{TimeZone, Utc}; + use chrono::TimeZone; use super::*; #[test] - fn test_paper_configuration() -> Result<()> { - // Load the paper configuration + fn validate_paper_configuration() { let config = include_str!("../../config/resources/paper.toml"); - let config: ResourceConfig = toml::from_str(config)?; + let config: ResourceConfig = toml::from_str(config).unwrap(); let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); @@ -310,21 +307,19 @@ mod tests { ]); // Validate the paper - assert!(config.validate(&paper_values)?); + assert!(config.validate(&paper_values).unwrap()); // Test required field validation let invalid_paper = toml::value::Table::from_iter([ ("authors".into(), toml::Value::Array(vec![])), // Missing title ]); assert!(config.validate(&invalid_paper).is_err()); - - Ok(()) } #[test] - fn test_book_configuration() -> Result<()> { + fn validate_book_configuration() { let config = include_str!("../../config/resources/book.toml"); - let config: ResourceConfig = toml::from_str(config)?; + let config: ResourceConfig = toml::from_str(config).unwrap(); let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); @@ -342,14 +337,13 @@ mod tests { ("publication_date".into(), toml::Value::Datetime(date)), ]); - assert!(config.validate(&book_values)?); - Ok(()) + assert!(config.validate(&book_values).unwrap()); } #[test] - fn test_thesis_configuration() -> Result<()> { + fn validate_thesis_configuration() { let config = include_str!("../../config/resources/thesis.toml"); - let config: ResourceConfig = toml::from_str(config)?; + let config: ResourceConfig = toml::from_str(config).unwrap(); let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); @@ -365,13 +359,11 @@ mod tests { ), ]); - assert!(config.validate(&thesis_values)?); + assert!(config.validate(&thesis_values).unwrap()); // Test degree enum validation let mut invalid_thesis = thesis_values.clone(); invalid_thesis.insert("degree".into(), toml::Value::String("InvalidDegree".into())); assert!(config.validate(&invalid_thesis).is_err()); - - Ok(()) } } diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 0fbf8d9..62900b6 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -8,7 +8,7 @@ //! //! The retriever system consists of several key components: //! -//! - [`Retriever`]: Main entry point for paper retrieval operations +//! - [`Retrievers`]: Main entry point for paper retrieval operations //! - [`RetrieverConfig`]: Configuration for specific paper sources //! - [`ResponseFormat`]: Format-specific parsing logic (XML/JSON) //! - [`ResponseProcessor`]: Trait for processing API responses @@ -35,12 +35,15 @@ //! Configure and use a retriever: //! //! ```no_run -//! use learner::retriever::{Retriever, RetrieverConfig}; +//! use learner::{ +//! prelude::*, +//! retriever::{RetrieverConfig, Retrievers}, +//! }; //! //! # async fn example() -> Result<(), Box> { //! // Create a new retriever //! let retriever = -//! Retriever::new().with_config_file("config/arxiv.toml")?.with_config_file("config/doi.toml")?; +//! Retrievers::new().with_config_file("config/arxiv.toml")?.with_config_file("config/doi.toml")?; //! //! // Retrieve a paper (automatically detects source) //! let paper = retriever.get_paper("10.1145/1327452.1327492").await?; @@ -52,10 +55,11 @@ //! Load multiple configurations: //! //! ```no_run -//! # use learner::retriever::Retriever; +//! # use learner::retriever::Retrievers; +//! # use learner::prelude::*; //! # async fn example() -> Result<(), Box> { //! // Load all TOML configs from a directory -//! let retriever = Retriever::new().with_config_dir("config/")?; +//! let retriever = Retrievers::new().with_config_dir("config/")?; //! //! // Retriever will automatically match source based on input format //! let arxiv_paper = retriever.get_paper("2301.07041").await?; @@ -84,9 +88,10 @@ pub use response::*; /// # Examples /// /// ```no_run -/// # use learner::retriever::Retriever; +/// # use learner::retriever::Retrievers; +/// # use learner::prelude::*; /// # async fn example() -> Result<(), Box> { -/// let retriever = Retriever::new().with_config_dir("config/")?; +/// let retriever = Retrievers::new().with_config_dir("config/")?; /// /// // Retrieve papers from different sources /// let paper1 = retriever.get_paper("2301.07041").await?; // arXiv @@ -96,20 +101,18 @@ pub use response::*; /// # } /// ``` #[derive(Default, Debug, Clone)] -pub struct Retriever { +pub struct Retrievers { /// The collection of configurations used for this [`Retriever`]. - configs: HashMap, + configs: BTreeMap, } -impl Configurable for Retriever { +impl Configurable for Retrievers { type Config = RetrieverConfig; - fn insert(&mut self, config_name: String, config: Self::Config) { - self.configs.insert(config_name, config); - } + fn as_map(&mut self) -> &mut BTreeMap { &mut self.configs } } -impl Retriever { +impl Retrievers { /// Checks whether the retreivers map is empty. /// /// This is useful for handling the case where no retreivers are specified and @@ -118,11 +121,11 @@ impl Retriever { /// # Examples /// /// ```no_run - /// # use learner::retriever::Retriever; + /// # use learner::retriever::Retrievers; /// # use learner::error::LearnerError; /// /// # fn check_is_empty() -> Result<(), LearnerError> { - /// let retriever = Retriever::new(); + /// let retriever = Retrievers::new(); /// /// if retriever.is_empty() { /// return Err(LearnerError::Config("No retriever configured.".to_string())); @@ -133,15 +136,15 @@ impl Retriever { pub fn is_empty(&self) -> bool { self.configs.is_empty() } } -impl Retriever { +impl Retrievers { /// Creates a new empty retriever with no configurations. /// /// # Examples /// /// ```no_run - /// use learner::retriever::Retriever; + /// use learner::retriever::Retrievers; /// - /// let retriever = Retriever::new(); + /// let retriever = Retrievers::new(); /// ``` pub fn new() -> Self { Self::default() } @@ -170,9 +173,10 @@ impl Retriever { /// # Examples /// /// ```no_run - /// # use learner::retriever::Retriever; + /// # use learner::retriever::Retrievers; + /// # use learner::prelude::*; /// # async fn example() -> Result<(), Box> { - /// let retriever = Retriever::new().with_config_dir("config/")?; + /// let retriever = Retrievers::new().with_config_dir("config/")?; /// /// // Retrieve from different sources /// let paper1 = retriever.get_paper("2301.07041").await?; @@ -226,9 +230,10 @@ impl Retriever { /// # Examples /// /// ``` - /// # use learner::retriever::Retriever; + /// # use learner::retriever::Retrievers; + /// # use learner::prelude::*; /// # async fn example() -> Result<(), Box> { - /// let retriever = Retriever::new().with_config_dir("config/")?; + /// let retriever = Retrievers::new().with_config_dir("config/")?; /// /// // Sanitize an arXiv URL /// let (source, id) = retriever.sanitize_identifier("https://arxiv.org/abs/2301.07041")?; @@ -322,3 +327,189 @@ fn apply_transform(value: &str, transform: &Transform) -> Result { Ok(format!("{}{}", base.replace("{value}", value), suffix.as_deref().unwrap_or(""))), } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_arxiv_config() { + let config_str = include_str!("../../config/retrievers/arxiv.toml"); + + let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + + // Verify basic fields + assert_eq!(retriever.name, "arxiv"); + assert_eq!(retriever.base_url, "http://export.arxiv.org"); + assert_eq!(retriever.source, "arxiv"); + + // Test pattern matching + assert!(retriever.pattern.is_match("2301.07041")); + assert!(retriever.pattern.is_match("math.AG/0601001")); + assert!(retriever.pattern.is_match("https://arxiv.org/abs/2301.07041")); + assert!(retriever.pattern.is_match("https://arxiv.org/pdf/2301.07041")); + assert!(retriever.pattern.is_match("https://arxiv.org/abs/math.AG/0601001")); + assert!(retriever.pattern.is_match("https://arxiv.org/abs/math/0404443")); + + // Test identifier extraction + assert_eq!(retriever.extract_identifier("2301.07041").unwrap(), "2301.07041"); + assert_eq!( + retriever.extract_identifier("https://arxiv.org/abs/2301.07041").unwrap(), + "2301.07041" + ); + assert_eq!(retriever.extract_identifier("math.AG/0601001").unwrap(), "math.AG/0601001"); + + // Verify response format + + if let ResponseFormat::Xml(config) = &retriever.response_format { + assert!(config.strip_namespaces); + + // Verify field mappings + let field_maps = &config.field_maps; + assert!(field_maps.contains_key("title")); + assert!(field_maps.contains_key("abstract")); + assert!(field_maps.contains_key("authors")); + assert!(field_maps.contains_key("publication_date")); + assert!(field_maps.contains_key("pdf_url")); + + // Verify PDF transform + if let Some(map) = field_maps.get("pdf_url") { + match &map.transform { + Some(Transform::Replace { pattern, replacement }) => { + assert_eq!(pattern, "/abs/"); + assert_eq!(replacement, "/pdf/"); + }, + _ => panic!("Expected Replace transform for pdf_url"), + } + } else { + panic!("Missing pdf_url field map"); + } + } else { + panic!("Expected an XML configuration, but did not get one.") + } + + // Verify headers + assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); + } + + #[test] + fn test_doi_config_deserialization() { + let config_str = include_str!("../../config/retrievers/doi.toml"); + + let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + + // Verify basic fields + assert_eq!(retriever.name, "doi"); + assert_eq!(retriever.base_url, "https://api.crossref.org/works"); + assert_eq!(retriever.source, "doi"); + + // Test pattern matching + let test_cases = [ + ("10.1145/1327452.1327492", true), + ("https://doi.org/10.1145/1327452.1327492", true), + ("invalid-doi", false), + ("https://wrong.url/10.1145/1327452.1327492", false), + ]; + + for (input, expected) in test_cases { + assert_eq!( + retriever.pattern.is_match(input), + expected, + "Pattern match failed for input: {}", + input + ); + } + + // Test identifier extraction + assert_eq!( + retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), + "10.1145/1327452.1327492" + ); + assert_eq!( + retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), + "10.1145/1327452.1327492" + ); + + // Verify response format + match &retriever.response_format { + ResponseFormat::Json(config) => { + // Verify field mappings + let field_maps = &config.field_maps; + assert!(field_maps.contains_key("title")); + assert!(field_maps.contains_key("abstract")); + assert!(field_maps.contains_key("authors")); + assert!(field_maps.contains_key("publication_date")); + assert!(field_maps.contains_key("pdf_url")); + assert!(field_maps.contains_key("doi")); + }, + _ => panic!("Expected JSON response format"), + } + } + + #[test] + fn test_iacr_config_deserialization() { + let config_str = include_str!("../../config/retrievers/iacr.toml"); + + let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + + // Verify basic fields + assert_eq!(retriever.name, "iacr"); + assert_eq!(retriever.base_url, "https://eprint.iacr.org"); + assert_eq!(retriever.source, "iacr"); + + // Test pattern matching + let test_cases = [ + ("2016/260", true), + ("2023/123", true), + ("https://eprint.iacr.org/2016/260", true), + ("https://eprint.iacr.org/2016/260.pdf", true), + ("invalid/format", false), + ("https://wrong.url/2016/260", false), + ]; + + for (input, expected) in test_cases { + assert_eq!( + retriever.pattern.is_match(input), + expected, + "Pattern match failed for input: {}", + input + ); + } + + // Test identifier extraction + assert_eq!(retriever.extract_identifier("2016/260").unwrap(), "2016/260"); + assert_eq!( + retriever.extract_identifier("https://eprint.iacr.org/2016/260").unwrap(), + "2016/260" + ); + assert_eq!( + retriever.extract_identifier("https://eprint.iacr.org/2016/260.pdf").unwrap(), + "2016/260" + ); + + // Verify response format + if let ResponseFormat::Xml(config) = &retriever.response_format { + assert!(config.strip_namespaces); + + // Verify field mappings + let field_maps = &config.field_maps; + assert!(field_maps.contains_key("title")); + assert!(field_maps.contains_key("abstract")); + assert!(field_maps.contains_key("authors")); + assert!(field_maps.contains_key("publication_date")); + assert!(field_maps.contains_key("pdf_url")); + + // Verify OAI-PMH paths + if let Some(map) = field_maps.get("title") { + assert!(map.path.contains(&"OAI-PMH/GetRecord/record/metadata/dc/title".to_string())); + } else { + panic!("Missing title field map"); + } + } else { + panic!("Expected an XML configuration, but did not get one.") + } + + // Verify headers + assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); + } +} diff --git a/crates/learner/tests/lib.rs b/crates/learner/tests/lib.rs index 3e73d55..0a88dd5 100644 --- a/crates/learner/tests/lib.rs +++ b/crates/learner/tests/lib.rs @@ -20,7 +20,6 @@ mod workflows; pub type TestResult = Result>; -// #[tokio::test] pub async fn create_test_learner() -> (Learner, TempDir, TempDir, TempDir) { let config_dir = tempdir().unwrap(); let database_dir = tempdir().unwrap(); diff --git a/crates/learner/tests/workflows/build_retriever.rs b/crates/learner/tests/workflows/build_retriever.rs deleted file mode 100644 index 71c61cb..0000000 --- a/crates/learner/tests/workflows/build_retriever.rs +++ /dev/null @@ -1,188 +0,0 @@ -use std::fs::read_to_string; - -use learner::retriever::{ResponseFormat, Transform}; - -use super::*; - -#[tokio::test] -async fn test_arxiv_config_deserialization() { - let config_str = - read_to_string("config/retrievers/arxiv.toml").expect("Failed to read config file"); - - // let (learner, _config_dir, _database_dir, _storage_dir) = create_test_learner().await; - - let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); - - // Verify basic fields - assert_eq!(retriever.name, "arxiv"); - assert_eq!(retriever.base_url, "http://export.arxiv.org"); - assert_eq!(retriever.source, "arxiv"); - - // Test pattern matching - assert!(retriever.pattern.is_match("2301.07041")); - assert!(retriever.pattern.is_match("math.AG/0601001")); - assert!(retriever.pattern.is_match("https://arxiv.org/abs/2301.07041")); - assert!(retriever.pattern.is_match("https://arxiv.org/pdf/2301.07041")); - assert!(retriever.pattern.is_match("https://arxiv.org/abs/math.AG/0601001")); - assert!(retriever.pattern.is_match("https://arxiv.org/abs/math/0404443")); - - // Test identifier extraction - assert_eq!(retriever.extract_identifier("2301.07041").unwrap(), "2301.07041"); - assert_eq!( - retriever.extract_identifier("https://arxiv.org/abs/2301.07041").unwrap(), - "2301.07041" - ); - assert_eq!(retriever.extract_identifier("math.AG/0601001").unwrap(), "math.AG/0601001"); - - // Verify response format - - if let ResponseFormat::Xml(config) = &retriever.response_format { - assert!(config.strip_namespaces); - - // Verify field mappings - let field_maps = &config.field_maps; - assert!(field_maps.contains_key("title")); - assert!(field_maps.contains_key("abstract")); - assert!(field_maps.contains_key("authors")); - assert!(field_maps.contains_key("publication_date")); - assert!(field_maps.contains_key("pdf_url")); - - // Verify PDF transform - if let Some(map) = field_maps.get("pdf_url") { - match &map.transform { - Some(Transform::Replace { pattern, replacement }) => { - assert_eq!(pattern, "/abs/"); - assert_eq!(replacement, "/pdf/"); - }, - _ => panic!("Expected Replace transform for pdf_url"), - } - } else { - panic!("Missing pdf_url field map"); - } - } else { - panic!("Expected an XML configuration, but did not get one.") - } - - // Verify headers - assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); -} - -#[test] -fn test_doi_config_deserialization() { - let config_str = - read_to_string("config/retrievers/doi.toml").expect("Failed to read config file"); - - let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); - - // Verify basic fields - assert_eq!(retriever.name, "doi"); - assert_eq!(retriever.base_url, "https://api.crossref.org/works"); - assert_eq!(retriever.source, "doi"); - - // Test pattern matching - let test_cases = [ - ("10.1145/1327452.1327492", true), - ("https://doi.org/10.1145/1327452.1327492", true), - ("invalid-doi", false), - ("https://wrong.url/10.1145/1327452.1327492", false), - ]; - - for (input, expected) in test_cases { - assert_eq!( - retriever.pattern.is_match(input), - expected, - "Pattern match failed for input: {}", - input - ); - } - - // Test identifier extraction - assert_eq!( - retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), - "10.1145/1327452.1327492" - ); - assert_eq!( - retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), - "10.1145/1327452.1327492" - ); - - // Verify response format - match &retriever.response_format { - ResponseFormat::Json(config) => { - // Verify field mappings - let field_maps = &config.field_maps; - assert!(field_maps.contains_key("title")); - assert!(field_maps.contains_key("abstract")); - assert!(field_maps.contains_key("authors")); - assert!(field_maps.contains_key("publication_date")); - assert!(field_maps.contains_key("pdf_url")); - assert!(field_maps.contains_key("doi")); - }, - _ => panic!("Expected JSON response format"), - } -} - -#[test] -fn test_iacr_config_deserialization() { - let config_str = - read_to_string("config/retrievers/iacr.toml").expect("Failed to read config file"); - - let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); - - // Verify basic fields - assert_eq!(retriever.name, "iacr"); - assert_eq!(retriever.base_url, "https://eprint.iacr.org"); - assert_eq!(retriever.source, "iacr"); - - // Test pattern matching - let test_cases = [ - ("2016/260", true), - ("2023/123", true), - ("https://eprint.iacr.org/2016/260", true), - ("https://eprint.iacr.org/2016/260.pdf", true), - ("invalid/format", false), - ("https://wrong.url/2016/260", false), - ]; - - for (input, expected) in test_cases { - assert_eq!( - retriever.pattern.is_match(input), - expected, - "Pattern match failed for input: {}", - input - ); - } - - // Test identifier extraction - assert_eq!(retriever.extract_identifier("2016/260").unwrap(), "2016/260"); - assert_eq!(retriever.extract_identifier("https://eprint.iacr.org/2016/260").unwrap(), "2016/260"); - assert_eq!( - retriever.extract_identifier("https://eprint.iacr.org/2016/260.pdf").unwrap(), - "2016/260" - ); - - // Verify response format - if let ResponseFormat::Xml(config) = &retriever.response_format { - assert!(config.strip_namespaces); - - // Verify field mappings - let field_maps = &config.field_maps; - assert!(field_maps.contains_key("title")); - assert!(field_maps.contains_key("abstract")); - assert!(field_maps.contains_key("authors")); - assert!(field_maps.contains_key("publication_date")); - assert!(field_maps.contains_key("pdf_url")); - - // Verify OAI-PMH paths - if let Some(map) = field_maps.get("title") { - assert!(map.path.contains(&"OAI-PMH/GetRecord/record/metadata/dc/title".to_string())); - } else { - panic!("Missing title field map"); - } - } else { - panic!("Expected an XML configuration, but did not get one.") - } - - // Verify headers - assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); -} diff --git a/crates/learner/tests/workflows/mod.rs b/crates/learner/tests/workflows/mod.rs index 3adb178..cf9b3dd 100644 --- a/crates/learner/tests/workflows/mod.rs +++ b/crates/learner/tests/workflows/mod.rs @@ -2,6 +2,5 @@ use learner::retriever::RetrieverConfig; use super::*; -mod build_retriever; mod database_operations; mod paper_retrieval; From 9c0200901f91aaa64ed66e35038fcafa19de444f Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 09:48:56 -0700 Subject: [PATCH 13/73] cleanup --- .gitignore | 5 ++++- README.md | 6 +++--- justfile | 1 + 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 656f9a0..8f39a5b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,7 @@ *.fls *.log *.out -*.synctex.gz \ No newline at end of file +*.synctex.gz + +# SDK stuff +config/* \ No newline at end of file diff --git a/README.md b/README.md index fb83d15..6024d93 100644 --- a/README.md +++ b/README.md @@ -326,15 +326,15 @@ This repository now supplies a very basic SDK for validating a `Retriever` and a To work with this SDK, use: ``` # Setup -just setup-sdk +just setup-sdk # sets up a `config/` dir at repo root with defaults # Validations just validate-retriever # optionally supply url/identifer just validate-resource # Examples -just validate-retriever crates/learner/config/retrievers/arxiv.toml 2301.07041 -just validate-resource crates/learner/config/resources/thesis.toml +just validate-retriever config/retrievers/arxiv.toml 2301.07041 +just validate-resource config/resources/thesis.toml ``` diff --git a/justfile b/justfile index d7b6746..535493a 100644 --- a/justfile +++ b/justfile @@ -153,6 +153,7 @@ debug: # Setup SDK setup-sdk: cargo install --path crates/sdk --debug + cp -r crates/learner/config . # Validate a retriever config validate-retriever path input="": From 0137881bf2ed1e4083ab552872542d17d0824692 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 16:57:15 -0700 Subject: [PATCH 14/73] WIP: environment managing --- README.md | 4 +- crates/learner/config/retrievers/arxiv.toml | 2 +- crates/learner/config/retrievers/doi.toml | 2 +- crates/learner/config/retrievers/iacr.toml | 2 +- crates/learner/src/environment.rs | 62 +++++++++++++++++++++ crates/learner/src/lib.rs | 1 + crates/learner/src/retriever/config.rs | 30 +++++++++- crates/sdk/src/main.rs | 19 ++++++- justfile | 1 - 9 files changed, 113 insertions(+), 10 deletions(-) create mode 100644 crates/learner/src/environment.rs diff --git a/README.md b/README.md index 6024d93..31ddcb6 100644 --- a/README.md +++ b/README.md @@ -333,8 +333,8 @@ just validate-retriever # optionally supply url/identife just validate-resource # Examples -just validate-retriever config/retrievers/arxiv.toml 2301.07041 -just validate-resource config/resources/thesis.toml +just validate-retriever crates/learner/config/retrievers/arxiv.toml 2301.07041 +just validate-resource crates/learner/config/resources/thesis.toml ``` diff --git a/crates/learner/config/retrievers/arxiv.toml b/crates/learner/config/retrievers/arxiv.toml index 9d474e5..4c72a86 100644 --- a/crates/learner/config/retrievers/arxiv.toml +++ b/crates/learner/config/retrievers/arxiv.toml @@ -2,7 +2,7 @@ base_url = "http://export.arxiv.org" endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}&max_results=1" name = "arxiv" pattern = "(?:^|https?://arxiv\\.org/(?:abs|pdf)/)(\\d{4}\\.\\d{4,5}|[a-zA-Z-]+(?:\\.[A-Z]{2})?/\\d{7})(?:\\.pdf)?$" -resource = "config/resources/paper.toml" +resource = "paper" source = "arxiv" [response_format] diff --git a/crates/learner/config/retrievers/doi.toml b/crates/learner/config/retrievers/doi.toml index 4970498..adad43c 100644 --- a/crates/learner/config/retrievers/doi.toml +++ b/crates/learner/config/retrievers/doi.toml @@ -2,7 +2,7 @@ base_url = "https://api.crossref.org/works" endpoint_template = "https://api.crossref.org/works/{identifier}" name = "doi" pattern = "(?:^|https?://doi\\.org/)(10\\.\\d{4,9}/[-._;()/:\\w]+)$" -resource = "config/resources/paper.toml" +resource = "paper" source = "doi" [response_format] diff --git a/crates/learner/config/retrievers/iacr.toml b/crates/learner/config/retrievers/iacr.toml index 9dac909..9ceee72 100644 --- a/crates/learner/config/retrievers/iacr.toml +++ b/crates/learner/config/retrievers/iacr.toml @@ -2,7 +2,7 @@ base_url = "https://eprint.iacr.org" endpoint_template = "https://eprint.iacr.org/oai?verb=GetRecord&identifier=oai:eprint.iacr.org:{identifier}&metadataPrefix=oai_dc" name = "iacr" pattern = "(?:^|https?://eprint\\.iacr\\.org/)(\\d{4}/\\d+)(?:\\.pdf)?$" -resource = "config/resources/paper.toml" +resource = "paper" source = "iacr" [response_format] diff --git a/crates/learner/src/environment.rs b/crates/learner/src/environment.rs new file mode 100644 index 0000000..31f7cfc --- /dev/null +++ b/crates/learner/src/environment.rs @@ -0,0 +1,62 @@ +use std::sync::OnceLock; + +use super::*; + +// In environment.rs +#[derive(Debug, Clone)] +pub struct Environment { + config_dir: PathBuf, + resources_dir: PathBuf, + retrievers_dir: PathBuf, +} + +impl Environment { + pub fn global() -> &'static Environment { + static INSTANCE: OnceLock = OnceLock::new(); + INSTANCE.get_or_init(|| Environment { + config_dir: Config::default_path().unwrap_or_else(|_| PathBuf::from(".")), + resources_dir: Config::default_resources_path(), + retrievers_dir: Config::default_retrievers_path(), + }) + } + + pub fn set_global(config_dir: PathBuf) -> Result<()> { + static INSTANCE: OnceLock = OnceLock::new(); + + let env = Environment { + config_dir: config_dir.clone(), + resources_dir: config_dir.join("resources"), + retrievers_dir: config_dir.join("retrievers"), + }; + + INSTANCE + .set(env) + .map_err(|_| LearnerError::Config("Global environment already initialized".into())) + } + + // Add getters since we want to access these paths + pub fn config_dir() -> PathBuf { Self::global().config_dir.clone() } + + pub fn resources_dir() -> PathBuf { Self::global().resources_dir.clone() } + + pub fn retrievers_dir() -> PathBuf { Self::global().retrievers_dir.clone() } + + pub fn resolve_resource_path(resource: &str) -> PathBuf { + // Add .toml if needed + let resource_file = if !resource.ends_with(".toml") { + format!("{}.toml", resource) + } else { + resource.to_string() + }; + Self::global().resources_dir.join(resource_file) + } + + pub fn resolve_retriever_path(retriever: &str) -> PathBuf { + let retriever_file = if !retriever.ends_with(".toml") { + format!("{}.toml", retriever) + } else { + retriever.to_string() + }; + Self::global().retrievers_dir.join(retriever_file) + } +} diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index 6fa978e..a6e1ac3 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -168,6 +168,7 @@ pub mod database; pub mod retriever; pub mod configuration; +pub mod environment; pub mod error; pub mod format; pub mod llm; diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index cb8c7d9..6dbbf82 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,3 +1,5 @@ +use environment::Environment; + use super::*; /// Configuration for a specific paper source retriever. @@ -167,14 +169,36 @@ where D: serde::Deserializer<'de> { #[serde(untagged)] enum ResourceConfigRef { Inline(ResourceConfig), - Path(PathBuf), + Path(String), } let config_ref = ResourceConfigRef::deserialize(deserializer)?; match config_ref { ResourceConfigRef::Inline(config) => Ok(config), - ResourceConfigRef::Path(path) => { - let content = std::fs::read_to_string(&path).map_err(serde::de::Error::custom)?; + ResourceConfigRef::Path(resource_name) => { + // Add .toml extension if not present + let resource_file = if !resource_name.ends_with(".toml") { + format!("{}.toml", resource_name) + } else { + resource_name + }; + + // First try using Environment to resolve path + let env_path = Environment::resolve_resource_path(&resource_file); + if env_path.exists() { + let content = std::fs::read_to_string(&env_path).map_err(serde::de::Error::custom)?; + return toml::from_str(&content).map_err(serde::de::Error::custom); + } + + // Fallback for tests + let fallback_path = PathBuf::from("config/resources").join(&resource_file); + let content = std::fs::read_to_string(&fallback_path).map_err(|_| { + serde::de::Error::custom(format!( + "Resource not found at either {} or {}", + env_path.display(), + fallback_path.display() + )) + })?; toml::from_str(&content).map_err(serde::de::Error::custom) }, } diff --git a/crates/sdk/src/main.rs b/crates/sdk/src/main.rs index dcd70fa..af61b5c 100644 --- a/crates/sdk/src/main.rs +++ b/crates/sdk/src/main.rs @@ -3,7 +3,7 @@ mod validate; use std::path::PathBuf; use clap::{Parser, Subcommand}; -use learner::prelude::*; +use learner::{environment::Environment, prelude::*}; use tracing::{debug, error, info, warn}; #[derive(Parser)] @@ -42,6 +42,23 @@ async fn main() { let cli = LearnerSdk::parse(); + // Get the path from the command + let path = match &cli.command { + Commands::ValidateRetriever { path, .. } | Commands::ValidateResource { path } => path, + }; + + // // Set up environment from the config directory in the path + // if let Some(config_dir) = path.parent().and_then(|p| p.parent()) { + // debug!("Setting config directory to: {}", config_dir.display()); + // if let Err(e) = Environment::set_global(config_dir.to_path_buf()) { + // error!("Failed to set global environment: {}", e); + // return; + // } + // } else { + // error!("Could not determine config directory from path: {}", path.display()); + // return; + // } + match &cli.command { Commands::ValidateRetriever { path, input } => { info!("Validating retriever..."); diff --git a/justfile b/justfile index 535493a..d7b6746 100644 --- a/justfile +++ b/justfile @@ -153,7 +153,6 @@ debug: # Setup SDK setup-sdk: cargo install --path crates/sdk --debug - cp -r crates/learner/config . # Validate a retriever config validate-retriever path input="": From f94de5e2b68fecadc8144e1cea699b3fedbc7a39 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 17:23:13 -0700 Subject: [PATCH 15/73] fix: environment + sdk --- crates/learner/src/environment.rs | 104 +++++++++++++++++-------- crates/learner/src/lib.rs | 1 + crates/learner/src/retriever/config.rs | 24 +++--- crates/sdk/src/main.rs | 96 +++++++++++++++++------ crates/sdk/src/validate.rs | 2 +- 5 files changed, 159 insertions(+), 68 deletions(-) diff --git a/crates/learner/src/environment.rs b/crates/learner/src/environment.rs index 31f7cfc..953a7fa 100644 --- a/crates/learner/src/environment.rs +++ b/crates/learner/src/environment.rs @@ -2,7 +2,9 @@ use std::sync::OnceLock; use super::*; -// In environment.rs +// Global singleton instance +static INSTANCE: OnceLock = OnceLock::new(); + #[derive(Debug, Clone)] pub struct Environment { config_dir: PathBuf, @@ -10,53 +12,89 @@ pub struct Environment { retrievers_dir: PathBuf, } +/// Builder for constructing Environment instances with custom paths. +/// This allows flexible configuration while maintaining the standard structure. +#[derive(Default)] +pub struct EnvironmentBuilder { + // Base configuration directory is required + config_dir: Option, + // Optional custom paths for subdirectories + resources_dir: Option, + retrievers_dir: Option, +} + impl Environment { + /// Starts building a new Environment instance. + /// This is the entry point for custom environment configuration. + pub fn builder() -> EnvironmentBuilder { EnvironmentBuilder::default() } + + /// Creates a new Environment directly from paths. + /// Used internally after validation by the builder. + fn new( + config_dir: PathBuf, + resources_dir: Option, + retrievers_dir: Option, + ) -> Self { + Self { + // Use provided subdirectory paths or default to standard locations + resources_dir: resources_dir.unwrap_or_else(|| config_dir.join("resources")), + retrievers_dir: retrievers_dir.unwrap_or_else(|| config_dir.join("retrievers")), + config_dir, + } + } + pub fn global() -> &'static Environment { - static INSTANCE: OnceLock = OnceLock::new(); - INSTANCE.get_or_init(|| Environment { - config_dir: Config::default_path().unwrap_or_else(|_| PathBuf::from(".")), - resources_dir: Config::default_resources_path(), - retrievers_dir: Config::default_retrievers_path(), + INSTANCE.get_or_init(|| { + Self::new(Config::default_path().unwrap_or_else(|_| PathBuf::from(".")), None, None) }) } - pub fn set_global(config_dir: PathBuf) -> Result<()> { - static INSTANCE: OnceLock = OnceLock::new(); - - let env = Environment { - config_dir: config_dir.clone(), - resources_dir: config_dir.join("resources"), - retrievers_dir: config_dir.join("retrievers"), - }; - + pub fn set_global(env: Environment) -> Result<()> { INSTANCE .set(env) .map_err(|_| LearnerError::Config("Global environment already initialized".into())) } - // Add getters since we want to access these paths + pub fn resolve_resource_path(name: &str) -> PathBuf { + let filename = + if !name.ends_with(".toml") { format!("{}.toml", name) } else { name.to_string() }; + Self::global().resources_dir.join(filename) + } + + pub fn resolve_retriever_path(name: &str) -> PathBuf { + let filename = + if !name.ends_with(".toml") { format!("{}.toml", name) } else { name.to_string() }; + Self::global().retrievers_dir.join(filename) + } + pub fn config_dir() -> PathBuf { Self::global().config_dir.clone() } pub fn resources_dir() -> PathBuf { Self::global().resources_dir.clone() } pub fn retrievers_dir() -> PathBuf { Self::global().retrievers_dir.clone() } +} + +impl EnvironmentBuilder { + pub fn config_dir(mut self, path: impl Into) -> Self { + self.config_dir = Some(path.into()); + self + } + + pub fn resources_dir(mut self, path: impl Into) -> Self { + self.resources_dir = Some(path.into()); + self + } + + pub fn retrievers_dir(mut self, path: impl Into) -> Self { + self.retrievers_dir = Some(path.into()); + self + } + + pub fn build(self) -> Result { + let config_dir = self + .config_dir + .ok_or_else(|| LearnerError::Config("Configuration directory must be specified".into()))?; - pub fn resolve_resource_path(resource: &str) -> PathBuf { - // Add .toml if needed - let resource_file = if !resource.ends_with(".toml") { - format!("{}.toml", resource) - } else { - resource.to_string() - }; - Self::global().resources_dir.join(resource_file) - } - - pub fn resolve_retriever_path(retriever: &str) -> PathBuf { - let retriever_file = if !retriever.ends_with(".toml") { - format!("{}.toml", retriever) - } else { - retriever.to_string() - }; - Self::global().retrievers_dir.join(retriever_file) + Ok(Environment::new(config_dir, self.resources_dir, self.retrievers_dir)) } } diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index a6e1ac3..a980eb2 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -293,6 +293,7 @@ pub struct Config { /// # Ok(()) /// # } /// ``` +// TODO: Add an `Environment` in here. #[derive(Debug, Clone)] pub struct Learner { /// Active configuration diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 6dbbf82..312627f 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -176,22 +176,23 @@ where D: serde::Deserializer<'de> { match config_ref { ResourceConfigRef::Inline(config) => Ok(config), ResourceConfigRef::Path(resource_name) => { - // Add .toml extension if not present - let resource_file = if !resource_name.ends_with(".toml") { - format!("{}.toml", resource_name) - } else { - resource_name - }; - - // First try using Environment to resolve path - let env_path = Environment::resolve_resource_path(&resource_file); + // Try loading from the global environment path + let env_path = Environment::resolve_resource_path(&resource_name); + if env_path.exists() { let content = std::fs::read_to_string(&env_path).map_err(serde::de::Error::custom)?; return toml::from_str(&content).map_err(serde::de::Error::custom); } - // Fallback for tests - let fallback_path = PathBuf::from("config/resources").join(&resource_file); + // If global path doesn't exist, try the local fallback + // This is mainly useful for development and testing + let fallback_path = + PathBuf::from("config/resources").join(if resource_name.ends_with(".toml") { + resource_name.to_string() + } else { + format!("{}.toml", resource_name) + }); + let content = std::fs::read_to_string(&fallback_path).map_err(|_| { serde::de::Error::custom(format!( "Resource not found at either {} or {}", @@ -199,6 +200,7 @@ where D: serde::Deserializer<'de> { fallback_path.display() )) })?; + toml::from_str(&content).map_err(serde::de::Error::custom) }, } diff --git a/crates/sdk/src/main.rs b/crates/sdk/src/main.rs index af61b5c..21a0c29 100644 --- a/crates/sdk/src/main.rs +++ b/crates/sdk/src/main.rs @@ -1,6 +1,6 @@ mod validate; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use clap::{Parser, Subcommand}; use learner::{environment::Environment, prelude::*}; @@ -30,8 +30,40 @@ enum Commands { }, } +/// Attempts to find the root config directory by walking up the path. +/// Returns the config directory and its relation to the input path. +fn find_config_dir(path: &Path) -> Option<(PathBuf, String)> { + // Convert path to absolute for clearer error messages + let abs_path = path.canonicalize().ok()?; + let mut current = abs_path.as_path(); + + // Walk up the directory tree + while let Some(parent) = current.parent() { + // Check if this is a config directory by looking for expected structure + if parent.ends_with("config") + && parent.join("resources").is_dir() + && parent.join("retrievers").is_dir() + { + // Calculate the relationship to the original path + let relation = if abs_path.starts_with(parent) { + format!( + "Found config directory {} levels up from input path", + abs_path.strip_prefix(parent).ok()?.components().count() - 1 + ) + } else { + "Found config directory".to_string() + }; + + return Some((parent.to_path_buf(), relation)); + } + current = parent; + } + None +} + #[tokio::main] async fn main() { + // Set up logging with a clean format focused on user feedback tracing_subscriber::fmt() .without_time() .with_file(false) @@ -47,35 +79,53 @@ async fn main() { Commands::ValidateRetriever { path, .. } | Commands::ValidateResource { path } => path, }; - // // Set up environment from the config directory in the path - // if let Some(config_dir) = path.parent().and_then(|p| p.parent()) { - // debug!("Setting config directory to: {}", config_dir.display()); - // if let Err(e) = Environment::set_global(config_dir.to_path_buf()) { - // error!("Failed to set global environment: {}", e); - // return; - // } - // } else { - // error!("Could not determine config directory from path: {}", path.display()); - // return; - // } + // First check if the input path exists + if !path.exists() { + error!("Input path does not exist: {}", path.display()); + error!("Please provide a valid path to a configuration file"); + return; + } + // Try to find the config directory + let (config_dir, message) = match find_config_dir(path) { + Some((dir, msg)) => (dir, msg), + None => { + error!("Could not find a valid configuration directory!"); + error!("Looking for a directory named 'config' containing:"); + error!(" - resources/ directory"); + error!(" - retrievers/ directory"); + error!("Input path was: {}", path.display()); + error!("Tip: Make sure you're running this command from a location where"); + error!(" the config directory structure is accessible"); + return; + }, + }; + + // Initialize the environment + info!("{}", message); + debug!("Using config directory: {}", config_dir.display()); + + if let Err(e) = + Environment::builder().config_dir(&config_dir).build().and_then(Environment::set_global) + { + error!("Failed to initialize environment: {}", e); + error!("This might indicate a problem with the config directory structure"); + return; + } + + // Proceed with validation based on command match &cli.command { Commands::ValidateRetriever { path, input } => { - info!("Validating retriever..."); - if !path.exists() { - error!("Path to retriever config was invalid.\nPath used: {path:?}"); - return; + info!("Validating retriever configuration..."); + debug!("Config file: {}", path.display()); + if let Some(input) = input { + debug!("Testing with input: {}", input); } - debug!("Validating retriever config at {:?}", path); validate::validate_retriever(path, input).await; }, Commands::ValidateResource { path } => { - info!("Validating resource..."); - if !path.exists() { - error!("Path to resource config was invalid.\nPath used: {path:?}"); - return; - } - debug!("Validating resource config at {:?}", path); + info!("Validating resource configuration..."); + debug!("Config file: {}", path.display()); validate::validate_resource(path); }, } diff --git a/crates/sdk/src/validate.rs b/crates/sdk/src/validate.rs index 7ebc8f4..12899a4 100644 --- a/crates/sdk/src/validate.rs +++ b/crates/sdk/src/validate.rs @@ -216,7 +216,7 @@ pub async fn validate_retriever(path: &PathBuf, input: &Option) { Err(e) => error!("Pattern matching failed: {}", e), } } else { - info!("No test input provided - skipping retrieval tests"); + warn!("No test input provided - skipping retrieval tests"); info!("To test retrieval, provide an identifier like: 2301.07041"); } } From b0d1d61ec5818e8d994c9c13e3a44649e5818362 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 17:36:53 -0700 Subject: [PATCH 16/73] improve sdk --- Cargo.lock | 1 + crates/sdk/Cargo.toml | 1 + crates/sdk/src/main.rs | 24 ++++- crates/sdk/src/validate.rs | 205 +++++++++++++++++++++++++------------ justfile | 92 +++++++++++++++-- 5 files changed, 250 insertions(+), 73 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1f588f1..a92d180 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1196,6 +1196,7 @@ name = "learner-sdk" version = "0.1.0" dependencies = [ "clap", + "console", "learner", "regex", "reqwest", diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 398c03c..8463bca 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -11,6 +11,7 @@ version = "0.1.0" [dependencies] clap = { workspace = true } +console = { workspace = true } learner = { workspace = true } regex = { workspace = true } reqwest = { workspace = true } diff --git a/crates/sdk/src/main.rs b/crates/sdk/src/main.rs index 21a0c29..67d59a3 100644 --- a/crates/sdk/src/main.rs +++ b/crates/sdk/src/main.rs @@ -2,7 +2,7 @@ mod validate; use std::path::{Path, PathBuf}; -use clap::{Parser, Subcommand}; +use clap::{ArgAction, Parser, Subcommand}; use learner::{environment::Environment, prelude::*}; use tracing::{debug, error, info, warn}; @@ -11,6 +11,16 @@ use tracing::{debug, error, info, warn}; struct LearnerSdk { #[command(subcommand)] command: Commands, + + /// Verbose mode (-v, -vv, -vvv) for different levels of logging detail + #[arg( + short, + long, + action = ArgAction::Count, + global = true, + help = "Increase logging verbosity" + )] + verbose: u8, } #[derive(Subcommand)] @@ -63,8 +73,16 @@ fn find_config_dir(path: &Path) -> Option<(PathBuf, String)> { #[tokio::main] async fn main() { - // Set up logging with a clean format focused on user feedback + let cli = LearnerSdk::parse(); + let filter = match cli.verbose { + 0 => "error", + 1 => "warn", + 2 => "info", + 3 => "debug", + _ => "trace", + }; tracing_subscriber::fmt() + .with_env_filter(filter) .without_time() .with_file(false) .with_line_number(false) @@ -72,8 +90,6 @@ async fn main() { .with_max_level(tracing::Level::TRACE) .init(); - let cli = LearnerSdk::parse(); - // Get the path from the command let path = match &cli.command { Commands::ValidateRetriever { path, .. } | Commands::ValidateResource { path } => path, diff --git a/crates/sdk/src/validate.rs b/crates/sdk/src/validate.rs index 12899a4..f444092 100644 --- a/crates/sdk/src/validate.rs +++ b/crates/sdk/src/validate.rs @@ -1,5 +1,6 @@ use std::fs::read_to_string; +use console::{style, Term}; // For better formatted output use learner::{ resource::ResourceConfig, retriever::{ResponseFormat, RetrieverConfig}, @@ -7,8 +8,30 @@ use learner::{ use super::*; +/// Formats a validation header with consistent styling +fn print_validation_header(message: &str) { + println!("\n{}", style(message).bold().cyan()); +} + +/// Formats a validation section header +fn print_section_header(message: &str) { + println!("\n{}", style(message).bold()); +} + +/// Prints a success message with a checkmark +fn print_success(message: &str) { + println!("{} {}", style("✓").bold().green(), message); +} + +/// Formats field information consistently +fn print_field_info(name: &str, field_type: &str, indent: usize) { + let indent_str = " ".repeat(indent); + println!("{}Field '{}' ({})", indent_str, style(name).bold(), style(field_type).italic()); +} + pub fn validate_resource(path: &PathBuf) { - info!("Validating resource configuration from: {}", path.display()); + print_validation_header("Resource Configuration Validation"); + info!("Loading configuration from: {}", path.display()); // Read and parse the configuration let config_str = match read_to_string(path) { @@ -28,28 +51,36 @@ pub fn validate_resource(path: &PathBuf) { error!("- Missing or malformed fields"); error!("- Incorrect data types"); error!("- TOML syntax errors"); + error!("\nDetailed error: {}", e); return; }, }; - info!("Found resource type: {}", resource.name); + // Resource overview + print_section_header("Resource Overview"); + println!("Name: {}", style(&resource.name).bold()); if let Some(desc) = &resource.description { - info!("Description: {}", desc); + println!("Description: {}", desc); } - // Validate field definitions - info!("Validating {} field definitions...", resource.fields.len()); + // Field validation + print_section_header(&format!("Field Definitions ({})", style(resource.fields.len()).bold())); + for field in &resource.fields { - // Check field type validity match field.field_type.as_str() { "string" | "integer" | "float" | "boolean" | "datetime" | "array" | "table" => { - info!("Field '{}' ({}):", field.name, field.field_type); + print_field_info(&field.name, &field.field_type, 0); + + // Show field metadata if let Some(desc) = &field.description { - info!(" Description: {}", desc); + println!(" Description: {}", desc); + } + + if field.required { + println!(" {}", style("Required: true").yellow()); } - info!(" Required: {}", field.required); - // Validate default values match declared type + // Validate and show default values if let Some(default) = &field.default { match (field.field_type.as_str(), default) { ("string", toml::Value::String(_)) @@ -59,61 +90,74 @@ pub fn validate_resource(path: &PathBuf) { | ("datetime", toml::Value::Datetime(_)) | ("array", toml::Value::Array(_)) | ("table", toml::Value::Table(_)) => { - info!(" Default value: valid"); + print_success("Default value has correct type"); }, _ => { - error!(" Default value type doesn't match field type!"); - error!(" Expected {}, got {}", field.field_type, default.type_str()); + println!(" {}: Default value type mismatch", style("ERROR").red()); + println!( + " Expected {}, got {}", + style(&field.field_type).bold(), + style(default.type_str()).red() + ); }, } } - // Validate validation rules + // Show validation rules with better formatting if let Some(rules) = &field.validation { - info!(" Validation rules:"); + println!(" Validation Rules:"); match field.field_type.as_str() { "string" => { if let Some(pattern) = &rules.pattern { match regex::Regex::new(pattern) { - Ok(_) => info!(" - Valid regex pattern"), - Err(e) => error!(" - Invalid regex pattern: {}", e), + Ok(_) => + println!(" {} Pattern: {}", style("✓").green(), style(pattern).italic()), + Err(e) => println!(" {} Invalid regex pattern: {}", style("✗").red(), e), } } if let Some(min) = rules.min_length { - info!(" - Minimum length: {}", min); + println!(" Minimum length: {}", min); } if let Some(max) = rules.max_length { - info!(" - Maximum length: {}", max); + println!(" Maximum length: {}", max); } }, "array" => { if let Some(min) = rules.min_items { - info!(" - Minimum items: {}", min); + println!(" Minimum items: {}", min); } if let Some(max) = rules.max_items { - info!(" - Maximum items: {}", max); + println!(" Maximum items: {}", max); } if rules.unique_items == Some(true) { - info!(" - Items must be unique"); + println!(" {}", style("Items must be unique").yellow()); } }, _ => {}, } } + println!(); // Add spacing between fields }, invalid_type => { - error!("Field '{}' has invalid type: {}", field.name, invalid_type); - error!("Valid types are: string, integer, float, boolean, datetime, array, table"); + println!("\n{}", style("ERROR").red().bold()); + println!( + "Field '{}' has invalid type: {}", + style(&field.name).bold(), + style(invalid_type).red() + ); + println!("Valid types are: string, integer, float, boolean, datetime, array, table"); }, } } - info!("Resource configuration validation complete!"); + print_success("Resource configuration validation complete!"); } pub async fn validate_retriever(path: &PathBuf, input: &Option) { - info!("Validating retriever configuration from: {}", path.display()); + print_validation_header("Retriever Configuration Validation"); + info!("Loading configuration from: {}", path.display()); + // Read and parse configuration let config_str = match read_to_string(path) { Ok(str) => str, Err(e) => { @@ -131,92 +175,127 @@ pub async fn validate_retriever(path: &PathBuf, input: &Option) { error!("- Missing required fields"); error!("- Incorrect field types"); error!("- Malformed URLs or patterns"); + error!("\nDetailed error: {}", e); return; }, }; - // Validate basic configuration - info!("Validating retriever '{}'", retriever.name); + // Basic configuration overview + print_section_header("Basic Configuration"); + println!("Name: {}", style(&retriever.name).bold()); + println!("Source: {}", retriever.source); - // Check URL validity - if let Err(e) = reqwest::Url::parse(&retriever.base_url) { - error!("Invalid base URL: {}", e); - return; + // URL validation + print_section_header("URL Configuration"); + match reqwest::Url::parse(&retriever.base_url) { + Ok(url) => { + print_success(&format!("Valid base URL: {}", style(url).green())); + }, + Err(e) => { + println!("{} Invalid base URL: {}", style("✗").red(), e); + return; + }, } - // Validate endpoint template - if !retriever.endpoint_template.contains("{identifier}") { - error!("Endpoint template must contain {{identifier}} placeholder"); + // Endpoint template validation + if retriever.endpoint_template.contains("{identifier}") { + print_success("Endpoint template contains required {identifier} placeholder"); + } else { + println!( + "{} Endpoint template must contain {{identifier}} placeholder", + style("✗").red().bold() + ); + println!("Current template: {}", style(&retriever.endpoint_template).italic()); return; } - // Check response format configuration + // Response format validation + print_section_header("Response Format Configuration"); match &retriever.response_format { ResponseFormat::Xml(config) => { - info!("XML configuration:"); - info!("- Namespace stripping: {}", config.strip_namespaces); - info!("Validating field mappings:"); + println!("Format: {}", style("XML").cyan()); + println!( + "Namespace handling: {}", + if config.strip_namespaces { + style("Stripping enabled").green() + } else { + style("Preserving namespaces").yellow() + } + ); + + println!("\nField Mappings:"); for (field, map) in &config.field_maps { - info!("- {}: {}", field, map.path); + println!("- {}: {}", style(field).bold(), map.path); if let Some(transform) = &map.transform { - info!(" with transformation: {:?}", transform); + println!(" Transform: {}", style(format!("{:?}", transform)).italic()); } } }, ResponseFormat::Json(config) => { - info!("JSON configuration:"); - info!("Validating field mappings:"); + println!("Format: {}", style("JSON").cyan()); + println!("\nField Mappings:"); for (field, map) in &config.field_maps { - info!("- {}: {}", field, map.path); + println!("- {}: {}", style(field).bold(), map.path); if let Some(transform) = &map.transform { - info!(" with transformation: {:?}", transform); + println!(" Transform: {}", style(format!("{:?}", transform)).italic()); } } }, } - // Test pattern matching if input provided + // Live testing if input provided if let Some(input) = input { - info!("Testing identifier pattern matching..."); + print_section_header("Live Testing"); + println!("Testing with input: {}", style(input).cyan()); + match retriever.extract_identifier(input) { Ok(identifier) => { - info!("✓ Successfully extracted identifier: {}", identifier); + print_success(&format!("Extracted identifier: {}", style(identifier).green())); - // Try fetching - info!("Testing retrieval..."); + // Paper retrieval test + println!("\nAttempting paper retrieval..."); match retriever.retrieve_paper(input).await { Ok(paper) => { - info!("✓ Successfully retrieved paper:"); - info!("Title: {}", paper.title); - info!("Authors: {:?}", paper.authors); + print_success("Paper retrieved successfully"); + println!("\nPaper Details:"); + println!("Title: {}", style(&paper.title).bold()); + println!( + "Authors: {}", + paper.authors.iter().map(|a| a.name.clone()).collect::>().join(", ") + ); - // Test PDF download if available + // PDF download test if let Some(url) = &paper.pdf_url { - info!("Testing PDF download from: {}", url); + println!("\nTesting PDF download capability..."); let tempdir = tempfile::tempdir().unwrap(); match paper.download_pdf(tempdir.path()).await { Ok(filename) => { let pdf_path = tempdir.path().join(filename); if pdf_path.exists() { let metadata = std::fs::metadata(&pdf_path).unwrap(); - info!("✓ PDF downloaded successfully ({} bytes)", metadata.len()); + print_success(&format!( + "PDF downloaded successfully ({} bytes)", + style(metadata.len()).green() + )); } else { - error!("PDF download failed - file not created"); + println!("{} PDF download failed - file not created", style("✗").red()); } }, - Err(e) => error!("PDF download failed: {}", e), + Err(e) => println!("{} PDF download failed: {}", style("✗").red(), e), } } else { - warn!("No PDF URL available"); + println!("{} No PDF URL available for testing", style("!").yellow()); } }, - Err(e) => error!("Retrieval failed: {}", e), + Err(e) => println!("{} Retrieval failed: {}", style("✗").red(), e), } }, - Err(e) => error!("Pattern matching failed: {}", e), + Err(e) => println!("{} Pattern matching failed: {}", style("✗").red(), e), } } else { - warn!("No test input provided - skipping retrieval tests"); - info!("To test retrieval, provide an identifier like: 2301.07041"); + println!("\n{} No test input provided - skipping retrieval tests", style("!").yellow()); + println!("Tip: Provide an identifier (like '2301.07041') to test live retrieval"); } + + print_success("Retriever configuration validation complete!"); } diff --git a/justfile b/justfile index d7b6746..3f62247 100644 --- a/justfile +++ b/justfile @@ -150,19 +150,99 @@ debug: @just header "Installing learnerd in debug mode" cargo install --path crates/learnerd --features tui --debug +# Display usage information for resource validation +help-resource: + @echo "Validate a resource configuration file" + @echo + @echo "Usage:" + @echo " just validate-resource path/to/resource.toml" + @echo + @echo "Examples:" + @echo " just validate-resource crates/learner/config/resources/paper.toml" + @echo " just validate-resource crates/learner/config/resources/book.toml" + @echo " just validate-resource crates/learner/config/resources/thesis.toml" + +# Display usage information for retriever validation +help-retriever: + @echo "Validate a retriever configuration file and optionally test with an identifier" + @echo + @echo "Usage:" + @echo " just validate-retriever path/to/retriever.toml [identifier]" + @echo + @echo "Examples:" + @echo " just validate-retriever crates/learner/config/retrievers/arxiv.toml" + @echo " just validate-retriever crates/learner/config/retrievers/arxiv.toml 2301.07041" + @echo " just validate-retriever crates/learner/config/retrievers/doi.toml \"10.1145/1327452.1327492\"" + @echo " just validate-retriever crates/learner/config/retrievers/iacr.toml \"2023/123\"" + # Setup SDK setup-sdk: + @echo "Installing learner-sdk..." cargo install --path crates/sdk --debug + @echo "Run 'just help-resource' or 'just help-retriever' for usage information" -# Validate a retriever config -validate-retriever path input="": - @just header "Validating retriever config" - learner-sdk validate-retriever {{path}} {{input}} +# Check if a path argument is provided for resource validation +[private] +_check-resource-args args: + #!/usr/bin/env bash + if [ -z "{{args}}" ]; then + printf "{{error}}Error: Missing required path argument{{reset}}\n" + echo + just help-resource + exit 1 + fi + +# Check if required arguments are provided for retriever validation +[private] +_check-retriever-args args: + #!/usr/bin/env bash + if [ -z "{{args}}" ]; then + printf "{{error}}Error: Missing required path argument{{reset}}\n" + echo + just help-retriever + exit 1 + fi # Validate a resource config -validate-resource path: +validate-resource +args="": (_check-resource-args args) + #!/usr/bin/env bash + if [ ! -f "{{args}}" ]; then + printf "{{error}}Error: File not found: {{args}}{{reset}}\n" + echo + just help-resource + exit 1 + fi @just header "Validating resource config" - learner-sdk validate-resource {{path}} + learner-sdk validate-resource {{args}} + +# Validate a retriever config +validate-retriever +args="": (_check-retriever-args args) + #!/usr/bin/env bash + # Split args into path and input + read -r path input <<< "{{args}}" + if [ ! -f "$path" ]; then + printf "{{error}}Error: File not found: $path{{reset}}\n" + echo + just help-retriever + exit 1 + fi + @just header "Validating retriever config" + learner-sdk validate-retriever "$path" $input + +# Run all example validations (useful for CI) +validate-examples: + #!/usr/bin/env bash + echo "==> Validating all example configurations" + echo + echo "Resource Configurations:" + just validate-resource crates/learner/config/resources/paper.toml + just validate-resource crates/learner/config/resources/book.toml + just validate-resource crates/learner/config/resources/thesis.toml + echo + echo "Retriever Configurations:" + just validate-retriever crates/learner/config/retrievers/arxiv.toml "2301.07041" + just validate-retriever crates/learner/config/retrievers/doi.toml "10.1145/1327452.1327492" + just validate-retriever crates/learner/config/retrievers/iacr.toml "2023/123" # Show your relevant environment information info: From bc42b5ae360c2ef0125434da3d13953106a8e1d9 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 17:37:54 -0700 Subject: [PATCH 17/73] Update README.md --- README.md | 86 ++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 31ddcb6..cc94cf3 100644 --- a/README.md +++ b/README.md @@ -321,22 +321,88 @@ just build-all # build all targets > [!TIP] > Running `just setup` and `just ci` locally is a quick way to get up to speed and see that the repo is working on your system! -## SDK -This repository now supplies a very basic SDK for validating a `Retriever` and a `Resource` TOML configurations. -To work with this SDK, use: +## Learner SDK + +The Learner SDK provides command-line tools for validating and testing your resource and retriever configurations. + +### Installation + +```bash +# From the repository root +just setup-sdk ``` -# Setup -just setup-sdk # sets up a `config/` dir at repo root with defaults -# Validations -just validate-retriever # optionally supply url/identifer -just validate-resource +### Validating Resource Configurations + +Resource configurations define the structure and validation rules for different types of academic materials (papers, books, theses, etc.). + +```bash +# View usage information +just help-resource -# Examples -just validate-retriever crates/learner/config/retrievers/arxiv.toml 2301.07041 +# Validate example configurations +just validate-resource crates/learner/config/resources/paper.toml +just validate-resource crates/learner/config/resources/book.toml just validate-resource crates/learner/config/resources/thesis.toml ``` +The validator will check: +- TOML syntax and structure +- Field type correctness +- Default value compatibility +- Validation rule syntax (regex patterns, etc.) +- Required field presence + +### Validating Retriever Configurations + +Retriever configurations define how to fetch papers from different sources (arXiv, DOI, IACR, etc.). + +```bash +# View usage information +just help-retriever + +# Validate configuration syntax only +just validate-retriever crates/learner/config/retrievers/arxiv.toml + +# Validate and test live retrieval +just validate-retriever crates/learner/config/retrievers/arxiv.toml "2301.07041" +just validate-retriever crates/learner/config/retrievers/doi.toml "10.1145/1327452.1327492" +just validate-retriever crates/learner/config/retrievers/iacr.toml "2023/123" +``` + +The validator will check: +- TOML syntax and structure +- URL validity +- Response format configuration +- Field mappings +- Pattern matching +- Live paper retrieval (when identifier provided) +- PDF download capability + +### Running All Examples + +To run all example validations (useful for testing your setup or in CI): + +```bash +just validate-examples +``` + +This will validate all provided example configurations and test live retrieval for each supported source. + +## Creating Your Own Configurations + +Use the example configurations in `crates/learner/config/` as templates: + +- Resources + - `resources/paper.toml` - Academic papers + - `resources/book.toml` - Books and monographs + - `resources/thesis.toml` - Theses and dissertations + +- Retrievers + - `retrievers/arxiv.toml` - arXiv papers + - `retrievers/doi.toml` - DOI-based papers + - `retrievers/iacr.toml` - IACR papers + ## License From 9f4c822babb31c377214dbe82ee3dfb4dc06b952 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 19:32:07 -0700 Subject: [PATCH 18/73] simplify stuff --- crates/learner/src/environment.rs | 100 ------------------------- crates/learner/src/lib.rs | 12 +-- crates/learner/src/retriever/config.rs | 54 +------------ crates/sdk/src/main.rs | 11 +-- crates/sdk/src/validate.rs | 4 +- 5 files changed, 11 insertions(+), 170 deletions(-) delete mode 100644 crates/learner/src/environment.rs diff --git a/crates/learner/src/environment.rs b/crates/learner/src/environment.rs deleted file mode 100644 index 953a7fa..0000000 --- a/crates/learner/src/environment.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::sync::OnceLock; - -use super::*; - -// Global singleton instance -static INSTANCE: OnceLock = OnceLock::new(); - -#[derive(Debug, Clone)] -pub struct Environment { - config_dir: PathBuf, - resources_dir: PathBuf, - retrievers_dir: PathBuf, -} - -/// Builder for constructing Environment instances with custom paths. -/// This allows flexible configuration while maintaining the standard structure. -#[derive(Default)] -pub struct EnvironmentBuilder { - // Base configuration directory is required - config_dir: Option, - // Optional custom paths for subdirectories - resources_dir: Option, - retrievers_dir: Option, -} - -impl Environment { - /// Starts building a new Environment instance. - /// This is the entry point for custom environment configuration. - pub fn builder() -> EnvironmentBuilder { EnvironmentBuilder::default() } - - /// Creates a new Environment directly from paths. - /// Used internally after validation by the builder. - fn new( - config_dir: PathBuf, - resources_dir: Option, - retrievers_dir: Option, - ) -> Self { - Self { - // Use provided subdirectory paths or default to standard locations - resources_dir: resources_dir.unwrap_or_else(|| config_dir.join("resources")), - retrievers_dir: retrievers_dir.unwrap_or_else(|| config_dir.join("retrievers")), - config_dir, - } - } - - pub fn global() -> &'static Environment { - INSTANCE.get_or_init(|| { - Self::new(Config::default_path().unwrap_or_else(|_| PathBuf::from(".")), None, None) - }) - } - - pub fn set_global(env: Environment) -> Result<()> { - INSTANCE - .set(env) - .map_err(|_| LearnerError::Config("Global environment already initialized".into())) - } - - pub fn resolve_resource_path(name: &str) -> PathBuf { - let filename = - if !name.ends_with(".toml") { format!("{}.toml", name) } else { name.to_string() }; - Self::global().resources_dir.join(filename) - } - - pub fn resolve_retriever_path(name: &str) -> PathBuf { - let filename = - if !name.ends_with(".toml") { format!("{}.toml", name) } else { name.to_string() }; - Self::global().retrievers_dir.join(filename) - } - - pub fn config_dir() -> PathBuf { Self::global().config_dir.clone() } - - pub fn resources_dir() -> PathBuf { Self::global().resources_dir.clone() } - - pub fn retrievers_dir() -> PathBuf { Self::global().retrievers_dir.clone() } -} - -impl EnvironmentBuilder { - pub fn config_dir(mut self, path: impl Into) -> Self { - self.config_dir = Some(path.into()); - self - } - - pub fn resources_dir(mut self, path: impl Into) -> Self { - self.resources_dir = Some(path.into()); - self - } - - pub fn retrievers_dir(mut self, path: impl Into) -> Self { - self.retrievers_dir = Some(path.into()); - self - } - - pub fn build(self) -> Result { - let config_dir = self - .config_dir - .ok_or_else(|| LearnerError::Config("Configuration directory must be specified".into()))?; - - Ok(Environment::new(config_dir, self.resources_dir, self.retrievers_dir)) - } -} diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index a980eb2..9f26eca 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -85,11 +85,6 @@ //! - Document storage management //! - Command pattern operations //! -//! - [`clients`]: API clients for paper sources -//! - Source-specific implementations -//! - Response parsing and validation -//! - Error handling and retry logic -//! //! - [`retriever`]: Configurable paper retrieval system //! - Automatic source detection //! - XML and JSON response handling @@ -168,7 +163,7 @@ pub mod database; pub mod retriever; pub mod configuration; -pub mod environment; + pub mod error; pub mod format; pub mod llm; @@ -575,7 +570,7 @@ impl LearnerBuilder { pub async fn build(self) -> Result { let config = if let Some(config) = self.config { config - } else if let Some(path) = self.config_path { + } else if let Some(path) = &self.config_path { let config_file = path.join("config.toml"); let content = std::fs::read_to_string(config_file)?; toml::from_str(&content).map_err(|e| LearnerError::Config(e.to_string()))? @@ -593,7 +588,6 @@ impl LearnerBuilder { let database = Database::open(&config.database_path).await?; database.set_storage_path(&config.storage_path).await?; - let retriever = Retrievers::new().with_config_dir(&config.retrievers_path)?; let resources = Resources::new().with_config_dir(&config.resources_path)?; @@ -755,6 +749,7 @@ impl Learner { mod tests { use super::*; + #[traced_test] #[tokio::test] async fn test_learner_creation() { let config_dir = tempdir().unwrap(); @@ -765,6 +760,7 @@ mod tests { .with_resources_path(&config_dir.path().join("config/resources/")) .with_retrievers_path(&config_dir.path().join("config/retrievers/")) .with_storage_path(storage_dir.path()); + let learner = Learner::builder().with_path(config_dir.path()).with_config(config).build().await.unwrap(); diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 312627f..f7ff1ae 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,5 +1,3 @@ -use environment::Environment; - use super::*; /// Configuration for a specific paper source retriever. @@ -32,9 +30,9 @@ use super::*; pub struct RetrieverConfig { /// Name of this retriever configuration pub name: String, + // TODO (autoparallel): Ultimately this will have to peer into the `Resources` to be useful /// The type of resource this retriever should yield - #[serde(deserialize_with = "load_resource_config")] - pub resource: ResourceConfig, + pub resource: String, /// Base URL for API requests pub base_url: String, /// Regex pattern for matching and extracting paper identifiers @@ -132,6 +130,7 @@ impl RetrieverConfig { Ok(paper) } + #[allow(missing_docs)] pub async fn retrieve_resource(&self, input: &str) -> Result { let identifier = self.extract_identifier(input)?; @@ -152,56 +151,11 @@ impl RetrieverConfig { trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); // Process the response into a generic Value first - let response_processor = match &self.response_format { + let _response_processor = match &self.response_format { ResponseFormat::Xml(config) => config as &dyn ResponseProcessor, ResponseFormat::Json(config) => config as &dyn ResponseProcessor, }; todo!(); - - // Ok(resource) - } -} - -fn load_resource_config<'de, D>(deserializer: D) -> std::result::Result -where D: serde::Deserializer<'de> { - #[derive(Deserialize)] - #[serde(untagged)] - enum ResourceConfigRef { - Inline(ResourceConfig), - Path(String), - } - - let config_ref = ResourceConfigRef::deserialize(deserializer)?; - match config_ref { - ResourceConfigRef::Inline(config) => Ok(config), - ResourceConfigRef::Path(resource_name) => { - // Try loading from the global environment path - let env_path = Environment::resolve_resource_path(&resource_name); - - if env_path.exists() { - let content = std::fs::read_to_string(&env_path).map_err(serde::de::Error::custom)?; - return toml::from_str(&content).map_err(serde::de::Error::custom); - } - - // If global path doesn't exist, try the local fallback - // This is mainly useful for development and testing - let fallback_path = - PathBuf::from("config/resources").join(if resource_name.ends_with(".toml") { - resource_name.to_string() - } else { - format!("{}.toml", resource_name) - }); - - let content = std::fs::read_to_string(&fallback_path).map_err(|_| { - serde::de::Error::custom(format!( - "Resource not found at either {} or {}", - env_path.display(), - fallback_path.display() - )) - })?; - - toml::from_str(&content).map_err(serde::de::Error::custom) - }, } } diff --git a/crates/sdk/src/main.rs b/crates/sdk/src/main.rs index 67d59a3..704a45a 100644 --- a/crates/sdk/src/main.rs +++ b/crates/sdk/src/main.rs @@ -3,8 +3,7 @@ mod validate; use std::path::{Path, PathBuf}; use clap::{ArgAction, Parser, Subcommand}; -use learner::{environment::Environment, prelude::*}; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info}; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -121,14 +120,6 @@ async fn main() { info!("{}", message); debug!("Using config directory: {}", config_dir.display()); - if let Err(e) = - Environment::builder().config_dir(&config_dir).build().and_then(Environment::set_global) - { - error!("Failed to initialize environment: {}", e); - error!("This might indicate a problem with the config directory structure"); - return; - } - // Proceed with validation based on command match &cli.command { Commands::ValidateRetriever { path, input } => { diff --git a/crates/sdk/src/validate.rs b/crates/sdk/src/validate.rs index f444092..5a39892 100644 --- a/crates/sdk/src/validate.rs +++ b/crates/sdk/src/validate.rs @@ -1,6 +1,6 @@ use std::fs::read_to_string; -use console::{style, Term}; // For better formatted output +use console::style; use learner::{ resource::ResourceConfig, retriever::{ResponseFormat, RetrieverConfig}, @@ -265,7 +265,7 @@ pub async fn validate_retriever(path: &PathBuf, input: &Option) { ); // PDF download test - if let Some(url) = &paper.pdf_url { + if paper.pdf_url.is_some() { println!("\nTesting PDF download capability..."); let tempdir = tempfile::tempdir().unwrap(); match paper.download_pdf(tempdir.path()).await { From f0e88c271e7dd1833efc76943339e79ac28242e9 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 30 Nov 2024 19:50:37 -0700 Subject: [PATCH 19/73] WIP: simplifying things --- crates/learner/src/lib.rs | 38 ++++++++++++++++--- crates/learner/src/resource/mod.rs | 7 ++++ crates/learner/src/retriever/config.rs | 10 ++++- crates/learner/tests/llm/mod.rs | 2 +- .../workflows/database_operations/add.rs | 10 ++--- .../workflows/database_operations/remove.rs | 6 +-- crates/learnerd/src/commands/add.rs | 6 +-- 7 files changed, 60 insertions(+), 19 deletions(-) diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index 9f26eca..915d318 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -153,7 +153,7 @@ use chrono::{DateTime, Utc}; use lazy_static::lazy_static; use regex::Regex; use reqwest::Url; -use resource::{ResourceConfig, Resources}; +use resource::{Resource, ResourceConfig, Resources}; use serde::{Deserialize, Serialize}; use tracing::{debug, trace, warn}; #[cfg(test)] @@ -292,13 +292,13 @@ pub struct Config { #[derive(Debug, Clone)] pub struct Learner { /// Active configuration - pub config: Config, + pub config: Config, /// Database connection and operations - pub database: Database, + pub database: Database, /// Paper retrieval system - pub retriever: Retrievers, + pub retrievers: Retrievers, /// Resources to use - pub resources: Resources, + pub resources: Resources, } /// Builder for creating configured Learner instances. @@ -591,7 +591,7 @@ impl LearnerBuilder { let retriever = Retrievers::new().with_config_dir(&config.retrievers_path)?; let resources = Resources::new().with_config_dir(&config.resources_path)?; - Ok(Learner { config, database, retriever, resources }) + Ok(Learner { config, database, retrievers: retriever, resources }) } } @@ -743,6 +743,32 @@ impl Learner { /// # } /// ``` pub async fn init() -> Result { Self::with_config(Config::init()?).await } + + pub async fn retreive(&mut self, input: &str) -> Result { + let mut matches = Vec::new(); + + // Find all configs that match the input + for (name, config) in self.retrievers.as_map().iter() { + if config.pattern.is_match(input) { + matches.push((name, config)); + } + } + + match matches.len() { + 0 => Err(LearnerError::InvalidIdentifier), + 1 => { + let resource_config = self.resources.as_map().get(matches[0].0); + if let Some(resource_config) = resource_config { + Ok(matches[0].1.retrieve_resource(input, resource_config).await?) + } else { + todo!("Error because that resource wasn't available.") + } + }, + _ => Err(LearnerError::AmbiguousIdentifier( + matches.into_iter().map(|(n, c)| n.to_string()).collect(), + )), + } + } } #[cfg(test)] diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index 2b266fd..f73345a 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -9,6 +9,13 @@ pub use paper::*; pub use shared::*; use toml::Value; +// TODO (autoparallel): We almost need something like `Resource` to be given by these +// `ResourceConfig`s. Or, even renaming these like `ResourceTemplates` or something so a `Resource` +// has to fit into the `ResourceTemplate` (now that I type this out, `ResourceConfig` is still a +// reasonable name). But when we want to retrieve a resource, we need to actually get back a +// resource. Perhaps its just: +pub type Resource = BTreeMap; + #[derive(Debug, Clone, Default)] pub struct Resources { configs: BTreeMap, diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index f7ff1ae..13831c6 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,3 +1,5 @@ +use resource::Resource; + use super::*; /// Configuration for a specific paper source retriever. @@ -130,8 +132,14 @@ impl RetrieverConfig { Ok(paper) } + // TODO: perhaps this just isn't even implemented here and is instead implemented on `Learner`. + // Could consider an `api.rs` module to extend more learner functionality there. #[allow(missing_docs)] - pub async fn retrieve_resource(&self, input: &str) -> Result { + pub async fn retrieve_resource( + &self, + input: &str, + resource_config: &ResourceConfig, + ) -> Result { let identifier = self.extract_identifier(input)?; // Send request and get response diff --git a/crates/learner/tests/llm/mod.rs b/crates/learner/tests/llm/mod.rs index d2e9d47..aa4658b 100644 --- a/crates/learner/tests/llm/mod.rs +++ b/crates/learner/tests/llm/mod.rs @@ -8,7 +8,7 @@ use super::*; async fn test_download_then_send_pdf() -> Result<(), Box> { // Download a PDF let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retriever.get_paper("https://eprint.iacr.org/2016/260").await?; + let paper = learner.retrievers.get_paper("https://eprint.iacr.org/2016/260").await?; // let paper = Paper::new("https://eprint.iacr.org/2016/260").await.unwrap(); // paper.download_pdf(dir.path()).await.unwrap(); diff --git a/crates/learner/tests/workflows/database_operations/add.rs b/crates/learner/tests/workflows/database_operations/add.rs index e7afa68..5fb307d 100644 --- a/crates/learner/tests/workflows/database_operations/add.rs +++ b/crates/learner/tests/workflows/database_operations/add.rs @@ -79,7 +79,7 @@ mod document_operations { #[tokio::test] async fn test_add_complete_paper() -> TestResult<()> { let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retriever.get_paper("https://arxiv.org/abs/2301.07041").await?; + let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; let papers = Add::complete(&paper).execute(&mut learner.database).await?; assert_eq!(papers.len(), 1); @@ -102,7 +102,7 @@ mod document_operations { #[tokio::test] async fn test_add_paper_then_document() -> TestResult<()> { let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retriever.get_paper("https://arxiv.org/abs/2301.07041").await?; + let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; // First add paper only Add::paper(&paper).execute(&mut learner.database).await?; @@ -127,7 +127,7 @@ mod document_operations { #[tokio::test] async fn test_chain_document_addition() -> TestResult<()> { let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retriever.get_paper("https://arxiv.org/abs/2301.07041").await?; + let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; let papers = Add::paper(&paper).with_document().execute(&mut learner.database).await?; assert_eq!(papers.len(), 1); @@ -146,8 +146,8 @@ mod document_operations { let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; // Add multiple papers without documents - let paper1 = learner.retriever.get_paper("https://arxiv.org/abs/2301.07041").await?; - let paper2 = learner.retriever.get_paper("https://eprint.iacr.org/2016/260").await?; + let paper1 = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; + let paper2 = learner.retrievers.get_paper("https://eprint.iacr.org/2016/260").await?; Add::paper(&paper1).execute(&mut learner.database).await?; Add::paper(&paper2).execute(&mut learner.database).await?; diff --git a/crates/learner/tests/workflows/database_operations/remove.rs b/crates/learner/tests/workflows/database_operations/remove.rs index 00ab6b6..1762ec2 100644 --- a/crates/learner/tests/workflows/database_operations/remove.rs +++ b/crates/learner/tests/workflows/database_operations/remove.rs @@ -60,7 +60,7 @@ mod basic_operations { async fn test_remove_complete_paper() -> TestResult<()> { let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retriever.get_paper("https://arxiv.org/abs/2301.07041").await?; + let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; // Add paper with document Add::complete(&paper).execute(&mut learner.database).await?; @@ -145,7 +145,7 @@ mod dry_run { async fn test_dry_run_with_complete_paper() -> TestResult<()> { let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retriever.get_paper("https://arxiv.org/abs/2301.07041").await?; + let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; Add::complete(&paper).execute(&mut learner.database).await?; let would_remove = Remove::by_source(&paper.source, &paper.source_identifier) @@ -358,7 +358,7 @@ mod recovery { let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; // Add paper without document - let paper = learner.retriever.get_paper("https://arxiv.org/abs/2301.07041").await?; + let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; Add::paper(&paper).execute(&mut learner.database).await?; // Remove it diff --git a/crates/learnerd/src/commands/add.rs b/crates/learnerd/src/commands/add.rs index 90ef7c2..069ad5a 100644 --- a/crates/learnerd/src/commands/add.rs +++ b/crates/learnerd/src/commands/add.rs @@ -24,21 +24,21 @@ pub struct AddArgs { pub async fn add(interaction: &mut I, add_args: AddArgs) -> Result { let AddArgs { identifier, pdf, no_pdf } = add_args; - if interaction.learner().retriever.is_empty() { + if interaction.learner().retrievers.is_empty() { return Err(LearnerdError::Learner(LearnerError::Config( "No retriever configured.".to_string(), ))); } let (source, sanitized_identifier) = - interaction.learner().retriever.sanitize_identifier(&identifier)?; + interaction.learner().retrievers.sanitize_identifier(&identifier)?; let papers = Query::by_source(&source, &sanitized_identifier) .execute(&mut interaction.learner().database) .await?; if papers.is_empty() { interaction.reply(ResponseContent::Info(&format!("Fetching paper: {}", identifier)))?; - let paper = interaction.learner().retriever.get_paper(&identifier).await?; + let paper = interaction.learner().retrievers.get_paper(&identifier).await?; interaction.reply(ResponseContent::Paper(&paper))?; let with_pdf = paper.pdf_url.is_some() From afe3385c43ac134b05f14e53830300e28d5c087f Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 07:17:06 -0700 Subject: [PATCH 20/73] WIP: working through generic retriever stuff --- crates/learner/config/retrievers/doi.toml | 2 +- crates/learner/src/lib.rs | 29 +- crates/learner/src/resource/mod.rs | 92 +++--- crates/learner/src/retriever/config.rs | 72 +---- crates/learner/src/retriever/mod.rs | 122 +++---- crates/learner/src/retriever/response/json.rs | 300 +++++++----------- crates/learner/src/retriever/response/mod.rs | 9 +- crates/learner/src/retriever/response/xml.rs | 154 ++++----- .../tests/workflows/paper_retrieval.rs | 147 +++++---- 9 files changed, 414 insertions(+), 513 deletions(-) diff --git a/crates/learner/config/retrievers/doi.toml b/crates/learner/config/retrievers/doi.toml index adad43c..d035e11 100644 --- a/crates/learner/config/retrievers/doi.toml +++ b/crates/learner/config/retrievers/doi.toml @@ -9,7 +9,7 @@ source = "doi" type = "json" [response_format.field_maps.title] -path = "message/title" +path = "message/title/0" [response_format.field_maps.abstract] path = "message/abstract" diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index 915d318..b9f1003 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -754,20 +754,21 @@ impl Learner { } } - match matches.len() { - 0 => Err(LearnerError::InvalidIdentifier), - 1 => { - let resource_config = self.resources.as_map().get(matches[0].0); - if let Some(resource_config) = resource_config { - Ok(matches[0].1.retrieve_resource(input, resource_config).await?) - } else { - todo!("Error because that resource wasn't available.") - } - }, - _ => Err(LearnerError::AmbiguousIdentifier( - matches.into_iter().map(|(n, c)| n.to_string()).collect(), - )), - } + todo!("Finish this") + // match matches.len() { + // 0 => Err(LearnerError::InvalidIdentifier), + // 1 => { + // let resource_config = self.resources.as_map().get(matches[0].0); + // if let Some(resource_config) = resource_config { + // Ok(matches[0].1.retrieve_resource(input, resource_config).await?) + // } else { + // todo!("Error because that resource wasn't available.") + // } + // }, + // _ => Err(LearnerError::AmbiguousIdentifier( + // matches.into_iter().map(|(n, c)| n.to_string()).collect(), + // )), + // } } } diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index f73345a..13834d4 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -90,7 +90,7 @@ pub struct ValidationRules { impl ResourceConfig { /// Validates a set of values against this resource configuration - pub fn validate(&self, values: &toml::value::Table) -> Result { + pub fn validate(&self, values: &Resource) -> Result { // Check required fields for field in &self.fields { if field.required { @@ -299,28 +299,24 @@ mod tests { let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - // Create a valid paper - let paper_values = toml::value::Table::from_iter([ - ("title".into(), toml::Value::String("Understanding Quantum Computing".into())), - ( - "authors".into(), - toml::Value::Array(vec![toml::Value::Table(toml::value::Table::from_iter([ - ("name".into(), toml::Value::String("Alice Researcher".into())), - ("affiliation".into(), toml::Value::String("Tech University".into())), - ]))]), - ), - ("publication_date".into(), toml::Value::Datetime(date)), - ("doi".into(), toml::Value::String("10.1234/example.123".into())), - ]); + // Create a valid paper resource + let mut paper_resource = BTreeMap::new(); + paper_resource.insert("title".into(), Value::String("Understanding Quantum Computing".into())); - // Validate the paper - assert!(config.validate(&paper_values).unwrap()); + // Create the author table using TOML's Map type + let author = { + let mut map = toml::map::Map::new(); + map.insert("name".into(), Value::String("Alice Researcher".into())); + map.insert("affiliation".into(), Value::String("Tech University".into())); + map + }; + + paper_resource.insert("authors".into(), Value::Array(vec![Value::Table(author)])); + paper_resource.insert("publication_date".into(), Value::Datetime(date)); + paper_resource.insert("doi".into(), Value::String("10.1234/example.123".into())); - // Test required field validation - let invalid_paper = toml::value::Table::from_iter([ - ("authors".into(), toml::Value::Array(vec![])), // Missing title - ]); - assert!(config.validate(&invalid_paper).is_err()); + // Validate the paper + assert!(config.validate(&paper_resource).unwrap()); } #[test] @@ -330,21 +326,18 @@ mod tests { let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - let book_values = toml::value::Table::from_iter([ - ("title".into(), toml::Value::String("Advanced Quantum Computing".into())), - ( - "authors".into(), - toml::Value::Array(vec![ - toml::Value::String("Alice Writer".into()), - toml::Value::String("Bob Author".into()), - ]), - ), - ("isbn".into(), toml::Value::String("978-0-12-345678-9".into())), - ("publisher".into(), toml::Value::String("Academic Press".into())), - ("publication_date".into(), toml::Value::Datetime(date)), - ]); - - assert!(config.validate(&book_values).unwrap()); + // Create a valid book resource + let mut book_resource = BTreeMap::new(); + book_resource.insert("title".into(), Value::String("Advanced Quantum Computing".into())); + book_resource.insert( + "authors".into(), + Value::Array(vec![Value::String("Alice Writer".into()), Value::String("Bob Author".into())]), + ); + book_resource.insert("isbn".into(), Value::String("978-0-12-345678-9".into())); + book_resource.insert("publisher".into(), Value::String("Academic Press".into())); + book_resource.insert("publication_date".into(), Value::Datetime(date)); + + assert!(config.validate(&book_resource).unwrap()); } #[test] @@ -354,23 +347,22 @@ mod tests { let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - let thesis_values = toml::value::Table::from_iter([ - ("title".into(), toml::Value::String("Novel Approaches to Quantum Error Correction".into())), - ("author".into(), toml::Value::String("Alice Researcher".into())), - ("degree".into(), toml::Value::String("PhD".into())), - ("institution".into(), toml::Value::String("Tech University".into())), - ("completion_date".into(), toml::Value::Datetime(date)), - ( - "advisors".into(), - toml::Value::Array(vec![toml::Value::String("Prof. Bob Supervisor".into())]), - ), - ]); + // Create a valid thesis resource + let mut thesis_resource = BTreeMap::new(); + thesis_resource + .insert("title".into(), Value::String("Novel Approaches to Quantum Error Correction".into())); + thesis_resource.insert("author".into(), Value::String("Alice Researcher".into())); + thesis_resource.insert("degree".into(), Value::String("PhD".into())); + thesis_resource.insert("institution".into(), Value::String("Tech University".into())); + thesis_resource.insert("completion_date".into(), Value::Datetime(date)); + thesis_resource + .insert("advisors".into(), Value::Array(vec![Value::String("Prof. Bob Supervisor".into())])); - assert!(config.validate(&thesis_values).unwrap()); + assert!(config.validate(&thesis_resource).unwrap()); // Test degree enum validation - let mut invalid_thesis = thesis_values.clone(); - invalid_thesis.insert("degree".into(), toml::Value::String("InvalidDegree".into())); + let mut invalid_thesis = thesis_resource.clone(); + invalid_thesis.insert("degree".into(), Value::String("InvalidDegree".into())); assert!(config.validate(&invalid_thesis).is_err()); } } diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 13831c6..af3588a 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,4 +1,5 @@ use resource::Resource; +use toml::Value; use super::*; @@ -79,66 +80,13 @@ impl RetrieverConfig { .ok_or(LearnerError::InvalidIdentifier) } - /// Retrieves a paper using this configuration. - /// - /// This method: - /// 1. Extracts the canonical identifier - /// 2. Constructs the API URL - /// 3. Makes the HTTP request - /// 4. Processes the response - /// - /// # Arguments - /// - /// * `input` - Paper identifier or URL - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - The retrieved Paper object - /// - A LearnerError if any step fails - /// - /// # Errors - /// - /// This method will return an error if: - /// - The identifier cannot be extracted - /// - The HTTP request fails - /// - The response cannot be parsed - pub async fn retrieve_paper(&self, input: &str) -> Result { - let identifier = self.extract_identifier(input)?; - let url = self.endpoint_template.replace("{identifier}", identifier); - - debug!("Fetching from {} via: {}", self.name, url); - - let client = reqwest::Client::new(); - let mut request = client.get(&url); - - // Add any configured headers - for (key, value) in &self.headers { - request = request.header(key, value); - } - - let response = request.send().await?; - let data = response.bytes().await?; - - trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); - - let response_processor = match &self.response_format { - ResponseFormat::Xml(config) => config as &dyn ResponseProcessor, - ResponseFormat::Json(config) => config as &dyn ResponseProcessor, - }; - let mut paper = response_processor.process_response(&data).await?; - paper.source = self.source.clone(); - paper.source_identifier = identifier.to_string(); - Ok(paper) - } - // TODO: perhaps this just isn't even implemented here and is instead implemented on `Learner`. // Could consider an `api.rs` module to extend more learner functionality there. #[allow(missing_docs)] pub async fn retrieve_resource( &self, input: &str, - resource_config: &ResourceConfig, + resource_config: ResourceConfig, ) -> Result { let identifier = self.extract_identifier(input)?; @@ -158,12 +106,22 @@ impl RetrieverConfig { let data = response.bytes().await?; trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); - // Process the response into a generic Value first - let _response_processor = match &self.response_format { + // Process the response using configured processor + let processor = match &self.response_format { ResponseFormat::Xml(config) => config as &dyn ResponseProcessor, ResponseFormat::Json(config) => config as &dyn ResponseProcessor, }; - todo!(); + // Process response and get resource + let mut resource = processor.process_response(&data, &resource_config)?; + + // Add source metadata + resource.insert("source".into(), Value::String(self.source.clone())); + resource.insert("source_identifier".into(), Value::String(identifier.to_string())); + + // Validate full resource against config + resource_config.validate(&resource)?; + + Ok(resource) } } diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 62900b6..0e82443 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -194,13 +194,14 @@ impl Retrievers { } } - match matches.len() { - 0 => Err(LearnerError::InvalidIdentifier), - 1 => matches[0].retrieve_paper(input).await, - _ => Err(LearnerError::AmbiguousIdentifier( - matches.into_iter().map(|c| c.name.clone()).collect(), - )), - } + todo!("Fix this") + // match matches.len() { + // 0 => Err(LearnerError::InvalidIdentifier), + // 1 => matches[0].retrieve_paper(input).await, + // _ => Err(LearnerError::AmbiguousIdentifier( + // matches.into_iter().map(|c| c.name.clone()).collect(), + // )), + // } } /// Sanitizes and normalizes a paper identifier using configured retrieval patterns. @@ -392,59 +393,60 @@ mod tests { assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); } - #[test] - fn test_doi_config_deserialization() { - let config_str = include_str!("../../config/retrievers/doi.toml"); - - let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); - - // Verify basic fields - assert_eq!(retriever.name, "doi"); - assert_eq!(retriever.base_url, "https://api.crossref.org/works"); - assert_eq!(retriever.source, "doi"); - - // Test pattern matching - let test_cases = [ - ("10.1145/1327452.1327492", true), - ("https://doi.org/10.1145/1327452.1327492", true), - ("invalid-doi", false), - ("https://wrong.url/10.1145/1327452.1327492", false), - ]; - - for (input, expected) in test_cases { - assert_eq!( - retriever.pattern.is_match(input), - expected, - "Pattern match failed for input: {}", - input - ); - } - - // Test identifier extraction - assert_eq!( - retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), - "10.1145/1327452.1327492" - ); - assert_eq!( - retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), - "10.1145/1327452.1327492" - ); - - // Verify response format - match &retriever.response_format { - ResponseFormat::Json(config) => { - // Verify field mappings - let field_maps = &config.field_maps; - assert!(field_maps.contains_key("title")); - assert!(field_maps.contains_key("abstract")); - assert!(field_maps.contains_key("authors")); - assert!(field_maps.contains_key("publication_date")); - assert!(field_maps.contains_key("pdf_url")); - assert!(field_maps.contains_key("doi")); - }, - _ => panic!("Expected JSON response format"), - } - } + // TODO: Fix this + // #[test] + // fn test_doi_config_deserialization() { + // let config_str = include_str!("../../config/retrievers/doi.toml"); + + // let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + + // // Verify basic fields + // assert_eq!(retriever.name, "doi"); + // assert_eq!(retriever.base_url, "https://api.crossref.org/works"); + // assert_eq!(retriever.source, "doi"); + + // // Test pattern matching + // let test_cases = [ + // ("10.1145/1327452.1327492", true), + // ("https://doi.org/10.1145/1327452.1327492", true), + // ("invalid-doi", false), + // ("https://wrong.url/10.1145/1327452.1327492", false), + // ]; + + // for (input, expected) in test_cases { + // assert_eq!( + // retriever.pattern.is_match(input), + // expected, + // "Pattern match failed for input: {}", + // input + // ); + // } + + // // Test identifier extraction + // assert_eq!( + // retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), + // "10.1145/1327452.1327492" + // ); + // assert_eq!( + // retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), + // "10.1145/1327452.1327492" + // ); + + // // Verify response format + // match &retriever.response_format { + // ResponseFormat::Json(config) => { + // // Verify field mappings + // let field_maps = &config.field_maps; + // assert!(field_maps.contains_key("title")); + // assert!(field_maps.contains_key("abstract")); + // assert!(field_maps.contains_key("authors")); + // assert!(field_maps.contains_key("publication_date")); + // assert!(field_maps.contains_key("pdf_url")); + // assert!(field_maps.contains_key("doi")); + // }, + // _ => panic!("Expected JSON response format"), + // } + // } #[test] fn test_iacr_config_deserialization() { diff --git a/crates/learner/src/retriever/response/json.rs b/crates/learner/src/retriever/response/json.rs index bfd558f..a600b81 100644 --- a/crates/learner/src/retriever/response/json.rs +++ b/crates/learner/src/retriever/response/json.rs @@ -1,7 +1,7 @@ //! JSON response parser implementation. //! -//! This module handles parsing of JSON API responses into Paper objects using -//! configurable field mappings. It supports flexible path-based field extraction +//! This module handles parsing of JSON API responses into resources using +//! configurable field mappings. It supports path-based field extraction //! with optional transformations. //! //! # Example Configuration @@ -12,218 +12,164 @@ //! //! [response_format.field_maps] //! title = { path = "message/title/0" } -//! abstract = { path = "message/abstract" } -//! publication_date = { path = "message/published-print/date-parts/0" } -//! authors = { path = "message/author" } +//! summary = { path = "message/abstract" } +//! created_at = { path = "message/published/date-time" } +//! contributors = { path = "message/contributors" } //! ``` -use serde_json::Value; +use resource::chrono_to_toml_datetime; +use serde_json; +use toml::{self, Value as TomlValue}; use super::*; /// Configuration for processing JSON API responses. /// -/// Provides field mapping rules to extract paper metadata from JSON responses +/// Provides field mapping rules to extract resource fields from JSON responses /// using path-based access patterns. -/// -/// # Examples -/// -/// ```no_run -/// # use std::collections::HashMap; -/// # use learner::retriever::{json::JsonConfig, FieldMap}; -/// let config = JsonConfig { -/// field_maps: HashMap::from([("title".to_string(), FieldMap { -/// path: "message/title/0".to_string(), -/// transform: None, -/// })]), -/// }; -/// ``` #[derive(Debug, Clone, Deserialize)] pub struct JsonConfig { - /// JSON path mappings for paper metadata fields pub field_maps: HashMap, } -#[async_trait] +// TODO: Refactor this impl ResponseProcessor for JsonConfig { - /// Processes a JSON API response into a Paper object. - /// - /// Extracts paper metadata from the JSON response using configured field mappings. - /// Required fields (title, abstract, publication date, authors) must be present - /// and valid. - /// - /// # Arguments - /// - /// * `data` - Raw JSON response bytes - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - A populated Paper object - /// - A LearnerError if parsing fails or required fields are missing - /// - /// # Errors - /// - /// This method will return an error if: - /// - JSON parsing fails - /// - Required fields are missing - /// - Field values are invalid or cannot be transformed - async fn process_response(&self, data: &[u8]) -> Result { - let json: Value = serde_json::from_slice(data) + fn process_response(&self, data: &[u8], resource_config: &ResourceConfig) -> Result { + // Parse raw JSON data + let json: serde_json::Value = serde_json::from_slice(data) .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; trace!("Processing JSON response: {}", serde_json::to_string_pretty(&json).unwrap()); - let title = self.extract_field(&json, "title")?; - let abstract_text = self.extract_field(&json, "abstract")?; - let publication_date = - chrono::DateTime::parse_from_rfc3339(&self.extract_field(&json, "publication_date")?) - .map(|dt| dt.with_timezone(&Utc)) - .map_err(|e| LearnerError::ApiError(format!("Invalid date format: {}", e)))?; - - let authors = if let Some(map) = self.field_maps.get("authors") { - self.extract_authors(&json, map)? - } else { - return Err(LearnerError::ApiError("Missing authors mapping".to_string())); - }; - - let pdf_url = self.field_maps.get("pdf_url").and_then(|map| { - self.get_by_path(&json, &map.path).map(|url| { - if let Some(transform) = &map.transform { - apply_transform(&url, transform).ok().unwrap_or_else(|| url.clone()) - } else { - url.clone() + let mut resource = BTreeMap::new(); + + // Process each field according to resource configuration + for field_def in &resource_config.fields { + if let Some(field_map) = self.field_maps.get(&field_def.name) { + // Extract raw value if present + if let Some(value) = self.extract_value(&json, field_map, &field_def.field_type)? { + resource.insert(field_def.name.clone(), value); + } else if field_def.required { + return Err(LearnerError::ApiError(format!( + "Required field '{}' not found in response", + field_def.name + ))); + } else if let Some(default) = &field_def.default { + resource.insert(field_def.name.clone(), default.clone()); } - }) - }); - - let doi = self - .field_maps - .get("doi") - .and_then(|map| self.get_by_path(&json, &map.path)) - .map(String::from); + } + } - Ok(Paper { - title, - authors, - abstract_text, - publication_date, - source: String::new(), - source_identifier: String::new(), - pdf_url, - doi, - }) + Ok(resource) } } impl JsonConfig { - /// Extracts a single field value using configured mapping. - /// - /// # Errors - /// - /// Returns error if: - /// - Field mapping is missing - /// - Field value cannot be found - /// - Value transformation fails - fn extract_field(&self, json: &Value, field: &str) -> Result { - let map = self - .field_maps - .get(field) - .ok_or_else(|| LearnerError::ApiError(format!("Missing field mapping for {}", field)))?; + /// Recursively converts a JSON value into a TOML value + fn json_to_toml_value(&self, value: &serde_json::Value) -> Option { + match value { + // For JSON objects, recursively convert all their fields + serde_json::Value::Object(obj) => { + let mut map = toml::map::Map::new(); + for (key, val) in obj { + if let Some(converted) = self.json_to_toml_value(val) { + map.insert(key.clone(), converted); + } + } + Some(TomlValue::Table(map)) + }, + + // For arrays, recursively convert all elements + serde_json::Value::Array(arr) => { + let values: Vec<_> = arr.iter().filter_map(|item| self.json_to_toml_value(item)).collect(); + Some(TomlValue::Array(values)) + }, + + // Direct conversions for primitive types + serde_json::Value::String(s) => Some(TomlValue::String(s.clone())), + serde_json::Value::Number(n) => + if n.is_i64() { + n.as_i64().map(TomlValue::Integer) + } else { + n.as_f64().map(TomlValue::Float) + }, + serde_json::Value::Bool(b) => Some(TomlValue::Boolean(*b)), + serde_json::Value::Null => None, + } + } - let value = self - .get_by_path(json, &map.path) - .ok_or_else(|| LearnerError::ApiError(format!("No content found for {}", field)))?; + /// Extracts and converts a value from the JSON response according to the field type + fn extract_value( + &self, + json: &serde_json::Value, + field_map: &FieldMap, + field_type: &str, + ) -> Result> { + // Get the value at the specified path + if let Some(value) = self.get_path_value(json, &field_map.path) { + // Apply any transformations if it's a string + let transformed_value = if let Some(transform) = &field_map.transform { + if let Some(str_val) = value.as_str() { + let transformed = apply_transform(str_val, transform)?; + serde_json::Value::String(transformed) + } else { + value.clone() + } + } else { + value.clone() + }; - if let Some(transform) = &map.transform { - apply_transform(&value, transform) + // Convert the value based on the expected field type + match field_type { + "string" => transformed_value + .as_str() + .map(|s| TomlValue::String(s.to_string())) + .ok_or_else(|| { + LearnerError::ApiError(format!("Expected string value for field type 'string'")) + }) + .map(Some), + "datetime" => transformed_value + .as_str() + .ok_or_else(|| LearnerError::ApiError("Expected string for datetime".into())) + .and_then(|s| { + DateTime::parse_from_rfc3339(s) + .map_err(|e| LearnerError::ApiError(format!("Invalid datetime format: {}", e))) + }) + .map(|dt| Some(TomlValue::Datetime(chrono_to_toml_datetime(dt.with_timezone(&Utc))))), + "array" => Ok(self.json_to_toml_value(&transformed_value)), + "table" => Ok(self.json_to_toml_value(&transformed_value)), + unsupported => + Err(LearnerError::ApiError(format!("Unsupported field type: {}", unsupported))), + } } else { - Ok(value) + Ok(None) } } - /// Retrieves a value from JSON using slash-separated path. - /// - /// Supports both object key and array index access: - /// - "message/title" -> object access - /// - "authors/0/name" -> array access - /// - /// Handles string, array, and number values with appropriate conversion. - fn get_by_path(&self, json: &Value, path: &str) -> Option { - let mut current = json; + /// Gets a string value from JSON using a path + fn get_by_path(&self, json: &serde_json::Value, path: &str) -> Option { + self.get_path_value(json, path).and_then(|value| match value { + serde_json::Value::String(s) => Some(s.clone()), + serde_json::Value::Number(n) => Some(n.to_string()), + serde_json::Value::Array(arr) if !arr.is_empty() => arr[0].as_str().map(String::from), + _ => value.as_str().map(String::from), + }) + } + /// Navigates JSON structure using a path + fn get_path_value<'a>( + &self, + json: &'a serde_json::Value, + path: &str, + ) -> Option<&'a serde_json::Value> { + let mut current = json; for part in path.split('/') { current = if let Ok(index) = part.parse::() { - // Handle numeric indices for arrays current.as_array()?.get(index)? } else { - // Handle regular object keys current.get(part)? }; } - - match current { - Value::String(s) => Some(s.clone()), - Value::Array(arr) if !arr.is_empty() => arr[0].as_str().map(String::from), - Value::Number(n) => Some(n.to_string()), - _ => current.as_str().map(String::from), - } - } - - /// Extracts and processes author information from JSON. - /// - /// Handles author objects with given/family name fields and optional - /// affiliation information. Expects authors as an array matching the - /// configured path. - /// - /// # Errors - /// - /// Returns error if no valid authors are found in the response. - fn extract_authors(&self, json: &Value, map: &FieldMap) -> Result> { - let authors = if let Some(Value::Array(arr)) = get_path_value(json, &map.path) { - arr - .iter() - .filter_map(|author| { - let name = match (author.get("given"), author.get("family")) { - (Some(given), Some(family)) => { - format!("{} {}", given.as_str().unwrap_or(""), family.as_str().unwrap_or("")) - }, - (Some(given), None) => given.as_str()?.to_string(), - (None, Some(family)) => family.as_str()?.to_string(), - (None, None) => return None, - }; - - let affiliation = author - .get("affiliation") - .and_then(|a| a.as_array()) - .and_then(|arr| arr.first()) - .and_then(|aff| aff.get("name")) - .and_then(|n| n.as_str()) - .map(String::from); - - Some(Author { name, affiliation, email: None }) - }) - .collect() - } else { - Vec::new() - }; - - if authors.is_empty() { - Err(LearnerError::ApiError("No authors found".to_string())) - } else { - Ok(authors) - } - } -} - -/// Helper function to navigate JSON structure using path. -/// -/// Similar to get_by_path but returns raw JSON Value instead of -/// converted string. -fn get_path_value<'a>(json: &'a Value, path: &str) -> Option<&'a Value> { - let mut current = json; - for part in path.split('/') { - current = current.get(part)?; + Some(current) } - Some(current) } diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index 73f938d..98c82bf 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -124,7 +124,7 @@ pub enum Transform { /// } /// } /// ``` -#[async_trait] +// #[async_trait] pub trait ResponseProcessor: Send + Sync { /// Process raw response data into a Paper object. /// @@ -137,5 +137,10 @@ pub trait ResponseProcessor: Send + Sync { /// Returns a Result containing either: /// - A fully populated Paper object /// - A LearnerError if parsing fails - async fn process_response(&self, data: &[u8]) -> Result; + fn process_response( + &self, + data: &[u8], + // retriever_config: RetrieverConfig, + resource_config: &ResourceConfig, + ) -> Result; } diff --git a/crates/learner/src/retriever/response/xml.rs b/crates/learner/src/retriever/response/xml.rs index f797f90..4f7f39e 100644 --- a/crates/learner/src/retriever/response/xml.rs +++ b/crates/learner/src/retriever/response/xml.rs @@ -19,6 +19,8 @@ //! ``` use quick_xml::{events::Event, Reader}; +use resource::chrono_to_toml_datetime; +use toml::Value; use super::*; @@ -49,107 +51,77 @@ pub struct XmlConfig { pub field_maps: HashMap, } -#[async_trait] impl ResponseProcessor for XmlConfig { - /// Processes an XML API response into a Paper object. - /// - /// Extracts paper metadata from the XML response using configured field mappings. - /// Handles namespace stripping if enabled and validates required fields. - /// - /// # Arguments - /// - /// * `data` - Raw XML response bytes - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - A populated Paper object - /// - A LearnerError if parsing fails or required fields are missing - /// - /// # Errors - /// - /// This method will return an error if: - /// - XML parsing fails - /// - Required fields are missing - /// - Field values are invalid or cannot be transformed - async fn process_response(&self, data: &[u8]) -> Result { + fn process_response( + &self, + data: &[u8], + // retriever_config: &RetrieverConfig, + resource_config: &ResourceConfig, + ) -> Result { + // Handle namespace stripping let xml = if self.strip_namespaces { strip_xml_namespaces(&String::from_utf8_lossy(data)) } else { String::from_utf8_lossy(data).to_string() }; + // Extract raw XML content into path -> string mapping let content = self.extract_content(&xml)?; - - // Helper function to extract and transform field - let get_field = |name: &str| -> Result { - let map = self - .field_maps - .get(name) - .ok_or_else(|| LearnerError::ApiError(format!("Missing field mapping for {}", name)))?; - - let value = content - .get(&map.path) - .ok_or_else(|| LearnerError::ApiError(format!("No content found for {}", name)))?; - - if let Some(transform) = &map.transform { - apply_transform(value, transform) - } else { - Ok(value.clone()) - } - }; - - let title = get_field("title")?; - let abstract_text = get_field("abstract")?; - let publication_date = chrono::DateTime::parse_from_rfc3339(&get_field("publication_date")?) - .map(|dt| dt.with_timezone(&Utc)) - .map_err(|e| LearnerError::ApiError(format!("Invalid date format: {}", e)))?; - - // Extract authors - let authors = if let Some(map) = self.field_maps.get("authors") { - let names: Vec = content - .get(&map.path) - .map(|s| { - s.split(';') - .map(|name| Author { - name: name.trim().to_string(), - affiliation: None, - email: None, - }) - .collect() - }) - .unwrap_or_default(); - if names.is_empty() { - return Err(LearnerError::ApiError("No authors found".to_string())); - } - names - } else { - return Err(LearnerError::ApiError("Missing authors mapping".to_string())); - }; - - // Optional fields - let pdf_url = self.field_maps.get("pdf_url").and_then(|map| { - content.get(&map.path).map(|url| { - if let Some(transform) = &map.transform { - apply_transform(url, transform).ok().unwrap_or_else(|| url.clone()) - } else { - url.clone() + let mut resource = BTreeMap::new(); + + // Process each field according to the resource configuration + for field_def in &resource_config.fields { + // Look up the field mapping from retriever config + if let Some(field_map) = self.field_maps.get(&field_def.name) { + // Try to get the raw value using configured path + if let Some(raw_value) = content.get(&field_map.path) { + // Apply any configured transformations + let transformed_value = if let Some(transform) = &field_map.transform { + apply_transform(raw_value, transform)? + } else { + raw_value.clone() + }; + + // Convert string to appropriate TOML type based on field definition + let value = match field_def.field_type.as_str() { + "string" => Value::String(transformed_value), + "datetime" => { + let dt = DateTime::parse_from_rfc3339(&transformed_value).map_err(|e| { + LearnerError::ApiError(format!( + "Invalid date format for field '{}': {}", + field_def.name, e + )) + })?; + Value::Datetime(chrono_to_toml_datetime(dt.with_timezone(&Utc))) + }, + "array" => { + // For arrays, split on semicolon and create string array + let values = + transformed_value.split(';').map(|s| Value::String(s.trim().to_string())).collect(); + Value::Array(values) + }, + // Add other type conversions as needed + unsupported => + return Err(LearnerError::ApiError(format!( + "Unsupported field type '{}' for field '{}'", + unsupported, field_def.name + ))), + }; + resource.insert(field_def.name.clone(), value); + } else if field_def.required { + // Field was required but not found in response + return Err(LearnerError::ApiError(format!( + "Required field '{}' not found in response", + field_def.name + ))); + } else if let Some(default) = &field_def.default { + // Use default value if available + resource.insert(field_def.name.clone(), default.clone()); } - }) - }); - - let doi = self.field_maps.get("doi").and_then(|map| content.get(&map.path)).map(String::from); + } + } - Ok(Paper { - title, - authors, - abstract_text, - publication_date, - source: String::new(), - source_identifier: String::new(), - pdf_url, - doi, - }) + Ok(resource) } } diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index 48065ac..7c1c328 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -1,25 +1,38 @@ use std::fs; +use learner::resource::ResourceConfig; + use super::*; #[tokio::test] async fn test_arxiv_retriever_integration() -> TestResult<()> { - let config_str = fs::read_to_string("config/retrievers/arxiv.toml").expect( + let ret_config_str = fs::read_to_string("config/retrievers/arxiv.toml").expect( + "Failed to read config + file", + ); + let res_config_str = fs::read_to_string("config/resources/paper.toml").expect( "Failed to read config file", ); - let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); + let retriever: RetrieverConfig = toml::from_str(&ret_config_str).expect("Failed to parse config"); + let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); // Test with a real arXiv paper - let paper = retriever.retrieve_paper("2301.07041").await?; - - assert!(!paper.title.is_empty()); - assert!(!paper.authors.is_empty()); - assert!(!paper.abstract_text.is_empty()); - assert!(paper.pdf_url.is_some()); - assert_eq!(paper.source, "arxiv"); - assert_eq!(paper.source_identifier, "2301.07041"); + let paper = retriever.retrieve_resource("2301.07041", resource).await?; + + dbg!(&paper); + + assert_eq!( + paper.get("title").unwrap().as_str().unwrap(), + "Verifiable Fully Homomorphic Encryption" + ); + // assert!(!paper.title.is_empty()); + // assert!(!paper.authors.is_empty()); + // assert!(!paper.abstract_text.is_empty()); + // assert!(paper.pdf_url.is_some()); + // assert_eq!(paper.source, "arxiv"); + // assert_eq!(paper.source_identifier, "2301.07041"); Ok(()) } @@ -33,16 +46,17 @@ async fn test_arxiv_pdf_from_paper() -> TestResult<()> { let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); - // Test with a real arXiv paper - let paper = retriever.retrieve_paper("2301.07041").await?; - let dir = tempdir()?; - paper.download_pdf(dir.path()).await?; - let path = dir.into_path().join(paper.filename()); - assert!(path.exists()); - let pdf_content = PDFContentBuilder::new().path(path).analyze()?; - assert!(pdf_content.pages[0].text.contains("arXiv:2301.07041v2")); - - Ok(()) + todo!() + // // Test with a real arXiv paper + // let paper = retriever.retrieve_paper("2301.07041").await?; + // let dir = tempdir()?; + // paper.download_pdf(dir.path()).await?; + // let path = dir.into_path().join(paper.filename()); + // assert!(path.exists()); + // let pdf_content = PDFContentBuilder::new().path(path).analyze()?; + // assert!(pdf_content.pages[0].text.contains("arXiv:2301.07041v2")); + + // Ok(()) } #[tokio::test] @@ -52,15 +66,15 @@ async fn test_iacr_retriever_integration() { let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); - // Test with a real IACR paper - let paper = retriever.retrieve_paper("2016/260").await.unwrap(); + // // Test with a real IACR paper + // let paper = retriever.retrieve_paper("2016/260").await.unwrap(); - assert!(!paper.title.is_empty()); - assert!(!paper.authors.is_empty()); - assert!(!paper.abstract_text.is_empty()); - assert!(paper.pdf_url.is_some()); - assert_eq!(paper.source, "iacr"); - assert_eq!(paper.source_identifier, "2016/260"); + // assert!(!paper.title.is_empty()); + // assert!(!paper.authors.is_empty()); + // assert!(!paper.abstract_text.is_empty()); + // assert!(paper.pdf_url.is_some()); + // assert_eq!(paper.source, "iacr"); + // assert_eq!(paper.source_identifier, "2016/260"); } #[traced_test] @@ -71,35 +85,46 @@ async fn test_iacr_pdf_from_paper() -> TestResult<()> { let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); - // Test with a real arXiv paper - let paper = retriever.retrieve_paper("2016/260").await?; - let dir = tempdir()?; - paper.download_pdf(dir.path()).await?; - let path = dir.into_path().join(paper.filename()); - assert!(path.exists()); - let pdf_content = PDFContentBuilder::new().path(path).analyze()?; - assert!(pdf_content.pages[0].text.contains("On the Size")); - - Ok(()) + todo!() + // // Test with a real arXiv paper + // let paper = retriever.retrieve_paper("2016/260").await?; + // let dir = tempdir()?; + // paper.download_pdf(dir.path()).await?; + // let path = dir.into_path().join(paper.filename()); + // assert!(path.exists()); + // let pdf_content = PDFContentBuilder::new().path(path).analyze()?; + // assert!(pdf_content.pages[0].text.contains("On the Size")); + + // Ok(()) } #[tokio::test] -async fn test_doi_retriever_integration() { - let config_str = - fs::read_to_string("config/retrievers/doi.toml").expect("Failed to read config file"); +#[traced_test] +async fn test_doi_retriever_integration() -> TestResult<()> { + let ret_config_str = fs::read_to_string("config/retrievers/doi.toml").expect( + "Failed to read config + file", + ); + let res_config_str = fs::read_to_string("config/resources/paper.toml").expect( + "Failed to read config + file", + ); - let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); + let retriever: RetrieverConfig = toml::from_str(&ret_config_str).expect("Failed to parse config"); + let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); // Test with a real DOI paper - let paper = retriever.retrieve_paper("10.1145/1327452.1327492").await.unwrap(); - - assert!(!paper.title.is_empty()); - assert!(!paper.authors.is_empty()); - assert!(!paper.abstract_text.is_empty()); - assert!(paper.pdf_url.is_some()); - assert_eq!(paper.source, "doi"); - assert_eq!(paper.source_identifier, "10.1145/1327452.1327492"); - assert!(paper.doi.is_some()); + let paper = retriever.retrieve_resource("10.1145/1327452.1327492", resource).await?; + + dbg!(&paper); + // assert!(!paper.title.is_empty()); + // assert!(!paper.authors.is_empty()); + // assert!(!paper.abstract_text.is_empty()); + // assert!(paper.pdf_url.is_some()); + // assert_eq!(paper.source, "doi"); + // assert_eq!(paper.source_identifier, "10.1145/1327452.1327492"); + // assert!(paper.doi.is_some()); + Ok(()) } #[ignore = "This PDF downloads properly but it does not parse correctly with Lopdf due to: `Error: \ @@ -111,15 +136,15 @@ async fn test_doi_pdf_from_paper() -> TestResult<()> { fs::read_to_string("config/retrievers/doi.toml").expect("Failed to read config file"); let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); - + todo!() // Test with a real arXiv paper - let paper = retriever.retrieve_paper("10.1145/1327452.1327492").await?; - let dir = tempdir()?; - paper.download_pdf(dir.path()).await?; - let path = dir.into_path().join(paper.filename()); - assert!(path.exists()); - let pdf_content = PDFContentBuilder::new().path(path).analyze()?; - assert!(pdf_content.pages[0].text.contains("arXiv:2301.07041v2")); - - Ok(()) + // let paper = retriever.retrieve_paper("10.1145/1327452.1327492").await?; + // let dir = tempdir()?; + // paper.download_pdf(dir.path()).await?; + // let path = dir.into_path().join(paper.filename()); + // assert!(path.exists()); + // let pdf_content = PDFContentBuilder::new().path(path).analyze()?; + // assert!(pdf_content.pages[0].text.contains("arXiv:2301.07041v2")); + + // Ok(()) } From 52c8ceef62dadddd6261e68d38e7b8e90ca5903c Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 07:36:30 -0700 Subject: [PATCH 21/73] WIP: improving resource/retriever --- crates/learner/config/resources/paper.toml | 27 ++- crates/learner/src/resource/mod.rs | 10 + crates/learner/src/retriever/response/json.rs | 187 ++++++++++++------ 3 files changed, 167 insertions(+), 57 deletions(-) diff --git a/crates/learner/config/resources/paper.toml b/crates/learner/config/resources/paper.toml index bf95c4d..4b21984 100644 --- a/crates/learner/config/resources/paper.toml +++ b/crates/learner/config/resources/paper.toml @@ -1,7 +1,8 @@ -name = "paper" - +# Core resource definition description = "A scholarly paper or article published in an academic context" +name = "paper" +# Title field - simple string type [[fields]] description = "The full title of the paper" field_type = "string" @@ -9,6 +10,7 @@ name = "title" required = true validation = { min_length = 1, max_length = 500 } +# Authors field - array of complex author objects [[fields]] description = "The paper's authors with their affiliations" field_type = "array" @@ -16,18 +18,30 @@ name = "authors" required = true validation = { min_items = 1 } +[fields.type_definition.element_type] +field_type = "table" +fields = [ + { name = "name", field_type = "string", description = "Author's full name", required = true }, + { name = "affiliation", field_type = "string", description = "Author's institutional affiliation", required = false }, + { name = "email", field_type = "string", description = "Author's contact email", required = false }, +] +name = "author" + +# Abstract field - simple string [[fields]] description = "The paper's abstract or summary" field_type = "string" name = "abstract" required = false +# Publication date - datetime type [[fields]] description = "When the paper was published or last updated" field_type = "datetime" name = "publication_date" required = false +# DOI field - string with pattern validation [[fields]] description = "Digital Object Identifier" field_type = "string" @@ -35,12 +49,20 @@ name = "doi" required = false validation = { pattern = "^10\\.\\d{4,9}/[-._;()/:a-zA-Z0-9]+$" } +# Keywords field - array of strings [[fields]] description = "Keywords or subject areas" field_type = "array" name = "keywords" required = false +[fields.type_definition.element_type] +description = "A single keyword or subject area" +field_type = "string" +name = "keyword" # Add this! +required = false + +# Publication details - all simple strings [[fields]] description = "Journal where the paper was published" field_type = "string" @@ -65,6 +87,7 @@ field_type = "string" name = "pages" required = false +# Peer review flag - boolean with default [[fields]] default = true description = "Whether the paper underwent peer review" diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index 13834d4..99078bd 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -65,6 +65,16 @@ pub struct FieldDefinition { /// Optional validation rules #[serde(default)] pub validation: Option, + + pub type_definition: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypeDefinition { + // For array types, defines the structure of elements + pub element_type: Option>, + // For table types, defines the fields + pub fields: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/learner/src/retriever/response/json.rs b/crates/learner/src/retriever/response/json.rs index a600b81..5474abd 100644 --- a/crates/learner/src/retriever/response/json.rs +++ b/crates/learner/src/retriever/response/json.rs @@ -17,7 +17,7 @@ //! contributors = { path = "message/contributors" } //! ``` -use resource::chrono_to_toml_datetime; +use resource::{chrono_to_toml_datetime, FieldDefinition, TypeDefinition}; use serde_json; use toml::{self, Value as TomlValue}; @@ -46,8 +46,8 @@ impl ResponseProcessor for JsonConfig { // Process each field according to resource configuration for field_def in &resource_config.fields { if let Some(field_map) = self.field_maps.get(&field_def.name) { - // Extract raw value if present - if let Some(value) = self.extract_value(&json, field_map, &field_def.field_type)? { + // Extract raw value if present, now passing the full field definition + if let Some(value) = self.extract_value(&json, field_map, field_def)? { resource.insert(field_def.name.clone(), value); } else if field_def.required { return Err(LearnerError::ApiError(format!( @@ -65,27 +65,119 @@ impl ResponseProcessor for JsonConfig { } impl JsonConfig { - /// Recursively converts a JSON value into a TOML value - fn json_to_toml_value(&self, value: &serde_json::Value) -> Option { - match value { - // For JSON objects, recursively convert all their fields - serde_json::Value::Object(obj) => { + /// Converts a JSON value into a TOML value, respecting type definitions + fn json_to_toml_value( + &self, + value: &serde_json::Value, + field_type: &str, + type_definition: Option<&TypeDefinition>, + ) -> Result> { + match field_type { + // Handle array types with potential element type definitions + "array" => { + let array = + value.as_array().ok_or_else(|| LearnerError::ApiError("Expected array value".into()))?; + + // Get element type definition if available + let element_def = type_definition.and_then(|def| def.element_type.as_ref()); + + // Convert each array element according to its type definition + let values: Result> = array + .iter() + .map(|item| { + if let Some(def) = element_def { + self.json_to_toml_value(item, &def.field_type, def.type_definition.as_ref()) + } else { + // For simple arrays without type definitions, do basic conversion + Ok(self.convert_simple_value(item)) + } + }) + .filter_map(|r| r.transpose()) + .collect(); + + Ok(Some(TomlValue::Array(values?))) + }, + + // Handle table types with field definitions + "table" => { let mut map = toml::map::Map::new(); - for (key, val) in obj { - if let Some(converted) = self.json_to_toml_value(val) { - map.insert(key.clone(), converted); + + // If we have field definitions, follow them for the table structure + if let Some(type_def) = type_definition { + if let Some(fields) = &type_def.fields { + for field_def in fields { + if let Some(field_map) = self.field_maps.get(&field_def.name) { + if let Some(field_value) = self.get_path_value(value, &field_map.path) { + if let Some(converted) = self.json_to_toml_value( + field_value, + &field_def.field_type, + field_def.type_definition.as_ref(), + )? { + map.insert(field_def.name.clone(), converted); + } + } + } + } + } + } else { + // For tables without type definitions, convert all fields + let obj = value + .as_object() + .ok_or_else(|| LearnerError::ApiError("Expected object value".into()))?; + for (key, val) in obj { + if let Some(converted) = self.convert_simple_value(val) { + map.insert(key.clone(), converted); + } } } - Some(TomlValue::Table(map)) - }, - // For arrays, recursively convert all elements - serde_json::Value::Array(arr) => { - let values: Vec<_> = arr.iter().filter_map(|item| self.json_to_toml_value(item)).collect(); - Some(TomlValue::Array(values)) + Ok(Some(TomlValue::Table(map))) }, - // Direct conversions for primitive types + // Handle primitive types + "string" | "datetime" | "boolean" => self.convert_primitive_value(value, field_type), + + // Handle unsupported types + unsupported => + Err(LearnerError::ApiError(format!("Unsupported field type: {}", unsupported))), + } + } + + /// Converts a primitive JSON value to a TOML value + fn convert_primitive_value( + &self, + value: &serde_json::Value, + field_type: &str, + ) -> Result> { + match field_type { + "string" => value + .as_str() + .map(|s| TomlValue::String(s.to_string())) + .ok_or_else(|| LearnerError::ApiError("Expected string value".into())) + .map(Some), + + "datetime" => value + .as_str() + .ok_or_else(|| LearnerError::ApiError("Expected string for datetime".into())) + .and_then(|s| { + DateTime::parse_from_rfc3339(s) + .map_err(|e| LearnerError::ApiError(format!("Invalid datetime: {}", e))) + }) + .map(|dt| Some(TomlValue::Datetime(chrono_to_toml_datetime(dt.with_timezone(&Utc))))), + + "boolean" => value + .as_bool() + .map(TomlValue::Boolean) + .ok_or_else(|| LearnerError::ApiError("Expected boolean value".into())) + .map(Some), + + _ => Ok(self.convert_simple_value(value)), + } + } + + /// Basic conversion for simple JSON values + fn convert_simple_value(&self, value: &serde_json::Value) -> Option { + match value { serde_json::Value::String(s) => Some(TomlValue::String(s.clone())), serde_json::Value::Number(n) => if n.is_i64() { @@ -94,20 +186,31 @@ impl JsonConfig { n.as_f64().map(TomlValue::Float) }, serde_json::Value::Bool(b) => Some(TomlValue::Boolean(*b)), + serde_json::Value::Array(arr) => { + let values: Vec<_> = + arr.iter().filter_map(|item| self.convert_simple_value(item)).collect(); + Some(TomlValue::Array(values)) + }, + serde_json::Value::Object(obj) => { + let map = obj + .iter() + .filter_map(|(k, v)| self.convert_simple_value(v).map(|val| (k.clone(), val))) + .collect(); + Some(TomlValue::Table(map)) + }, serde_json::Value::Null => None, } } - /// Extracts and converts a value from the JSON response according to the field type + /// Updates extract_value to use the full field definition fn extract_value( &self, json: &serde_json::Value, field_map: &FieldMap, - field_type: &str, + field_def: &FieldDefinition, ) -> Result> { - // Get the value at the specified path if let Some(value) = self.get_path_value(json, &field_map.path) { - // Apply any transformations if it's a string + // Apply transformations if configured let transformed_value = if let Some(transform) = &field_map.transform { if let Some(str_val) = value.as_str() { let transformed = apply_transform(str_val, transform)?; @@ -119,44 +222,18 @@ impl JsonConfig { value.clone() }; - // Convert the value based on the expected field type - match field_type { - "string" => transformed_value - .as_str() - .map(|s| TomlValue::String(s.to_string())) - .ok_or_else(|| { - LearnerError::ApiError(format!("Expected string value for field type 'string'")) - }) - .map(Some), - "datetime" => transformed_value - .as_str() - .ok_or_else(|| LearnerError::ApiError("Expected string for datetime".into())) - .and_then(|s| { - DateTime::parse_from_rfc3339(s) - .map_err(|e| LearnerError::ApiError(format!("Invalid datetime format: {}", e))) - }) - .map(|dt| Some(TomlValue::Datetime(chrono_to_toml_datetime(dt.with_timezone(&Utc))))), - "array" => Ok(self.json_to_toml_value(&transformed_value)), - "table" => Ok(self.json_to_toml_value(&transformed_value)), - unsupported => - Err(LearnerError::ApiError(format!("Unsupported field type: {}", unsupported))), - } + // Convert using type definition + self.json_to_toml_value( + &transformed_value, + &field_def.field_type, + field_def.type_definition.as_ref(), + ) } else { Ok(None) } } - /// Gets a string value from JSON using a path - fn get_by_path(&self, json: &serde_json::Value, path: &str) -> Option { - self.get_path_value(json, path).and_then(|value| match value { - serde_json::Value::String(s) => Some(s.clone()), - serde_json::Value::Number(n) => Some(n.to_string()), - serde_json::Value::Array(arr) if !arr.is_empty() => arr[0].as_str().map(String::from), - _ => value.as_str().map(String::from), - }) - } - - /// Navigates JSON structure using a path + // get_path_value remains the same as it's already working well fn get_path_value<'a>( &self, json: &'a serde_json::Value, From 8dc803d71753052d7bd0d396a8ce1d37080c23e1 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 13:07:14 -0700 Subject: [PATCH 22/73] WIP: working DOI transform --- crates/learner/config/retrievers/doi.toml | 3 + crates/learner/src/configuration.rs | 2 - crates/learner/src/resource/mod.rs | 1 - crates/learner/src/retriever/mod.rs | 143 +++++++++----- crates/learner/src/retriever/response/json.rs | 184 ++++++++---------- crates/learner/src/retriever/response/mod.rs | 4 + 6 files changed, 174 insertions(+), 163 deletions(-) diff --git a/crates/learner/config/retrievers/doi.toml b/crates/learner/config/retrievers/doi.toml index d035e11..0d5103e 100644 --- a/crates/learner/config/retrievers/doi.toml +++ b/crates/learner/config/retrievers/doi.toml @@ -21,6 +21,9 @@ type = "Replace" [response_format.field_maps.authors] path = "message/author" +[response_format.field_maps.authors.transform] +fields = ["given", "family"] +type = "CombineFields" [response_format.field_maps.publication_date] path = "message/created/date-time" diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 43fddec..0426fc1 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -1,5 +1,3 @@ -use std::ops::{Index, IndexMut}; - use super::*; pub trait Identifiable { diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index 99078bd..344dd38 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -151,7 +151,6 @@ impl ResourceConfig { // Pattern matching via regex if let Some(pattern) = &rules.pattern { - dbg!(&pattern); let re = Regex::new(pattern) .map_err(|_| LearnerError::InvalidResource("Invalid regex pattern".into()))?; if !re.is_match(v) { diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 0e82443..711a2f0 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -76,6 +76,7 @@ mod config; mod response; pub use config::*; +use json::get_path_value; pub use response::*; /// Main entry point for paper retrieval operations. @@ -326,6 +327,40 @@ fn apply_transform(value: &str, transform: &Transform) -> Result { .map(|dt| dt.format(to_format).to_string()), Transform::Url { base, suffix } => Ok(format!("{}{}", base.replace("{value}", value), suffix.as_deref().unwrap_or(""))), + Transform::CombineFields { fields } => { + let json: serde_json::Value = serde_json::from_str(value) + .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; + + // Handle both single objects and arrays + let result = if let Some(array) = json.as_array() { + // Create array of objects with combined fields + let combined: Vec> = array + .iter() + .filter_map(|obj| { + let mut map = serde_json::Map::new(); + let parts: Vec<_> = + fields.iter().filter_map(|field| obj.get(field)).filter_map(|v| v.as_str()).collect(); + if !parts.is_empty() { + map.insert("name".to_string(), serde_json::Value::String(parts.join(" "))); + Some(map) + } else { + None + } + }) + .collect(); + serde_json::Value::Array(combined.into_iter().map(serde_json::Value::Object).collect()) + } else if let Some(obj) = json.as_object() { + // Handle single object + let parts: Vec<_> = + fields.iter().filter_map(|field| obj.get(field)).filter_map(|v| v.as_str()).collect(); + serde_json::Value::String(parts.join(" ")) + } else { + return Err(LearnerError::ApiError("Expected object or array for CombineFields".into())); + }; + + serde_json::to_string(&result) + .map_err(|e| LearnerError::ApiError(format!("Failed to serialize result: {}", e))) + }, } } @@ -394,59 +429,61 @@ mod tests { } // TODO: Fix this - // #[test] - // fn test_doi_config_deserialization() { - // let config_str = include_str!("../../config/retrievers/doi.toml"); - - // let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); - - // // Verify basic fields - // assert_eq!(retriever.name, "doi"); - // assert_eq!(retriever.base_url, "https://api.crossref.org/works"); - // assert_eq!(retriever.source, "doi"); - - // // Test pattern matching - // let test_cases = [ - // ("10.1145/1327452.1327492", true), - // ("https://doi.org/10.1145/1327452.1327492", true), - // ("invalid-doi", false), - // ("https://wrong.url/10.1145/1327452.1327492", false), - // ]; - - // for (input, expected) in test_cases { - // assert_eq!( - // retriever.pattern.is_match(input), - // expected, - // "Pattern match failed for input: {}", - // input - // ); - // } - - // // Test identifier extraction - // assert_eq!( - // retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), - // "10.1145/1327452.1327492" - // ); - // assert_eq!( - // retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), - // "10.1145/1327452.1327492" - // ); - - // // Verify response format - // match &retriever.response_format { - // ResponseFormat::Json(config) => { - // // Verify field mappings - // let field_maps = &config.field_maps; - // assert!(field_maps.contains_key("title")); - // assert!(field_maps.contains_key("abstract")); - // assert!(field_maps.contains_key("authors")); - // assert!(field_maps.contains_key("publication_date")); - // assert!(field_maps.contains_key("pdf_url")); - // assert!(field_maps.contains_key("doi")); - // }, - // _ => panic!("Expected JSON response format"), - // } - // } + #[test] + fn test_doi_config_deserialization() { + let config_str = include_str!("../../config/retrievers/doi.toml"); + + let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + + dbg!(&retriever); + + // Verify basic fields + assert_eq!(retriever.name, "doi"); + assert_eq!(retriever.base_url, "https://api.crossref.org/works"); + assert_eq!(retriever.source, "doi"); + + // Test pattern matching + let test_cases = [ + ("10.1145/1327452.1327492", true), + ("https://doi.org/10.1145/1327452.1327492", true), + ("invalid-doi", false), + ("https://wrong.url/10.1145/1327452.1327492", false), + ]; + + for (input, expected) in test_cases { + assert_eq!( + retriever.pattern.is_match(input), + expected, + "Pattern match failed for input: {}", + input + ); + } + + // Test identifier extraction + assert_eq!( + retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), + "10.1145/1327452.1327492" + ); + assert_eq!( + retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), + "10.1145/1327452.1327492" + ); + + // Verify response format + match &retriever.response_format { + ResponseFormat::Json(config) => { + // Verify field mappings + let field_maps = &config.field_maps; + assert!(field_maps.contains_key("title")); + assert!(field_maps.contains_key("abstract")); + assert!(field_maps.contains_key("authors")); + assert!(field_maps.contains_key("publication_date")); + assert!(field_maps.contains_key("pdf_url")); + assert!(field_maps.contains_key("doi")); + }, + _ => panic!("Expected JSON response format"), + } + } #[test] fn test_iacr_config_deserialization() { diff --git a/crates/learner/src/retriever/response/json.rs b/crates/learner/src/retriever/response/json.rs index 5474abd..a04ef01 100644 --- a/crates/learner/src/retriever/response/json.rs +++ b/crates/learner/src/retriever/response/json.rs @@ -1,32 +1,9 @@ -//! JSON response parser implementation. -//! -//! This module handles parsing of JSON API responses into resources using -//! configurable field mappings. It supports path-based field extraction -//! with optional transformations. -//! -//! # Example Configuration -//! -//! ```toml -//! [response_format] -//! type = "json" -//! -//! [response_format.field_maps] -//! title = { path = "message/title/0" } -//! summary = { path = "message/abstract" } -//! created_at = { path = "message/published/date-time" } -//! contributors = { path = "message/contributors" } -//! ``` - use resource::{chrono_to_toml_datetime, FieldDefinition, TypeDefinition}; use serde_json; use toml::{self, Value as TomlValue}; use super::*; -/// Configuration for processing JSON API responses. -/// -/// Provides field mapping rules to extract resource fields from JSON responses -/// using path-based access patterns. #[derive(Debug, Clone, Deserialize)] pub struct JsonConfig { pub field_maps: HashMap, @@ -39,12 +16,14 @@ impl ResponseProcessor for JsonConfig { let json: serde_json::Value = serde_json::from_slice(data) .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; + dbg!(&self); trace!("Processing JSON response: {}", serde_json::to_string_pretty(&json).unwrap()); let mut resource = BTreeMap::new(); // Process each field according to resource configuration for field_def in &resource_config.fields { + dbg!(&field_def); if let Some(field_map) = self.field_maps.get(&field_def.name) { // Extract raw value if present, now passing the full field definition if let Some(value) = self.extract_value(&json, field_map, field_def)? { @@ -89,7 +68,7 @@ impl JsonConfig { self.json_to_toml_value(item, &def.field_type, def.type_definition.as_ref()) } else { // For simple arrays without type definitions, do basic conversion - Ok(self.convert_simple_value(item)) + Ok(convert_simple_value(item)) } }) .filter_map(|r| r.transpose()) @@ -107,7 +86,7 @@ impl JsonConfig { if let Some(fields) = &type_def.fields { for field_def in fields { if let Some(field_map) = self.field_maps.get(&field_def.name) { - if let Some(field_value) = self.get_path_value(value, &field_map.path) { + if let Some(field_value) = get_path_value(value, &field_map.path) { if let Some(converted) = self.json_to_toml_value( field_value, &field_def.field_type, @@ -125,7 +104,7 @@ impl JsonConfig { .as_object() .ok_or_else(|| LearnerError::ApiError("Expected object value".into()))?; for (key, val) in obj { - if let Some(converted) = self.convert_simple_value(val) { + if let Some(converted) = convert_simple_value(val) { map.insert(key.clone(), converted); } } @@ -135,7 +114,7 @@ impl JsonConfig { }, // Handle primitive types - "string" | "datetime" | "boolean" => self.convert_primitive_value(value, field_type), + "string" | "datetime" | "boolean" => convert_primitive_value(value, field_type), // Handle unsupported types unsupported => @@ -143,65 +122,6 @@ impl JsonConfig { } } - /// Converts a primitive JSON value to a TOML value - fn convert_primitive_value( - &self, - value: &serde_json::Value, - field_type: &str, - ) -> Result> { - match field_type { - "string" => value - .as_str() - .map(|s| TomlValue::String(s.to_string())) - .ok_or_else(|| LearnerError::ApiError("Expected string value".into())) - .map(Some), - - "datetime" => value - .as_str() - .ok_or_else(|| LearnerError::ApiError("Expected string for datetime".into())) - .and_then(|s| { - DateTime::parse_from_rfc3339(s) - .map_err(|e| LearnerError::ApiError(format!("Invalid datetime: {}", e))) - }) - .map(|dt| Some(TomlValue::Datetime(chrono_to_toml_datetime(dt.with_timezone(&Utc))))), - - "boolean" => value - .as_bool() - .map(TomlValue::Boolean) - .ok_or_else(|| LearnerError::ApiError("Expected boolean value".into())) - .map(Some), - - _ => Ok(self.convert_simple_value(value)), - } - } - - /// Basic conversion for simple JSON values - fn convert_simple_value(&self, value: &serde_json::Value) -> Option { - match value { - serde_json::Value::String(s) => Some(TomlValue::String(s.clone())), - serde_json::Value::Number(n) => - if n.is_i64() { - n.as_i64().map(TomlValue::Integer) - } else { - n.as_f64().map(TomlValue::Float) - }, - serde_json::Value::Bool(b) => Some(TomlValue::Boolean(*b)), - serde_json::Value::Array(arr) => { - let values: Vec<_> = - arr.iter().filter_map(|item| self.convert_simple_value(item)).collect(); - Some(TomlValue::Array(values)) - }, - serde_json::Value::Object(obj) => { - let map = obj - .iter() - .filter_map(|(k, v)| self.convert_simple_value(v).map(|val| (k.clone(), val))) - .collect(); - Some(TomlValue::Table(map)) - }, - serde_json::Value::Null => None, - } - } - /// Updates extract_value to use the full field definition fn extract_value( &self, @@ -209,15 +129,10 @@ impl JsonConfig { field_map: &FieldMap, field_def: &FieldDefinition, ) -> Result> { - if let Some(value) = self.get_path_value(json, &field_map.path) { + if let Some(value) = get_path_value(json, &field_map.path) { // Apply transformations if configured let transformed_value = if let Some(transform) = &field_map.transform { - if let Some(str_val) = value.as_str() { - let transformed = apply_transform(str_val, transform)?; - serde_json::Value::String(transformed) - } else { - value.clone() - } + serde_json::from_str(&apply_transform(&serde_json::to_string(&value)?, transform)?)? } else { value.clone() }; @@ -232,21 +147,76 @@ impl JsonConfig { Ok(None) } } +} + +/// Converts a primitive JSON value to a TOML value +fn convert_primitive_value( + value: &serde_json::Value, + field_type: &str, +) -> Result> { + match field_type { + "string" => value + .as_str() + .map(|s| TomlValue::String(s.to_string())) + .ok_or_else(|| LearnerError::ApiError("Expected string value".into())) + .map(Some), + + "datetime" => value + .as_str() + .ok_or_else(|| LearnerError::ApiError("Expected string for datetime".into())) + .and_then(|s| { + DateTime::parse_from_rfc3339(s) + .map_err(|e| LearnerError::ApiError(format!("Invalid datetime: {}", e))) + }) + .map(|dt| Some(TomlValue::Datetime(chrono_to_toml_datetime(dt.with_timezone(&Utc))))), + + "boolean" => value + .as_bool() + .map(TomlValue::Boolean) + .ok_or_else(|| LearnerError::ApiError("Expected boolean value".into())) + .map(Some), + + _ => Ok(convert_simple_value(value)), + } +} - // get_path_value remains the same as it's already working well - fn get_path_value<'a>( - &self, - json: &'a serde_json::Value, - path: &str, - ) -> Option<&'a serde_json::Value> { - let mut current = json; - for part in path.split('/') { - current = if let Ok(index) = part.parse::() { - current.as_array()?.get(index)? +pub fn get_path_value<'a>( + json: &'a serde_json::Value, + path: &str, +) -> Option<&'a serde_json::Value> { + let mut current = json; + for part in path.split('/') { + current = if let Ok(index) = part.parse::() { + current.as_array()?.get(index)? + } else { + current.get(part)? + }; + } + Some(current) +} + +/// Basic conversion for simple JSON values +fn convert_simple_value(value: &serde_json::Value) -> Option { + match value { + serde_json::Value::String(s) => Some(TomlValue::String(s.clone())), + serde_json::Value::Number(n) => + if n.is_i64() { + n.as_i64().map(TomlValue::Integer) } else { - current.get(part)? - }; - } - Some(current) + n.as_f64().map(TomlValue::Float) + }, + serde_json::Value::Bool(b) => Some(TomlValue::Boolean(*b)), + serde_json::Value::Array(arr) => { + let values: Vec<_> = arr.iter().filter_map(|item| convert_simple_value(item)).collect(); + Some(TomlValue::Array(values)) + }, + serde_json::Value::Object(obj) => { + let map = obj + .iter() + .filter_map(|(k, v)| convert_simple_value(v).map(|val| (k.clone(), val))) + .collect(); + Some(TomlValue::Table(map)) + }, + serde_json::Value::Null => None, } } diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index 98c82bf..ed51d18 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -100,6 +100,10 @@ pub enum Transform { /// Optional suffix to append to the URL (e.g., ".pdf") suffix: Option, }, + // New transform for combining fields + CombineFields { + fields: Vec, + }, } /// Trait for processing API responses into Paper objects. From c9a239df36159b4d2f683bdebb895395e0633414 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 13:13:40 -0700 Subject: [PATCH 23/73] Update json.rs --- crates/learner/src/retriever/response/json.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/learner/src/retriever/response/json.rs b/crates/learner/src/retriever/response/json.rs index a04ef01..392f1fc 100644 --- a/crates/learner/src/retriever/response/json.rs +++ b/crates/learner/src/retriever/response/json.rs @@ -16,14 +16,12 @@ impl ResponseProcessor for JsonConfig { let json: serde_json::Value = serde_json::from_slice(data) .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; - dbg!(&self); trace!("Processing JSON response: {}", serde_json::to_string_pretty(&json).unwrap()); let mut resource = BTreeMap::new(); // Process each field according to resource configuration for field_def in &resource_config.fields { - dbg!(&field_def); if let Some(field_map) = self.field_maps.get(&field_def.name) { // Extract raw value if present, now passing the full field definition if let Some(value) = self.extract_value(&json, field_map, field_def)? { From 1a989c45f7af9ec501d6ee4f7c5506374be32c97 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 13:32:00 -0700 Subject: [PATCH 24/73] improved DOI --- crates/learner/config/retrievers/doi.toml | 5 +-- crates/learner/src/retriever/mod.rs | 33 +++++++++++++++----- crates/learner/src/retriever/response/mod.rs | 9 +++++- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/crates/learner/config/retrievers/doi.toml b/crates/learner/config/retrievers/doi.toml index 0d5103e..193ea3f 100644 --- a/crates/learner/config/retrievers/doi.toml +++ b/crates/learner/config/retrievers/doi.toml @@ -22,8 +22,9 @@ type = "Replace" [response_format.field_maps.authors] path = "message/author" [response_format.field_maps.authors.transform] -fields = ["given", "family"] -type = "CombineFields" +fields = ["given", "family"] +inner_paths = [{ new_key_name = "affiliation", path = "affiliation/0/name" }] +type = "CombineFields" [response_format.field_maps.publication_date] path = "message/created/date-time" diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 711a2f0..12c7a5d 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -327,7 +327,7 @@ fn apply_transform(value: &str, transform: &Transform) -> Result { .map(|dt| dt.format(to_format).to_string()), Transform::Url { base, suffix } => Ok(format!("{}{}", base.replace("{value}", value), suffix.as_deref().unwrap_or(""))), - Transform::CombineFields { fields } => { + Transform::CombineFields { fields, inner_paths } => { let json: serde_json::Value = serde_json::from_str(value) .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; @@ -338,24 +338,43 @@ fn apply_transform(value: &str, transform: &Transform) -> Result { .iter() .filter_map(|obj| { let mut map = serde_json::Map::new(); + dbg!(&obj); + + // Handle the name fields combination let parts: Vec<_> = fields.iter().filter_map(|field| obj.get(field)).filter_map(|v| v.as_str()).collect(); + if !parts.is_empty() { map.insert("name".to_string(), serde_json::Value::String(parts.join(" "))); + + // Handle any additional inner paths + if let Some(paths) = inner_paths { + dbg!(paths); + for path in paths { + if let Some(inner_val) = get_path_value(obj, &path.path) { + dbg!(inner_val); + if let Some(str_val) = inner_val.as_str() { + // Use the last component of the path as the field name + let field_name = path.new_key_name.clone(); + map.insert( + field_name.to_string(), + serde_json::Value::String(str_val.to_string()), + ); + } + } + } + } + Some(map) } else { None } }) .collect(); + serde_json::Value::Array(combined.into_iter().map(serde_json::Value::Object).collect()) - } else if let Some(obj) = json.as_object() { - // Handle single object - let parts: Vec<_> = - fields.iter().filter_map(|field| obj.get(field)).filter_map(|v| v.as_str()).collect(); - serde_json::Value::String(parts.join(" ")) } else { - return Err(LearnerError::ApiError("Expected object or array for CombineFields".into())); + return Err(LearnerError::ApiError("Expected array for CombineFields".into())); }; serde_json::to_string(&result) diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index ed51d18..331863a 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -102,10 +102,17 @@ pub enum Transform { }, // New transform for combining fields CombineFields { - fields: Vec, + fields: Vec, // Fields to combine for name + inner_paths: Option>, // Additional paths to collect }, } +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct InnerPath { + pub new_key_name: String, + pub path: String, +} + /// Trait for processing API responses into Paper objects. /// /// Implementors of this trait handle the conversion of raw API response data From bf57318b87f807551f4d3dc0da32446ad086f439 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 13:37:24 -0700 Subject: [PATCH 25/73] debugging arxiv and iacr --- crates/learner/src/retriever/response/xml.rs | 2 ++ .../tests/workflows/paper_retrieval.rs | 25 +++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/crates/learner/src/retriever/response/xml.rs b/crates/learner/src/retriever/response/xml.rs index 4f7f39e..66bcca4 100644 --- a/crates/learner/src/retriever/response/xml.rs +++ b/crates/learner/src/retriever/response/xml.rs @@ -65,6 +65,8 @@ impl ResponseProcessor for XmlConfig { String::from_utf8_lossy(data).to_string() }; + trace!("Processing XML response: {:#?}", &xml); + // Extract raw XML content into path -> string mapping let content = self.extract_content(&xml)?; let mut resource = BTreeMap::new(); diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index 7c1c328..ccedcd3 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -4,6 +4,7 @@ use learner::resource::ResourceConfig; use super::*; +#[traced_test] #[tokio::test] async fn test_arxiv_retriever_integration() -> TestResult<()> { let ret_config_str = fs::read_to_string("config/retrievers/arxiv.toml").expect( @@ -27,6 +28,7 @@ async fn test_arxiv_retriever_integration() -> TestResult<()> { paper.get("title").unwrap().as_str().unwrap(), "Verifiable Fully Homomorphic Encryption" ); + todo!("This needs cleaned up."); // assert!(!paper.title.is_empty()); // assert!(!paper.authors.is_empty()); // assert!(!paper.abstract_text.is_empty()); @@ -59,22 +61,35 @@ async fn test_arxiv_pdf_from_paper() -> TestResult<()> { // Ok(()) } +#[traced_test] #[tokio::test] -async fn test_iacr_retriever_integration() { - let config_str = - fs::read_to_string("config/retrievers/iacr.toml").expect("Failed to read config file"); +async fn test_iacr_retriever_integration() -> TestResult<()> { + let ret_config_str = fs::read_to_string("config/retrievers/iacr.toml").expect( + "Failed to read config + file", + ); + let res_config_str = fs::read_to_string("config/resources/paper.toml").expect( + "Failed to read config + file", + ); - let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); + let retriever: RetrieverConfig = toml::from_str(&ret_config_str).expect("Failed to parse config"); + let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); // // Test with a real IACR paper - // let paper = retriever.retrieve_paper("2016/260").await.unwrap(); + let paper = retriever.retrieve_resource("2016/260", resource).await.unwrap(); + dbg!(&paper); + + todo!("This likely needs cleaned up"); // assert!(!paper.title.is_empty()); // assert!(!paper.authors.is_empty()); // assert!(!paper.abstract_text.is_empty()); // assert!(paper.pdf_url.is_some()); // assert_eq!(paper.source, "iacr"); // assert_eq!(paper.source_identifier, "2016/260"); + + Ok(()) } #[traced_test] From 51e89b5bcf09b89368b8f6c10212a16318888061 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 16:26:39 -0700 Subject: [PATCH 26/73] switch to JSON value --- crates/learner/config/resources/book.toml | 3 +- crates/learner/config/resources/paper.toml | 3 +- crates/learner/config/resources/thesis.toml | 3 +- crates/learner/src/lib.rs | 1 + crates/learner/src/resource/mod.rs | 197 ++++--- crates/learner/src/retriever/config.rs | 3 +- crates/learner/src/retriever/mod.rs | 540 +++++++++--------- crates/learner/src/retriever/response/json.rs | 223 +------- crates/learner/src/retriever/response/xml.rs | 183 +++--- 9 files changed, 530 insertions(+), 626 deletions(-) diff --git a/crates/learner/config/resources/book.toml b/crates/learner/config/resources/book.toml index 8b14ebb..5a3cd5b 100644 --- a/crates/learner/config/resources/book.toml +++ b/crates/learner/config/resources/book.toml @@ -36,9 +36,10 @@ required = false [[fields]] description = "When the book was published" -field_type = "datetime" +field_type = "string" name = "publication_date" required = false +validation = { datetime = true } [[fields]] description = "The edition number or description" diff --git a/crates/learner/config/resources/paper.toml b/crates/learner/config/resources/paper.toml index 4b21984..24e7087 100644 --- a/crates/learner/config/resources/paper.toml +++ b/crates/learner/config/resources/paper.toml @@ -37,9 +37,10 @@ required = false # Publication date - datetime type [[fields]] description = "When the paper was published or last updated" -field_type = "datetime" +field_type = "string" name = "publication_date" required = false +validation = { datetime = true } # DOI field - string with pattern validation [[fields]] diff --git a/crates/learner/config/resources/thesis.toml b/crates/learner/config/resources/thesis.toml index 37b5653..4601a55 100644 --- a/crates/learner/config/resources/thesis.toml +++ b/crates/learner/config/resources/thesis.toml @@ -42,9 +42,10 @@ required = false [[fields]] description = "When the degree was awarded" -field_type = "datetime" +field_type = "string" name = "completion_date" required = true +validation = { datetime = true } [[fields]] description = "Thesis advisors or supervisors" diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index b9f1003..70d65d2 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -155,6 +155,7 @@ use regex::Regex; use reqwest::Url; use resource::{Resource, ResourceConfig, Resources}; use serde::{Deserialize, Serialize}; +use serde_json::Value; use tracing::{debug, trace, warn}; #[cfg(test)] use {tempfile::tempdir, tracing_test::traced_test}; diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index 344dd38..125de37 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -6,14 +6,10 @@ mod paper; mod shared; pub use paper::*; +use serde_json::Value; pub use shared::*; -use toml::Value; -// TODO (autoparallel): We almost need something like `Resource` to be given by these -// `ResourceConfig`s. Or, even renaming these like `ResourceTemplates` or something so a `Resource` -// has to fit into the `ResourceTemplate` (now that I type this out, `ResourceConfig` is still a -// reasonable name). But when we want to retrieve a resource, we need to actually get back a -// resource. Perhaps its just: +// Type alias for clarity and consistency pub type Resource = BTreeMap; #[derive(Debug, Clone, Default)] @@ -51,7 +47,7 @@ impl Identifiable for ResourceConfig { pub struct FieldDefinition { /// Name of the field pub name: String, - /// Type of the field (should be a TOML Value) + /// Type of the field (should be a JSON Value type) pub field_type: String, /// Whether this field must be present #[serde(default)] @@ -77,7 +73,7 @@ pub struct TypeDefinition { pub fields: Option>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct ValidationRules { // String validations pub pattern: Option, // Regex pattern to match @@ -96,6 +92,7 @@ pub struct ValidationRules { // General validations pub enum_values: Option>, // List of allowed values + pub datetime: Option, // Validates RFC3339 format } impl ResourceConfig { @@ -103,13 +100,11 @@ impl ResourceConfig { pub fn validate(&self, values: &Resource) -> Result { // Check required fields for field in &self.fields { - if field.required { - if !values.contains_key(&field.name) { - return Err(LearnerError::InvalidResource(format!( - "Missing required field: {}", - field.name - ))); - } + if field.required && !values.contains_key(&field.name) { + return Err(LearnerError::InvalidResource(format!( + "Missing required field: {}", + field.name + ))); } } @@ -125,11 +120,10 @@ impl ResourceConfig { } /// Validates a single field value against its definition - fn validate_field(&self, field: &FieldDefinition, value: &toml::Value) -> Result<()> { - // First validate that the provided value matches the declared type + fn validate_field(&self, field: &FieldDefinition, value: &Value) -> Result<()> { match (field.field_type.as_str(), value) { // String validation - handles both basic type checking and string-specific rules - ("string", toml::Value::String(v)) => { + ("string", Value::String(v)) => { if let Some(rules) = &field.validation { // Length constraints if let Some(min_length) = rules.min_length { @@ -161,6 +155,16 @@ impl ResourceConfig { } } + // Datetime validation if specified + if rules.datetime == Some(true) { + if DateTime::parse_from_rfc3339(v).is_err() { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be a valid RFC3339 datetime", + field.name + ))); + } + } + // Enumerated values check if let Some(allowed) = &rules.enum_values { if !allowed.contains(v) { @@ -174,23 +178,18 @@ impl ResourceConfig { Ok(()) }, - // Numeric validations - handle both integer and float values - ("integer", toml::Value::Integer(v)) => { + // Numeric validations - handle both number types + ("number", Value::Number(n)) => { if let Some(rules) = &field.validation { - validate_numeric(field, *v as f64, rules)?; - } - Ok(()) - }, - - ("float", toml::Value::Float(v)) => { - if let Some(rules) = &field.validation { - validate_numeric(field, *v, rules)?; + if let Some(num) = n.as_f64() { + validate_numeric(field, num, rules)?; + } } Ok(()) }, // Array validation - handles array-specific rules - ("array", toml::Value::Array(v)) => { + ("array", Value::Array(v)) => { if let Some(rules) = &field.validation { if let Some(min_items) = rules.min_items { if v.len() < min_items { @@ -213,7 +212,7 @@ impl ResourceConfig { if rules.unique_items == Some(true) { let mut seen = HashSet::new(); for item in v { - let item_str = toml::to_string(item).map_err(|_| { + let item_str = serde_json::to_string(item).map_err(|_| { LearnerError::InvalidResource("Failed to serialize array item".into()) })?; if !seen.insert(item_str) { @@ -229,9 +228,9 @@ impl ResourceConfig { }, // Simple type validations - just ensure type matches - ("boolean", toml::Value::Boolean(_)) => Ok(()), - ("datetime", toml::Value::Datetime(_)) => Ok(()), - ("table", toml::Value::Table(_)) => Ok(()), + ("boolean", Value::Bool(_)) => Ok(()), + ("object", Value::Object(_)) => Ok(()), + ("null", Value::Null) => Ok(()), // Type mismatch - provide a clear error message _ => Err(LearnerError::InvalidResource(format!( @@ -239,13 +238,12 @@ impl ResourceConfig { field.name, field.field_type, match value { - toml::Value::String(_) => "string", - toml::Value::Integer(_) => "integer", - toml::Value::Float(_) => "float", - toml::Value::Boolean(_) => "boolean", - toml::Value::Datetime(_) => "datetime", - toml::Value::Array(_) => "array", - toml::Value::Table(_) => "table", + Value::String(_) => "string", + Value::Number(_) => "number", + Value::Bool(_) => "boolean", + Value::Array(_) => "array", + Value::Object(_) => "object", + Value::Null => "null", } ))), } @@ -283,21 +281,20 @@ fn validate_numeric(field: &FieldDefinition, value: f64, rules: &ValidationRules Ok(()) } -// Convert from chrono DateTime to TOML Datetime -pub fn chrono_to_toml_datetime(dt: DateTime) -> toml::value::Datetime { - // TOML datetime is stored as seconds since Unix epoch - toml::value::Datetime::from_str(&dt.to_rfc3339()).unwrap() -} -// Convert from TOML Datetime to chrono DateTime -pub fn toml_to_chrono_datetime(dt: toml::value::Datetime) -> DateTime { - // Create DateTime from Unix timestamp - DateTime::parse_from_rfc3339(&dt.to_string()).unwrap().to_utc() -} +/// Convert DateTime to RFC3339 string for JSON storage +pub fn datetime_to_json(dt: DateTime) -> String { dt.to_rfc3339() } +/// Parse RFC3339 string from JSON into DateTime +pub fn datetime_from_json(s: &str) -> Result> { + DateTime::parse_from_rfc3339(s) + .map(|dt| dt.with_timezone(&Utc)) + .map_err(|e| LearnerError::InvalidResource(format!("Invalid datetime format: {}", e))) +} #[cfg(test)] mod tests { use chrono::TimeZone; + use serde_json::json; use super::*; @@ -306,26 +303,30 @@ mod tests { let config = include_str!("../../config/resources/paper.toml"); let config: ResourceConfig = toml::from_str(config).unwrap(); - let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); + let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); // Create a valid paper resource - let mut paper_resource = BTreeMap::new(); - paper_resource.insert("title".into(), Value::String("Understanding Quantum Computing".into())); - - // Create the author table using TOML's Map type - let author = { - let mut map = toml::map::Map::new(); - map.insert("name".into(), Value::String("Alice Researcher".into())); - map.insert("affiliation".into(), Value::String("Tech University".into())); - map - }; - - paper_resource.insert("authors".into(), Value::Array(vec![Value::Table(author)])); - paper_resource.insert("publication_date".into(), Value::Datetime(date)); - paper_resource.insert("doi".into(), Value::String("10.1234/example.123".into())); + let paper_resource = BTreeMap::from([ + ("title".into(), json!("Understanding Quantum Computing")), + ( + "authors".into(), + json!([{ + "name": "Alice Researcher", + "affiliation": "Tech University" + }]), + ), + ("publication_date".into(), json!(date)), + ("doi".into(), json!("10.1234/example.123")), + ]); // Validate the paper assert!(config.validate(&paper_resource).unwrap()); + + // Test required field validation + let invalid_paper = BTreeMap::from([ + ("authors".into(), json!([])), // Missing title + ]); + assert!(config.validate(&invalid_paper).is_err()); } #[test] @@ -333,18 +334,15 @@ mod tests { let config = include_str!("../../config/resources/book.toml"); let config: ResourceConfig = toml::from_str(config).unwrap(); - let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); + let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - // Create a valid book resource - let mut book_resource = BTreeMap::new(); - book_resource.insert("title".into(), Value::String("Advanced Quantum Computing".into())); - book_resource.insert( - "authors".into(), - Value::Array(vec![Value::String("Alice Writer".into()), Value::String("Bob Author".into())]), - ); - book_resource.insert("isbn".into(), Value::String("978-0-12-345678-9".into())); - book_resource.insert("publisher".into(), Value::String("Academic Press".into())); - book_resource.insert("publication_date".into(), Value::Datetime(date)); + let book_resource = BTreeMap::from([ + ("title".into(), json!("Advanced Quantum Computing")), + ("authors".into(), json!(["Alice Writer", "Bob Author"])), + ("isbn".into(), json!("978-0-12-345678-9")), + ("publisher".into(), json!("Academic Press")), + ("publication_date".into(), json!(date)), + ]); assert!(config.validate(&book_resource).unwrap()); } @@ -354,24 +352,47 @@ mod tests { let config = include_str!("../../config/resources/thesis.toml"); let config: ResourceConfig = toml::from_str(config).unwrap(); - let date = chrono_to_toml_datetime(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); + let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - // Create a valid thesis resource - let mut thesis_resource = BTreeMap::new(); - thesis_resource - .insert("title".into(), Value::String("Novel Approaches to Quantum Error Correction".into())); - thesis_resource.insert("author".into(), Value::String("Alice Researcher".into())); - thesis_resource.insert("degree".into(), Value::String("PhD".into())); - thesis_resource.insert("institution".into(), Value::String("Tech University".into())); - thesis_resource.insert("completion_date".into(), Value::Datetime(date)); - thesis_resource - .insert("advisors".into(), Value::Array(vec![Value::String("Prof. Bob Supervisor".into())])); + let thesis_resource = BTreeMap::from([ + ("title".into(), json!("Novel Approaches to Quantum Error Correction")), + ("author".into(), json!("Alice Researcher")), + ("degree".into(), json!("PhD")), + ("institution".into(), json!("Tech University")), + ("completion_date".into(), json!(date)), + ("advisors".into(), json!(["Prof. Bob Supervisor"])), + ]); assert!(config.validate(&thesis_resource).unwrap()); // Test degree enum validation let mut invalid_thesis = thesis_resource.clone(); - invalid_thesis.insert("degree".into(), Value::String("InvalidDegree".into())); + invalid_thesis.insert("degree".into(), json!("InvalidDegree")); assert!(config.validate(&invalid_thesis).is_err()); } + + #[test] + fn test_datetime_validation() { + let mut config = ResourceConfig { + name: "test".into(), + description: None, + fields: vec![FieldDefinition { + name: "timestamp".into(), + field_type: "string".into(), + required: true, + description: None, + default: None, + validation: Some(ValidationRules { datetime: Some(true), ..Default::default() }), + type_definition: None, + }], + }; + + let valid_resource = BTreeMap::from([("timestamp".into(), json!("2024-01-01T00:00:00Z"))]); + assert!(config.validate(&valid_resource).unwrap()); + + let invalid_resource = BTreeMap::from([ + ("timestamp".into(), json!("2024-01-01")), // Not RFC3339 + ]); + assert!(config.validate(&invalid_resource).is_err()); + } } diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index af3588a..7b1b3c8 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,5 +1,4 @@ use resource::Resource; -use toml::Value; use super::*; @@ -49,7 +48,7 @@ pub struct RetrieverConfig { pub response_format: ResponseFormat, /// Optional HTTP headers for API requests #[serde(default)] - pub headers: HashMap, + pub headers: BTreeMap, } impl Identifiable for RetrieverConfig { diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 12c7a5d..77b191a 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -68,8 +68,6 @@ //! # } //! ``` -use std::collections::HashMap; - use super::*; mod config; @@ -303,271 +301,273 @@ where D: serde::Deserializer<'de> { Regex::new(&s).map_err(serde::de::Error::custom) } -/// Applies a transformation to a string value based on the transform type. -/// -/// Handles three types of transformations: -/// - Regular expression replacements -/// - Date format conversions -/// - URL construction -/// -/// # Errors -/// -/// Returns a LearnerError if: -/// - Regex pattern is invalid -/// - Date parsing fails -/// - Date format is invalid -fn apply_transform(value: &str, transform: &Transform) -> Result { - match transform { - Transform::Replace { pattern, replacement } => Regex::new(pattern) - .map_err(|e| LearnerError::ApiError(format!("Invalid regex: {}", e))) - .map(|re| re.replace_all(value, replacement.as_str()).into_owned()), - Transform::Date { from_format, to_format } => - chrono::NaiveDateTime::parse_from_str(value, from_format) - .map_err(|e| LearnerError::ApiError(format!("Invalid date: {}", e))) - .map(|dt| dt.format(to_format).to_string()), - Transform::Url { base, suffix } => - Ok(format!("{}{}", base.replace("{value}", value), suffix.as_deref().unwrap_or(""))), - Transform::CombineFields { fields, inner_paths } => { - let json: serde_json::Value = serde_json::from_str(value) - .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; - - // Handle both single objects and arrays - let result = if let Some(array) = json.as_array() { - // Create array of objects with combined fields - let combined: Vec> = array - .iter() - .filter_map(|obj| { - let mut map = serde_json::Map::new(); - dbg!(&obj); - - // Handle the name fields combination - let parts: Vec<_> = - fields.iter().filter_map(|field| obj.get(field)).filter_map(|v| v.as_str()).collect(); - - if !parts.is_empty() { - map.insert("name".to_string(), serde_json::Value::String(parts.join(" "))); - - // Handle any additional inner paths - if let Some(paths) = inner_paths { - dbg!(paths); - for path in paths { - if let Some(inner_val) = get_path_value(obj, &path.path) { - dbg!(inner_val); - if let Some(str_val) = inner_val.as_str() { - // Use the last component of the path as the field name - let field_name = path.new_key_name.clone(); - map.insert( - field_name.to_string(), - serde_json::Value::String(str_val.to_string()), - ); - } - } - } - } - - Some(map) - } else { - None - } - }) - .collect(); - - serde_json::Value::Array(combined.into_iter().map(serde_json::Value::Object).collect()) - } else { - return Err(LearnerError::ApiError("Expected array for CombineFields".into())); - }; - - serde_json::to_string(&result) - .map_err(|e| LearnerError::ApiError(format!("Failed to serialize result: {}", e))) - }, - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn validate_arxiv_config() { - let config_str = include_str!("../../config/retrievers/arxiv.toml"); - - let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); - - // Verify basic fields - assert_eq!(retriever.name, "arxiv"); - assert_eq!(retriever.base_url, "http://export.arxiv.org"); - assert_eq!(retriever.source, "arxiv"); - - // Test pattern matching - assert!(retriever.pattern.is_match("2301.07041")); - assert!(retriever.pattern.is_match("math.AG/0601001")); - assert!(retriever.pattern.is_match("https://arxiv.org/abs/2301.07041")); - assert!(retriever.pattern.is_match("https://arxiv.org/pdf/2301.07041")); - assert!(retriever.pattern.is_match("https://arxiv.org/abs/math.AG/0601001")); - assert!(retriever.pattern.is_match("https://arxiv.org/abs/math/0404443")); - - // Test identifier extraction - assert_eq!(retriever.extract_identifier("2301.07041").unwrap(), "2301.07041"); - assert_eq!( - retriever.extract_identifier("https://arxiv.org/abs/2301.07041").unwrap(), - "2301.07041" - ); - assert_eq!(retriever.extract_identifier("math.AG/0601001").unwrap(), "math.AG/0601001"); - - // Verify response format - - if let ResponseFormat::Xml(config) = &retriever.response_format { - assert!(config.strip_namespaces); - - // Verify field mappings - let field_maps = &config.field_maps; - assert!(field_maps.contains_key("title")); - assert!(field_maps.contains_key("abstract")); - assert!(field_maps.contains_key("authors")); - assert!(field_maps.contains_key("publication_date")); - assert!(field_maps.contains_key("pdf_url")); - - // Verify PDF transform - if let Some(map) = field_maps.get("pdf_url") { - match &map.transform { - Some(Transform::Replace { pattern, replacement }) => { - assert_eq!(pattern, "/abs/"); - assert_eq!(replacement, "/pdf/"); - }, - _ => panic!("Expected Replace transform for pdf_url"), - } - } else { - panic!("Missing pdf_url field map"); - } - } else { - panic!("Expected an XML configuration, but did not get one.") - } - - // Verify headers - assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); - } - - // TODO: Fix this - #[test] - fn test_doi_config_deserialization() { - let config_str = include_str!("../../config/retrievers/doi.toml"); - - let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); - - dbg!(&retriever); - - // Verify basic fields - assert_eq!(retriever.name, "doi"); - assert_eq!(retriever.base_url, "https://api.crossref.org/works"); - assert_eq!(retriever.source, "doi"); - - // Test pattern matching - let test_cases = [ - ("10.1145/1327452.1327492", true), - ("https://doi.org/10.1145/1327452.1327492", true), - ("invalid-doi", false), - ("https://wrong.url/10.1145/1327452.1327492", false), - ]; - - for (input, expected) in test_cases { - assert_eq!( - retriever.pattern.is_match(input), - expected, - "Pattern match failed for input: {}", - input - ); - } - - // Test identifier extraction - assert_eq!( - retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), - "10.1145/1327452.1327492" - ); - assert_eq!( - retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), - "10.1145/1327452.1327492" - ); - - // Verify response format - match &retriever.response_format { - ResponseFormat::Json(config) => { - // Verify field mappings - let field_maps = &config.field_maps; - assert!(field_maps.contains_key("title")); - assert!(field_maps.contains_key("abstract")); - assert!(field_maps.contains_key("authors")); - assert!(field_maps.contains_key("publication_date")); - assert!(field_maps.contains_key("pdf_url")); - assert!(field_maps.contains_key("doi")); - }, - _ => panic!("Expected JSON response format"), - } - } - - #[test] - fn test_iacr_config_deserialization() { - let config_str = include_str!("../../config/retrievers/iacr.toml"); - - let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); - - // Verify basic fields - assert_eq!(retriever.name, "iacr"); - assert_eq!(retriever.base_url, "https://eprint.iacr.org"); - assert_eq!(retriever.source, "iacr"); - - // Test pattern matching - let test_cases = [ - ("2016/260", true), - ("2023/123", true), - ("https://eprint.iacr.org/2016/260", true), - ("https://eprint.iacr.org/2016/260.pdf", true), - ("invalid/format", false), - ("https://wrong.url/2016/260", false), - ]; - - for (input, expected) in test_cases { - assert_eq!( - retriever.pattern.is_match(input), - expected, - "Pattern match failed for input: {}", - input - ); - } - - // Test identifier extraction - assert_eq!(retriever.extract_identifier("2016/260").unwrap(), "2016/260"); - assert_eq!( - retriever.extract_identifier("https://eprint.iacr.org/2016/260").unwrap(), - "2016/260" - ); - assert_eq!( - retriever.extract_identifier("https://eprint.iacr.org/2016/260.pdf").unwrap(), - "2016/260" - ); - - // Verify response format - if let ResponseFormat::Xml(config) = &retriever.response_format { - assert!(config.strip_namespaces); - - // Verify field mappings - let field_maps = &config.field_maps; - assert!(field_maps.contains_key("title")); - assert!(field_maps.contains_key("abstract")); - assert!(field_maps.contains_key("authors")); - assert!(field_maps.contains_key("publication_date")); - assert!(field_maps.contains_key("pdf_url")); - - // Verify OAI-PMH paths - if let Some(map) = field_maps.get("title") { - assert!(map.path.contains(&"OAI-PMH/GetRecord/record/metadata/dc/title".to_string())); - } else { - panic!("Missing title field map"); - } - } else { - panic!("Expected an XML configuration, but did not get one.") - } - - // Verify headers - assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); - } -} +// / Applies a transformation to a string value based on the transform type. +// / +// / Handles three types of transformations: +// / - Regular expression replacements +// / - Date format conversions +// / - URL construction +// / +// / # Errors +// / +// / Returns a LearnerError if: +// / - Regex pattern is invalid +// / - Date parsing fails +// / - Date format is invalid + +// fn apply_transform(value: &str, transform: &Transform) -> Result { +// match transform { +// Transform::Replace { pattern, replacement } => Regex::new(pattern) +// .map_err(|e| LearnerError::ApiError(format!("Invalid regex: {}", e))) +// .map(|re| re.replace_all(value, replacement.as_str()).into_owned()), +// Transform::Date { from_format, to_format } => +// chrono::NaiveDateTime::parse_from_str(value, from_format) +// .map_err(|e| LearnerError::ApiError(format!("Invalid date: {}", e))) +// .map(|dt| dt.format(to_format).to_string()), +// Transform::Url { base, suffix } => +// Ok(format!("{}{}", base.replace("{value}", value), suffix.as_deref().unwrap_or(""))), +// Transform::CombineFields { fields, inner_paths } => { +// let json: serde_json::Value = serde_json::from_str(value) +// .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; + +// // Handle both single objects and arrays +// let result = if let Some(array) = json.as_array() { +// // Create array of objects with combined fields +// let combined: Vec> = array +// .iter() +// .filter_map(|obj| { +// let mut map = serde_json::Map::new(); +// dbg!(&obj); + +// // Handle the name fields combination +// let parts: Vec<_> = +// fields.iter().filter_map(|field| obj.get(field)).filter_map(|v| +// v.as_str()).collect(); + +// if !parts.is_empty() { +// map.insert("name".to_string(), serde_json::Value::String(parts.join(" "))); + +// // Handle any additional inner paths +// if let Some(paths) = inner_paths { +// dbg!(paths); +// for path in paths { +// if let Some(inner_val) = get_path_value(obj, &path.path) { +// dbg!(inner_val); +// if let Some(str_val) = inner_val.as_str() { +// // Use the last component of the path as the field name +// let field_name = path.new_key_name.clone(); +// map.insert( +// field_name.to_string(), +// serde_json::Value::String(str_val.to_string()), +// ); +// } +// } +// } +// } + +// Some(map) +// } else { +// None +// } +// }) +// .collect(); + +// serde_json::Value::Array(combined.into_iter().map(serde_json::Value::Object).collect()) +// } else { +// return Err(LearnerError::ApiError("Expected array for CombineFields".into())); +// }; + +// serde_json::to_string(&result) +// .map_err(|e| LearnerError::ApiError(format!("Failed to serialize result: {}", e))) +// }, +// } +// } + +// #[cfg(test)] +// mod tests { +// use super::*; + +// #[test] +// fn validate_arxiv_config() { +// let config_str = include_str!("../../config/retrievers/arxiv.toml"); + +// let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + +// // Verify basic fields +// assert_eq!(retriever.name, "arxiv"); +// assert_eq!(retriever.base_url, "http://export.arxiv.org"); +// assert_eq!(retriever.source, "arxiv"); + +// // Test pattern matching +// assert!(retriever.pattern.is_match("2301.07041")); +// assert!(retriever.pattern.is_match("math.AG/0601001")); +// assert!(retriever.pattern.is_match("https://arxiv.org/abs/2301.07041")); +// assert!(retriever.pattern.is_match("https://arxiv.org/pdf/2301.07041")); +// assert!(retriever.pattern.is_match("https://arxiv.org/abs/math.AG/0601001")); +// assert!(retriever.pattern.is_match("https://arxiv.org/abs/math/0404443")); + +// // Test identifier extraction +// assert_eq!(retriever.extract_identifier("2301.07041").unwrap(), "2301.07041"); +// assert_eq!( +// retriever.extract_identifier("https://arxiv.org/abs/2301.07041").unwrap(), +// "2301.07041" +// ); +// assert_eq!(retriever.extract_identifier("math.AG/0601001").unwrap(), "math.AG/0601001"); + +// // Verify response format + +// if let ResponseFormat::Xml(config) = &retriever.response_format { +// assert!(config.strip_namespaces); + +// // Verify field mappings +// let field_maps = &config.field_maps; +// assert!(field_maps.contains_key("title")); +// assert!(field_maps.contains_key("abstract")); +// assert!(field_maps.contains_key("authors")); +// assert!(field_maps.contains_key("publication_date")); +// assert!(field_maps.contains_key("pdf_url")); + +// // Verify PDF transform +// if let Some(map) = field_maps.get("pdf_url") { +// match &map.transform { +// Some(Transform::Replace { pattern, replacement }) => { +// assert_eq!(pattern, "/abs/"); +// assert_eq!(replacement, "/pdf/"); +// }, +// _ => panic!("Expected Replace transform for pdf_url"), +// } +// } else { +// panic!("Missing pdf_url field map"); +// } +// } else { +// panic!("Expected an XML configuration, but did not get one.") +// } + +// // Verify headers +// assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); +// } + +// // TODO: Fix this +// #[test] +// fn test_doi_config_deserialization() { +// let config_str = include_str!("../../config/retrievers/doi.toml"); + +// let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + +// dbg!(&retriever); + +// // Verify basic fields +// assert_eq!(retriever.name, "doi"); +// assert_eq!(retriever.base_url, "https://api.crossref.org/works"); +// assert_eq!(retriever.source, "doi"); + +// // Test pattern matching +// let test_cases = [ +// ("10.1145/1327452.1327492", true), +// ("https://doi.org/10.1145/1327452.1327492", true), +// ("invalid-doi", false), +// ("https://wrong.url/10.1145/1327452.1327492", false), +// ]; + +// for (input, expected) in test_cases { +// assert_eq!( +// retriever.pattern.is_match(input), +// expected, +// "Pattern match failed for input: {}", +// input +// ); +// } + +// // Test identifier extraction +// assert_eq!( +// retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), +// "10.1145/1327452.1327492" +// ); +// assert_eq!( +// retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), +// "10.1145/1327452.1327492" +// ); + +// // Verify response format +// match &retriever.response_format { +// ResponseFormat::Json(config) => { +// // Verify field mappings +// let field_maps = &config.field_maps; +// assert!(field_maps.contains_key("title")); +// assert!(field_maps.contains_key("abstract")); +// assert!(field_maps.contains_key("authors")); +// assert!(field_maps.contains_key("publication_date")); +// assert!(field_maps.contains_key("pdf_url")); +// assert!(field_maps.contains_key("doi")); +// }, +// _ => panic!("Expected JSON response format"), +// } +// } + +// #[test] +// fn test_iacr_config_deserialization() { +// let config_str = include_str!("../../config/retrievers/iacr.toml"); + +// let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + +// // Verify basic fields +// assert_eq!(retriever.name, "iacr"); +// assert_eq!(retriever.base_url, "https://eprint.iacr.org"); +// assert_eq!(retriever.source, "iacr"); + +// // Test pattern matching +// let test_cases = [ +// ("2016/260", true), +// ("2023/123", true), +// ("https://eprint.iacr.org/2016/260", true), +// ("https://eprint.iacr.org/2016/260.pdf", true), +// ("invalid/format", false), +// ("https://wrong.url/2016/260", false), +// ]; + +// for (input, expected) in test_cases { +// assert_eq!( +// retriever.pattern.is_match(input), +// expected, +// "Pattern match failed for input: {}", +// input +// ); +// } + +// // Test identifier extraction +// assert_eq!(retriever.extract_identifier("2016/260").unwrap(), "2016/260"); +// assert_eq!( +// retriever.extract_identifier("https://eprint.iacr.org/2016/260").unwrap(), +// "2016/260" +// ); +// assert_eq!( +// retriever.extract_identifier("https://eprint.iacr.org/2016/260.pdf").unwrap(), +// "2016/260" +// ); + +// // Verify response format +// if let ResponseFormat::Xml(config) = &retriever.response_format { +// assert!(config.strip_namespaces); + +// // Verify field mappings +// let field_maps = &config.field_maps; +// assert!(field_maps.contains_key("title")); +// assert!(field_maps.contains_key("abstract")); +// assert!(field_maps.contains_key("authors")); +// assert!(field_maps.contains_key("publication_date")); +// assert!(field_maps.contains_key("pdf_url")); + +// // Verify OAI-PMH paths +// if let Some(map) = field_maps.get("title") { +// assert!(map.path.contains(&"OAI-PMH/GetRecord/record/metadata/dc/title".to_string())); +// } else { +// panic!("Missing title field map"); +// } +// } else { +// panic!("Expected an XML configuration, but did not get one.") +// } + +// // Verify headers +// assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); +// } +// } diff --git a/crates/learner/src/retriever/response/json.rs b/crates/learner/src/retriever/response/json.rs index 392f1fc..d89a2dc 100644 --- a/crates/learner/src/retriever/response/json.rs +++ b/crates/learner/src/retriever/response/json.rs @@ -1,180 +1,43 @@ -use resource::{chrono_to_toml_datetime, FieldDefinition, TypeDefinition}; -use serde_json; -use toml::{self, Value as TomlValue}; +use resource::{datetime_to_json, FieldDefinition, TypeDefinition}; +use serde_json::{self, Number}; use super::*; #[derive(Debug, Clone, Deserialize)] pub struct JsonConfig { - pub field_maps: HashMap, + pub field_maps: BTreeMap, } // TODO: Refactor this impl ResponseProcessor for JsonConfig { fn process_response(&self, data: &[u8], resource_config: &ResourceConfig) -> Result { - // Parse raw JSON data - let json: serde_json::Value = serde_json::from_slice(data) - .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; - - trace!("Processing JSON response: {}", serde_json::to_string_pretty(&json).unwrap()); - - let mut resource = BTreeMap::new(); - - // Process each field according to resource configuration - for field_def in &resource_config.fields { - if let Some(field_map) = self.field_maps.get(&field_def.name) { - // Extract raw value if present, now passing the full field definition - if let Some(value) = self.extract_value(&json, field_map, field_def)? { - resource.insert(field_def.name.clone(), value); - } else if field_def.required { - return Err(LearnerError::ApiError(format!( - "Required field '{}' not found in response", - field_def.name - ))); - } else if let Some(default) = &field_def.default { - resource.insert(field_def.name.clone(), default.clone()); - } - } - } - - Ok(resource) - } -} - -impl JsonConfig { - /// Converts a JSON value into a TOML value, respecting type definitions - fn json_to_toml_value( - &self, - value: &serde_json::Value, - field_type: &str, - type_definition: Option<&TypeDefinition>, - ) -> Result> { - match field_type { - // Handle array types with potential element type definitions - "array" => { - let array = - value.as_array().ok_or_else(|| LearnerError::ApiError("Expected array value".into()))?; - - // Get element type definition if available - let element_def = type_definition.and_then(|def| def.element_type.as_ref()); - - // Convert each array element according to its type definition - let values: Result> = array - .iter() - .map(|item| { - if let Some(def) = element_def { - self.json_to_toml_value(item, &def.field_type, def.type_definition.as_ref()) - } else { - // For simple arrays without type definitions, do basic conversion - Ok(convert_simple_value(item)) - } - }) - .filter_map(|r| r.transpose()) - .collect(); - - Ok(Some(TomlValue::Array(values?))) - }, - - // Handle table types with field definitions - "table" => { - let mut map = toml::map::Map::new(); - - // If we have field definitions, follow them for the table structure - if let Some(type_def) = type_definition { - if let Some(fields) = &type_def.fields { - for field_def in fields { - if let Some(field_map) = self.field_maps.get(&field_def.name) { - if let Some(field_value) = get_path_value(value, &field_map.path) { - if let Some(converted) = self.json_to_toml_value( - field_value, - &field_def.field_type, - field_def.type_definition.as_ref(), - )? { - map.insert(field_def.name.clone(), converted); - } - } - } - } - } - } else { - // For tables without type definitions, convert all fields - let obj = value - .as_object() - .ok_or_else(|| LearnerError::ApiError("Expected object value".into()))?; - for (key, val) in obj { - if let Some(converted) = convert_simple_value(val) { - map.insert(key.clone(), converted); - } - } - } - - Ok(Some(TomlValue::Table(map))) - }, - - // Handle primitive types - "string" | "datetime" | "boolean" => convert_primitive_value(value, field_type), - - // Handle unsupported types - unsupported => - Err(LearnerError::ApiError(format!("Unsupported field type: {}", unsupported))), - } - } - - /// Updates extract_value to use the full field definition - fn extract_value( - &self, - json: &serde_json::Value, - field_map: &FieldMap, - field_def: &FieldDefinition, - ) -> Result> { - if let Some(value) = get_path_value(json, &field_map.path) { - // Apply transformations if configured - let transformed_value = if let Some(transform) = &field_map.transform { - serde_json::from_str(&apply_transform(&serde_json::to_string(&value)?, transform)?)? - } else { - value.clone() - }; - - // Convert using type definition - self.json_to_toml_value( - &transformed_value, - &field_def.field_type, - field_def.type_definition.as_ref(), - ) - } else { - Ok(None) - } - } -} - -/// Converts a primitive JSON value to a TOML value -fn convert_primitive_value( - value: &serde_json::Value, - field_type: &str, -) -> Result> { - match field_type { - "string" => value - .as_str() - .map(|s| TomlValue::String(s.to_string())) - .ok_or_else(|| LearnerError::ApiError("Expected string value".into())) - .map(Some), - - "datetime" => value - .as_str() - .ok_or_else(|| LearnerError::ApiError("Expected string for datetime".into())) - .and_then(|s| { - DateTime::parse_from_rfc3339(s) - .map_err(|e| LearnerError::ApiError(format!("Invalid datetime: {}", e))) - }) - .map(|dt| Some(TomlValue::Datetime(chrono_to_toml_datetime(dt.with_timezone(&Utc))))), - - "boolean" => value - .as_bool() - .map(TomlValue::Boolean) - .ok_or_else(|| LearnerError::ApiError("Expected boolean value".into())) - .map(Some), - - _ => Ok(convert_simple_value(value)), + todo!() + // // Parse raw JSON data + // let json: serde_json::Value = serde_json::from_slice(data) + // .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; + + // trace!("Processing JSON response: {}", serde_json::to_string_pretty(&json).unwrap()); + + // let mut resource = BTreeMap::new(); + + // // Process each field according to resource configuration + // for field_def in &resource_config.fields { + // if let Some(field_map) = self.field_maps.get(&field_def.name) { + // // Extract raw value if present, now passing the full field definition + // if let Some(value) = self.extract_value(&json, field_map, field_def)? { + // resource.insert(field_def.name.clone(), value); + // } else if field_def.required { + // return Err(LearnerError::ApiError(format!( + // "Required field '{}' not found in response", + // field_def.name + // ))); + // } else if let Some(default) = &field_def.default { + // resource.insert(field_def.name.clone(), default.clone()); + // } + // } + // } + + // Ok(resource) } } @@ -192,29 +55,3 @@ pub fn get_path_value<'a>( } Some(current) } - -/// Basic conversion for simple JSON values -fn convert_simple_value(value: &serde_json::Value) -> Option { - match value { - serde_json::Value::String(s) => Some(TomlValue::String(s.clone())), - serde_json::Value::Number(n) => - if n.is_i64() { - n.as_i64().map(TomlValue::Integer) - } else { - n.as_f64().map(TomlValue::Float) - }, - serde_json::Value::Bool(b) => Some(TomlValue::Boolean(*b)), - serde_json::Value::Array(arr) => { - let values: Vec<_> = arr.iter().filter_map(|item| convert_simple_value(item)).collect(); - Some(TomlValue::Array(values)) - }, - serde_json::Value::Object(obj) => { - let map = obj - .iter() - .filter_map(|(k, v)| convert_simple_value(v).map(|val| (k.clone(), val))) - .collect(); - Some(TomlValue::Table(map)) - }, - serde_json::Value::Null => None, - } -} diff --git a/crates/learner/src/retriever/response/xml.rs b/crates/learner/src/retriever/response/xml.rs index 66bcca4..4863836 100644 --- a/crates/learner/src/retriever/response/xml.rs +++ b/crates/learner/src/retriever/response/xml.rs @@ -19,8 +19,6 @@ //! ``` use quick_xml::{events::Event, Reader}; -use resource::chrono_to_toml_datetime; -use toml::Value; use super::*; @@ -48,7 +46,7 @@ pub struct XmlConfig { #[serde(default)] pub strip_namespaces: bool, /// XML path mappings for paper metadata fields - pub field_maps: HashMap, + pub field_maps: BTreeMap, } impl ResponseProcessor for XmlConfig { @@ -58,72 +56,73 @@ impl ResponseProcessor for XmlConfig { // retriever_config: &RetrieverConfig, resource_config: &ResourceConfig, ) -> Result { - // Handle namespace stripping - let xml = if self.strip_namespaces { - strip_xml_namespaces(&String::from_utf8_lossy(data)) - } else { - String::from_utf8_lossy(data).to_string() - }; - - trace!("Processing XML response: {:#?}", &xml); - - // Extract raw XML content into path -> string mapping - let content = self.extract_content(&xml)?; - let mut resource = BTreeMap::new(); - - // Process each field according to the resource configuration - for field_def in &resource_config.fields { - // Look up the field mapping from retriever config - if let Some(field_map) = self.field_maps.get(&field_def.name) { - // Try to get the raw value using configured path - if let Some(raw_value) = content.get(&field_map.path) { - // Apply any configured transformations - let transformed_value = if let Some(transform) = &field_map.transform { - apply_transform(raw_value, transform)? - } else { - raw_value.clone() - }; - - // Convert string to appropriate TOML type based on field definition - let value = match field_def.field_type.as_str() { - "string" => Value::String(transformed_value), - "datetime" => { - let dt = DateTime::parse_from_rfc3339(&transformed_value).map_err(|e| { - LearnerError::ApiError(format!( - "Invalid date format for field '{}': {}", - field_def.name, e - )) - })?; - Value::Datetime(chrono_to_toml_datetime(dt.with_timezone(&Utc))) - }, - "array" => { - // For arrays, split on semicolon and create string array - let values = - transformed_value.split(';').map(|s| Value::String(s.trim().to_string())).collect(); - Value::Array(values) - }, - // Add other type conversions as needed - unsupported => - return Err(LearnerError::ApiError(format!( - "Unsupported field type '{}' for field '{}'", - unsupported, field_def.name - ))), - }; - resource.insert(field_def.name.clone(), value); - } else if field_def.required { - // Field was required but not found in response - return Err(LearnerError::ApiError(format!( - "Required field '{}' not found in response", - field_def.name - ))); - } else if let Some(default) = &field_def.default { - // Use default value if available - resource.insert(field_def.name.clone(), default.clone()); - } - } - } + todo!() + // // Handle namespace stripping + // let xml = if self.strip_namespaces { + // strip_xml_namespaces(&String::from_utf8_lossy(data)) + // } else { + // String::from_utf8_lossy(data).to_string() + // }; + + // trace!("Processing XML response: {:#?}", &xml); + + // // Extract raw XML content into path -> string mapping + // let content = self.extract_content(&xml)?; + // let mut resource = BTreeMap::new(); + + // // Process each field according to the resource configuration + // for field_def in &resource_config.fields { + // // Look up the field mapping from retriever config + // if let Some(field_map) = self.field_maps.get(&field_def.name) { + // // Try to get the raw value using configured path + // if let Some(raw_value) = content.get(&field_map.path) { + // // Apply any configured transformations + // let transformed_value = if let Some(transform) = &field_map.transform { + // apply_transform(raw_value, transform)? + // } else { + // raw_value.clone() + // }; + + // // Convert string to appropriate TOML type based on field definition + // let value = match field_def.field_type.as_str() { + // "string" => Value::String(transformed_value), + // "datetime" => { + // let dt = DateTime::parse_from_rfc3339(&transformed_value).map_err(|e| { + // LearnerError::ApiError(format!( + // "Invalid date format for field '{}': {}", + // field_def.name, e + // )) + // })?; + // Value::String(chrono_to_toml_datetime(dt.with_timezone(&Utc))) + // }, + // "array" => { + // // For arrays, split on semicolon and create string array + // let values = + // transformed_value.split(';').map(|s| + // Value::String(s.trim().to_string())).collect(); Value::Array(values) + // }, + // // Add other type conversions as needed + // unsupported => + // return Err(LearnerError::ApiError(format!( + // "Unsupported field type '{}' for field '{}'", + // unsupported, field_def.name + // ))), + // }; + // resource.insert(field_def.name.clone(), value); + // } else if field_def.required { + // // Field was required but not found in response + // return Err(LearnerError::ApiError(format!( + // "Required field '{}' not found in response", + // field_def.name + // ))); + // } else if let Some(default) = &field_def.default { + // // Use default value if available + // resource.insert(field_def.name.clone(), default.clone()); + // } + // } + // } - Ok(resource) + // Ok(resource) } } @@ -140,9 +139,53 @@ impl XmlConfig { /// # Returns /// /// Returns a HashMap mapping XML paths to their text content. - fn extract_content(&self, xml: &str) -> Result> { + fn extract_content(&self, xml: &str) -> Result> { + // let ser_xml: Vec<(String, Value)> = quick_xml::de::from_str(xml).unwrap(); + // quick_xml::de:: + // dbg!(ser_xml); + + let mut reader = Reader::from_str(xml); + + let mut map = BTreeMap::new(); + + let mut current_key = Vec::new(); + + while let Ok(event) = reader.read_event() { + match event { + Event::Start(ref e) => { + let tag = String::from_utf8_lossy(e.trim_ascii()).to_string(); + current_key.push(tag); + }, + Event::Text(e) => { + let value = e.unescape().unwrap_or_default().trim().to_string(); + if !value.is_empty() { + let key = current_key.join("."); + map + .entry(key) + .and_modify(|existing| { + if let Value::Array(arr) = existing { + arr.push(Value::String(value.clone())); + } else { + *existing = Value::Array(vec![existing.clone(), Value::String(value.clone())]); + } + }) + .or_insert(Value::String(value)); + } + }, + Event::End(_) => { + current_key.pop(); + }, + Event::Eof => break, + _ => (), + } + } + + dbg!(map); + + //////////////////////////////////////////////////// + let mut reader = Reader::from_str(xml); - let mut content = HashMap::new(); + let mut content = BTreeMap::new(); let mut path_stack = Vec::new(); let mut buf = Vec::new(); From 542c36f6fe86f4a81a8e5d5d008b2367e9cac6b3 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 18:00:03 -0700 Subject: [PATCH 27/73] fixed retrievers --- crates/learner/config/resources/paper.toml | 13 +- crates/learner/config/retrievers/arxiv.toml | 4 +- crates/learner/config/retrievers/doi.toml | 2 +- crates/learner/config/retrievers/iacr.toml | 2 +- crates/learner/src/resource/mod.rs | 6 +- crates/learner/src/retriever/config.rs | 2 +- crates/learner/src/retriever/mod.rs | 83 ------ crates/learner/src/retriever/response/json.rs | 48 +--- crates/learner/src/retriever/response/mod.rs | 233 +++++++++++++++++ crates/learner/src/retriever/response/xml.rs | 247 +++++++----------- .../tests/workflows/paper_retrieval.rs | 11 +- 11 files changed, 353 insertions(+), 298 deletions(-) diff --git a/crates/learner/config/resources/paper.toml b/crates/learner/config/resources/paper.toml index 24e7087..6bc0e05 100644 --- a/crates/learner/config/resources/paper.toml +++ b/crates/learner/config/resources/paper.toml @@ -34,13 +34,16 @@ field_type = "string" name = "abstract" required = false -# Publication date - datetime type [[fields]] -description = "When the paper was published or last updated" -field_type = "string" -name = "publication_date" +description = "Publication and update history" +field_type = "array" +name = "publication_dates" required = false -validation = { datetime = true } + +[fields.type_definition.element_type] +field_type = "string" +name = "date" +validation = { datetime = true } # DOI field - string with pattern validation [[fields]] diff --git a/crates/learner/config/retrievers/arxiv.toml b/crates/learner/config/retrievers/arxiv.toml index 4c72a86..05493cd 100644 --- a/crates/learner/config/retrievers/arxiv.toml +++ b/crates/learner/config/retrievers/arxiv.toml @@ -16,9 +16,9 @@ path = "feed/entry/title" path = "feed/entry/summary" [response_format.field_maps.authors] -path = "feed/entry/author/name" +path = "feed/entry/author" -[response_format.field_maps.publication_date] +[response_format.field_maps.publication_dates] path = "feed/entry/published" [response_format.field_maps.pdf_url] diff --git a/crates/learner/config/retrievers/doi.toml b/crates/learner/config/retrievers/doi.toml index 193ea3f..0cbf529 100644 --- a/crates/learner/config/retrievers/doi.toml +++ b/crates/learner/config/retrievers/doi.toml @@ -26,7 +26,7 @@ fields = ["given", "family"] inner_paths = [{ new_key_name = "affiliation", path = "affiliation/0/name" }] type = "CombineFields" -[response_format.field_maps.publication_date] +[response_format.field_maps.publication_dates] path = "message/created/date-time" [response_format.field_maps.pdf_url] diff --git a/crates/learner/config/retrievers/iacr.toml b/crates/learner/config/retrievers/iacr.toml index 9ceee72..832ed91 100644 --- a/crates/learner/config/retrievers/iacr.toml +++ b/crates/learner/config/retrievers/iacr.toml @@ -18,7 +18,7 @@ path = "OAI-PMH/GetRecord/record/metadata/dc/description" [response_format.field_maps.authors] path = "OAI-PMH/GetRecord/record/metadata/dc/creator" -[response_format.field_maps.publication_date] +[response_format.field_maps.publication_dates] path = "OAI-PMH/GetRecord/record/metadata/dc/date" [response_format.field_maps.pdf_url] diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index 125de37..fba51b6 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -97,10 +97,10 @@ pub struct ValidationRules { impl ResourceConfig { /// Validates a set of values against this resource configuration - pub fn validate(&self, values: &Resource) -> Result { + pub fn validate(&self, resource: &Resource) -> Result { // Check required fields for field in &self.fields { - if field.required && !values.contains_key(&field.name) { + if field.required && !resource.contains_key(&field.name) { return Err(LearnerError::InvalidResource(format!( "Missing required field: {}", field.name @@ -109,7 +109,7 @@ impl ResourceConfig { } // Validate each provided field - for (name, value) in values { + for (name, value) in resource { if let Some(field) = self.fields.iter().find(|f| f.name == *name) { // Validate field value against its definition self.validate_field(field, value)?; diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 7b1b3c8..2325f56 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -85,7 +85,7 @@ impl RetrieverConfig { pub async fn retrieve_resource( &self, input: &str, - resource_config: ResourceConfig, + resource_config: &ResourceConfig, ) -> Result { let identifier = self.extract_identifier(input)?; diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 77b191a..c54a279 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -74,7 +74,6 @@ mod config; mod response; pub use config::*; -use json::get_path_value; pub use response::*; /// Main entry point for paper retrieval operations. @@ -301,88 +300,6 @@ where D: serde::Deserializer<'de> { Regex::new(&s).map_err(serde::de::Error::custom) } -// / Applies a transformation to a string value based on the transform type. -// / -// / Handles three types of transformations: -// / - Regular expression replacements -// / - Date format conversions -// / - URL construction -// / -// / # Errors -// / -// / Returns a LearnerError if: -// / - Regex pattern is invalid -// / - Date parsing fails -// / - Date format is invalid - -// fn apply_transform(value: &str, transform: &Transform) -> Result { -// match transform { -// Transform::Replace { pattern, replacement } => Regex::new(pattern) -// .map_err(|e| LearnerError::ApiError(format!("Invalid regex: {}", e))) -// .map(|re| re.replace_all(value, replacement.as_str()).into_owned()), -// Transform::Date { from_format, to_format } => -// chrono::NaiveDateTime::parse_from_str(value, from_format) -// .map_err(|e| LearnerError::ApiError(format!("Invalid date: {}", e))) -// .map(|dt| dt.format(to_format).to_string()), -// Transform::Url { base, suffix } => -// Ok(format!("{}{}", base.replace("{value}", value), suffix.as_deref().unwrap_or(""))), -// Transform::CombineFields { fields, inner_paths } => { -// let json: serde_json::Value = serde_json::from_str(value) -// .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; - -// // Handle both single objects and arrays -// let result = if let Some(array) = json.as_array() { -// // Create array of objects with combined fields -// let combined: Vec> = array -// .iter() -// .filter_map(|obj| { -// let mut map = serde_json::Map::new(); -// dbg!(&obj); - -// // Handle the name fields combination -// let parts: Vec<_> = -// fields.iter().filter_map(|field| obj.get(field)).filter_map(|v| -// v.as_str()).collect(); - -// if !parts.is_empty() { -// map.insert("name".to_string(), serde_json::Value::String(parts.join(" "))); - -// // Handle any additional inner paths -// if let Some(paths) = inner_paths { -// dbg!(paths); -// for path in paths { -// if let Some(inner_val) = get_path_value(obj, &path.path) { -// dbg!(inner_val); -// if let Some(str_val) = inner_val.as_str() { -// // Use the last component of the path as the field name -// let field_name = path.new_key_name.clone(); -// map.insert( -// field_name.to_string(), -// serde_json::Value::String(str_val.to_string()), -// ); -// } -// } -// } -// } - -// Some(map) -// } else { -// None -// } -// }) -// .collect(); - -// serde_json::Value::Array(combined.into_iter().map(serde_json::Value::Object).collect()) -// } else { -// return Err(LearnerError::ApiError("Expected array for CombineFields".into())); -// }; - -// serde_json::to_string(&result) -// .map_err(|e| LearnerError::ApiError(format!("Failed to serialize result: {}", e))) -// }, -// } -// } - // #[cfg(test)] // mod tests { // use super::*; diff --git a/crates/learner/src/retriever/response/json.rs b/crates/learner/src/retriever/response/json.rs index d89a2dc..c9eb322 100644 --- a/crates/learner/src/retriever/response/json.rs +++ b/crates/learner/src/retriever/response/json.rs @@ -1,5 +1,4 @@ -use resource::{datetime_to_json, FieldDefinition, TypeDefinition}; -use serde_json::{self, Number}; +use serde_json::{self}; use super::*; @@ -11,47 +10,10 @@ pub struct JsonConfig { // TODO: Refactor this impl ResponseProcessor for JsonConfig { fn process_response(&self, data: &[u8], resource_config: &ResourceConfig) -> Result { - todo!() - // // Parse raw JSON data - // let json: serde_json::Value = serde_json::from_slice(data) - // .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; + // Parse raw JSON data + let json: serde_json::Value = serde_json::from_slice(data) + .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; - // trace!("Processing JSON response: {}", serde_json::to_string_pretty(&json).unwrap()); - - // let mut resource = BTreeMap::new(); - - // // Process each field according to resource configuration - // for field_def in &resource_config.fields { - // if let Some(field_map) = self.field_maps.get(&field_def.name) { - // // Extract raw value if present, now passing the full field definition - // if let Some(value) = self.extract_value(&json, field_map, field_def)? { - // resource.insert(field_def.name.clone(), value); - // } else if field_def.required { - // return Err(LearnerError::ApiError(format!( - // "Required field '{}' not found in response", - // field_def.name - // ))); - // } else if let Some(default) = &field_def.default { - // resource.insert(field_def.name.clone(), default.clone()); - // } - // } - // } - - // Ok(resource) - } -} - -pub fn get_path_value<'a>( - json: &'a serde_json::Value, - path: &str, -) -> Option<&'a serde_json::Value> { - let mut current = json; - for part in path.split('/') { - current = if let Ok(index) = part.parse::() { - current.as_array()?.get(index)? - } else { - current.get(part)? - }; + dbg!(process_json_value(dbg!(&json), &self.field_maps, resource_config)) } - Some(current) } diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index 331863a..d543290 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -1,3 +1,6 @@ +use resource::FieldDefinition; +use serde_json::Map; + use super::*; pub mod json; @@ -155,3 +158,233 @@ pub trait ResponseProcessor: Send + Sync { resource_config: &ResourceConfig, ) -> Result; } + +/// Process a JSON value according to field mappings and resource configuration +fn process_json_value( + json: &Value, + field_maps: &BTreeMap, + resource_config: &ResourceConfig, +) -> Result { + let mut resource = Resource::new(); + + for field_def in &resource_config.fields { + if let Some(field_map) = field_maps.get(&field_def.name) { + if let Some(value) = extract_mapped_value(json, field_map, field_def)? { + resource.insert(field_def.name.clone(), value); + } else if field_def.required { + return Err(LearnerError::ApiError(format!( + "Required field '{}' not found in response", + field_def.name + ))); + } else if let Some(default) = &field_def.default { + resource.insert(field_def.name.clone(), default.clone()); + } + } + } + + Ok(resource) +} + +/// Extract and transform a value from JSON using a field mapping +fn extract_mapped_value( + json: &Value, + field_map: &FieldMap, + field_def: &FieldDefinition, +) -> Result> { + let path_components: Vec<&str> = field_map.path.split('/').collect(); + + // Extract raw value using path + let raw_value = get_path_value(json, &path_components)?; + + // If no value found, return None + let Some(raw_value) = raw_value else { + return Ok(None); + }; + + // First apply any explicit transforms + let value = if let Some(transform) = &field_map.transform { + apply_transform(&raw_value, transform)? + } else { + raw_value.clone() + }; + + // Then attempt type coercion based on field definition + let coerced = coerce_to_type(&value, field_def)?; + Ok(Some(coerced)) +} + +fn coerce_to_type(value: &Value, field_def: &FieldDefinition) -> Result { + match field_def.field_type.as_str() { + "array" => { + let arr = match value { + // Single value -> wrap in array + Value::String(_) | Value::Object(_) | Value::Number(_) => vec![value.clone()], + // Already an array + Value::Array(arr) => arr.clone(), + _ => return Ok(value.clone()), // Can't coerce, return as-is + }; + + // If we have inner type info, try to coerce each element + if let Some(ref type_def) = field_def.type_definition { + if let Some(ref element_def) = type_def.element_type { + let coerced: Vec = + arr.into_iter().map(|v| coerce_to_type(&v, element_def)).collect::>()?; + Ok(Value::Array(coerced)) + } else { + Ok(Value::Array(arr)) + } + } else { + Ok(Value::Array(arr)) + } + }, + "object" => { + // If we have field definitions, ensure object has required structure + if let Some(ref type_def) = field_def.type_definition { + if let Some(fields) = &type_def.fields { + let mut obj = Map::new(); + match value { + // Convert string to {name: string} if that's the structure we want + Value::String(s) if fields.len() == 1 && fields[0].name == "name" => { + obj.insert("name".to_string(), Value::String(s.clone())); + Ok(Value::Object(obj)) + }, + Value::Object(m) => { + // Copy over matching fields with coercion + for field in fields { + if let Some(v) = m.get(&field.name) { + obj.insert(field.name.clone(), coerce_to_type(v, field)?); + } + } + Ok(Value::Object(obj)) + }, + _ => Ok(value.clone()), + } + } else { + Ok(value.clone()) + } + } else { + Ok(value.clone()) + } + }, + // Add other type coercions as needed + _ => Ok(value.clone()), + } +} + +/// Get a value from JSON using a path +// Change return type to owned Value +fn get_path_value(json: &Value, path: &[&str]) -> Result> { + let mut current = json.clone(); + + for &component in path { + match current { + Value::Object(map) => + if let Some(value) = map.get(component) { + current = value.clone(); + } else { + return Ok(None); + }, + Value::Array(arr) => { + // If component is numeric, use it as array index + if let Ok(index) = component.parse::() { + if let Some(value) = arr.get(index) { + current = value.clone(); + } else { + return Ok(None); + } + } else { + // Otherwise collect matching values from array elements + let values: Vec = arr + .iter() + .filter_map(|item| match item { + Value::Object(map) => map.get(component).cloned(), + _ => None, + }) + .collect(); + + if values.is_empty() { + return Ok(None); + } else if values.len() == 1 { + current = values[0].clone(); + } else { + return Ok(Some(Value::Array(values))); + } + } + }, + _ => return Ok(None), + } + } + + Ok(Some(current)) +} + +/// Apply a transform to a JSON value +fn apply_transform(value: &Value, transform: &Transform) -> Result { + match transform { + Transform::Replace { pattern, replacement } => { + let text = value.as_str().ok_or_else(|| { + LearnerError::ApiError("Replace transform requires string input".to_string()) + })?; + let re = + Regex::new(pattern).map_err(|e| LearnerError::ApiError(format!("Invalid regex: {}", e)))?; + Ok(Value::String(re.replace_all(text, replacement.as_str()).into_owned())) + }, + + Transform::Date { from_format, to_format } => { + let text = value.as_str().ok_or_else(|| { + LearnerError::ApiError("Date transform requires string input".to_string()) + })?; + let dt = chrono::NaiveDateTime::parse_from_str(text, from_format) + .map_err(|e| LearnerError::ApiError(format!("Invalid date: {}", e)))?; + Ok(Value::String(dt.format(to_format).to_string())) + }, + + Transform::Url { base, suffix } => { + let text = value + .as_str() + .ok_or_else(|| LearnerError::ApiError("URL transform requires string input".to_string()))?; + Ok(Value::String(format!( + "{}{}", + base.replace("{value}", text), + suffix.as_deref().unwrap_or("") + ))) + }, + + Transform::CombineFields { fields, inner_paths } => { + let arr = value.as_array().ok_or_else(|| { + LearnerError::ApiError("CombineFields transform requires array input".to_string()) + })?; + + let combined: Vec = arr + .iter() + .filter_map(|item| { + let obj = item.as_object()?; + let mut result = Map::new(); + + // Combine name fields + let parts: Vec<_> = + fields.iter().filter_map(|field| obj.get(field)).filter_map(|v| v.as_str()).collect(); + + if !parts.is_empty() { + result.insert("name".to_string(), Value::String(parts.join(" "))); + + // Add any additional fields + if let Some(paths) = inner_paths { + for path in paths { + if let Ok(Some(inner_val)) = get_path_value(item, &[&path.path]) { + result.insert(path.new_key_name.clone(), inner_val.clone()); + } + } + } + + Some(Value::Object(result)) + } else { + None + } + }) + .collect(); + + Ok(Value::Array(combined)) + }, + } +} diff --git a/crates/learner/src/retriever/response/xml.rs b/crates/learner/src/retriever/response/xml.rs index 4863836..a12129a 100644 --- a/crates/learner/src/retriever/response/xml.rs +++ b/crates/learner/src/retriever/response/xml.rs @@ -50,168 +50,107 @@ pub struct XmlConfig { } impl ResponseProcessor for XmlConfig { - fn process_response( - &self, - data: &[u8], - // retriever_config: &RetrieverConfig, - resource_config: &ResourceConfig, - ) -> Result { - todo!() - // // Handle namespace stripping - // let xml = if self.strip_namespaces { - // strip_xml_namespaces(&String::from_utf8_lossy(data)) - // } else { - // String::from_utf8_lossy(data).to_string() - // }; - - // trace!("Processing XML response: {:#?}", &xml); - - // // Extract raw XML content into path -> string mapping - // let content = self.extract_content(&xml)?; - // let mut resource = BTreeMap::new(); - - // // Process each field according to the resource configuration - // for field_def in &resource_config.fields { - // // Look up the field mapping from retriever config - // if let Some(field_map) = self.field_maps.get(&field_def.name) { - // // Try to get the raw value using configured path - // if let Some(raw_value) = content.get(&field_map.path) { - // // Apply any configured transformations - // let transformed_value = if let Some(transform) = &field_map.transform { - // apply_transform(raw_value, transform)? - // } else { - // raw_value.clone() - // }; - - // // Convert string to appropriate TOML type based on field definition - // let value = match field_def.field_type.as_str() { - // "string" => Value::String(transformed_value), - // "datetime" => { - // let dt = DateTime::parse_from_rfc3339(&transformed_value).map_err(|e| { - // LearnerError::ApiError(format!( - // "Invalid date format for field '{}': {}", - // field_def.name, e - // )) - // })?; - // Value::String(chrono_to_toml_datetime(dt.with_timezone(&Utc))) - // }, - // "array" => { - // // For arrays, split on semicolon and create string array - // let values = - // transformed_value.split(';').map(|s| - // Value::String(s.trim().to_string())).collect(); Value::Array(values) - // }, - // // Add other type conversions as needed - // unsupported => - // return Err(LearnerError::ApiError(format!( - // "Unsupported field type '{}' for field '{}'", - // unsupported, field_def.name - // ))), - // }; - // resource.insert(field_def.name.clone(), value); - // } else if field_def.required { - // // Field was required but not found in response - // return Err(LearnerError::ApiError(format!( - // "Required field '{}' not found in response", - // field_def.name - // ))); - // } else if let Some(default) = &field_def.default { - // // Use default value if available - // resource.insert(field_def.name.clone(), default.clone()); - // } - // } - // } - - // Ok(resource) + fn process_response(&self, data: &[u8], resource_config: &ResourceConfig) -> Result { + // Handle namespace stripping + let xml = if self.strip_namespaces { + strip_xml_namespaces(&String::from_utf8_lossy(data)) + } else { + String::from_utf8_lossy(data).to_string() + }; + + trace!("Processing XML response: {:#?}", &xml); + + // Extract raw XML content into JSON equivalent + let json = convert_to_json(&xml); + dbg!(process_json_value(&json, &self.field_maps, resource_config)) } } -impl XmlConfig { - /// Extracts field values from XML content using path-based navigation. - /// - /// Builds a map of path -> value pairs by walking the XML tree and - /// tracking element paths. Handles nested elements and text content. - /// - /// # Arguments - /// - /// * `xml` - XML content as string - /// - /// # Returns - /// - /// Returns a HashMap mapping XML paths to their text content. - fn extract_content(&self, xml: &str) -> Result> { - // let ser_xml: Vec<(String, Value)> = quick_xml::de::from_str(xml).unwrap(); - // quick_xml::de:: - // dbg!(ser_xml); - - let mut reader = Reader::from_str(xml); - - let mut map = BTreeMap::new(); - - let mut current_key = Vec::new(); - - while let Ok(event) = reader.read_event() { - match event { - Event::Start(ref e) => { - let tag = String::from_utf8_lossy(e.trim_ascii()).to_string(); - current_key.push(tag); - }, - Event::Text(e) => { - let value = e.unescape().unwrap_or_default().trim().to_string(); - if !value.is_empty() { - let key = current_key.join("."); - map - .entry(key) - .and_modify(|existing| { - if let Value::Array(arr) = existing { - arr.push(Value::String(value.clone())); - } else { - *existing = Value::Array(vec![existing.clone(), Value::String(value.clone())]); - } - }) - .or_insert(Value::String(value)); - } - }, - Event::End(_) => { - current_key.pop(); - }, - Event::Eof => break, - _ => (), - } - } +use serde_json::{Map, Value}; + +pub fn convert_to_json(xml: &str) -> Value { + let mut reader = Reader::from_str(xml); + let mut stack = Vec::new(); + let mut current = Map::new(); + + while let Ok(event) = reader.read_event() { + match event { + Event::Start(ref e) => { + let tag = String::from_utf8_lossy(e.name().as_ref()).to_string(); - dbg!(map); - - //////////////////////////////////////////////////// - - let mut reader = Reader::from_str(xml); - let mut content = BTreeMap::new(); - let mut path_stack = Vec::new(); - let mut buf = Vec::new(); - - while let Ok(event) = reader.read_event_into(&mut buf) { - match event { - Event::Start(e) => { - path_stack.push(String::from_utf8_lossy(e.name().as_ref()).into_owned()); - }, - Event::Text(e) => - if let Ok(text) = e.unescape() { - let text = text.trim(); - if !text.is_empty() { - content.insert(path_stack.join("/"), text.to_string()); + // Create new object for this element + let mut new_obj = Map::new(); + + // Handle attributes + for attr in e.attributes().flatten() { + if let Ok(key) = String::from_utf8(attr.key.as_ref().to_vec()) { + if let Ok(value) = attr.unescape_value() { + new_obj.insert(format!("@{}", key), Value::String(value.into_owned())); } + } + } + + // Add this element to its parent + match current.get_mut(&tag) { + Some(Value::Array(_)) => { + // Element already exists as array, push onto it later + stack.push((tag, current, true)); }, - Event::End(_) => { - path_stack.pop(); - }, - Event::Eof => break, - _ => (), - } - buf.clear(); - } + Some(_) => { + // Element exists but not as array, convert to array + let existing = current.remove(&tag).unwrap(); + current.insert(tag.clone(), Value::Array(vec![existing])); + stack.push((tag, current, true)); + }, + None => { + // First occurrence of this element + stack.push((tag, current, false)); + }, + } + + current = new_obj; + }, + Event::Text(e) => { + if let Ok(txt) = e.unescape() { + let text = txt.trim(); + if !text.is_empty() { + if current.is_empty() { + // No attributes, just text content + current.insert("$text".to_string(), Value::String(text.to_string())); + } else { + // Has attributes, add text alongside them + current.insert("$text".to_string(), Value::String(text.to_string())); + } + } + } + }, + Event::End(_) => { + if let Some((tag, mut parent, is_array)) = stack.pop() { + // Simplify if only text content + let value = if current.len() == 1 && current.contains_key("$text") { + current.remove("$text").unwrap() + } else { + Value::Object(current) + }; + + // Add to parent according to array status + if is_array { + if let Some(Value::Array(arr)) = parent.get_mut(&tag) { + arr.push(value); + } + } else { + parent.insert(tag, value); + } - Ok(content) + current = parent; + } + }, + Event::Eof => break, + _ => (), + } } + + dbg!(Value::Object(current)) } /// Removes XML namespace declarations and prefixes from content. diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index ccedcd3..c912ec6 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -20,9 +20,10 @@ async fn test_arxiv_retriever_integration() -> TestResult<()> { let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); // Test with a real arXiv paper - let paper = retriever.retrieve_resource("2301.07041", resource).await?; + let paper = retriever.retrieve_resource("2301.07041", &resource).await?; dbg!(&paper); + assert!(resource.validate(&paper)?); assert_eq!( paper.get("title").unwrap().as_str().unwrap(), @@ -77,8 +78,8 @@ async fn test_iacr_retriever_integration() -> TestResult<()> { let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); // // Test with a real IACR paper - let paper = retriever.retrieve_resource("2016/260", resource).await.unwrap(); - + let paper = retriever.retrieve_resource("2016/260", &resource).await.unwrap(); + assert!(resource.validate(&paper)?); dbg!(&paper); todo!("This likely needs cleaned up"); @@ -129,8 +130,8 @@ async fn test_doi_retriever_integration() -> TestResult<()> { let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); // Test with a real DOI paper - let paper = retriever.retrieve_resource("10.1145/1327452.1327492", resource).await?; - + let paper = retriever.retrieve_resource("10.1145/1327452.1327492", &resource).await?; + assert!(resource.validate(&paper)?); dbg!(&paper); // assert!(!paper.title.is_empty()); // assert!(!paper.authors.is_empty()); From d7f0fb6b45e18a320769719ff20f91c1f0e35f9f Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 18:07:18 -0700 Subject: [PATCH 28/73] update todo --- crates/learner/tests/workflows/paper_retrieval.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index c912ec6..6fc271f 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -82,7 +82,7 @@ async fn test_iacr_retriever_integration() -> TestResult<()> { assert!(resource.validate(&paper)?); dbg!(&paper); - todo!("This likely needs cleaned up"); + todo!("This isn't actually validating properly because right now the authors isn't right."); // assert!(!paper.title.is_empty()); // assert!(!paper.authors.is_empty()); // assert!(!paper.abstract_text.is_empty()); From 667e264bc8b6b1ccf3bb45b922af9caf489809da Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 19:14:55 -0700 Subject: [PATCH 29/73] working doi --- crates/learner/config/retrievers/doi.toml | 25 ++- crates/learner/config/retrievers/iacr.toml | 1 + crates/learner/src/retriever/response/mod.rs | 157 ++++++++++++++----- 3 files changed, 143 insertions(+), 40 deletions(-) diff --git a/crates/learner/config/retrievers/doi.toml b/crates/learner/config/retrievers/doi.toml index 0cbf529..aad2319 100644 --- a/crates/learner/config/retrievers/doi.toml +++ b/crates/learner/config/retrievers/doi.toml @@ -9,7 +9,18 @@ source = "doi" type = "json" [response_format.field_maps.title] -path = "message/title/0" +path = "message" +[response_format.field_maps.title.transform] +sources = [ + { type = "path", value = "title/0" }, + { type = "path", value = "subtitle/0" }, +] +type = "Compose" + +[response_format.field_maps.title.transform.format] +delimiter = ": " +type = "Join" + [response_format.field_maps.abstract] path = "message/abstract" @@ -22,9 +33,15 @@ type = "Replace" [response_format.field_maps.authors] path = "message/author" [response_format.field_maps.authors.transform] -fields = ["given", "family"] -inner_paths = [{ new_key_name = "affiliation", path = "affiliation/0/name" }] -type = "CombineFields" +sources = [ + { type = "key_value", value = { key = "family", path = "family" } }, + { type = "key_value", value = { key = "given", path = "given" } }, +] +type = "Compose" + +[response_format.field_maps.authors.transform.format] +template = { name = "{given} {family}" } +type = "ArrayOfObjects" [response_format.field_maps.publication_dates] path = "message/created/date-time" diff --git a/crates/learner/config/retrievers/iacr.toml b/crates/learner/config/retrievers/iacr.toml index 832ed91..64bd55e 100644 --- a/crates/learner/config/retrievers/iacr.toml +++ b/crates/learner/config/retrievers/iacr.toml @@ -18,6 +18,7 @@ path = "OAI-PMH/GetRecord/record/metadata/dc/description" [response_format.field_maps.authors] path = "OAI-PMH/GetRecord/record/metadata/dc/creator" + [response_format.field_maps.publication_dates] path = "OAI-PMH/GetRecord/record/metadata/dc/date" diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index d543290..8cffcea 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -103,17 +103,40 @@ pub enum Transform { /// Optional suffix to append to the URL (e.g., ".pdf") suffix: Option, }, - // New transform for combining fields - CombineFields { - fields: Vec, // Fields to combine for name - inner_paths: Option>, // Additional paths to collect + Compose { + /// List of field paths or direct values to combine + sources: Vec, + /// How to format the combined result + format: ComposeFormat, }, } -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct InnerPath { - pub new_key_name: String, - pub path: String, +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", content = "value")] +pub enum Source { + /// Path to a field to extract + #[serde(rename = "path")] + Path(String), + /// A literal string value + #[serde(rename = "literal")] + Literal(String), + /// A field mapping with a new key name + #[serde(rename = "key_value")] + KeyValue { key: String, path: String }, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type")] +pub enum ComposeFormat { + /// Join fields with a delimiter + Join { delimiter: String }, + /// Create an object with key-value pairs + Object, + /// Create an array of objects with specified structure + ArrayOfObjects { + /// How to structure each object + template: BTreeMap, + }, } /// Trait for processing API responses into Paper objects. @@ -350,41 +373,103 @@ fn apply_transform(value: &Value, transform: &Transform) -> Result { ))) }, - Transform::CombineFields { fields, inner_paths } => { - let arr = value.as_array().ok_or_else(|| { - LearnerError::ApiError("CombineFields transform requires array input".to_string()) - })?; - - let combined: Vec = arr + Transform::Compose { sources, format } => { + // Extract values from each source + let values: Vec = sources .iter() - .filter_map(|item| { - let obj = item.as_object()?; - let mut result = Map::new(); + .filter_map(|source| match source { + Source::Path(path) => { + let components: Vec<&str> = path.split('/').collect(); + get_path_value(value, &components).ok().flatten() + }, + Source::Literal(text) => Some(Value::String(text.clone())), + Source::KeyValue { key: _, path } => { + let components: Vec<&str> = path.split('/').collect(); + get_path_value(value, &components).ok().flatten() + }, + }) + .collect(); - // Combine name fields - let parts: Vec<_> = - fields.iter().filter_map(|field| obj.get(field)).filter_map(|v| v.as_str()).collect(); + // Apply the format to the collected values + match format { + ComposeFormat::Join { delimiter } => { + // Convert values to strings and join + let strings: Vec = values + .iter() + .filter_map(|v| match v { + Value::String(s) => Some(s.clone()), + Value::Array(arr) if arr.len() == 1 => arr[0].as_str().map(|s| s.to_string()), + _ => None, + }) + .collect(); + Ok(Value::String(strings.join(delimiter))) + }, - if !parts.is_empty() { - result.insert("name".to_string(), Value::String(parts.join(" "))); + ComposeFormat::Object => { + let mut obj = Map::new(); + for (source, value) in sources.iter().zip(values.iter()) { + if let Source::KeyValue { key, .. } = source { + obj.insert(key.clone(), value.clone()); + } + } + Ok(Value::Object(obj)) + }, - // Add any additional fields - if let Some(paths) = inner_paths { - for path in paths { - if let Ok(Some(inner_val)) = get_path_value(item, &[&path.path]) { - result.insert(path.new_key_name.clone(), inner_val.clone()); - } + ComposeFormat::ArrayOfObjects { template } => { + match value { + // Handle single string -> array of objects + Value::String(s) => { + let mut obj = Map::new(); + for (key, template_value) in template { + let value = template_value.replace("{value}", s); + obj.insert(key.clone(), Value::String(value)); } - } + Ok(Value::Array(vec![Value::Object(obj)])) + }, - Some(Value::Object(result)) - } else { - None - } - }) - .collect(); + // Handle array -> array of objects + Value::Array(arr) => { + dbg!(&arr); + let objects: Vec = arr + .iter() + .filter_map(|item| { + dbg!(&item); + let mut obj = Map::new(); + for (key, template_value) in template { + let value = match item { + Value::String(s) => template_value.replace("{value}", s), + Value::Object(obj) => { + dbg!(obj); + let mut keys_and_vals = Vec::new(); + sources.iter().for_each(|source| { + if let Source::KeyValue { key, path } = source { + if let Some(val) = obj.get(path) { + keys_and_vals.push((key, val)) + } + } + }); + dbg!(&key); + keys_and_vals.into_iter().fold(template_value.clone(), |acc, (k, v)| { + let replacement = format!("{{{k}}}"); + acc.replace(&replacement, v.as_str().unwrap_or_default()) + }) + }, + _ => return None, + }; + obj.insert(key.clone(), Value::String(value)); + } + Some(Value::Object(obj)) + }) + .collect(); + Ok(Value::Array(objects)) + }, - Ok(Value::Array(combined)) + _ => Err(LearnerError::ApiError( + "ArrayOfObjects transform requires string or array input".to_string(), + )), + } + }, + } }, } } From 0cdfaff386570f636826ae0ed2ab6eeea6bc7dda Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 1 Dec 2024 19:27:21 -0700 Subject: [PATCH 30/73] working iacr --- crates/learner/config/retrievers/iacr.toml | 8 ++++++-- crates/learner/src/retriever/response/mod.rs | 8 ++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/crates/learner/config/retrievers/iacr.toml b/crates/learner/config/retrievers/iacr.toml index 64bd55e..c90f1fa 100644 --- a/crates/learner/config/retrievers/iacr.toml +++ b/crates/learner/config/retrievers/iacr.toml @@ -16,8 +16,12 @@ path = "OAI-PMH/GetRecord/record/metadata/dc/title" path = "OAI-PMH/GetRecord/record/metadata/dc/description" [response_format.field_maps.authors] -path = "OAI-PMH/GetRecord/record/metadata/dc/creator" - +path = "OAI-PMH/GetRecord/record/metadata/dc" +[response_format.field_maps.authors.transform] +sources = [{ type = "key_value", value = { key = "name", path = "creator" } }] +type = "Compose" +[response_format.field_maps.authors.transform.format] +type = "Object" [response_format.field_maps.publication_dates] path = "OAI-PMH/GetRecord/record/metadata/dc/date" diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index 8cffcea..131f9e6 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -343,6 +343,7 @@ fn get_path_value(json: &Value, path: &[&str]) -> Result> { /// Apply a transform to a JSON value fn apply_transform(value: &Value, transform: &Transform) -> Result { + dbg!(&value); match transform { Transform::Replace { pattern, replacement } => { let text = value.as_str().ok_or_else(|| { @@ -390,6 +391,8 @@ fn apply_transform(value: &Value, transform: &Transform) -> Result { }) .collect(); + dbg!(&values); + // Apply the format to the collected values match format { ComposeFormat::Join { delimiter } => { @@ -406,12 +409,17 @@ fn apply_transform(value: &Value, transform: &Transform) -> Result { }, ComposeFormat::Object => { + dbg!("inside here"); let mut obj = Map::new(); + dbg!(&sources); for (source, value) in sources.iter().zip(values.iter()) { + dbg!(&source); if let Source::KeyValue { key, .. } = source { + dbg!(key); obj.insert(key.clone(), value.clone()); } } + dbg!(&obj); Ok(Value::Object(obj)) }, From 17d37be43d4a0931ce64daa94771ac60f8244db9 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Mon, 2 Dec 2024 06:33:28 -0700 Subject: [PATCH 31/73] WIP: get retrievaldata --- crates/learner/config/retrievers/arxiv.toml | 6 +- crates/learner/src/database/record.rs | 24 +- crates/learner/src/retriever/config.rs | 23 +- crates/learner/src/retriever/mod.rs | 511 +++++++----------- crates/learner/src/retriever/response/mod.rs | 7 +- .../tests/workflows/paper_retrieval.rs | 15 +- 6 files changed, 237 insertions(+), 349 deletions(-) diff --git a/crates/learner/config/retrievers/arxiv.toml b/crates/learner/config/retrievers/arxiv.toml index 05493cd..f9bacbe 100644 --- a/crates/learner/config/retrievers/arxiv.toml +++ b/crates/learner/config/retrievers/arxiv.toml @@ -9,6 +9,8 @@ source = "arxiv" strip_namespaces = true type = "xml" + +# TODO: Could flatten out the `field_maps`? [response_format.field_maps.title] path = "feed/entry/title" @@ -21,10 +23,10 @@ path = "feed/entry/author" [response_format.field_maps.publication_dates] path = "feed/entry/published" -[response_format.field_maps.pdf_url] +[retrieval_data.urls] path = "feed/entry/id" -[response_format.field_maps.pdf_url.transform] +[retrieval_data.urls.transform] pattern = "/abs/" replacement = "/pdf/" type = "Replace" diff --git a/crates/learner/src/database/record.rs b/crates/learner/src/database/record.rs index a39285d..44334fb 100644 --- a/crates/learner/src/database/record.rs +++ b/crates/learner/src/database/record.rs @@ -5,17 +5,19 @@ use super::*; /// A complete view of a resource with all associated data #[derive(Debug)] pub struct ResourceRecord { - pub resource: ResourceConfig, - pub state: ResourceState, - pub tags: Vec, - pub storage: Option, - pub retrieval: Option, + pub resource: Resource, + pub resource_config: ResourceConfig, + pub state: ResourceState, + pub tags: Vec, + pub storage: Option, + pub retrieval: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ReadStatus { - Unread, - Reading { +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub enum Progress { + #[default] + Unopened, + Opened { progress: f32, // last_read: DateTime, // Track when reading sessions occur // total_time: Duration, // Accumulate reading time @@ -26,9 +28,9 @@ pub enum ReadStatus { }, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct ResourceState { - pub read_status: ReadStatus, + pub read_status: Progress, pub starred: bool, pub rating: Option, pub last_accessed: Option>, diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 2325f56..5d48407 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,3 +1,4 @@ +use record::{ResourceRecord, ResourceState, RetrievalData}; use resource::Resource; use super::*; @@ -31,9 +32,9 @@ use super::*; #[derive(Debug, Clone, Deserialize)] pub struct RetrieverConfig { /// Name of this retriever configuration - pub name: String, + pub name: String, + // TODO (autoparallel): Ultimately this will have to peer into the `Resources` to be useful - /// The type of resource this retriever should yield pub resource: String, /// Base URL for API requests pub base_url: String, @@ -44,11 +45,13 @@ pub struct RetrieverConfig { pub source: String, /// Template for constructing API endpoint URLs pub endpoint_template: String, - /// Format and parsing configuration for API responses + // TODO: This is now more like "how to get the thing to map into the resource" pub response_format: ResponseFormat, /// Optional HTTP headers for API requests #[serde(default)] pub headers: BTreeMap, + + pub retrieval_data: BTreeMap>, } impl Identifiable for RetrieverConfig { @@ -86,7 +89,7 @@ impl RetrieverConfig { &self, input: &str, resource_config: &ResourceConfig, - ) -> Result { + ) -> Result { let identifier = self.extract_identifier(input)?; // Send request and get response @@ -112,7 +115,7 @@ impl RetrieverConfig { }; // Process response and get resource - let mut resource = processor.process_response(&data, &resource_config)?; + let mut resource = processor.process_response(&data, resource_config)?; // Add source metadata resource.insert("source".into(), Value::String(self.source.clone())); @@ -120,7 +123,13 @@ impl RetrieverConfig { // Validate full resource against config resource_config.validate(&resource)?; - - Ok(resource) + Ok(ResourceRecord { + resource, + resource_config: resource_config.clone(), + retrieval: None, + state: ResourceState::default(), + storage: None, + tags: Vec::new(), + }) } } diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index c54a279..6a5b130 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -1,73 +1,3 @@ -//! Paper retrieval and metadata extraction framework. -//! -//! This module provides a flexible, configuration-driven system for retrieving academic papers -//! and their metadata from various sources. It supports both XML and JSON-based APIs through -//! a common interface, with configurable field mapping and transformation capabilities. -//! -//! # Architecture -//! -//! The retriever system consists of several key components: -//! -//! - [`Retrievers`]: Main entry point for paper retrieval operations -//! - [`RetrieverConfig`]: Configuration for specific paper sources -//! - [`ResponseFormat`]: Format-specific parsing logic (XML/JSON) -//! - [`ResponseProcessor`]: Trait for processing API responses -//! -//! # Features -//! -//! - Configuration-driven paper retrieval -//! - Support for multiple paper sources -//! - Flexible field mapping -//! - Custom field transformations -//! - Automatic source detection -//! -//! # Configuration -//! -//! Retrievers are configured using TOML files that specify: -//! -//! - API endpoints and authentication -//! - Field mapping rules -//! - Response format handling -//! - Identifier patterns -//! -//! # Examples -//! -//! Configure and use a retriever: -//! -//! ```no_run -//! use learner::{ -//! prelude::*, -//! retriever::{RetrieverConfig, Retrievers}, -//! }; -//! -//! # async fn example() -> Result<(), Box> { -//! // Create a new retriever -//! let retriever = -//! Retrievers::new().with_config_file("config/arxiv.toml")?.with_config_file("config/doi.toml")?; -//! -//! // Retrieve a paper (automatically detects source) -//! let paper = retriever.get_paper("10.1145/1327452.1327492").await?; -//! println!("Retrieved paper: {}", paper.title); -//! # Ok(()) -//! # } -//! ``` -//! -//! Load multiple configurations: -//! -//! ```no_run -//! # use learner::retriever::Retrievers; -//! # use learner::prelude::*; -//! # async fn example() -> Result<(), Box> { -//! // Load all TOML configs from a directory -//! let retriever = Retrievers::new().with_config_dir("config/")?; -//! -//! // Retriever will automatically match source based on input format -//! let arxiv_paper = retriever.get_paper("2301.07041").await?; -//! let doi_paper = retriever.get_paper("10.1145/1327452.1327492").await?; -//! # Ok(()) -//! # } -//! ``` - use super::*; mod config; @@ -76,28 +6,6 @@ mod response; pub use config::*; pub use response::*; -/// Main entry point for paper retrieval operations. -/// -/// The `Retriever` struct manages a collection of paper source configurations and -/// provides a unified interface for retrieving papers from any configured source. -/// It automatically detects the appropriate source based on the input identifier -/// format. -/// -/// # Examples -/// -/// ```no_run -/// # use learner::retriever::Retrievers; -/// # use learner::prelude::*; -/// # async fn example() -> Result<(), Box> { -/// let retriever = Retrievers::new().with_config_dir("config/")?; -/// -/// // Retrieve papers from different sources -/// let paper1 = retriever.get_paper("2301.07041").await?; // arXiv -/// let paper2 = retriever.get_paper("2023/123").await?; // IACR -/// let paper3 = retriever.get_paper("10.1145/1327452.1327492").await?; // DOI -/// # Ok(()) -/// # } -/// ``` #[derive(Default, Debug, Clone)] pub struct Retrievers { /// The collection of configurations used for this [`Retriever`]. @@ -146,42 +54,7 @@ impl Retrievers { /// ``` pub fn new() -> Self { Self::default() } - /// Attempts to retrieve a paper using any matching configuration. - /// - /// This method tries to match the input against all configured retrievers - /// and uses the first matching configuration to fetch the paper. - /// - /// # Arguments - /// - /// * `input` - Paper identifier or URL - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - The retrieved Paper object - /// - A LearnerError if no matching configuration is found or retrieval fails - /// - /// # Errors - /// - /// This method will return an error if: - /// - No configuration matches the input format - /// - Multiple configurations match ambiguously - /// - Paper retrieval fails - /// - /// # Examples - /// - /// ```no_run - /// # use learner::retriever::Retrievers; - /// # use learner::prelude::*; - /// # async fn example() -> Result<(), Box> { - /// let retriever = Retrievers::new().with_config_dir("config/")?; - /// - /// // Retrieve from different sources - /// let paper1 = retriever.get_paper("2301.07041").await?; - /// let paper2 = retriever.get_paper("10.1145/1327452.1327492").await?; - /// # Ok(()) - /// # } - /// ``` + #[deprecated] pub async fn get_paper(&self, input: &str) -> Result { let mut matches = Vec::new(); @@ -202,6 +75,13 @@ impl Retrievers { // } } + pub async fn get_resource(&self, input: &str) -> Result { + todo!( + "Arguably, we don't even need this. We could instead just have this handled by `Learner` so \ + the API is simpler" + ) + } + /// Sanitizes and normalizes a paper identifier using configured retrieval patterns. /// /// This function processes an input string (which could be a URL, DOI, arXiv ID, etc.) @@ -300,191 +180,190 @@ where D: serde::Deserializer<'de> { Regex::new(&s).map_err(serde::de::Error::custom) } -// #[cfg(test)] -// mod tests { -// use super::*; - -// #[test] -// fn validate_arxiv_config() { -// let config_str = include_str!("../../config/retrievers/arxiv.toml"); - -// let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); - -// // Verify basic fields -// assert_eq!(retriever.name, "arxiv"); -// assert_eq!(retriever.base_url, "http://export.arxiv.org"); -// assert_eq!(retriever.source, "arxiv"); - -// // Test pattern matching -// assert!(retriever.pattern.is_match("2301.07041")); -// assert!(retriever.pattern.is_match("math.AG/0601001")); -// assert!(retriever.pattern.is_match("https://arxiv.org/abs/2301.07041")); -// assert!(retriever.pattern.is_match("https://arxiv.org/pdf/2301.07041")); -// assert!(retriever.pattern.is_match("https://arxiv.org/abs/math.AG/0601001")); -// assert!(retriever.pattern.is_match("https://arxiv.org/abs/math/0404443")); - -// // Test identifier extraction -// assert_eq!(retriever.extract_identifier("2301.07041").unwrap(), "2301.07041"); -// assert_eq!( -// retriever.extract_identifier("https://arxiv.org/abs/2301.07041").unwrap(), -// "2301.07041" -// ); -// assert_eq!(retriever.extract_identifier("math.AG/0601001").unwrap(), "math.AG/0601001"); - -// // Verify response format - -// if let ResponseFormat::Xml(config) = &retriever.response_format { -// assert!(config.strip_namespaces); - -// // Verify field mappings -// let field_maps = &config.field_maps; -// assert!(field_maps.contains_key("title")); -// assert!(field_maps.contains_key("abstract")); -// assert!(field_maps.contains_key("authors")); -// assert!(field_maps.contains_key("publication_date")); -// assert!(field_maps.contains_key("pdf_url")); - -// // Verify PDF transform -// if let Some(map) = field_maps.get("pdf_url") { -// match &map.transform { -// Some(Transform::Replace { pattern, replacement }) => { -// assert_eq!(pattern, "/abs/"); -// assert_eq!(replacement, "/pdf/"); -// }, -// _ => panic!("Expected Replace transform for pdf_url"), -// } -// } else { -// panic!("Missing pdf_url field map"); -// } -// } else { -// panic!("Expected an XML configuration, but did not get one.") -// } - -// // Verify headers -// assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); -// } - -// // TODO: Fix this -// #[test] -// fn test_doi_config_deserialization() { -// let config_str = include_str!("../../config/retrievers/doi.toml"); - -// let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); - -// dbg!(&retriever); - -// // Verify basic fields -// assert_eq!(retriever.name, "doi"); -// assert_eq!(retriever.base_url, "https://api.crossref.org/works"); -// assert_eq!(retriever.source, "doi"); - -// // Test pattern matching -// let test_cases = [ -// ("10.1145/1327452.1327492", true), -// ("https://doi.org/10.1145/1327452.1327492", true), -// ("invalid-doi", false), -// ("https://wrong.url/10.1145/1327452.1327492", false), -// ]; - -// for (input, expected) in test_cases { -// assert_eq!( -// retriever.pattern.is_match(input), -// expected, -// "Pattern match failed for input: {}", -// input -// ); -// } - -// // Test identifier extraction -// assert_eq!( -// retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), -// "10.1145/1327452.1327492" -// ); -// assert_eq!( -// retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), -// "10.1145/1327452.1327492" -// ); - -// // Verify response format -// match &retriever.response_format { -// ResponseFormat::Json(config) => { -// // Verify field mappings -// let field_maps = &config.field_maps; -// assert!(field_maps.contains_key("title")); -// assert!(field_maps.contains_key("abstract")); -// assert!(field_maps.contains_key("authors")); -// assert!(field_maps.contains_key("publication_date")); -// assert!(field_maps.contains_key("pdf_url")); -// assert!(field_maps.contains_key("doi")); -// }, -// _ => panic!("Expected JSON response format"), -// } -// } - -// #[test] -// fn test_iacr_config_deserialization() { -// let config_str = include_str!("../../config/retrievers/iacr.toml"); - -// let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); - -// // Verify basic fields -// assert_eq!(retriever.name, "iacr"); -// assert_eq!(retriever.base_url, "https://eprint.iacr.org"); -// assert_eq!(retriever.source, "iacr"); - -// // Test pattern matching -// let test_cases = [ -// ("2016/260", true), -// ("2023/123", true), -// ("https://eprint.iacr.org/2016/260", true), -// ("https://eprint.iacr.org/2016/260.pdf", true), -// ("invalid/format", false), -// ("https://wrong.url/2016/260", false), -// ]; - -// for (input, expected) in test_cases { -// assert_eq!( -// retriever.pattern.is_match(input), -// expected, -// "Pattern match failed for input: {}", -// input -// ); -// } - -// // Test identifier extraction -// assert_eq!(retriever.extract_identifier("2016/260").unwrap(), "2016/260"); -// assert_eq!( -// retriever.extract_identifier("https://eprint.iacr.org/2016/260").unwrap(), -// "2016/260" -// ); -// assert_eq!( -// retriever.extract_identifier("https://eprint.iacr.org/2016/260.pdf").unwrap(), -// "2016/260" -// ); - -// // Verify response format -// if let ResponseFormat::Xml(config) = &retriever.response_format { -// assert!(config.strip_namespaces); - -// // Verify field mappings -// let field_maps = &config.field_maps; -// assert!(field_maps.contains_key("title")); -// assert!(field_maps.contains_key("abstract")); -// assert!(field_maps.contains_key("authors")); -// assert!(field_maps.contains_key("publication_date")); -// assert!(field_maps.contains_key("pdf_url")); - -// // Verify OAI-PMH paths -// if let Some(map) = field_maps.get("title") { -// assert!(map.path.contains(&"OAI-PMH/GetRecord/record/metadata/dc/title".to_string())); -// } else { -// panic!("Missing title field map"); -// } -// } else { -// panic!("Expected an XML configuration, but did not get one.") -// } - -// // Verify headers -// assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); -// } -// } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_arxiv_config() { + let config_str = include_str!("../../config/retrievers/arxiv.toml"); + + let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + + // Verify basic fields + assert_eq!(retriever.name, "arxiv"); + assert_eq!(retriever.base_url, "http://export.arxiv.org"); + assert_eq!(retriever.source, "arxiv"); + + // Test pattern matching + assert!(retriever.pattern.is_match("2301.07041")); + assert!(retriever.pattern.is_match("math.AG/0601001")); + assert!(retriever.pattern.is_match("https://arxiv.org/abs/2301.07041")); + assert!(retriever.pattern.is_match("https://arxiv.org/pdf/2301.07041")); + assert!(retriever.pattern.is_match("https://arxiv.org/abs/math.AG/0601001")); + assert!(retriever.pattern.is_match("https://arxiv.org/abs/math/0404443")); + + // Test identifier extraction + assert_eq!(retriever.extract_identifier("2301.07041").unwrap(), "2301.07041"); + assert_eq!( + retriever.extract_identifier("https://arxiv.org/abs/2301.07041").unwrap(), + "2301.07041" + ); + assert_eq!(retriever.extract_identifier("math.AG/0601001").unwrap(), "math.AG/0601001"); + + // Verify response format + + if let ResponseFormat::Xml(config) = &retriever.response_format { + assert!(config.strip_namespaces); + + // Verify field mappings + let field_maps = &config.field_maps; + assert!(field_maps.contains_key("title")); + assert!(field_maps.contains_key("abstract")); + assert!(field_maps.contains_key("authors")); + assert!(field_maps.contains_key("publication_date")); + assert!(field_maps.contains_key("pdf_url")); + + // Verify PDF transform + if let Some(map) = field_maps.get("pdf_url") { + match &map.transform { + Some(Transform::Replace { pattern, replacement }) => { + assert_eq!(pattern, "/abs/"); + assert_eq!(replacement, "/pdf/"); + }, + _ => panic!("Expected Replace transform for pdf_url"), + } + } else { + panic!("Missing pdf_url field map"); + } + } else { + panic!("Expected an XML configuration, but did not get one.") + } + + // Verify headers + assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); + } + + #[test] + fn test_doi_config_deserialization() { + let config_str = include_str!("../../config/retrievers/doi.toml"); + + let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + + dbg!(&retriever); + + // Verify basic fields + assert_eq!(retriever.name, "doi"); + assert_eq!(retriever.base_url, "https://api.crossref.org/works"); + assert_eq!(retriever.source, "doi"); + + // Test pattern matching + let test_cases = [ + ("10.1145/1327452.1327492", true), + ("https://doi.org/10.1145/1327452.1327492", true), + ("invalid-doi", false), + ("https://wrong.url/10.1145/1327452.1327492", false), + ]; + + for (input, expected) in test_cases { + assert_eq!( + retriever.pattern.is_match(input), + expected, + "Pattern match failed for input: {}", + input + ); + } + + // Test identifier extraction + assert_eq!( + retriever.extract_identifier("10.1145/1327452.1327492").unwrap(), + "10.1145/1327452.1327492" + ); + assert_eq!( + retriever.extract_identifier("https://doi.org/10.1145/1327452.1327492").unwrap(), + "10.1145/1327452.1327492" + ); + + // Verify response format + match &retriever.response_format { + ResponseFormat::Json(config) => { + // Verify field mappings + let field_maps = &config.field_maps; + assert!(field_maps.contains_key("title")); + assert!(field_maps.contains_key("abstract")); + assert!(field_maps.contains_key("authors")); + assert!(field_maps.contains_key("publication_date")); + assert!(field_maps.contains_key("pdf_url")); + assert!(field_maps.contains_key("doi")); + }, + _ => panic!("Expected JSON response format"), + } + } + + #[test] + fn test_iacr_config_deserialization() { + let config_str = include_str!("../../config/retrievers/iacr.toml"); + + let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + + // Verify basic fields + assert_eq!(retriever.name, "iacr"); + assert_eq!(retriever.base_url, "https://eprint.iacr.org"); + assert_eq!(retriever.source, "iacr"); + + // Test pattern matching + let test_cases = [ + ("2016/260", true), + ("2023/123", true), + ("https://eprint.iacr.org/2016/260", true), + ("https://eprint.iacr.org/2016/260.pdf", true), + ("invalid/format", false), + ("https://wrong.url/2016/260", false), + ]; + + for (input, expected) in test_cases { + assert_eq!( + retriever.pattern.is_match(input), + expected, + "Pattern match failed for input: {}", + input + ); + } + + // Test identifier extraction + assert_eq!(retriever.extract_identifier("2016/260").unwrap(), "2016/260"); + assert_eq!( + retriever.extract_identifier("https://eprint.iacr.org/2016/260").unwrap(), + "2016/260" + ); + assert_eq!( + retriever.extract_identifier("https://eprint.iacr.org/2016/260.pdf").unwrap(), + "2016/260" + ); + + // Verify response format + if let ResponseFormat::Xml(config) = &retriever.response_format { + assert!(config.strip_namespaces); + + // Verify field mappings + let field_maps = &config.field_maps; + assert!(field_maps.contains_key("title")); + assert!(field_maps.contains_key("abstract")); + assert!(field_maps.contains_key("authors")); + assert!(field_maps.contains_key("publication_date")); + assert!(field_maps.contains_key("pdf_url")); + + // Verify OAI-PMH paths + if let Some(map) = field_maps.get("title") { + assert!(map.path.contains(&"OAI-PMH/GetRecord/record/metadata/dc/title".to_string())); + } else { + panic!("Missing title field map"); + } + } else { + panic!("Expected an XML configuration, but did not get one.") + } + + // Verify headers + assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); + } +} diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index 131f9e6..b6f34df 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -174,12 +174,7 @@ pub trait ResponseProcessor: Send + Sync { /// Returns a Result containing either: /// - A fully populated Paper object /// - A LearnerError if parsing fails - fn process_response( - &self, - data: &[u8], - // retriever_config: RetrieverConfig, - resource_config: &ResourceConfig, - ) -> Result; + fn process_response(&self, data: &[u8], resource_config: &ResourceConfig) -> Result; } /// Process a JSON value according to field mappings and resource configuration diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index 6fc271f..b24f3bc 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -23,12 +23,12 @@ async fn test_arxiv_retriever_integration() -> TestResult<()> { let paper = retriever.retrieve_resource("2301.07041", &resource).await?; dbg!(&paper); - assert!(resource.validate(&paper)?); + // assert!(resource.validate(&paper)?); - assert_eq!( - paper.get("title").unwrap().as_str().unwrap(), - "Verifiable Fully Homomorphic Encryption" - ); + // assert_eq!( + // paper.get("title").unwrap().as_str().unwrap(), + // "Verifiable Fully Homomorphic Encryption" + // ); todo!("This needs cleaned up."); // assert!(!paper.title.is_empty()); // assert!(!paper.authors.is_empty()); @@ -79,7 +79,8 @@ async fn test_iacr_retriever_integration() -> TestResult<()> { // // Test with a real IACR paper let paper = retriever.retrieve_resource("2016/260", &resource).await.unwrap(); - assert!(resource.validate(&paper)?); + // assert!(resource.validate(&paper)?); // TODO: validation already happens internally, to be fair + // that validation may not be working totally right dbg!(&paper); todo!("This isn't actually validating properly because right now the authors isn't right."); @@ -131,7 +132,7 @@ async fn test_doi_retriever_integration() -> TestResult<()> { // Test with a real DOI paper let paper = retriever.retrieve_resource("10.1145/1327452.1327492", &resource).await?; - assert!(resource.validate(&paper)?); + // assert!(resource.validate(&paper)?); dbg!(&paper); // assert!(!paper.title.is_empty()); // assert!(!paper.authors.is_empty()); From 7807e05fefc05b336b13321bc360464440964d53 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Wed, 4 Dec 2024 06:10:06 -0700 Subject: [PATCH 32/73] WIP: working on configs --- Cargo.lock | 233 +++++++++++++++++- Cargo.toml | 2 +- crates/learner/Cargo.toml | 1 + crates/learner/config_new/base_resource.toml | 12 + crates/learner/config_new/paper.toml | 12 + crates/learner/config_new/paper_record.toml | 25 ++ crates/learner/src/configuration.rs | 190 ++++++++++++++ crates/learner/src/retriever/config.rs | 5 +- crates/learner/src/retriever/response/json.rs | 2 +- crates/learner/src/retriever/response/mod.rs | 10 +- crates/learner/src/retriever/response/xml.rs | 2 +- 11 files changed, 482 insertions(+), 12 deletions(-) create mode 100644 crates/learner/config_new/base_resource.toml create mode 100644 crates/learner/config_new/paper.toml create mode 100644 crates/learner/config_new/paper_record.toml diff --git a/Cargo.lock b/Cargo.lock index a92d180..f898136 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,12 @@ version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +[[package]] +name = "arraydeque" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" + [[package]] name = "assert_cmd" version = "2.0.16" @@ -184,6 +190,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -195,6 +207,9 @@ name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +dependencies = [ + "serde", +] [[package]] name = "block-buffer" @@ -354,6 +369,25 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "config" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68578f196d2a33ff61b27fae256c3164f65e36382648e30666dde05b8cc9dfdf" +dependencies = [ + "async-trait", + "convert_case", + "json5", + "nom", + "pathdiff", + "ron", + "rust-ini", + "serde", + "serde_json", + "toml", + "yaml-rust2", +] + [[package]] name = "console" version = "0.15.8" @@ -367,12 +401,50 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom", + "once_cell", + "tiny-keccak", +] + +[[package]] +name = "convert_case" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec182b0ca2f35d8fc196cf3404988fd8b8c739a4d270ff118a398feb0cbec1ca" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -441,6 +513,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -562,6 +640,15 @@ dependencies = [ "syn", ] +[[package]] +name = "dlv-list" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "442039f5147480ba31067cb00ada1adae6892028e40e45fc5de7b7df6dcc1b5f" +dependencies = [ + "const-random", +] + [[package]] name = "doc-comment" version = "0.3.3" @@ -794,6 +881,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash", + "allocator-api2", ] [[package]] @@ -807,6 +895,15 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hashlink" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "hashlink" version = "0.9.1" @@ -1158,6 +1255,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "json5" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b0db21af676c1ce64250b5f40f3ce2cf27e4e47cb91ed91eb6fe9350b430c1" +dependencies = [ + "pest", + "pest_derive", + "serde", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1171,6 +1279,7 @@ dependencies = [ "anyhow", "async-trait", "chrono", + "config", "dirs", "futures", "lazy_static", @@ -1452,6 +1561,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-multimap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49203cdcae0030493bad186b28da2fa25645fa276a51b6fec8010d281e02ef79" +dependencies = [ + "dlv-list", + "hashbrown 0.14.5", +] + [[package]] name = "overload" version = "0.1.1" @@ -1487,12 +1606,63 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pathdiff" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" + [[package]] name = "percent-encoding" version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pest" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879952a81a83930934cbf1786752d6dedc3b1f29e8f8fb2ad1d0a36f377cf442" +dependencies = [ + "memchr", + "thiserror 1.0.69", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d214365f632b123a47fd913301e14c946c61d1c183ee245fa76eb752e59a02dd" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb55586734301717aea2ac313f50b2eb8f60d2fc3dc01d190eefa2e625f60c4e" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pest_meta" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b75da2a70cf4d9cb76833c990ac9cd3923c9a8905a8929789ce347c84564d03d" +dependencies = [ + "once_cell", + "pest", + "sha2", +] + [[package]] name = "pin-project-lite" version = "0.2.15" @@ -1793,7 +1963,7 @@ version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "futures-core", "futures-util", @@ -1844,6 +2014,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ron" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" +dependencies = [ + "base64 0.21.7", + "bitflags", + "serde", + "serde_derive", +] + [[package]] name = "rusqlite" version = "0.32.1" @@ -1854,11 +2036,21 @@ dependencies = [ "chrono", "fallible-iterator", "fallible-streaming-iterator", - "hashlink", + "hashlink 0.9.1", "libsqlite3-sys", "smallvec", ] +[[package]] +name = "rust-ini" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e0698206bcb8882bf2a9ecb4c1e7785db57ff052297085a6efd4fe42302068a" +dependencies = [ + "cfg-if", + "ordered-multimap", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -2038,6 +2230,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2297,6 +2500,15 @@ dependencies = [ "time-core", ] +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + [[package]] name = "tinystr" version = "0.7.6" @@ -2542,6 +2754,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "unicode-ident" version = "1.0.14" @@ -2986,6 +3204,17 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +[[package]] +name = "yaml-rust2" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8902160c4e6f2fb145dbe9d6760a75e3c9522d8bf796ed7047c85919ac7115f8" +dependencies = [ + "arraydeque", + "encoding_rs", + "hashlink 0.8.4", +] + [[package]] name = "yansi" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 5a58e04..3dc2c55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,9 +28,9 @@ tokio = { version = "1.41", features = [ ] } tracing = { version = "0.1" } - # learner dependencies async-trait = { version = "0.1" } +config = { version = "0.14.1" } dirs = { version = "5.0" } futures = { version = "0.3.31" } lazy_static = { version = "1.5" } diff --git a/crates/learner/Cargo.toml b/crates/learner/Cargo.toml index f841e76..7a03874 100644 --- a/crates/learner/Cargo.toml +++ b/crates/learner/Cargo.toml @@ -13,6 +13,7 @@ version = "0.9.1" [dependencies] async-trait = { workspace = true } chrono = { workspace = true } +config = { workspace = true } dirs = { workspace = true } futures = { workspace = true } lazy_static = { workspace = true } diff --git a/crates/learner/config_new/base_resource.toml b/crates/learner/config_new/base_resource.toml new file mode 100644 index 0000000..5222b4c --- /dev/null +++ b/crates/learner/config_new/base_resource.toml @@ -0,0 +1,12 @@ +description = "Base configuration for all academic resources" +name = "base_resource" + +[item] +required_fields = ["title", "authors"] +resource_type = "base" + +# Define complete field definitions as inline tables +abstract_text = { name = "abstract", field_type = "string", required = false } +authors = { name = "authors", field_type = "array", required = true, validation = { min_items = 1 } } +publication_date = { name = "publication_date", field_type = "string", required = true, validation = { datetime = true } } +title = { name = "title", field_type = "string", required = true, validation = { min_length = 1, max_length = 500 } } diff --git a/crates/learner/config_new/paper.toml b/crates/learner/config_new/paper.toml new file mode 100644 index 0000000..565bba1 --- /dev/null +++ b/crates/learner/config_new/paper.toml @@ -0,0 +1,12 @@ +# paper.toml +description = "Academic paper configuration" +extends = ["base_resource"] +name = "paper" + +[item] +required_fields = ["title", "authors", "publication_date", "abstract_text"] +resource_type = "paper" + +[item.abstract_text] +field_type = "string" +required = true diff --git a/crates/learner/config_new/paper_record.toml b/crates/learner/config_new/paper_record.toml new file mode 100644 index 0000000..1e5096c --- /dev/null +++ b/crates/learner/config_new/paper_record.toml @@ -0,0 +1,25 @@ +# paper_record.toml +description = "Record configuration for academic papers" +extends = ["paper"] +name = "paper_record" + +[item.resource] +required_fields = ["title", "authors", "publication_date", "abstract_text"] +resource_type = "paper" + +[item.state_tracking] +allow_notes = true +progress_tracking = true +rating_system = 5 +track_access_time = true + +[item.storage] +required_files = ["pdf"] +track_checksums = true +track_file_history = true + +[item.retrieval] +access_types = ["open", "subscription"] +track_urls = true +url_types = ["pdf", "html"] +verify_access = true diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 0426fc1..9f8f65a 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -1,5 +1,98 @@ +use resource::FieldDefinition; +use serde::de::DeserializeOwned; + use super::*; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + /// Name of this configuration + pub name: String, + /// Optional description + #[serde(default)] + pub description: Option, + #[serde(default)] + pub extends: Option>, + + #[serde(default)] + pub additional_fields: BTreeMap, + /// The specific configuration type + pub item: T, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Resource { + /// Required fields for any academic resource + pub title: FieldDefinition, + pub authors: FieldDefinition, + pub publication_date: FieldDefinition, + pub abstract_text: Option, + + /// Resource-type specific requirements + pub resource_type: String, // paper, book, thesis, etc. + pub required_fields: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Record { + /// The resource type this record manages + pub resource: Resource, + + /// State tracking configuration + pub state_tracking: State, + + /// Storage configuration + pub storage: Storage, + + /// Retrieval configuration + pub retrieval: Retrieval, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Retriever { + /// The record type this retriever populates + pub record_type: String, + + /// API configuration + pub base_url: String, + pub endpoint_template: String, + pub pattern: String, + #[serde(default)] + pub headers: BTreeMap, + + /// How to process responses + pub response_format: ResponseFormat, + + /// Field mappings + pub resource_mappings: BTreeMap, + pub record_mappings: BTreeMap, +} + +/// Configuration for state tracking +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct State { + pub progress_tracking: bool, + pub rating_system: Option, + pub allow_notes: bool, + pub track_access_time: bool, +} + +/// Configuration for storage +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Storage { + pub required_files: Vec, + pub track_checksums: bool, + pub track_file_history: bool, +} + +/// Configuration for retrieval metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Retrieval { + pub track_urls: bool, + pub verify_access: bool, + pub url_types: Vec, + pub access_types: Vec, +} + pub trait Identifiable { fn name(&self) -> String; } @@ -41,3 +134,100 @@ pub trait Configurable: Sized { Ok(configurable) } } + +/// Helper for managing configurations with inheritance +pub struct ConfigurationManager { + builder: config::ConfigBuilder, + loaded_configs: BTreeMap, +} + +impl ConfigurationManager { + pub fn new() -> Self { + Self { builder: config::Config::builder(), loaded_configs: BTreeMap::new() } + } + + // TODO: Remove unwraps + pub fn load_config(&mut self, path: impl AsRef) -> Result> + where T: Serialize + DeserializeOwned + std::fmt::Debug { + let path = path.as_ref(); + let content = + std::fs::read_to_string(path).map_err(|e| config::ConfigError::Foreign(Box::new(e))).unwrap(); + + // Try to parse and provide detailed error information + match toml::from_str::>(&content) { + Ok(config) => { + let value = serde_json::to_value(&config) + .map_err(|e| config::ConfigError::Foreign(Box::new(e))) + .unwrap(); + self.loaded_configs.insert(config.name.clone(), value); + Ok(config) + }, + Err(e) => { + println!("Failed to parse configuration file: {}", path.display()); + println!("Error: {}", e); + println!("\nExpected structure for {} configuration:", std::any::type_name::()); + panic!() + // Print example structure if we're parsing a Resource + // if std::any::type_name::() == std::any::type_name::() { + // Resource::print_example_structure(); + // } + // Err(config::ConfigError::Foreign(Box::new(e))) + }, + } + } + + fn merge_configs(&self, base: Value, override_with: Value) -> Result { + use serde_json::Value::*; + + match (base, override_with) { + (Object(mut base_map), Object(override_map)) => { + for (k, v) in override_map { + match base_map.get(&k) { + Some(base_value) => { + let merged = self.merge_configs(base_value.clone(), v)?; + base_map.insert(k, merged); + }, + None => { + base_map.insert(k, v); + }, + } + } + Ok(Object(base_map)) + }, + // Arrays could be merged if needed + (Array(mut base_arr), Array(override_arr)) => { + // For now, just append new items + base_arr.extend(override_arr); + Ok(Array(base_arr)) + }, + // For all other cases, override takes precedence + (_, override_with) => Ok(override_with), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_extension() { + let mut manager = ConfigurationManager::new(); + + // Load configurations in order + let base_resource: Config = + dbg!(manager.load_config("config_new/base_resource.toml").unwrap()); + + // let paper: Config = dbg!(manager.load_config("config_new/paper.toml").unwrap()); + + // let paper_record: Config = + // dbg!(manager.load_config("config_new/paper_record.toml").unwrap()); + + // The paper_record now has all fields from base_resource and paper, + // plus its own record-specific configuration + + // assert_eq!(paper_record.item.resource.resource_type, "paper"); + // assert!(paper_record.item.resource.required_fields.contains(&"abstract_text".to_string())); + // assert!(paper_record.item.state_tracking.progress_tracking); + } +} diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 5d48407..50ea68b 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -50,8 +50,9 @@ pub struct RetrieverConfig { /// Optional HTTP headers for API requests #[serde(default)] pub headers: BTreeMap, - - pub retrieval_data: BTreeMap>, + // TODO: need to have these be associated somehow, actually resource should probably be in record + // pub resource_config: ResourceConfig, + // pub record_config: RecordConfig, } impl Identifiable for RetrieverConfig { diff --git a/crates/learner/src/retriever/response/json.rs b/crates/learner/src/retriever/response/json.rs index c9eb322..8e9dcdf 100644 --- a/crates/learner/src/retriever/response/json.rs +++ b/crates/learner/src/retriever/response/json.rs @@ -2,7 +2,7 @@ use serde_json::{self}; use super::*; -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct JsonConfig { pub field_maps: BTreeMap, } diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index b6f34df..b6b1d24 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -31,7 +31,7 @@ pub mod xml; /// [response_format.field_maps] /// title = { path = "message/title/0" } /// ``` -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ResponseFormat { /// XML response parser configuration @@ -53,7 +53,7 @@ pub enum ResponseFormat { /// path = "entry/title" /// transform = { type = "replace", pattern = "\\s+", replacement = " " } /// ``` -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct FieldMap { /// Path to field in response (e.g., JSON path or XPath) pub path: String, @@ -79,7 +79,7 @@ pub struct FieldMap { /// # Construct full URL /// transform = { type = "url", base = "https://example.com/", suffix = ".pdf" } /// ``` -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum Transform { /// Replace text using regex pattern @@ -111,7 +111,7 @@ pub enum Transform { }, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", content = "value")] pub enum Source { /// Path to a field to extract @@ -125,7 +125,7 @@ pub enum Source { KeyValue { key: String, path: String }, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ComposeFormat { /// Join fields with a delimiter diff --git a/crates/learner/src/retriever/response/xml.rs b/crates/learner/src/retriever/response/xml.rs index a12129a..df1fe35 100644 --- a/crates/learner/src/retriever/response/xml.rs +++ b/crates/learner/src/retriever/response/xml.rs @@ -40,7 +40,7 @@ use super::*; /// })]), /// }; /// ``` -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct XmlConfig { /// Whether to remove XML namespace declarations and prefixes #[serde(default)] From 40e3cfdfe43e1d622af96fb66fd5139e3f3d873b Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Wed, 4 Dec 2024 06:16:24 -0700 Subject: [PATCH 33/73] WIP: simpler field definition --- crates/learner/config_new/base_resource.toml | 8 +++---- crates/learner/src/configuration.rs | 25 +++++++++++++++++++- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/crates/learner/config_new/base_resource.toml b/crates/learner/config_new/base_resource.toml index 5222b4c..c861032 100644 --- a/crates/learner/config_new/base_resource.toml +++ b/crates/learner/config_new/base_resource.toml @@ -6,7 +6,7 @@ required_fields = ["title", "authors"] resource_type = "base" # Define complete field definitions as inline tables -abstract_text = { name = "abstract", field_type = "string", required = false } -authors = { name = "authors", field_type = "array", required = true, validation = { min_items = 1 } } -publication_date = { name = "publication_date", field_type = "string", required = true, validation = { datetime = true } } -title = { name = "title", field_type = "string", required = true, validation = { min_length = 1, max_length = 500 } } +abstract_text = { field_type = "string", required = false } +authors = { field_type = "array", required = true, validation = { min_items = 1 } } +publication_date = { field_type = "string", required = true, validation = { datetime = true } } +title = { field_type = "string", required = true, validation = { min_length = 1, max_length = 500 } } diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 9f8f65a..11a5564 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -1,8 +1,31 @@ -use resource::FieldDefinition; +use resource::{TypeDefinition, ValidationRules}; +// use resource::FieldDefinition; use serde::de::DeserializeOwned; use super::*; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FieldDefinition { + // /// Name of the field + // pub name: String, + /// Type of the field (should be a JSON Value type) + pub field_type: String, + /// Whether this field must be present + #[serde(default)] + pub required: bool, + /// Human-readable description + #[serde(default)] + pub description: Option, + /// Default value if field is absent + #[serde(default)] + pub default: Option, + /// Optional validation rules + #[serde(default)] + pub validation: Option, + + pub type_definition: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { /// Name of this configuration From 0844a8dbbd01e550cf2cc66ef6fbb49f3fbcc341 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Wed, 4 Dec 2024 06:44:59 -0700 Subject: [PATCH 34/73] WIP: clean `Resource` --- crates/learner/config_new/base_resource.toml | 3 +-- crates/learner/src/configuration.rs | 28 +++++++++++++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/crates/learner/config_new/base_resource.toml b/crates/learner/config_new/base_resource.toml index c861032..6183111 100644 --- a/crates/learner/config_new/base_resource.toml +++ b/crates/learner/config_new/base_resource.toml @@ -2,8 +2,7 @@ description = "Base configuration for all academic resources" name = "base_resource" [item] -required_fields = ["title", "authors"] -resource_type = "base" +resource_type = "base" # Define complete field definitions as inline tables abstract_text = { field_type = "string", required = false } diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 11a5564..47dc752 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -42,17 +42,27 @@ pub struct Config { pub item: T, } +// #[derive(Debug, Clone, Serialize, Deserialize)] +// pub struct Resource { +// /// Required fields for any academic resource +// pub title: FieldDefinition, +// pub authors: FieldDefinition, +// pub publication_date: FieldDefinition, +// pub abstract_text: Option, + +// /// Resource-type specific requirements +// pub resource_type: String, // paper, book, thesis, etc. +// pub required_fields: Vec, +// } + +// TODO: use this (may have to change back the fielddefinition now) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Resource { - /// Required fields for any academic resource - pub title: FieldDefinition, - pub authors: FieldDefinition, - pub publication_date: FieldDefinition, - pub abstract_text: Option, - - /// Resource-type specific requirements - pub resource_type: String, // paper, book, thesis, etc. - pub required_fields: Vec, + pub resource_type: String, + /// Field definitions with optional metadata + #[serde(default)] + #[serde(flatten)] + pub fields: BTreeMap, } #[derive(Debug, Clone, Serialize, Deserialize)] From 5d9287ae53ecf25f49ca63eed36f9a3601d8c2c3 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 7 Dec 2024 05:31:14 -0700 Subject: [PATCH 35/73] simplifying --- crates/learner/config_new/base_resource.toml | 11 ----------- crates/learner/config_new/paper.toml | 15 +++++---------- crates/learner/src/configuration.rs | 15 +++++---------- 3 files changed, 10 insertions(+), 31 deletions(-) delete mode 100644 crates/learner/config_new/base_resource.toml diff --git a/crates/learner/config_new/base_resource.toml b/crates/learner/config_new/base_resource.toml deleted file mode 100644 index 6183111..0000000 --- a/crates/learner/config_new/base_resource.toml +++ /dev/null @@ -1,11 +0,0 @@ -description = "Base configuration for all academic resources" -name = "base_resource" - -[item] -resource_type = "base" - -# Define complete field definitions as inline tables -abstract_text = { field_type = "string", required = false } -authors = { field_type = "array", required = true, validation = { min_items = 1 } } -publication_date = { field_type = "string", required = true, validation = { datetime = true } } -title = { field_type = "string", required = true, validation = { min_length = 1, max_length = 500 } } diff --git a/crates/learner/config_new/paper.toml b/crates/learner/config_new/paper.toml index 565bba1..af838a9 100644 --- a/crates/learner/config_new/paper.toml +++ b/crates/learner/config_new/paper.toml @@ -1,12 +1,7 @@ -# paper.toml -description = "Academic paper configuration" -extends = ["base_resource"] +description = "Configuration for a paper" name = "paper" -[item] -required_fields = ["title", "authors", "publication_date", "abstract_text"] -resource_type = "paper" - -[item.abstract_text] -field_type = "string" -required = true +abstract_text = { field_type = "string", required = false } +authors = { field_type = "array", required = true, validation = { min_items = 1 } } +publication_date = { field_type = "string", required = true, validation = { datetime = true } } +title = { field_type = "string", required = true, validation = { min_length = 1, max_length = 500 } } diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 47dc752..7f8ac88 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -6,8 +6,6 @@ use super::*; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FieldDefinition { - // /// Name of the field - // pub name: String, /// Type of the field (should be a JSON Value type) pub field_type: String, /// Whether this field must be present @@ -39,6 +37,7 @@ pub struct Config { #[serde(default)] pub additional_fields: BTreeMap, /// The specific configuration type + #[serde(flatten)] pub item: T, } @@ -58,11 +57,10 @@ pub struct Config { // TODO: use this (may have to change back the fielddefinition now) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Resource { - pub resource_type: String, /// Field definitions with optional metadata #[serde(default)] #[serde(flatten)] - pub fields: BTreeMap, + pub fields: BTreeMap, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -248,13 +246,10 @@ mod tests { let mut manager = ConfigurationManager::new(); // Load configurations in order - let base_resource: Config = - dbg!(manager.load_config("config_new/base_resource.toml").unwrap()); - - // let paper: Config = dbg!(manager.load_config("config_new/paper.toml").unwrap()); + let paper: Config = dbg!(manager.load_config("config_new/paper.toml").unwrap()); - // let paper_record: Config = - // dbg!(manager.load_config("config_new/paper_record.toml").unwrap()); + let paper_record: Config = + dbg!(manager.load_config("config_new/paper_record.toml").unwrap()); // The paper_record now has all fields from base_resource and paper, // plus its own record-specific configuration From c47ae26d0bd70f2bb749d544ba89074cd490bb0a Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 7 Dec 2024 06:28:43 -0700 Subject: [PATCH 36/73] organizing -- commit before i break everything --- crates/learner/config_new/arxiv.toml | 35 ++++ crates/learner/config_new/paper_record.toml | 25 --- crates/learner/src/configuration.rs | 195 +++++-------------- crates/learner/src/database/mod.rs | 2 +- crates/learner/src/lib.rs | 1 + crates/learner/src/{database => }/record.rs | 28 +-- crates/learner/src/retriever/config.rs | 56 ++---- crates/learner/src/retriever/mod.rs | 13 +- crates/learner/src/retriever/response/xml.rs | 4 +- crates/learner/tests/workflows/mod.rs | 2 +- 10 files changed, 124 insertions(+), 237 deletions(-) create mode 100644 crates/learner/config_new/arxiv.toml delete mode 100644 crates/learner/config_new/paper_record.toml rename crates/learner/src/{database => }/record.rs (82%) diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml new file mode 100644 index 0000000..f9bacbe --- /dev/null +++ b/crates/learner/config_new/arxiv.toml @@ -0,0 +1,35 @@ +base_url = "http://export.arxiv.org" +endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}&max_results=1" +name = "arxiv" +pattern = "(?:^|https?://arxiv\\.org/(?:abs|pdf)/)(\\d{4}\\.\\d{4,5}|[a-zA-Z-]+(?:\\.[A-Z]{2})?/\\d{7})(?:\\.pdf)?$" +resource = "paper" +source = "arxiv" + +[response_format] +strip_namespaces = true +type = "xml" + + +# TODO: Could flatten out the `field_maps`? +[response_format.field_maps.title] +path = "feed/entry/title" + +[response_format.field_maps.abstract] +path = "feed/entry/summary" + +[response_format.field_maps.authors] +path = "feed/entry/author" + +[response_format.field_maps.publication_dates] +path = "feed/entry/published" + +[retrieval_data.urls] +path = "feed/entry/id" + +[retrieval_data.urls.transform] +pattern = "/abs/" +replacement = "/pdf/" +type = "Replace" + +[headers] +Accept = "application/xml" diff --git a/crates/learner/config_new/paper_record.toml b/crates/learner/config_new/paper_record.toml deleted file mode 100644 index 1e5096c..0000000 --- a/crates/learner/config_new/paper_record.toml +++ /dev/null @@ -1,25 +0,0 @@ -# paper_record.toml -description = "Record configuration for academic papers" -extends = ["paper"] -name = "paper_record" - -[item.resource] -required_fields = ["title", "authors", "publication_date", "abstract_text"] -resource_type = "paper" - -[item.state_tracking] -allow_notes = true -progress_tracking = true -rating_system = 5 -track_access_time = true - -[item.storage] -required_files = ["pdf"] -track_checksums = true -track_file_history = true - -[item.retrieval] -access_types = ["open", "subscription"] -track_urls = true -url_types = ["pdf", "html"] -verify_access = true diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 7f8ac88..3675174 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -1,9 +1,22 @@ use resource::{TypeDefinition, ValidationRules}; -// use resource::FieldDefinition; use serde::de::DeserializeOwned; use super::*; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + /// Name of this configuration + pub name: String, + /// Optional description + #[serde(default)] + pub description: Option, + #[serde(default)] + pub additional_fields: BTreeMap, + /// The specific configuration type + #[serde(flatten)] + pub item: T, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FieldDefinition { /// Type of the field (should be a JSON Value type) @@ -24,37 +37,6 @@ pub struct FieldDefinition { pub type_definition: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Config { - /// Name of this configuration - pub name: String, - /// Optional description - #[serde(default)] - pub description: Option, - #[serde(default)] - pub extends: Option>, - - #[serde(default)] - pub additional_fields: BTreeMap, - /// The specific configuration type - #[serde(flatten)] - pub item: T, -} - -// #[derive(Debug, Clone, Serialize, Deserialize)] -// pub struct Resource { -// /// Required fields for any academic resource -// pub title: FieldDefinition, -// pub authors: FieldDefinition, -// pub publication_date: FieldDefinition, -// pub abstract_text: Option, - -// /// Resource-type specific requirements -// pub resource_type: String, // paper, book, thesis, etc. -// pub required_fields: Vec, -// } - -// TODO: use this (may have to change back the fielddefinition now) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Resource { /// Field definitions with optional metadata @@ -63,67 +45,7 @@ pub struct Resource { pub fields: BTreeMap, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Record { - /// The resource type this record manages - pub resource: Resource, - - /// State tracking configuration - pub state_tracking: State, - - /// Storage configuration - pub storage: Storage, - - /// Retrieval configuration - pub retrieval: Retrieval, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Retriever { - /// The record type this retriever populates - pub record_type: String, - - /// API configuration - pub base_url: String, - pub endpoint_template: String, - pub pattern: String, - #[serde(default)] - pub headers: BTreeMap, - - /// How to process responses - pub response_format: ResponseFormat, - - /// Field mappings - pub resource_mappings: BTreeMap, - pub record_mappings: BTreeMap, -} - -/// Configuration for state tracking -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct State { - pub progress_tracking: bool, - pub rating_system: Option, - pub allow_notes: bool, - pub track_access_time: bool, -} - -/// Configuration for storage -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Storage { - pub required_files: Vec, - pub track_checksums: bool, - pub track_file_history: bool, -} - -/// Configuration for retrieval metadata -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Retrieval { - pub track_urls: bool, - pub verify_access: bool, - pub url_types: Vec, - pub access_types: Vec, -} - +// TODO: These two traits can probably be removed pub trait Identifiable { fn name(&self) -> String; } @@ -166,74 +88,49 @@ pub trait Configurable: Sized { } } -/// Helper for managing configurations with inheritance pub struct ConfigurationManager { builder: config::ConfigBuilder, loaded_configs: BTreeMap, + // Track config paths for loading extends + config_paths: PathBuf, } impl ConfigurationManager { - pub fn new() -> Self { - Self { builder: config::Config::builder(), loaded_configs: BTreeMap::new() } + pub fn new(config_path: impl AsRef) -> Self { + Self { + builder: config::Config::builder(), + loaded_configs: BTreeMap::new(), + config_paths: config_path.as_ref().to_path_buf(), + } } - // TODO: Remove unwraps pub fn load_config(&mut self, path: impl AsRef) -> Result> - where T: Serialize + DeserializeOwned + std::fmt::Debug { + where T: DeserializeOwned + std::fmt::Debug { let path = path.as_ref(); - let content = - std::fs::read_to_string(path).map_err(|e| config::ConfigError::Foreign(Box::new(e))).unwrap(); + let content = std::fs::read_to_string(path)?; - // Try to parse and provide detailed error information - match toml::from_str::>(&content) { - Ok(config) => { - let value = serde_json::to_value(&config) - .map_err(|e| config::ConfigError::Foreign(Box::new(e))) - .unwrap(); - self.loaded_configs.insert(config.name.clone(), value); - Ok(config) - }, - Err(e) => { - println!("Failed to parse configuration file: {}", path.display()); - println!("Error: {}", e); - println!("\nExpected structure for {} configuration:", std::any::type_name::()); - panic!() - // Print example structure if we're parsing a Resource - // if std::any::type_name::() == std::any::type_name::() { - // Resource::print_example_structure(); - // } - // Err(config::ConfigError::Foreign(Box::new(e))) - }, - } - } + // Parse into toml::Value first + let mut raw_config: toml::Value = toml::from_str(&content)?; - fn merge_configs(&self, base: Value, override_with: Value) -> Result { - use serde_json::Value::*; + // If this is a Retriever config, handle resource reference + if std::any::type_name::() == std::any::type_name::() { + if let Some(toml::Value::String(resource_name)) = raw_config.get("resource") { + // Load the referenced resource + let resource_path = self.config_paths.join(format!("{resource_name}.toml")); + let resource_config: Config = self.load_config(&resource_path)?; - match (base, override_with) { - (Object(mut base_map), Object(override_map)) => { - for (k, v) in override_map { - match base_map.get(&k) { - Some(base_value) => { - let merged = self.merge_configs(base_value.clone(), v)?; - base_map.insert(k, merged); - }, - None => { - base_map.insert(k, v); - }, - } + // Replace the string reference with the actual resource + if let Some(table) = raw_config.as_table_mut() { + // TODO: Fix unwrap + table.insert("resource".into(), toml::Value::try_from(resource_config.item).unwrap()); } - Ok(Object(base_map)) - }, - // Arrays could be merged if needed - (Array(mut base_arr), Array(override_arr)) => { - // For now, just append new items - base_arr.extend(override_arr); - Ok(Array(base_arr)) - }, - // For all other cases, override takes precedence - (_, override_with) => Ok(override_with), + } } + + // Convert to final type through intermediate JSON representation + let json_value = serde_json::to_value(&raw_config)?; + let typed_config: Config = serde_json::from_value(json_value)?; + Ok(typed_config) } } @@ -243,13 +140,13 @@ mod tests { #[test] fn test_config_extension() { - let mut manager = ConfigurationManager::new(); + let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); // Load configurations in order let paper: Config = dbg!(manager.load_config("config_new/paper.toml").unwrap()); - let paper_record: Config = - dbg!(manager.load_config("config_new/paper_record.toml").unwrap()); + let arxiv_retriever: Config = + dbg!(manager.load_config("config_new/arxiv.toml").unwrap()); // The paper_record now has all fields from base_resource and paper, // plus its own record-specific configuration diff --git a/crates/learner/src/database/mod.rs b/crates/learner/src/database/mod.rs index d1aa1d6..d5cd6d5 100644 --- a/crates/learner/src/database/mod.rs +++ b/crates/learner/src/database/mod.rs @@ -57,7 +57,7 @@ use tokio_rusqlite::Connection; use super::*; mod instruction; -pub mod record; + #[cfg(test)] mod tests; pub use self::instruction::{ diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index 70d65d2..9dd3ba2 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -169,6 +169,7 @@ pub mod error; pub mod format; pub mod llm; pub mod pdf; +pub mod record; pub mod resource; use crate::{ diff --git a/crates/learner/src/database/record.rs b/crates/learner/src/record.rs similarity index 82% rename from crates/learner/src/database/record.rs rename to crates/learner/src/record.rs index 44334fb..51a66fd 100644 --- a/crates/learner/src/database/record.rs +++ b/crates/learner/src/record.rs @@ -1,16 +1,19 @@ -use resource::ResourceConfig; - use super::*; -/// A complete view of a resource with all associated data -#[derive(Debug)] -pub struct ResourceRecord { - pub resource: Resource, - pub resource_config: ResourceConfig, - pub state: ResourceState, - pub tags: Vec, - pub storage: Option, - pub retrieval: Option, +// TODO: Might want to put `Config`, etc. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Record { + /// The resource type this record manages + pub resource: Resource, + + /// State tracking configuration + pub state: State, + + /// Storage configuration + pub storage: StorageData, + + /// Retrieval configuration + pub retrieval: RetrievalData, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -29,7 +32,7 @@ pub enum Progress { } #[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct ResourceState { +pub struct State { pub read_status: Progress, pub starred: bool, pub rating: Option, @@ -37,6 +40,7 @@ pub struct ResourceState { pub notes: Option, pub citation_key: Option, // pub importance: Option, // Different from rating - how crucial is this? + pub tags: Vec, pub tags_updated_at: Option>, // Track tag management } diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 50ea68b..f79705f 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,36 +1,9 @@ -use record::{ResourceRecord, ResourceState, RetrievalData}; use resource::Resource; use super::*; -/// Configuration for a specific paper source retriever. -/// -/// This struct defines how to interact with a particular paper source's API, -/// including URL patterns, authentication, and response parsing rules. -/// -/// # Examples -/// -/// Example TOML configuration: -/// -/// ```toml -/// name = "arxiv" -/// base_url = "http://export.arxiv.org/api/query" -/// pattern = "^\\d{4}\\.\\d{4,5}$" -/// source = "arxiv" -/// endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}" -/// -/// [response_format] -/// type = "xml" -/// strip_namespaces = true -/// -/// [response_format.field_maps] -/// title = { path = "entry/title" } -/// abstract = { path = "entry/summary" } -/// publication_date = { path = "entry/published" } -/// authors = { path = "entry/author/name" } -/// ``` #[derive(Debug, Clone, Deserialize)] -pub struct RetrieverConfig { +pub struct Retriever { /// Name of this retriever configuration pub name: String, @@ -51,15 +24,15 @@ pub struct RetrieverConfig { #[serde(default)] pub headers: BTreeMap, // TODO: need to have these be associated somehow, actually resource should probably be in record - // pub resource_config: ResourceConfig, - // pub record_config: RecordConfig, + pub resource_mappings: BTreeMap, + pub record_mappings: BTreeMap, } -impl Identifiable for RetrieverConfig { +impl Identifiable for Retriever { fn name(&self) -> String { self.name.clone() } } -impl RetrieverConfig { +impl Retriever { /// Extracts the canonical identifier from an input string. /// /// Uses the configured regex pattern to extract the standardized @@ -90,7 +63,7 @@ impl RetrieverConfig { &self, input: &str, resource_config: &ResourceConfig, - ) -> Result { + ) -> Result { let identifier = self.extract_identifier(input)?; // Send request and get response @@ -124,13 +97,14 @@ impl RetrieverConfig { // Validate full resource against config resource_config.validate(&resource)?; - Ok(ResourceRecord { - resource, - resource_config: resource_config.clone(), - retrieval: None, - state: ResourceState::default(), - storage: None, - tags: Vec::new(), - }) + Ok(resource) + // Ok(Record { + // resource, + // resource_config: resource_config.clone(), + // retrieval: None, + // state: ResourceState::default(), + // storage: None, + // tags: Vec::new(), + // }) } } diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 6a5b130..a455646 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -6,14 +6,15 @@ mod response; pub use config::*; pub use response::*; +// TODO: This should be `BTreeMap>` #[derive(Default, Debug, Clone)] pub struct Retrievers { - /// The collection of configurations used for this [`Retriever`]. - configs: BTreeMap, + /// The collection of configurations used for this [`Retrievers`]. + configs: BTreeMap, } impl Configurable for Retrievers { - type Config = RetrieverConfig; + type Config = Retriever; fn as_map(&mut self) -> &mut BTreeMap { &mut self.configs } } @@ -188,7 +189,7 @@ mod tests { fn validate_arxiv_config() { let config_str = include_str!("../../config/retrievers/arxiv.toml"); - let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + let retriever: Retriever = toml::from_str(config_str).expect("Failed to parse config"); // Verify basic fields assert_eq!(retriever.name, "arxiv"); @@ -248,7 +249,7 @@ mod tests { fn test_doi_config_deserialization() { let config_str = include_str!("../../config/retrievers/doi.toml"); - let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + let retriever: Retriever = toml::from_str(config_str).expect("Failed to parse config"); dbg!(&retriever); @@ -304,7 +305,7 @@ mod tests { fn test_iacr_config_deserialization() { let config_str = include_str!("../../config/retrievers/iacr.toml"); - let retriever: RetrieverConfig = toml::from_str(config_str).expect("Failed to parse config"); + let retriever: Retriever = toml::from_str(config_str).expect("Failed to parse config"); // Verify basic fields assert_eq!(retriever.name, "iacr"); diff --git a/crates/learner/src/retriever/response/xml.rs b/crates/learner/src/retriever/response/xml.rs index df1fe35..360314b 100644 --- a/crates/learner/src/retriever/response/xml.rs +++ b/crates/learner/src/retriever/response/xml.rs @@ -45,8 +45,8 @@ pub struct XmlConfig { /// Whether to remove XML namespace declarations and prefixes #[serde(default)] pub strip_namespaces: bool, - /// XML path mappings for paper metadata fields - pub field_maps: BTreeMap, + // / XML path mappings for paper metadata fields + // pub field_maps: BTreeMap, } impl ResponseProcessor for XmlConfig { diff --git a/crates/learner/tests/workflows/mod.rs b/crates/learner/tests/workflows/mod.rs index cf9b3dd..eff2b73 100644 --- a/crates/learner/tests/workflows/mod.rs +++ b/crates/learner/tests/workflows/mod.rs @@ -1,4 +1,4 @@ -use learner::retriever::RetrieverConfig; +use learner::retriever::Retriever; use super::*; From ec2087af07185e64251fdc1a10c717f235c809cf Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 7 Dec 2024 06:56:04 -0700 Subject: [PATCH 37/73] WIP: still simplifying --- crates/learner/config_new/arxiv.toml | 31 +++++----- crates/learner/src/lib.rs | 20 +++---- crates/learner/src/retriever/config.rs | 33 ++++++----- crates/learner/src/retriever/mod.rs | 30 +++++----- crates/learner/src/retriever/response/json.rs | 19 ------- crates/learner/src/retriever/response/mod.rs | 48 ++-------------- crates/learner/src/retriever/response/xml.rs | 57 ++++--------------- 7 files changed, 76 insertions(+), 162 deletions(-) delete mode 100644 crates/learner/src/retriever/response/json.rs diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml index f9bacbe..149e5b8 100644 --- a/crates/learner/config_new/arxiv.toml +++ b/crates/learner/config_new/arxiv.toml @@ -1,35 +1,36 @@ +name = "arxiv" + +description = "Retriever template for getting a paper from arXiv" + base_url = "http://export.arxiv.org" endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}&max_results=1" -name = "arxiv" pattern = "(?:^|https?://arxiv\\.org/(?:abs|pdf)/)(\\d{4}\\.\\d{4,5}|[a-zA-Z-]+(?:\\.[A-Z]{2})?/\\d{7})(?:\\.pdf)?$" resource = "paper" source = "arxiv" -[response_format] -strip_namespaces = true -type = "xml" +response_format = { type = "xml", strip_namespaces = true } -# TODO: Could flatten out the `field_maps`? -[response_format.field_maps.title] +# # TODO: Could flatten out the `field_maps`? +[resource_mappings.title] path = "feed/entry/title" -[response_format.field_maps.abstract] +[resource_mappings.abstract] path = "feed/entry/summary" -[response_format.field_maps.authors] +[resource_mappings.authors] path = "feed/entry/author" -[response_format.field_maps.publication_dates] +[resource_mappings.publication_dates] path = "feed/entry/published" -[retrieval_data.urls] -path = "feed/entry/id" +# [retrieval_data.urls] +# path = "feed/entry/id" -[retrieval_data.urls.transform] -pattern = "/abs/" -replacement = "/pdf/" -type = "Replace" +# [retrieval_data.urls.transform] +# pattern = "/abs/" +# replacement = "/pdf/" +# type = "Replace" [headers] Accept = "application/xml" diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index 9dd3ba2..f300460 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -227,7 +227,6 @@ pub mod prelude { configuration::{Configurable, Identifiable}, database::DatabaseInstruction, error::LearnerError, - retriever::ResponseProcessor, }; } @@ -590,10 +589,11 @@ impl LearnerBuilder { let database = Database::open(&config.database_path).await?; database.set_storage_path(&config.storage_path).await?; - let retriever = Retrievers::new().with_config_dir(&config.retrievers_path)?; + todo!("This needs fixed now"); + // let retriever = Retrievers::new().with_config_dir(&config.retrievers_path)?; let resources = Resources::new().with_config_dir(&config.resources_path)?; - Ok(Learner { config, database, retrievers: retriever, resources }) + Ok(Learner { config, database, retrievers: Retrievers::new(), resources }) } } @@ -747,14 +747,14 @@ impl Learner { pub async fn init() -> Result { Self::with_config(Config::init()?).await } pub async fn retreive(&mut self, input: &str) -> Result { - let mut matches = Vec::new(); + // let mut matches = Vec::new(); - // Find all configs that match the input - for (name, config) in self.retrievers.as_map().iter() { - if config.pattern.is_match(input) { - matches.push((name, config)); - } - } + // // Find all configs that match the input + // for (name, config) in self.retrievers.as_map().iter() { + // if config.pattern.is_match(input) { + // matches.push((name, config)); + // } + // } todo!("Finish this") // match matches.len() { diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index f79705f..a4cd46b 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,14 +1,14 @@ -use resource::Resource; +// use resource::Resource; use super::*; +// TODO: fix all the stuff that had to do with `Retriever.name` + #[derive(Debug, Clone, Deserialize)] pub struct Retriever { - /// Name of this retriever configuration - pub name: String, + pub resource: Resource, - // TODO (autoparallel): Ultimately this will have to peer into the `Resources` to be useful - pub resource: String, + // TODO: Should own a `Record` /// Base URL for API requests pub base_url: String, /// Regex pattern for matching and extracting paper identifiers @@ -19,18 +19,21 @@ pub struct Retriever { /// Template for constructing API endpoint URLs pub endpoint_template: String, // TODO: This is now more like "how to get the thing to map into the resource" + // #[serde(flatten)] pub response_format: ResponseFormat, /// Optional HTTP headers for API requests #[serde(default)] pub headers: BTreeMap, // TODO: need to have these be associated somehow, actually resource should probably be in record + #[serde(default)] pub resource_mappings: BTreeMap, + #[serde(default)] pub record_mappings: BTreeMap, } -impl Identifiable for Retriever { - fn name(&self) -> String { self.name.clone() } -} +// impl Identifiable for Retriever { +// fn name(&self) -> String { self.name.clone() } +// } impl Retriever { /// Extracts the canonical identifier from an input string. @@ -68,7 +71,7 @@ impl Retriever { // Send request and get response let url = self.endpoint_template.replace("{identifier}", identifier); - debug!("Fetching from {} via: {}", self.name, url); + // debug!("Fetching from {} via: {}", self.name, url); let client = reqwest::Client::new(); let mut request = client.get(&url); @@ -80,16 +83,18 @@ impl Retriever { let response = request.send().await?; let data = response.bytes().await?; - trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); + + // trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); // Process the response using configured processor - let processor = match &self.response_format { - ResponseFormat::Xml(config) => config as &dyn ResponseProcessor, - ResponseFormat::Json(config) => config as &dyn ResponseProcessor, + let json = match &self.response_format { + ResponseFormat::Xml { strip_namespaces } => xml::convert_to_json(&data, *strip_namespaces), + ResponseFormat::Json => serde_json::from_slice(&data)?, }; // Process response and get resource - let mut resource = processor.process_response(&data, resource_config)?; + // TODO: this should probably be a method + let mut resource = process_json_value(&json, &self.resource_mappings, resource_config)?; // Add source metadata resource.insert("source".into(), Value::String(self.source.clone())); diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index a455646..fa65793 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -13,11 +13,11 @@ pub struct Retrievers { configs: BTreeMap, } -impl Configurable for Retrievers { - type Config = Retriever; +// impl Configurable for Retrievers { +// type Config = Retriever; - fn as_map(&mut self) -> &mut BTreeMap { &mut self.configs } -} +// fn as_map(&mut self) -> &mut BTreeMap { &mut self.configs } +// } impl Retrievers { /// Checks whether the retreivers map is empty. @@ -192,7 +192,7 @@ mod tests { let retriever: Retriever = toml::from_str(config_str).expect("Failed to parse config"); // Verify basic fields - assert_eq!(retriever.name, "arxiv"); + // assert_eq!(retriever.name, "arxiv"); assert_eq!(retriever.base_url, "http://export.arxiv.org"); assert_eq!(retriever.source, "arxiv"); @@ -214,11 +214,11 @@ mod tests { // Verify response format - if let ResponseFormat::Xml(config) = &retriever.response_format { - assert!(config.strip_namespaces); + if let ResponseFormat::Xml { strip_namespaces } = &retriever.response_format { + assert!(strip_namespaces); // Verify field mappings - let field_maps = &config.field_maps; + let field_maps = &retriever.resource_mappings; assert!(field_maps.contains_key("title")); assert!(field_maps.contains_key("abstract")); assert!(field_maps.contains_key("authors")); @@ -254,7 +254,7 @@ mod tests { dbg!(&retriever); // Verify basic fields - assert_eq!(retriever.name, "doi"); + // assert_eq!(retriever.name, "doi"); assert_eq!(retriever.base_url, "https://api.crossref.org/works"); assert_eq!(retriever.source, "doi"); @@ -287,9 +287,9 @@ mod tests { // Verify response format match &retriever.response_format { - ResponseFormat::Json(config) => { + ResponseFormat::Json => { // Verify field mappings - let field_maps = &config.field_maps; + let field_maps = &retriever.record_mappings; assert!(field_maps.contains_key("title")); assert!(field_maps.contains_key("abstract")); assert!(field_maps.contains_key("authors")); @@ -308,7 +308,7 @@ mod tests { let retriever: Retriever = toml::from_str(config_str).expect("Failed to parse config"); // Verify basic fields - assert_eq!(retriever.name, "iacr"); + // assert_eq!(retriever.name, "iacr"); assert_eq!(retriever.base_url, "https://eprint.iacr.org"); assert_eq!(retriever.source, "iacr"); @@ -343,11 +343,11 @@ mod tests { ); // Verify response format - if let ResponseFormat::Xml(config) = &retriever.response_format { - assert!(config.strip_namespaces); + if let ResponseFormat::Xml { strip_namespaces } = &retriever.response_format { + assert!(strip_namespaces); // Verify field mappings - let field_maps = &config.field_maps; + let field_maps = &retriever.resource_mappings; assert!(field_maps.contains_key("title")); assert!(field_maps.contains_key("abstract")); assert!(field_maps.contains_key("authors")); diff --git a/crates/learner/src/retriever/response/json.rs b/crates/learner/src/retriever/response/json.rs deleted file mode 100644 index 8e9dcdf..0000000 --- a/crates/learner/src/retriever/response/json.rs +++ /dev/null @@ -1,19 +0,0 @@ -use serde_json::{self}; - -use super::*; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JsonConfig { - pub field_maps: BTreeMap, -} - -// TODO: Refactor this -impl ResponseProcessor for JsonConfig { - fn process_response(&self, data: &[u8], resource_config: &ResourceConfig) -> Result { - // Parse raw JSON data - let json: serde_json::Value = serde_json::from_slice(data) - .map_err(|e| LearnerError::ApiError(format!("Failed to parse JSON: {}", e)))?; - - dbg!(process_json_value(dbg!(&json), &self.field_maps, resource_config)) - } -} diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index b6b1d24..c2b0c70 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -3,7 +3,6 @@ use serde_json::Map; use super::*; -pub mod json; pub mod xml; /// Available response format handlers. @@ -36,10 +35,13 @@ pub mod xml; pub enum ResponseFormat { /// XML response parser configuration #[serde(rename = "xml")] - Xml(xml::XmlConfig), + Xml { + #[serde(default)] + strip_namespaces: bool, + }, /// JSON response parser configuration #[serde(rename = "json")] - Json(json::JsonConfig), + Json, } /// Field mapping configuration. @@ -139,46 +141,8 @@ pub enum ComposeFormat { }, } -/// Trait for processing API responses into Paper objects. -/// -/// Implementors of this trait handle the conversion of raw API response data -/// into structured Paper metadata. The trait is implemented separately for -/// different response formats (XML, JSON) to provide a unified interface for -/// paper retrieval. -/// -/// # Examples -/// -/// ```no_run -/// # use learner::{retriever::ResponseProcessor, resource::Paper}; -/// # use learner::error::LearnerError; -/// struct CustomProcessor; -/// -/// #[async_trait::async_trait] -/// impl ResponseProcessor for CustomProcessor { -/// async fn process_response(&self, data: &[u8]) -> Result { -/// // Parse response data and construct Paper -/// todo!() -/// } -/// } -/// ``` -// #[async_trait] -pub trait ResponseProcessor: Send + Sync { - /// Process raw response data into a Paper object. - /// - /// # Arguments - /// - /// * `data` - Raw bytes from the API response - /// - /// # Returns - /// - /// Returns a Result containing either: - /// - A fully populated Paper object - /// - A LearnerError if parsing fails - fn process_response(&self, data: &[u8], resource_config: &ResourceConfig) -> Result; -} - /// Process a JSON value according to field mappings and resource configuration -fn process_json_value( +pub fn process_json_value( json: &Value, field_maps: &BTreeMap, resource_config: &ResourceConfig, diff --git a/crates/learner/src/retriever/response/xml.rs b/crates/learner/src/retriever/response/xml.rs index 360314b..99afc98 100644 --- a/crates/learner/src/retriever/response/xml.rs +++ b/crates/learner/src/retriever/response/xml.rs @@ -19,57 +19,20 @@ //! ``` use quick_xml::{events::Event, Reader}; +use serde_json::{Map, Value}; use super::*; -/// Configuration for processing XML API responses. -/// -/// Provides field mapping rules and namespace handling options to extract -/// paper metadata from XML responses using path-based access patterns. -/// -/// # Examples -/// -/// ```no_run -/// # use std::collections::HashMap; -/// # use learner::retriever::{xml::XmlConfig, FieldMap}; -/// let config = XmlConfig { -/// strip_namespaces: true, -/// field_maps: HashMap::from([("title".to_string(), FieldMap { -/// path: "entry/title".to_string(), -/// transform: None, -/// })]), -/// }; -/// ``` -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct XmlConfig { - /// Whether to remove XML namespace declarations and prefixes - #[serde(default)] - pub strip_namespaces: bool, - // / XML path mappings for paper metadata fields - // pub field_maps: BTreeMap, -} - -impl ResponseProcessor for XmlConfig { - fn process_response(&self, data: &[u8], resource_config: &ResourceConfig) -> Result { - // Handle namespace stripping - let xml = if self.strip_namespaces { - strip_xml_namespaces(&String::from_utf8_lossy(data)) - } else { - String::from_utf8_lossy(data).to_string() - }; - - trace!("Processing XML response: {:#?}", &xml); - - // Extract raw XML content into JSON equivalent - let json = convert_to_json(&xml); - dbg!(process_json_value(&json, &self.field_maps, resource_config)) - } -} - -use serde_json::{Map, Value}; +pub fn convert_to_json(data: &[u8], strip_namespaces: bool) -> Value { + // Handle namespace stripping + let xml = if strip_namespaces { + strip_xml_namespaces(&String::from_utf8_lossy(data)) + } else { + String::from_utf8_lossy(data).to_string() + }; -pub fn convert_to_json(xml: &str) -> Value { - let mut reader = Reader::from_str(xml); + trace!("Processing XML response: {:#?}", &xml); + let mut reader = Reader::from_str(&xml); let mut stack = Vec::new(); let mut current = Map::new(); From 0874fb2fbea7520da6b2757b69a50959009aa5a4 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 7 Dec 2024 07:05:21 -0700 Subject: [PATCH 38/73] WIP: will now try to get tests to pass --- crates/learner/src/configuration.rs | 4 ++-- crates/learner/src/record.rs | 6 ++--- crates/learner/src/retriever/config.rs | 24 ++++++++++++------- crates/learner/src/retriever/mod.rs | 2 +- .../tests/workflows/paper_retrieval.rs | 12 +++++----- 5 files changed, 27 insertions(+), 21 deletions(-) diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 3675174..d755c67 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -38,7 +38,7 @@ pub struct FieldDefinition { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Resource { +pub struct ResourceTemplate { /// Field definitions with optional metadata #[serde(default)] #[serde(flatten)] @@ -117,7 +117,7 @@ impl ConfigurationManager { if let Some(toml::Value::String(resource_name)) = raw_config.get("resource") { // Load the referenced resource let resource_path = self.config_paths.join(format!("{resource_name}.toml")); - let resource_config: Config = self.load_config(&resource_path)?; + let resource_config: Config = self.load_config(&resource_path)?; // Replace the string reference with the actual resource if let Some(table) = raw_config.as_table_mut() { diff --git a/crates/learner/src/record.rs b/crates/learner/src/record.rs index 51a66fd..f1ddc2b 100644 --- a/crates/learner/src/record.rs +++ b/crates/learner/src/record.rs @@ -39,12 +39,12 @@ pub struct State { pub last_accessed: Option>, pub notes: Option, pub citation_key: Option, - // pub importance: Option, // Different from rating - how crucial is this? + pub importance: Option, pub tags: Vec, - pub tags_updated_at: Option>, // Track tag management + pub tags_updated_at: Option>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RetrievalData { pub source: Option, pub source_identifier: Option, diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index a4cd46b..a1e9715 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,34 +1,40 @@ // use resource::Resource; +use configuration::ResourceTemplate; +use record::RetrievalData; + use super::*; // TODO: fix all the stuff that had to do with `Retriever.name` #[derive(Debug, Clone, Deserialize)] pub struct Retriever { - pub resource: Resource, + pub resource: ResourceTemplate, + #[serde(skip)] + #[serde(default)] + pub retrieval_data: RetrievalData, // TODO: Should own a `Record` /// Base URL for API requests - pub base_url: String, + pub base_url: String, /// Regex pattern for matching and extracting paper identifiers #[serde(deserialize_with = "deserialize_regex")] - pub pattern: Regex, + pub pattern: Regex, /// Source identifier for papers from this retriever - pub source: String, + pub source: String, /// Template for constructing API endpoint URLs - pub endpoint_template: String, + pub endpoint_template: String, // TODO: This is now more like "how to get the thing to map into the resource" // #[serde(flatten)] - pub response_format: ResponseFormat, + pub response_format: ResponseFormat, /// Optional HTTP headers for API requests #[serde(default)] - pub headers: BTreeMap, + pub headers: BTreeMap, // TODO: need to have these be associated somehow, actually resource should probably be in record #[serde(default)] - pub resource_mappings: BTreeMap, + pub resource_mappings: BTreeMap, #[serde(default)] - pub record_mappings: BTreeMap, + pub retrieval_mappings: BTreeMap, } // impl Identifiable for Retriever { diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index fa65793..5a7cdc8 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -289,7 +289,7 @@ mod tests { match &retriever.response_format { ResponseFormat::Json => { // Verify field mappings - let field_maps = &retriever.record_mappings; + let field_maps = &retriever.resource_mappings; assert!(field_maps.contains_key("title")); assert!(field_maps.contains_key("abstract")); assert!(field_maps.contains_key("authors")); diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index b24f3bc..11d9f93 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -16,7 +16,7 @@ async fn test_arxiv_retriever_integration() -> TestResult<()> { file", ); - let retriever: RetrieverConfig = toml::from_str(&ret_config_str).expect("Failed to parse config"); + let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); // Test with a real arXiv paper @@ -47,7 +47,7 @@ async fn test_arxiv_pdf_from_paper() -> TestResult<()> { file", ); - let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); + let retriever: Retriever = toml::from_str(&config_str).expect("Failed to parse config"); todo!() // // Test with a real arXiv paper @@ -74,7 +74,7 @@ async fn test_iacr_retriever_integration() -> TestResult<()> { file", ); - let retriever: RetrieverConfig = toml::from_str(&ret_config_str).expect("Failed to parse config"); + let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); // // Test with a real IACR paper @@ -100,7 +100,7 @@ async fn test_iacr_pdf_from_paper() -> TestResult<()> { let config_str = fs::read_to_string("config/retrievers/iacr.toml").expect("Failed to read config file"); - let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); + let retriever: Retriever = toml::from_str(&config_str).expect("Failed to parse config"); todo!() // // Test with a real arXiv paper @@ -127,7 +127,7 @@ async fn test_doi_retriever_integration() -> TestResult<()> { file", ); - let retriever: RetrieverConfig = toml::from_str(&ret_config_str).expect("Failed to parse config"); + let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); // Test with a real DOI paper @@ -152,7 +152,7 @@ async fn test_doi_pdf_from_paper() -> TestResult<()> { let config_str = fs::read_to_string("config/retrievers/doi.toml").expect("Failed to read config file"); - let retriever: RetrieverConfig = toml::from_str(&config_str).expect("Failed to parse config"); + let retriever: Retriever = toml::from_str(&config_str).expect("Failed to parse config"); todo!() // Test with a real arXiv paper // let paper = retriever.retrieve_paper("10.1145/1327452.1327492").await?; From 29316769c572eeb40efb7fca99c94cad611abc62 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 7 Dec 2024 07:38:29 -0700 Subject: [PATCH 39/73] WIP: deserializing properly again --- crates/learner/config_new/arxiv.toml | 16 +-- crates/learner/src/configuration.rs | 103 ++++++++++++------- crates/learner/src/resource/mod.rs | 1 + crates/learner/src/retriever/config.rs | 17 ++- crates/learner/src/retriever/response/mod.rs | 3 +- crates/sdk/src/validate.rs | 2 +- 6 files changed, 87 insertions(+), 55 deletions(-) diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml index 149e5b8..044ffb6 100644 --- a/crates/learner/config_new/arxiv.toml +++ b/crates/learner/config_new/arxiv.toml @@ -12,17 +12,17 @@ response_format = { type = "xml", strip_namespaces = true } # # TODO: Could flatten out the `field_maps`? -[resource_mappings.title] -path = "feed/entry/title" +# [resource_mappings.title] +# path = "feed/entry/title" -[resource_mappings.abstract] -path = "feed/entry/summary" +# [resource_mappings.abstract] +# path = "feed/entry/summary" -[resource_mappings.authors] -path = "feed/entry/author" +# [resource_mappings.authors] +# path = "feed/entry/author" -[resource_mappings.publication_dates] -path = "feed/entry/published" +# [resource_mappings.publication_dates] +# path = "feed/entry/published" # [retrieval_data.urls] # path = "feed/entry/id" diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index d755c67..f9dd464 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -1,4 +1,4 @@ -use resource::{TypeDefinition, ValidationRules}; +use resource::{FieldDefinition, TypeDefinition, ValidationRules}; use serde::de::DeserializeOwned; use super::*; @@ -17,32 +17,51 @@ pub struct Config { pub item: T, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FieldDefinition { - /// Type of the field (should be a JSON Value type) - pub field_type: String, - /// Whether this field must be present - #[serde(default)] - pub required: bool, - /// Human-readable description - #[serde(default)] - pub description: Option, - /// Default value if field is absent - #[serde(default)] - pub default: Option, - /// Optional validation rules - #[serde(default)] - pub validation: Option, - - pub type_definition: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] +// #[derive(Debug, Clone, Serialize, Deserialize)] +// pub struct FieldDefinition { +// /// Type of the field (should be a JSON Value type) +// pub field_type: String, +// /// Whether this field must be present +// #[serde(default)] +// pub required: bool, +// /// Human-readable description +// #[serde(default)] +// pub description: Option, +// /// Default value if field is absent +// #[serde(default)] +// pub default: Option, +// /// Optional validation rules +// #[serde(default)] +// pub validation: Option, + +// pub type_definition: Option, +// } + +#[derive(Debug, Clone, Serialize)] pub struct ResourceTemplate { /// Field definitions with optional metadata #[serde(default)] - #[serde(flatten)] - pub fields: BTreeMap, + // #[serde(flatten)] + pub fields: Vec, +} + +impl<'de> Deserialize<'de> for ResourceTemplate { + fn deserialize(deserializer: D) -> std::result::Result + where D: serde::Deserializer<'de> { + // First deserialize into a map + let map: BTreeMap = BTreeMap::deserialize(deserializer)?; + + // Convert the map into a Vec, setting the name from the key + let fields = map + .into_iter() + .map(|(key, mut field_def)| { + field_def.name = key; + field_def + }) + .collect(); + + Ok(ResourceTemplate { fields }) + } } // TODO: These two traits can probably be removed @@ -108,28 +127,39 @@ impl ConfigurationManager { where T: DeserializeOwned + std::fmt::Debug { let path = path.as_ref(); let content = std::fs::read_to_string(path)?; - - // Parse into toml::Value first let mut raw_config: toml::Value = toml::from_str(&content)?; // If this is a Retriever config, handle resource reference if std::any::type_name::() == std::any::type_name::() { if let Some(toml::Value::String(resource_name)) = raw_config.get("resource") { // Load the referenced resource - let resource_path = self.config_paths.join(format!("{resource_name}.toml")); - let resource_config: Config = self.load_config(&resource_path)?; - - // Replace the string reference with the actual resource + let resource_path = self.config_paths.join(format!("{}.toml", resource_name)); + let resource_content = std::fs::read_to_string(resource_path)?; + let resource_config: toml::Value = toml::from_str(&resource_content)?; + + // Get just the fields we need (ignore name, description etc) + let resource_fields = resource_config + .as_table() + .and_then(|t| { + Some( + t.iter() + .filter(|(k, _)| !["name", "description"].contains(&k.as_str())) + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>(), + ) + }) + .ok_or_else(|| config::ConfigError::Message("Invalid resource config structure".into())) + .unwrap(); + + // Replace the string reference with the resource fields if let Some(table) = raw_config.as_table_mut() { - // TODO: Fix unwrap - table.insert("resource".into(), toml::Value::try_from(resource_config.item).unwrap()); + table.insert("resource".into(), toml::Value::Table(resource_fields)); } } } - // Convert to final type through intermediate JSON representation - let json_value = serde_json::to_value(&raw_config)?; - let typed_config: Config = serde_json::from_value(json_value)?; + // Convert directly to final type + let typed_config: Config = raw_config.try_into()?; Ok(typed_config) } } @@ -143,7 +173,8 @@ mod tests { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); // Load configurations in order - let paper: Config = dbg!(manager.load_config("config_new/paper.toml").unwrap()); + let paper: Config = + dbg!(manager.load_config("config_new/paper.toml").unwrap()); let arxiv_retriever: Config = dbg!(manager.load_config("config_new/arxiv.toml").unwrap()); diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index fba51b6..4fe497e 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -46,6 +46,7 @@ impl Identifiable for ResourceConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FieldDefinition { /// Name of the field + #[serde(skip_deserializing)] pub name: String, /// Type of the field (should be a JSON Value type) pub field_type: String, diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index a1e9715..718e19c 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -10,7 +10,7 @@ use super::*; #[derive(Debug, Clone, Deserialize)] pub struct Retriever { pub resource: ResourceTemplate, - #[serde(skip)] + #[serde(skip_deserializing)] #[serde(default)] pub retrieval_data: RetrievalData, @@ -68,11 +68,7 @@ impl Retriever { // TODO: perhaps this just isn't even implemented here and is instead implemented on `Learner`. // Could consider an `api.rs` module to extend more learner functionality there. #[allow(missing_docs)] - pub async fn retrieve_resource( - &self, - input: &str, - resource_config: &ResourceConfig, - ) -> Result { + pub async fn retrieve_resource(&self, input: &str) -> Result { let identifier = self.extract_identifier(input)?; // Send request and get response @@ -100,15 +96,18 @@ impl Retriever { // Process response and get resource // TODO: this should probably be a method - let mut resource = process_json_value(&json, &self.resource_mappings, resource_config)?; + let mut resource = process_json_value(&json, &self.resource_mappings, &self.resource)?; // Add source metadata resource.insert("source".into(), Value::String(self.source.clone())); resource.insert("source_identifier".into(), Value::String(identifier.to_string())); // Validate full resource against config - resource_config.validate(&resource)?; - Ok(resource) + // self.resource.validate(&resource)?; + // Ok(resource) + + todo!() + // Ok(Record { // resource, // resource_config: resource_config.clone(), diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index c2b0c70..91cc36b 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -1,3 +1,4 @@ +use configuration::ResourceTemplate; use resource::FieldDefinition; use serde_json::Map; @@ -145,7 +146,7 @@ pub enum ComposeFormat { pub fn process_json_value( json: &Value, field_maps: &BTreeMap, - resource_config: &ResourceConfig, + resource_config: &ResourceTemplate, ) -> Result { let mut resource = Resource::new(); diff --git a/crates/sdk/src/validate.rs b/crates/sdk/src/validate.rs index 5a39892..0fef9c5 100644 --- a/crates/sdk/src/validate.rs +++ b/crates/sdk/src/validate.rs @@ -231,7 +231,7 @@ pub async fn validate_retriever(path: &PathBuf, input: &Option) { } } }, - ResponseFormat::Json(config) => { + ResponseFormat::Json => { println!("Format: {}", style("JSON").cyan()); println!("\nField Mappings:"); for (field, map) in &config.field_maps { From acc632f3a5c4754b90bda48bbe675414fc753533 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 7 Dec 2024 07:56:30 -0700 Subject: [PATCH 40/73] WIP: retrieval almost working again --- crates/learner/config_new/arxiv.toml | 28 +++++----- crates/learner/src/configuration.rs | 34 +++--------- crates/learner/src/lib.rs | 6 +-- crates/learner/src/record.rs | 2 + crates/learner/src/resource/mod.rs | 54 ++++++++++--------- crates/learner/src/retriever/config.rs | 8 +-- crates/learner/src/retriever/response/mod.rs | 3 +- .../tests/workflows/paper_retrieval.rs | 39 ++++++++------ 8 files changed, 82 insertions(+), 92 deletions(-) diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml index 044ffb6..5959f54 100644 --- a/crates/learner/config_new/arxiv.toml +++ b/crates/learner/config_new/arxiv.toml @@ -12,25 +12,25 @@ response_format = { type = "xml", strip_namespaces = true } # # TODO: Could flatten out the `field_maps`? -# [resource_mappings.title] -# path = "feed/entry/title" +[resource_mappings.title] +path = "feed/entry/title" -# [resource_mappings.abstract] -# path = "feed/entry/summary" +[resource_mappings.abstract] +path = "feed/entry/summary" -# [resource_mappings.authors] -# path = "feed/entry/author" +[resource_mappings.authors] +path = "feed/entry/author" -# [resource_mappings.publication_dates] -# path = "feed/entry/published" +[resource_mappings.publication_dates] +path = "feed/entry/published" -# [retrieval_data.urls] -# path = "feed/entry/id" +[retrieval_data.urls] +path = "feed/entry/id" -# [retrieval_data.urls.transform] -# pattern = "/abs/" -# replacement = "/pdf/" -# type = "Replace" +[retrieval_data.urls.transform] +pattern = "/abs/" +replacement = "/pdf/" +type = "Replace" [headers] Accept = "application/xml" diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index f9dd464..949ea9f 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -17,6 +17,11 @@ pub struct Config { pub item: T, } +// TODO: this is honestly probably dumb and needs refactored +impl Config { + pub fn inner(&self) -> &T { &self.item } +} + // #[derive(Debug, Clone, Serialize, Deserialize)] // pub struct FieldDefinition { // /// Type of the field (should be a JSON Value type) @@ -37,33 +42,6 @@ pub struct Config { // pub type_definition: Option, // } -#[derive(Debug, Clone, Serialize)] -pub struct ResourceTemplate { - /// Field definitions with optional metadata - #[serde(default)] - // #[serde(flatten)] - pub fields: Vec, -} - -impl<'de> Deserialize<'de> for ResourceTemplate { - fn deserialize(deserializer: D) -> std::result::Result - where D: serde::Deserializer<'de> { - // First deserialize into a map - let map: BTreeMap = BTreeMap::deserialize(deserializer)?; - - // Convert the map into a Vec, setting the name from the key - let fields = map - .into_iter() - .map(|(key, mut field_def)| { - field_def.name = key; - field_def - }) - .collect(); - - Ok(ResourceTemplate { fields }) - } -} - // TODO: These two traits can probably be removed pub trait Identifiable { fn name(&self) -> String; @@ -166,6 +144,8 @@ impl ConfigurationManager { #[cfg(test)] mod tests { + use resource::ResourceTemplate; + use super::*; #[test] diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index f300460..e13f5b8 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -153,7 +153,7 @@ use chrono::{DateTime, Utc}; use lazy_static::lazy_static; use regex::Regex; use reqwest::Url; -use resource::{Resource, ResourceConfig, Resources}; +use resource::{Resource, ResourceTemplate, Resources}; use serde::{Deserialize, Serialize}; use serde_json::Value; use tracing::{debug, trace, warn}; @@ -591,9 +591,9 @@ impl LearnerBuilder { database.set_storage_path(&config.storage_path).await?; todo!("This needs fixed now"); // let retriever = Retrievers::new().with_config_dir(&config.retrievers_path)?; - let resources = Resources::new().with_config_dir(&config.resources_path)?; + // let resources = Resources::new().with_config_dir(&config.resources_path)?; - Ok(Learner { config, database, retrievers: Retrievers::new(), resources }) + Ok(Learner { config, database, retrievers: Retrievers::new(), resources: Resources::new() }) } } diff --git a/crates/learner/src/record.rs b/crates/learner/src/record.rs index f1ddc2b..88f8f9e 100644 --- a/crates/learner/src/record.rs +++ b/crates/learner/src/record.rs @@ -1,5 +1,7 @@ use super::*; +// TODO: We probably want a `RecordTemplate` just like we have for `Resource`. Perhaps we can just +// have a general `Template` and make this like a "template engine" // TODO: Might want to put `Config`, etc. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Record { diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index 4fe497e..b066a15 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -14,33 +14,37 @@ pub type Resource = BTreeMap; #[derive(Debug, Clone, Default)] pub struct Resources { - configs: BTreeMap, + templates: BTreeMap, } impl Resources { pub fn new() -> Self { Self::default() } } -impl Configurable for Resources { - type Config = ResourceConfig; - - fn as_map(&mut self) -> &mut BTreeMap { &mut self.configs } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ResourceConfig { - /// The type identifier for this resource - pub name: String, - /// Optional description of this resource type - #[serde(default)] - pub description: Option, +#[derive(Debug, Clone, Serialize)] +pub struct ResourceTemplate { /// Field definitions with optional metadata #[serde(default)] - pub fields: Vec, + // #[serde(flatten)] + pub fields: Vec, } - -impl Identifiable for ResourceConfig { - fn name(&self) -> String { self.name.clone() } +impl<'de> Deserialize<'de> for ResourceTemplate { + fn deserialize(deserializer: D) -> std::result::Result + where D: serde::Deserializer<'de> { + // First deserialize into a map + let map: BTreeMap = BTreeMap::deserialize(deserializer)?; + + // Convert the map into a Vec, setting the name from the key + let fields = map + .into_iter() + .map(|(key, mut field_def)| { + field_def.name = key; + field_def + }) + .collect(); + + Ok(ResourceTemplate { fields }) + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -96,7 +100,7 @@ pub struct ValidationRules { pub datetime: Option, // Validates RFC3339 format } -impl ResourceConfig { +impl ResourceTemplate { /// Validates a set of values against this resource configuration pub fn validate(&self, resource: &Resource) -> Result { // Check required fields @@ -302,7 +306,7 @@ mod tests { #[test] fn validate_paper_configuration() { let config = include_str!("../../config/resources/paper.toml"); - let config: ResourceConfig = toml::from_str(config).unwrap(); + let config: ResourceTemplate = toml::from_str(config).unwrap(); let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); @@ -333,7 +337,7 @@ mod tests { #[test] fn validate_book_configuration() { let config = include_str!("../../config/resources/book.toml"); - let config: ResourceConfig = toml::from_str(config).unwrap(); + let config: ResourceTemplate = toml::from_str(config).unwrap(); let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); @@ -351,7 +355,7 @@ mod tests { #[test] fn validate_thesis_configuration() { let config = include_str!("../../config/resources/thesis.toml"); - let config: ResourceConfig = toml::from_str(config).unwrap(); + let config: ResourceTemplate = toml::from_str(config).unwrap(); let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); @@ -374,10 +378,8 @@ mod tests { #[test] fn test_datetime_validation() { - let mut config = ResourceConfig { - name: "test".into(), - description: None, - fields: vec![FieldDefinition { + let mut config = ResourceTemplate { + fields: vec![FieldDefinition { name: "timestamp".into(), field_type: "string".into(), required: true, diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 718e19c..f7875a9 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,7 +1,7 @@ // use resource::Resource; -use configuration::ResourceTemplate; use record::RetrievalData; +use resource::ResourceTemplate; use super::*; @@ -103,10 +103,10 @@ impl Retriever { resource.insert("source_identifier".into(), Value::String(identifier.to_string())); // Validate full resource against config - // self.resource.validate(&resource)?; - // Ok(resource) + self.resource.validate(&resource)?; + Ok(resource) - todo!() + // todo!() // Ok(Record { // resource, diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index 91cc36b..d6bbc0a 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -1,5 +1,4 @@ -use configuration::ResourceTemplate; -use resource::FieldDefinition; +use resource::{FieldDefinition, ResourceTemplate}; use serde_json::Map; use super::*; diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index 11d9f93..f0faba0 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -1,26 +1,32 @@ use std::fs; -use learner::resource::ResourceConfig; +use learner::{ + configuration::{Config, ConfigurationManager}, + resource::ResourceTemplate, +}; use super::*; #[traced_test] #[tokio::test] async fn test_arxiv_retriever_integration() -> TestResult<()> { - let ret_config_str = fs::read_to_string("config/retrievers/arxiv.toml").expect( - "Failed to read config - file", - ); - let res_config_str = fs::read_to_string("config/resources/paper.toml").expect( - "Failed to read config - file", - ); + // let ret_config_str = fs::read_to_string("config/retrievers/arxiv.toml").expect( + // "Failed to read config + // file", + // ); + // let res_config_str = fs::read_to_string("config/resources/paper.toml").expect( + // "Failed to read config + // file", + // ); + let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); + let retriever: Config = manager.load_config("config_new/arxiv.toml")?; - let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); - let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); + // let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); + // let resource: ResourceTemplate = toml::from_str(&res_config_str).expect("Failed to parse + // config"); // Test with a real arXiv paper - let paper = retriever.retrieve_resource("2301.07041", &resource).await?; + let paper = retriever.inner().retrieve_resource("2301.07041").await?; dbg!(&paper); // assert!(resource.validate(&paper)?); @@ -75,10 +81,11 @@ async fn test_iacr_retriever_integration() -> TestResult<()> { ); let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); - let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); + // let resource: ResourceTemplate = toml::from_str(&res_config_str).expect("Failed to parse + // config"); // // Test with a real IACR paper - let paper = retriever.retrieve_resource("2016/260", &resource).await.unwrap(); + let paper = retriever.retrieve_resource("2016/260").await.unwrap(); // assert!(resource.validate(&paper)?); // TODO: validation already happens internally, to be fair // that validation may not be working totally right dbg!(&paper); @@ -128,10 +135,10 @@ async fn test_doi_retriever_integration() -> TestResult<()> { ); let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); - let resource: ResourceConfig = toml::from_str(&res_config_str).expect("Failed to parse config"); + let resource: ResourceTemplate = toml::from_str(&res_config_str).expect("Failed to parse config"); // Test with a real DOI paper - let paper = retriever.retrieve_resource("10.1145/1327452.1327492", &resource).await?; + let paper = retriever.retrieve_resource("10.1145/1327452.1327492").await?; // assert!(resource.validate(&paper)?); dbg!(&paper); // assert!(!paper.title.is_empty()); From 9dd4fee1f76894aadf6f267912ac1683a46227c6 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 7 Dec 2024 18:14:42 -0700 Subject: [PATCH 41/73] passing integration tests again --- crates/learner/config_new/arxiv.toml | 1 - crates/learner/config_new/doi.toml | 59 +++++++++++++++++++ crates/learner/config_new/iacr.toml | 39 ++++++++++++ crates/learner/config_new/paper.toml | 8 +-- .../tests/workflows/paper_retrieval.rs | 31 +++------- 5 files changed, 109 insertions(+), 29 deletions(-) create mode 100644 crates/learner/config_new/doi.toml create mode 100644 crates/learner/config_new/iacr.toml diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml index 5959f54..ef4d446 100644 --- a/crates/learner/config_new/arxiv.toml +++ b/crates/learner/config_new/arxiv.toml @@ -11,7 +11,6 @@ source = "arxiv" response_format = { type = "xml", strip_namespaces = true } -# # TODO: Could flatten out the `field_maps`? [resource_mappings.title] path = "feed/entry/title" diff --git a/crates/learner/config_new/doi.toml b/crates/learner/config_new/doi.toml new file mode 100644 index 0000000..7edc052 --- /dev/null +++ b/crates/learner/config_new/doi.toml @@ -0,0 +1,59 @@ +name = "doi" + +description = "Retriever template for getting a paper from DOI/Crossref" + +base_url = "https://api.crossref.org/works" +endpoint_template = "https://api.crossref.org/works/{identifier}" + +pattern = "(?:^|https?://doi\\.org/)(10\\.\\d{4,9}/[-._;()/:\\w]+)$" +resource = "paper" +source = "doi" + +response_format = { type = "json" } + +[resource_mappings.title] +path = "message" +[resource_mappings.title.transform] +sources = [ + { type = "path", value = "title/0" }, + { type = "path", value = "subtitle/0" }, +] +type = "Compose" + +[resource_mappings.title.transform.format] +delimiter = ": " +type = "Join" + + +[resource_mappings.abstract] +path = "message/abstract" + +[resource_mappings.abstract.transform] +pattern = "<[^>]+>" +replacement = "" +type = "Replace" + +[resource_mappings.authors] +path = "message/author" +[resource_mappings.authors.transform] +sources = [ + { type = "key_value", value = { key = "family", path = "family" } }, + { type = "key_value", value = { key = "given", path = "given" } }, +] +type = "Compose" + +[resource_mappings.authors.transform.format] +template = { name = "{given} {family}" } +type = "ArrayOfObjects" + +[resource_mappings.publication_dates] +path = "message/created/date-time" + +[resource_mappings.pdf_url] +path = "message/link/0/URL" + +[resource_mappings.doi] +path = "message/DOI" + +[headers] +Accept = "application/json" diff --git a/crates/learner/config_new/iacr.toml b/crates/learner/config_new/iacr.toml new file mode 100644 index 0000000..bb8bb0a --- /dev/null +++ b/crates/learner/config_new/iacr.toml @@ -0,0 +1,39 @@ +name = "iacr" + +description = "Retriever template for getting a paper from IACR" + +base_url = "https://eprint.iacr.org" +endpoint_template = "https://eprint.iacr.org/oai?verb=GetRecord&identifier=oai:eprint.iacr.org:{identifier}&metadataPrefix=oai_dc" +pattern = "(?:^|https?://eprint\\.iacr\\.org/)(\\d{4}/\\d+)(?:\\.pdf)?$" +resource = "paper" +source = "iacr" + +response_format = { type = "xml", strip_namespaces = true } + +[resource_mappings.title] +path = "OAI-PMH/GetRecord/record/metadata/dc/title" + +[resource_mappings.abstract] +path = "OAI-PMH/GetRecord/record/metadata/dc/description" + +[resource_mappings.authors] +path = "OAI-PMH/GetRecord/record/metadata/dc" +[resource_mappings.authors.transform] +sources = [{ type = "key_value", value = { key = "name", path = "creator" } }] +type = "Compose" +[resource_mappings.authors.transform.format] +type = "Object" + +[resource_mappings.publication_dates] +path = "OAI-PMH/GetRecord/record/metadata/dc/date" + +[resource_mappings.pdf_url] +path = "OAI-PMH/GetRecord/record/metadata/dc/identifier" + +[resource_mappings.pdf_url.transform] +pattern = "^(https://eprint\\.iacr\\.org/\\d{4}/\\d+)$" +replacement = "$1.pdf" +type = "Replace" + +[headers] +Accept = "application/xml" diff --git a/crates/learner/config_new/paper.toml b/crates/learner/config_new/paper.toml index af838a9..e4b298f 100644 --- a/crates/learner/config_new/paper.toml +++ b/crates/learner/config_new/paper.toml @@ -1,7 +1,7 @@ description = "Configuration for a paper" name = "paper" -abstract_text = { field_type = "string", required = false } -authors = { field_type = "array", required = true, validation = { min_items = 1 } } -publication_date = { field_type = "string", required = true, validation = { datetime = true } } -title = { field_type = "string", required = true, validation = { min_length = 1, max_length = 500 } } +abstract_text = { field_type = "string", required = false } +authors = { field_type = "array", required = true, validation = { min_items = 1 } } +publication_dates = { field_type = "array", required = true, validation = { datetime = true } } +title = { field_type = "string", required = true, validation = { min_length = 1, max_length = 500 } } diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index f0faba0..eb588ba 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -71,26 +71,18 @@ async fn test_arxiv_pdf_from_paper() -> TestResult<()> { #[traced_test] #[tokio::test] async fn test_iacr_retriever_integration() -> TestResult<()> { - let ret_config_str = fs::read_to_string("config/retrievers/iacr.toml").expect( - "Failed to read config - file", - ); - let res_config_str = fs::read_to_string("config/resources/paper.toml").expect( - "Failed to read config - file", - ); - - let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); + let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); + let retriever: Config = manager.load_config("config_new/iacr.toml")?; // let resource: ResourceTemplate = toml::from_str(&res_config_str).expect("Failed to parse // config"); // // Test with a real IACR paper - let paper = retriever.retrieve_resource("2016/260").await.unwrap(); + let paper = retriever.inner().retrieve_resource("2016/260").await.unwrap(); // assert!(resource.validate(&paper)?); // TODO: validation already happens internally, to be fair // that validation may not be working totally right dbg!(&paper); - todo!("This isn't actually validating properly because right now the authors isn't right."); + todo!("This needs cleaned up."); // assert!(!paper.title.is_empty()); // assert!(!paper.authors.is_empty()); // assert!(!paper.abstract_text.is_empty()); @@ -125,20 +117,11 @@ async fn test_iacr_pdf_from_paper() -> TestResult<()> { #[tokio::test] #[traced_test] async fn test_doi_retriever_integration() -> TestResult<()> { - let ret_config_str = fs::read_to_string("config/retrievers/doi.toml").expect( - "Failed to read config - file", - ); - let res_config_str = fs::read_to_string("config/resources/paper.toml").expect( - "Failed to read config - file", - ); - - let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); - let resource: ResourceTemplate = toml::from_str(&res_config_str).expect("Failed to parse config"); + let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); + let retriever: Config = manager.load_config("config_new/doi.toml")?; // Test with a real DOI paper - let paper = retriever.retrieve_resource("10.1145/1327452.1327492").await?; + let paper = retriever.inner().retrieve_resource("10.1145/1327452.1327492").await?; // assert!(resource.validate(&paper)?); dbg!(&paper); // assert!(!paper.title.is_empty()); From a32614ea2c6956a7ca09d66a54eb48957fba835d Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 7 Dec 2024 18:37:42 -0700 Subject: [PATCH 42/73] refactor --- crates/learner/src/configuration.rs | 116 +-- crates/learner/src/database/mod.rs | 2 +- crates/learner/src/lib.rs | 49 +- crates/learner/src/record.rs | 2 + crates/learner/src/resource/mod.rs | 401 +-------- crates/learner/src/resource/paper.rs | 2 + crates/learner/src/resource/shared.rs | 24 - crates/learner/src/retriever/config.rs | 6 +- crates/learner/src/retriever/mod.rs | 24 +- crates/learner/src/retriever/response/mod.rs | 4 +- crates/learner/src/template.rs | 399 +++++++++ crates/learner/tests/lib.rs | 2 +- crates/learner/tests/llm/mod.rs | 52 +- .../workflows/database_operations/add.rs | 432 +++++----- .../workflows/database_operations/remove.rs | 786 +++++++++--------- .../tests/workflows/paper_retrieval.rs | 17 +- 16 files changed, 1073 insertions(+), 1245 deletions(-) create mode 100644 crates/learner/src/template.rs diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 949ea9f..8a4796f 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -1,90 +1,7 @@ -use resource::{FieldDefinition, TypeDefinition, ValidationRules}; use serde::de::DeserializeOwned; use super::*; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Config { - /// Name of this configuration - pub name: String, - /// Optional description - #[serde(default)] - pub description: Option, - #[serde(default)] - pub additional_fields: BTreeMap, - /// The specific configuration type - #[serde(flatten)] - pub item: T, -} - -// TODO: this is honestly probably dumb and needs refactored -impl Config { - pub fn inner(&self) -> &T { &self.item } -} - -// #[derive(Debug, Clone, Serialize, Deserialize)] -// pub struct FieldDefinition { -// /// Type of the field (should be a JSON Value type) -// pub field_type: String, -// /// Whether this field must be present -// #[serde(default)] -// pub required: bool, -// /// Human-readable description -// #[serde(default)] -// pub description: Option, -// /// Default value if field is absent -// #[serde(default)] -// pub default: Option, -// /// Optional validation rules -// #[serde(default)] -// pub validation: Option, - -// pub type_definition: Option, -// } - -// TODO: These two traits can probably be removed -pub trait Identifiable { - fn name(&self) -> String; -} - -pub trait Configurable: Sized { - type Config: Identifiable + for<'de> Deserialize<'de>; - fn as_map(&mut self) -> &mut BTreeMap; - - fn with_config(mut self, config: Self::Config) { self.as_map().insert(config.name(), config); } - - fn with_config_str(mut self, toml_str: &str) -> Result { - let config: Self::Config = toml::from_str(toml_str)?; - self.as_map().insert(config.name(), config); - Ok(self) - } - - fn with_config_file(self, path: impl AsRef) -> Result { - let content = std::fs::read_to_string(path)?; - self.with_config_str(&content) - } - - fn with_config_dir(self, dir: impl AsRef) -> Result { - let dir = dir.as_ref(); - if !dir.is_dir() { - return Err(LearnerError::Path(std::io::Error::new( - std::io::ErrorKind::NotFound, - "Config directory not found", - ))); - } - - let mut configurable = self; - for entry in std::fs::read_dir(dir)? { - let entry = entry?; - let path = entry.path(); - if path.extension().is_some_and(|ext| ext == "toml") { - configurable = configurable.with_config_file(path)?; - } - } - Ok(configurable) - } -} - pub struct ConfigurationManager { builder: config::ConfigBuilder, loaded_configs: BTreeMap, @@ -101,7 +18,7 @@ impl ConfigurationManager { } } - pub fn load_config(&mut self, path: impl AsRef) -> Result> + pub fn load_config(&mut self, path: impl AsRef) -> Result where T: DeserializeOwned + std::fmt::Debug { let path = path.as_ref(); let content = std::fs::read_to_string(path)?; @@ -111,40 +28,26 @@ impl ConfigurationManager { if std::any::type_name::() == std::any::type_name::() { if let Some(toml::Value::String(resource_name)) = raw_config.get("resource") { // Load the referenced resource - let resource_path = self.config_paths.join(format!("{}.toml", resource_name)); + let resource_path = self.config_paths.join(format!("{resource_name}.toml")); let resource_content = std::fs::read_to_string(resource_path)?; let resource_config: toml::Value = toml::from_str(&resource_content)?; - // Get just the fields we need (ignore name, description etc) - let resource_fields = resource_config - .as_table() - .and_then(|t| { - Some( - t.iter() - .filter(|(k, _)| !["name", "description"].contains(&k.as_str())) - .map(|(k, v)| (k.clone(), v.clone())) - .collect::>(), - ) - }) - .ok_or_else(|| config::ConfigError::Message("Invalid resource config structure".into())) - .unwrap(); - - // Replace the string reference with the resource fields + // Replace the string reference with the resource config if let Some(table) = raw_config.as_table_mut() { - table.insert("resource".into(), toml::Value::Table(resource_fields)); + table.insert("resource".into(), resource_config); } } } // Convert directly to final type - let typed_config: Config = raw_config.try_into()?; + let typed_config: T = raw_config.try_into()?; Ok(typed_config) } } #[cfg(test)] mod tests { - use resource::ResourceTemplate; + use template::Template; use super::*; @@ -153,12 +56,11 @@ mod tests { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); // Load configurations in order - let paper: Config = - dbg!(manager.load_config("config_new/paper.toml").unwrap()); + let paper: Template = dbg!(manager.load_config("config_new/paper.toml").unwrap()); - let arxiv_retriever: Config = - dbg!(manager.load_config("config_new/arxiv.toml").unwrap()); + let arxiv_retriever: Retriever = dbg!(manager.load_config("config_new/arxiv.toml").unwrap()); + todo!("Clean this up") // The paper_record now has all fields from base_resource and paper, // plus its own record-specific configuration diff --git a/crates/learner/src/database/mod.rs b/crates/learner/src/database/mod.rs index d5cd6d5..99d0aa6 100644 --- a/crates/learner/src/database/mod.rs +++ b/crates/learner/src/database/mod.rs @@ -52,10 +52,10 @@ //! # } //! ``` +use resource::{paper::Paper, shared::Author}; use tokio_rusqlite::Connection; use super::*; - mod instruction; #[cfg(test)] mod tests; diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index e13f5b8..417b7f2 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -153,7 +153,6 @@ use chrono::{DateTime, Utc}; use lazy_static::lazy_static; use regex::Regex; use reqwest::Url; -use resource::{Resource, ResourceTemplate, Resources}; use serde::{Deserialize, Serialize}; use serde_json::Value; use tracing::{debug, trace, warn}; @@ -171,14 +170,9 @@ pub mod llm; pub mod pdf; pub mod record; pub mod resource; +pub mod template; -use crate::{ - database::*, - error::*, - prelude::*, - resource::{Author, Paper}, - retriever::*, -}; +use crate::{database::*, error::*, prelude::*, retriever::*}; /// ArXiv default configuration pub const ARXIV_CONFIG: &str = include_str!("../config/retrievers/arxiv.toml"); @@ -223,11 +217,7 @@ pub const THESIS_CONFIG: &str = include_str!("../config/resources/thesis.toml"); /// } /// ``` pub mod prelude { - pub use crate::{ - configuration::{Configurable, Identifiable}, - database::DatabaseInstruction, - error::LearnerError, - }; + pub use crate::{database::DatabaseInstruction, error::LearnerError}; } /// Core configuration for the library. @@ -298,8 +288,8 @@ pub struct Learner { pub database: Database, /// Paper retrieval system pub retrievers: Retrievers, - /// Resources to use - pub resources: Resources, + // / Resources to use + // pub resources: Resources, } /// Builder for creating configured Learner instances. @@ -593,7 +583,7 @@ impl LearnerBuilder { // let retriever = Retrievers::new().with_config_dir(&config.retrievers_path)?; // let resources = Resources::new().with_config_dir(&config.resources_path)?; - Ok(Learner { config, database, retrievers: Retrievers::new(), resources: Resources::new() }) + Ok(Learner { config, database, retrievers: Retrievers::new() }) } } @@ -745,33 +735,6 @@ impl Learner { /// # } /// ``` pub async fn init() -> Result { Self::with_config(Config::init()?).await } - - pub async fn retreive(&mut self, input: &str) -> Result { - // let mut matches = Vec::new(); - - // // Find all configs that match the input - // for (name, config) in self.retrievers.as_map().iter() { - // if config.pattern.is_match(input) { - // matches.push((name, config)); - // } - // } - - todo!("Finish this") - // match matches.len() { - // 0 => Err(LearnerError::InvalidIdentifier), - // 1 => { - // let resource_config = self.resources.as_map().get(matches[0].0); - // if let Some(resource_config) = resource_config { - // Ok(matches[0].1.retrieve_resource(input, resource_config).await?) - // } else { - // todo!("Error because that resource wasn't available.") - // } - // }, - // _ => Err(LearnerError::AmbiguousIdentifier( - // matches.into_iter().map(|(n, c)| n.to_string()).collect(), - // )), - // } - } } #[cfg(test)] diff --git a/crates/learner/src/record.rs b/crates/learner/src/record.rs index 88f8f9e..4c8efab 100644 --- a/crates/learner/src/record.rs +++ b/crates/learner/src/record.rs @@ -1,3 +1,5 @@ +use template::Resource; + use super::*; // TODO: We probably want a `RecordTemplate` just like we have for `Resource`. Perhaps we can just diff --git a/crates/learner/src/resource/mod.rs b/crates/learner/src/resource/mod.rs index b066a15..e6a7f4a 100644 --- a/crates/learner/src/resource/mod.rs +++ b/crates/learner/src/resource/mod.rs @@ -1,401 +1,6 @@ -use std::{collections::HashSet, str::FromStr}; +// temp mod to keep database happy use super::*; -mod paper; -mod shared; - -pub use paper::*; -use serde_json::Value; -pub use shared::*; - -// Type alias for clarity and consistency -pub type Resource = BTreeMap; - -#[derive(Debug, Clone, Default)] -pub struct Resources { - templates: BTreeMap, -} - -impl Resources { - pub fn new() -> Self { Self::default() } -} - -#[derive(Debug, Clone, Serialize)] -pub struct ResourceTemplate { - /// Field definitions with optional metadata - #[serde(default)] - // #[serde(flatten)] - pub fields: Vec, -} -impl<'de> Deserialize<'de> for ResourceTemplate { - fn deserialize(deserializer: D) -> std::result::Result - where D: serde::Deserializer<'de> { - // First deserialize into a map - let map: BTreeMap = BTreeMap::deserialize(deserializer)?; - - // Convert the map into a Vec, setting the name from the key - let fields = map - .into_iter() - .map(|(key, mut field_def)| { - field_def.name = key; - field_def - }) - .collect(); - - Ok(ResourceTemplate { fields }) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FieldDefinition { - /// Name of the field - #[serde(skip_deserializing)] - pub name: String, - /// Type of the field (should be a JSON Value type) - pub field_type: String, - /// Whether this field must be present - #[serde(default)] - pub required: bool, - /// Human-readable description - #[serde(default)] - pub description: Option, - /// Default value if field is absent - #[serde(default)] - pub default: Option, - /// Optional validation rules - #[serde(default)] - pub validation: Option, - - pub type_definition: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TypeDefinition { - // For array types, defines the structure of elements - pub element_type: Option>, - // For table types, defines the fields - pub fields: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct ValidationRules { - // String validations - pub pattern: Option, // Regex pattern to match - pub min_length: Option, // Minimum string length - pub max_length: Option, // Maximum string length - - // Numeric validations - pub minimum: Option, // Minimum value - pub maximum: Option, // Maximum value - pub multiple_of: Option, // Must be multiple of this value - - // Array validations - pub min_items: Option, // Minimum array length - pub max_items: Option, // Maximum array length - pub unique_items: Option, // Whether items must be unique - - // General validations - pub enum_values: Option>, // List of allowed values - pub datetime: Option, // Validates RFC3339 format -} - -impl ResourceTemplate { - /// Validates a set of values against this resource configuration - pub fn validate(&self, resource: &Resource) -> Result { - // Check required fields - for field in &self.fields { - if field.required && !resource.contains_key(&field.name) { - return Err(LearnerError::InvalidResource(format!( - "Missing required field: {}", - field.name - ))); - } - } - - // Validate each provided field - for (name, value) in resource { - if let Some(field) = self.fields.iter().find(|f| f.name == *name) { - // Validate field value against its definition - self.validate_field(field, value)?; - } - } - - Ok(true) - } - - /// Validates a single field value against its definition - fn validate_field(&self, field: &FieldDefinition, value: &Value) -> Result<()> { - match (field.field_type.as_str(), value) { - // String validation - handles both basic type checking and string-specific rules - ("string", Value::String(v)) => { - if let Some(rules) = &field.validation { - // Length constraints - if let Some(min_length) = rules.min_length { - if v.len() < min_length { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' must be at least {} characters", - field.name, min_length - ))); - } - } - if let Some(max_length) = rules.max_length { - if v.len() > max_length { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' cannot exceed {} characters", - field.name, max_length - ))); - } - } - - // Pattern matching via regex - if let Some(pattern) = &rules.pattern { - let re = Regex::new(pattern) - .map_err(|_| LearnerError::InvalidResource("Invalid regex pattern".into()))?; - if !re.is_match(v) { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' must match pattern: {}", - field.name, pattern - ))); - } - } - - // Datetime validation if specified - if rules.datetime == Some(true) { - if DateTime::parse_from_rfc3339(v).is_err() { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' must be a valid RFC3339 datetime", - field.name - ))); - } - } - - // Enumerated values check - if let Some(allowed) = &rules.enum_values { - if !allowed.contains(v) { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' must be one of: {:?}", - field.name, allowed - ))); - } - } - } - Ok(()) - }, - - // Numeric validations - handle both number types - ("number", Value::Number(n)) => { - if let Some(rules) = &field.validation { - if let Some(num) = n.as_f64() { - validate_numeric(field, num, rules)?; - } - } - Ok(()) - }, - - // Array validation - handles array-specific rules - ("array", Value::Array(v)) => { - if let Some(rules) = &field.validation { - if let Some(min_items) = rules.min_items { - if v.len() < min_items { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' must have at least {} items", - field.name, min_items - ))); - } - } - - if let Some(max_items) = rules.max_items { - if v.len() > max_items { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' cannot exceed {} items", - field.name, max_items - ))); - } - } - - if rules.unique_items == Some(true) { - let mut seen = HashSet::new(); - for item in v { - let item_str = serde_json::to_string(item).map_err(|_| { - LearnerError::InvalidResource("Failed to serialize array item".into()) - })?; - if !seen.insert(item_str) { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' contains duplicate items", - field.name - ))); - } - } - } - } - Ok(()) - }, - - // Simple type validations - just ensure type matches - ("boolean", Value::Bool(_)) => Ok(()), - ("object", Value::Object(_)) => Ok(()), - ("null", Value::Null) => Ok(()), - - // Type mismatch - provide a clear error message - _ => Err(LearnerError::InvalidResource(format!( - "Field '{}' expected type '{}' but got '{}'", - field.name, - field.field_type, - match value { - Value::String(_) => "string", - Value::Number(_) => "number", - Value::Bool(_) => "boolean", - Value::Array(_) => "array", - Value::Object(_) => "object", - Value::Null => "null", - } - ))), - } - } -} - -fn validate_numeric(field: &FieldDefinition, value: f64, rules: &ValidationRules) -> Result<()> { - if let Some(min) = rules.minimum { - if value < min { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' must be at least {}", - field.name, min - ))); - } - } - - if let Some(max) = rules.maximum { - if value > max { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' cannot exceed {}", - field.name, max - ))); - } - } - - if let Some(multiple) = rules.multiple_of { - let ratio = value / multiple; - if (ratio - ratio.round()).abs() > f64::EPSILON { - return Err(LearnerError::InvalidResource(format!( - "Field '{}' must be a multiple of {}", - field.name, multiple - ))); - } - } - - Ok(()) -} - -/// Convert DateTime to RFC3339 string for JSON storage -pub fn datetime_to_json(dt: DateTime) -> String { dt.to_rfc3339() } - -/// Parse RFC3339 string from JSON into DateTime -pub fn datetime_from_json(s: &str) -> Result> { - DateTime::parse_from_rfc3339(s) - .map(|dt| dt.with_timezone(&Utc)) - .map_err(|e| LearnerError::InvalidResource(format!("Invalid datetime format: {}", e))) -} -#[cfg(test)] -mod tests { - use chrono::TimeZone; - use serde_json::json; - - use super::*; - - #[test] - fn validate_paper_configuration() { - let config = include_str!("../../config/resources/paper.toml"); - let config: ResourceTemplate = toml::from_str(config).unwrap(); - - let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - - // Create a valid paper resource - let paper_resource = BTreeMap::from([ - ("title".into(), json!("Understanding Quantum Computing")), - ( - "authors".into(), - json!([{ - "name": "Alice Researcher", - "affiliation": "Tech University" - }]), - ), - ("publication_date".into(), json!(date)), - ("doi".into(), json!("10.1234/example.123")), - ]); - - // Validate the paper - assert!(config.validate(&paper_resource).unwrap()); - - // Test required field validation - let invalid_paper = BTreeMap::from([ - ("authors".into(), json!([])), // Missing title - ]); - assert!(config.validate(&invalid_paper).is_err()); - } - - #[test] - fn validate_book_configuration() { - let config = include_str!("../../config/resources/book.toml"); - let config: ResourceTemplate = toml::from_str(config).unwrap(); - - let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - - let book_resource = BTreeMap::from([ - ("title".into(), json!("Advanced Quantum Computing")), - ("authors".into(), json!(["Alice Writer", "Bob Author"])), - ("isbn".into(), json!("978-0-12-345678-9")), - ("publisher".into(), json!("Academic Press")), - ("publication_date".into(), json!(date)), - ]); - - assert!(config.validate(&book_resource).unwrap()); - } - - #[test] - fn validate_thesis_configuration() { - let config = include_str!("../../config/resources/thesis.toml"); - let config: ResourceTemplate = toml::from_str(config).unwrap(); - - let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - - let thesis_resource = BTreeMap::from([ - ("title".into(), json!("Novel Approaches to Quantum Error Correction")), - ("author".into(), json!("Alice Researcher")), - ("degree".into(), json!("PhD")), - ("institution".into(), json!("Tech University")), - ("completion_date".into(), json!(date)), - ("advisors".into(), json!(["Prof. Bob Supervisor"])), - ]); - - assert!(config.validate(&thesis_resource).unwrap()); - - // Test degree enum validation - let mut invalid_thesis = thesis_resource.clone(); - invalid_thesis.insert("degree".into(), json!("InvalidDegree")); - assert!(config.validate(&invalid_thesis).is_err()); - } - - #[test] - fn test_datetime_validation() { - let mut config = ResourceTemplate { - fields: vec![FieldDefinition { - name: "timestamp".into(), - field_type: "string".into(), - required: true, - description: None, - default: None, - validation: Some(ValidationRules { datetime: Some(true), ..Default::default() }), - type_definition: None, - }], - }; - - let valid_resource = BTreeMap::from([("timestamp".into(), json!("2024-01-01T00:00:00Z"))]); - assert!(config.validate(&valid_resource).unwrap()); - - let invalid_resource = BTreeMap::from([ - ("timestamp".into(), json!("2024-01-01")), // Not RFC3339 - ]); - assert!(config.validate(&invalid_resource).is_err()); - } -} +pub mod paper; +pub mod shared; diff --git a/crates/learner/src/resource/paper.rs b/crates/learner/src/resource/paper.rs index 744cc6f..7742713 100644 --- a/crates/learner/src/resource/paper.rs +++ b/crates/learner/src/resource/paper.rs @@ -42,6 +42,8 @@ //! # } //! ``` +use shared::Author; + use super::*; /// Complete representation of an academic paper with metadata. diff --git a/crates/learner/src/resource/shared.rs b/crates/learner/src/resource/shared.rs index 6410120..6a4163b 100644 --- a/crates/learner/src/resource/shared.rs +++ b/crates/learner/src/resource/shared.rs @@ -1,29 +1,5 @@ -//! Shared types for various academic resources. -//! -//! This module contains common data structures used across different types of -//! resources in the learner library. These shared types help maintain consistency -//! and reduce duplication in how we represent common academic concepts like -//! authorship, publication details, and citations. - use super::*; -/// Author information for academic papers. -/// -/// Represents a single author of a paper, including their name and optional -/// institutional details. This struct supports varying levels of author -/// information availability across different sources. -/// -/// # Examples -/// -/// ``` -/// use learner::resource::Author; -/// -/// let author = Author { -/// name: "Alice Researcher".to_string(), -/// affiliation: Some("Example University".to_string()), -/// email: Some("alice@example.edu".to_string()), -/// }; -/// ``` #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Author { /// Author's full name diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index f7875a9..1c53ba0 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,7 +1,7 @@ // use resource::Resource; use record::RetrievalData; -use resource::ResourceTemplate; +use template::{Resource, Template}; use super::*; @@ -9,7 +9,9 @@ use super::*; #[derive(Debug, Clone, Deserialize)] pub struct Retriever { - pub resource: ResourceTemplate, + pub name: String, + pub description: Option, + pub resource: Template, #[serde(skip_deserializing)] #[serde(default)] pub retrieval_data: RetrievalData, diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 5a7cdc8..a1959c7 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -5,6 +5,7 @@ mod response; pub use config::*; pub use response::*; +use template::Resource; // TODO: This should be `BTreeMap>` #[derive(Default, Debug, Clone)] @@ -55,28 +56,7 @@ impl Retrievers { /// ``` pub fn new() -> Self { Self::default() } - #[deprecated] - pub async fn get_paper(&self, input: &str) -> Result { - let mut matches = Vec::new(); - - // Find all configs that match the input - for config in self.configs.values() { - if config.pattern.is_match(input) { - matches.push(config); - } - } - - todo!("Fix this") - // match matches.len() { - // 0 => Err(LearnerError::InvalidIdentifier), - // 1 => matches[0].retrieve_paper(input).await, - // _ => Err(LearnerError::AmbiguousIdentifier( - // matches.into_iter().map(|c| c.name.clone()).collect(), - // )), - // } - } - - pub async fn get_resource(&self, input: &str) -> Result { + pub async fn get_resource(&self, input: &str) -> Result { todo!( "Arguably, we don't even need this. We could instead just have this handled by `Learner` so \ the API is simpler" diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index d6bbc0a..212459a 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -1,5 +1,5 @@ -use resource::{FieldDefinition, ResourceTemplate}; use serde_json::Map; +use template::{FieldDefinition, Template}; use super::*; @@ -145,7 +145,7 @@ pub enum ComposeFormat { pub fn process_json_value( json: &Value, field_maps: &BTreeMap, - resource_config: &ResourceTemplate, + resource_config: &Template, ) -> Result { let mut resource = Resource::new(); diff --git a/crates/learner/src/template.rs b/crates/learner/src/template.rs new file mode 100644 index 0000000..1bfbc59 --- /dev/null +++ b/crates/learner/src/template.rs @@ -0,0 +1,399 @@ +use std::collections::HashSet; + +use super::*; + +// Type alias for clarity and consistency +pub type Resource = BTreeMap; + +#[derive(Debug, Clone, Serialize)] +pub struct Template { + pub name: String, + #[serde(default)] + pub description: Option, + #[serde(default)] + pub fields: Vec, +} +impl<'de> Deserialize<'de> for Template { + fn deserialize(deserializer: D) -> std::result::Result + where D: serde::Deserializer<'de> { + // Helper struct to capture the raw TOML structure + #[derive(Deserialize)] + struct TemplateHelper { + name: String, + #[serde(default)] + description: Option, + #[serde(flatten)] + fields: BTreeMap, + } + + // Deserialize into our helper first + let helper = TemplateHelper::deserialize(deserializer)?; + + // Convert the field map into a Vec, setting the name from the key + // Filter out the metadata fields we don't want to treat as FieldDefinitions + let fields = helper + .fields + .into_iter() + .filter(|(key, _)| key != "name" && key != "description") + .map(|(key, mut field_def)| { + field_def.name = key; + field_def + }) + .collect(); + + Ok(Template { name: helper.name, description: helper.description, fields }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FieldDefinition { + /// Name of the field + #[serde(skip_deserializing)] + pub name: String, + /// Type of the field (should be a JSON Value type) + pub field_type: String, + /// Whether this field must be present + #[serde(default)] + pub required: bool, + /// Human-readable description + #[serde(default)] + pub description: Option, + /// Default value if field is absent + #[serde(default)] + pub default: Option, + /// Optional validation rules + #[serde(default)] + pub validation: Option, + + pub type_definition: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypeDefinition { + // For array types, defines the structure of elements + pub element_type: Option>, + // For table types, defines the fields + pub fields: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ValidationRules { + // String validations + pub pattern: Option, // Regex pattern to match + pub min_length: Option, // Minimum string length + pub max_length: Option, // Maximum string length + + // Numeric validations + pub minimum: Option, // Minimum value + pub maximum: Option, // Maximum value + pub multiple_of: Option, // Must be multiple of this value + + // Array validations + pub min_items: Option, // Minimum array length + pub max_items: Option, // Maximum array length + pub unique_items: Option, // Whether items must be unique + + // General validations + pub enum_values: Option>, // List of allowed values + pub datetime: Option, // Validates RFC3339 format +} + +impl Template { + /// Validates a set of values against this resource configuration + pub fn validate(&self, resource: &Resource) -> Result { + // Check required fields + for field in &self.fields { + if field.required && !resource.contains_key(&field.name) { + return Err(LearnerError::InvalidResource(format!( + "Missing required field: {}", + field.name + ))); + } + } + + // Validate each provided field + for (name, value) in resource { + if let Some(field) = self.fields.iter().find(|f| f.name == *name) { + // Validate field value against its definition + self.validate_field(field, value)?; + } + } + + Ok(true) + } + + /// Validates a single field value against its definition + fn validate_field(&self, field: &FieldDefinition, value: &Value) -> Result<()> { + match (field.field_type.as_str(), value) { + // String validation - handles both basic type checking and string-specific rules + ("string", Value::String(v)) => { + if let Some(rules) = &field.validation { + // Length constraints + if let Some(min_length) = rules.min_length { + if v.len() < min_length { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be at least {} characters", + field.name, min_length + ))); + } + } + if let Some(max_length) = rules.max_length { + if v.len() > max_length { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' cannot exceed {} characters", + field.name, max_length + ))); + } + } + + // Pattern matching via regex + if let Some(pattern) = &rules.pattern { + let re = Regex::new(pattern) + .map_err(|_| LearnerError::InvalidResource("Invalid regex pattern".into()))?; + if !re.is_match(v) { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must match pattern: {}", + field.name, pattern + ))); + } + } + + // Datetime validation if specified + if rules.datetime == Some(true) && DateTime::parse_from_rfc3339(v).is_err() { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be a valid RFC3339 datetime", + field.name + ))); + } + + // Enumerated values check + if let Some(allowed) = &rules.enum_values { + if !allowed.contains(v) { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be one of: {:?}", + field.name, allowed + ))); + } + } + } + Ok(()) + }, + + // Numeric validations - handle both number types + ("number", Value::Number(n)) => { + if let Some(rules) = &field.validation { + if let Some(num) = n.as_f64() { + validate_numeric(field, num, rules)?; + } + } + Ok(()) + }, + + // Array validation - handles array-specific rules + ("array", Value::Array(v)) => { + if let Some(rules) = &field.validation { + if let Some(min_items) = rules.min_items { + if v.len() < min_items { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must have at least {} items", + field.name, min_items + ))); + } + } + + if let Some(max_items) = rules.max_items { + if v.len() > max_items { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' cannot exceed {} items", + field.name, max_items + ))); + } + } + + if rules.unique_items == Some(true) { + let mut seen = HashSet::new(); + for item in v { + let item_str = serde_json::to_string(item).map_err(|_| { + LearnerError::InvalidResource("Failed to serialize array item".into()) + })?; + if !seen.insert(item_str) { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' contains duplicate items", + field.name + ))); + } + } + } + } + Ok(()) + }, + + // Simple type validations - just ensure type matches + ("boolean", Value::Bool(_)) => Ok(()), + ("object", Value::Object(_)) => Ok(()), + ("null", Value::Null) => Ok(()), + + // Type mismatch - provide a clear error message + _ => Err(LearnerError::InvalidResource(format!( + "Field '{}' expected type '{}' but got '{}'", + field.name, + field.field_type, + match value { + Value::String(_) => "string", + Value::Number(_) => "number", + Value::Bool(_) => "boolean", + Value::Array(_) => "array", + Value::Object(_) => "object", + Value::Null => "null", + } + ))), + } + } +} + +fn validate_numeric(field: &FieldDefinition, value: f64, rules: &ValidationRules) -> Result<()> { + if let Some(min) = rules.minimum { + if value < min { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be at least {}", + field.name, min + ))); + } + } + + if let Some(max) = rules.maximum { + if value > max { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' cannot exceed {}", + field.name, max + ))); + } + } + + if let Some(multiple) = rules.multiple_of { + let ratio = value / multiple; + if (ratio - ratio.round()).abs() > f64::EPSILON { + return Err(LearnerError::InvalidResource(format!( + "Field '{}' must be a multiple of {}", + field.name, multiple + ))); + } + } + + Ok(()) +} + +/// Convert DateTime to RFC3339 string for JSON storage +pub fn datetime_to_json(dt: DateTime) -> String { dt.to_rfc3339() } + +/// Parse RFC3339 string from JSON into DateTime +pub fn datetime_from_json(s: &str) -> Result> { + DateTime::parse_from_rfc3339(s) + .map(|dt| dt.with_timezone(&Utc)) + .map_err(|e| LearnerError::InvalidResource(format!("Invalid datetime format: {}", e))) +} +#[cfg(test)] +mod tests { + use chrono::TimeZone; + use serde_json::json; + + use super::*; + + #[test] + fn validate_paper_configuration() { + let template = include_str!("../config/resources/paper.toml"); + let template: Template = toml::from_str(template).unwrap(); + + let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); + + // Create a valid paper resource + let paper_resource = BTreeMap::from([ + ("title".into(), json!("Understanding Quantum Computing")), + ( + "authors".into(), + json!([{ + "name": "Alice Researcher", + "affiliation": "Tech University" + }]), + ), + ("publication_date".into(), json!(date)), + ("doi".into(), json!("10.1234/example.123")), + ]); + + // Validate the paper + assert!(template.validate(&paper_resource).unwrap()); + + // Test required field validation + let invalid_paper = BTreeMap::from([ + ("authors".into(), json!([])), // Missing title + ]); + assert!(template.validate(&invalid_paper).is_err()); + } + + #[test] + fn validate_book_configuration() { + let template = include_str!("../config/resources/book.toml"); + let template: Template = toml::from_str(template).unwrap(); + + let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); + + let book_resource = BTreeMap::from([ + ("title".into(), json!("Advanced Quantum Computing")), + ("authors".into(), json!(["Alice Writer", "Bob Author"])), + ("isbn".into(), json!("978-0-12-345678-9")), + ("publisher".into(), json!("Academic Press")), + ("publication_date".into(), json!(date)), + ]); + + assert!(template.validate(&book_resource).unwrap()); + } + + #[test] + fn validate_thesis_configuration() { + let template = include_str!("../config/resources/thesis.toml"); + let template: Template = toml::from_str(template).unwrap(); + + let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); + + let thesis_resource = BTreeMap::from([ + ("title".into(), json!("Novel Approaches to Quantum Error Correction")), + ("author".into(), json!("Alice Researcher")), + ("degree".into(), json!("PhD")), + ("institution".into(), json!("Tech University")), + ("completion_date".into(), json!(date)), + ("advisors".into(), json!(["Prof. Bob Supervisor"])), + ]); + + assert!(template.validate(&thesis_resource).unwrap()); + + // Test degree enum validation + let mut invalid_thesis = thesis_resource.clone(); + invalid_thesis.insert("degree".into(), json!("InvalidDegree")); + assert!(template.validate(&invalid_thesis).is_err()); + } + + #[test] + fn test_datetime_validation() { + let template = Template { + name: "Test Template".to_string(), + description: None, + fields: vec![FieldDefinition { + name: "timestamp".into(), + field_type: "string".into(), + required: true, + description: None, + default: None, + validation: Some(ValidationRules { datetime: Some(true), ..Default::default() }), + type_definition: None, + }], + }; + + let valid_resource = BTreeMap::from([("timestamp".into(), json!("2024-01-01T00:00:00Z"))]); + assert!(template.validate(&valid_resource).unwrap()); + + let invalid_resource = BTreeMap::from([ + ("timestamp".into(), json!("2024-01-01")), // Not RFC3339 + ]); + assert!(template.validate(&invalid_resource).is_err()); + } +} diff --git a/crates/learner/tests/lib.rs b/crates/learner/tests/lib.rs index 0a88dd5..a8450fd 100644 --- a/crates/learner/tests/lib.rs +++ b/crates/learner/tests/lib.rs @@ -9,7 +9,7 @@ use learner::{ llm::{LlamaRequest, Model}, pdf::PDFContentBuilder, prelude::*, - resource::{Author, Paper}, + resource::{paper::Paper, shared::Author}, Config, Learner, }; use tempfile::{tempdir, TempDir}; diff --git a/crates/learner/tests/llm/mod.rs b/crates/learner/tests/llm/mod.rs index aa4658b..0a91429 100644 --- a/crates/learner/tests/llm/mod.rs +++ b/crates/learner/tests/llm/mod.rs @@ -1,34 +1,34 @@ -use learner::database::Add; +// use learner::database::Add; -use super::*; +// use super::*; -#[ignore = "Can't run this in general -- relies on local LLM endpoint."] -#[tokio::test] -#[traced_test] -async fn test_download_then_send_pdf() -> Result<(), Box> { - // Download a PDF - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retrievers.get_paper("https://eprint.iacr.org/2016/260").await?; - // let paper = Paper::new("https://eprint.iacr.org/2016/260").await.unwrap(); +// #[ignore = "Can't run this in general -- relies on local LLM endpoint."] +// #[tokio::test] +// #[traced_test] +// async fn test_download_then_send_pdf() -> Result<(), Box> { +// // Download a PDF +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let paper = learner.retrievers.get_paper("https://eprint.iacr.org/2016/260").await?; +// // let paper = Paper::new("https://eprint.iacr.org/2016/260").await.unwrap(); - // paper.download_pdf(dir.path()).await.unwrap(); - Add::complete(&paper).execute(&mut learner.database).await?; +// // paper.download_pdf(dir.path()).await.unwrap(); +// Add::complete(&paper).execute(&mut learner.database).await?; - // Get the content of the PDF +// // Get the content of the PDF - let path = learner.database.get_storage_path().await?.join(paper.filename()); - let pdf_content = PDFContentBuilder::new().path(path).analyze()?; +// let path = learner.database.get_storage_path().await?.join(paper.filename()); +// let pdf_content = PDFContentBuilder::new().path(path).analyze()?; - let mut message = - "Please act like a researcher and digest this text from a PDF for me and give me an \ - excellent summary. The summary can be long and descriptive. \n" - .to_owned(); +// let mut message = +// "Please act like a researcher and digest this text from a PDF for me and give me an \ +// excellent summary. The summary can be long and descriptive. \n" +// .to_owned(); - message.push_str(&serde_json::to_string(&pdf_content.metadata).unwrap()); - message.push_str(&serde_json::to_string(&pdf_content.pages[0..5]).unwrap()); +// message.push_str(&serde_json::to_string(&pdf_content.metadata).unwrap()); +// message.push_str(&serde_json::to_string(&pdf_content.pages[0..5]).unwrap()); - let response = - LlamaRequest::new().with_model(Model::Llama3p2c3b).with_message(&message).send().await?; - dbg!(response.message); - Ok(()) -} +// let response = +// LlamaRequest::new().with_model(Model::Llama3p2c3b).with_message(&message).send().await?; +// dbg!(response.message); +// Ok(()) +// } diff --git a/crates/learner/tests/workflows/database_operations/add.rs b/crates/learner/tests/workflows/database_operations/add.rs index 5fb307d..5ad6439 100644 --- a/crates/learner/tests/workflows/database_operations/add.rs +++ b/crates/learner/tests/workflows/database_operations/add.rs @@ -1,216 +1,216 @@ -use super::*; - -/// Basic paper addition tests -mod basic_operations { - - use super::*; - - #[traced_test] - #[tokio::test] - async fn test_add_paper() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = create_test_paper(); - - let papers = Add::paper(&paper).execute(&mut learner.database).await?; - assert_eq!(papers.len(), 1); - assert_eq!(papers[0].title, paper.title); - - // Verify paper exists in database - let stored = Query::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - assert_eq!(stored.len(), 1); - assert_eq!(stored[0].title, paper.title); - - Ok(()) - } - - #[traced_test] - #[tokio::test] - async fn test_add_paper_twice() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = create_test_paper(); - - Add::paper(&paper).execute(&mut learner.database).await?; - let err = Add::paper(&paper).execute(&mut learner.database).await.unwrap_err(); - - assert!(matches!(err, LearnerError::DatabaseDuplicatePaper(_))); - - // Verify only one copy exists - let stored = Query::list_all().execute(&mut learner.database).await?; - assert_eq!(stored.len(), 1); - - Ok(()) - } - - #[traced_test] - #[tokio::test] - async fn test_add_paper_with_authors() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let mut paper = create_test_paper(); - paper.authors = vec![ - Author { - name: "Test Author 1".into(), - affiliation: Some("University 1".into()), - email: Some("email1@test.com".into()), - }, - Author { name: "Test Author 2".into(), affiliation: None, email: None }, - ]; - - Add::paper(&paper).execute(&mut learner.database).await?; - - // Verify authors were stored - let stored = Query::by_author("Test Author 1").execute(&mut learner.database).await?; - assert_eq!(stored.len(), 1); - assert_eq!(stored[0].authors.len(), 2); - assert_eq!(stored[0].authors[0].affiliation, Some("University 1".into())); - assert_eq!(stored[0].authors[1].name, "Test Author 2"); - - Ok(()) - } -} - -/// Tests for paper addition with documents -mod document_operations { - - use super::*; - - #[traced_test] - #[tokio::test] - async fn test_add_complete_paper() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; - - let papers = Add::complete(&paper).execute(&mut learner.database).await?; - assert_eq!(papers.len(), 1); - - // Verify both paper and document were added - let stored = Query::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - assert_eq!(stored.len(), 1); - - // Verify PDF exists in storage location - let storage_path = learner.database.get_storage_path().await?; - let pdf_path = storage_path.join(paper.filename()); - assert!(pdf_path.exists(), "PDF file should exist at {:?}", pdf_path); - - Ok(()) - } - - #[traced_test] - #[tokio::test] - async fn test_add_paper_then_document() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; - - // First add paper only - Add::paper(&paper).execute(&mut learner.database).await?; - - // Then add with document - let papers = Add::complete(&paper).execute(&mut learner.database).await?; - assert_eq!(papers.len(), 1); - - // Verify PDF exists - let storage_path = learner.database.get_storage_path().await?; - let pdf_path = storage_path.join(paper.filename()); - assert!(pdf_path.exists()); - - assert!(logs_contain( - "WARN test_add_paper_then_document: learner::database::instruction::add: Tried to add \ - complete paper when paper existed in database already, attempting to add only the document!" - )); - Ok(()) - } - - #[traced_test] - #[tokio::test] - async fn test_chain_document_addition() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; - - let papers = Add::paper(&paper).with_document().execute(&mut learner.database).await?; - assert_eq!(papers.len(), 1); - - // Verify PDF exists - let storage_path = learner.database.get_storage_path().await?; - let pdf_path = storage_path.join(paper.filename()); - assert!(pdf_path.exists()); - - Ok(()) - } - - #[traced_test] - #[tokio::test] - async fn test_add_documents_by_query() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - // Add multiple papers without documents - let paper1 = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; - let paper2 = learner.retrievers.get_paper("https://eprint.iacr.org/2016/260").await?; - Add::paper(&paper1).execute(&mut learner.database).await?; - Add::paper(&paper2).execute(&mut learner.database).await?; - - // Add documents for all papers - let papers = Add::documents(Query::list_all()).execute(&mut learner.database).await?; - assert_eq!(papers.len(), 2); - - // Verify PDFs exist - let storage_path = learner.database.get_storage_path().await?; - for paper in papers { - let pdf_path = storage_path.join(paper.filename()); - assert!(pdf_path.exists(), "PDF should exist for {}", paper.source_identifier); - } - - Ok(()) - } -} - -/// Edge case tests -mod edge_cases { - use super::*; - - #[traced_test] - #[tokio::test] - async fn test_add_paper_with_special_characters() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let mut paper = create_test_paper(); - paper.title = "Test & Paper: A Study!".into(); - paper.abstract_text = "Abstract with & and other symbols: @#$%".into(); - - let papers = Add::paper(&paper).execute(&mut learner.database).await?; - assert_eq!(papers.len(), 1); - assert_eq!(papers[0].title, paper.title); - - Ok(()) - } - - #[traced_test] - #[tokio::test] - async fn test_add_empty_author_list() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let mut paper = create_test_paper(); - paper.authors.clear(); - - let papers = Add::paper(&paper).execute(&mut learner.database).await?; - assert_eq!(papers.len(), 1); - assert!(papers[0].authors.is_empty()); - - Ok(()) - } - - #[traced_test] - #[tokio::test] - async fn test_add_paper_with_optional_fields() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let mut paper = create_test_paper(); - paper.doi = Some("10.1234/test".into()); - paper.pdf_url = Some("https://example.com/paper.pdf".into()); - - let papers = Add::paper(&paper).execute(&mut learner.database).await?; - assert_eq!(papers[0].doi, Some("10.1234/test".into())); - assert_eq!(papers[0].pdf_url, Some("https://example.com/paper.pdf".into())); - - Ok(()) - } -} +// use super::*; + +// /// Basic paper addition tests +// mod basic_operations { + +// use super::*; + +// #[traced_test] +// #[tokio::test] +// async fn test_add_paper() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let paper = create_test_paper(); + +// let papers = Add::paper(&paper).execute(&mut learner.database).await?; +// assert_eq!(papers.len(), 1); +// assert_eq!(papers[0].title, paper.title); + +// // Verify paper exists in database +// let stored = Query::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// assert_eq!(stored.len(), 1); +// assert_eq!(stored[0].title, paper.title); + +// Ok(()) +// } + +// #[traced_test] +// #[tokio::test] +// async fn test_add_paper_twice() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let paper = create_test_paper(); + +// Add::paper(&paper).execute(&mut learner.database).await?; +// let err = Add::paper(&paper).execute(&mut learner.database).await.unwrap_err(); + +// assert!(matches!(err, LearnerError::DatabaseDuplicatePaper(_))); + +// // Verify only one copy exists +// let stored = Query::list_all().execute(&mut learner.database).await?; +// assert_eq!(stored.len(), 1); + +// Ok(()) +// } + +// #[traced_test] +// #[tokio::test] +// async fn test_add_paper_with_authors() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let mut paper = create_test_paper(); +// paper.authors = vec![ +// Author { +// name: "Test Author 1".into(), +// affiliation: Some("University 1".into()), +// email: Some("email1@test.com".into()), +// }, +// Author { name: "Test Author 2".into(), affiliation: None, email: None }, +// ]; + +// Add::paper(&paper).execute(&mut learner.database).await?; + +// // Verify authors were stored +// let stored = Query::by_author("Test Author 1").execute(&mut learner.database).await?; +// assert_eq!(stored.len(), 1); +// assert_eq!(stored[0].authors.len(), 2); +// assert_eq!(stored[0].authors[0].affiliation, Some("University 1".into())); +// assert_eq!(stored[0].authors[1].name, "Test Author 2"); + +// Ok(()) +// } +// } + +// /// Tests for paper addition with documents +// mod document_operations { + +// use super::*; + +// #[traced_test] +// #[tokio::test] +// async fn test_add_complete_paper() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; + +// let papers = Add::complete(&paper).execute(&mut learner.database).await?; +// assert_eq!(papers.len(), 1); + +// // Verify both paper and document were added +// let stored = Query::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// assert_eq!(stored.len(), 1); + +// // Verify PDF exists in storage location +// let storage_path = learner.database.get_storage_path().await?; +// let pdf_path = storage_path.join(paper.filename()); +// assert!(pdf_path.exists(), "PDF file should exist at {:?}", pdf_path); + +// Ok(()) +// } + +// #[traced_test] +// #[tokio::test] +// async fn test_add_paper_then_document() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; + +// // First add paper only +// Add::paper(&paper).execute(&mut learner.database).await?; + +// // Then add with document +// let papers = Add::complete(&paper).execute(&mut learner.database).await?; +// assert_eq!(papers.len(), 1); + +// // Verify PDF exists +// let storage_path = learner.database.get_storage_path().await?; +// let pdf_path = storage_path.join(paper.filename()); +// assert!(pdf_path.exists()); + +// assert!(logs_contain( +// "WARN test_add_paper_then_document: learner::database::instruction::add: Tried to add \ +// complete paper when paper existed in database already, attempting to add only the +// document!" )); +// Ok(()) +// } + +// #[traced_test] +// #[tokio::test] +// async fn test_chain_document_addition() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; + +// let papers = Add::paper(&paper).with_document().execute(&mut learner.database).await?; +// assert_eq!(papers.len(), 1); + +// // Verify PDF exists +// let storage_path = learner.database.get_storage_path().await?; +// let pdf_path = storage_path.join(paper.filename()); +// assert!(pdf_path.exists()); + +// Ok(()) +// } + +// #[traced_test] +// #[tokio::test] +// async fn test_add_documents_by_query() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// // Add multiple papers without documents +// let paper1 = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; +// let paper2 = learner.retrievers.get_paper("https://eprint.iacr.org/2016/260").await?; +// Add::paper(&paper1).execute(&mut learner.database).await?; +// Add::paper(&paper2).execute(&mut learner.database).await?; + +// // Add documents for all papers +// let papers = Add::documents(Query::list_all()).execute(&mut learner.database).await?; +// assert_eq!(papers.len(), 2); + +// // Verify PDFs exist +// let storage_path = learner.database.get_storage_path().await?; +// for paper in papers { +// let pdf_path = storage_path.join(paper.filename()); +// assert!(pdf_path.exists(), "PDF should exist for {}", paper.source_identifier); +// } + +// Ok(()) +// } +// } + +// /// Edge case tests +// mod edge_cases { +// use super::*; + +// #[traced_test] +// #[tokio::test] +// async fn test_add_paper_with_special_characters() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let mut paper = create_test_paper(); +// paper.title = "Test & Paper: A Study!".into(); +// paper.abstract_text = "Abstract with & and other symbols: @#$%".into(); + +// let papers = Add::paper(&paper).execute(&mut learner.database).await?; +// assert_eq!(papers.len(), 1); +// assert_eq!(papers[0].title, paper.title); + +// Ok(()) +// } + +// #[traced_test] +// #[tokio::test] +// async fn test_add_empty_author_list() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let mut paper = create_test_paper(); +// paper.authors.clear(); + +// let papers = Add::paper(&paper).execute(&mut learner.database).await?; +// assert_eq!(papers.len(), 1); +// assert!(papers[0].authors.is_empty()); + +// Ok(()) +// } + +// #[traced_test] +// #[tokio::test] +// async fn test_add_paper_with_optional_fields() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// let mut paper = create_test_paper(); +// paper.doi = Some("10.1234/test".into()); +// paper.pdf_url = Some("https://example.com/paper.pdf".into()); + +// let papers = Add::paper(&paper).execute(&mut learner.database).await?; +// assert_eq!(papers[0].doi, Some("10.1234/test".into())); +// assert_eq!(papers[0].pdf_url, Some("https://example.com/paper.pdf".into())); + +// Ok(()) +// } +// } diff --git a/crates/learner/tests/workflows/database_operations/remove.rs b/crates/learner/tests/workflows/database_operations/remove.rs index 1762ec2..65e671a 100644 --- a/crates/learner/tests/workflows/database_operations/remove.rs +++ b/crates/learner/tests/workflows/database_operations/remove.rs @@ -1,435 +1,435 @@ -use super::*; +// use super::*; -/// Basic removal functionality tests -mod basic_operations { +// /// Basic removal functionality tests +// mod basic_operations { - use super::*; +// use super::*; - #[tokio::test] - #[traced_test] - async fn test_remove_existing_paper() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// #[tokio::test] +// #[traced_test] +// async fn test_remove_existing_paper() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = create_test_paper(); - Add::paper(&paper).execute(&mut learner.database).await?; +// let paper = create_test_paper(); +// Add::paper(&paper).execute(&mut learner.database).await?; - let removed_papers = Remove::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; +// let removed_papers = Remove::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; - assert_eq!(removed_papers.len(), 1); - assert_eq!(removed_papers[0].title, paper.title); - assert_eq!(removed_papers[0].authors.len(), paper.authors.len()); +// assert_eq!(removed_papers.len(), 1); +// assert_eq!(removed_papers[0].title, paper.title); +// assert_eq!(removed_papers[0].authors.len(), paper.authors.len()); - let results = Query::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - assert_eq!(results.len(), 0); +// let results = Query::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// assert_eq!(results.len(), 0); - Ok(()) - } +// Ok(()) +// } - #[tokio::test] - #[traced_test] - async fn test_remove_nonexistent_paper() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - let removed = Remove::by_source("arxiv", "nonexistent").execute(&mut learner.database).await?; - assert!(removed.is_empty()); +// #[tokio::test] +// #[traced_test] +// async fn test_remove_nonexistent_paper() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// let removed = Remove::by_source("arxiv", "nonexistent").execute(&mut +// learner.database).await?; assert!(removed.is_empty()); - Ok(()) - } +// Ok(()) +// } - #[tokio::test] - #[traced_test] - async fn test_remove_cascades_to_authors() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// #[tokio::test] +// #[traced_test] +// async fn test_remove_cascades_to_authors() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = create_test_paper(); - Add::paper(&paper).execute(&mut learner.database).await?; +// let paper = create_test_paper(); +// Add::paper(&paper).execute(&mut learner.database).await?; - Remove::from_query(Query::text("test")).execute(&mut learner.database).await?; - let authors = Query::by_author("").execute(&mut learner.database).await?; - - assert_eq!(authors.len(), 0); - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_remove_complete_paper() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; - // Add paper with document - Add::complete(&paper).execute(&mut learner.database).await?; - - // Remove it - Remove::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; +// Remove::from_query(Query::text("test")).execute(&mut learner.database).await?; +// let authors = Query::by_author("").execute(&mut learner.database).await?; + +// assert_eq!(authors.len(), 0); +// Ok(()) +// } + +// #[tokio::test] +// #[traced_test] +// async fn test_remove_complete_paper() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; +// // Add paper with document +// Add::complete(&paper).execute(&mut learner.database).await?; + +// // Remove it +// Remove::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; - // Verify paper is gone - let results = Query::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - assert_eq!(results.len(), 0); - - Ok(()) - } -} +// // Verify paper is gone +// let results = Query::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// assert_eq!(results.len(), 0); + +// Ok(()) +// } +// } -/// Dry run functionality tests -mod dry_run { - use super::*; +// /// Dry run functionality tests +// mod dry_run { +// use super::*; - #[tokio::test] - #[traced_test] - async fn test_dry_run_basic() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// #[tokio::test] +// #[traced_test] +// async fn test_dry_run_basic() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = create_test_paper(); - Add::paper(&paper).execute(&mut learner.database).await?; +// let paper = create_test_paper(); +// Add::paper(&paper).execute(&mut learner.database).await?; - let would_remove = Remove::by_source(&paper.source, &paper.source_identifier) - .dry_run() - .execute(&mut learner.database) - .await?; +// let would_remove = Remove::by_source(&paper.source, &paper.source_identifier) +// .dry_run() +// .execute(&mut learner.database) +// .await?; - assert_eq!(would_remove.len(), 1); - assert_eq!(would_remove[0].title, paper.title); +// assert_eq!(would_remove.len(), 1); +// assert_eq!(would_remove[0].title, paper.title); - let results = Query::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - assert_eq!(results.len(), 1); - - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_dry_run_returns_complete_paper() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - let paper = create_test_paper(); - Add::paper(&paper).execute(&mut learner.database).await?; - - let would_remove = - Remove::from_query(Query::text("test")).dry_run().execute(&mut learner.database).await?; - - assert_eq!(would_remove.len(), 1); - let removed = &would_remove[0]; - - // Verify all fields - assert_eq!(removed.title, paper.title); - assert_eq!(removed.abstract_text, paper.abstract_text); - assert_eq!(removed.publication_date, paper.publication_date); - assert_eq!(removed.source, paper.source); - assert_eq!(removed.source_identifier, paper.source_identifier); - assert_eq!(removed.pdf_url, paper.pdf_url); - assert_eq!(removed.doi, paper.doi); - assert_eq!(removed.authors.len(), paper.authors.len()); - - for (removed_author, original_author) in removed.authors.iter().zip(paper.authors.iter()) { - assert_eq!(removed_author.name, original_author.name); - assert_eq!(removed_author.affiliation, original_author.affiliation); - assert_eq!(removed_author.email, original_author.email); - } - - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_dry_run_with_complete_paper() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; - Add::complete(&paper).execute(&mut learner.database).await?; - - let would_remove = Remove::by_source(&paper.source, &paper.source_identifier) - .dry_run() - .execute(&mut learner.database) - .await?; - - // Verify paper would be removed - assert_eq!(would_remove.len(), 1); - - // But verify it's still in the database - let results = Query::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - assert_eq!(results.len(), 1); - - Ok(()) - } -} - -/// Query-based removal tests -mod query_based_removal { - - use super::*; - - #[tokio::test] - #[traced_test] - async fn test_remove_by_text_search() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - Add::paper(&create_test_paper()).execute(&mut learner.database).await?; - Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; - - let removed = Remove::from_query(Query::text("two")).execute(&mut learner.database).await?; - assert_eq!(removed.len(), 1); - assert_eq!(removed[0].title, "Test Paper: Two"); - - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_remove_by_author() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - Add::paper(&create_test_paper()).execute(&mut learner.database).await?; - Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; - - let removed = - Remove::from_query(Query::by_author("John Doe")).execute(&mut learner.database).await?; - assert_eq!(removed.len(), 1); - assert_eq!(removed[0].authors[0].name, "John Doe"); - - // Verify only the matching paper was removed - let remaining = Query::list_all().execute(&mut learner.database).await?; - assert_eq!(remaining.len(), 1); - - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_remove_with_ordering() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - Add::paper(&create_test_paper()).execute(&mut learner.database).await?; - Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; - - let removed = - Remove::from_query(Query::text("test").order_by(OrderField::PublicationDate).descending()) - .execute(&mut learner.database) - .await?; - - assert_eq!(removed.len(), 2); - assert_eq!(removed[0].title, "Test Paper: Two"); // More recent - assert_eq!(removed[1].title, "Test Paper"); - - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_remove_by_date_range() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - Add::paper(&create_test_paper()).execute(&mut learner.database).await?; - Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; - - let cutoff_date = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); - let removed = - Remove::from_query(Query::before_date(cutoff_date).order_by(OrderField::PublicationDate)) - .execute(&mut learner.database) - .await?; +// let results = Query::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// assert_eq!(results.len(), 1); + +// Ok(()) +// } + +// #[tokio::test] +// #[traced_test] +// async fn test_dry_run_returns_complete_paper() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// let paper = create_test_paper(); +// Add::paper(&paper).execute(&mut learner.database).await?; + +// let would_remove = +// Remove::from_query(Query::text("test")).dry_run().execute(&mut learner.database).await?; + +// assert_eq!(would_remove.len(), 1); +// let removed = &would_remove[0]; + +// // Verify all fields +// assert_eq!(removed.title, paper.title); +// assert_eq!(removed.abstract_text, paper.abstract_text); +// assert_eq!(removed.publication_date, paper.publication_date); +// assert_eq!(removed.source, paper.source); +// assert_eq!(removed.source_identifier, paper.source_identifier); +// assert_eq!(removed.pdf_url, paper.pdf_url); +// assert_eq!(removed.doi, paper.doi); +// assert_eq!(removed.authors.len(), paper.authors.len()); + +// for (removed_author, original_author) in removed.authors.iter().zip(paper.authors.iter()) { +// assert_eq!(removed_author.name, original_author.name); +// assert_eq!(removed_author.affiliation, original_author.affiliation); +// assert_eq!(removed_author.email, original_author.email); +// } + +// Ok(()) +// } + +// #[tokio::test] +// #[traced_test] +// async fn test_dry_run_with_complete_paper() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; +// Add::complete(&paper).execute(&mut learner.database).await?; + +// let would_remove = Remove::by_source(&paper.source, &paper.source_identifier) +// .dry_run() +// .execute(&mut learner.database) +// .await?; + +// // Verify paper would be removed +// assert_eq!(would_remove.len(), 1); + +// // But verify it's still in the database +// let results = Query::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// assert_eq!(results.len(), 1); + +// Ok(()) +// } +// } + +// /// Query-based removal tests +// mod query_based_removal { + +// use super::*; + +// #[tokio::test] +// #[traced_test] +// async fn test_remove_by_text_search() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// Add::paper(&create_test_paper()).execute(&mut learner.database).await?; +// Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; + +// let removed = Remove::from_query(Query::text("two")).execute(&mut learner.database).await?; +// assert_eq!(removed.len(), 1); +// assert_eq!(removed[0].title, "Test Paper: Two"); + +// Ok(()) +// } + +// #[tokio::test] +// #[traced_test] +// async fn test_remove_by_author() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// Add::paper(&create_test_paper()).execute(&mut learner.database).await?; +// Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; + +// let removed = +// Remove::from_query(Query::by_author("John Doe")).execute(&mut learner.database).await?; +// assert_eq!(removed.len(), 1); +// assert_eq!(removed[0].authors[0].name, "John Doe"); + +// // Verify only the matching paper was removed +// let remaining = Query::list_all().execute(&mut learner.database).await?; +// assert_eq!(remaining.len(), 1); + +// Ok(()) +// } + +// #[tokio::test] +// #[traced_test] +// async fn test_remove_with_ordering() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// Add::paper(&create_test_paper()).execute(&mut learner.database).await?; +// Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; + +// let removed = +// Remove::from_query(Query::text("test").order_by(OrderField::PublicationDate).descending()) +// .execute(&mut learner.database) +// .await?; + +// assert_eq!(removed.len(), 2); +// assert_eq!(removed[0].title, "Test Paper: Two"); // More recent +// assert_eq!(removed[1].title, "Test Paper"); + +// Ok(()) +// } + +// #[tokio::test] +// #[traced_test] +// async fn test_remove_by_date_range() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// Add::paper(&create_test_paper()).execute(&mut learner.database).await?; +// Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; + +// let cutoff_date = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(); +// let removed = +// Remove::from_query(Query::before_date(cutoff_date).order_by(OrderField::PublicationDate)) +// .execute(&mut learner.database) +// .await?; - assert_eq!(removed.len(), 1); - assert_eq!(removed[0].title, "Test Paper"); +// assert_eq!(removed.len(), 1); +// assert_eq!(removed[0].title, "Test Paper"); - Ok(()) - } +// Ok(()) +// } - #[tokio::test] - #[traced_test] - async fn test_remove_multiple_papers_by_source() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// #[tokio::test] +// #[traced_test] +// async fn test_remove_multiple_papers_by_source() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - // Add multiple papers from same source - let paper1 = create_test_paper(); - let paper2 = create_second_test_paper(); - Add::paper(&paper1).execute(&mut learner.database).await?; - Add::paper(&paper2).execute(&mut learner.database).await?; +// // Add multiple papers from same source +// let paper1 = create_test_paper(); +// let paper2 = create_second_test_paper(); +// Add::paper(&paper1).execute(&mut learner.database).await?; +// Add::paper(&paper2).execute(&mut learner.database).await?; - // Use a text search that will match all papers from this source - // alternatively we could use Query::list_all() with a source filter - let removed = Remove::from_query(Query::text("test")).execute(&mut learner.database).await?; - assert_eq!(removed.len(), 2); - assert!(removed.iter().all(|p| p.source == "arxiv")); +// // Use a text search that will match all papers from this source +// // alternatively we could use Query::list_all() with a source filter +// let removed = Remove::from_query(Query::text("test")).execute(&mut learner.database).await?; +// assert_eq!(removed.len(), 2); +// assert!(removed.iter().all(|p| p.source == "arxiv")); - // Verify all papers are gone - let remaining = Query::text("test").execute(&mut learner.database).await?; - assert!(remaining.is_empty()); +// // Verify all papers are gone +// let remaining = Query::text("test").execute(&mut learner.database).await?; +// assert!(remaining.is_empty()); - Ok(()) - } +// Ok(()) +// } - // Alternative version using list_all - #[tokio::test] - #[traced_test] - async fn test_remove_multiple_papers_from_source() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// // Alternative version using list_all +// #[tokio::test] +// #[traced_test] +// async fn test_remove_multiple_papers_from_source() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - // Add papers from different sources - let paper1 = create_test_paper(); - let paper2 = create_second_test_paper(); - Add::paper(&paper1).execute(&mut learner.database).await?; - Add::paper(&paper2).execute(&mut learner.database).await?; - - // Verify we have papers from our source before removal - let initial = Query::list_all().execute(&mut learner.database).await?; - assert!(initial.iter().any(|p| p.source == "arxiv")); - - // Remove all papers using list_all and checking source - let removed = Remove::from_query(Query::list_all().order_by(OrderField::Source)) - .execute(&mut learner.database) - .await?; +// // Add papers from different sources +// let paper1 = create_test_paper(); +// let paper2 = create_second_test_paper(); +// Add::paper(&paper1).execute(&mut learner.database).await?; +// Add::paper(&paper2).execute(&mut learner.database).await?; + +// // Verify we have papers from our source before removal +// let initial = Query::list_all().execute(&mut learner.database).await?; +// assert!(initial.iter().any(|p| p.source == "arxiv")); + +// // Remove all papers using list_all and checking source +// let removed = Remove::from_query(Query::list_all().order_by(OrderField::Source)) +// .execute(&mut learner.database) +// .await?; - // Count papers from our source - let arxiv_count = removed.iter().filter(|p| p.source == "arxiv").count(); - assert_eq!(arxiv_count, 2); - - // Verify no papers remain from that source - let remaining = Query::list_all().execute(&mut learner.database).await?; - assert!(!remaining.iter().any(|p| p.source == "arxiv")); - - Ok(()) - } -} - -/// Recovery and data integrity tests -mod recovery { - use super::*; - - #[tokio::test] - #[traced_test] - async fn test_remove_papers_can_be_readded() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - let paper = create_test_paper(); - Add::paper(&paper).execute(&mut learner.database).await?; - - let removed_papers = - Remove::from_query(Query::text("test")).execute(&mut learner.database).await?; - assert_eq!(removed_papers.len(), 1); - - Add::paper(&removed_papers[0]).execute(&mut learner.database).await?; - - let results = Query::text("test").execute(&mut learner.database).await?; - assert_eq!(results.len(), 1); - - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_bulk_remove_and_readd() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - Add::paper(&create_test_paper()).execute(&mut learner.database).await?; - Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; - - let removed = Remove::from_query(Query::text("test")).execute(&mut learner.database).await?; - assert_eq!(removed.len(), 2); - - for paper in &removed { - Add::paper(paper).execute(&mut learner.database).await?; - } - - let results = Query::text("test").execute(&mut learner.database).await?; - assert_eq!(results.len(), 2); - - // Verify order is preserved - assert_eq!(results[0].title, removed[0].title); - assert_eq!(results[1].title, removed[1].title); - - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_readd_with_different_completion() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - - // Add paper without document - let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; - Add::paper(&paper).execute(&mut learner.database).await?; - - // Remove it - let removed = Remove::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - - // Readd with document - Add::complete(&removed[0]).execute(&mut learner.database).await?; - - // Verify paper exists with updated data - let results = Query::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - assert_eq!(results.len(), 1); +// // Count papers from our source +// let arxiv_count = removed.iter().filter(|p| p.source == "arxiv").count(); +// assert_eq!(arxiv_count, 2); + +// // Verify no papers remain from that source +// let remaining = Query::list_all().execute(&mut learner.database).await?; +// assert!(!remaining.iter().any(|p| p.source == "arxiv")); + +// Ok(()) +// } +// } + +// /// Recovery and data integrity tests +// mod recovery { +// use super::*; + +// #[tokio::test] +// #[traced_test] +// async fn test_remove_papers_can_be_readded() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// let paper = create_test_paper(); +// Add::paper(&paper).execute(&mut learner.database).await?; + +// let removed_papers = +// Remove::from_query(Query::text("test")).execute(&mut learner.database).await?; +// assert_eq!(removed_papers.len(), 1); + +// Add::paper(&removed_papers[0]).execute(&mut learner.database).await?; + +// let results = Query::text("test").execute(&mut learner.database).await?; +// assert_eq!(results.len(), 1); + +// Ok(()) +// } + +// #[tokio::test] +// #[traced_test] +// async fn test_bulk_remove_and_readd() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// Add::paper(&create_test_paper()).execute(&mut learner.database).await?; +// Add::paper(&create_second_test_paper()).execute(&mut learner.database).await?; + +// let removed = Remove::from_query(Query::text("test")).execute(&mut learner.database).await?; +// assert_eq!(removed.len(), 2); + +// for paper in &removed { +// Add::paper(paper).execute(&mut learner.database).await?; +// } + +// let results = Query::text("test").execute(&mut learner.database).await?; +// assert_eq!(results.len(), 2); + +// // Verify order is preserved +// assert_eq!(results[0].title, removed[0].title); +// assert_eq!(results[1].title, removed[1].title); + +// Ok(()) +// } + +// #[tokio::test] +// #[traced_test] +// async fn test_readd_with_different_completion() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; + +// // Add paper without document +// let paper = learner.retrievers.get_paper("https://arxiv.org/abs/2301.07041").await?; +// Add::paper(&paper).execute(&mut learner.database).await?; + +// // Remove it +// let removed = Remove::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; + +// // Readd with document +// Add::complete(&removed[0]).execute(&mut learner.database).await?; + +// // Verify paper exists with updated data +// let results = Query::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// assert_eq!(results.len(), 1); - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_remove_and_readd_preserves_metadata() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// Ok(()) +// } + +// #[tokio::test] +// #[traced_test] +// async fn test_remove_and_readd_preserves_metadata() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let mut paper = create_test_paper(); - // Add some optional fields - paper.doi = Some("10.1234/test".to_string()); - paper.pdf_url = Some("https://example.com/test.pdf".to_string()); +// let mut paper = create_test_paper(); +// // Add some optional fields +// paper.doi = Some("10.1234/test".to_string()); +// paper.pdf_url = Some("https://example.com/test.pdf".to_string()); - Add::paper(&paper).execute(&mut learner.database).await?; +// Add::paper(&paper).execute(&mut learner.database).await?; - let removed = Remove::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - Add::paper(&removed[0]).execute(&mut learner.database).await?; +// let removed = Remove::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// Add::paper(&removed[0]).execute(&mut learner.database).await?; - let results = Query::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - assert_eq!(results[0].doi, paper.doi); - assert_eq!(results[0].pdf_url, paper.pdf_url); +// let results = Query::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// assert_eq!(results[0].doi, paper.doi); +// assert_eq!(results[0].pdf_url, paper.pdf_url); - Ok(()) - } +// Ok(()) +// } - #[tokio::test] - #[traced_test] - async fn test_remove_readd_with_updated_data() -> TestResult<()> { - let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; +// #[tokio::test] +// #[traced_test] +// async fn test_remove_readd_with_updated_data() -> TestResult<()> { +// let (mut learner, _cfg_dir, _db_dir, _strg_dir) = create_test_learner().await; - let paper = create_test_paper(); - Add::paper(&paper).execute(&mut learner.database).await?; +// let paper = create_test_paper(); +// Add::paper(&paper).execute(&mut learner.database).await?; - let mut removed = Remove::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; +// let mut removed = Remove::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; - // Modify the removed paper - let mut updated_paper = removed.remove(0); - updated_paper.abstract_text = "Updated abstract".to_string(); - updated_paper.doi = Some("10.1234/new".to_string()); +// // Modify the removed paper +// let mut updated_paper = removed.remove(0); +// updated_paper.abstract_text = "Updated abstract".to_string(); +// updated_paper.doi = Some("10.1234/new".to_string()); - // Readd with changes - Add::paper(&updated_paper).execute(&mut learner.database).await?; +// // Readd with changes +// Add::paper(&updated_paper).execute(&mut learner.database).await?; - let results = Query::by_source(&paper.source, &paper.source_identifier) - .execute(&mut learner.database) - .await?; - assert_eq!(results[0].abstract_text, "Updated abstract"); - assert_eq!(results[0].doi, Some("10.1234/new".to_string())); - - Ok(()) - } -} +// let results = Query::by_source(&paper.source, &paper.source_identifier) +// .execute(&mut learner.database) +// .await?; +// assert_eq!(results[0].abstract_text, "Updated abstract"); +// assert_eq!(results[0].doi, Some("10.1234/new".to_string())); + +// Ok(()) +// } +// } diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index eb588ba..39186b3 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -1,9 +1,6 @@ use std::fs; -use learner::{ - configuration::{Config, ConfigurationManager}, - resource::ResourceTemplate, -}; +use learner::{configuration::ConfigurationManager, template::Template}; use super::*; @@ -19,14 +16,14 @@ async fn test_arxiv_retriever_integration() -> TestResult<()> { // file", // ); let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); - let retriever: Config = manager.load_config("config_new/arxiv.toml")?; + let retriever: Retriever = manager.load_config("config_new/arxiv.toml")?; // let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); // let resource: ResourceTemplate = toml::from_str(&res_config_str).expect("Failed to parse // config"); // Test with a real arXiv paper - let paper = retriever.inner().retrieve_resource("2301.07041").await?; + let paper = retriever.retrieve_resource("2301.07041").await?; dbg!(&paper); // assert!(resource.validate(&paper)?); @@ -72,12 +69,12 @@ async fn test_arxiv_pdf_from_paper() -> TestResult<()> { #[tokio::test] async fn test_iacr_retriever_integration() -> TestResult<()> { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); - let retriever: Config = manager.load_config("config_new/iacr.toml")?; + let retriever: Retriever = manager.load_config("config_new/iacr.toml")?; // let resource: ResourceTemplate = toml::from_str(&res_config_str).expect("Failed to parse // config"); // // Test with a real IACR paper - let paper = retriever.inner().retrieve_resource("2016/260").await.unwrap(); + let paper = retriever.retrieve_resource("2016/260").await.unwrap(); // assert!(resource.validate(&paper)?); // TODO: validation already happens internally, to be fair // that validation may not be working totally right dbg!(&paper); @@ -118,10 +115,10 @@ async fn test_iacr_pdf_from_paper() -> TestResult<()> { #[traced_test] async fn test_doi_retriever_integration() -> TestResult<()> { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); - let retriever: Config = manager.load_config("config_new/doi.toml")?; + let retriever: Retriever = manager.load_config("config_new/doi.toml")?; // Test with a real DOI paper - let paper = retriever.inner().retrieve_resource("10.1145/1327452.1327492").await?; + let paper = retriever.retrieve_resource("10.1145/1327452.1327492").await?; // assert!(resource.validate(&paper)?); dbg!(&paper); // assert!(!paper.title.is_empty()); From 76e3819aa3b4cbdd0a5cc073b0b236d717aa919d Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 06:16:51 -0700 Subject: [PATCH 43/73] cleaning up more --- crates/learner/config_new/doi.toml | 6 +- crates/learner/src/lib.rs | 2 +- crates/learner/src/retriever/config.rs | 354 ++++++++++++++++--- crates/learner/src/retriever/mod.rs | 2 +- crates/learner/src/retriever/response/mod.rs | 303 ---------------- 5 files changed, 317 insertions(+), 350 deletions(-) diff --git a/crates/learner/config_new/doi.toml b/crates/learner/config_new/doi.toml index 7edc052..1e14200 100644 --- a/crates/learner/config_new/doi.toml +++ b/crates/learner/config_new/doi.toml @@ -49,11 +49,11 @@ type = "ArrayOfObjects" [resource_mappings.publication_dates] path = "message/created/date-time" -[resource_mappings.pdf_url] -path = "message/link/0/URL" - [resource_mappings.doi] path = "message/DOI" +[resource_mappings.pdf_url] +path = "message/link/0/URL" + [headers] Accept = "application/json" diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index 417b7f2..bceb800 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -139,7 +139,7 @@ //! # } //! ``` -#![warn(missing_docs, clippy::missing_docs_in_private_items)] +// #![warn(missing_docs, clippy::missing_docs_in_private_items)] #![feature(str_from_utf16_endian)] use std::{ diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 1c53ba0..5c435b7 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,48 +1,37 @@ -// use resource::Resource; - -use record::RetrievalData; -use template::{Resource, Template}; +use serde_json::Map; +use template::{FieldDefinition, Resource, Template}; use super::*; -// TODO: fix all the stuff that had to do with `Retriever.name` - #[derive(Debug, Clone, Deserialize)] pub struct Retriever { - pub name: String, - pub description: Option, - pub resource: Template, - #[serde(skip_deserializing)] - #[serde(default)] - pub retrieval_data: RetrievalData, - - // TODO: Should own a `Record` + pub name: String, + pub description: Option, /// Base URL for API requests - pub base_url: String, + pub base_url: String, /// Regex pattern for matching and extracting paper identifiers #[serde(deserialize_with = "deserialize_regex")] - pub pattern: Regex, + pub pattern: Regex, /// Source identifier for papers from this retriever - pub source: String, + pub source: String, /// Template for constructing API endpoint URLs - pub endpoint_template: String, + pub endpoint_template: String, // TODO: This is now more like "how to get the thing to map into the resource" // #[serde(flatten)] - pub response_format: ResponseFormat, + pub response_format: ResponseFormat, /// Optional HTTP headers for API requests #[serde(default)] - pub headers: BTreeMap, - // TODO: need to have these be associated somehow, actually resource should probably be in record + pub headers: BTreeMap, + + #[serde(rename = "resource")] + pub resource_template: Template, #[serde(default)] - pub resource_mappings: BTreeMap, + pub resource_mappings: BTreeMap, + #[serde(default)] pub retrieval_mappings: BTreeMap, } -// impl Identifiable for Retriever { -// fn name(&self) -> String { self.name.clone() } -// } - impl Retriever { /// Extracts the canonical identifier from an input string. /// @@ -67,15 +56,12 @@ impl Retriever { .ok_or(LearnerError::InvalidIdentifier) } - // TODO: perhaps this just isn't even implemented here and is instead implemented on `Learner`. - // Could consider an `api.rs` module to extend more learner functionality there. - #[allow(missing_docs)] pub async fn retrieve_resource(&self, input: &str) -> Result { let identifier = self.extract_identifier(input)?; // Send request and get response let url = self.endpoint_template.replace("{identifier}", identifier); - // debug!("Fetching from {} via: {}", self.name, url); + debug!("Fetching from {} via: {}", self.name, url); let client = reqwest::Client::new(); let mut request = client.get(&url); @@ -88,7 +74,7 @@ impl Retriever { let response = request.send().await?; let data = response.bytes().await?; - // trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); + trace!("{} response: {}", self.name, String::from_utf8_lossy(&data)); // Process the response using configured processor let json = match &self.response_format { @@ -98,25 +84,309 @@ impl Retriever { // Process response and get resource // TODO: this should probably be a method - let mut resource = process_json_value(&json, &self.resource_mappings, &self.resource)?; + let mut resource = self.process_json_value(&json)?; // Add source metadata resource.insert("source".into(), Value::String(self.source.clone())); resource.insert("source_identifier".into(), Value::String(identifier.to_string())); // Validate full resource against config - self.resource.validate(&resource)?; + self.resource_template.validate(&resource)?; + Ok(resource) + } + + pub fn process_json_value(&self, json: &Value) -> Result { + let mut resource = Resource::new(); + + for field_def in &self.resource_template.fields { + if let Some(field_map) = self.resource_mappings.get(&field_def.name) { + if let Some(value) = extract_mapped_value(json, field_map, field_def)? { + resource.insert(field_def.name.clone(), value); + } else if field_def.required { + return Err(LearnerError::ApiError(format!( + "Required field '{}' not found in response", + field_def.name + ))); + } else if let Some(default) = &field_def.default { + resource.insert(field_def.name.clone(), default.clone()); + } + } + } + Ok(resource) + } +} + +/// Extract and transform a value from JSON using a field mapping +fn extract_mapped_value( + json: &Value, + field_map: &FieldMap, + field_def: &FieldDefinition, +) -> Result> { + let path_components: Vec<&str> = field_map.path.split('/').collect(); + + // Extract raw value using path + let raw_value = get_path_value(json, &path_components); + + // If no value found, return None + let Some(raw_value) = raw_value else { + return Ok(None); + }; + + // First apply any explicit transforms + let value = if let Some(transform) = &field_map.transform { + apply_transform(&raw_value, transform)? + } else { + raw_value + }; + + // Then attempt type coercion based on field definition + let coerced = coerce_to_type(&value, field_def)?; + Ok(Some(coerced)) +} + +fn coerce_to_type(value: &Value, field_def: &FieldDefinition) -> Result { + match field_def.field_type.as_str() { + "array" => { + let arr = match value { + // Single value -> wrap in array + Value::String(_) | Value::Object(_) | Value::Number(_) => vec![value.clone()], + // Already an array + Value::Array(arr) => arr.clone(), + _ => return Ok(value.clone()), // Can't coerce, return as-is + }; + + // If we have inner type info, try to coerce each element + if let Some(ref type_def) = field_def.type_definition { + if let Some(ref element_def) = type_def.element_type { + let coerced: Vec = + arr.into_iter().map(|v| coerce_to_type(&v, element_def)).collect::>()?; + Ok(Value::Array(coerced)) + } else { + Ok(Value::Array(arr)) + } + } else { + Ok(Value::Array(arr)) + } + }, + "object" => { + // If we have field definitions, ensure object has required structure + if let Some(ref type_def) = field_def.type_definition { + if let Some(fields) = &type_def.fields { + let mut obj = Map::new(); + match value { + // Convert string to {name: string} if that's the structure we want + Value::String(s) if fields.len() == 1 && fields[0].name == "name" => { + obj.insert("name".to_string(), Value::String(s.clone())); + Ok(Value::Object(obj)) + }, + Value::Object(m) => { + // Copy over matching fields with coercion + for field in fields { + if let Some(v) = m.get(&field.name) { + obj.insert(field.name.clone(), coerce_to_type(v, field)?); + } + } + Ok(Value::Object(obj)) + }, + _ => Ok(value.clone()), + } + } else { + Ok(value.clone()) + } + } else { + Ok(value.clone()) + } + }, + // Add other type coercions as needed + _ => Ok(value.clone()), + } +} + +/// Get a value from JSON using a path +// Change return type to owned Value +fn get_path_value(json: &Value, path: &[&str]) -> Option { + let mut current = json.clone(); + + for &component in path { + match current { + Value::Object(map) => + if let Some(value) = map.get(component) { + current = value.clone(); + } else { + return None; + }, + Value::Array(arr) => { + // If component is numeric, use it as array index + if let Ok(index) = component.parse::() { + if let Some(value) = arr.get(index) { + current = value.clone(); + } else { + return None; + } + } else { + // Otherwise collect matching values from array elements + let values: Vec = arr + .iter() + .filter_map(|item| match item { + Value::Object(map) => map.get(component).cloned(), + _ => None, + }) + .collect(); + + if values.is_empty() { + return None; + } else if values.len() == 1 { + current = values[0].clone(); + } else { + return Some(Value::Array(values)); + } + } + }, + _ => return None, + } + } + + Some(current) +} + +/// Apply a transform to a JSON value +fn apply_transform(value: &Value, transform: &Transform) -> Result { + dbg!(&value); + match transform { + Transform::Replace { pattern, replacement } => { + let text = value.as_str().ok_or_else(|| { + LearnerError::ApiError("Replace transform requires string input".to_string()) + })?; + let re = + Regex::new(pattern).map_err(|e| LearnerError::ApiError(format!("Invalid regex: {}", e)))?; + Ok(Value::String(re.replace_all(text, replacement.as_str()).into_owned())) + }, + + Transform::Date { from_format, to_format } => { + let text = value.as_str().ok_or_else(|| { + LearnerError::ApiError("Date transform requires string input".to_string()) + })?; + let dt = chrono::NaiveDateTime::parse_from_str(text, from_format) + .map_err(|e| LearnerError::ApiError(format!("Invalid date: {}", e)))?; + Ok(Value::String(dt.format(to_format).to_string())) + }, + + Transform::Url { base, suffix } => { + let text = value + .as_str() + .ok_or_else(|| LearnerError::ApiError("URL transform requires string input".to_string()))?; + Ok(Value::String(format!( + "{}{}", + base.replace("{value}", text), + suffix.as_deref().unwrap_or("") + ))) + }, + + Transform::Compose { sources, format } => { + // Extract values from each source + let values: Vec = sources + .iter() + .filter_map(|source| match source { + Source::Path(path) => { + let components: Vec<&str> = path.split('/').collect(); + get_path_value(value, &components) + }, + Source::Literal(text) => Some(Value::String(text.clone())), + Source::KeyValue { key: _, path } => { + let components: Vec<&str> = path.split('/').collect(); + get_path_value(value, &components) + }, + }) + .collect(); + + dbg!(&values); + + // Apply the format to the collected values + match format { + ComposeFormat::Join { delimiter } => { + // Convert values to strings and join + let strings: Vec = values + .iter() + .filter_map(|v| match v { + Value::String(s) => Some(s.clone()), + Value::Array(arr) if arr.len() == 1 => arr[0].as_str().map(|s| s.to_string()), + _ => None, + }) + .collect(); + Ok(Value::String(strings.join(delimiter))) + }, + + ComposeFormat::Object => { + dbg!("inside here"); + let mut obj = Map::new(); + dbg!(&sources); + for (source, value) in sources.iter().zip(values.iter()) { + dbg!(&source); + if let Source::KeyValue { key, .. } = source { + dbg!(key); + obj.insert(key.clone(), value.clone()); + } + } + dbg!(&obj); + Ok(Value::Object(obj)) + }, + + ComposeFormat::ArrayOfObjects { template } => { + match value { + // Handle single string -> array of objects + Value::String(s) => { + let mut obj = Map::new(); + for (key, template_value) in template { + let value = template_value.replace("{value}", s); + obj.insert(key.clone(), Value::String(value)); + } + Ok(Value::Array(vec![Value::Object(obj)])) + }, - // todo!() + // Handle array -> array of objects + Value::Array(arr) => { + dbg!(&arr); + let objects: Vec = arr + .iter() + .filter_map(|item| { + dbg!(&item); + let mut obj = Map::new(); + for (key, template_value) in template { + let value = match item { + Value::String(s) => template_value.replace("{value}", s), + Value::Object(obj) => { + dbg!(obj); + let mut keys_and_vals = Vec::new(); + sources.iter().for_each(|source| { + if let Source::KeyValue { key, path } = source { + if let Some(val) = obj.get(path) { + keys_and_vals.push((key, val)) + } + } + }); + dbg!(&key); + keys_and_vals.into_iter().fold(template_value.clone(), |acc, (k, v)| { + let replacement = format!("{{{k}}}"); + acc.replace(&replacement, v.as_str().unwrap_or_default()) + }) + }, + _ => return None, + }; + obj.insert(key.clone(), Value::String(value)); + } + Some(Value::Object(obj)) + }) + .collect(); + Ok(Value::Array(objects)) + }, - // Ok(Record { - // resource, - // resource_config: resource_config.clone(), - // retrieval: None, - // state: ResourceState::default(), - // storage: None, - // tags: Vec::new(), - // }) + _ => Err(LearnerError::ApiError( + "ArrayOfObjects transform requires string or array input".to_string(), + )), + } + }, + } + }, } } diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index a1959c7..969506d 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -56,7 +56,7 @@ impl Retrievers { /// ``` pub fn new() -> Self { Self::default() } - pub async fn get_resource(&self, input: &str) -> Result { + pub async fn get_resource_file(&self, input: &str) -> Result { todo!( "Arguably, we don't even need this. We could instead just have this handled by `Learner` so \ the API is simpler" diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response/mod.rs index 212459a..2e11a24 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response/mod.rs @@ -1,6 +1,3 @@ -use serde_json::Map; -use template::{FieldDefinition, Template}; - use super::*; pub mod xml; @@ -140,303 +137,3 @@ pub enum ComposeFormat { template: BTreeMap, }, } - -/// Process a JSON value according to field mappings and resource configuration -pub fn process_json_value( - json: &Value, - field_maps: &BTreeMap, - resource_config: &Template, -) -> Result { - let mut resource = Resource::new(); - - for field_def in &resource_config.fields { - if let Some(field_map) = field_maps.get(&field_def.name) { - if let Some(value) = extract_mapped_value(json, field_map, field_def)? { - resource.insert(field_def.name.clone(), value); - } else if field_def.required { - return Err(LearnerError::ApiError(format!( - "Required field '{}' not found in response", - field_def.name - ))); - } else if let Some(default) = &field_def.default { - resource.insert(field_def.name.clone(), default.clone()); - } - } - } - - Ok(resource) -} - -/// Extract and transform a value from JSON using a field mapping -fn extract_mapped_value( - json: &Value, - field_map: &FieldMap, - field_def: &FieldDefinition, -) -> Result> { - let path_components: Vec<&str> = field_map.path.split('/').collect(); - - // Extract raw value using path - let raw_value = get_path_value(json, &path_components)?; - - // If no value found, return None - let Some(raw_value) = raw_value else { - return Ok(None); - }; - - // First apply any explicit transforms - let value = if let Some(transform) = &field_map.transform { - apply_transform(&raw_value, transform)? - } else { - raw_value.clone() - }; - - // Then attempt type coercion based on field definition - let coerced = coerce_to_type(&value, field_def)?; - Ok(Some(coerced)) -} - -fn coerce_to_type(value: &Value, field_def: &FieldDefinition) -> Result { - match field_def.field_type.as_str() { - "array" => { - let arr = match value { - // Single value -> wrap in array - Value::String(_) | Value::Object(_) | Value::Number(_) => vec![value.clone()], - // Already an array - Value::Array(arr) => arr.clone(), - _ => return Ok(value.clone()), // Can't coerce, return as-is - }; - - // If we have inner type info, try to coerce each element - if let Some(ref type_def) = field_def.type_definition { - if let Some(ref element_def) = type_def.element_type { - let coerced: Vec = - arr.into_iter().map(|v| coerce_to_type(&v, element_def)).collect::>()?; - Ok(Value::Array(coerced)) - } else { - Ok(Value::Array(arr)) - } - } else { - Ok(Value::Array(arr)) - } - }, - "object" => { - // If we have field definitions, ensure object has required structure - if let Some(ref type_def) = field_def.type_definition { - if let Some(fields) = &type_def.fields { - let mut obj = Map::new(); - match value { - // Convert string to {name: string} if that's the structure we want - Value::String(s) if fields.len() == 1 && fields[0].name == "name" => { - obj.insert("name".to_string(), Value::String(s.clone())); - Ok(Value::Object(obj)) - }, - Value::Object(m) => { - // Copy over matching fields with coercion - for field in fields { - if let Some(v) = m.get(&field.name) { - obj.insert(field.name.clone(), coerce_to_type(v, field)?); - } - } - Ok(Value::Object(obj)) - }, - _ => Ok(value.clone()), - } - } else { - Ok(value.clone()) - } - } else { - Ok(value.clone()) - } - }, - // Add other type coercions as needed - _ => Ok(value.clone()), - } -} - -/// Get a value from JSON using a path -// Change return type to owned Value -fn get_path_value(json: &Value, path: &[&str]) -> Result> { - let mut current = json.clone(); - - for &component in path { - match current { - Value::Object(map) => - if let Some(value) = map.get(component) { - current = value.clone(); - } else { - return Ok(None); - }, - Value::Array(arr) => { - // If component is numeric, use it as array index - if let Ok(index) = component.parse::() { - if let Some(value) = arr.get(index) { - current = value.clone(); - } else { - return Ok(None); - } - } else { - // Otherwise collect matching values from array elements - let values: Vec = arr - .iter() - .filter_map(|item| match item { - Value::Object(map) => map.get(component).cloned(), - _ => None, - }) - .collect(); - - if values.is_empty() { - return Ok(None); - } else if values.len() == 1 { - current = values[0].clone(); - } else { - return Ok(Some(Value::Array(values))); - } - } - }, - _ => return Ok(None), - } - } - - Ok(Some(current)) -} - -/// Apply a transform to a JSON value -fn apply_transform(value: &Value, transform: &Transform) -> Result { - dbg!(&value); - match transform { - Transform::Replace { pattern, replacement } => { - let text = value.as_str().ok_or_else(|| { - LearnerError::ApiError("Replace transform requires string input".to_string()) - })?; - let re = - Regex::new(pattern).map_err(|e| LearnerError::ApiError(format!("Invalid regex: {}", e)))?; - Ok(Value::String(re.replace_all(text, replacement.as_str()).into_owned())) - }, - - Transform::Date { from_format, to_format } => { - let text = value.as_str().ok_or_else(|| { - LearnerError::ApiError("Date transform requires string input".to_string()) - })?; - let dt = chrono::NaiveDateTime::parse_from_str(text, from_format) - .map_err(|e| LearnerError::ApiError(format!("Invalid date: {}", e)))?; - Ok(Value::String(dt.format(to_format).to_string())) - }, - - Transform::Url { base, suffix } => { - let text = value - .as_str() - .ok_or_else(|| LearnerError::ApiError("URL transform requires string input".to_string()))?; - Ok(Value::String(format!( - "{}{}", - base.replace("{value}", text), - suffix.as_deref().unwrap_or("") - ))) - }, - - Transform::Compose { sources, format } => { - // Extract values from each source - let values: Vec = sources - .iter() - .filter_map(|source| match source { - Source::Path(path) => { - let components: Vec<&str> = path.split('/').collect(); - get_path_value(value, &components).ok().flatten() - }, - Source::Literal(text) => Some(Value::String(text.clone())), - Source::KeyValue { key: _, path } => { - let components: Vec<&str> = path.split('/').collect(); - get_path_value(value, &components).ok().flatten() - }, - }) - .collect(); - - dbg!(&values); - - // Apply the format to the collected values - match format { - ComposeFormat::Join { delimiter } => { - // Convert values to strings and join - let strings: Vec = values - .iter() - .filter_map(|v| match v { - Value::String(s) => Some(s.clone()), - Value::Array(arr) if arr.len() == 1 => arr[0].as_str().map(|s| s.to_string()), - _ => None, - }) - .collect(); - Ok(Value::String(strings.join(delimiter))) - }, - - ComposeFormat::Object => { - dbg!("inside here"); - let mut obj = Map::new(); - dbg!(&sources); - for (source, value) in sources.iter().zip(values.iter()) { - dbg!(&source); - if let Source::KeyValue { key, .. } = source { - dbg!(key); - obj.insert(key.clone(), value.clone()); - } - } - dbg!(&obj); - Ok(Value::Object(obj)) - }, - - ComposeFormat::ArrayOfObjects { template } => { - match value { - // Handle single string -> array of objects - Value::String(s) => { - let mut obj = Map::new(); - for (key, template_value) in template { - let value = template_value.replace("{value}", s); - obj.insert(key.clone(), Value::String(value)); - } - Ok(Value::Array(vec![Value::Object(obj)])) - }, - - // Handle array -> array of objects - Value::Array(arr) => { - dbg!(&arr); - let objects: Vec = arr - .iter() - .filter_map(|item| { - dbg!(&item); - let mut obj = Map::new(); - for (key, template_value) in template { - let value = match item { - Value::String(s) => template_value.replace("{value}", s), - Value::Object(obj) => { - dbg!(obj); - let mut keys_and_vals = Vec::new(); - sources.iter().for_each(|source| { - if let Source::KeyValue { key, path } = source { - if let Some(val) = obj.get(path) { - keys_and_vals.push((key, val)) - } - } - }); - dbg!(&key); - keys_and_vals.into_iter().fold(template_value.clone(), |acc, (k, v)| { - let replacement = format!("{{{k}}}"); - acc.replace(&replacement, v.as_str().unwrap_or_default()) - }) - }, - _ => return None, - }; - obj.insert(key.clone(), Value::String(value)); - } - Some(Value::Object(obj)) - }) - .collect(); - Ok(Value::Array(objects)) - }, - - _ => Err(LearnerError::ApiError( - "ArrayOfObjects transform requires string or array input".to_string(), - )), - } - }, - } - }, - } -} From 9d20ab1bb8e667f47a536f0d974a04845e01c00a Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 06:20:21 -0700 Subject: [PATCH 44/73] continuing cleanup --- crates/learner/src/retriever/config.rs | 13 +- crates/learner/src/retriever/mod.rs | 4 +- .../{response/mod.rs => response.rs} | 116 ++++++++++++++- crates/learner/src/retriever/response/xml.rs | 136 ------------------ 4 files changed, 122 insertions(+), 147 deletions(-) rename crates/learner/src/retriever/{response/mod.rs => response.rs} (50%) delete mode 100644 crates/learner/src/retriever/response/xml.rs diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 5c435b7..657f6a7 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,7 +1,5 @@ -use serde_json::Map; -use template::{FieldDefinition, Resource, Template}; - use super::*; +use crate::template::{FieldDefinition, Resource, Template}; #[derive(Debug, Clone, Deserialize)] pub struct Retriever { @@ -16,12 +14,11 @@ pub struct Retriever { pub source: String, /// Template for constructing API endpoint URLs pub endpoint_template: String, - // TODO: This is now more like "how to get the thing to map into the resource" - // #[serde(flatten)] - pub response_format: ResponseFormat, + + pub response_format: ResponseFormat, /// Optional HTTP headers for API requests #[serde(default)] - pub headers: BTreeMap, + pub headers: BTreeMap, #[serde(rename = "resource")] pub resource_template: Template, @@ -78,7 +75,7 @@ impl Retriever { // Process the response using configured processor let json = match &self.response_format { - ResponseFormat::Xml { strip_namespaces } => xml::convert_to_json(&data, *strip_namespaces), + ResponseFormat::Xml { strip_namespaces } => xml_to_json(&data, *strip_namespaces), ResponseFormat::Json => serde_json::from_slice(&data)?, }; diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 969506d..c12f82d 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -5,7 +5,9 @@ mod response; pub use config::*; pub use response::*; -use template::Resource; +use serde_json::Map; + +use crate::template::Resource; // TODO: This should be `BTreeMap>` #[derive(Default, Debug, Clone)] diff --git a/crates/learner/src/retriever/response/mod.rs b/crates/learner/src/retriever/response.rs similarity index 50% rename from crates/learner/src/retriever/response/mod.rs rename to crates/learner/src/retriever/response.rs index 2e11a24..97840d8 100644 --- a/crates/learner/src/retriever/response/mod.rs +++ b/crates/learner/src/retriever/response.rs @@ -1,6 +1,6 @@ -use super::*; +use quick_xml::{events::Event, Reader}; -pub mod xml; +use super::*; /// Available response format handlers. /// @@ -137,3 +137,115 @@ pub enum ComposeFormat { template: BTreeMap, }, } + +pub fn xml_to_json(data: &[u8], strip_namespaces: bool) -> Value { + // Handle namespace stripping + let xml = if strip_namespaces { + strip_xml_namespaces(&String::from_utf8_lossy(data)) + } else { + String::from_utf8_lossy(data).to_string() + }; + + trace!("Processing XML response: {:#?}", &xml); + let mut reader = Reader::from_str(&xml); + let mut stack = Vec::new(); + let mut current = Map::new(); + + while let Ok(event) = reader.read_event() { + match event { + Event::Start(ref e) => { + let tag = String::from_utf8_lossy(e.name().as_ref()).to_string(); + + // Create new object for this element + let mut new_obj = Map::new(); + + // Handle attributes + for attr in e.attributes().flatten() { + if let Ok(key) = String::from_utf8(attr.key.as_ref().to_vec()) { + if let Ok(value) = attr.unescape_value() { + new_obj.insert(format!("@{}", key), Value::String(value.into_owned())); + } + } + } + + // Add this element to its parent + match current.get_mut(&tag) { + Some(Value::Array(_)) => { + // Element already exists as array, push onto it later + stack.push((tag, current, true)); + }, + Some(_) => { + // Element exists but not as array, convert to array + let existing = current.remove(&tag).unwrap(); + current.insert(tag.clone(), Value::Array(vec![existing])); + stack.push((tag, current, true)); + }, + None => { + // First occurrence of this element + stack.push((tag, current, false)); + }, + } + + current = new_obj; + }, + Event::Text(e) => { + if let Ok(txt) = e.unescape() { + let text = txt.trim(); + if !text.is_empty() { + if current.is_empty() { + // No attributes, just text content + current.insert("$text".to_string(), Value::String(text.to_string())); + } else { + // Has attributes, add text alongside them + current.insert("$text".to_string(), Value::String(text.to_string())); + } + } + } + }, + Event::End(_) => { + if let Some((tag, mut parent, is_array)) = stack.pop() { + // Simplify if only text content + let value = if current.len() == 1 && current.contains_key("$text") { + current.remove("$text").unwrap() + } else { + Value::Object(current) + }; + + // Add to parent according to array status + if is_array { + if let Some(Value::Array(arr)) = parent.get_mut(&tag) { + arr.push(value); + } + } else { + parent.insert(tag, value); + } + + current = parent; + } + }, + Event::Eof => break, + _ => (), + } + } + + dbg!(Value::Object(current)) +} + +/// Removes XML namespace declarations and prefixes from content. +/// +/// Strips both namespace declarations (xmlns attributes) and namespace +/// prefixes from element names for simpler path-based access. +/// +/// # Arguments +/// +/// * `xml` - Raw XML content +/// +/// # Returns +/// +/// XML content with namespaces removed +fn strip_xml_namespaces(xml: &str) -> String { + let re = regex::Regex::new(r#"xmlns(?::\w+)?="[^"]*""#).unwrap(); + let mut result = re.replace_all(xml, "").to_string(); + result = result.replace("oai_dc:", "").replace("dc:", ""); + result +} diff --git a/crates/learner/src/retriever/response/xml.rs b/crates/learner/src/retriever/response/xml.rs deleted file mode 100644 index 99afc98..0000000 --- a/crates/learner/src/retriever/response/xml.rs +++ /dev/null @@ -1,136 +0,0 @@ -//! XML response parser implementation. -//! -//! This module handles parsing of XML API responses into Paper objects using -//! configurable field mappings. It provides namespace handling and path-based -//! field extraction with optional transformations. -//! -//! # Example Configuration -//! -//! ```toml -//! [response_format] -//! type = "xml" -//! strip_namespaces = true -//! -//! [response_format.field_maps] -//! title = { path = "entry/title" } -//! abstract = { path = "entry/summary" } -//! publication_date = { path = "entry/published" } -//! authors = { path = "entry/author/name" } -//! ``` - -use quick_xml::{events::Event, Reader}; -use serde_json::{Map, Value}; - -use super::*; - -pub fn convert_to_json(data: &[u8], strip_namespaces: bool) -> Value { - // Handle namespace stripping - let xml = if strip_namespaces { - strip_xml_namespaces(&String::from_utf8_lossy(data)) - } else { - String::from_utf8_lossy(data).to_string() - }; - - trace!("Processing XML response: {:#?}", &xml); - let mut reader = Reader::from_str(&xml); - let mut stack = Vec::new(); - let mut current = Map::new(); - - while let Ok(event) = reader.read_event() { - match event { - Event::Start(ref e) => { - let tag = String::from_utf8_lossy(e.name().as_ref()).to_string(); - - // Create new object for this element - let mut new_obj = Map::new(); - - // Handle attributes - for attr in e.attributes().flatten() { - if let Ok(key) = String::from_utf8(attr.key.as_ref().to_vec()) { - if let Ok(value) = attr.unescape_value() { - new_obj.insert(format!("@{}", key), Value::String(value.into_owned())); - } - } - } - - // Add this element to its parent - match current.get_mut(&tag) { - Some(Value::Array(_)) => { - // Element already exists as array, push onto it later - stack.push((tag, current, true)); - }, - Some(_) => { - // Element exists but not as array, convert to array - let existing = current.remove(&tag).unwrap(); - current.insert(tag.clone(), Value::Array(vec![existing])); - stack.push((tag, current, true)); - }, - None => { - // First occurrence of this element - stack.push((tag, current, false)); - }, - } - - current = new_obj; - }, - Event::Text(e) => { - if let Ok(txt) = e.unescape() { - let text = txt.trim(); - if !text.is_empty() { - if current.is_empty() { - // No attributes, just text content - current.insert("$text".to_string(), Value::String(text.to_string())); - } else { - // Has attributes, add text alongside them - current.insert("$text".to_string(), Value::String(text.to_string())); - } - } - } - }, - Event::End(_) => { - if let Some((tag, mut parent, is_array)) = stack.pop() { - // Simplify if only text content - let value = if current.len() == 1 && current.contains_key("$text") { - current.remove("$text").unwrap() - } else { - Value::Object(current) - }; - - // Add to parent according to array status - if is_array { - if let Some(Value::Array(arr)) = parent.get_mut(&tag) { - arr.push(value); - } - } else { - parent.insert(tag, value); - } - - current = parent; - } - }, - Event::Eof => break, - _ => (), - } - } - - dbg!(Value::Object(current)) -} - -/// Removes XML namespace declarations and prefixes from content. -/// -/// Strips both namespace declarations (xmlns attributes) and namespace -/// prefixes from element names for simpler path-based access. -/// -/// # Arguments -/// -/// * `xml` - Raw XML content -/// -/// # Returns -/// -/// XML content with namespaces removed -fn strip_xml_namespaces(xml: &str) -> String { - let re = regex::Regex::new(r#"xmlns(?::\w+)?="[^"]*""#).unwrap(); - let mut result = re.replace_all(xml, "").to_string(); - result = result.replace("oai_dc:", "").replace("dc:", ""); - result -} From 0d3ef327777e51c11a71e7a002638b66f8f73fa5 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 06:24:57 -0700 Subject: [PATCH 45/73] clean --- crates/learner/src/retriever/config.rs | 124 ++++++++++--------------- 1 file changed, 47 insertions(+), 77 deletions(-) diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 657f6a7..0f6971c 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -249,26 +249,23 @@ fn get_path_value(json: &Value, path: &[&str]) -> Option { /// Apply a transform to a JSON value fn apply_transform(value: &Value, transform: &Transform) -> Result { - dbg!(&value); match transform { Transform::Replace { pattern, replacement } => { let text = value.as_str().ok_or_else(|| { LearnerError::ApiError("Replace transform requires string input".to_string()) })?; let re = - Regex::new(pattern).map_err(|e| LearnerError::ApiError(format!("Invalid regex: {}", e)))?; + Regex::new(pattern).map_err(|e| LearnerError::ApiError(format!("Invalid regex: {e}")))?; Ok(Value::String(re.replace_all(text, replacement.as_str()).into_owned())) }, - Transform::Date { from_format, to_format } => { let text = value.as_str().ok_or_else(|| { LearnerError::ApiError("Date transform requires string input".to_string()) })?; let dt = chrono::NaiveDateTime::parse_from_str(text, from_format) - .map_err(|e| LearnerError::ApiError(format!("Invalid date: {}", e)))?; + .map_err(|e| LearnerError::ApiError(format!("Invalid date: {e}")))?; Ok(Value::String(dt.format(to_format).to_string())) }, - Transform::Url { base, suffix } => { let text = value .as_str() @@ -279,109 +276,82 @@ fn apply_transform(value: &Value, transform: &Transform) -> Result { suffix.as_deref().unwrap_or("") ))) }, - Transform::Compose { sources, format } => { - // Extract values from each source let values: Vec = sources .iter() .filter_map(|source| match source { - Source::Path(path) => { + Source::Path(path) | Source::KeyValue { key: _, path } => { let components: Vec<&str> = path.split('/').collect(); get_path_value(value, &components) }, Source::Literal(text) => Some(Value::String(text.clone())), - Source::KeyValue { key: _, path } => { - let components: Vec<&str> = path.split('/').collect(); - get_path_value(value, &components) - }, }) .collect(); - - dbg!(&values); - - // Apply the format to the collected values match format { ComposeFormat::Join { delimiter } => { - // Convert values to strings and join let strings: Vec = values .iter() .filter_map(|v| match v { Value::String(s) => Some(s.clone()), - Value::Array(arr) if arr.len() == 1 => arr[0].as_str().map(|s| s.to_string()), + Value::Array(arr) if arr.len() == 1 => + arr[0].as_str().map(std::string::ToString::to_string), _ => None, }) .collect(); Ok(Value::String(strings.join(delimiter))) }, - ComposeFormat::Object => { - dbg!("inside here"); let mut obj = Map::new(); - dbg!(&sources); for (source, value) in sources.iter().zip(values.iter()) { - dbg!(&source); if let Source::KeyValue { key, .. } = source { - dbg!(key); obj.insert(key.clone(), value.clone()); } } - dbg!(&obj); Ok(Value::Object(obj)) }, - - ComposeFormat::ArrayOfObjects { template } => { - match value { - // Handle single string -> array of objects - Value::String(s) => { - let mut obj = Map::new(); - for (key, template_value) in template { - let value = template_value.replace("{value}", s); - obj.insert(key.clone(), Value::String(value)); - } - Ok(Value::Array(vec![Value::Object(obj)])) - }, - - // Handle array -> array of objects - Value::Array(arr) => { - dbg!(&arr); - let objects: Vec = arr - .iter() - .filter_map(|item| { - dbg!(&item); - let mut obj = Map::new(); - for (key, template_value) in template { - let value = match item { - Value::String(s) => template_value.replace("{value}", s), - Value::Object(obj) => { - dbg!(obj); - let mut keys_and_vals = Vec::new(); - sources.iter().for_each(|source| { - if let Source::KeyValue { key, path } = source { - if let Some(val) = obj.get(path) { - keys_and_vals.push((key, val)) - } + ComposeFormat::ArrayOfObjects { template } => match value { + Value::String(s) => { + let mut obj = Map::new(); + for (key, template_value) in template { + let value = template_value.replace("{value}", s); + obj.insert(key.clone(), Value::String(value)); + } + Ok(Value::Array(vec![Value::Object(obj)])) + }, + Value::Array(arr) => { + let objects: Vec = arr + .iter() + .filter_map(|item| { + let mut obj = Map::new(); + for (key, template_value) in template { + let value = match item { + Value::String(s) => template_value.replace("{value}", s), + Value::Object(obj) => { + let mut keys_and_vals = Vec::new(); + for source in sources { + if let Source::KeyValue { key, path } = source { + if let Some(val) = obj.get(path) { + keys_and_vals.push((key, val)); } - }); - dbg!(&key); - keys_and_vals.into_iter().fold(template_value.clone(), |acc, (k, v)| { - let replacement = format!("{{{k}}}"); - acc.replace(&replacement, v.as_str().unwrap_or_default()) - }) - }, - _ => return None, - }; - obj.insert(key.clone(), Value::String(value)); - } - Some(Value::Object(obj)) - }) - .collect(); - Ok(Value::Array(objects)) - }, - - _ => Err(LearnerError::ApiError( - "ArrayOfObjects transform requires string or array input".to_string(), - )), - } + } + } + keys_and_vals.into_iter().fold(template_value.clone(), |acc, (k, v)| { + let replacement = format!("{{{k}}}"); + acc.replace(&replacement, v.as_str().unwrap_or_default()) + }) + }, + _ => return None, + }; + obj.insert(key.clone(), Value::String(value)); + } + Some(Value::Object(obj)) + }) + .collect(); + Ok(Value::Array(objects)) + }, + _ => Err(LearnerError::ApiError( + "ArrayOfObjects transform requires string or array input".to_string(), + )), }, } }, From 9eceb6184e1ec59bfac77d0227eedaa01f918fdc Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 06:55:55 -0700 Subject: [PATCH 46/73] retreival template --- crates/learner/config_new/arxiv.toml | 28 ++++------ crates/learner/config_new/retrieval.toml | 40 ++++++++++++++ crates/learner/src/configuration.rs | 29 ++++++---- crates/learner/src/retriever/config.rs | 29 ++++++++-- crates/learner/src/retriever/response.rs | 53 ------------------- .../tests/workflows/paper_retrieval.rs | 19 +------ 6 files changed, 97 insertions(+), 101 deletions(-) create mode 100644 crates/learner/config_new/retrieval.toml diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml index ef4d446..05745c0 100644 --- a/crates/learner/config_new/arxiv.toml +++ b/crates/learner/config_new/arxiv.toml @@ -2,31 +2,25 @@ name = "arxiv" description = "Retriever template for getting a paper from arXiv" +resource_template = "paper" +retrieval_template = "retrieval" + base_url = "http://export.arxiv.org" endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}&max_results=1" pattern = "(?:^|https?://arxiv\\.org/(?:abs|pdf)/)(\\d{4}\\.\\d{4,5}|[a-zA-Z-]+(?:\\.[A-Z]{2})?/\\d{7})(?:\\.pdf)?$" -resource = "paper" source = "arxiv" response_format = { type = "xml", strip_namespaces = true } +[resource_mappings] +abstract = { path = "feed/entry/summary" } +authors = { path = "feed/entry/author" } +publication_dates = { path = "feed/entry/published" } +title = { path = "feed/entry/title" } -[resource_mappings.title] -path = "feed/entry/title" - -[resource_mappings.abstract] -path = "feed/entry/summary" - -[resource_mappings.authors] -path = "feed/entry/author" - -[resource_mappings.publication_dates] -path = "feed/entry/published" - -[retrieval_data.urls] -path = "feed/entry/id" - -[retrieval_data.urls.transform] +[retrieval_mappings] +urls = { path = "feed/entry/id" } +[urls.transform] pattern = "/abs/" replacement = "/pdf/" type = "Replace" diff --git a/crates/learner/config_new/retrieval.toml b/crates/learner/config_new/retrieval.toml new file mode 100644 index 0000000..18c2d24 --- /dev/null +++ b/crates/learner/config_new/retrieval.toml @@ -0,0 +1,40 @@ +name = "retrieval" + +description = "Standard retrieval data template" + +[source] +field_type = "string" +required = false + +[source_identifier] +field_type = "string" +required = false + +[urls] +field_type = "object" +required = true + +[urls.type_definition] +fields = [ + { name = "pdf", field_type = "string", required = false }, + { name = "html", field_type = "string", required = false }, +] + +[doi] +field_type = "string" +required = false + +[last_checked] +field_type = "string" +required = false +validation = { datetime = true } + +[access_type] +field_type = "string" +required = false +validation = { enum_values = ["open", "subscription", "institutional"] } + +[verified] +default = true +field_type = "boolean" +required = true diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 8a4796f..b230a53 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -21,20 +21,25 @@ impl ConfigurationManager { pub fn load_config(&mut self, path: impl AsRef) -> Result where T: DeserializeOwned + std::fmt::Debug { let path = path.as_ref(); - let content = std::fs::read_to_string(path)?; - let mut raw_config: toml::Value = toml::from_str(&content)?; + let content = dbg!(std::fs::read_to_string(path)?); + let mut raw_config: toml::Value = dbg!(toml::from_str(&content)?); // If this is a Retriever config, handle resource reference if std::any::type_name::() == std::any::type_name::() { - if let Some(toml::Value::String(resource_name)) = raw_config.get("resource") { - // Load the referenced resource - let resource_path = self.config_paths.join(format!("{resource_name}.toml")); - let resource_content = std::fs::read_to_string(resource_path)?; - let resource_config: toml::Value = toml::from_str(&resource_content)?; - - // Replace the string reference with the resource config - if let Some(table) = raw_config.as_table_mut() { - table.insert("resource".into(), resource_config); + // Handle both resource and retrieval templates + let template_fields = ["resource_template", "retrieval_template"]; + + for field in &template_fields { + if let Some(toml::Value::String(template_name)) = raw_config.get(field) { + // Load the referenced template + let template_path = self.config_paths.join(format!("{template_name}.toml")); + let template_content = std::fs::read_to_string(template_path)?; + let template_config: toml::Value = toml::from_str(&template_content)?; + + // Replace the string reference with the template config + if let Some(table) = raw_config.as_table_mut() { + table.insert((*field).to_string(), template_config); + } } } } @@ -58,6 +63,8 @@ mod tests { // Load configurations in order let paper: Template = dbg!(manager.load_config("config_new/paper.toml").unwrap()); + let retreival: Template = dbg!(manager.load_config("config_new/retrieval.toml")).unwrap(); + let arxiv_retriever: Retriever = dbg!(manager.load_config("config_new/arxiv.toml").unwrap()); todo!("Clean this up") diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 0f6971c..89e0c3e 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -20,11 +20,11 @@ pub struct Retriever { #[serde(default)] pub headers: BTreeMap, - #[serde(rename = "resource")] pub resource_template: Template, #[serde(default)] pub resource_mappings: BTreeMap, + pub retrieval_template: Template, #[serde(default)] pub retrieval_mappings: BTreeMap, } @@ -79,8 +79,6 @@ impl Retriever { ResponseFormat::Json => serde_json::from_slice(&data)?, }; - // Process response and get resource - // TODO: this should probably be a method let mut resource = self.process_json_value(&json)?; // Add source metadata @@ -114,6 +112,31 @@ impl Retriever { } } +fn process_template_fields( + json: &Value, + template: &Template, + mappings: &BTreeMap, +) -> Result> { + let mut result = BTreeMap::new(); + + for field_def in &template.fields { + if let Some(field_map) = mappings.get(&field_def.name) { + if let Some(value) = extract_mapped_value(json, field_map, field_def)? { + result.insert(field_def.name.clone(), value); + } else if field_def.required { + return Err(LearnerError::ApiError(format!( + "Required field '{}' not found in response", + field_def.name + ))); + } else if let Some(default) = &field_def.default { + result.insert(field_def.name.clone(), default.clone()); + } + } + } + + Ok(result) +} + /// Extract and transform a value from JSON using a field mapping fn extract_mapped_value( json: &Value, diff --git a/crates/learner/src/retriever/response.rs b/crates/learner/src/retriever/response.rs index 97840d8..e93d85d 100644 --- a/crates/learner/src/retriever/response.rs +++ b/crates/learner/src/retriever/response.rs @@ -2,31 +2,6 @@ use quick_xml::{events::Event, Reader}; use super::*; -/// Available response format handlers. -/// -/// Specifies how to parse and extract paper metadata from API responses -/// in different formats. -/// -/// # Examples -/// -/// XML configuration: -/// ```toml -/// [response_format] -/// type = "xml" -/// strip_namespaces = true -/// -/// [response_format.field_maps] -/// title = { path = "entry/title" } -/// ``` -/// -/// JSON configuration: -/// ```toml -/// [response_format] -/// type = "json" -/// -/// [response_format.field_maps] -/// title = { path = "message/title/0" } -/// ``` #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ResponseFormat { @@ -41,17 +16,6 @@ pub enum ResponseFormat { Json, } -/// Field mapping configuration. -/// -/// Defines how to extract and transform specific fields from API responses. -/// -/// # Examples -/// -/// ```toml -/// [field_maps.title] -/// path = "entry/title" -/// transform = { type = "replace", pattern = "\\s+", replacement = " " } -/// ``` #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FieldMap { /// Path to field in response (e.g., JSON path or XPath) @@ -61,23 +25,6 @@ pub struct FieldMap { pub transform: Option, } -/// Available field value transformations. -/// -/// Transformations that can be applied to extracted field values -/// before constructing the final Paper object. -/// -/// # Examples -/// -/// ```toml -/// # Clean up whitespace -/// transform = { type = "replace", pattern = "\\s+", replacement = " " } -/// -/// # Convert date format -/// transform = { type = "date", from_format = "%Y-%m-%d", to_format = "%Y-%m-%dT00:00:00Z" } -/// -/// # Construct full URL -/// transform = { type = "url", base = "https://example.com/", suffix = ".pdf" } -/// ``` #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum Transform { diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index 39186b3..0409449 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -1,28 +1,15 @@ use std::fs; -use learner::{configuration::ConfigurationManager, template::Template}; +use learner::configuration::ConfigurationManager; use super::*; #[traced_test] #[tokio::test] async fn test_arxiv_retriever_integration() -> TestResult<()> { - // let ret_config_str = fs::read_to_string("config/retrievers/arxiv.toml").expect( - // "Failed to read config - // file", - // ); - // let res_config_str = fs::read_to_string("config/resources/paper.toml").expect( - // "Failed to read config - // file", - // ); let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); let retriever: Retriever = manager.load_config("config_new/arxiv.toml")?; - // let retriever: Retriever = toml::from_str(&ret_config_str).expect("Failed to parse config"); - // let resource: ResourceTemplate = toml::from_str(&res_config_str).expect("Failed to parse - // config"); - - // Test with a real arXiv paper let paper = retriever.retrieve_resource("2301.07041").await?; dbg!(&paper); @@ -70,10 +57,7 @@ async fn test_arxiv_pdf_from_paper() -> TestResult<()> { async fn test_iacr_retriever_integration() -> TestResult<()> { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); let retriever: Retriever = manager.load_config("config_new/iacr.toml")?; - // let resource: ResourceTemplate = toml::from_str(&res_config_str).expect("Failed to parse - // config"); - // // Test with a real IACR paper let paper = retriever.retrieve_resource("2016/260").await.unwrap(); // assert!(resource.validate(&paper)?); // TODO: validation already happens internally, to be fair // that validation may not be working totally right @@ -121,6 +105,7 @@ async fn test_doi_retriever_integration() -> TestResult<()> { let paper = retriever.retrieve_resource("10.1145/1327452.1327492").await?; // assert!(resource.validate(&paper)?); dbg!(&paper); + todo!("Clean this up"); // assert!(!paper.title.is_empty()); // assert!(!paper.authors.is_empty()); // assert!(!paper.abstract_text.is_empty()); From 73d62a5914166b001a0edb74828fd77e83f77717 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 07:42:18 -0700 Subject: [PATCH 47/73] WIP: adjusted `get_path_value` --- crates/learner/config_new/arxiv.toml | 13 ++++--- crates/learner/config_new/iacr.toml | 4 +- crates/learner/config_new/retrieval.toml | 7 +--- crates/learner/src/error.rs | 4 +- crates/learner/src/record.rs | 10 ++--- crates/learner/src/retriever/config.rs | 46 ++++++++++------------- crates/learner/src/retriever/mod.rs | 4 +- crates/learner/src/template.rs | 48 ++++++++++++------------ 8 files changed, 63 insertions(+), 73 deletions(-) diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml index 05745c0..60ca13e 100644 --- a/crates/learner/config_new/arxiv.toml +++ b/crates/learner/config_new/arxiv.toml @@ -18,12 +18,13 @@ authors = { path = "feed/entry/author" } publication_dates = { path = "feed/entry/published" } title = { path = "feed/entry/title" } -[retrieval_mappings] -urls = { path = "feed/entry/id" } -[urls.transform] -pattern = "/abs/" -replacement = "/pdf/" -type = "Replace" +[retrieval_mappings.urls] +path = "feed/entry/id" +[retrieval_mappings.urls.transform] +format = { type = "Object", template = { "pdf" = "{value}" } } +sources = [{ type = "path", value = "feed/entry/id" }] +type = "Compose" + [headers] Accept = "application/xml" diff --git a/crates/learner/config_new/iacr.toml b/crates/learner/config_new/iacr.toml index bb8bb0a..2bf74cd 100644 --- a/crates/learner/config_new/iacr.toml +++ b/crates/learner/config_new/iacr.toml @@ -2,10 +2,12 @@ name = "iacr" description = "Retriever template for getting a paper from IACR" +resource_template = "paper" +retrieval_template = "retrieval" + base_url = "https://eprint.iacr.org" endpoint_template = "https://eprint.iacr.org/oai?verb=GetRecord&identifier=oai:eprint.iacr.org:{identifier}&metadataPrefix=oai_dc" pattern = "(?:^|https?://eprint\\.iacr\\.org/)(\\d{4}/\\d+)(?:\\.pdf)?$" -resource = "paper" source = "iacr" response_format = { type = "xml", strip_namespaces = true } diff --git a/crates/learner/config_new/retrieval.toml b/crates/learner/config_new/retrieval.toml index 18c2d24..12d2c0d 100644 --- a/crates/learner/config_new/retrieval.toml +++ b/crates/learner/config_new/retrieval.toml @@ -12,7 +12,7 @@ required = false [urls] field_type = "object" -required = true +required = false # Consider changing [urls.type_definition] fields = [ @@ -33,8 +33,3 @@ validation = { datetime = true } field_type = "string" required = false validation = { enum_values = ["open", "subscription", "institutional"] } - -[verified] -default = true -field_type = "boolean" -required = true diff --git a/crates/learner/src/error.rs b/crates/learner/src/error.rs index c4a2bb9..b902e5c 100644 --- a/crates/learner/src/error.rs +++ b/crates/learner/src/error.rs @@ -251,6 +251,6 @@ pub enum LearnerError { #[error(transparent)] SerdeJson(#[from] serde_json::Error), - #[error("Failed to be a valid resource due to: {0}")] - InvalidResource(String), + #[error("Failed to be a valid item due to: {0}")] + TemplateInvalidation(String), } diff --git a/crates/learner/src/record.rs b/crates/learner/src/record.rs index 4c8efab..85ecf16 100644 --- a/crates/learner/src/record.rs +++ b/crates/learner/src/record.rs @@ -1,4 +1,4 @@ -use template::Resource; +use template::TemplatedItem; use super::*; @@ -8,7 +8,7 @@ use super::*; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Record { /// The resource type this record manages - pub resource: Resource, + pub resource: TemplatedItem, /// State tracking configuration pub state: State, @@ -17,7 +17,7 @@ pub struct Record { pub storage: StorageData, /// Retrieval configuration - pub retrieval: RetrievalData, + pub retrieval: TemplatedItem, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -59,12 +59,12 @@ pub struct RetrievalData { pub verified: bool, // Whether we've confirmed this data } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StorageData { pub files: BTreeMap, pub original_filenames: BTreeMap, pub added_at: BTreeMap>, pub file_sizes: BTreeMap, // Track file sizes pub checksums: BTreeMap, // For integrity checking - pub last_verified: DateTime, // When we last checked files exist + pub last_verified: Option>, // When we last checked files exist } diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 89e0c3e..5cf0a3e 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,5 +1,7 @@ +use record::{Record, State, StorageData}; + use super::*; -use crate::template::{FieldDefinition, Resource, Template}; +use crate::template::{FieldDefinition, Template, TemplatedItem}; #[derive(Debug, Clone, Deserialize)] pub struct Retriever { @@ -53,7 +55,7 @@ impl Retriever { .ok_or(LearnerError::InvalidIdentifier) } - pub async fn retrieve_resource(&self, input: &str) -> Result { + pub async fn retrieve_resource(&self, input: &str) -> Result { let identifier = self.extract_identifier(input)?; // Send request and get response @@ -79,7 +81,7 @@ impl Retriever { ResponseFormat::Json => serde_json::from_slice(&data)?, }; - let mut resource = self.process_json_value(&json)?; + let (mut resource, retrieval) = self.process_json_value(&json)?; // Add source metadata resource.insert("source".into(), Value::String(self.source.clone())); @@ -87,28 +89,17 @@ impl Retriever { // Validate full resource against config self.resource_template.validate(&resource)?; - Ok(resource) + self.retrieval_template.validate(&retrieval)?; + + Ok(Record { resource, state: State::default(), storage: StorageData::default(), retrieval }) } - pub fn process_json_value(&self, json: &Value) -> Result { - let mut resource = Resource::new(); - - for field_def in &self.resource_template.fields { - if let Some(field_map) = self.resource_mappings.get(&field_def.name) { - if let Some(value) = extract_mapped_value(json, field_map, field_def)? { - resource.insert(field_def.name.clone(), value); - } else if field_def.required { - return Err(LearnerError::ApiError(format!( - "Required field '{}' not found in response", - field_def.name - ))); - } else if let Some(default) = &field_def.default { - resource.insert(field_def.name.clone(), default.clone()); - } - } - } + pub fn process_json_value(&self, json: &Value) -> Result<(TemplatedItem, TemplatedItem)> { + let resource = process_template_fields(json, &self.resource_template, &self.resource_mappings)?; + let retrieval = + process_template_fields(json, &self.retrieval_template, &self.retrieval_mappings)?; - Ok(resource) + Ok((resource, retrieval)) } } @@ -120,7 +111,7 @@ fn process_template_fields( let mut result = BTreeMap::new(); for field_def in &template.fields { - if let Some(field_map) = mappings.get(&field_def.name) { + if let Some(field_map) = mappings.get(dbg!(&field_def.name)) { if let Some(value) = extract_mapped_value(json, field_map, field_def)? { result.insert(field_def.name.clone(), value); } else if field_def.required { @@ -263,7 +254,7 @@ fn get_path_value(json: &Value, path: &[&str]) -> Option { } } }, - _ => return None, + _ => return Some(json.clone()), } } @@ -300,7 +291,7 @@ fn apply_transform(value: &Value, transform: &Transform) -> Result { ))) }, Transform::Compose { sources, format } => { - let values: Vec = sources + let values: Vec = dbg!(sources .iter() .filter_map(|source| match source { Source::Path(path) | Source::KeyValue { key: _, path } => { @@ -309,7 +300,7 @@ fn apply_transform(value: &Value, transform: &Transform) -> Result { }, Source::Literal(text) => Some(Value::String(text.clone())), }) - .collect(); + .collect()); match format { ComposeFormat::Join { delimiter } => { let strings: Vec = values @@ -325,8 +316,9 @@ fn apply_transform(value: &Value, transform: &Transform) -> Result { }, ComposeFormat::Object => { let mut obj = Map::new(); + dbg!(&sources); for (source, value) in sources.iter().zip(values.iter()) { - if let Source::KeyValue { key, .. } = source { + if let Source::KeyValue { key, .. } = dbg!(source) { obj.insert(key.clone(), value.clone()); } } diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index c12f82d..a6a6fcd 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -7,7 +7,7 @@ pub use config::*; pub use response::*; use serde_json::Map; -use crate::template::Resource; +use crate::template::TemplatedItem; // TODO: This should be `BTreeMap>` #[derive(Default, Debug, Clone)] @@ -58,7 +58,7 @@ impl Retrievers { /// ``` pub fn new() -> Self { Self::default() } - pub async fn get_resource_file(&self, input: &str) -> Result { + pub async fn get_resource_file(&self, input: &str) -> Result { todo!( "Arguably, we don't even need this. We could instead just have this handled by `Learner` so \ the API is simpler" diff --git a/crates/learner/src/template.rs b/crates/learner/src/template.rs index 1bfbc59..e78603d 100644 --- a/crates/learner/src/template.rs +++ b/crates/learner/src/template.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use super::*; // Type alias for clarity and consistency -pub type Resource = BTreeMap; +pub type TemplatedItem = BTreeMap; #[derive(Debug, Clone, Serialize)] pub struct Template { @@ -99,12 +99,12 @@ pub struct ValidationRules { } impl Template { - /// Validates a set of values against this resource configuration - pub fn validate(&self, resource: &Resource) -> Result { + // TODO: Make this just return a `Result<()>` + pub fn validate(&self, resource: &TemplatedItem) -> Result<()> { // Check required fields for field in &self.fields { if field.required && !resource.contains_key(&field.name) { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Missing required field: {}", field.name ))); @@ -119,7 +119,7 @@ impl Template { } } - Ok(true) + Ok(()) } /// Validates a single field value against its definition @@ -131,7 +131,7 @@ impl Template { // Length constraints if let Some(min_length) = rules.min_length { if v.len() < min_length { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' must be at least {} characters", field.name, min_length ))); @@ -139,7 +139,7 @@ impl Template { } if let Some(max_length) = rules.max_length { if v.len() > max_length { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' cannot exceed {} characters", field.name, max_length ))); @@ -149,9 +149,9 @@ impl Template { // Pattern matching via regex if let Some(pattern) = &rules.pattern { let re = Regex::new(pattern) - .map_err(|_| LearnerError::InvalidResource("Invalid regex pattern".into()))?; + .map_err(|_| LearnerError::TemplateInvalidation("Invalid regex pattern".into()))?; if !re.is_match(v) { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' must match pattern: {}", field.name, pattern ))); @@ -160,7 +160,7 @@ impl Template { // Datetime validation if specified if rules.datetime == Some(true) && DateTime::parse_from_rfc3339(v).is_err() { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' must be a valid RFC3339 datetime", field.name ))); @@ -169,7 +169,7 @@ impl Template { // Enumerated values check if let Some(allowed) = &rules.enum_values { if !allowed.contains(v) { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' must be one of: {:?}", field.name, allowed ))); @@ -194,7 +194,7 @@ impl Template { if let Some(rules) = &field.validation { if let Some(min_items) = rules.min_items { if v.len() < min_items { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' must have at least {} items", field.name, min_items ))); @@ -203,7 +203,7 @@ impl Template { if let Some(max_items) = rules.max_items { if v.len() > max_items { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' cannot exceed {} items", field.name, max_items ))); @@ -214,10 +214,10 @@ impl Template { let mut seen = HashSet::new(); for item in v { let item_str = serde_json::to_string(item).map_err(|_| { - LearnerError::InvalidResource("Failed to serialize array item".into()) + LearnerError::TemplateInvalidation("Failed to serialize array item".into()) })?; if !seen.insert(item_str) { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' contains duplicate items", field.name ))); @@ -234,7 +234,7 @@ impl Template { ("null", Value::Null) => Ok(()), // Type mismatch - provide a clear error message - _ => Err(LearnerError::InvalidResource(format!( + _ => Err(LearnerError::TemplateInvalidation(format!( "Field '{}' expected type '{}' but got '{}'", field.name, field.field_type, @@ -254,7 +254,7 @@ impl Template { fn validate_numeric(field: &FieldDefinition, value: f64, rules: &ValidationRules) -> Result<()> { if let Some(min) = rules.minimum { if value < min { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' must be at least {}", field.name, min ))); @@ -263,7 +263,7 @@ fn validate_numeric(field: &FieldDefinition, value: f64, rules: &ValidationRules if let Some(max) = rules.maximum { if value > max { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' cannot exceed {}", field.name, max ))); @@ -273,7 +273,7 @@ fn validate_numeric(field: &FieldDefinition, value: f64, rules: &ValidationRules if let Some(multiple) = rules.multiple_of { let ratio = value / multiple; if (ratio - ratio.round()).abs() > f64::EPSILON { - return Err(LearnerError::InvalidResource(format!( + return Err(LearnerError::TemplateInvalidation(format!( "Field '{}' must be a multiple of {}", field.name, multiple ))); @@ -290,7 +290,7 @@ pub fn datetime_to_json(dt: DateTime) -> String { dt.to_rfc3339() } pub fn datetime_from_json(s: &str) -> Result> { DateTime::parse_from_rfc3339(s) .map(|dt| dt.with_timezone(&Utc)) - .map_err(|e| LearnerError::InvalidResource(format!("Invalid datetime format: {}", e))) + .map_err(|e| LearnerError::TemplateInvalidation(format!("Invalid datetime format: {}", e))) } #[cfg(test)] mod tests { @@ -321,7 +321,7 @@ mod tests { ]); // Validate the paper - assert!(template.validate(&paper_resource).unwrap()); + template.validate(&paper_resource).unwrap(); // Test required field validation let invalid_paper = BTreeMap::from([ @@ -345,7 +345,7 @@ mod tests { ("publication_date".into(), json!(date)), ]); - assert!(template.validate(&book_resource).unwrap()); + template.validate(&book_resource).unwrap(); } #[test] @@ -364,7 +364,7 @@ mod tests { ("advisors".into(), json!(["Prof. Bob Supervisor"])), ]); - assert!(template.validate(&thesis_resource).unwrap()); + template.validate(&thesis_resource).unwrap(); // Test degree enum validation let mut invalid_thesis = thesis_resource.clone(); @@ -389,7 +389,7 @@ mod tests { }; let valid_resource = BTreeMap::from([("timestamp".into(), json!("2024-01-01T00:00:00Z"))]); - assert!(template.validate(&valid_resource).unwrap()); + template.validate(&valid_resource).unwrap(); let invalid_resource = BTreeMap::from([ ("timestamp".into(), json!("2024-01-01")), // Not RFC3339 From 7ac95aa2d00a8d8b2a005b456f3b05184ecccca0 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 08:11:32 -0700 Subject: [PATCH 48/73] WIP: gets back urls, but poorly formatted --- crates/learner/config_new/arxiv.toml | 19 ++++-- crates/learner/src/retriever/config.rs | 76 +++++++++++++----------- crates/learner/src/retriever/response.rs | 2 +- 3 files changed, 55 insertions(+), 42 deletions(-) diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml index 60ca13e..54a6584 100644 --- a/crates/learner/config_new/arxiv.toml +++ b/crates/learner/config_new/arxiv.toml @@ -19,11 +19,20 @@ publication_dates = { path = "feed/entry/published" } title = { path = "feed/entry/title" } [retrieval_mappings.urls] -path = "feed/entry/id" -[retrieval_mappings.urls.transform] -format = { type = "Object", template = { "pdf" = "{value}" } } -sources = [{ type = "path", value = "feed/entry/id" }] -type = "Compose" +path = "feed/entry/id" +transform.type = "Compose" + +[[retrieval_mappings.urls.transform.sources]] +transform.pattern = ".*/(?:abs|pdf)/(.+?)(?:v\\d+)?$" +transform.replacement = "$1" +transform.type = "Replace" +type = "path" +value = "feed/entry/id" + +[retrieval_mappings.urls.transform.format] +template.html = "http://arxiv.org/abs/{value}" +template.pdf = "http://arxiv.org/pdf/{value}" +type = "Object" [headers] diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 5cf0a3e..0415867 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -152,7 +152,7 @@ fn extract_mapped_value( }; // Then attempt type coercion based on field definition - let coerced = coerce_to_type(&value, field_def)?; + let coerced = dbg!(coerce_to_type(&value, field_def)?); Ok(Some(coerced)) } @@ -180,35 +180,31 @@ fn coerce_to_type(value: &Value, field_def: &FieldDefinition) -> Result { Ok(Value::Array(arr)) } }, - "object" => { - // If we have field definitions, ensure object has required structure - if let Some(ref type_def) = field_def.type_definition { - if let Some(fields) = &type_def.fields { - let mut obj = Map::new(); - match value { - // Convert string to {name: string} if that's the structure we want - Value::String(s) if fields.len() == 1 && fields[0].name == "name" => { - obj.insert("name".to_string(), Value::String(s.clone())); - Ok(Value::Object(obj)) - }, - Value::Object(m) => { - // Copy over matching fields with coercion - for field in fields { - if let Some(v) = m.get(&field.name) { - obj.insert(field.name.clone(), coerce_to_type(v, field)?); - } - } - Ok(Value::Object(obj)) - }, - _ => Ok(value.clone()), - } - } else { - Ok(value.clone()) - } - } else { - Ok(value.clone()) - } - }, + // "object" => { + // match value { + // Value::Object(m) => { + // if let Some(ref type_def) = field_def.type_definition { + // if let Some(fields) = &type_def.fields { + // let mut obj = Map::new(); + // // Copy over matching fields with coercion + // for field in fields { + // if let Some(v) = m.get(&field.name) { + // obj.insert(field.name.clone(), coerce_to_type(v, field)?); + // } + // } + // Ok(Value::Object(obj)) + // } else { + // // If no fields defined, preserve the original object + // Ok(value.clone()) + // } + // } else { + // // If no type definition, preserve the original object + // Ok(value.clone()) + // } + // }, + // _ => Ok(value.clone()), + // } + // }, // Add other type coercions as needed _ => Ok(value.clone()), } @@ -314,15 +310,23 @@ fn apply_transform(value: &Value, transform: &Transform) -> Result { .collect(); Ok(Value::String(strings.join(delimiter))) }, - ComposeFormat::Object => { + ComposeFormat::Object { template } => { + println!("Values to process: {:?}", values); + println!("Template: {:?}", template); let mut obj = Map::new(); - dbg!(&sources); - for (source, value) in sources.iter().zip(values.iter()) { - if let Source::KeyValue { key, .. } = dbg!(source) { - obj.insert(key.clone(), value.clone()); + if values.len() == 1 { + if let Some(value) = values.first() { + println!("Processing value: {:?}", value); + for (key, template_str) in template { + println!("Processing template: {} -> {}", key, template_str); + let formatted = template_str.replace("{value}", value.as_str().unwrap_or_default()); + println!("Formatted result: {}", formatted); + obj.insert(key.clone(), Value::String(formatted)); + } } } - Ok(Value::Object(obj)) + println!("Final object: {:?}", obj); + Ok(dbg!(Value::Object(obj))) }, ComposeFormat::ArrayOfObjects { template } => match value { Value::String(s) => { diff --git a/crates/learner/src/retriever/response.rs b/crates/learner/src/retriever/response.rs index e93d85d..50dfd48 100644 --- a/crates/learner/src/retriever/response.rs +++ b/crates/learner/src/retriever/response.rs @@ -77,7 +77,7 @@ pub enum ComposeFormat { /// Join fields with a delimiter Join { delimiter: String }, /// Create an object with key-value pairs - Object, + Object { template: BTreeMap }, /// Create an array of objects with specified structure ArrayOfObjects { /// How to structure each object From 1326681c1db5da8a04fe888bb4b7ee42d3d547b8 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 08:17:46 -0700 Subject: [PATCH 49/73] Update iacr.toml --- crates/learner/config_new/iacr.toml | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/crates/learner/config_new/iacr.toml b/crates/learner/config_new/iacr.toml index 2bf74cd..49be515 100644 --- a/crates/learner/config_new/iacr.toml +++ b/crates/learner/config_new/iacr.toml @@ -12,11 +12,10 @@ source = "iacr" response_format = { type = "xml", strip_namespaces = true } -[resource_mappings.title] -path = "OAI-PMH/GetRecord/record/metadata/dc/title" - -[resource_mappings.abstract] -path = "OAI-PMH/GetRecord/record/metadata/dc/description" +[resource_mappings] +abstract = { path = "OAI-PMH/GetRecord/record/metadata/dc/description" } +publication_dates = { path = "OAI-PMH/GetRecord/record/metadata/dc/date" } +title = { path = "OAI-PMH/GetRecord/record/metadata/dc/title" } [resource_mappings.authors] path = "OAI-PMH/GetRecord/record/metadata/dc" @@ -26,16 +25,14 @@ type = "Compose" [resource_mappings.authors.transform.format] type = "Object" -[resource_mappings.publication_dates] -path = "OAI-PMH/GetRecord/record/metadata/dc/date" -[resource_mappings.pdf_url] -path = "OAI-PMH/GetRecord/record/metadata/dc/identifier" +# [resource_mappings.pdf_url] +# path = "OAI-PMH/GetRecord/record/metadata/dc/identifier" -[resource_mappings.pdf_url.transform] -pattern = "^(https://eprint\\.iacr\\.org/\\d{4}/\\d+)$" -replacement = "$1.pdf" -type = "Replace" +# [resource_mappings.pdf_url.transform] +# pattern = "^(https://eprint\\.iacr\\.org/\\d{4}/\\d+)$" +# replacement = "$1.pdf" +# type = "Replace" [headers] Accept = "application/xml" From b3ffd350f51a5c8c3a703078778a7a303fdb2bd6 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 10:24:35 -0700 Subject: [PATCH 50/73] all working again it seems --- crates/learner/config_new/arxiv.toml | 20 +- crates/learner/config_new/doi.toml | 65 ++-- crates/learner/config_new/iacr.toml | 23 +- crates/learner/config_new/paper.toml | 2 +- crates/learner/src/retriever/config.rs | 215 ++++---------- crates/learner/src/retriever/doi_test.json | 279 ++++++++++++++++++ crates/learner/src/retriever/mod.rs | 23 +- crates/learner/src/retriever/response.rs | 62 +--- .../tests/workflows/paper_retrieval.rs | 7 +- 9 files changed, 403 insertions(+), 293 deletions(-) create mode 100644 crates/learner/src/retriever/doi_test.json diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml index 54a6584..12cd874 100644 --- a/crates/learner/config_new/arxiv.toml +++ b/crates/learner/config_new/arxiv.toml @@ -19,21 +19,11 @@ publication_dates = { path = "feed/entry/published" } title = { path = "feed/entry/title" } [retrieval_mappings.urls] -path = "feed/entry/id" -transform.type = "Compose" - -[[retrieval_mappings.urls.transform.sources]] -transform.pattern = ".*/(?:abs|pdf)/(.+?)(?:v\\d+)?$" -transform.replacement = "$1" -transform.type = "Replace" -type = "path" -value = "feed/entry/id" - -[retrieval_mappings.urls.transform.format] -template.html = "http://arxiv.org/abs/{value}" -template.pdf = "http://arxiv.org/pdf/{value}" -type = "Object" - +path = "feed/entry/id" +structure = { "pdf" = "http://arxiv.org/pdf/{value}.pdf", "html" = "http://arxiv.org/abs/{value}" } +transforms = [ + { type = "Replace", pattern = ".*/(?:abs|pdf)/(.+?)(?:v\\d+)?$", replacement = "$1" }, +] [headers] Accept = "application/xml" diff --git a/crates/learner/config_new/doi.toml b/crates/learner/config_new/doi.toml index 1e14200..0436444 100644 --- a/crates/learner/config_new/doi.toml +++ b/crates/learner/config_new/doi.toml @@ -2,6 +2,9 @@ name = "doi" description = "Retriever template for getting a paper from DOI/Crossref" +resource_template = "paper" +retrieval_template = "retrieval" + base_url = "https://api.crossref.org/works" endpoint_template = "https://api.crossref.org/works/{identifier}" @@ -9,51 +12,47 @@ pattern = "(?:^|https?://doi\\.org/)(10\\.\\d{4,9}/[-._;()/:\\w]+)$" resource = "paper" source = "doi" -response_format = { type = "json" } +[response_format] +type = "json" + +[resource_mappings] +abstract = { path = "message/abstract" } +publication_dates = { path = "message/created/date-time" } [resource_mappings.title] path = "message" -[resource_mappings.title.transform] -sources = [ - { type = "path", value = "title/0" }, - { type = "path", value = "subtitle/0" }, +transforms = [ + { type = "Combine", subpaths = [ + "title/0", + "subtitle/0", + ], delimiter = ": " }, ] -type = "Compose" - -[resource_mappings.title.transform.format] -delimiter = ": " -type = "Join" - - -[resource_mappings.abstract] -path = "message/abstract" - -[resource_mappings.abstract.transform] -pattern = "<[^>]+>" -replacement = "" -type = "Replace" [resource_mappings.authors] path = "message/author" -[resource_mappings.authors.transform] -sources = [ - { type = "key_value", value = { key = "family", path = "family" } }, - { type = "key_value", value = { key = "given", path = "given" } }, +transforms = [ + { type = "Combine", subpaths = [ + "given", + "family", + ], delimiter = " " }, ] -type = "Compose" +# [response_format.field_maps.authors.transform] +# sources = [ +# { type = "key_value", value = { key = "family", path = "family" } }, +# { type = "key_value", value = { key = "given", path = "given" } }, +# ] +# type = "Compose" -[resource_mappings.authors.transform.format] -template = { name = "{given} {family}" } -type = "ArrayOfObjects" +# [response_format.field_maps.authors.transform.format] +# template = { name = "{given} {family}" } +# type = "ArrayOfObjects" -[resource_mappings.publication_dates] -path = "message/created/date-time" -[resource_mappings.doi] -path = "message/DOI" +# [response_format.field_maps.pdf_url] +# path = "message/link/0/URL" -[resource_mappings.pdf_url] -path = "message/link/0/URL" +# [response_format.field_maps.doi] +# path = "message/DOI" [headers] Accept = "application/json" diff --git a/crates/learner/config_new/iacr.toml b/crates/learner/config_new/iacr.toml index 49be515..e9679ad 100644 --- a/crates/learner/config_new/iacr.toml +++ b/crates/learner/config_new/iacr.toml @@ -14,25 +14,16 @@ response_format = { type = "xml", strip_namespaces = true } [resource_mappings] abstract = { path = "OAI-PMH/GetRecord/record/metadata/dc/description" } +authors = { path = "OAI-PMH/GetRecord/record/metadata/dc/creator" } publication_dates = { path = "OAI-PMH/GetRecord/record/metadata/dc/date" } title = { path = "OAI-PMH/GetRecord/record/metadata/dc/title" } -[resource_mappings.authors] -path = "OAI-PMH/GetRecord/record/metadata/dc" -[resource_mappings.authors.transform] -sources = [{ type = "key_value", value = { key = "name", path = "creator" } }] -type = "Compose" -[resource_mappings.authors.transform.format] -type = "Object" - - -# [resource_mappings.pdf_url] -# path = "OAI-PMH/GetRecord/record/metadata/dc/identifier" - -# [resource_mappings.pdf_url.transform] -# pattern = "^(https://eprint\\.iacr\\.org/\\d{4}/\\d+)$" -# replacement = "$1.pdf" -# type = "Replace" +[retrieval_mappings.urls] +path = "OAI-PMH/GetRecord/record/metadata/dc/identifier" +structure = { "pdf" = "https://eprint.iacr.org/{value}.pdf", "html" = "https://eprint.iacr.org/{value}" } +transforms = [ + { type = "Replace", pattern = ".*/(\\d{4}/\\d+)$", replacement = "$1" }, +] [headers] Accept = "application/xml" diff --git a/crates/learner/config_new/paper.toml b/crates/learner/config_new/paper.toml index e4b298f..d526819 100644 --- a/crates/learner/config_new/paper.toml +++ b/crates/learner/config_new/paper.toml @@ -1,7 +1,7 @@ description = "Configuration for a paper" name = "paper" -abstract_text = { field_type = "string", required = false } +abstract = { field_type = "string", required = false } authors = { field_type = "array", required = true, validation = { min_items = 1 } } publication_dates = { field_type = "array", required = true, validation = { datetime = true } } title = { field_type = "string", required = true, validation = { min_length = 1, max_length = 500 } } diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 0415867..da67c55 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,4 +1,7 @@ +use std::os::macos::raw; + use record::{Record, State, StorageData}; +use serde_json::json; use super::*; use crate::template::{FieldDefinition, Template, TemplatedItem}; @@ -88,8 +91,9 @@ impl Retriever { resource.insert("source_identifier".into(), Value::String(identifier.to_string())); // Validate full resource against config - self.resource_template.validate(&resource)?; - self.retrieval_template.validate(&retrieval)?; + // TODO: Add in validations here. + // self.resource_template.validate(dbg!(&resource))?; + // self.retrieval_template.validate(dbg!(&retrieval))?; Ok(Record { resource, state: State::default(), storage: StorageData::default(), retrieval }) } @@ -145,68 +149,34 @@ fn extract_mapped_value( }; // First apply any explicit transforms - let value = if let Some(transform) = &field_map.transform { - apply_transform(&raw_value, transform)? + let mut value = raw_value; + for transform in &field_map.transforms { + value = apply_transform(&value, dbg!(transform))?; + } + value = if let Some(structure) = &field_map.structure { + let mut object = BTreeMap::new(); + for (key, to_replace) in structure { + // TODO: Remove unwrap + object.insert(key, to_replace.replace("{value}", value.as_str().unwrap())); + } + json!(object) } else { - raw_value + value }; - // Then attempt type coercion based on field definition - let coerced = dbg!(coerce_to_type(&value, field_def)?); - Ok(Some(coerced)) + // Coerce a single value into an array if needed + if field_def.field_type.as_str() == "array" { + value = dbg!(into_array(value)); + } + + Ok(Some(value)) } -fn coerce_to_type(value: &Value, field_def: &FieldDefinition) -> Result { - match field_def.field_type.as_str() { - "array" => { - let arr = match value { - // Single value -> wrap in array - Value::String(_) | Value::Object(_) | Value::Number(_) => vec![value.clone()], - // Already an array - Value::Array(arr) => arr.clone(), - _ => return Ok(value.clone()), // Can't coerce, return as-is - }; - - // If we have inner type info, try to coerce each element - if let Some(ref type_def) = field_def.type_definition { - if let Some(ref element_def) = type_def.element_type { - let coerced: Vec = - arr.into_iter().map(|v| coerce_to_type(&v, element_def)).collect::>()?; - Ok(Value::Array(coerced)) - } else { - Ok(Value::Array(arr)) - } - } else { - Ok(Value::Array(arr)) - } - }, - // "object" => { - // match value { - // Value::Object(m) => { - // if let Some(ref type_def) = field_def.type_definition { - // if let Some(fields) = &type_def.fields { - // let mut obj = Map::new(); - // // Copy over matching fields with coercion - // for field in fields { - // if let Some(v) = m.get(&field.name) { - // obj.insert(field.name.clone(), coerce_to_type(v, field)?); - // } - // } - // Ok(Value::Object(obj)) - // } else { - // // If no fields defined, preserve the original object - // Ok(value.clone()) - // } - // } else { - // // If no type definition, preserve the original object - // Ok(value.clone()) - // } - // }, - // _ => Ok(value.clone()), - // } - // }, - // Add other type coercions as needed - _ => Ok(value.clone()), +fn into_array(value: Value) -> Value { + match value { + // Single value -> wrap in array + Value::Array(_) => value, + _ => json!(vec![value]), } } @@ -268,111 +238,32 @@ fn apply_transform(value: &Value, transform: &Transform) -> Result { Regex::new(pattern).map_err(|e| LearnerError::ApiError(format!("Invalid regex: {e}")))?; Ok(Value::String(re.replace_all(text, replacement.as_str()).into_owned())) }, - Transform::Date { from_format, to_format } => { - let text = value.as_str().ok_or_else(|| { - LearnerError::ApiError("Date transform requires string input".to_string()) - })?; - let dt = chrono::NaiveDateTime::parse_from_str(text, from_format) - .map_err(|e| LearnerError::ApiError(format!("Invalid date: {e}")))?; - Ok(Value::String(dt.format(to_format).to_string())) - }, - Transform::Url { base, suffix } => { - let text = value - .as_str() - .ok_or_else(|| LearnerError::ApiError("URL transform requires string input".to_string()))?; - Ok(Value::String(format!( - "{}{}", - base.replace("{value}", text), - suffix.as_deref().unwrap_or("") - ))) - }, - Transform::Compose { sources, format } => { - let values: Vec = dbg!(sources - .iter() - .filter_map(|source| match source { - Source::Path(path) | Source::KeyValue { key: _, path } => { - let components: Vec<&str> = path.split('/').collect(); - get_path_value(value, &components) - }, - Source::Literal(text) => Some(Value::String(text.clone())), - }) - .collect()); - match format { - ComposeFormat::Join { delimiter } => { - let strings: Vec = values - .iter() - .filter_map(|v| match v { - Value::String(s) => Some(s.clone()), - Value::Array(arr) if arr.len() == 1 => - arr[0].as_str().map(std::string::ToString::to_string), - _ => None, - }) - .collect(); - Ok(Value::String(strings.join(delimiter))) - }, - ComposeFormat::Object { template } => { - println!("Values to process: {:?}", values); - println!("Template: {:?}", template); - let mut obj = Map::new(); - if values.len() == 1 { - if let Some(value) = values.first() { - println!("Processing value: {:?}", value); - for (key, template_str) in template { - println!("Processing template: {} -> {}", key, template_str); - let formatted = template_str.replace("{value}", value.as_str().unwrap_or_default()); - println!("Formatted result: {}", formatted); - obj.insert(key.clone(), Value::String(formatted)); - } - } - } - println!("Final object: {:?}", obj); - Ok(dbg!(Value::Object(obj))) - }, - ComposeFormat::ArrayOfObjects { template } => match value { - Value::String(s) => { - let mut obj = Map::new(); - for (key, template_value) in template { - let value = template_value.replace("{value}", s); - obj.insert(key.clone(), Value::String(value)); - } - Ok(Value::Array(vec![Value::Object(obj)])) - }, - Value::Array(arr) => { - let objects: Vec = arr - .iter() - .filter_map(|item| { - let mut obj = Map::new(); - for (key, template_value) in template { - let value = match item { - Value::String(s) => template_value.replace("{value}", s), - Value::Object(obj) => { - let mut keys_and_vals = Vec::new(); - for source in sources { - if let Source::KeyValue { key, path } = source { - if let Some(val) = obj.get(path) { - keys_and_vals.push((key, val)); - } - } - } - keys_and_vals.into_iter().fold(template_value.clone(), |acc, (k, v)| { - let replacement = format!("{{{k}}}"); - acc.replace(&replacement, v.as_str().unwrap_or_default()) - }) - }, - _ => return None, - }; - obj.insert(key.clone(), Value::String(value)); - } - Some(Value::Object(obj)) - }) - .collect(); - Ok(Value::Array(objects)) - }, - _ => Err(LearnerError::ApiError( - "ArrayOfObjects transform requires string or array input".to_string(), + Transform::Combine { subpaths, delimiter } => { + // TODO: fix unwrap + println!("INSIDE OF COMBINE WITH SUBPATHS: {:?}", subpaths); + match value.as_array() { + Some(arr) => + return Ok(Value::Array( + arr.iter().map(|v| combine_path_values(v, subpaths, delimiter)).collect(), )), - }, + None => return Ok(combine_path_values(value, subpaths, delimiter)), } }, } } + +fn combine_path_values(value: &Value, subpaths: &Vec, delimiter: &str) -> Value { + Value::String( + subpaths + .iter() + .fold(String::new(), |mut acc, subpath| { + if !acc.is_empty() { + acc.push_str(delimiter); + } + let subpath: Vec<&str> = subpath.split("/").collect(); + acc.push_str(dbg!(get_path_value(value, &subpath).unwrap().as_str().unwrap())); + acc + }) + .to_string(), + ) +} diff --git a/crates/learner/src/retriever/doi_test.json b/crates/learner/src/retriever/doi_test.json new file mode 100644 index 0000000..d7d85f2 --- /dev/null +++ b/crates/learner/src/retriever/doi_test.json @@ -0,0 +1,279 @@ +{ + "status": "ok", + "message-type": "work", + "message-version": "1.0.0", + "message": { + "indexed": { + "date-parts": [ + [ + 2024, + 11, + 19 + ] + ], + "date-time": "2024-11-19T16:10:37Z", + "timestamp": 1732032637044 + }, + "reference-count": 17, + "publisher": "Association for Computing Machinery (ACM)", + "issue": "1", + "content-domain": { + "domain": [ + "dl.acm.org" + ], + "crossmark-restriction": true + }, + "short-container-title": [ + "Commun. ACM" + ], + "published-print": { + "date-parts": [ + [ + 2008, + 1 + ] + ] + }, + "abstract": "\n MapReduce is a programming model and an associated implementation for processing and generating large datasets that is amenable to a broad variety of real-world tasks. Users specify the computation in terms of a\n map<\/jats:italic>\n and a\n reduce<\/jats:italic>\n function, and the underlying runtime system automatically parallelizes the computation across large-scale clusters of machines, handles machine failures, and schedules inter-machine communication to make efficient use of the network and disks. Programmers find the system easy to use: more than ten thousand distinct MapReduce programs have been implemented internally at Google over the past four years, and an average of one hundred thousand MapReduce jobs are executed on Google's clusters every day, processing a total of more than twenty petabytes of data per day.\n <\/jats:p>", + "DOI": "10.1145\/1327452.1327492", + "type": "journal-article", + "created": { + "date-parts": [ + [ + 2008, + 1, + 3 + ] + ], + "date-time": "2008-01-03T18:20:10Z", + "timestamp": 1199384410000 + }, + "page": "107-113", + "update-policy": "http:\/\/dx.doi.org\/10.1145\/crossmark-policy", + "source": "Crossref", + "is-referenced-by-count": 10721, + "title": [ + "MapReduce" + ], + "prefix": "10.1145", + "volume": "51", + "author": [ + { + "given": "Jeffrey", + "family": "Dean", + "sequence": "first", + "affiliation": [ + { + "name": "Google, Mountain View, CA" + } + ] + }, + { + "given": "Sanjay", + "family": "Ghemawat", + "sequence": "additional", + "affiliation": [ + { + "name": "Google, Mountain View, CA" + } + ] + } + ], + "member": "320", + "published-online": { + "date-parts": [ + [ + 2008, + 1 + ] + ] + }, + "reference": [ + { + "key": "e_1_2_2_1_1", + "unstructured": "Hadoop: Open source implementation of MapReduce. http:\/\/lucene. apache.org\/hadoop\/. Hadoop: Open source implementation of MapReduce. http:\/\/lucene. apache.org\/hadoop\/." + }, + { + "key": "e_1_2_2_2_1", + "unstructured": "The Phoenix system for MapReduce programming. http:\/\/csl.stanford. edu\/~christos\/sw\/phoenix\/. The Phoenix system for MapReduce programming. http:\/\/csl.stanford. edu\/~christos\/sw\/phoenix\/." + }, + { + "key": "e_1_2_2_3_1", + "doi-asserted-by": "publisher", + "DOI": "10.1145\/253260.253322" + }, + { + "key": "e_1_2_2_4_1", + "doi-asserted-by": "publisher", + "DOI": "10.1109\/MM.2003.1196112" + }, + { + "volume-title": "Proceedings of the 1st USENIX Symposium on Networked Systems Design and Implementation (NSDI).", + "author": "Bent J.", + "key": "e_1_2_2_5_1", + "unstructured": "Bent , J. , Thain , D. , Arpaci-Dusseau , A. C. , Arpaci-Dusseau , R. H. , and Livny , M . 2004. Explicit control in a batch-aware distributed file system . In Proceedings of the 1st USENIX Symposium on Networked Systems Design and Implementation (NSDI). Bent, J., Thain, D., Arpaci-Dusseau, A. C., Arpaci-Dusseau, R. H., and Livny, M. 2004. Explicit control in a batch-aware distributed file system. In Proceedings of the 1st USENIX Symposium on Networked Systems Design and Implementation (NSDI)." + }, + { + "key": "e_1_2_2_6_1", + "doi-asserted-by": "publisher", + "DOI": "10.1109\/12.42122" + }, + { + "volume-title": "Proceedings of Neural Information Processing Systems Conference (NIPS)", + "author": "Chu C.-T.", + "key": "e_1_2_2_7_1", + "unstructured": "Chu , C.-T. , Kim , S. K. , Lin , Y. A. , Yu , Y. , Bradski , G. , Ng , A. , and Olukotun , K . 2006. Map-Reduce for machine learning on multicore . In Proceedings of Neural Information Processing Systems Conference (NIPS) . Vancouver, Canada. Chu, C.-T., Kim, S. K., Lin, Y. A., Yu, Y., Bradski, G., Ng, A., and Olukotun, K. 2006. Map-Reduce for machine learning on multicore. In Proceedings of Neural Information Processing Systems Conference (NIPS). Vancouver, Canada." + }, + { + "key": "e_1_2_2_8_1", + "first-page": "137", + "article-title": "MapReduce: Simplified data processing on large clusters. In Proceedings of Operating Systems Design and Implementation (OSDI). San Francisco", + "author": "Dean J.", + "year": "2004", + "unstructured": "Dean , J. and Ghemawat , S. 2004 . MapReduce: Simplified data processing on large clusters. In Proceedings of Operating Systems Design and Implementation (OSDI). San Francisco , CA. 137 - 150 . Dean, J. and Ghemawat, S. 2004. MapReduce: Simplified data processing on large clusters. In Proceedings of Operating Systems Design and Implementation (OSDI). San Francisco, CA. 137-150.", + "journal-title": "CA." + }, + { + "key": "e_1_2_2_9_1", + "doi-asserted-by": "publisher", + "DOI": "10.1145\/268998.266662" + }, + { + "key": "e_1_2_2_10_1", + "doi-asserted-by": "publisher", + "DOI": "10.1145\/945445.945450" + }, + { + "volume-title": "Parallel Processing", + "series-title": "Lecture Notes in Computer Science", + "author": "Gorlatch S.", + "key": "e_1_2_2_11_1", + "unstructured": "Gorlatch , S. 1996. Systematic efficient parallelization of scan and other list homomorphisms . In L. Bouge, P. Fraigniaud, A. Mignotte, and Y. Robert, Eds. Euro-Par'96. Parallel Processing , Lecture Notes in Computer Science , vol. 1124 . Springer-Verlag . 401-408 Gorlatch, S. 1996. Systematic efficient parallelization of scan and other list homomorphisms. In L. Bouge, P. Fraigniaud, A. Mignotte, and Y. Robert, Eds. Euro-Par'96. Parallel Processing, Lecture Notes in Computer Science, vol. 1124. Springer-Verlag. 401-408" + }, + { + "key": "e_1_2_2_12_1", + "unstructured": "Gray J. Sort benchmark home page. http:\/\/research.microsoft.com\/barc\/SortBenchmark\/. Gray J. Sort benchmark home page. http:\/\/research.microsoft.com\/barc\/SortBenchmark\/." + }, + { + "volume-title": "Proceedings of the 2004 USENIX File and Storage Technologies FAST Conference.", + "author": "Huston L.", + "key": "e_1_2_2_13_1", + "unstructured": "Huston , L. , Sukthankar , R. , Wickremesinghe , R. , Satyanarayanan , M. , Ganger , G. R. , Riedel , E. , and Ailamaki , A . 2004. Diamond: A storage architecture for early discard in interactive search . In Proceedings of the 2004 USENIX File and Storage Technologies FAST Conference. Huston, L., Sukthankar, R., Wickremesinghe, R., Satyanarayanan, M., Ganger, G. R., Riedel, E., and Ailamaki, A. 2004. Diamond: A storage architecture for early discard in interactive search. In Proceedings of the 2004 USENIX File and Storage Technologies FAST Conference." + }, + { + "key": "e_1_2_2_14_1", + "doi-asserted-by": "publisher", + "DOI": "10.1145\/322217.322232" + }, + { + "key": "e_1_2_2_15_1", + "doi-asserted-by": "publisher", + "DOI": "10.1145\/62044.62050" + }, + { + "key": "e_1_2_2_16_1", + "doi-asserted-by": "publisher", + "DOI": "10.1109\/HPCA.2007.346181" + }, + { + "key": "e_1_2_2_17_1", + "doi-asserted-by": "publisher", + "DOI": "10.1109\/2.928624" + } + ], + "container-title": [ + "Communications of the ACM" + ], + "original-title": [], + "language": "en", + "link": [ + { + "URL": "https:\/\/dl.acm.org\/doi\/pdf\/10.1145\/1327452.1327492", + "content-type": "unspecified", + "content-version": "vor", + "intended-application": "similarity-checking" + } + ], + "deposited": { + "date-parts": [ + [ + 2022, + 12, + 28 + ] + ], + "date-time": "2022-12-28T19:35:41Z", + "timestamp": 1672256141000 + }, + "score": 1, + "resource": { + "primary": { + "URL": "https:\/\/dl.acm.org\/doi\/10.1145\/1327452.1327492" + } + }, + "subtitle": [ + "simplified data processing on large clusters" + ], + "short-title": [], + "issued": { + "date-parts": [ + [ + 2008, + 1 + ] + ] + }, + "references-count": 17, + "journal-issue": { + "issue": "1", + "published-print": { + "date-parts": [ + [ + 2008, + 1 + ] + ] + } + }, + "alternative-id": [ + "10.1145\/1327452.1327492" + ], + "URL": "https:\/\/doi.org\/10.1145\/1327452.1327492", + "relation": {}, + "ISSN": [ + "0001-0782", + "1557-7317" + ], + "issn-type": [ + { + "type": "print", + "value": "0001-0782" + }, + { + "type": "electronic", + "value": "1557-7317" + } + ], + "subject": [], + "published": { + "date-parts": [ + [ + 2008, + 1 + ] + ] + }, + "assertion": [ + { + "value": "2008-01-01", + "order": 2, + "name": "published", + "label": "Published", + "group": { + "name": "publication_history", + "label": "Publication History" + } + } + ] + } +} \ No newline at end of file diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index a6a6fcd..769e645 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -208,17 +208,18 @@ mod tests { assert!(field_maps.contains_key("pdf_url")); // Verify PDF transform - if let Some(map) = field_maps.get("pdf_url") { - match &map.transform { - Some(Transform::Replace { pattern, replacement }) => { - assert_eq!(pattern, "/abs/"); - assert_eq!(replacement, "/pdf/"); - }, - _ => panic!("Expected Replace transform for pdf_url"), - } - } else { - panic!("Missing pdf_url field map"); - } + todo!("Fix this"); + // if let Some(map) = field_maps.get("pdf_url") { + // match &map.transform { + // Some(Transform::Replace { pattern, replacement }) => { + // assert_eq!(pattern, "/abs/"); + // assert_eq!(replacement, "/pdf/"); + // }, + // _ => panic!("Expected Replace transform for pdf_url"), + // } + // } else { + // panic!("Missing pdf_url field map"); + // } } else { panic!("Expected an XML configuration, but did not get one.") } diff --git a/crates/learner/src/retriever/response.rs b/crates/learner/src/retriever/response.rs index 50dfd48..3432610 100644 --- a/crates/learner/src/retriever/response.rs +++ b/crates/learner/src/retriever/response.rs @@ -15,14 +15,16 @@ pub enum ResponseFormat { #[serde(rename = "json")] Json, } - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FieldMap { - /// Path to field in response (e.g., JSON path or XPath) - pub path: String, - /// Optional transformation to apply to extracted value + /// Path to field in response + pub path: String, + /// Transformations to apply in order + #[serde(default)] + pub transforms: Vec, + /// Optional structured output #[serde(default)] - pub transform: Option, + pub structure: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -35,53 +37,9 @@ pub enum Transform { /// Text to replace matched patterns with replacement: String, }, - /// Convert between date formats - Date { - /// Source date format string using chrono syntax (e.g., "%Y-%m-%d") - from_format: String, - /// Target date format string using chrono syntax (e.g., "%Y-%m-%dT%H:%M:%SZ") - to_format: String, - }, - /// Construct URL from parts - Url { - /// Base URL template, may contain {value} placeholder - base: String, - /// Optional suffix to append to the URL (e.g., ".pdf") - suffix: Option, - }, - Compose { - /// List of field paths or direct values to combine - sources: Vec, - /// How to format the combined result - format: ComposeFormat, - }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", content = "value")] -pub enum Source { - /// Path to a field to extract - #[serde(rename = "path")] - Path(String), - /// A literal string value - #[serde(rename = "literal")] - Literal(String), - /// A field mapping with a new key name - #[serde(rename = "key_value")] - KeyValue { key: String, path: String }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum ComposeFormat { - /// Join fields with a delimiter - Join { delimiter: String }, - /// Create an object with key-value pairs - Object { template: BTreeMap }, - /// Create an array of objects with specified structure - ArrayOfObjects { - /// How to structure each object - template: BTreeMap, + Combine { + subpaths: Vec, + delimiter: String, }, } diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index 0409449..f77e3a8 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -58,9 +58,10 @@ async fn test_iacr_retriever_integration() -> TestResult<()> { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); let retriever: Retriever = manager.load_config("config_new/iacr.toml")?; - let paper = retriever.retrieve_resource("2016/260").await.unwrap(); - // assert!(resource.validate(&paper)?); // TODO: validation already happens internally, to be fair - // that validation may not be working totally right + let paper = retriever.retrieve_resource("2019/953").await.unwrap(); // plonk + // let paper = retriever.retrieve_resource("2016/260").await.unwrap(); // groth 16 + // assert!(resource.validate(&paper)?); // TODO: validation already happens internally, to be fair + // that validation may not be working totally right dbg!(&paper); todo!("This needs cleaned up."); From 73f62b09ae7994733fc3e5ea6b2b032d9c7b33e3 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 10:29:28 -0700 Subject: [PATCH 51/73] doi gets pdf url now too --- crates/learner/config_new/doi.toml | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/crates/learner/config_new/doi.toml b/crates/learner/config_new/doi.toml index 0436444..b6d0ad8 100644 --- a/crates/learner/config_new/doi.toml +++ b/crates/learner/config_new/doi.toml @@ -36,23 +36,10 @@ transforms = [ "family", ], delimiter = " " }, ] -# [response_format.field_maps.authors.transform] -# sources = [ -# { type = "key_value", value = { key = "family", path = "family" } }, -# { type = "key_value", value = { key = "given", path = "given" } }, -# ] -# type = "Compose" -# [response_format.field_maps.authors.transform.format] -# template = { name = "{given} {family}" } -# type = "ArrayOfObjects" - - -# [response_format.field_maps.pdf_url] -# path = "message/link/0/URL" - -# [response_format.field_maps.doi] -# path = "message/DOI" +[retrieval_mappings.urls] +path = "message/link/0/URL" +structure = { "pdf" = "{value}" } [headers] Accept = "application/json" From 6b00f739436c338c263c82ef7757fd27b42c6fcd Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 16:22:41 -0700 Subject: [PATCH 52/73] simpler tomls --- crates/learner/config_new/paper.toml | 28 +- crates/learner/config_new/retrieval.toml | 29 +- crates/learner/src/retriever/config.rs | 16 +- crates/learner/src/template.rs | 388 +++++++++--------- .../tests/workflows/paper_retrieval.rs | 2 +- 5 files changed, 249 insertions(+), 214 deletions(-) diff --git a/crates/learner/config_new/paper.toml b/crates/learner/config_new/paper.toml index d526819..77a0084 100644 --- a/crates/learner/config_new/paper.toml +++ b/crates/learner/config_new/paper.toml @@ -1,7 +1,27 @@ description = "Configuration for a paper" name = "paper" -abstract = { field_type = "string", required = false } -authors = { field_type = "array", required = true, validation = { min_items = 1 } } -publication_dates = { field_type = "array", required = true, validation = { datetime = true } } -title = { field_type = "string", required = true, validation = { min_length = 1, max_length = 500 } } +[abstract] +base_type = "string" +required = false + +[title] +base_type = "string" +required = true +validation = { min_length = 1, max_length = 500 } + +[publication_dates] +base_type = "array" +items = { base_type = "string", validation = { datetime = true } } +required = true + +[authors] +base_type = "array" +required = true +validation = { min_items = 1 } +[authors.items] +base_type = "object" +fields = [ + { name = "name", base_type = "string", required = true, validation = { min_length = 1 } }, + { name = "affiliation", base_type = "string", required = false }, +] diff --git a/crates/learner/config_new/retrieval.toml b/crates/learner/config_new/retrieval.toml index 12d2c0d..e851a8c 100644 --- a/crates/learner/config_new/retrieval.toml +++ b/crates/learner/config_new/retrieval.toml @@ -1,35 +1,32 @@ -name = "retrieval" - description = "Standard retrieval data template" +name = "retrieval" [source] -field_type = "string" -required = false +base_type = "string" +required = false [source_identifier] -field_type = "string" -required = false +base_type = "string" +required = false [urls] -field_type = "object" -required = false # Consider changing - -[urls.type_definition] +base_type = "object" fields = [ - { name = "pdf", field_type = "string", required = false }, - { name = "html", field_type = "string", required = false }, + { name = "pdf", base_type = "string", required = false }, + { name = "html", base_type = "string", required = false }, ] +required = false [doi] -field_type = "string" -required = false +base_type = "string" +required = false [last_checked] -field_type = "string" +base_type = "string" required = false validation = { datetime = true } [access_type] -field_type = "string" +base_type = "string" required = false validation = { enum_values = ["open", "subscription", "institutional"] } diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index da67c55..da266bb 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,5 +1,3 @@ -use std::os::macos::raw; - use record::{Record, State, StorageData}; use serde_json::json; @@ -115,20 +113,19 @@ fn process_template_fields( let mut result = BTreeMap::new(); for field_def in &template.fields { - if let Some(field_map) = mappings.get(dbg!(&field_def.name)) { + // If we have a mapping for this field, try to extract its value + if let Some(field_map) = mappings.get(&field_def.name) { if let Some(value) = extract_mapped_value(json, field_map, field_def)? { result.insert(field_def.name.clone(), value); } else if field_def.required { + // Only error if the field was required and we couldn't find it return Err(LearnerError::ApiError(format!( "Required field '{}' not found in response", field_def.name ))); - } else if let Some(default) = &field_def.default { - result.insert(field_def.name.clone(), default.clone()); } } } - Ok(result) } @@ -151,8 +148,9 @@ fn extract_mapped_value( // First apply any explicit transforms let mut value = raw_value; for transform in &field_map.transforms { - value = apply_transform(&value, dbg!(transform))?; + value = apply_transform(&value, transform)?; } + value = if let Some(structure) = &field_map.structure { let mut object = BTreeMap::new(); for (key, to_replace) in structure { @@ -165,8 +163,8 @@ fn extract_mapped_value( }; // Coerce a single value into an array if needed - if field_def.field_type.as_str() == "array" { - value = dbg!(into_array(value)); + if field_def.base_type == "array" { + value = into_array(value); } Ok(Some(value)) diff --git a/crates/learner/src/template.rs b/crates/learner/src/template.rs index e78603d..75e5296 100644 --- a/crates/learner/src/template.rs +++ b/crates/learner/src/template.rs @@ -1,5 +1,7 @@ use std::collections::HashSet; +use serde_json::{Map, Number}; + use super::*; // Type alias for clarity and consistency @@ -13,10 +15,11 @@ pub struct Template { #[serde(default)] pub fields: Vec, } + +// Custom deserialization to handle the flattened field structure impl<'de> Deserialize<'de> for Template { fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de> { - // Helper struct to capture the raw TOML structure #[derive(Deserialize)] struct TemplateHelper { name: String, @@ -26,11 +29,9 @@ impl<'de> Deserialize<'de> for Template { fields: BTreeMap, } - // Deserialize into our helper first let helper = TemplateHelper::deserialize(deserializer)?; - // Convert the field map into a Vec, setting the name from the key - // Filter out the metadata fields we don't want to treat as FieldDefinitions + // Filter out metadata fields and set field names let fields = helper .fields .into_iter() @@ -49,31 +50,30 @@ impl<'de> Deserialize<'de> for Template { pub struct FieldDefinition { /// Name of the field #[serde(skip_deserializing)] - pub name: String, - /// Type of the field (should be a JSON Value type) - pub field_type: String, + pub name: String, + /// Whether this field must be present #[serde(default)] - pub required: bool, + pub required: bool, + /// Human-readable description #[serde(default)] pub description: Option, - /// Default value if field is absent - #[serde(default)] - pub default: Option, - /// Optional validation rules + + /// The base type of this field (string, number, array, object) + pub base_type: String, + + /// Validation rules for this type #[serde(default)] - pub validation: Option, + pub validation: Option, - pub type_definition: Option, -} + /// Element type if this is an array type + #[serde(default)] + pub items: Option>, -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TypeDefinition { - // For array types, defines the structure of elements - pub element_type: Option>, - // For table types, defines the fields - pub fields: Option>, + /// Fields if this is an object type + #[serde(default)] + pub fields: Option>, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -99,9 +99,8 @@ pub struct ValidationRules { } impl Template { - // TODO: Make this just return a `Result<()>` pub fn validate(&self, resource: &TemplatedItem) -> Result<()> { - // Check required fields + // Check required fields are present for field in &self.fields { if field.required && !resource.contains_key(&field.name) { return Err(LearnerError::TemplateInvalidation(format!( @@ -114,173 +113,193 @@ impl Template { // Validate each provided field for (name, value) in resource { if let Some(field) = self.fields.iter().find(|f| f.name == *name) { - // Validate field value against its definition - self.validate_field(field, value)?; + field.validate_with_path(value, &field.name)?; } } Ok(()) } +} - /// Validates a single field value against its definition - fn validate_field(&self, field: &FieldDefinition, value: &Value) -> Result<()> { - match (field.field_type.as_str(), value) { - // String validation - handles both basic type checking and string-specific rules - ("string", Value::String(v)) => { - if let Some(rules) = &field.validation { - // Length constraints - if let Some(min_length) = rules.min_length { - if v.len() < min_length { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be at least {} characters", - field.name, min_length - ))); - } - } - if let Some(max_length) = rules.max_length { - if v.len() > max_length { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' cannot exceed {} characters", - field.name, max_length - ))); - } - } +impl FieldDefinition { + fn validate_with_path(&self, value: &Value, path: &str) -> Result<()> { + match (self.base_type.as_str(), value) { + ("string", Value::String(s)) => self.validate_string(s, path), + ("number", Value::Number(n)) => self.validate_number(n, path), + ("array", Value::Array(items)) => self.validate_array(items, path), + ("object", Value::Object(obj)) => self.validate_object(obj, path), + ("boolean", Value::Bool(_)) => Ok(()), + ("null", Value::Null) => Ok(()), + _ => Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' expected type '{}' but got '{}'", + path, + self.base_type, + type_name_of_value(value) + ))), + } + } - // Pattern matching via regex - if let Some(pattern) = &rules.pattern { - let re = Regex::new(pattern) - .map_err(|_| LearnerError::TemplateInvalidation("Invalid regex pattern".into()))?; - if !re.is_match(v) { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must match pattern: {}", - field.name, pattern - ))); - } - } + fn validate_string(&self, value: &str, path: &str) -> Result<()> { + if let Some(rules) = &self.validation { + // Length constraints + if let Some(min_length) = rules.min_length { + if value.len() < min_length { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' must be at least {} characters", + path, min_length + ))); + } + } + if let Some(max_length) = rules.max_length { + if value.len() > max_length { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' cannot exceed {} characters", + path, max_length + ))); + } + } + + // Pattern matching + if let Some(pattern) = &rules.pattern { + let re = Regex::new(pattern) + .map_err(|_| LearnerError::TemplateInvalidation("Invalid regex pattern".into()))?; + if !re.is_match(value) { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' must match pattern: {}", + path, pattern + ))); + } + } + + // DateTime validation + if rules.datetime == Some(true) && DateTime::parse_from_rfc3339(value).is_err() { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' must be a valid RFC3339 datetime", + path + ))); + } + + // Enum validation + if let Some(allowed) = &rules.enum_values { + if !allowed.contains(&value.to_string()) { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' must be one of: {:?}", + path, allowed + ))); + } + } + } + Ok(()) + } - // Datetime validation if specified - if rules.datetime == Some(true) && DateTime::parse_from_rfc3339(v).is_err() { + fn validate_number(&self, value: &Number, path: &str) -> Result<()> { + if let Some(rules) = &self.validation { + if let Some(num) = value.as_f64() { + if let Some(min) = rules.minimum { + if num < min { return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be a valid RFC3339 datetime", - field.name + "Field '{}' must be at least {}", + path, min ))); } - - // Enumerated values check - if let Some(allowed) = &rules.enum_values { - if !allowed.contains(v) { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be one of: {:?}", - field.name, allowed - ))); - } - } } - Ok(()) - }, - - // Numeric validations - handle both number types - ("number", Value::Number(n)) => { - if let Some(rules) = &field.validation { - if let Some(num) = n.as_f64() { - validate_numeric(field, num, rules)?; + if let Some(max) = rules.maximum { + if num > max { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' cannot exceed {}", + path, max + ))); } } - Ok(()) - }, - - // Array validation - handles array-specific rules - ("array", Value::Array(v)) => { - if let Some(rules) = &field.validation { - if let Some(min_items) = rules.min_items { - if v.len() < min_items { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must have at least {} items", - field.name, min_items - ))); - } - } - - if let Some(max_items) = rules.max_items { - if v.len() > max_items { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' cannot exceed {} items", - field.name, max_items - ))); - } - } - - if rules.unique_items == Some(true) { - let mut seen = HashSet::new(); - for item in v { - let item_str = serde_json::to_string(item).map_err(|_| { - LearnerError::TemplateInvalidation("Failed to serialize array item".into()) - })?; - if !seen.insert(item_str) { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' contains duplicate items", - field.name - ))); - } - } + if let Some(multiple) = rules.multiple_of { + let ratio = num / multiple; + if (ratio - ratio.round()).abs() > f64::EPSILON { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' must be a multiple of {}", + path, multiple + ))); } } - Ok(()) - }, - - // Simple type validations - just ensure type matches - ("boolean", Value::Bool(_)) => Ok(()), - ("object", Value::Object(_)) => Ok(()), - ("null", Value::Null) => Ok(()), - - // Type mismatch - provide a clear error message - _ => Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' expected type '{}' but got '{}'", - field.name, - field.field_type, - match value { - Value::String(_) => "string", - Value::Number(_) => "number", - Value::Bool(_) => "boolean", - Value::Array(_) => "array", - Value::Object(_) => "object", - Value::Null => "null", - } - ))), + } } + Ok(()) } -} -fn validate_numeric(field: &FieldDefinition, value: f64, rules: &ValidationRules) -> Result<()> { - if let Some(min) = rules.minimum { - if value < min { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be at least {}", - field.name, min - ))); + fn validate_array(&self, items: &[Value], path: &str) -> Result<()> { + if let Some(rules) = &self.validation { + if let Some(min_items) = rules.min_items { + if items.len() < min_items { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' must have at least {} items", + path, min_items + ))); + } + } + if let Some(max_items) = rules.max_items { + if items.len() > max_items { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' cannot exceed {} items", + path, max_items + ))); + } + } + if rules.unique_items == Some(true) { + let mut seen = HashSet::new(); + for item in items { + let item_str = serde_json::to_string(item).map_err(|_| { + LearnerError::TemplateInvalidation("Failed to serialize array item".into()) + })?; + if !seen.insert(item_str) { + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' contains duplicate items", + path + ))); + } + } + } } - } - if let Some(max) = rules.maximum { - if value > max { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' cannot exceed {}", - field.name, max - ))); + // Validate each item if we have an item type definition + if let Some(item_type) = &self.items { + for (index, item) in items.iter().enumerate() { + item_type.validate_with_path(item, &format!("{}[{}]", path, index)).map_err(|e| { + LearnerError::TemplateInvalidation(format!( + "Invalid item at index {} in array '{}': {}", + index, path, e + )) + })?; + } } + + Ok(()) } - if let Some(multiple) = rules.multiple_of { - let ratio = value / multiple; - if (ratio - ratio.round()).abs() > f64::EPSILON { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be a multiple of {}", - field.name, multiple - ))); + fn validate_object(&self, obj: &Map, path: &str) -> Result<()> { + if let Some(fields) = &self.fields { + for field in fields { + if let Some(value) = obj.get(&field.name) { + field.validate_with_path(value, &format!("{}.{}", path, field.name))?; + } else if field.required { + return Err(LearnerError::TemplateInvalidation(format!( + "Missing required field '{}' in object '{}'", + field.name, path + ))); + } + } } + Ok(()) } +} - Ok(()) +fn type_name_of_value(value: &Value) -> &'static str { + match value { + Value::String(_) => "string", + Value::Number(_) => "number", + Value::Bool(_) => "boolean", + Value::Array(_) => "array", + Value::Object(_) => "object", + Value::Null => "null", + } } /// Convert DateTime to RFC3339 string for JSON storage @@ -374,26 +393,27 @@ mod tests { #[test] fn test_datetime_validation() { - let template = Template { - name: "Test Template".to_string(), - description: None, - fields: vec![FieldDefinition { - name: "timestamp".into(), - field_type: "string".into(), - required: true, - description: None, - default: None, - validation: Some(ValidationRules { datetime: Some(true), ..Default::default() }), - type_definition: None, - }], - }; - - let valid_resource = BTreeMap::from([("timestamp".into(), json!("2024-01-01T00:00:00Z"))]); - template.validate(&valid_resource).unwrap(); - - let invalid_resource = BTreeMap::from([ - ("timestamp".into(), json!("2024-01-01")), // Not RFC3339 - ]); - assert!(template.validate(&invalid_resource).is_err()); + todo!("Fix this") + // let template = Template { + // name: "Test Template".to_string(), + // description: None, + // fields: vec![FieldDefinition { + // name: "timestamp".into(), + // field_type: "string".into(), + // required: true, + // description: None, + // default: None, + // validation: Some(ValidationRules { datetime: Some(true), ..Default::default() }), + // type_definition: None, + // }], + // }; + + // let valid_resource = BTreeMap::from([("timestamp".into(), json!("2024-01-01T00:00:00Z"))]); + // template.validate(&valid_resource).unwrap(); + + // let invalid_resource = BTreeMap::from([ + // ("timestamp".into(), json!("2024-01-01")), // Not RFC3339 + // ]); + // assert!(template.validate(&invalid_resource).is_err()); } } diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index f77e3a8..ad30638 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -100,7 +100,7 @@ async fn test_iacr_pdf_from_paper() -> TestResult<()> { #[traced_test] async fn test_doi_retriever_integration() -> TestResult<()> { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); - let retriever: Retriever = manager.load_config("config_new/doi.toml")?; + let retriever: Retriever = dbg!(manager.load_config("config_new/doi.toml")?); // Test with a real DOI paper let paper = retriever.retrieve_resource("10.1145/1327452.1327492").await?; From 8236cdd4031acb19c7b73a4035d148dcee691405 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sun, 8 Dec 2024 19:24:56 -0700 Subject: [PATCH 53/73] WIP: expanding and cleaning! --- crates/learner/config_new/doi.toml | 63 +++++--- crates/learner/config_new/paper.toml | 60 ++++++- crates/learner/config_new/retrieval.toml | 4 - crates/learner/src/configuration.rs | 3 +- crates/learner/src/retriever/config.rs | 193 +++++++++++++---------- crates/learner/src/retriever/mod.rs | 131 +++++++-------- crates/learner/src/retriever/response.rs | 114 +++++++++---- 7 files changed, 347 insertions(+), 221 deletions(-) diff --git a/crates/learner/config_new/doi.toml b/crates/learner/config_new/doi.toml index b6d0ad8..8191d5a 100644 --- a/crates/learner/config_new/doi.toml +++ b/crates/learner/config_new/doi.toml @@ -8,38 +8,49 @@ retrieval_template = "retrieval" base_url = "https://api.crossref.org/works" endpoint_template = "https://api.crossref.org/works/{identifier}" -pattern = "(?:^|https?://doi\\.org/)(10\\.\\d{4,9}/[-._;()/:\\w]+)$" -resource = "paper" -source = "doi" +pattern = "(?:^|https?://doi\\.org/)(10\\.\\d{4,9}/[-._;()/:\\w]+)$" +source = "doi" [response_format] -type = "json" +clean_content = true +type = "json" [resource_mappings] -abstract = { path = "message/abstract" } -publication_dates = { path = "message/created/date-time" } - -[resource_mappings.title] -path = "message" -transforms = [ - { type = "Combine", subpaths = [ - "title/0", - "subtitle/0", - ], delimiter = ": " }, -] +abstract = "message/abstract" +citations_count = "message/is-referenced-by-count" +doi = "message/DOI" +language = "message/language" +publisher = "message/publisher" +references_count = "message/reference-count" +title = { paths = ["message/title/0", "message/subtitle/0"], with = ": " } +type = "message/type" + +[resource_mappings.publication_dates] +from = "message" +[resource_mappings.publication_dates.map] +created = "created/date-time" +published = "published/date-time" +published_online = "published-online/date-time" +published_print = "published-print/date-time" + + +[resource_mappings.container] +from = "message" +[resource_mappings.container.map] +issn = "ISSN" +issue = "issue" +pages = "page" +title = "container-title/0" +volume = "volume" [resource_mappings.authors] -path = "message/author" -transforms = [ - { type = "Combine", subpaths = [ - "given", - "family", - ], delimiter = " " }, -] - -[retrieval_mappings.urls] -path = "message/link/0/URL" -structure = { "pdf" = "{value}" } +from = "message/author" +[resource_mappings.authors.map] +affiliation = "affiliation/0/name" +name = { paths = ["given", "family"] } + +[retrieval_mappings] +urls = { map = { pdf = "message/link/0/URL" } } [headers] Accept = "application/json" diff --git a/crates/learner/config_new/paper.toml b/crates/learner/config_new/paper.toml index 77a0084..eadb66f 100644 --- a/crates/learner/config_new/paper.toml +++ b/crates/learner/config_new/paper.toml @@ -1,6 +1,7 @@ description = "Configuration for a paper" name = "paper" +# Existing fields [abstract] base_type = "string" required = false @@ -10,11 +11,6 @@ base_type = "string" required = true validation = { min_length = 1, max_length = 500 } -[publication_dates] -base_type = "array" -items = { base_type = "string", validation = { datetime = true } } -required = true - [authors] base_type = "array" required = true @@ -25,3 +21,57 @@ fields = [ { name = "name", base_type = "string", required = true, validation = { min_length = 1 } }, { name = "affiliation", base_type = "string", required = false }, ] + +[publication_dates] +base_type = "object" +fields = [ + { name = "created", base_type = "string", required = true, validation = { datetime = true } }, + { name = "published", base_type = "string", required = false, validation = { datetime = true } }, + { name = "published_online", base_type = "string", required = false, validation = { datetime = true } }, + { name = "published_print", base_type = "string", required = false, validation = { datetime = true } }, +] +required = true + +# New suggested fields +[publisher] +base_type = "string" +required = false + +[container] +base_type = "object" +fields = [ + { name = "title", base_type = "string", required = true }, + { name = "issn", base_type = "array", required = false }, + { name = "volume", base_type = "string", required = false }, + { name = "issue", base_type = "string", required = false }, + { name = "pages", base_type = "string", required = false }, +] +required = false + +[type] +base_type = "string" +required = false +validation = { enum_values = [ + "journal-article", + "book", + "conference-paper", + "report", + "thesis", +] } + +[references_count] +base_type = "number" +required = false + +[language] +base_type = "string" +required = false + +[citations_count] # This is "is-referenced-by-count" in the DOI response +base_type = "number" +required = false + +[doi] +base_type = "string" +required = false +validation = { pattern = "10\\.\\d{4,9}/[-._;()/:\\w]+" } diff --git a/crates/learner/config_new/retrieval.toml b/crates/learner/config_new/retrieval.toml index e851a8c..0f441ed 100644 --- a/crates/learner/config_new/retrieval.toml +++ b/crates/learner/config_new/retrieval.toml @@ -17,10 +17,6 @@ fields = [ ] required = false -[doi] -base_type = "string" -required = false - [last_checked] base_type = "string" required = false diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index b230a53..3fbbf1a 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -65,7 +65,8 @@ mod tests { let retreival: Template = dbg!(manager.load_config("config_new/retrieval.toml")).unwrap(); - let arxiv_retriever: Retriever = dbg!(manager.load_config("config_new/arxiv.toml").unwrap()); + // let arxiv_retriever: Retriever = dbg!(manager.load_config("config_new/arxiv.toml").unwrap()); + let doi_retriever: Retriever = dbg!(manager.load_config("config_new/doi.toml").unwrap()); todo!("Clean this up") // The paper_record now has all fields from base_resource and paper, diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index da266bb..f707b64 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -25,11 +25,11 @@ pub struct Retriever { pub resource_template: Template, #[serde(default)] - pub resource_mappings: BTreeMap, + pub resource_mappings: BTreeMap, pub retrieval_template: Template, #[serde(default)] - pub retrieval_mappings: BTreeMap, + pub retrieval_mappings: BTreeMap, } impl Retriever { @@ -78,8 +78,33 @@ impl Retriever { // Process the response using configured processor let json = match &self.response_format { - ResponseFormat::Xml { strip_namespaces } => xml_to_json(&data, *strip_namespaces), - ResponseFormat::Json => serde_json::from_slice(&data)?, + ResponseFormat::Xml { strip_namespaces, clean_content } => { + let xml = if *strip_namespaces { + response::strip_xml_namespaces(&String::from_utf8_lossy(&data)) + } else { + String::from_utf8_lossy(&data).to_string() + }; + + // Convert to JSON value + let mut value = xml_to_json(&xml); + + // Clean content if requested + if *clean_content { + clean_value(&mut value); + } + + value + }, + + ResponseFormat::Json { clean_content } => { + let mut value = serde_json::from_slice(&data)?; + + if *clean_content { + clean_value(&mut value); + } + + value + }, }; let (mut resource, retrieval) = self.process_json_value(&json)?; @@ -108,65 +133,102 @@ impl Retriever { fn process_template_fields( json: &Value, template: &Template, - mappings: &BTreeMap, + mappings: &BTreeMap, ) -> Result> { let mut result = BTreeMap::new(); for field_def in &template.fields { - // If we have a mapping for this field, try to extract its value - if let Some(field_map) = mappings.get(&field_def.name) { - if let Some(value) = extract_mapped_value(json, field_map, field_def)? { - result.insert(field_def.name.clone(), value); - } else if field_def.required { - // Only error if the field was required and we couldn't find it - return Err(LearnerError::ApiError(format!( - "Required field '{}' not found in response", - field_def.name - ))); + if let Some(mapping) = mappings.get(&field_def.name) { + match extract_mapped_value(json, mapping, field_def) { + Ok(Some(value)) => { + result.insert(field_def.name.clone(), value); + }, + Ok(None) if field_def.required => { + return Err(LearnerError::ApiError(format!( + "Required field '{}' not found in response", + field_def.name + ))); + }, + Err(e) => return Err(e), + _ => continue, } } } + Ok(result) } -/// Extract and transform a value from JSON using a field mapping +// TODO: Fix unwraps in here fn extract_mapped_value( json: &Value, - field_map: &FieldMap, + mapping: &Mapping, field_def: &FieldDefinition, ) -> Result> { - let path_components: Vec<&str> = field_map.path.split('/').collect(); - - // Extract raw value using path - let raw_value = get_path_value(json, &path_components); - - // If no value found, return None - let Some(raw_value) = raw_value else { - return Ok(None); - }; + let value = match mapping { + // Simple path extraction - most common case + Mapping::Path(path) => { + let components: Vec<&str> = path.split('/').collect(); + get_path_value(json, &components) + .ok_or_else(|| LearnerError::ApiError(format!("Path '{}' not found", path)))? + }, - // First apply any explicit transforms - let mut value = raw_value; - for transform in &field_map.transforms { - value = apply_transform(&value, transform)?; - } + // Join multiple string values with a delimiter + Mapping::Join { paths, with } => { + let parts: Result> = paths + .iter() + .map(|path| { + let components: Vec<&str> = path.split('/').collect(); + get_path_value(json, &components) + .and_then(|v| v.as_str().map(|s| s.to_string())) + .ok_or_else(|| LearnerError::ApiError(format!("Path '{}' is not a string", path))) + }) + .collect(); + Value::String(parts?.join(with)) + }, - value = if let Some(structure) = &field_map.structure { - let mut object = BTreeMap::new(); - for (key, to_replace) in structure { - // TODO: Remove unwrap - object.insert(key, to_replace.replace("{value}", value.as_str().unwrap())); - } - json!(object) - } else { - value + // Map values into new structures - handles both arrays and objects + Mapping::Map { from, map } => { + // Get the source to map from, if specified + let source = if let Some(path) = from { + let components: Vec<&str> = path.split('/').collect(); + get_path_value(json, &components) + .ok_or_else(|| LearnerError::ApiError(format!("Path '{}' not found", path)))? + } else { + json.clone() + }; + + // Process based on whether the source is an array or not + match source { + Value::Array(items) => { + // Map each array item + let mapped: Result> = items + .iter() + .map(|item| { + let mut obj = Map::new(); + for (key, mapping) in map { + if let Ok(Some(value)) = extract_mapped_value(item, mapping, field_def) { + obj.insert(key.clone(), value); + } + } + Ok(Value::Object(obj)) + }) + .collect(); + Value::Array(mapped?) + }, + // Process as a single object + _ => { + let mut obj = Map::new(); + for (key, mapping) in map { + if let Ok(Some(value)) = extract_mapped_value(&source, mapping, field_def) { + obj.insert(key.clone(), value); + } + } + Value::Object(obj) + }, + } + }, }; - // Coerce a single value into an array if needed - if field_def.base_type == "array" { - value = into_array(value); - } - Ok(Some(value)) } @@ -224,44 +286,3 @@ fn get_path_value(json: &Value, path: &[&str]) -> Option { Some(current) } - -/// Apply a transform to a JSON value -fn apply_transform(value: &Value, transform: &Transform) -> Result { - match transform { - Transform::Replace { pattern, replacement } => { - let text = value.as_str().ok_or_else(|| { - LearnerError::ApiError("Replace transform requires string input".to_string()) - })?; - let re = - Regex::new(pattern).map_err(|e| LearnerError::ApiError(format!("Invalid regex: {e}")))?; - Ok(Value::String(re.replace_all(text, replacement.as_str()).into_owned())) - }, - Transform::Combine { subpaths, delimiter } => { - // TODO: fix unwrap - println!("INSIDE OF COMBINE WITH SUBPATHS: {:?}", subpaths); - match value.as_array() { - Some(arr) => - return Ok(Value::Array( - arr.iter().map(|v| combine_path_values(v, subpaths, delimiter)).collect(), - )), - None => return Ok(combine_path_values(value, subpaths, delimiter)), - } - }, - } -} - -fn combine_path_values(value: &Value, subpaths: &Vec, delimiter: &str) -> Value { - Value::String( - subpaths - .iter() - .fold(String::new(), |mut acc, subpath| { - if !acc.is_empty() { - acc.push_str(delimiter); - } - let subpath: Vec<&str> = subpath.split("/").collect(); - acc.push_str(dbg!(get_path_value(value, &subpath).unwrap().as_str().unwrap())); - acc - }) - .to_string(), - ) -} diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index 769e645..f688129 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -196,8 +196,9 @@ mod tests { // Verify response format - if let ResponseFormat::Xml { strip_namespaces } = &retriever.response_format { + if let ResponseFormat::Xml { strip_namespaces, clean_content } = &retriever.response_format { assert!(strip_namespaces); + assert!(clean_content); // Verify field mappings let field_maps = &retriever.resource_mappings; @@ -270,7 +271,8 @@ mod tests { // Verify response format match &retriever.response_format { - ResponseFormat::Json => { + ResponseFormat::Json { clean_content } => { + assert!(clean_content); // Verify field mappings let field_maps = &retriever.resource_mappings; assert!(field_maps.contains_key("title")); @@ -286,68 +288,69 @@ mod tests { #[test] fn test_iacr_config_deserialization() { - let config_str = include_str!("../../config/retrievers/iacr.toml"); - - let retriever: Retriever = toml::from_str(config_str).expect("Failed to parse config"); - - // Verify basic fields - // assert_eq!(retriever.name, "iacr"); - assert_eq!(retriever.base_url, "https://eprint.iacr.org"); - assert_eq!(retriever.source, "iacr"); - - // Test pattern matching - let test_cases = [ - ("2016/260", true), - ("2023/123", true), - ("https://eprint.iacr.org/2016/260", true), - ("https://eprint.iacr.org/2016/260.pdf", true), - ("invalid/format", false), - ("https://wrong.url/2016/260", false), - ]; - - for (input, expected) in test_cases { - assert_eq!( - retriever.pattern.is_match(input), - expected, - "Pattern match failed for input: {}", - input - ); - } - - // Test identifier extraction - assert_eq!(retriever.extract_identifier("2016/260").unwrap(), "2016/260"); - assert_eq!( - retriever.extract_identifier("https://eprint.iacr.org/2016/260").unwrap(), - "2016/260" - ); - assert_eq!( - retriever.extract_identifier("https://eprint.iacr.org/2016/260.pdf").unwrap(), - "2016/260" - ); - - // Verify response format - if let ResponseFormat::Xml { strip_namespaces } = &retriever.response_format { - assert!(strip_namespaces); - - // Verify field mappings - let field_maps = &retriever.resource_mappings; - assert!(field_maps.contains_key("title")); - assert!(field_maps.contains_key("abstract")); - assert!(field_maps.contains_key("authors")); - assert!(field_maps.contains_key("publication_date")); - assert!(field_maps.contains_key("pdf_url")); - - // Verify OAI-PMH paths - if let Some(map) = field_maps.get("title") { - assert!(map.path.contains(&"OAI-PMH/GetRecord/record/metadata/dc/title".to_string())); - } else { - panic!("Missing title field map"); - } - } else { - panic!("Expected an XML configuration, but did not get one.") - } + todo!("Fix this") + // let config_str = include_str!("../../config/retrievers/iacr.toml"); + + // let retriever: Retriever = toml::from_str(config_str).expect("Failed to parse config"); + + // // Verify basic fields + // // assert_eq!(retriever.name, "iacr"); + // assert_eq!(retriever.base_url, "https://eprint.iacr.org"); + // assert_eq!(retriever.source, "iacr"); + + // // Test pattern matching + // let test_cases = [ + // ("2016/260", true), + // ("2023/123", true), + // ("https://eprint.iacr.org/2016/260", true), + // ("https://eprint.iacr.org/2016/260.pdf", true), + // ("invalid/format", false), + // ("https://wrong.url/2016/260", false), + // ]; + + // for (input, expected) in test_cases { + // assert_eq!( + // retriever.pattern.is_match(input), + // expected, + // "Pattern match failed for input: {}", + // input + // ); + // } + + // // Test identifier extraction + // assert_eq!(retriever.extract_identifier("2016/260").unwrap(), "2016/260"); + // assert_eq!( + // retriever.extract_identifier("https://eprint.iacr.org/2016/260").unwrap(), + // "2016/260" + // ); + // assert_eq!( + // retriever.extract_identifier("https://eprint.iacr.org/2016/260.pdf").unwrap(), + // "2016/260" + // ); + + // // Verify response format + // if let ResponseFormat::Xml { strip_namespaces } = &retriever.response_format { + // assert!(strip_namespaces); + + // // Verify field mappings + // let field_maps = &retriever.resource_mappings; + // assert!(field_maps.contains_key("title")); + // assert!(field_maps.contains_key("abstract")); + // assert!(field_maps.contains_key("authors")); + // assert!(field_maps.contains_key("publication_date")); + // assert!(field_maps.contains_key("pdf_url")); + + // // Verify OAI-PMH paths + // if let Some(map) = field_maps.get("title") { + // assert!(map.path.contains(&"OAI-PMH/GetRecord/record/metadata/dc/title".to_string())); + // } else { + // panic!("Missing title field map"); + // } + // } else { + // panic!("Expected an XML configuration, but did not get one.") + // } - // Verify headers - assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); + // // Verify headers + // assert_eq!(retriever.headers.get("Accept").unwrap(), "application/xml"); } } diff --git a/crates/learner/src/retriever/response.rs b/crates/learner/src/retriever/response.rs index 3432610..7df4d1d 100644 --- a/crates/learner/src/retriever/response.rs +++ b/crates/learner/src/retriever/response.rs @@ -5,54 +5,49 @@ use super::*; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ResponseFormat { - /// XML response parser configuration #[serde(rename = "xml")] Xml { + /// Whether to strip XML namespace declarations and prefixes #[serde(default)] strip_namespaces: bool, + /// Whether to clean content by removing markup tags and normalizing whitespace + #[serde(default)] + clean_content: bool, }, - /// JSON response parser configuration + #[serde(rename = "json")] - Json, -} -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FieldMap { - /// Path to field in response - pub path: String, - /// Transformations to apply in order - #[serde(default)] - pub transforms: Vec, - /// Optional structured output - #[serde(default)] - pub structure: Option>, + Json { + /// Whether to clean string values by removing markup and normalizing content + #[serde(default)] + clean_content: bool, + }, } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum Transform { - /// Replace text using regex pattern - Replace { - /// Regular expression pattern to match - pattern: String, - /// Text to replace matched patterns with - replacement: String, +#[serde(untagged)] +pub enum Mapping { + // A single path string - most common case + Path(String), + + // Multiple paths to join with optional delimiter + Join { + paths: Vec, + #[serde(default = "default_delimiter")] + with: String, }, - Combine { - subpaths: Vec, - delimiter: String, + + // Map values into new structures + Map { + from: Option, + map: BTreeMap, }, } -pub fn xml_to_json(data: &[u8], strip_namespaces: bool) -> Value { - // Handle namespace stripping - let xml = if strip_namespaces { - strip_xml_namespaces(&String::from_utf8_lossy(data)) - } else { - String::from_utf8_lossy(data).to_string() - }; +fn default_delimiter() -> String { " ".to_string() } - trace!("Processing XML response: {:#?}", &xml); - let mut reader = Reader::from_str(&xml); +pub fn xml_to_json(data: &str) -> Value { + trace!("Processing XML response: {:#?}", data); + let mut reader = Reader::from_str(data); let mut stack = Vec::new(); let mut current = Map::new(); @@ -148,9 +143,58 @@ pub fn xml_to_json(data: &[u8], strip_namespaces: bool) -> Value { /// # Returns /// /// XML content with namespaces removed -fn strip_xml_namespaces(xml: &str) -> String { +pub fn strip_xml_namespaces(xml: &str) -> String { let re = regex::Regex::new(r#"xmlns(?::\w+)?="[^"]*""#).unwrap(); let mut result = re.replace_all(xml, "").to_string(); result = result.replace("oai_dc:", "").replace("dc:", ""); + result } + +pub fn clean_value(value: &mut Value) { + match value { + // Clean string content + Value::String(s) => + if s.contains('<') || s.contains('\n') { + *s = clean_content(s); + }, + // Recursively clean arrays and objects + Value::Array(arr) => + for item in arr { + clean_value(item); + }, + Value::Object(obj) => + for (_, val) in obj { + clean_value(val); + }, + _ => (), // Other value types don't need cleaning + } +} + +pub fn clean_content(s: &str) -> String { + let mut cleaned = s.to_string(); + + // Remove various markup tags + let tag_patterns = [ + // JATS tags + r"]+>", + r"]+>", + // Generic XML tags + r"<[^>]+>", + // Any remaining XML-like tags + r"]*>", + ]; + + for pattern in &tag_patterns { + if let Ok(re) = Regex::new(pattern) { + cleaned = re.replace_all(&cleaned, "").to_string(); + } + } + + // Normalize whitespace + if let Ok(re) = Regex::new(r"\s+") { + cleaned = re.replace_all(&cleaned.trim(), " ").to_string(); + } + + cleaned +} From 199458522140d03a6bda5b5ef0bc271b3956204b Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Mon, 9 Dec 2024 05:50:34 -0700 Subject: [PATCH 54/73] updated configs --- crates/learner/config_new/arxiv.toml | 34 +++++++++++++++-------- crates/learner/config_new/iacr.toml | 35 ++++++++++++++++-------- crates/learner/src/configuration.rs | 3 +- crates/learner/src/retriever/config.rs | 12 ++------ crates/learner/src/retriever/mod.rs | 6 ---- crates/learner/src/retriever/response.rs | 4 +-- 6 files changed, 51 insertions(+), 43 deletions(-) diff --git a/crates/learner/config_new/arxiv.toml b/crates/learner/config_new/arxiv.toml index 12cd874..e0750d1 100644 --- a/crates/learner/config_new/arxiv.toml +++ b/crates/learner/config_new/arxiv.toml @@ -1,6 +1,5 @@ -name = "arxiv" - description = "Retriever template for getting a paper from arXiv" +name = "arxiv" resource_template = "paper" retrieval_template = "retrieval" @@ -10,20 +9,31 @@ endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}&max_ pattern = "(?:^|https?://arxiv\\.org/(?:abs|pdf)/)(\\d{4}\\.\\d{4,5}|[a-zA-Z-]+(?:\\.[A-Z]{2})?/\\d{7})(?:\\.pdf)?$" source = "arxiv" -response_format = { type = "xml", strip_namespaces = true } +[response_format] +clean_content = true +strip_namespaces = true +type = "xml" [resource_mappings] -abstract = { path = "feed/entry/summary" } -authors = { path = "feed/entry/author" } -publication_dates = { path = "feed/entry/published" } -title = { path = "feed/entry/title" } +abstract = "feed/entry/summary" +title = "feed/entry/title" + +[resource_mappings.authors] +from = "feed/entry/author/name" +[resource_mappings.authors.map] +name = "." + +[resource_mappings.publication_dates.map] +created = "feed/entry/published" [retrieval_mappings.urls] -path = "feed/entry/id" -structure = { "pdf" = "http://arxiv.org/pdf/{value}.pdf", "html" = "http://arxiv.org/abs/{value}" } -transforms = [ - { type = "Replace", pattern = ".*/(?:abs|pdf)/(.+?)(?:v\\d+)?$", replacement = "$1" }, -] +from = "feed/entry/id" +[retrieval_mappings.urls.map.pdf] +paths = ["."] +transform = "replace:.*/(?:abs|pdf)/(.+?)(?:v\\d+)?$:http://arxiv.org/pdf/$1.pdf" # No delimiter needed for single value +[retrieval_mappings.urls.map.html] +paths = ["."] +transform = "replace:.*/(?:abs|pdf)/(.+?)(?:v\\d+)?$:http://arxiv.org/abs/$1" [headers] Accept = "application/xml" diff --git a/crates/learner/config_new/iacr.toml b/crates/learner/config_new/iacr.toml index e9679ad..453a43d 100644 --- a/crates/learner/config_new/iacr.toml +++ b/crates/learner/config_new/iacr.toml @@ -1,6 +1,5 @@ -name = "iacr" - description = "Retriever template for getting a paper from IACR" +name = "iacr" resource_template = "paper" retrieval_template = "retrieval" @@ -10,20 +9,32 @@ endpoint_template = "https://eprint.iacr.org/oai?verb=GetRecord&identifier=oai:e pattern = "(?:^|https?://eprint\\.iacr\\.org/)(\\d{4}/\\d+)(?:\\.pdf)?$" source = "iacr" -response_format = { type = "xml", strip_namespaces = true } +[response_format] +clean_content = true +strip_namespaces = true +type = "xml" [resource_mappings] -abstract = { path = "OAI-PMH/GetRecord/record/metadata/dc/description" } -authors = { path = "OAI-PMH/GetRecord/record/metadata/dc/creator" } -publication_dates = { path = "OAI-PMH/GetRecord/record/metadata/dc/date" } -title = { path = "OAI-PMH/GetRecord/record/metadata/dc/title" } +abstract = "OAI-PMH/GetRecord/record/metadata/dc/description" +title = "OAI-PMH/GetRecord/record/metadata/dc/title" + +[resource_mappings.authors] +from = "OAI-PMH/GetRecord/record/metadata/dc/creator" +map = { name = "." } + +[resource_mappings.publication_dates] +map = { created = "OAI-PMH/GetRecord/record/metadata/dc/date" } [retrieval_mappings.urls] -path = "OAI-PMH/GetRecord/record/metadata/dc/identifier" -structure = { "pdf" = "https://eprint.iacr.org/{value}.pdf", "html" = "https://eprint.iacr.org/{value}" } -transforms = [ - { type = "Replace", pattern = ".*/(\\d{4}/\\d+)$", replacement = "$1" }, -] +from = "OAI-PMH/GetRecord/record/metadata/dc/identifier" +[retrieval_mappings.urls.map] +html = { paths = [ + ".", +], transform = "replace:.*/(\\d{4}/\\d+)$:https://eprint.iacr.org/$1" } +pdf = { paths = [ + ".", +], transform = "replace:.*/(\\d{4}/\\d+)$:https://eprint.iacr.org/$1.pdf" } + [headers] Accept = "application/xml" diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 3fbbf1a..1a04178 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -65,8 +65,9 @@ mod tests { let retreival: Template = dbg!(manager.load_config("config_new/retrieval.toml")).unwrap(); - // let arxiv_retriever: Retriever = dbg!(manager.load_config("config_new/arxiv.toml").unwrap()); + let arxiv_retriever: Retriever = dbg!(manager.load_config("config_new/arxiv.toml").unwrap()); let doi_retriever: Retriever = dbg!(manager.load_config("config_new/doi.toml").unwrap()); + let iacr_retriever: Retriever = dbg!(manager.load_config("config_new/iacr.toml").unwrap()); todo!("Clean this up") // The paper_record now has all fields from base_resource and paper, diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index f707b64..bbfcf9e 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -115,8 +115,8 @@ impl Retriever { // Validate full resource against config // TODO: Add in validations here. - // self.resource_template.validate(dbg!(&resource))?; - // self.retrieval_template.validate(dbg!(&retrieval))?; + self.resource_template.validate(dbg!(&resource))?; + self.retrieval_template.validate(dbg!(&retrieval))?; Ok(Record { resource, state: State::default(), storage: StorageData::default(), retrieval }) } @@ -232,14 +232,6 @@ fn extract_mapped_value( Ok(Some(value)) } -fn into_array(value: Value) -> Value { - match value { - // Single value -> wrap in array - Value::Array(_) => value, - _ => json!(vec![value]), - } -} - /// Get a value from JSON using a path // Change return type to owned Value fn get_path_value(json: &Value, path: &[&str]) -> Option { diff --git a/crates/learner/src/retriever/mod.rs b/crates/learner/src/retriever/mod.rs index f688129..ee182fb 100644 --- a/crates/learner/src/retriever/mod.rs +++ b/crates/learner/src/retriever/mod.rs @@ -16,12 +16,6 @@ pub struct Retrievers { configs: BTreeMap, } -// impl Configurable for Retrievers { -// type Config = Retriever; - -// fn as_map(&mut self) -> &mut BTreeMap { &mut self.configs } -// } - impl Retrievers { /// Checks whether the retreivers map is empty. /// diff --git a/crates/learner/src/retriever/response.rs b/crates/learner/src/retriever/response.rs index 7df4d1d..99c7b5c 100644 --- a/crates/learner/src/retriever/response.rs +++ b/crates/learner/src/retriever/response.rs @@ -43,7 +43,7 @@ pub enum Mapping { }, } -fn default_delimiter() -> String { " ".to_string() } +fn default_delimiter() -> String { "".to_string() } pub fn xml_to_json(data: &str) -> Value { trace!("Processing XML response: {:#?}", data); @@ -193,7 +193,7 @@ pub fn clean_content(s: &str) -> String { // Normalize whitespace if let Ok(re) = Regex::new(r"\s+") { - cleaned = re.replace_all(&cleaned.trim(), " ").to_string(); + cleaned = re.replace_all(cleaned.trim(), " ").to_string(); } cleaned From e4147641c0f8c48193bb4d3deae4c5fa5ab10a3e Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Mon, 9 Dec 2024 07:10:23 -0700 Subject: [PATCH 55/73] WIP: save state --- crates/learner/config_new/doi.toml | 3 +- crates/learner/config_new/paper.toml | 30 +- crates/learner/src/lib.rs | 2 +- crates/learner/src/retriever/config.rs | 19 +- crates/learner/src/template.rs | 661 +++++++++++++----- .../tests/workflows/paper_retrieval.rs | 2 +- 6 files changed, 540 insertions(+), 177 deletions(-) diff --git a/crates/learner/config_new/doi.toml b/crates/learner/config_new/doi.toml index 8191d5a..a0f0115 100644 --- a/crates/learner/config_new/doi.toml +++ b/crates/learner/config_new/doi.toml @@ -33,7 +33,6 @@ published = "published/date-time" published_online = "published-online/date-time" published_print = "published-print/date-time" - [resource_mappings.container] from = "message" [resource_mappings.container.map] @@ -47,7 +46,7 @@ volume = "volume" from = "message/author" [resource_mappings.authors.map] affiliation = "affiliation/0/name" -name = { paths = ["given", "family"] } +name = { paths = ["given", "family"], with = " " } [retrieval_mappings] urls = { map = { pdf = "message/link/0/URL" } } diff --git a/crates/learner/config_new/paper.toml b/crates/learner/config_new/paper.toml index eadb66f..67e3c7d 100644 --- a/crates/learner/config_new/paper.toml +++ b/crates/learner/config_new/paper.toml @@ -24,13 +24,29 @@ fields = [ [publication_dates] base_type = "object" -fields = [ - { name = "created", base_type = "string", required = true, validation = { datetime = true } }, - { name = "published", base_type = "string", required = false, validation = { datetime = true } }, - { name = "published_online", base_type = "string", required = false, validation = { datetime = true } }, - { name = "published_print", base_type = "string", required = false, validation = { datetime = true } }, -] -required = true +required = true +[[publication_dates.fields]] +base_type = "array" +items = { base_type = "string", validation = { datetime = true } } +name = "created" +required = true +validation = { min_items = 1 } +[[publication_dates.fields]] +base_type = "array" +items = { base_type = "string", validation = { datetime = true } } +name = "published" +required = false +[[publication_dates.fields]] +base_type = "array" +items = { base_type = "string", validation = { datetime = true } } +name = "published_online" +required = false +[[publication_dates.fields]] +base_type = "array" +items = { base_type = "string", validation = { datetime = true } } +name = "published_print" +required = false + # New suggested fields [publisher] diff --git a/crates/learner/src/lib.rs b/crates/learner/src/lib.rs index bceb800..fd5a50e 100644 --- a/crates/learner/src/lib.rs +++ b/crates/learner/src/lib.rs @@ -155,7 +155,7 @@ use regex::Regex; use reqwest::Url; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tracing::{debug, trace, warn}; +use tracing::{debug, error, info, instrument, trace, warn}; #[cfg(test)] use {tempfile::tempdir, tracing_test::traced_test}; diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index bbfcf9e..88bfc15 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -229,7 +229,24 @@ fn extract_mapped_value( }, }; - Ok(Some(value)) + dbg!(&field_def); + let array_coerced = if field_def.base_type == "array" { + println!("{field_def:?} should be array"); + match value { + Value::Array(_) => value, + _ => Value::Array(vec![value]), + } + } else { + match (field_def.base_type.as_str(), &value) { + ("string", Value::Array(arr)) if arr.len() == 1 => { + println!("should be string"); + arr[0].clone() + }, + _ => value, + } + }; + + Ok(Some(array_coerced)) } /// Get a value from JSON using a path diff --git a/crates/learner/src/template.rs b/crates/learner/src/template.rs index 75e5296..7663816 100644 --- a/crates/learner/src/template.rs +++ b/crates/learner/src/template.rs @@ -17,6 +17,7 @@ pub struct Template { } // Custom deserialization to handle the flattened field structure + impl<'de> Deserialize<'de> for Template { fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de> { @@ -31,13 +32,14 @@ impl<'de> Deserialize<'de> for Template { let helper = TemplateHelper::deserialize(deserializer)?; - // Filter out metadata fields and set field names - let fields = helper + // Convert map into vec and set top-level field names + let mut fields: Vec = helper .fields .into_iter() .filter(|(key, _)| key != "name" && key != "description") .map(|(key, mut field_def)| { - field_def.name = key; + field_def.name = key.clone(); + field_def.process_nested_names(); field_def }) .collect(); @@ -49,7 +51,7 @@ impl<'de> Deserialize<'de> for Template { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FieldDefinition { /// Name of the field - #[serde(skip_deserializing)] + #[serde(default)] pub name: String, /// Whether this field must be present @@ -99,10 +101,26 @@ pub struct ValidationRules { } impl Template { + #[instrument( + skip(self, resource), + fields( + template_name = %self.name, + field_count = %self.fields.len() + ), + level = "debug" + )] pub fn validate(&self, resource: &TemplatedItem) -> Result<()> { - // Check required fields are present + info!("Starting template validation"); + + // First validate all required fields are present for field in &self.fields { if field.required && !resource.contains_key(&field.name) { + error!( + field = %field.name, + required = true, + validation_type = "required_field", + "Validation failed: missing required field" + ); return Err(LearnerError::TemplateInvalidation(format!( "Missing required field: {}", field.name @@ -110,81 +128,236 @@ impl Template { } } - // Validate each provided field + // Then validate each provided field for (name, value) in resource { if let Some(field) = self.fields.iter().find(|f| f.name == *name) { - field.validate_with_path(value, &field.name)?; + debug!( + field = %name, + field_type = %field.base_type, + required = %field.required, + "Validating field" + ); + + if let Err(e) = field.validate_with_path(value, &field.name) { + error!( + field = %name, + error = %e, + "Field validation failed" + ); + return Err(e); + } + } else { + warn!( + field = %name, + "Found unexpected field in resource" + ); } } + info!("Template validation completed successfully"); Ok(()) } } impl FieldDefinition { + #[instrument( + skip(self), + fields( + field_name = %self.name, + field_type = %self.base_type, + has_items = %self.items.is_some(), + has_fields = %self.fields.is_some() + ), + level = "debug" +)] + fn process_nested_names(&mut self) { + if let Some(items) = &mut self.items { + items.process_nested_names(); + } + + if let Some(fields) = &mut self.fields { + for (index, field) in fields.iter_mut().enumerate() { + if field.name.is_empty() { + error!( + parent_field = %self.name, + field_index = index, + field_type = %field.base_type, + "Found empty field name in template definition" + ); + } + field.process_nested_names(); + } + } + } + + #[instrument( + skip(self, value), + fields( + field_name = %self.name, + field_type = %self.base_type, + required = %self.required, + has_validation = %self.validation.is_some(), + has_items = %self.items.is_some(), + has_fields = %self.fields.is_some() + ) + )] fn validate_with_path(&self, value: &Value, path: &str) -> Result<()> { - match (self.base_type.as_str(), value) { + debug!( + path = %path, + value = ?value, + "Starting field validation" + ); + + let result = match (self.base_type.as_str(), value) { ("string", Value::String(s)) => self.validate_string(s, path), ("number", Value::Number(n)) => self.validate_number(n, path), ("array", Value::Array(items)) => self.validate_array(items, path), ("object", Value::Object(obj)) => self.validate_object(obj, path), - ("boolean", Value::Bool(_)) => Ok(()), - ("null", Value::Null) => Ok(()), - _ => Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' expected type '{}' but got '{}'", - path, - self.base_type, - type_name_of_value(value) - ))), + ("boolean", Value::Bool(_)) | ("null", Value::Null) => Ok(()), + _ => { + error!( + path = %path, + expected_type = %self.base_type, + actual_type = %type_name_of_value(value), + value = ?value, + "Type mismatch in field validation" + ); + Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' expected type '{}' but got '{}'", + path, + self.base_type, + type_name_of_value(value) + ))) + }, + }; + + if let Err(ref e) = result { + error!( + path = %path, + error = %e, + "Field validation failed" + ); } + result } + #[instrument( + skip(self, value), + fields( + field_name = %self.name, + string_length = %value.len(), + has_validation = %self.validation.is_some() + ) + )] fn validate_string(&self, value: &str, path: &str) -> Result<()> { + debug!( + path = %path, + value = %value, + "Starting string validation" + ); + if let Some(rules) = &self.validation { // Length constraints if let Some(min_length) = rules.min_length { if value.len() < min_length { + error!( + path = %path, + min_required = min_length, + actual_length = value.len(), + validation_type = "min_length", + "String validation failed: too short" + ); return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be at least {} characters", - path, min_length + "Field '{}' must be at least {} characters (found {})", + path, + min_length, + value.len() ))); } } + if let Some(max_length) = rules.max_length { if value.len() > max_length { + error!( + path = %path, + max_allowed = max_length, + actual_length = value.len(), + validation_type = "max_length", + "String validation failed: too long" + ); return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' cannot exceed {} characters", - path, max_length + "Field '{}' cannot exceed {} characters (found {})", + path, + max_length, + value.len() ))); } } // Pattern matching if let Some(pattern) = &rules.pattern { - let re = Regex::new(pattern) - .map_err(|_| LearnerError::TemplateInvalidation("Invalid regex pattern".into()))?; - if !re.is_match(value) { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must match pattern: {}", - path, pattern - ))); + match Regex::new(pattern) { + Ok(re) => + if !re.is_match(value) { + error!( + path = %path, + pattern = %pattern, + value = %value, + validation_type = "pattern", + "String validation failed: pattern mismatch" + ); + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{}' must match pattern: {}", + path, pattern + ))); + }, + Err(e) => { + error!( + path = %path, + pattern = %pattern, + error = %e, + validation_type = "pattern", + "Invalid regex pattern" + ); + return Err(LearnerError::TemplateInvalidation(format!( + "Invalid regex pattern for field '{}': {}", + path, e + ))); + }, } } // DateTime validation - if rules.datetime == Some(true) && DateTime::parse_from_rfc3339(value).is_err() { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be a valid RFC3339 datetime", - path - ))); + if rules.datetime == Some(true) { + match DateTime::parse_from_rfc3339(value) { + Ok(_) => {}, + Err(e) => { + error!( + path = %path, + value = %value, + error = %e, + validation_type = "datetime", + "Invalid datetime format" + ); + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{path}' must be a valid RFC3339 datetime: {e}", + ))); + }, + } } // Enum validation if let Some(allowed) = &rules.enum_values { if !allowed.contains(&value.to_string()) { + error!( + path = %path, + value = %value, + allowed_values = ?allowed, + validation_type = "enum", + "Invalid enum value" + ); return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be one of: {:?}", - path, allowed + "Field '{path}' must be one of: {allowed:?}", ))); } } @@ -192,106 +365,269 @@ impl FieldDefinition { Ok(()) } + #[instrument( + skip(self, value), + fields( + field_name = %self.name, + has_validation = %self.validation.is_some(), + number_type = ?value.as_f64().map(|_| "f64").or_else(|| value.as_i64().map(|_| "i64")).or_else(|| value.as_u64().map(|_| "u64")) + ) + )] fn validate_number(&self, value: &Number, path: &str) -> Result<()> { + debug!( + path = %path, + value = %value, + "Starting number validation" + ); + if let Some(rules) = &self.validation { if let Some(num) = value.as_f64() { if let Some(min) = rules.minimum { if num < min { + error!( + path = %path, + min_required = min, + actual = num, + validation_type = "minimum", + "Number validation failed: too small" + ); return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be at least {}", - path, min + "Field '{path}' must be at least {min} (found {num})", ))); } } + if let Some(max) = rules.maximum { if num > max { + error!( + path = %path, + max_allowed = max, + actual = num, + validation_type = "maximum", + "Number validation failed: too large" + ); return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' cannot exceed {}", - path, max + "Field '{path}' cannot exceed {max} (found {num})", ))); } } + if let Some(multiple) = rules.multiple_of { let ratio = num / multiple; if (ratio - ratio.round()).abs() > f64::EPSILON { + error!( + path = %path, + multiple = multiple, + value = num, + validation_type = "multiple_of", + "Number validation failed: not a multiple" + ); return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must be a multiple of {}", - path, multiple + "Field '{path}' must be a multiple of {multiple} (found {num})", ))); } } + } else { + warn!( + path = %path, + value = %value, + "Number could not be converted to f64 for validation" + ); } } Ok(()) } + #[instrument( + skip(self, items), + fields( + field_name = %self.name, + array_length = %items.len(), + has_validation = %self.validation.is_some(), + has_item_def = %self.items.is_some() + ) + )] fn validate_array(&self, items: &[Value], path: &str) -> Result<()> { + debug!( + path = %path, + "Starting array validation" + ); + + // Validate array-level rules if let Some(rules) = &self.validation { if let Some(min_items) = rules.min_items { if items.len() < min_items { + error!( + path = %path, + min_required = min_items, + actual = items.len(), + validation_type = "min_items", + "Array validation failed: too few items" + ); return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' must have at least {} items", - path, min_items + "Field '{}' must have at least {} items (found {})", + path, + min_items, + items.len() ))); } } + if let Some(max_items) = rules.max_items { if items.len() > max_items { + error!( + path = %path, + max_allowed = max_items, + actual = items.len(), + validation_type = "max_items", + "Array validation failed: too many items" + ); return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' cannot exceed {} items", - path, max_items + "Field '{}' cannot exceed {} items (found {})", + path, + max_items, + items.len() ))); } } + if rules.unique_items == Some(true) { let mut seen = HashSet::new(); - for item in items { - let item_str = serde_json::to_string(item).map_err(|_| { - LearnerError::TemplateInvalidation("Failed to serialize array item".into()) - })?; - if !seen.insert(item_str) { - return Err(LearnerError::TemplateInvalidation(format!( - "Field '{}' contains duplicate items", - path - ))); + for (idx, item) in items.iter().enumerate() { + match serde_json::to_string(item) { + Ok(item_str) => + if !seen.insert(item_str.clone()) { + error!( + path = %path, + index = idx, + value = %item_str, + validation_type = "unique_items", + "Array validation failed: duplicate item" + ); + return Err(LearnerError::TemplateInvalidation(format!( + "Field '{path}' contains duplicate item at index {idx}", + ))); + }, + Err(e) => { + error!( + path = %path, + index = idx, + error = %e, + validation_type = "unique_items", + "Failed to serialize array item" + ); + return Err(LearnerError::TemplateInvalidation(format!( + "Failed to check uniqueness for item at index {idx}: {e}", + ))); + }, } } } } - // Validate each item if we have an item type definition - if let Some(item_type) = &self.items { + // Validate individual items if we have an item definition + if let Some(item_def) = &self.items { for (index, item) in items.iter().enumerate() { - item_type.validate_with_path(item, &format!("{}[{}]", path, index)).map_err(|e| { - LearnerError::TemplateInvalidation(format!( - "Invalid item at index {} in array '{}': {}", - index, path, e - )) - })?; + let item_path = format!("{path}[{index}]"); + + match (item_def.base_type.as_str(), item) { + ("object", Value::Object(obj)) => { + if let Err(e) = item_def.validate_object(obj, &item_path) { + error!( + path = %item_path, + error = %e, + validation_type = "object", + "Array item validation failed" + ); + return Err(e); + } + }, + (expected, got) => { + error!( + path = %item_path, + expected_type = %expected, + actual_type = %type_name_of_value(got), + validation_type = "type_check", + "Array item type mismatch" + ); + return Err(LearnerError::TemplateInvalidation(format!( + "Item at index {} in '{}' expected type '{}' but got '{}'", + index, + path, + expected, + type_name_of_value(got) + ))); + }, + } } } - Ok(()) } + #[instrument( + skip(self, obj), + fields( + field_name = %self.name, + field_count = %obj.len(), + has_fields = %self.fields.is_some() + ) +)] fn validate_object(&self, obj: &Map, path: &str) -> Result<()> { + debug!( + path = %path, + fields = ?obj.keys().collect::>(), + "Starting object validation" + ); + if let Some(fields) = &self.fields { for field in fields { - if let Some(value) = obj.get(&field.name) { - field.validate_with_path(value, &format!("{}.{}", path, field.name))?; - } else if field.required { - return Err(LearnerError::TemplateInvalidation(format!( - "Missing required field '{}' in object '{}'", - field.name, path - ))); + match obj.get(&field.name) { + Some(value) => { + let field_path = format!("{}.{}", path, field.name); + if let Err(e) = field.validate_with_path(value, &field_path) { + error!( + path = %field_path, + field = %field.name, + error = %e, + "Object field validation failed" + ); + return Err(e); + } + }, + None if field.required => { + // Field is missing but required + error!( + path = %path, + field = %field.name, + validation_type = "required_field", + "Missing required field in object" + ); + return Err(LearnerError::TemplateInvalidation(format!( + "Missing required field '{}' in object '{}'", + field.name, path + ))); + }, + None => {}, } } + + // Log any extra fields that weren't in our field definitions + let defined_fields: HashSet<_> = fields.iter().map(|f| &f.name).collect(); + let extra_fields: Vec<_> = obj.keys().filter(|k| !defined_fields.contains(k)).collect(); + + if !extra_fields.is_empty() { + warn!( + path = %path, + extra_fields = ?extra_fields, + "Object contains undefined fields" + ); + } } Ok(()) } } -fn type_name_of_value(value: &Value) -> &'static str { +const fn type_name_of_value(value: &Value) -> &'static str { match value { Value::String(_) => "string", Value::Number(_) => "number", @@ -302,118 +638,113 @@ fn type_name_of_value(value: &Value) -> &'static str { } } -/// Convert DateTime to RFC3339 string for JSON storage +// TODO: Not sure we really need this... pub fn datetime_to_json(dt: DateTime) -> String { dt.to_rfc3339() } /// Parse RFC3339 string from JSON into DateTime pub fn datetime_from_json(s: &str) -> Result> { DateTime::parse_from_rfc3339(s) .map(|dt| dt.with_timezone(&Utc)) - .map_err(|e| LearnerError::TemplateInvalidation(format!("Invalid datetime format: {}", e))) + .map_err(|e| LearnerError::TemplateInvalidation(format!("Invalid datetime format: {e}"))) } #[cfg(test)] mod tests { - use chrono::TimeZone; use serde_json::json; use super::*; #[test] - fn validate_paper_configuration() { - let template = include_str!("../config/resources/paper.toml"); - let template: Template = toml::from_str(template).unwrap(); - - let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - - // Create a valid paper resource - let paper_resource = BTreeMap::from([ - ("title".into(), json!("Understanding Quantum Computing")), - ( - "authors".into(), - json!([{ - "name": "Alice Researcher", - "affiliation": "Tech University" - }]), - ), - ("publication_date".into(), json!(date)), - ("doi".into(), json!("10.1234/example.123")), - ]); - - // Validate the paper - template.validate(&paper_resource).unwrap(); - - // Test required field validation - let invalid_paper = BTreeMap::from([ - ("authors".into(), json!([])), // Missing title - ]); - assert!(template.validate(&invalid_paper).is_err()); - } - - #[test] - fn validate_book_configuration() { - let template = include_str!("../config/resources/book.toml"); - let template: Template = toml::from_str(template).unwrap(); - - let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - - let book_resource = BTreeMap::from([ - ("title".into(), json!("Advanced Quantum Computing")), - ("authors".into(), json!(["Alice Writer", "Bob Author"])), - ("isbn".into(), json!("978-0-12-345678-9")), - ("publisher".into(), json!("Academic Press")), - ("publication_date".into(), json!(date)), - ]); - - template.validate(&book_resource).unwrap(); - } - - #[test] - fn validate_thesis_configuration() { - let template = include_str!("../config/resources/thesis.toml"); - let template: Template = toml::from_str(template).unwrap(); - - let date = datetime_to_json(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).single().unwrap()); - - let thesis_resource = BTreeMap::from([ - ("title".into(), json!("Novel Approaches to Quantum Error Correction")), - ("author".into(), json!("Alice Researcher")), - ("degree".into(), json!("PhD")), - ("institution".into(), json!("Tech University")), - ("completion_date".into(), json!(date)), - ("advisors".into(), json!(["Prof. Bob Supervisor"])), - ]); - - template.validate(&thesis_resource).unwrap(); - - // Test degree enum validation - let mut invalid_thesis = thesis_resource.clone(); - invalid_thesis.insert("degree".into(), json!("InvalidDegree")); - assert!(template.validate(&invalid_thesis).is_err()); + #[traced_test] + fn test_array_object_validation() { + let template_str = r#" + name = "test" + description = "Test template" + + [authors] + base_type = "array" + required = true + validation = { min_items = 1 } + + [authors.items] + base_type = "object" + + [[authors.items.fields]] + name = "name" + base_type = "string" + required = true + validation = { min_length = 1 } + + [[authors.items.fields]] + name = "affiliation" + base_type = "string" + required = false + "#; + + let template: Template = toml::from_str(template_str).unwrap(); + + // Test valid case + let valid_resource = BTreeMap::from([( + "authors".into(), + json!([ + {"name": "John Doe", "affiliation": "University"}, + {"name": "Jane Smith"} + ]), + )]); + + if let Err(e) = template.validate(&valid_resource) { + error!( + error = %e, + template = ?template, + data = ?valid_resource, + "Validation failed unexpectedly" + ); + panic!("Validation should have succeeded: {}", e); + } } #[test] + #[traced_test] fn test_datetime_validation() { - todo!("Fix this") - // let template = Template { - // name: "Test Template".to_string(), - // description: None, - // fields: vec![FieldDefinition { - // name: "timestamp".into(), - // field_type: "string".into(), - // required: true, - // description: None, - // default: None, - // validation: Some(ValidationRules { datetime: Some(true), ..Default::default() }), - // type_definition: None, - // }], - // }; - - // let valid_resource = BTreeMap::from([("timestamp".into(), json!("2024-01-01T00:00:00Z"))]); - // template.validate(&valid_resource).unwrap(); - - // let invalid_resource = BTreeMap::from([ - // ("timestamp".into(), json!("2024-01-01")), // Not RFC3339 - // ]); - // assert!(template.validate(&invalid_resource).is_err()); + let template_str = r#" + name = "test" + description = "Test template" + + [dates] + base_type = "object" + required = true + + [[dates.fields]] + name = "created" + base_type = "string" + required = true + validation = { datetime = true } + + [[dates.fields]] + name = "updated" + base_type = "string" + required = false + validation = { datetime = true } + "#; + + let template: Template = toml::from_str(template_str).unwrap(); + + // Test valid dates + let valid_dates = BTreeMap::from([( + "dates".into(), + json!({ + "created": "2024-01-01T00:00:00Z", + "updated": "2024-02-01T00:00:00Z" + }), + )]); + + if let Err(e) = template.validate(&valid_dates) { + error!( + error = %e, + template = ?template, + data = ?valid_dates, + "Validation failed unexpectedly" + ); + panic!("Validation should have succeeded: {}", e); + } } } diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index ad30638..0ce6d88 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -4,7 +4,7 @@ use learner::configuration::ConfigurationManager; use super::*; -#[traced_test] +// #[traced_test] #[tokio::test] async fn test_arxiv_retriever_integration() -> TestResult<()> { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); From c2a8e52267bd8cc8048652e6639f9ad2524cb006 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 14 Dec 2024 06:10:14 -0700 Subject: [PATCH 56/73] fix: coerce value --- crates/learner/src/retriever/config.rs | 106 +++++++++++++----- .../tests/workflows/paper_retrieval.rs | 4 +- 2 files changed, 82 insertions(+), 28 deletions(-) diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 88bfc15..18286c7 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -164,15 +164,13 @@ fn extract_mapped_value( mapping: &Mapping, field_def: &FieldDefinition, ) -> Result> { - let value = match mapping { - // Simple path extraction - most common case + // First get the raw value through mapping + let raw_value = match mapping { Mapping::Path(path) => { let components: Vec<&str> = path.split('/').collect(); get_path_value(json, &components) .ok_or_else(|| LearnerError::ApiError(format!("Path '{}' not found", path)))? }, - - // Join multiple string values with a delimiter Mapping::Join { paths, with } => { let parts: Result> = paths .iter() @@ -185,8 +183,6 @@ fn extract_mapped_value( .collect(); Value::String(parts?.join(with)) }, - - // Map values into new structures - handles both arrays and objects Mapping::Map { from, map } => { // Get the source to map from, if specified let source = if let Some(path) = from { @@ -197,16 +193,16 @@ fn extract_mapped_value( json.clone() }; - // Process based on whether the source is an array or not match source { Value::Array(items) => { - // Map each array item let mapped: Result> = items .iter() .map(|item| { let mut obj = Map::new(); for (key, mapping) in map { - if let Ok(Some(value)) = extract_mapped_value(item, mapping, field_def) { + if let Ok(Some(value)) = + extract_mapped_value(item, mapping, &get_field_def(field_def, key)) + { obj.insert(key.clone(), value); } } @@ -215,11 +211,12 @@ fn extract_mapped_value( .collect(); Value::Array(mapped?) }, - // Process as a single object _ => { let mut obj = Map::new(); for (key, mapping) in map { - if let Ok(Some(value)) = extract_mapped_value(&source, mapping, field_def) { + if let Ok(Some(value)) = + extract_mapped_value(&source, mapping, &get_field_def(field_def, key)) + { obj.insert(key.clone(), value); } } @@ -229,24 +226,81 @@ fn extract_mapped_value( }, }; - dbg!(&field_def); - let array_coerced = if field_def.base_type == "array" { - println!("{field_def:?} should be array"); - match value { - Value::Array(_) => value, - _ => Value::Array(vec![value]), + // Then coerce the value based on the expected type + let coerced = coerce_value(&raw_value, field_def)?; + + Ok(Some(coerced)) +} + +// Helper function to get field definition for nested fields +fn get_field_def<'a>(parent: &'a FieldDefinition, field_name: &str) -> FieldDefinition { + // Check for object fields first + if let Some(fields) = &parent.fields { + if let Some(field) = fields.iter().find(|f| f.name == field_name) { + return field.clone(); } - } else { - match (field_def.base_type.as_str(), &value) { - ("string", Value::Array(arr)) if arr.len() == 1 => { - println!("should be string"); - arr[0].clone() - }, - _ => value, + } + + // Then check array items if they exist + if let Some(items) = &parent.items { + if let Some(fields) = &items.fields { + if let Some(field) = fields.iter().find(|f| f.name == field_name) { + return field.clone(); + } } - }; + } + + // Return a default field definition if not found + FieldDefinition { + name: field_name.to_string(), + base_type: "string".to_string(), + required: false, + description: None, + validation: None, + items: None, + fields: None, + } +} - Ok(Some(array_coerced)) +// Helper function to coerce values based on expected type +fn coerce_value(value: &Value, field_def: &FieldDefinition) -> Result { + let result = match field_def.base_type.as_str() { + "array" => match value { + Value::Array(_) => value.clone(), + // If not an array but should be, wrap it + _ => Value::Array(vec![value.clone()]), + }, + "string" => match value { + // If we have a single-element array and need a string + Value::Array(arr) if arr.len() == 1 => + if let Some(s) = arr[0].as_str() { + Value::String(s.to_string()) + } else { + arr[0].clone() + }, + Value::String(_) => value.clone(), + _ => value.clone(), + }, + "object" => match value { + Value::Object(obj) => { + let mut new_obj = Map::new(); + // If we have fields defined, try to coerce each field + if let Some(fields) = &field_def.fields { + for field in fields { + if let Some(val) = obj.get(&field.name) { + new_obj.insert(field.name.clone(), coerce_value(val, field)?); + } + } + Value::Object(new_obj) + } else { + value.clone() + } + }, + _ => value.clone(), + }, + _ => value.clone(), + }; + Ok(result) } /// Get a value from JSON using a path diff --git a/crates/learner/tests/workflows/paper_retrieval.rs b/crates/learner/tests/workflows/paper_retrieval.rs index 0ce6d88..f6e559b 100644 --- a/crates/learner/tests/workflows/paper_retrieval.rs +++ b/crates/learner/tests/workflows/paper_retrieval.rs @@ -52,7 +52,7 @@ async fn test_arxiv_pdf_from_paper() -> TestResult<()> { // Ok(()) } -#[traced_test] +// #[traced_test] #[tokio::test] async fn test_iacr_retriever_integration() -> TestResult<()> { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); @@ -97,7 +97,7 @@ async fn test_iacr_pdf_from_paper() -> TestResult<()> { } #[tokio::test] -#[traced_test] +// #[traced_test] async fn test_doi_retriever_integration() -> TestResult<()> { let mut manager = ConfigurationManager::new(PathBuf::from("config_new")); let retriever: Retriever = dbg!(manager.load_config("config_new/doi.toml")?); From 002102da08af1a9cb8f5308bec747ee69b4026f4 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 14 Dec 2024 06:33:52 -0700 Subject: [PATCH 57/73] fix: template validation --- crates/learner/src/template.rs | 78 ++++++++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/crates/learner/src/template.rs b/crates/learner/src/template.rs index 7663816..a6132ee 100644 --- a/crates/learner/src/template.rs +++ b/crates/learner/src/template.rs @@ -527,40 +527,82 @@ impl FieldDefinition { // Validate individual items if we have an item definition if let Some(item_def) = &self.items { + debug!( + path = %path, + expected_type = %item_def.base_type, + "Validating array items" + ); + for (index, item) in items.iter().enumerate() { let item_path = format!("{path}[{index}]"); - match (item_def.base_type.as_str(), item) { - ("object", Value::Object(obj)) => { - if let Err(e) = item_def.validate_object(obj, &item_path) { + match item_def.base_type.as_str() { + "string" => + if let Value::String(s) = item { + if let Err(e) = item_def.validate_string(s, &item_path) { + error!( + path = %item_path, + error = %e, + validation_type = "string", + "Array item validation failed" + ); + return Err(e); + } + } else { error!( path = %item_path, - error = %e, - validation_type = "object", - "Array item validation failed" + expected_type = "string", + actual_type = %type_name_of_value(item), + validation_type = "type_check", + "Array item type mismatch" ); - return Err(e); - } - }, - (expected, got) => { + return Err(LearnerError::TemplateInvalidation(format!( + "Item at index {} in '{}' expected type 'string' but got '{}'", + index, + path, + type_name_of_value(item) + ))); + }, + "object" => + if let Value::Object(obj) = item { + if let Err(e) = item_def.validate_object(obj, &item_path) { + error!( + path = %item_path, + error = %e, + validation_type = "object", + "Array item validation failed" + ); + return Err(e); + } + } else { + error!( + path = %item_path, + expected_type = "object", + actual_type = %type_name_of_value(item), + validation_type = "type_check", + "Array item type mismatch" + ); + return Err(LearnerError::TemplateInvalidation(format!( + "Item at index {} in '{}' expected type 'object' but got '{}'", + index, + path, + type_name_of_value(item) + ))); + }, + other => { error!( path = %item_path, - expected_type = %expected, - actual_type = %type_name_of_value(got), validation_type = "type_check", - "Array item type mismatch" + "Unsupported array item type" ); return Err(LearnerError::TemplateInvalidation(format!( - "Item at index {} in '{}' expected type '{}' but got '{}'", - index, - path, - expected, - type_name_of_value(got) + "Unsupported array item type '{other}' at index {index} in '{path}'", ))); }, } } } + Ok(()) } From d2932e964a73592e53b7e4a0c051fa527a2904b3 Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 14 Dec 2024 06:35:38 -0700 Subject: [PATCH 58/73] chore: remove some lint --- crates/learner/src/retriever/config.rs | 65 +++++++++++--------------- 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/crates/learner/src/retriever/config.rs b/crates/learner/src/retriever/config.rs index 18286c7..f4cde12 100644 --- a/crates/learner/src/retriever/config.rs +++ b/crates/learner/src/retriever/config.rs @@ -1,5 +1,4 @@ use record::{Record, State, StorageData}; -use serde_json::json; use super::*; use crate::template::{FieldDefinition, Template, TemplatedItem}; @@ -169,7 +168,7 @@ fn extract_mapped_value( Mapping::Path(path) => { let components: Vec<&str> = path.split('/').collect(); get_path_value(json, &components) - .ok_or_else(|| LearnerError::ApiError(format!("Path '{}' not found", path)))? + .ok_or_else(|| LearnerError::ApiError(format!("Path '{path}' not found")))? }, Mapping::Join { paths, with } => { let parts: Result> = paths @@ -177,8 +176,8 @@ fn extract_mapped_value( .map(|path| { let components: Vec<&str> = path.split('/').collect(); get_path_value(json, &components) - .and_then(|v| v.as_str().map(|s| s.to_string())) - .ok_or_else(|| LearnerError::ApiError(format!("Path '{}' is not a string", path))) + .and_then(|v| v.as_str().map(std::string::ToString::to_string)) + .ok_or_else(|| LearnerError::ApiError(format!("Path '{path}' is not a string"))) }) .collect(); Value::String(parts?.join(with)) @@ -188,40 +187,37 @@ fn extract_mapped_value( let source = if let Some(path) = from { let components: Vec<&str> = path.split('/').collect(); get_path_value(json, &components) - .ok_or_else(|| LearnerError::ApiError(format!("Path '{}' not found", path)))? + .ok_or_else(|| LearnerError::ApiError(format!("Path '{path}' not found")))? } else { json.clone() }; - match source { - Value::Array(items) => { - let mapped: Result> = items - .iter() - .map(|item| { - let mut obj = Map::new(); - for (key, mapping) in map { - if let Ok(Some(value)) = - extract_mapped_value(item, mapping, &get_field_def(field_def, key)) - { - obj.insert(key.clone(), value); - } + if let Value::Array(items) = source { + let mapped: Result> = items + .iter() + .map(|item| { + let mut obj = Map::new(); + for (key, mapping) in map { + if let Ok(Some(value)) = + extract_mapped_value(item, mapping, &get_field_def(field_def, key)) + { + obj.insert(key.clone(), value); } - Ok(Value::Object(obj)) - }) - .collect(); - Value::Array(mapped?) - }, - _ => { - let mut obj = Map::new(); - for (key, mapping) in map { - if let Ok(Some(value)) = - extract_mapped_value(&source, mapping, &get_field_def(field_def, key)) - { - obj.insert(key.clone(), value); } + Ok(Value::Object(obj)) + }) + .collect(); + Value::Array(mapped?) + } else { + let mut obj = Map::new(); + for (key, mapping) in map { + if let Ok(Some(value)) = + extract_mapped_value(&source, mapping, &get_field_def(field_def, key)) + { + obj.insert(key.clone(), value); } - Value::Object(obj) - }, + } + Value::Object(obj) } }, }; @@ -273,12 +269,7 @@ fn coerce_value(value: &Value, field_def: &FieldDefinition) -> Result { "string" => match value { // If we have a single-element array and need a string Value::Array(arr) if arr.len() == 1 => - if let Some(s) = arr[0].as_str() { - Value::String(s.to_string()) - } else { - arr[0].clone() - }, - Value::String(_) => value.clone(), + arr[0].as_str().map_or_else(|| arr[0].clone(), |s| Value::String(s.to_string())), _ => value.clone(), }, "object" => match value { From 69d765be2090ea2cce2db69e976cb10caa09489e Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Sat, 14 Dec 2024 07:21:54 -0700 Subject: [PATCH 59/73] WIP: `state` and `storage` templates --- crates/learner/config_new/config.toml | 11 +++++ crates/learner/config_new/state.toml | 19 ++++++++ crates/learner/config_new/storage.toml | 14 ++++++ crates/learner/src/configuration.rs | 64 +++++++++++++++----------- crates/learner/src/record.rs | 11 ----- 5 files changed, 81 insertions(+), 38 deletions(-) create mode 100644 crates/learner/config_new/config.toml create mode 100644 crates/learner/config_new/state.toml create mode 100644 crates/learner/config_new/storage.toml diff --git a/crates/learner/config_new/config.toml b/crates/learner/config_new/config.toml new file mode 100644 index 0000000..7942016 --- /dev/null +++ b/crates/learner/config_new/config.toml @@ -0,0 +1,11 @@ +description = "Global configuration" +name = "learner" + +# Core templates that apply to all resources +retrieval_template = "retrieval" +state_template = "state" +storage_template = "storage" + +# List of available resource types +[[resources]] +template = "paper" diff --git a/crates/learner/config_new/state.toml b/crates/learner/config_new/state.toml new file mode 100644 index 0000000..70455f6 --- /dev/null +++ b/crates/learner/config_new/state.toml @@ -0,0 +1,19 @@ +description = "Base state tracking template" +name = "state" + +[read_status] +base_type = "string" +required = true +validation = { enum_values = ["unopened", "opened", "completed"] } + +[starred] +base_type = "boolean" +default = false +required = true + +[rating] +base_type = "number" +required = false +validation = { minimum = 1, maximum = 5 } + +# Additional fields can be added by extending this template diff --git a/crates/learner/config_new/storage.toml b/crates/learner/config_new/storage.toml new file mode 100644 index 0000000..eb82e6d --- /dev/null +++ b/crates/learner/config_new/storage.toml @@ -0,0 +1,14 @@ +description = "Base storage configuration template" +name = "storage" + +[files] +base_type = "object" +required = true + +[original_filenames] +base_type = "object" +required = true + +[added_at] +base_type = "object" +required = true diff --git a/crates/learner/src/configuration.rs b/crates/learner/src/configuration.rs index 1a04178..39bf583 100644 --- a/crates/learner/src/configuration.rs +++ b/crates/learner/src/configuration.rs @@ -1,4 +1,4 @@ -use serde::de::DeserializeOwned; +use template::Template; use super::*; @@ -18,41 +18,50 @@ impl ConfigurationManager { } } - pub fn load_config(&mut self, path: impl AsRef) -> Result - where T: DeserializeOwned + std::fmt::Debug { - let path = path.as_ref(); - let content = dbg!(std::fs::read_to_string(path)?); - let mut raw_config: toml::Value = dbg!(toml::from_str(&content)?); + pub fn load_config(&mut self) -> Result