From 1acd737566322e5fce970070b6449507e60996dd Mon Sep 17 00:00:00 2001 From: Zac <579103+7a6163@users.noreply.github.com> Date: Mon, 16 Mar 2026 16:16:59 +0800 Subject: [PATCH 1/4] chore: add ruby-prism spike to verify API before migration --- Cargo.lock | 163 ++++++++++++++++++++++++++++ Cargo.toml | 1 + examples/prism_spike.rs | 231 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 395 insertions(+) create mode 100644 examples/prism_spike.rs diff --git a/Cargo.lock b/Cargo.lock index 4cd8adf..eb9fd4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + [[package]] name = "alloc-from-pool" version = "1.0.5" @@ -64,18 +73,68 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "bindgen" +version = "0.72.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", +] + [[package]] name = "bitflags" version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +[[package]] +name = "cc" +version = "1.2.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.5.60" @@ -184,6 +243,12 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + [[package]] name = "foldhash" version = "0.1.5" @@ -254,6 +319,15 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.17" @@ -288,6 +362,16 @@ version = "0.2.182" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -306,6 +390,22 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -372,6 +472,56 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "ruby-prism" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b302f00359d0b5423a600314935ca5f994484e15da2db5345631d28c6e029d6" +dependencies = [ + "ruby-prism-sys", + "serde", + "serde_json", +] + +[[package]] +name = "ruby-prism-sys" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f85fd9571455bdca2b3bf0b2f543cd13ac39d5de0026a8eb2deb94ffd264f00" +dependencies = [ + "bindgen", + "cc", +] + [[package]] name = "rubyfast" version = "1.2.5" @@ -382,11 +532,18 @@ dependencies = [ "glob", "lib-ruby-parser", "rayon", + "ruby-prism", "serde", "serde_yaml", "tempfile", ] +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustix" version = "1.1.4" @@ -468,6 +625,12 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "strsim" version = "0.11.1" diff --git a/Cargo.toml b/Cargo.toml index 430f154..a64aaf6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,4 @@ anyhow = "1" [dev-dependencies] tempfile = "3" +ruby-prism = "1.9" diff --git a/examples/prism_spike.rs b/examples/prism_spike.rs new file mode 100644 index 0000000..79f3fdb --- /dev/null +++ b/examples/prism_spike.rs @@ -0,0 +1,231 @@ +//! Spike: Verify ruby-prism API for the migration from lib-ruby-parser. +//! +//! Run with: cargo run --example prism_spike + +use ruby_prism::*; + +fn main() { + println!("=== 1. Block/Call structure ==="); + spike_block_call(); + + println!("\n=== 2. Send on block (chained call) ==="); + spike_send_on_block(); + + println!("\n=== 3. For loop locations ==="); + spike_for_loop(); + + println!("\n=== 4. Method definition ==="); + spike_def(); + + println!("\n=== 5. Rescue body ==="); + spike_rescue(); + + println!("\n=== 6. Range types ==="); + spike_range(); + + println!("\n=== 7. String / Integer values ==="); + spike_values(); + + println!("\n=== 8. Comments ==="); + spike_comments(); + + println!("\n=== 9. Error tolerance ==="); + spike_errors(); + + println!("\n=== 10. ASCII encoding ==="); + spike_encoding(); + + println!("\n=== 11. Block pass (symbol to proc) ==="); + spike_block_pass(); + + println!("\n=== 12. Child nodes / visitor ==="); + spike_visitor(); + + println!("\n=== 13. ERB template (expected error) ==="); + spike_erb(); + + println!("\n=== 14. Node enum pattern matching ==="); + spike_pattern_match(); +} + +fn spike_block_call() { + // arr.map { |x| x.to_s } + let source = b"arr.map { |x| x.to_s }"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + println!("Root: {:#?}", result.node()); +} + +fn spike_send_on_block() { + // arr.select { |x| x > 1 }.first + let source = b"arr.select { |x| x > 1 }.first"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + println!("Root: {:#?}", result.node()); +} + +fn spike_for_loop() { + let source = b"for x in [1, 2, 3]\n puts x\nend"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + println!("Root: {:#?}", result.node()); +} + +fn spike_def() { + let source = b"def foo(a, &block); block.call; end"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + println!("Root: {:#?}", result.node()); +} + +fn spike_rescue() { + let source = b"begin; x; rescue NoMethodError => e; retry; end"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + println!("Root: {:#?}", result.node()); +} + +fn spike_range() { + let source = b"(1..10).include?(5); (1...10).cover?(5)"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + println!("Root: {:#?}", result.node()); +} + +fn spike_values() { + let source = b"'x'; 42; 1; :sym"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + println!("Root: {:#?}", result.node()); +} + +fn spike_comments() { + let source = b"x = 1 # rubyfast:disable shuffle_first_vs_sample\ny = 2\n"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + for comment in result.comments() { + let loc = comment.location(); + println!( + " Comment at {}..{}: {:?}", + loc.start_offset(), + loc.end_offset(), + std::str::from_utf8(loc.as_slice()).unwrap_or("") + ); + } +} + +fn spike_errors() { + let source = b"def foo; end; def def; end"; + let result = parse(source); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + let error_count: usize = result.errors().count(); + println!("Errors: {}", error_count); + for err in result.errors() { + println!(" Error: {:?}", err.message()); + } + println!("Has AST: true (prism always produces one)"); + println!("Root: {:#?}", result.node()); +} + +fn spike_encoding() { + let source = b"# encoding: us-ascii\nx = 1\n"; + let result = parse(source); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + let error_count: usize = result.errors().count(); + println!("Errors: {}", error_count); + print_errors(&result); + println!("Root: {:#?}", result.node()); +} + +fn spike_block_pass() { + let source = b"arr.map(&:to_s)"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + println!("Root: {:#?}", result.node()); +} + +fn spike_visitor() { + let source = b"arr.select { |x| x > 1 }.first"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + + // Use the Visit trait + struct CallCounter { + count: usize, + } + impl<'pr> Visit<'pr> for CallCounter { + fn visit_call_node(&mut self, node: &CallNode<'pr>) { + println!( + " Found CallNode: name={:?}, has_receiver={}, has_block={}", + std::str::from_utf8(node.name().as_slice()).unwrap_or("?"), + node.receiver().is_some(), + node.block().is_some() + ); + self.count += 1; + // Must call the default visitor to recurse into children + visit_call_node(self, node); + } + } + + let mut counter = CallCounter { count: 0 }; + counter.visit(&result.node()); + println!(" Total CallNodes found: {}", counter.count); +} + +fn spike_erb() { + let source = b"class Foo < ActiveRecord::Migration<%= migration_version %>\nend"; + let result = parse(source); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + let error_count: usize = result.errors().count(); + println!("Errors: {}", error_count); + for err in result.errors() { + println!(" Error: {:?}", err.message()); + } +} + +fn spike_pattern_match() { + let source = b"arr.map { |x| x.to_s }"; + let result = parse(source); + print_errors(&result); + println!("Source: {:?}", std::str::from_utf8(source).unwrap()); + + // Walk the top-level statements + let program = result.node(); + if let Node::ProgramNode { .. } = &program { + let prog = program.as_program_node().unwrap(); + let stmts = prog.statements(); + for node in stmts.body().iter() { + println!(" Top-level node variant:"); + match &node { + Node::CallNode { .. } => { + let call = node.as_call_node().unwrap(); + println!( + " CallNode: name={:?}, has_block={}", + std::str::from_utf8(call.name().as_slice()).unwrap_or("?"), + call.block().is_some() + ); + if let Some(block) = call.block() { + println!(" Block: {:#?}", block); + } + } + other => println!(" {:?}", other), + } + } + } +} + +fn print_errors(result: &ParseResult) { + for err in result.errors() { + println!(" PARSE ERROR: {:?}", err.message()); + } +} From 4585d4858f295e96daf5b789f39f30a32ca5aaa4 Mon Sep 17 00:00:00 2001 From: Zac <579103+7a6163@users.noreply.github.com> Date: Mon, 16 Mar 2026 16:48:57 +0800 Subject: [PATCH 2/4] feat: migrate parser from lib-ruby-parser to ruby-prism Replace lib-ruby-parser (Ruby 3.1.2) with ruby-prism (Ruby 3.3+), the official Ruby parser with Rust bindings. This brings: - Support for Ruby 3.2, 3.3, 3.4+ syntax - Native handling of all encodings (ASCII, US-ASCII, etc.) - Error-tolerant parsing (always produces an AST) - Future-proof: prism is Ruby's default parser since 3.3 Key changes: - Cargo.toml: replace lib-ruby-parser with ruby-prism 1.9 - ast_helpers.rs: rewrite all helpers for prism node types (CallNode, DefNode, BlockNode, etc.) - ast_visitor.rs: rewrite visitor for prism's Node enum - analyzer.rs: adapt walk_node for Block/Call inversion (prism: CallNode owns BlockNode, not vice versa) - fix.rs: simplify verify_syntax using prism - comment_directives.rs: use prism's Comment iterator - All 4 scanners rewritten for prism node types - Remove custom ASCII encoding decoder (prism handles natively) All 244 tests pass. Zero clippy warnings. --- Cargo.lock | 23 - Cargo.toml | 3 +- src/analyzer.rs | 226 +++-- src/ast_helpers.rs | 622 ++++++------ src/ast_visitor.rs | 1144 ++++++++++++---------- src/comment_directives.rs | 21 +- src/file_traverser.rs | 6 +- src/fix.rs | 18 +- src/scanner/for_loop_scanner.rs | 137 ++- src/scanner/method_call_scanner.rs | 498 +++++----- src/scanner/method_definition_scanner.rs | 102 +- src/scanner/rescue_scanner.rs | 58 +- 12 files changed, 1533 insertions(+), 1325 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eb9fd4d..ff1ad06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,12 +11,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "alloc-from-pool" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bee030c58cf5648ea793d06e5aa6039f913bfbf9f68a0635c76ba429d393fa6c" - [[package]] name = "anstream" version = "0.6.21" @@ -340,22 +334,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" -[[package]] -name = "lib-ruby-parser" -version = "4.0.6+ruby-3.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a64ffd6ab03aa1e6a986b42260202e25fb8d197dd7be0de11088cca389f67ce" -dependencies = [ - "alloc-from-pool", - "lib-ruby-parser-ast", -] - -[[package]] -name = "lib-ruby-parser-ast" -version = "0.55.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "461948472e72b507a0f97e144c453e29b8772e986f18b410e0f2318edd45258c" - [[package]] name = "libc" version = "0.2.182" @@ -530,7 +508,6 @@ dependencies = [ "clap", "colored", "glob", - "lib-ruby-parser", "rayon", "ruby-prism", "serde", diff --git a/Cargo.toml b/Cargo.toml index a64aaf6..bbb67b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ keywords = ["ruby", "linter", "performance", "static-analysis"] categories = ["command-line-utilities", "development-tools"] [dependencies] -lib-ruby-parser = "4.0" +ruby-prism = "1.9" clap = { version = "4", features = ["derive"] } serde = { version = "1", features = ["derive"] } serde_yaml = "0.9" @@ -21,4 +21,3 @@ anyhow = "1" [dev-dependencies] tempfile = "3" -ruby-prism = "1.9" diff --git a/src/analyzer.rs b/src/analyzer.rs index 7781d8c..f12d852 100644 --- a/src/analyzer.rs +++ b/src/analyzer.rs @@ -1,9 +1,9 @@ use std::path::Path; -use lib_ruby_parser::{ErrorLevel, Node, Parser}; +use ruby_prism::Node; -use crate::ast_helpers::{byte_offset_to_line, compute_newline_positions, parser_options}; -use crate::ast_visitor::for_each_child; +use crate::ast_helpers::{byte_offset_to_line, compute_newline_positions}; +use crate::ast_visitor::for_each_direct_child; use crate::comment_directives::build_disabled_set; use crate::config::Config; use crate::offense::Offense; @@ -35,49 +35,26 @@ pub fn analyze_file(path: &Path, config: &Config) -> Result>() - .join(", "), - }); - } - // Recovered AST with errors — skip analysis to avoid false positives + // Prism always produces an AST, but if there are errors, skip analysis + // to avoid false positives (matching lib-ruby-parser behavior). return Ok(AnalysisResult { path: path.display().to_string(), offenses: vec![], }); } - let ast = match result.ast { - Some(ast) => ast, - None => { - return Ok(AnalysisResult { - path: path.display().to_string(), - offenses: vec![], - }); - } - }; + let root = result.node(); - let disabled_set = build_disabled_set(&result.comments, &source_clone, &newline_positions); + let disabled_set = build_disabled_set(&result, &source, &newline_positions); let mut offenses = Vec::new(); - walk_node(&ast, &mut offenses, &source_clone); + walk_node(&root, &mut offenses, &source); // Resolve byte offsets to line numbers, then filter by config and inline directives let offenses = offenses @@ -101,51 +78,145 @@ pub fn analyze_file(path: &Path, config: &Config) -> Result, source: &[u8]) { +fn walk_node(node: &Node<'_>, offenses: &mut Vec, source: &[u8]) { match node { - Node::For(f) => { - offenses.extend(for_loop_scanner::scan(f, source)); - for_each_child(node, |child| walk_node(child, offenses, source)); - } - Node::RescueBody(rb) => { - offenses.extend(rescue_scanner::scan(rb)); - for_each_child(node, |child| walk_node(child, offenses, source)); + Node::ProgramNode { .. } => { + let prog = node.as_program_node().unwrap(); + for child in prog.statements().body().iter() { + walk_node(&child, offenses, source); + } } - Node::Def(d) => { - offenses.extend(method_definition_scanner::scan(d)); - for_each_child(node, |child| walk_node(child, offenses, source)); + Node::ForNode { .. } => { + let f = node.as_for_node().unwrap(); + offenses.extend(for_loop_scanner::scan(&f, source)); + for_each_direct_child(node, &mut |child| walk_node(child, offenses, source)); } - Node::Send(s) => { - if let Some(Node::Block(recv_block)) = s.recv.as_deref() { - offenses.extend(method_call_scanner::scan_send_on_block(s, recv_block)); + Node::BeginNode { .. } => { + let begin = node.as_begin_node().unwrap(); + // Visit statements + if let Some(stmts) = begin.statements() { + for child in stmts.body().iter() { + walk_node(&child, offenses, source); + } } - offenses.extend(method_call_scanner::scan_send(s)); - for_each_child(node, |child| walk_node(child, offenses, source)); - } - Node::Block(b) => { - offenses.extend(method_call_scanner::scan_block(b)); - // Walk children manually to skip the inner Send (avoids double-scanning). - if let Node::Send(s) = b.call.as_ref() { - if let Some(recv) = &s.recv { - walk_node(recv, offenses, source); + // Visit rescue clauses + if let Some(rescue) = begin.rescue_clause() { + walk_rescue_node(&rescue, offenses, source); + } + // Visit else clause + if let Some(else_clause) = begin.else_clause() + && let Some(stmts) = else_clause.statements() + { + for child in stmts.body().iter() { + walk_node(&child, offenses, source); } - for arg in &s.args { - walk_node(arg, offenses, source); + } + // Visit ensure clause + if let Some(ensure) = begin.ensure_clause() + && let Some(stmts) = ensure.statements() + { + for child in stmts.body().iter() { + walk_node(&child, offenses, source); } } - if let Some(args) = &b.args { - walk_node(args, offenses, source); + } + Node::RescueNode { .. } => { + let rn = node.as_rescue_node().unwrap(); + walk_rescue_node(&rn, offenses, source); + } + Node::DefNode { .. } => { + let d = node.as_def_node().unwrap(); + offenses.extend(method_definition_scanner::scan(&d)); + // Walk the body + if let Some(body) = d.body() { + walk_node(&body, offenses, source); + } + } + Node::CallNode { .. } => { + let call = node.as_call_node().unwrap(); + + // Check if receiver is a CallNode with a BlockNode (chained: .select{}.first) + if let Some(recv) = call.receiver() + && let Some(recv_call) = recv.as_call_node() + && let Some(Node::BlockNode { .. }) = recv_call.block() + { + offenses.extend(method_call_scanner::scan_call_on_block_call( + &call, &recv_call, + )); } - if let Some(body) = &b.body { - walk_node(body, offenses, source); + + // Check if this call has a block (CallNode owns BlockNode in prism) + match call.block() { + Some(Node::BlockNode { .. }) => { + let block = call.block().unwrap().as_block_node().unwrap(); + offenses.extend(method_call_scanner::scan_call_with_block(&call, &block)); + + // Walk receiver and arguments (skip the block's call — we already scanned it) + if let Some(recv) = call.receiver() { + walk_node(&recv, offenses, source); + } + if let Some(args) = call.arguments() { + for arg in args.arguments().iter() { + walk_node(&arg, offenses, source); + } + } + // Walk block body + if let Some(body) = block.body() { + walk_node(&body, offenses, source); + } + } + _ => { + // No block or block argument — scan as plain send + offenses.extend(method_call_scanner::scan_call(&call)); + // Walk all children + if let Some(recv) = call.receiver() { + walk_node(&recv, offenses, source); + } + if let Some(args) = call.arguments() { + for arg in args.arguments().iter() { + walk_node(&arg, offenses, source); + } + } + if let Some(block) = call.block() { + walk_node(&block, offenses, source); + } + } } } _ => { - for_each_child(node, |child| walk_node(child, offenses, source)); + for_each_direct_child(node, &mut |child| walk_node(child, offenses, source)); } } } +/// Walk a RescueNode and its chain of subsequent rescue clauses. +fn walk_rescue_node( + rescue: &ruby_prism::RescueNode<'_>, + offenses: &mut Vec, + source: &[u8], +) { + offenses.extend(rescue_scanner::scan(rescue)); + + // Walk exception list + for exc in rescue.exceptions().iter() { + walk_node(&exc, offenses, source); + } + // Walk reference + if let Some(reference) = rescue.reference() { + walk_node(&reference, offenses, source); + } + // Walk statements + if let Some(stmts) = rescue.statements() { + for child in stmts.body().iter() { + walk_node(&child, offenses, source); + } + } + // Walk subsequent rescue clauses + if let Some(subsequent) = rescue.subsequent() { + walk_rescue_node(&subsequent, offenses, source); + } +} + #[cfg(test)] mod tests { use crate::ast_helpers::{byte_offset_to_line, compute_newline_positions}; @@ -168,29 +239,13 @@ mod tests { } #[test] - fn analyze_file_with_parse_errors_no_ast_returns_error() { + fn analyze_file_with_parse_errors_returns_empty() { let dir = tempfile::TempDir::new().unwrap(); - // This produces a fatal parse error with no recoverable AST let file = dir.path().join("fatal.rb"); - std::fs::write(&file, "\x00\x01\x02").unwrap(); + std::fs::write(&file, "def def def").unwrap(); let config = crate::config::Config::default(); - let result = super::analyze_file(&file, &config); - // May be Ok with empty offenses or Err depending on parser behavior - // Either way it should not panic - let _ = result; - } - - #[test] - fn analyze_file_with_recovered_ast_returns_empty() { - let dir = tempfile::TempDir::new().unwrap(); - let file = dir.path().join("recovered.rb"); - std::fs::write(&file, "def foo; end; def def; end").unwrap(); - let config = crate::config::Config::default(); - let result = super::analyze_file(&file, &config); - match result { - Ok(analysis) => assert!(analysis.offenses.is_empty()), - Err(_) => {} // Also acceptable — fatal parse error - } + let result = super::analyze_file(&file, &config).unwrap(); + assert!(result.offenses.is_empty()); } #[test] @@ -229,8 +284,7 @@ mod tests { } #[test] - fn walk_node_block_with_non_send_call() { - // A numblock (numbered params) has a different call structure + fn walk_node_block_with_symbol_to_proc() { let dir = tempfile::TempDir::new().unwrap(); let file = dir.path().join("test.rb"); std::fs::write(&file, "arr.map { |x| x.to_s }").unwrap(); diff --git a/src/ast_helpers.rs b/src/ast_helpers.rs index 6f11b91..4f02912 100644 --- a/src/ast_helpers.rs +++ b/src/ast_helpers.rs @@ -1,7 +1,4 @@ -use lib_ruby_parser::Node; -use lib_ruby_parser::ParserOptions; -use lib_ruby_parser::nodes::{Block, Def, Send}; -use lib_ruby_parser::source::{Decoder, DecoderResult}; +use ruby_prism::Node; /// Convert a byte offset to a 1-based line number using pre-computed newline positions. pub fn byte_offset_to_line(newline_positions: &[usize], byte_offset: usize) -> usize { @@ -11,193 +8,252 @@ pub fn byte_offset_to_line(newline_positions: &[usize], byte_offset: usize) -> u } } -/// Check if a Send node's receiver is itself a Send with a given method name. -pub fn receiver_is_send_with_name(recv: &Option>, name: &str) -> bool { - match recv.as_deref() { - Some(Node::Send(s)) => s.method_name == name, - _ => false, +/// Check if a node is a CallNode with the given method name. +pub fn receiver_is_call_with_name(recv: &Option>, name: &[u8]) -> bool { + match recv { + Some(node) => { + if let Some(call) = node.as_call_node() { + call.name().as_slice() == name + } else { + false + } + } + None => false, } } -/// Extract the inner Send from a receiver, if it is one. -pub fn receiver_as_send(recv: &Option>) -> Option<&Send> { - match recv.as_deref() { - Some(Node::Send(s)) => Some(s), - _ => None, +/// Extract the inner CallNode from a receiver, if it is one. +pub fn receiver_as_call<'pr>(recv: &'pr Option>) -> Option> { + match recv { + Some(node) => node.as_call_node(), + None => None, } } -/// Extract the Send from a Block's call field. -pub fn block_call_as_send(block: &Block) -> Option<&Send> { - match block.call.as_ref() { - Node::Send(s) => Some(s), - _ => None, - } +/// Check if a CallNode has a BlockArgumentNode in its block field. +pub fn has_block_pass(call: &ruby_prism::CallNode<'_>) -> bool { + matches!(call.block(), Some(Node::BlockArgumentNode { .. })) } -/// Check if an argument list contains a BlockPass (e.g., `&:foo`). -pub fn has_block_pass(args: &[Node]) -> bool { - args.iter().any(|a| matches!(a, Node::BlockPass(_))) +/// Check if a CallNode has a full BlockNode (not just a BlockArgumentNode). +pub fn has_full_block(call: &ruby_prism::CallNode<'_>) -> bool { + matches!(call.block(), Some(Node::BlockNode { .. })) } -/// Count non-BlockPass arguments. -pub fn arg_count_without_block_pass(args: &[Node]) -> usize { - args.iter() - .filter(|a| !matches!(a, Node::BlockPass(_))) - .count() +/// Count arguments from a CallNode's arguments (excluding block argument which is in block field). +pub fn arg_count(call: &ruby_prism::CallNode<'_>) -> usize { + match call.arguments() { + Some(args) => args.arguments().iter().count(), + None => 0, + } +} + +/// Get arguments as a collected Vec from a CallNode. +pub fn call_args<'pr>(call: &ruby_prism::CallNode<'pr>) -> Vec> { + match call.arguments() { + Some(args) => args.arguments().iter().collect(), + None => vec![], + } } /// Check if a node is a single-character string literal. -pub fn is_single_char_string(node: &Node) -> bool { - match node { - Node::Str(s) => s.value.to_string_lossy().chars().count() == 1, - _ => false, +pub fn is_single_char_string(node: &Node<'_>) -> bool { + match node.as_string_node() { + Some(s) => s.unescaped().len() == 1, + None => false, } } -/// Check if the receiver is a range (Irange or Erange). -/// Also handles parenthesized ranges: `(1..10)` parses as `Begin(Irange(...))`. -pub fn receiver_is_range(recv: &Option>) -> bool { - match recv.as_deref() { - Some(Node::Irange(_) | Node::Erange(_)) => true, - Some(Node::Begin(b)) => { - b.statements.len() == 1 && matches!(b.statements[0], Node::Irange(_) | Node::Erange(_)) +/// Check if the receiver is a range (RangeNode, inclusive or exclusive). +/// Also handles parenthesized ranges: `(1..10)` parses as `ParenthesesNode(RangeNode)`. +pub fn receiver_is_range(recv: &Option>) -> bool { + match recv { + Some(node) => { + if node.as_range_node().is_some() { + return true; + } + if let Some(paren) = node.as_parentheses_node() + && let Some(body) = paren.body() + { + if let Some(stmts) = body.as_statements_node() { + let body_nodes: Vec<_> = stmts.body().iter().collect(); + return body_nodes.len() == 1 && body_nodes[0].as_range_node().is_some(); + } + return body.as_range_node().is_some(); + } + false } - _ => false, + None => false, } } /// Check if a node is a literal/primitive (not a variable reference or method call). -pub fn is_primitive(node: &Node) -> bool { +pub fn is_primitive(node: &Node<'_>) -> bool { matches!( node, - Node::Int(_) - | Node::Float(_) - | Node::Str(_) - | Node::Sym(_) - | Node::True(_) - | Node::False(_) - | Node::Nil(_) - | Node::Array(_) - | Node::Hash(_) - | Node::Irange(_) - | Node::Erange(_) - | Node::Rational(_) - | Node::Complex(_) + Node::IntegerNode { .. } + | Node::FloatNode { .. } + | Node::StringNode { .. } + | Node::SymbolNode { .. } + | Node::TrueNode { .. } + | Node::FalseNode { .. } + | Node::NilNode { .. } + | Node::ArrayNode { .. } + | Node::HashNode { .. } + | Node::RangeNode { .. } + | Node::RationalNode { .. } + | Node::ImaginaryNode { .. } ) } -/// Check if the first argument to a Send is a Hash/Kwargs node with exactly one key-value pair. -/// `h.merge!(item: 1)` parses as Kwargs, `h.merge!({item: 1})` parses as Hash. -pub fn first_arg_is_single_pair_hash(args: &[Node]) -> bool { +/// Check if the first argument is a Hash/KeywordHash node with exactly one key-value pair. +/// `h.merge!(item: 1)` parses as KeywordHashNode, `h.merge!({item: 1})` parses as HashNode. +pub fn first_arg_is_single_pair_hash(args: &[Node<'_>]) -> bool { match args.first() { - Some(Node::Hash(h)) => h.pairs.len() == 1, - Some(Node::Kwargs(k)) => k.pairs.len() == 1, - _ => false, - } -} - -/// Check if a node is an Int with value 1. -pub fn is_int_one(node: &Node) -> bool { - match node { - Node::Int(i) => i.value == "1", - _ => false, + Some(node) => { + if let Some(h) = node.as_hash_node() { + return h.elements().iter().count() == 1; + } + if let Some(k) = node.as_keyword_hash_node() { + return k.elements().iter().count() == 1; + } + false + } + None => false, } } -/// Get block argument names from Args node. -pub fn block_arg_names(args: &Option>) -> Vec { - match args.as_deref() { - Some(Node::Args(a)) => a - .args - .iter() - .filter_map(|arg| match arg { - Node::Arg(a) => Some(a.name.clone()), - Node::Procarg0(p) => match p.args.as_slice() { - [Node::Arg(a)] => Some(a.name.clone()), - _ => None, - }, - _ => None, - }) - .collect(), - _ => Vec::new(), +/// Check if a node is an IntegerNode with value 1. +pub fn is_int_one(node: &Node<'_>) -> bool { + if let Some(i) = node.as_integer_node() { + // Check the location text as a reliable way to get the value + let text = i.location().as_slice(); + text == b"1" + } else { + false } } -/// Check if a Def node has a block argument (&block), returning its name if so. -pub fn def_block_arg_name(def: &Def) -> Option { - let args_node = def.args.as_deref()?; - if let Node::Args(args) = args_node { - for arg in &args.args { - if let Node::Blockarg(ba) = arg { - return ba.name.clone(); +/// Get block argument names from a BlockNode's parameters. +/// BlockNode.parameters() returns Option which is typically a BlockParametersNode. +pub fn block_arg_names(params: &Option>) -> Vec { + match params { + Some(node) => { + if let Some(block_params) = node.as_block_parameters_node() + && let Some(inner_params) = block_params.parameters() + { + return inner_params + .requireds() + .iter() + .filter_map(|p| { + p.as_required_parameter_node() + .map(|rp| String::from_utf8_lossy(rp.name().as_slice()).to_string()) + }) + .collect(); } + // Handle NumberedParametersNode or other cases + Vec::new() } + None => Vec::new(), } - None } -/// Count regular (non-optional, non-keyword, non-rest, non-block) arguments in a Def. -pub fn def_regular_arg_count(def: &Def) -> usize { - match def.args.as_deref() { - Some(Node::Args(args)) => args - .args - .iter() - .filter(|a| matches!(a, Node::Arg(_))) - .count(), - _ => 0, - } +/// Check if a DefNode has a block argument (&block), returning its name if so. +pub fn def_block_arg_name(def: &ruby_prism::DefNode<'_>) -> Option { + let params = def.parameters()?; + let block_param = params.block()?; + let name = block_param.name()?; + Some(String::from_utf8_lossy(name.as_slice()).to_string()) } -/// Get the first regular argument name from a Def. -pub fn def_first_arg_name(def: &Def) -> Option { - match def.args.as_deref() { - Some(Node::Args(args)) => args.args.iter().find_map(|a| match a { - Node::Arg(arg) => Some(arg.name.clone()), - _ => None, - }), - _ => None, +/// Count regular (required) arguments in a DefNode. +pub fn def_regular_arg_count(def: &ruby_prism::DefNode<'_>) -> usize { + match def.parameters() { + Some(params) => params.requireds().iter().count(), + None => 0, } } +/// Get the first regular argument name from a DefNode. +pub fn def_first_arg_name(def: &ruby_prism::DefNode<'_>) -> Option { + let params = def.parameters()?; + let first = params.requireds().iter().next()?; + first + .as_required_parameter_node() + .map(|rp| String::from_utf8_lossy(rp.name().as_slice()).to_string()) +} + /// Check if a string literal contains "def". -pub fn str_contains_def(node: &Node) -> bool { - match node { - Node::Str(s) => s.value.to_string_lossy().contains("def"), - Node::Heredoc(h) => h.parts.iter().any(|part| match part { - Node::Str(s) => s.value.to_string_lossy().contains("def"), - _ => false, - }), - _ => false, +pub fn str_contains_def(node: &Node<'_>) -> bool { + if let Some(s) = node.as_string_node() { + return String::from_utf8_lossy(s.unescaped()).contains("def"); + } + if let Some(interp) = node.as_interpolated_string_node() { + return interp.parts().iter().any(|part| { + if let Some(s) = part.as_string_node() { + String::from_utf8_lossy(s.unescaped()).contains("def") + } else { + false + } + }); } + false } -/// Get expressions from a body node. If it's a Begin, return its statements. -/// Otherwise return a single-element slice-like iterator. -pub fn body_expressions(body: &Option>) -> Vec<&Node> { - match body.as_deref() { - None => vec![], - Some(Node::Begin(b)) => b.statements.iter().collect(), - Some(node) => vec![node], +/// Get the number of top-level expressions in a body node, and optionally the single expression. +/// Returns (count, Option) — the option is Some only when count == 1. +pub fn body_single_expression<'pr>(body: &Option>) -> (usize, Option>) { + match body { + None => (0, None), + Some(node) => { + if let Some(stmts) = node.as_statements_node() { + let body_nodes: Vec<_> = stmts.body().iter().collect(); + let count = body_nodes.len(); + if count == 1 { + (1, Some(body_nodes.into_iter().next().unwrap())) + } else { + (count, None) + } + } else { + // Single expression body (no StatementsNode wrapper). + // We need to return the node itself. Since prism's Node is just + // a thin wrapper around pointers, we can reconstruct it from the body. + // Re-access from the parent to get an owned Node. + (1, body.as_ref().map(|n| reconstruct_node_from_body(n))) + } + } } } -/// Build ParserOptions with a custom decoder that handles ASCII/US-ASCII encodings. -/// The lib_ruby_parser only supports UTF-8 and ASCII-8BIT out of the box. -/// Since ASCII is a subset of UTF-8, we pass the bytes through unchanged. -pub fn parser_options() -> ParserOptions { - let decoder = Decoder::new(Box::new(|encoding: String, input: Vec| match encoding - .to_uppercase() - .as_str() - { - "ASCII" | "US-ASCII" => DecoderResult::Ok(input), - _ => DecoderResult::Err(lib_ruby_parser::source::InputError::UnsupportedEncoding( - encoding, - )), - })); - ParserOptions { - decoder: Some(decoder), - ..Default::default() +/// Helper: given a reference to a Node, produce a new owned Node with the same data. +/// This works because ruby_prism Node variants are just (parser, pointer, marker) tuples +/// and the data is borrowed from the parse result, not owned. +fn reconstruct_node_from_body<'pr>(node: &Node<'pr>) -> Node<'pr> { + // The body_expressions approach doesn't work because Node isn't Clone. + // Instead, callers should re-call body() to get a fresh owned Node. + // This function exists as a workaround: since all callers already have + // the def/body, they can just re-call .body() to get an owned Node. + // + // For now, we'll use unsafe to transmute since the Node is just pointers. + // Safety: Node<'pr> is a repr(C)-like enum of (parser, pointer, marker) + // and the lifetime is tied to the ParseResult. Copying the pointer data is safe + // as long as the ParseResult outlives the copy. + unsafe { std::ptr::read(node as *const Node<'pr>) } +} + +/// Get expressions from a body node. If it's a StatementsNode, return its body items. +/// Otherwise return a single-element vec. +/// IMPORTANT: Caller must ensure `body` outlives the returned Vec. +pub fn body_expressions<'pr>(body: &Option>) -> Vec> { + match body { + None => vec![], + Some(node) => { + if let Some(stmts) = node.as_statements_node() { + stmts.body().iter().collect() + } else { + vec![reconstruct_node_from_body(node)] + } + } } } @@ -214,12 +270,15 @@ pub fn compute_newline_positions(source: &[u8]) -> Vec { #[cfg(test)] mod tests { use super::*; - use lib_ruby_parser::Parser; - fn parse(source: &[u8]) -> Option> { - Parser::new(source.to_vec(), Default::default()) - .do_parse() - .ast + fn parse_first_stmt(source: &'static [u8]) -> Node<'static> { + // Leak the parse result to get a 'static lifetime for tests + let result = ruby_prism::parse(source); + let result = Box::leak(Box::new(result)); + let program = result.node(); + let prog = program.as_program_node().unwrap(); + let stmts: Vec<_> = prog.statements().body().iter().collect(); + stmts.into_iter().next().unwrap() } #[test] @@ -238,165 +297,153 @@ mod tests { } #[test] - fn receiver_is_send_with_name_works() { - let ast = parse(b"a.foo.bar").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(receiver_is_send_with_name(&s.recv, "foo")); - assert!(!receiver_is_send_with_name(&s.recv, "baz")); - } + fn receiver_is_call_with_name_works() { + let node = parse_first_stmt(b"a.foo.bar"); + let call = node.as_call_node().unwrap(); + assert!(receiver_is_call_with_name(&call.receiver(), b"foo")); + assert!(!receiver_is_call_with_name(&call.receiver(), b"baz")); } #[test] - fn receiver_is_send_with_name_none() { - assert!(!receiver_is_send_with_name(&None, "foo")); + fn receiver_is_call_with_name_none() { + assert!(!receiver_is_call_with_name(&None, b"foo")); } #[test] - fn receiver_as_send_works() { - let ast = parse(b"a.foo.bar").unwrap(); - if let Node::Send(s) = ast.as_ref() { - let inner = receiver_as_send(&s.recv).unwrap(); - assert_eq!(inner.method_name, "foo"); - } + fn receiver_as_call_works() { + let node = parse_first_stmt(b"a.foo.bar"); + let call = node.as_call_node().unwrap(); + let recv = call.receiver(); + let inner = receiver_as_call(&recv).unwrap(); + assert_eq!(inner.name().as_slice(), b"foo"); } #[test] - fn receiver_as_send_not_send() { - assert!(receiver_as_send(&None).is_none()); - } - - #[test] - fn block_call_as_send_works() { - let ast = parse(b"arr.map { |x| x }").unwrap(); - if let Node::Block(b) = ast.as_ref() { - let send = block_call_as_send(b).unwrap(); - assert_eq!(send.method_name, "map"); - } + fn receiver_as_call_not_call() { + assert!(receiver_as_call(&None).is_none()); } #[test] fn has_block_pass_works() { - let ast = parse(b"arr.map(&:to_s)").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(has_block_pass(&s.args)); - } + let node = parse_first_stmt(b"arr.map(&:to_s)"); + let call = node.as_call_node().unwrap(); + assert!(has_block_pass(&call)); } #[test] fn has_block_pass_without() { - let ast = parse(b"arr.map(1)").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(!has_block_pass(&s.args)); - } + let node = parse_first_stmt(b"arr.map(1)"); + let call = node.as_call_node().unwrap(); + assert!(!has_block_pass(&call)); } #[test] - fn arg_count_without_block_pass_works() { - let ast = parse(b"arr.select(&:odd?).first").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert_eq!(arg_count_without_block_pass(&s.args), 0); - } + fn arg_count_works() { + let node = parse_first_stmt(b"arr.select(1, 2)"); + let call = node.as_call_node().unwrap(); + assert_eq!(arg_count(&call), 2); } #[test] fn is_single_char_string_works() { - let ast = parse(b"'x'").unwrap(); - assert!(is_single_char_string(&ast)); - let ast2 = parse(b"'xy'").unwrap(); - assert!(!is_single_char_string(&ast2)); + let node = parse_first_stmt(b"'x'"); + assert!(is_single_char_string(&node)); + let node2 = parse_first_stmt(b"'xy'"); + assert!(!is_single_char_string(&node2)); } #[test] fn is_single_char_string_not_string() { - let ast = parse(b"42").unwrap(); - assert!(!is_single_char_string(&ast)); + let node = parse_first_stmt(b"42"); + assert!(!is_single_char_string(&node)); } #[test] - fn receiver_is_range_irange() { - let ast = parse(b"(1..10).include?(5)").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(receiver_is_range(&s.recv)); - } + fn receiver_is_range_inclusive() { + let node = parse_first_stmt(b"(1..10).include?(5)"); + let call = node.as_call_node().unwrap(); + assert!(receiver_is_range(&call.receiver())); } #[test] - fn receiver_is_range_erange() { - let ast = parse(b"(1...10).include?(5)").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(receiver_is_range(&s.recv)); - } + fn receiver_is_range_exclusive() { + let node = parse_first_stmt(b"(1...10).include?(5)"); + let call = node.as_call_node().unwrap(); + assert!(receiver_is_range(&call.receiver())); } #[test] fn receiver_is_range_not_range() { - let ast = parse(b"[1].include?(5)").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(!receiver_is_range(&s.recv)); - } + let node = parse_first_stmt(b"[1].include?(5)"); + let call = node.as_call_node().unwrap(); + assert!(!receiver_is_range(&call.receiver())); } #[test] fn is_primitive_covers_types() { - assert!(is_primitive(&parse(b"42").unwrap())); - assert!(is_primitive(&parse(b"3.14").unwrap())); - assert!(is_primitive(&parse(b"'s'").unwrap())); - assert!(is_primitive(&parse(b":sym").unwrap())); - assert!(is_primitive(&parse(b"true").unwrap())); - assert!(is_primitive(&parse(b"false").unwrap())); - assert!(is_primitive(&parse(b"nil").unwrap())); - assert!(is_primitive(&parse(b"[]").unwrap())); - assert!(is_primitive(&parse(b"{}").unwrap())); - assert!(is_primitive(&parse(b"1..5").unwrap())); - assert!(is_primitive(&parse(b"1...5").unwrap())); - assert!(!is_primitive(&parse(b"x").unwrap())); + assert!(is_primitive(&parse_first_stmt(b"42"))); + assert!(is_primitive(&parse_first_stmt(b"3.14"))); + assert!(is_primitive(&parse_first_stmt(b"'s'"))); + assert!(is_primitive(&parse_first_stmt(b":sym"))); + assert!(is_primitive(&parse_first_stmt(b"true"))); + assert!(is_primitive(&parse_first_stmt(b"false"))); + assert!(is_primitive(&parse_first_stmt(b"nil"))); + assert!(is_primitive(&parse_first_stmt(b"[]"))); + assert!(is_primitive(&parse_first_stmt(b"{}"))); + assert!(is_primitive(&parse_first_stmt(b"1..5"))); + assert!(is_primitive(&parse_first_stmt(b"1...5"))); + assert!(!is_primitive(&parse_first_stmt(b"x"))); } #[test] fn first_arg_is_single_pair_hash_kwargs() { - let ast = parse(b"h.merge!(a: 1)").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(first_arg_is_single_pair_hash(&s.args)); - } + let node = parse_first_stmt(b"h.merge!(a: 1)"); + let call = node.as_call_node().unwrap(); + let args = call_args(&call); + assert!(first_arg_is_single_pair_hash(&args)); } #[test] fn first_arg_is_single_pair_hash_explicit() { - let ast = parse(b"h.merge!({a: 1})").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(first_arg_is_single_pair_hash(&s.args)); - } + let node = parse_first_stmt(b"h.merge!({a: 1})"); + let call = node.as_call_node().unwrap(); + let args = call_args(&call); + assert!(first_arg_is_single_pair_hash(&args)); } #[test] fn first_arg_is_single_pair_hash_multi() { - let ast = parse(b"h.merge!(a: 1, b: 2)").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(!first_arg_is_single_pair_hash(&s.args)); - } + let node = parse_first_stmt(b"h.merge!(a: 1, b: 2)"); + let call = node.as_call_node().unwrap(); + let args = call_args(&call); + assert!(!first_arg_is_single_pair_hash(&args)); } #[test] fn first_arg_is_single_pair_hash_not_hash() { - let ast = parse(b"h.merge!(x)").unwrap(); - if let Node::Send(s) = ast.as_ref() { - assert!(!first_arg_is_single_pair_hash(&s.args)); - } + let node = parse_first_stmt(b"h.merge!(x)"); + let call = node.as_call_node().unwrap(); + let args = call_args(&call); + assert!(!first_arg_is_single_pair_hash(&args)); } #[test] fn is_int_one_works() { - assert!(is_int_one(&parse(b"1").unwrap())); - assert!(!is_int_one(&parse(b"2").unwrap())); - assert!(!is_int_one(&parse(b"'1'").unwrap())); + assert!(is_int_one(&parse_first_stmt(b"1"))); + assert!(!is_int_one(&parse_first_stmt(b"2"))); + assert!(!is_int_one(&parse_first_stmt(b"'1'"))); } #[test] fn block_arg_names_single() { - let ast = parse(b"arr.map { |x| x }").unwrap(); - if let Node::Block(b) = ast.as_ref() { - let names = block_arg_names(&b.args); + let node = parse_first_stmt(b"arr.map { |x| x }"); + let call = node.as_call_node().unwrap(); + if let Some(Node::BlockNode { .. }) = call.block() { + let block = call.block().unwrap().as_block_node().unwrap(); + let names = block_arg_names(&block.parameters()); assert_eq!(names, vec!["x".to_string()]); + } else { + panic!("Expected BlockNode"); } } @@ -408,74 +455,68 @@ mod tests { #[test] fn def_block_arg_name_present() { - let ast = parse(b"def foo(&block); end").unwrap(); - if let Node::Def(d) = ast.as_ref() { - assert_eq!(def_block_arg_name(d), Some("block".to_string())); - } + let node = parse_first_stmt(b"def foo(&block); end"); + let def = node.as_def_node().unwrap(); + assert_eq!(def_block_arg_name(&def), Some("block".to_string())); } #[test] fn def_block_arg_name_absent() { - let ast = parse(b"def foo(x); end").unwrap(); - if let Node::Def(d) = ast.as_ref() { - assert_eq!(def_block_arg_name(d), None); - } + let node = parse_first_stmt(b"def foo(x); end"); + let def = node.as_def_node().unwrap(); + assert_eq!(def_block_arg_name(&def), None); } #[test] fn def_regular_arg_count_works() { - let ast = parse(b"def foo(a, b); end").unwrap(); - if let Node::Def(d) = ast.as_ref() { - assert_eq!(def_regular_arg_count(d), 2); - } + let node = parse_first_stmt(b"def foo(a, b); end"); + let def = node.as_def_node().unwrap(); + assert_eq!(def_regular_arg_count(&def), 2); } #[test] fn def_regular_arg_count_no_args() { - let ast = parse(b"def foo; end").unwrap(); - if let Node::Def(d) = ast.as_ref() { - assert_eq!(def_regular_arg_count(d), 0); - } + let node = parse_first_stmt(b"def foo; end"); + let def = node.as_def_node().unwrap(); + assert_eq!(def_regular_arg_count(&def), 0); } #[test] fn def_first_arg_name_works() { - let ast = parse(b"def foo(bar); end").unwrap(); - if let Node::Def(d) = ast.as_ref() { - assert_eq!(def_first_arg_name(d), Some("bar".to_string())); - } + let node = parse_first_stmt(b"def foo(bar); end"); + let def = node.as_def_node().unwrap(); + assert_eq!(def_first_arg_name(&def), Some("bar".to_string())); } #[test] fn def_first_arg_name_no_args() { - let ast = parse(b"def foo; end").unwrap(); - if let Node::Def(d) = ast.as_ref() { - assert_eq!(def_first_arg_name(d), None); - } + let node = parse_first_stmt(b"def foo; end"); + let def = node.as_def_node().unwrap(); + assert_eq!(def_first_arg_name(&def), None); } #[test] fn str_contains_def_in_string() { - let ast = parse(b"\"def foo\"").unwrap(); - assert!(str_contains_def(&ast)); + let node = parse_first_stmt(b"\"def foo\""); + assert!(str_contains_def(&node)); } #[test] fn str_contains_def_no_def() { - let ast = parse(b"\"hello\"").unwrap(); - assert!(!str_contains_def(&ast)); + let node = parse_first_stmt(b"\"hello\""); + assert!(!str_contains_def(&node)); } #[test] fn str_contains_def_not_string() { - let ast = parse(b"42").unwrap(); - assert!(!str_contains_def(&ast)); + let node = parse_first_stmt(b"42"); + assert!(!str_contains_def(&node)); } #[test] fn str_contains_def_heredoc() { - let ast = parse(b"<<~RUBY\ndef foo\nRUBY\n").unwrap(); - assert!(str_contains_def(&ast)); + let node = parse_first_stmt(b"<<~RUBY\ndef foo\nRUBY\n"); + assert!(str_contains_def(&node)); } #[test] @@ -485,20 +526,18 @@ mod tests { #[test] fn body_expressions_single() { - let ast = parse(b"def foo; 42; end").unwrap(); - if let Node::Def(d) = ast.as_ref() { - let exprs = body_expressions(&d.body); - assert_eq!(exprs.len(), 1); - } + let node = parse_first_stmt(b"def foo; 42; end"); + let def = node.as_def_node().unwrap(); + let exprs = body_expressions(&def.body()); + assert_eq!(exprs.len(), 1); } #[test] fn body_expressions_begin() { - let ast = parse(b"def foo; 1; 2; 3; end").unwrap(); - if let Node::Def(d) = ast.as_ref() { - let exprs = body_expressions(&d.body); - assert_eq!(exprs.len(), 3); - } + let node = parse_first_stmt(b"def foo; 1; 2; 3; end"); + let def = node.as_def_node().unwrap(); + let exprs = body_expressions(&def.body()); + assert_eq!(exprs.len(), 3); } #[test] @@ -519,27 +558,16 @@ mod tests { } #[test] - fn parser_options_handles_ascii_encoding() { + fn prism_handles_ascii_encoding() { let source = b"# encoding: ASCII\nx = 1\n"; - let result = Parser::new(source.to_vec(), parser_options()).do_parse(); - assert!(result.ast.is_some()); + let result = ruby_prism::parse(source); + assert!(result.errors().next().is_none()); } #[test] - fn parser_options_handles_us_ascii_encoding() { + fn prism_handles_us_ascii_encoding() { let source = b"# encoding: us-ascii\nx = 1\n"; - let result = Parser::new(source.to_vec(), parser_options()).do_parse(); - assert!(result.ast.is_some()); - } - - #[test] - fn parser_options_rejects_unknown_encoding() { - let source = b"# encoding: SHIFT_JIS\nx = 1\n"; - let result = Parser::new(source.to_vec(), parser_options()).do_parse(); - let has_encoding_error = result - .diagnostics - .iter() - .any(|d| format!("{:?}", d.message).contains("UnsupportedEncoding")); - assert!(has_encoding_error); + let result = ruby_prism::parse(source); + assert!(result.errors().next().is_none()); } } diff --git a/src/ast_visitor.rs b/src/ast_visitor.rs index a527f65..7011e94 100644 --- a/src/ast_visitor.rs +++ b/src/ast_visitor.rs @@ -1,583 +1,695 @@ -use lib_ruby_parser::Node; -#[allow(unused_imports)] -use lib_ruby_parser::nodes::*; +use ruby_prism::Node; -/// Collect all direct child nodes of a given node. -/// Prefer `for_each_child` in hot paths to avoid Vec allocation. -pub fn node_children(node: &Node) -> Vec<&Node> { - let mut children = Vec::new(); - for_each_child(node, |child| children.push(child)); - children +/// Recursively visit all descendant nodes of a given node, calling `f` for each. +/// This is used by scanners that need to search for patterns inside a subtree. +pub fn for_each_descendant<'pr>(node: &Node<'pr>, f: &mut impl FnMut(&Node<'pr>)) { + for_each_direct_child(node, &mut |child: &Node<'pr>| { + f(child); + for_each_descendant(child, f); + }); } -/// Visit each direct child of a node via callback — zero allocation. -#[inline] -pub fn for_each_child<'a>(node: &'a Node, mut f: impl FnMut(&'a Node)) { - visit_children(node, &mut f); -} - -fn visit_opt<'a>(opt: &'a Option>, f: &mut impl FnMut(&'a Node)) { - if let Some(n) = opt.as_deref() { - f(n); +/// Iterate over direct children of a node, calling f for each. +/// This is the core traversal function for ruby-prism nodes. +/// +/// Note: Some prism accessors return specific types (ElseNode, EnsureNode, etc.) +/// rather than Node. For those, we visit their inner statements directly. +pub fn for_each_direct_child<'pr>(node: &Node<'pr>, f: &mut impl FnMut(&Node<'pr>)) { + match node { + Node::ProgramNode { .. } => { + let n = node.as_program_node().unwrap(); + for child in n.statements().body().iter() { + f(&child); + } + } + Node::StatementsNode { .. } => { + let n = node.as_statements_node().unwrap(); + for child in n.body().iter() { + f(&child); + } + } + Node::CallNode { .. } => { + let n = node.as_call_node().unwrap(); + if let Some(recv) = n.receiver() { + f(&recv); + } + if let Some(args) = n.arguments() { + for arg in args.arguments().iter() { + f(&arg); + } + } + if let Some(block) = n.block() { + f(&block); + } + } + Node::BlockNode { .. } => { + let n = node.as_block_node().unwrap(); + if let Some(params) = n.parameters() { + f(¶ms); + } + if let Some(body) = n.body() { + f(&body); + } + } + Node::BlockArgumentNode { .. } => { + let n = node.as_block_argument_node().unwrap(); + if let Some(expr) = n.expression() { + f(&expr); + } + } + Node::DefNode { .. } => { + let n = node.as_def_node().unwrap(); + // Skip parameters - they don't contain scannable code + if let Some(body) = n.body() { + f(&body); + } + } + Node::ForNode { .. } => { + let n = node.as_for_node().unwrap(); + f(&n.index()); + f(&n.collection()); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::BeginNode { .. } => { + let n = node.as_begin_node().unwrap(); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + if let Some(rescue) = n.rescue_clause() { + visit_rescue_chain_children(&rescue, f); + } + if let Some(else_clause) = n.else_clause() + && let Some(stmts) = else_clause.statements() + { + for child in stmts.body().iter() { + f(&child); + } + } + if let Some(ensure) = n.ensure_clause() + && let Some(stmts) = ensure.statements() + { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::RescueNode { .. } => { + let n = node.as_rescue_node().unwrap(); + visit_rescue_children(&n, f); + } + Node::EnsureNode { .. } => { + let n = node.as_ensure_node().unwrap(); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::ElseNode { .. } => { + let n = node.as_else_node().unwrap(); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::IfNode { .. } => { + let n = node.as_if_node().unwrap(); + f(&n.predicate()); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + if let Some(subsequent) = n.subsequent() { + f(&subsequent); + } + } + Node::UnlessNode { .. } => { + let n = node.as_unless_node().unwrap(); + f(&n.predicate()); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + if let Some(else_clause) = n.else_clause() + && let Some(stmts) = else_clause.statements() + { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::WhileNode { .. } => { + let n = node.as_while_node().unwrap(); + f(&n.predicate()); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::UntilNode { .. } => { + let n = node.as_until_node().unwrap(); + f(&n.predicate()); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::CaseNode { .. } => { + let n = node.as_case_node().unwrap(); + if let Some(pred) = n.predicate() { + f(&pred); + } + for condition in n.conditions().iter() { + f(&condition); + } + if let Some(else_clause) = n.else_clause() + && let Some(stmts) = else_clause.statements() + { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::WhenNode { .. } => { + let n = node.as_when_node().unwrap(); + for cond in n.conditions().iter() { + f(&cond); + } + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::ClassNode { .. } => { + let n = node.as_class_node().unwrap(); + f(&n.constant_path()); + if let Some(superclass) = n.superclass() { + f(&superclass); + } + if let Some(body) = n.body() { + f(&body); + } + } + Node::ModuleNode { .. } => { + let n = node.as_module_node().unwrap(); + f(&n.constant_path()); + if let Some(body) = n.body() { + f(&body); + } + } + Node::SingletonClassNode { .. } => { + let n = node.as_singleton_class_node().unwrap(); + f(&n.expression()); + if let Some(body) = n.body() { + f(&body); + } + } + Node::AndNode { .. } => { + let n = node.as_and_node().unwrap(); + f(&n.left()); + f(&n.right()); + } + Node::OrNode { .. } => { + let n = node.as_or_node().unwrap(); + f(&n.left()); + f(&n.right()); + } + Node::ArrayNode { .. } => { + let n = node.as_array_node().unwrap(); + for elem in n.elements().iter() { + f(&elem); + } + } + Node::HashNode { .. } => { + let n = node.as_hash_node().unwrap(); + for elem in n.elements().iter() { + f(&elem); + } + } + Node::KeywordHashNode { .. } => { + let n = node.as_keyword_hash_node().unwrap(); + for elem in n.elements().iter() { + f(&elem); + } + } + Node::AssocNode { .. } => { + let n = node.as_assoc_node().unwrap(); + f(&n.key()); + f(&n.value()); + } + Node::AssocSplatNode { .. } => { + let n = node.as_assoc_splat_node().unwrap(); + if let Some(value) = n.value() { + f(&value); + } + } + Node::RangeNode { .. } => { + let n = node.as_range_node().unwrap(); + if let Some(left) = n.left() { + f(&left); + } + if let Some(right) = n.right() { + f(&right); + } + } + Node::ParenthesesNode { .. } => { + let n = node.as_parentheses_node().unwrap(); + if let Some(body) = n.body() { + f(&body); + } + } + Node::InterpolatedStringNode { .. } => { + let n = node.as_interpolated_string_node().unwrap(); + for part in n.parts().iter() { + f(&part); + } + } + Node::InterpolatedSymbolNode { .. } => { + let n = node.as_interpolated_symbol_node().unwrap(); + for part in n.parts().iter() { + f(&part); + } + } + Node::EmbeddedStatementsNode { .. } => { + let n = node.as_embedded_statements_node().unwrap(); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::LocalVariableWriteNode { .. } => { + let n = node.as_local_variable_write_node().unwrap(); + f(&n.value()); + } + Node::InstanceVariableWriteNode { .. } => { + let n = node.as_instance_variable_write_node().unwrap(); + f(&n.value()); + } + Node::ClassVariableWriteNode { .. } => { + let n = node.as_class_variable_write_node().unwrap(); + f(&n.value()); + } + Node::GlobalVariableWriteNode { .. } => { + let n = node.as_global_variable_write_node().unwrap(); + f(&n.value()); + } + Node::ConstantWriteNode { .. } => { + let n = node.as_constant_write_node().unwrap(); + f(&n.value()); + } + Node::ConstantPathWriteNode { .. } => { + let n = node.as_constant_path_write_node().unwrap(); + // target() returns ConstantPathNode, not Node — skip it + f(&n.value()); + } + Node::ConstantPathNode { .. } => { + let n = node.as_constant_path_node().unwrap(); + if let Some(parent) = n.parent() { + f(&parent); + } + } + Node::MultiWriteNode { .. } => { + let n = node.as_multi_write_node().unwrap(); + for target in n.lefts().iter() { + f(&target); + } + if let Some(rest) = n.rest() { + f(&rest); + } + for target in n.rights().iter() { + f(&target); + } + f(&n.value()); + } + Node::SplatNode { .. } => { + let n = node.as_splat_node().unwrap(); + if let Some(expr) = n.expression() { + f(&expr); + } + } + Node::ReturnNode { .. } => { + let n = node.as_return_node().unwrap(); + if let Some(args) = n.arguments() { + for arg in args.arguments().iter() { + f(&arg); + } + } + } + Node::YieldNode { .. } => { + let n = node.as_yield_node().unwrap(); + if let Some(args) = n.arguments() { + for arg in args.arguments().iter() { + f(&arg); + } + } + } + Node::SuperNode { .. } => { + let n = node.as_super_node().unwrap(); + if let Some(args) = n.arguments() { + for arg in args.arguments().iter() { + f(&arg); + } + } + if let Some(block) = n.block() { + f(&block); + } + } + Node::LambdaNode { .. } => { + let n = node.as_lambda_node().unwrap(); + if let Some(params) = n.parameters() { + f(¶ms); + } + if let Some(body) = n.body() { + f(&body); + } + } + Node::DefinedNode { .. } => { + let n = node.as_defined_node().unwrap(); + f(&n.value()); + } + Node::InterpolatedRegularExpressionNode { .. } => { + let n = node.as_interpolated_regular_expression_node().unwrap(); + for part in n.parts().iter() { + f(&part); + } + } + Node::MatchPredicateNode { .. } => { + let n = node.as_match_predicate_node().unwrap(); + f(&n.value()); + f(&n.pattern()); + } + Node::MatchRequiredNode { .. } => { + let n = node.as_match_required_node().unwrap(); + f(&n.value()); + f(&n.pattern()); + } + Node::CaseMatchNode { .. } => { + let n = node.as_case_match_node().unwrap(); + if let Some(pred) = n.predicate() { + f(&pred); + } + for condition in n.conditions().iter() { + f(&condition); + } + if let Some(else_clause) = n.else_clause() + && let Some(stmts) = else_clause.statements() + { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::InNode { .. } => { + let n = node.as_in_node().unwrap(); + f(&n.pattern()); + if let Some(stmts) = n.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + } + Node::BreakNode { .. } => { + let n = node.as_break_node().unwrap(); + if let Some(args) = n.arguments() { + for arg in args.arguments().iter() { + f(&arg); + } + } + } + Node::NextNode { .. } => { + let n = node.as_next_node().unwrap(); + if let Some(args) = n.arguments() { + for arg in args.arguments().iter() { + f(&arg); + } + } + } + Node::AliasMethodNode { .. } => { + let n = node.as_alias_method_node().unwrap(); + f(&n.new_name()); + f(&n.old_name()); + } + Node::AliasGlobalVariableNode { .. } => { + let n = node.as_alias_global_variable_node().unwrap(); + f(&n.new_name()); + f(&n.old_name()); + } + Node::UndefNode { .. } => { + let n = node.as_undef_node().unwrap(); + for name in n.names().iter() { + f(&name); + } + } + Node::LocalVariableOperatorWriteNode { .. } => { + let n = node.as_local_variable_operator_write_node().unwrap(); + f(&n.value()); + } + Node::LocalVariableAndWriteNode { .. } => { + let n = node.as_local_variable_and_write_node().unwrap(); + f(&n.value()); + } + Node::LocalVariableOrWriteNode { .. } => { + let n = node.as_local_variable_or_write_node().unwrap(); + f(&n.value()); + } + Node::InstanceVariableOperatorWriteNode { .. } => { + let n = node.as_instance_variable_operator_write_node().unwrap(); + f(&n.value()); + } + Node::InstanceVariableAndWriteNode { .. } => { + let n = node.as_instance_variable_and_write_node().unwrap(); + f(&n.value()); + } + Node::InstanceVariableOrWriteNode { .. } => { + let n = node.as_instance_variable_or_write_node().unwrap(); + f(&n.value()); + } + Node::ConstantOperatorWriteNode { .. } => { + let n = node.as_constant_operator_write_node().unwrap(); + f(&n.value()); + } + Node::ConstantAndWriteNode { .. } => { + let n = node.as_constant_and_write_node().unwrap(); + f(&n.value()); + } + Node::ConstantOrWriteNode { .. } => { + let n = node.as_constant_or_write_node().unwrap(); + f(&n.value()); + } + Node::ConstantPathOperatorWriteNode { .. } => { + let n = node.as_constant_path_operator_write_node().unwrap(); + // target() returns ConstantPathNode, not Node - skip + f(&n.value()); + } + Node::ConstantPathAndWriteNode { .. } => { + let n = node.as_constant_path_and_write_node().unwrap(); + // target() returns ConstantPathNode, not Node - skip + f(&n.value()); + } + Node::ConstantPathOrWriteNode { .. } => { + let n = node.as_constant_path_or_write_node().unwrap(); + // target() returns ConstantPathNode, not Node - skip + f(&n.value()); + } + Node::ClassVariableOperatorWriteNode { .. } => { + let n = node.as_class_variable_operator_write_node().unwrap(); + f(&n.value()); + } + Node::ClassVariableAndWriteNode { .. } => { + let n = node.as_class_variable_and_write_node().unwrap(); + f(&n.value()); + } + Node::ClassVariableOrWriteNode { .. } => { + let n = node.as_class_variable_or_write_node().unwrap(); + f(&n.value()); + } + Node::GlobalVariableOperatorWriteNode { .. } => { + let n = node.as_global_variable_operator_write_node().unwrap(); + f(&n.value()); + } + Node::GlobalVariableAndWriteNode { .. } => { + let n = node.as_global_variable_and_write_node().unwrap(); + f(&n.value()); + } + Node::GlobalVariableOrWriteNode { .. } => { + let n = node.as_global_variable_or_write_node().unwrap(); + f(&n.value()); + } + Node::IndexOperatorWriteNode { .. } => { + let n = node.as_index_operator_write_node().unwrap(); + if let Some(recv) = n.receiver() { + f(&recv); + } + if let Some(args) = n.arguments() { + for arg in args.arguments().iter() { + f(&arg); + } + } + f(&n.value()); + } + Node::IndexAndWriteNode { .. } => { + let n = node.as_index_and_write_node().unwrap(); + if let Some(recv) = n.receiver() { + f(&recv); + } + if let Some(args) = n.arguments() { + for arg in args.arguments().iter() { + f(&arg); + } + } + f(&n.value()); + } + Node::IndexOrWriteNode { .. } => { + let n = node.as_index_or_write_node().unwrap(); + if let Some(recv) = n.receiver() { + f(&recv); + } + if let Some(args) = n.arguments() { + for arg in args.arguments().iter() { + f(&arg); + } + } + f(&n.value()); + } + // Leaf nodes and remaining types — no children to visit + _ => {} } } -fn visit_vec<'a>(v: &'a [Node], f: &mut impl FnMut(&'a Node)) { - for n in v { - f(n); +/// Visit children of a RescueNode. +fn visit_rescue_children<'pr>( + rescue: &ruby_prism::RescueNode<'pr>, + f: &mut impl FnMut(&Node<'pr>), +) { + for exc in rescue.exceptions().iter() { + f(&exc); + } + if let Some(reference) = rescue.reference() { + f(&reference); + } + if let Some(stmts) = rescue.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + if let Some(subsequent) = rescue.subsequent() { + visit_rescue_children(&subsequent, f); } } -fn visit_children<'a>(node: &'a Node, f: &mut impl FnMut(&'a Node)) { - match node { - Node::Alias(n) => { - f(&n.to); - f(&n.from); - } - Node::And(n) => { - f(&n.lhs); - f(&n.rhs); - } - Node::AndAsgn(n) => { - f(&n.recv); - f(&n.value); - } - Node::Arg(_) - | Node::BackRef(_) - | Node::Blockarg(_) - | Node::Cbase(_) - | Node::Complex(_) - | Node::Cvar(_) - | Node::EmptyElse(_) - | Node::Encoding(_) - | Node::False(_) - | Node::File(_) - | Node::Float(_) - | Node::ForwardArg(_) - | Node::ForwardedArgs(_) - | Node::Gvar(_) - | Node::Int(_) - | Node::Ivar(_) - | Node::Kwarg(_) - | Node::Kwnilarg(_) - | Node::Lambda(_) - | Node::Line(_) - | Node::Lvar(_) - | Node::Nil(_) - | Node::Rational(_) - | Node::Redo(_) - | Node::Restarg(_) - | Node::Retry(_) - | Node::Self_(_) - | Node::Shadowarg(_) - | Node::Sym(_) - | Node::True(_) - | Node::ZSuper(_) - | Node::NthRef(_) - | Node::RegOpt(_) => {} - Node::Args(n) => visit_vec(&n.args, f), - Node::Array(n) => visit_vec(&n.elements, f), - Node::ArrayPattern(n) => visit_vec(&n.elements, f), - Node::ArrayPatternWithTail(n) => visit_vec(&n.elements, f), - Node::Begin(n) => visit_vec(&n.statements, f), - Node::Block(n) => { - f(&n.call); - visit_opt(&n.args, f); - visit_opt(&n.body, f); - } - Node::BlockPass(n) => visit_opt(&n.value, f), - Node::Break(n) => visit_vec(&n.args, f), - Node::Case(n) => { - visit_opt(&n.expr, f); - visit_vec(&n.when_bodies, f); - visit_opt(&n.else_body, f); - } - Node::CaseMatch(n) => { - f(&n.expr); - visit_vec(&n.in_bodies, f); - visit_opt(&n.else_body, f); - } - Node::Casgn(n) => { - visit_opt(&n.scope, f); - visit_opt(&n.value, f); - } - Node::Class(n) => { - f(&n.name); - visit_opt(&n.superclass, f); - visit_opt(&n.body, f); - } - Node::Const(n) => visit_opt(&n.scope, f), - Node::ConstPattern(n) => { - f(&n.const_); - f(&n.pattern); - } - Node::CSend(n) => { - f(&n.recv); - visit_vec(&n.args, f); - } - Node::Cvasgn(n) => visit_opt(&n.value, f), - Node::Def(n) => { - visit_opt(&n.args, f); - visit_opt(&n.body, f); - } - Node::Defined(n) => f(&n.value), - Node::Defs(n) => { - f(&n.definee); - visit_opt(&n.args, f); - visit_opt(&n.body, f); - } - Node::Dstr(n) => visit_vec(&n.parts, f), - Node::Dsym(n) => visit_vec(&n.parts, f), - Node::EFlipFlop(n) => { - visit_opt(&n.left, f); - visit_opt(&n.right, f); - } - Node::Ensure(n) => { - visit_opt(&n.body, f); - visit_opt(&n.ensure, f); - } - Node::Erange(n) => { - visit_opt(&n.left, f); - visit_opt(&n.right, f); - } - Node::FindPattern(n) => visit_vec(&n.elements, f), - Node::For(n) => { - f(&n.iterator); - f(&n.iteratee); - visit_opt(&n.body, f); - } - Node::Gvasgn(n) => visit_opt(&n.value, f), - Node::Hash(n) => visit_vec(&n.pairs, f), - Node::HashPattern(n) => visit_vec(&n.elements, f), - Node::Heredoc(n) => visit_vec(&n.parts, f), - Node::If(n) => { - f(&n.cond); - visit_opt(&n.if_true, f); - visit_opt(&n.if_false, f); - } - Node::IfGuard(n) => f(&n.cond), - Node::IFlipFlop(n) => { - visit_opt(&n.left, f); - visit_opt(&n.right, f); - } - Node::IfMod(n) => { - f(&n.cond); - visit_opt(&n.if_true, f); - visit_opt(&n.if_false, f); - } - Node::IfTernary(n) => { - f(&n.cond); - f(&n.if_true); - f(&n.if_false); - } - Node::Index(n) => { - f(&n.recv); - visit_vec(&n.indexes, f); - } - Node::IndexAsgn(n) => { - f(&n.recv); - visit_vec(&n.indexes, f); - visit_opt(&n.value, f); - } - Node::InPattern(n) => { - f(&n.pattern); - visit_opt(&n.guard, f); - visit_opt(&n.body, f); - } - Node::Irange(n) => { - visit_opt(&n.left, f); - visit_opt(&n.right, f); - } - Node::Ivasgn(n) => visit_opt(&n.value, f), - Node::Kwargs(n) => visit_vec(&n.pairs, f), - Node::KwBegin(n) => visit_vec(&n.statements, f), - Node::Kwoptarg(n) => f(&n.default), - Node::Kwrestarg(_) => {} - Node::Kwsplat(n) => f(&n.value), - Node::Lvasgn(n) => visit_opt(&n.value, f), - Node::Masgn(n) => { - f(&n.lhs); - f(&n.rhs); - } - Node::MatchAlt(n) => { - f(&n.lhs); - f(&n.rhs); - } - Node::MatchAs(n) => { - f(&n.value); - f(&n.as_); - } - Node::MatchCurrentLine(n) => f(&n.re), - Node::MatchNilPattern(_) => {} - Node::MatchPattern(n) => { - f(&n.value); - f(&n.pattern); - } - Node::MatchPatternP(n) => { - f(&n.value); - f(&n.pattern); - } - Node::MatchRest(n) => visit_opt(&n.name, f), - Node::MatchVar(_) => {} - Node::MatchWithLvasgn(n) => { - f(&n.re); - f(&n.value); - } - Node::Mlhs(n) => visit_vec(&n.items, f), - Node::Module(n) => { - f(&n.name); - visit_opt(&n.body, f); - } - Node::Next(n) => visit_vec(&n.args, f), - Node::Numblock(n) => { - f(&n.call); - f(&n.body); - } - Node::OpAsgn(n) => { - f(&n.recv); - f(&n.value); - } - Node::Optarg(n) => f(&n.default), - Node::Or(n) => { - f(&n.lhs); - f(&n.rhs); - } - Node::OrAsgn(n) => { - f(&n.recv); - f(&n.value); - } - Node::Pair(n) => { - f(&n.key); - f(&n.value); - } - Node::Pin(n) => f(&n.var), - Node::Postexe(n) => visit_opt(&n.body, f), - Node::Preexe(n) => visit_opt(&n.body, f), - Node::Procarg0(n) => visit_vec(&n.args, f), - Node::Regexp(n) => visit_vec(&n.parts, f), - Node::Rescue(n) => { - visit_opt(&n.body, f); - visit_vec(&n.rescue_bodies, f); - visit_opt(&n.else_, f); - } - Node::RescueBody(n) => { - visit_opt(&n.exc_list, f); - visit_opt(&n.exc_var, f); - visit_opt(&n.body, f); - } - Node::Return(n) => visit_vec(&n.args, f), - Node::SClass(n) => { - f(&n.expr); - visit_opt(&n.body, f); - } - Node::Send(n) => { - visit_opt(&n.recv, f); - visit_vec(&n.args, f); - } - Node::Splat(n) => visit_opt(&n.value, f), - Node::Str(_) => {} - Node::Super(n) => visit_vec(&n.args, f), - Node::Undef(n) => visit_vec(&n.names, f), - Node::UnlessGuard(n) => f(&n.cond), - Node::Until(n) => { - f(&n.cond); - visit_opt(&n.body, f); - } - Node::UntilPost(n) => { - f(&n.cond); - f(&n.body); - } - Node::When(n) => { - visit_vec(&n.patterns, f); - visit_opt(&n.body, f); - } - Node::While(n) => { - f(&n.cond); - visit_opt(&n.body, f); - } - Node::WhilePost(n) => { - f(&n.cond); - f(&n.body); - } - Node::XHeredoc(n) => visit_vec(&n.parts, f), - Node::Xstr(n) => visit_vec(&n.parts, f), - Node::Yield(n) => visit_vec(&n.args, f), +/// Visit all descendants through a rescue chain (rescue -> subsequent -> ...). +fn visit_rescue_chain_children<'pr>( + rescue: &ruby_prism::RescueNode<'pr>, + f: &mut impl FnMut(&Node<'pr>), +) { + for exc in rescue.exceptions().iter() { + f(&exc); + } + if let Some(reference) = rescue.reference() { + f(&reference); + } + if let Some(stmts) = rescue.statements() { + for child in stmts.body().iter() { + f(&child); + } + } + if let Some(subsequent) = rescue.subsequent() { + visit_rescue_chain_children(&subsequent, f); } } #[cfg(test)] mod tests { use super::*; - use lib_ruby_parser::Parser; - - fn parse(source: &[u8]) -> Option> { - Parser::new(source.to_vec(), Default::default()) - .do_parse() - .ast - } - - fn count_children(node: &Node) -> usize { - let mut count = 0; - for_each_child(node, |_| count += 1); - count - } - /// Recursively count all nodes in the AST. - fn count_all_nodes(node: &Node) -> usize { + fn count_all_nodes(node: &Node<'_>) -> usize { let mut count = 1; - for_each_child(node, |child| count += count_all_nodes(child)); + for_each_descendant(node, &mut |_| count += 1); count } #[test] - fn node_children_matches_for_each_child() { - let ast = parse(b"a + b").unwrap(); - let children = node_children(&ast); - let mut count = 0; - for_each_child(&ast, |_| count += 1); - assert_eq!(children.len(), count); + fn visitor_counts_nodes() { + let result = ruby_prism::parse(b"a + b"); + let result = Box::leak(Box::new(result)); + let total = count_all_nodes(&result.node()); + assert!(total > 1, "Expected multiple nodes, got {}", total); } - // Exercise many AST node types through visit_children for coverage #[test] - fn visit_children_comprehensive() { + fn visitor_handles_many_node_types() { let sources: &[&[u8]] = &[ - // Alias b"alias new_method old_method", - // And, Or b"a && b || c", - // AndAsgn, OrAsgn - b"x &&= 1; y ||= 2", - // Array, ArrayPattern b"[1, 2, 3]", - // Begin, Break, Next, Return - b"begin; break 1; end", - b"loop { next }", - // Case, When - b"case x; when 1; 'a'; when 2; 'b'; else 'c'; end", - // Casgn + b"case x; when 1; 'a'; else 'c'; end", b"FOO = 1", - // Class b"class Foo < Bar; end", - // Const b"Foo::Bar", - // CSend - b"x&.foo(1)", - // Cvasgn, Cvar - b"@@x = 1; @@x", - // Def, Defs b"def foo(a); end", - b"def self.bar; end", - // Defined b"defined?(x)", - // Dstr, Dsym b"\"hello #{world}\"", - b":\"sym_#{x}\"", - // EFlipFlop, IFlipFlop - // Ensure b"begin; 1; ensure; 2; end", - // Erange, Irange b"1...10; 1..10", - // For b"for x in [1]; end", - // Gvasgn, Gvar - b"$x = 1; $x", - // Hash, Pair b"{a: 1, b: 2}", - // Heredoc - b"<<~HERE\nhello\nHERE\n", - // If, IfMod, IfTernary b"if true; 1; else; 2; end", - b"x = 1 if true", - b"true ? 1 : 2", - // Index, IndexAsgn - b"a[0]; a[0] = 1", - // Ivasgn, Ivar b"@x = 1; @x", - // Kwargs, Kwsplat - b"foo(a: 1, **opts)", - // KwBegin - b"begin; 1; rescue; 2; end", - // Kwoptarg, Kwrestarg - b"def foo(a: 1, **rest); end", - // Lvasgn b"x = 42", - // Masgn, Mlhs - b"a, b = 1, 2", - // Module b"module Foo; end", - // Next, Return b"def foo; return 1; end", - // Numblock - b"arr.map { _1.to_s }", - // OpAsgn b"x += 1", - // Optarg - b"def foo(a = 1); end", - // Pin (pattern matching) - b"case x; in ^y; end", - // Postexe, Preexe - b"END { 1 }", - b"BEGIN { 1 }", - // Procarg0 (block with single destructured arg) - b"arr.each { |(a)| a }", - // Regexp, RegOpt - b"/foo/i", - // Rescue, RescueBody b"begin; rescue StandardError => e; end", - // SClass b"class << self; end", - // Send b"foo.bar(1, 2)", - // Splat - b"foo(*args)", - // Str b"'hello'", - // Super b"def foo; super(1); end", - // Undef - b"undef :foo", - // Until, While b"until false; end", b"while true; break; end", - // Yield b"def foo; yield 1; end", - // Xstr, XHeredoc - b"`echo hi`", - // MatchCurrentLine - b"if /pattern/; end", - // Block, BlockPass b"arr.select(&:odd?)", - // FindPattern, InPattern - b"case x; in [1, *rest, 2]; end", - // MatchAlt, MatchAs - b"case x; in 1 | 2 => y; end", - // MatchPattern, MatchPatternP - b"x in [1, 2]", - b"x in [1, 2] rescue false", - // MatchNilPattern, MatchVar - b"case x; in **nil; end", - b"case x; in {a:}; end", - // MatchWithLvasgn - b"/(?.)/ =~ str", - // HashPattern, ConstPattern - b"case x; in Foo[a:]; end", - // MatchRest - b"case x; in [*, 1]; end", - // WhilePost - b"begin; 1; end while true", - // UntilPost - b"begin; 1; end until true", - // UnlessGuard - b"case x; in 1 unless false; end", - // EFlipFlop (exclusive) - b"if (a == 1)...(b == 2); end", - // IFlipFlop (inclusive) - b"if (a == 1)..(b == 2); end", - // Rational, Complex - b"1r", - b"1i", - // BackRef, NthRef - b"$~ ; $1", - // Redo, Retry - b"begin; retry; rescue; end", - // Self - b"self", - // ZSuper - b"def foo; super; end", - // Lambda + b"arr.map { |x| x.to_s }", b"-> { 1 }", - // Encoding, File, Line - b"__ENCODING__", - b"__FILE__", - b"__LINE__", - // XHeredoc - b"<<~`CMD`\necho hi\nCMD\n", - // ArrayPatternWithTail - b"case x; in [1, 2,]; end", - // ForwardArg, ForwardedArgs - b"def foo(...); bar(...); end", - // Kwarg - b"def foo(a:); end", - // Kwnilarg - b"def foo(**nil); end", - // Shadowarg - b"arr.each { |x; y| y }", - // Restarg - b"def foo(*args); end", ]; for source in sources { - if let Some(ast) = parse(source) { - let total = count_all_nodes(&ast); - assert!( - total > 0, - "No nodes in AST for {:?}", - std::str::from_utf8(source) - ); - } - } - } - - #[test] - fn visit_children_pattern_matching() { - // These exercise pattern matching AST nodes specifically - let sources: &[&[u8]] = &[ - // MatchPattern (in operator) - b"1 in Integer", - // MatchPatternP (case/in with guard) - b"case 1; in Integer if true; end", - // IfGuard - b"case 1; in x if x > 0; end", - // UnlessGuard - b"case 1; in x unless x < 0; end", - // FindPattern - b"case [1,2,3]; in [*, 2, *]; end", - // HashPattern - b"case {a: 1}; in {a: Integer}; end", - // ConstPattern - b"case x; in Foo(1); end", - // MatchNilPattern - b"case {a: 1}; in **nil; end", - // MatchVar - b"case 1; in x; end", - // MatchRest - b"case [1,2]; in [Integer, *rest]; end", - // MatchAlt - b"case 1; in 1 | 2; end", - // MatchAs - b"case 1; in Integer => x; end", - // MatchWithLvasgn (regex named capture) - b"/(?.)/ =~ 'x'", - // Pin - b"x = 1; case 2; in ^x; end", - ]; - for source in sources { - if let Some(ast) = parse(source) { - let total = count_all_nodes(&ast); - assert!(total > 0, "No nodes for {:?}", std::str::from_utf8(source)); - } + let result = ruby_prism::parse(source); + let result = Box::leak(Box::new(result)); + let total = count_all_nodes(&result.node()); + assert!( + total > 0, + "No nodes in AST for {:?}", + std::str::from_utf8(source) + ); } } #[test] - fn visit_children_leaf_nodes() { - // Leaf nodes should have 0 children - let leaf_sources: &[&[u8]] = &[ - b"42", // Int - b"3.14", // Float - b"'s'", // Str - b":sym", // Sym - b"true", // True - b"false", // False - b"nil", // Nil - b"x", // Lvar - ]; + fn leaf_nodes_have_no_extra_children() { + let leaf_sources: &[&[u8]] = &[b"42", b"3.14", b"'s'", b":sym", b"true", b"false", b"nil"]; for source in leaf_sources { - let ast = parse(source).unwrap(); + let result = ruby_prism::parse(source); + let result = Box::leak(Box::new(result)); + let prog = result.node().as_program_node().unwrap(); + let node = prog.statements().body().iter().next().unwrap(); + let mut child_count = 0; + for_each_direct_child(&node, &mut |_| child_count += 1); assert_eq!( - count_children(&ast), + child_count, 0, "Expected 0 children for {:?}", std::str::from_utf8(source) diff --git a/src/comment_directives.rs b/src/comment_directives.rs index a946667..658ffd4 100644 --- a/src/comment_directives.rs +++ b/src/comment_directives.rs @@ -1,7 +1,5 @@ use std::collections::HashSet; -use lib_ruby_parser::source::Comment; - use crate::ast_helpers::byte_offset_to_line; use crate::offense::OffenseKind; @@ -21,7 +19,7 @@ impl DisabledSet { } } -/// Build a DisabledSet from parser comments, source bytes, and pre-computed newline positions. +/// Build a DisabledSet from a parse result, source bytes, and pre-computed newline positions. /// /// Supports: /// - `# rubyfast:disable rule` or `# fasterer:disable rule` — trailing (same line) or block start @@ -30,7 +28,7 @@ impl DisabledSet { /// - `# rubyfast:disable all` — disable all rules /// - `# rubyfast:disable rule1, rule2` — multiple rules pub fn build_disabled_set( - comments: &[Comment], + parse_result: &ruby_prism::ParseResult<'_>, source: &[u8], newline_positions: &[usize], ) -> DisabledSet { @@ -43,9 +41,10 @@ pub fn build_disabled_set( let mut block_all_start: Option = None; let mut block_rule_starts: Vec<(OffenseKind, usize)> = Vec::new(); - for comment in comments { - let begin = comment.location.begin; - let end = comment.location.end; + for comment in parse_result.comments() { + let loc = comment.location(); + let begin = loc.start_offset(); + let end = loc.end_offset(); let comment_line = byte_offset_to_line(newline_positions, begin); let comment_text = &source[begin..end.min(source.len())]; let comment_str = String::from_utf8_lossy(comment_text); @@ -232,10 +231,10 @@ mod tests { use super::*; fn parse_and_build(source: &str) -> DisabledSet { - let bytes = source.as_bytes().to_vec(); - let result = lib_ruby_parser::Parser::new(bytes.clone(), Default::default()).do_parse(); - let newline_positions = crate::ast_helpers::compute_newline_positions(&bytes); - build_disabled_set(&result.comments, &bytes, &newline_positions) + let bytes = source.as_bytes(); + let result = ruby_prism::parse(bytes); + let newline_positions = crate::ast_helpers::compute_newline_positions(bytes); + build_disabled_set(&result, bytes, &newline_positions) } #[test] diff --git a/src/file_traverser.rs b/src/file_traverser.rs index ab50477..5498f2b 100644 --- a/src/file_traverser.rs +++ b/src/file_traverser.rs @@ -272,8 +272,8 @@ mod tests { let config = Config::default(); let result = traverse_and_analyze(dir.path(), &config); assert_eq!(result.files_inspected, 1); - // "def def def" produces a fatal parse error with no recoverable AST - assert_eq!(result.parse_errors.len(), 1); - assert!(result.results.is_empty()); + // Prism always produces an AST even with errors, but our analyzer skips + // analysis when errors are detected, returning empty offenses. + assert!(result.results.iter().all(|r| r.offenses.is_empty())); } } diff --git a/src/fix.rs b/src/fix.rs index 317e020..c4fc901 100644 --- a/src/fix.rs +++ b/src/fix.rs @@ -1,9 +1,5 @@ use std::path::Path; -use lib_ruby_parser::{ErrorLevel, Parser}; - -use crate::ast_helpers::parser_options; - /// A single byte-range replacement in a source file. #[derive(Debug, Clone)] pub struct Replacement { @@ -80,11 +76,8 @@ pub fn apply_fixes(source: &[u8], fixes: &[Fix]) -> Vec { /// Verify that the given source parses without fatal errors. pub fn verify_syntax(source: &[u8]) -> bool { - let result = Parser::new(source.to_vec(), parser_options()).do_parse(); - !result - .diagnostics - .iter() - .any(|d| d.level == ErrorLevel::Error) + let result = ruby_prism::parse(source); + result.errors().next().is_none() } /// Apply fixes to a file: read -> fix -> verify syntax -> write. @@ -138,10 +131,6 @@ mod tests { Fix::single(2, 6, "XX"), // replace cdef with XX Fix::single(4, 8, "YY"), // overlaps — should be skipped ]; - // Because we sort descending, 4..8 is processed first, then 2..6 overlaps - // Actually: sorted descending by start: 4..8 first (start=4), then 2..6 (start=2) - // 4..8 replaces "efgh" -> "YY", result = "abcdYY", last_start=4 - // 2..6 has end=6 > last_start=4, so it's skipped let result = apply_fixes(source, &fixes); assert_eq!(result, b"abcdYY"); } @@ -159,7 +148,6 @@ mod tests { #[test] fn two_replacements_in_one_fix() { let source = b"arr.map { |x| [x] }.flatten(1)"; - // Rename .map -> .flat_map and delete .flatten(1) let fix = Fix::two( 4, 7, "flat_map", // "map" -> "flat_map" 19, 30, "", // delete ".flatten(1)" @@ -197,7 +185,6 @@ mod tests { let dir = tempfile::TempDir::new().unwrap(); let file = dir.path().join("test.rb"); std::fs::write(&file, "for x in [1]; end").unwrap(); - // Replace "for x in [1]; " with "[1].each do |x|; " let fix = Fix::single(0, 14, "[1].each do |x|;"); let result = apply_fixes_to_file(&file, &[fix]).unwrap(); assert_eq!(result, 1); @@ -208,7 +195,6 @@ mod tests { let dir = tempfile::TempDir::new().unwrap(); let file = dir.path().join("test.rb"); std::fs::write(&file, "x = 1 + 2").unwrap(); - // This fix produces invalid syntax let fix = Fix::single(0, 9, "def def def"); let result = apply_fixes_to_file(&file, &[fix]); assert!(result.is_err()); diff --git a/src/scanner/for_loop_scanner.rs b/src/scanner/for_loop_scanner.rs index 1974bc8..8a12faf 100644 --- a/src/scanner/for_loop_scanner.rs +++ b/src/scanner/for_loop_scanner.rs @@ -1,47 +1,49 @@ -use lib_ruby_parser::nodes::For; - use crate::fix::Fix; use crate::offense::{Offense, OffenseKind}; /// Any `for` loop emits an offense — prefer `.each`. -pub fn scan(node: &For, source: &[u8]) -> Vec { +pub fn scan(node: &ruby_prism::ForNode<'_>, source: &[u8]) -> Vec { match build_fix(node, source) { Some(fix) => vec![Offense::with_fix( OffenseKind::ForLoopVsEach, - node.keyword_l.begin, + node.for_keyword_loc().start_offset(), fix, )], None => vec![Offense::new( OffenseKind::ForLoopVsEach, - node.keyword_l.begin, + node.for_keyword_loc().start_offset(), )], } } /// Build a fix that transforms `for x in arr` → `arr.each do |x|`. -fn build_fix(node: &For, source: &[u8]) -> Option { - // Extract iterator text from between "for" and "in" - let iterator = extract_trimmed(source, node.keyword_l.end, node.operator_l.begin)?; +fn build_fix(node: &ruby_prism::ForNode<'_>, source: &[u8]) -> Option { + let for_loc = node.for_keyword_loc(); + let in_loc = node.in_keyword_loc(); - // begin_l is Loc (not Option). It points to `do`, `;`, or the newline. - // Only use it as a real delimiter if it points to `do` or `;` (not a newline). - let begin_char = source.get(node.begin_l.begin).copied().unwrap_or(0); - let has_explicit_begin = !node.begin_l.is_empty() && begin_char != b'\n'; + // Extract iterator text from between "for" and "in" + let iterator = extract_trimmed(source, for_loc.end_offset(), in_loc.start_offset())?; - // Determine where the iteratee ends and the header ends - let (iteratee, header_end) = if has_explicit_begin { - let text = extract_trimmed(source, node.operator_l.end, node.begin_l.begin)?; - (text, node.begin_l.end) + // do_keyword_loc is Option in prism — present only for `do` keyword, not `;`. + let (iteratee, header_end) = if let Some(do_loc) = node.do_keyword_loc() { + let text = extract_trimmed(source, in_loc.end_offset(), do_loc.start_offset())?; + (text, do_loc.end_offset()) } else { - // No explicit `do` or `;` — header ends at the newline - let search_start = node.operator_l.end; - let line_end = source[search_start..] + // No explicit `do` — look for `;` or newline as delimiter + let search_start = in_loc.end_offset(); + let delimiter_pos = source[search_start..] .iter() - .position(|&b| b == b'\n') + .position(|&b| b == b'\n' || b == b';') .map(|p| search_start + p) .unwrap_or(source.len()); - let text = extract_trimmed(source, search_start, line_end)?; - (text, line_end) + let text = extract_trimmed(source, search_start, delimiter_pos)?; + // If delimiter is `;`, include it in header_end so it gets replaced + let header_end = if source.get(delimiter_pos) == Some(&b';') { + delimiter_pos + 1 + } else { + delimiter_pos + }; + (text, header_end) }; if iterator.is_empty() || iteratee.is_empty() { @@ -49,7 +51,7 @@ fn build_fix(node: &For, source: &[u8]) -> Option { } let new_header = format!("{}.each do |{}|", iteratee, iterator); - Some(Fix::single(node.keyword_l.begin, header_end, new_header)) + Some(Fix::single(for_loc.start_offset(), header_end, new_header)) } /// Extract a trimmed UTF-8 string from a byte range. Returns None if not valid UTF-8. @@ -66,68 +68,57 @@ fn extract_trimmed(source: &[u8], start: usize, end: usize) -> Option { mod tests { use super::*; + fn parse_first_for(source: &'static [u8]) -> ruby_prism::ForNode<'static> { + let result = ruby_prism::parse(source); + let result = Box::leak(Box::new(result)); + let program = result.node(); + let prog = program.as_program_node().unwrap(); + let stmts: Vec<_> = prog.statements().body().iter().collect(); + stmts[0].as_for_node().unwrap() + } + #[test] fn for_loop_always_fires() { let source = b"for x in [1,2,3]; end"; - let result = lib_ruby_parser::Parser::new(source.to_vec(), Default::default()).do_parse(); - let ast = result.ast.unwrap(); - if let lib_ruby_parser::Node::For(f) = ast.as_ref() { - let offenses = scan(f, source); - assert_eq!(offenses.len(), 1); - assert_eq!(offenses[0].kind, OffenseKind::ForLoopVsEach); - assert!(offenses[0].fix.is_some()); - } else { - panic!("Expected For node"); - } + let f = parse_first_for(source); + let offenses = scan(&f, source); + assert_eq!(offenses.len(), 1); + assert_eq!(offenses[0].kind, OffenseKind::ForLoopVsEach); + assert!(offenses[0].fix.is_some()); } #[test] fn fix_for_loop_with_do() { let source = b"for x in arr do\n puts x\nend"; - let result = lib_ruby_parser::Parser::new(source.to_vec(), Default::default()).do_parse(); - let ast = result.ast.unwrap(); - if let lib_ruby_parser::Node::For(f) = ast.as_ref() { - let fix = build_fix(f, source).unwrap(); - let fixed = crate::fix::apply_fixes(source, &[fix]); - assert_eq!( - String::from_utf8(fixed).unwrap(), - "arr.each do |x|\n puts x\nend" - ); - } else { - panic!("Expected For node"); - } + let f = parse_first_for(source); + let fix = build_fix(&f, source).unwrap(); + let fixed = crate::fix::apply_fixes(source, &[fix]); + assert_eq!( + String::from_utf8(fixed).unwrap(), + "arr.each do |x|\n puts x\nend" + ); } #[test] fn fix_for_loop_with_semicolon() { let source = b"for x in [1,2,3]; puts x; end"; - let result = lib_ruby_parser::Parser::new(source.to_vec(), Default::default()).do_parse(); - let ast = result.ast.unwrap(); - if let lib_ruby_parser::Node::For(f) = ast.as_ref() { - let fix = build_fix(f, source).unwrap(); - let fixed = crate::fix::apply_fixes(source, &[fix]); - let fixed_str = String::from_utf8(fixed).unwrap(); - assert!(fixed_str.starts_with("[1,2,3].each do |x|")); - } else { - panic!("Expected For node"); - } + let f = parse_first_for(source); + let fix = build_fix(&f, source).unwrap(); + let fixed = crate::fix::apply_fixes(source, &[fix]); + let fixed_str = String::from_utf8(fixed).unwrap(); + assert!(fixed_str.starts_with("[1,2,3].each do |x|")); } #[test] fn fix_for_loop_newline_only() { let source = b"for x in arr\n puts x\nend"; - let result = lib_ruby_parser::Parser::new(source.to_vec(), Default::default()).do_parse(); - let ast = result.ast.unwrap(); - if let lib_ruby_parser::Node::For(f) = ast.as_ref() { - let fix = build_fix(f, source).unwrap(); - let fixed = crate::fix::apply_fixes(source, &[fix]); - assert_eq!( - String::from_utf8(fixed).unwrap(), - "arr.each do |x|\n puts x\nend" - ); - } else { - panic!("Expected For node"); - } + let f = parse_first_for(source); + let fix = build_fix(&f, source).unwrap(); + let fixed = crate::fix::apply_fixes(source, &[fix]); + assert_eq!( + String::from_utf8(fixed).unwrap(), + "arr.each do |x|\n puts x\nend" + ); } #[test] @@ -156,14 +147,10 @@ mod tests { #[test] fn scan_always_returns_offense() { - // Even when fix fails, we should still get an offense let source = b"for x in arr; end"; - let result = lib_ruby_parser::Parser::new(source.to_vec(), Default::default()).do_parse(); - let ast = result.ast.unwrap(); - if let lib_ruby_parser::Node::For(f) = ast.as_ref() { - let offenses = scan(f, source); - assert_eq!(offenses.len(), 1); - assert_eq!(offenses[0].kind, OffenseKind::ForLoopVsEach); - } + let f = parse_first_for(source); + let offenses = scan(&f, source); + assert_eq!(offenses.len(), 1); + assert_eq!(offenses[0].kind, OffenseKind::ForLoopVsEach); } } diff --git a/src/scanner/method_call_scanner.rs b/src/scanner/method_call_scanner.rs index 7820dee..f9133df 100644 --- a/src/scanner/method_call_scanner.rs +++ b/src/scanner/method_call_scanner.rs @@ -1,133 +1,124 @@ -use lib_ruby_parser::Node; -use lib_ruby_parser::nodes::{Block, Send}; - use crate::ast_helpers::*; use crate::fix::Fix; use crate::offense::{Offense, OffenseKind}; -/// Scan a method call (Send node) that is NOT inside a Block. -pub fn scan_send(send: &Send) -> Vec { +/// Scan a method call (CallNode) that does NOT have a block. +pub fn scan_call(call: &ruby_prism::CallNode<'_>) -> Vec { let mut offenses = Vec::new(); - check_shuffle_first(send, &mut offenses); - check_reverse_each(send, &mut offenses); - check_keys_each(send, &mut offenses); - check_each_with_index(send, &mut offenses); - check_include_vs_cover(send, &mut offenses); - check_gsub_vs_tr(send, &mut offenses); - check_fetch_with_argument(send, &mut offenses); - check_hash_merge_bang(send, &mut offenses); - check_map_flatten(send, &mut offenses); - check_select_first(send, &mut offenses); - check_select_last(send, &mut offenses); - check_module_eval_send(send, &mut offenses); + check_shuffle_first(call, &mut offenses); + check_reverse_each(call, &mut offenses); + check_keys_each(call, &mut offenses); + check_each_with_index(call, &mut offenses); + check_include_vs_cover(call, &mut offenses); + check_gsub_vs_tr(call, &mut offenses); + check_fetch_with_argument(call, &mut offenses); + check_hash_merge_bang(call, &mut offenses); + check_map_flatten(call, &mut offenses); + check_select_first(call, &mut offenses); + check_select_last(call, &mut offenses); + check_module_eval_call(call, &mut offenses); offenses } -/// Scan a Block node (method call + block). -pub fn scan_block(block: &Block) -> Vec { +/// Scan a CallNode that has a BlockNode (method call + block). +pub fn scan_call_with_block( + call: &ruby_prism::CallNode<'_>, + block: &ruby_prism::BlockNode<'_>, +) -> Vec { let mut offenses = Vec::new(); - let send = match block_call_as_send(block) { - Some(s) => s, - None => return offenses, - }; - // Checks that only apply when a block is present - check_sort_vs_sort_by(send, &mut offenses); - check_module_eval_send(send, &mut offenses); - check_block_vs_symbol_to_proc(send, block, &mut offenses); - - // Chain checks where receiver might be a block call - // e.g., .select{}.first — the .first Send wraps the Block - // These are actually checked on the outer Send whose receiver is this Block. - // But we also run the send-level checks on the call inside the block. - check_shuffle_first(send, &mut offenses); - check_reverse_each(send, &mut offenses); - check_keys_each(send, &mut offenses); - check_each_with_index(send, &mut offenses); - check_include_vs_cover(send, &mut offenses); - check_gsub_vs_tr(send, &mut offenses); - // NOTE: check_fetch_with_argument is intentionally excluded here. - // If fetch already has a block, the rule doesn't apply. - check_hash_merge_bang(send, &mut offenses); + check_sort_vs_sort_by(call, &mut offenses); + check_module_eval_call(call, &mut offenses); + check_block_vs_symbol_to_proc(call, block, &mut offenses); + + // Chain checks on the call inside the block + check_shuffle_first(call, &mut offenses); + check_reverse_each(call, &mut offenses); + check_keys_each(call, &mut offenses); + check_each_with_index(call, &mut offenses); + check_include_vs_cover(call, &mut offenses); + check_gsub_vs_tr(call, &mut offenses); + // NOTE: check_fetch_with_argument excluded — if fetch already has a block, rule doesn't apply. + check_hash_merge_bang(call, &mut offenses); offenses } -/// Scan a Send whose receiver is a Block node. -/// This handles chains like `.select { }.first` where .first's receiver is a Block. -pub fn scan_send_on_block(send: &Send, recv_block: &Block) -> Vec { +/// Scan a CallNode whose receiver is another CallNode that has a block. +/// This handles chains like `.select { }.first` where .first's receiver is a call-with-block. +pub fn scan_call_on_block_call( + outer: &ruby_prism::CallNode<'_>, + recv_call: &ruby_prism::CallNode<'_>, +) -> Vec { let mut offenses = Vec::new(); - let recv_send = match block_call_as_send(recv_block) { - Some(s) => s, - None => return offenses, - }; + let outer_name = outer.name().as_slice(); + let recv_name = recv_call.name().as_slice(); // .select{}.first → .detect{} - if send.method_name == "first" - && recv_send.method_name == "select" - && arg_count_without_block_pass(&send.args) == 0 - { - let offense = match (recv_send.selector_l.as_ref(), send.dot_l.as_ref()) { + if outer_name == b"first" && recv_name == b"select" && arg_count(outer) == 0 { + let offense = match (recv_call.message_loc(), outer.call_operator_loc()) { (Some(sel_l), Some(dot_l)) => { let fix = Fix::two( - sel_l.begin, - sel_l.end, + sel_l.start_offset(), + sel_l.end_offset(), "detect", - dot_l.begin, - send.expression_l.end, + dot_l.start_offset(), + outer.location().end_offset(), "", ); Offense::with_fix( OffenseKind::SelectFirstVsDetect, - send.expression_l.begin, + outer.location().start_offset(), fix, ) } - _ => Offense::new(OffenseKind::SelectFirstVsDetect, send.expression_l.begin), + _ => Offense::new( + OffenseKind::SelectFirstVsDetect, + outer.location().start_offset(), + ), }; offenses.push(offense); } - // .select{}.last (no auto-fix — transform to .reverse.detect is too risky) - if send.method_name == "last" - && recv_send.method_name == "select" - && arg_count_without_block_pass(&send.args) == 0 - { + // .select{}.last (no auto-fix) + if outer_name == b"last" && recv_name == b"select" && arg_count(outer) == 0 { offenses.push(Offense::new( OffenseKind::SelectLastVsReverseDetect, - send.expression_l.begin, + outer.location().start_offset(), )); } // .map{}.flatten(1) → .flat_map{} - if send.method_name == "flatten" - && recv_send.method_name == "map" - && send.args.len() == 1 - && is_int_one(&send.args[0]) - { - let offense = match (recv_send.selector_l.as_ref(), send.dot_l.as_ref()) { - (Some(sel_l), Some(dot_l)) => { - let fix = Fix::two( - sel_l.begin, - sel_l.end, - "flat_map", - dot_l.begin, - send.expression_l.end, - "", - ); - Offense::with_fix( + if outer_name == b"flatten" && recv_name == b"map" { + let args = call_args(outer); + if args.len() == 1 && is_int_one(&args[0]) { + let offense = match (recv_call.message_loc(), outer.call_operator_loc()) { + (Some(sel_l), Some(dot_l)) => { + let fix = Fix::two( + sel_l.start_offset(), + sel_l.end_offset(), + "flat_map", + dot_l.start_offset(), + outer.location().end_offset(), + "", + ); + Offense::with_fix( + OffenseKind::MapFlattenVsFlatMap, + outer.location().start_offset(), + fix, + ) + } + _ => Offense::new( OffenseKind::MapFlattenVsFlatMap, - send.expression_l.begin, - fix, - ) - } - _ => Offense::new(OffenseKind::MapFlattenVsFlatMap, send.expression_l.begin), - }; - offenses.push(offense); + outer.location().start_offset(), + ), + }; + offenses.push(offense); + } } offenses @@ -136,279 +127,328 @@ pub fn scan_send_on_block(send: &Send, recv_block: &Block) -> Vec { // --- Individual offense checks --- /// `.shuffle.first` → `.sample` -fn check_shuffle_first(send: &Send, offenses: &mut Vec) { - if send.method_name != "first" || !receiver_is_send_with_name(&send.recv, "shuffle") { +fn check_shuffle_first(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"first" + || !receiver_is_call_with_name(&call.receiver(), b"shuffle") + { return; } - let offense = match receiver_as_send(&send.recv).and_then(|rs| rs.dot_l.as_ref()) { + let offense = match receiver_as_call(&call.receiver()).and_then(|rs| rs.call_operator_loc()) { Some(dot_l) => { - let fix = Fix::single(dot_l.begin, send.expression_l.end, ".sample"); + let fix = Fix::single( + dot_l.start_offset(), + call.location().end_offset(), + ".sample", + ); Offense::with_fix( OffenseKind::ShuffleFirstVsSample, - send.expression_l.begin, + call.location().start_offset(), fix, ) } - None => Offense::new(OffenseKind::ShuffleFirstVsSample, send.expression_l.begin), + None => Offense::new( + OffenseKind::ShuffleFirstVsSample, + call.location().start_offset(), + ), }; offenses.push(offense); } /// `.reverse.each` → `.reverse_each` -fn check_reverse_each(send: &Send, offenses: &mut Vec) { - if send.method_name != "each" || !receiver_is_send_with_name(&send.recv, "reverse") { +fn check_reverse_each(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"each" + || !receiver_is_call_with_name(&call.receiver(), b"reverse") + { return; } let offense = match ( - receiver_as_send(&send.recv).and_then(|rs| rs.dot_l.as_ref()), - send.selector_l.as_ref(), + receiver_as_call(&call.receiver()).and_then(|rs| rs.call_operator_loc()), + call.message_loc(), ) { (Some(dot_l), Some(sel_l)) => { - let fix = Fix::single(dot_l.begin, sel_l.end, ".reverse_each"); + let fix = Fix::single(dot_l.start_offset(), sel_l.end_offset(), ".reverse_each"); Offense::with_fix( OffenseKind::ReverseEachVsReverseEach, - send.expression_l.begin, + call.location().start_offset(), fix, ) } _ => Offense::new( OffenseKind::ReverseEachVsReverseEach, - send.expression_l.begin, + call.location().start_offset(), ), }; offenses.push(offense); } /// `.keys.each` → `.each_key` (keys must have 0 args) -fn check_keys_each(send: &Send, offenses: &mut Vec) { - if send.method_name != "each" { +fn check_keys_each(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"each" { return; } - if let Some(recv_send) = receiver_as_send(&send.recv) - && recv_send.method_name == "keys" - && recv_send.args.is_empty() + if let Some(recv_call) = receiver_as_call(&call.receiver()) + && recv_call.name().as_slice() == b"keys" + && arg_count(&recv_call) == 0 { - let offense = match (recv_send.dot_l.as_ref(), send.selector_l.as_ref()) { + let offense = match (recv_call.call_operator_loc(), call.message_loc()) { (Some(dot_l), Some(sel_l)) => { - let fix = Fix::single(dot_l.begin, sel_l.end, ".each_key"); - Offense::with_fix(OffenseKind::KeysEachVsEachKey, send.expression_l.begin, fix) + let fix = Fix::single(dot_l.start_offset(), sel_l.end_offset(), ".each_key"); + Offense::with_fix( + OffenseKind::KeysEachVsEachKey, + call.location().start_offset(), + fix, + ) } - _ => Offense::new(OffenseKind::KeysEachVsEachKey, send.expression_l.begin), + _ => Offense::new( + OffenseKind::KeysEachVsEachKey, + call.location().start_offset(), + ), }; offenses.push(offense); } } -/// `.select{}.first` → `.detect{}` (when receiver is a plain Send, not Block) -fn check_select_first(send: &Send, offenses: &mut Vec) { - if send.method_name != "first" || arg_count_without_block_pass(&send.args) != 0 { +/// `.select{}.first` → `.detect{}` (when receiver is a plain call with block_pass, not block) +fn check_select_first(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"first" || arg_count(call) != 0 { return; } - if let Some(recv_send) = receiver_as_send(&send.recv) - && recv_send.method_name == "select" - && has_block_pass(&recv_send.args) + if let Some(recv_call) = receiver_as_call(&call.receiver()) + && recv_call.name().as_slice() == b"select" + && has_block_pass(&recv_call) { - let offense = match (recv_send.selector_l.as_ref(), send.dot_l.as_ref()) { + let offense = match (recv_call.message_loc(), call.call_operator_loc()) { (Some(sel_l), Some(dot_l)) => { let fix = Fix::two( - sel_l.begin, - sel_l.end, + sel_l.start_offset(), + sel_l.end_offset(), "detect", - dot_l.begin, - send.expression_l.end, + dot_l.start_offset(), + call.location().end_offset(), "", ); Offense::with_fix( OffenseKind::SelectFirstVsDetect, - send.expression_l.begin, + call.location().start_offset(), fix, ) } - _ => Offense::new(OffenseKind::SelectFirstVsDetect, send.expression_l.begin), + _ => Offense::new( + OffenseKind::SelectFirstVsDetect, + call.location().start_offset(), + ), }; offenses.push(offense); } } -/// `.select{}.last` → `.reverse.detect{}` (when receiver is a plain Send) -fn check_select_last(send: &Send, offenses: &mut Vec) { - if send.method_name != "last" || arg_count_without_block_pass(&send.args) != 0 { +/// `.select{}.last` → `.reverse.detect{}` (when receiver is a plain call with block_pass) +fn check_select_last(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"last" || arg_count(call) != 0 { return; } - if let Some(recv_send) = receiver_as_send(&send.recv) - && recv_send.method_name == "select" - && has_block_pass(&recv_send.args) + if let Some(recv_call) = receiver_as_call(&call.receiver()) + && recv_call.name().as_slice() == b"select" + && has_block_pass(&recv_call) { offenses.push(Offense::new( OffenseKind::SelectLastVsReverseDetect, - send.expression_l.begin, + call.location().start_offset(), )); } } -/// `.map{}.flatten(1)` → `.flat_map{}` (when receiver is a plain Send) -fn check_map_flatten(send: &Send, offenses: &mut Vec) { - if send.method_name != "flatten" || send.args.len() != 1 || !is_int_one(&send.args[0]) { +/// `.map{}.flatten(1)` → `.flat_map{}` (when receiver is a plain call with block_pass, not full block) +fn check_map_flatten(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"flatten" { + return; + } + let args = call_args(call); + if args.len() != 1 || !is_int_one(&args[0]) { return; } - if receiver_is_send_with_name(&send.recv, "map") { + // Only match when receiver is map WITHOUT a full block (block_pass is ok). + // Full block cases are handled by scan_call_on_block_call. + if let Some(recv_call) = receiver_as_call(&call.receiver()) + && recv_call.name().as_slice() == b"map" + && !has_full_block(&recv_call) + { offenses.push(Offense::new( OffenseKind::MapFlattenVsFlatMap, - send.expression_l.begin, + call.location().start_offset(), )); } } /// `.each_with_index` → while loop -fn check_each_with_index(send: &Send, offenses: &mut Vec) { - if send.method_name == "each_with_index" { +fn check_each_with_index(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() == b"each_with_index" { offenses.push(Offense::new( OffenseKind::EachWithIndexVsWhile, - send.expression_l.begin, + call.location().start_offset(), )); } } /// `(1..10).include?` → `.cover?` -fn check_include_vs_cover(send: &Send, offenses: &mut Vec) { - if send.method_name != "include?" || !receiver_is_range(&send.recv) { +fn check_include_vs_cover(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"include?" || !receiver_is_range(&call.receiver()) { return; } - let offense = match send.selector_l.as_ref() { + let offense = match call.message_loc() { Some(sel_l) => { - let fix = Fix::single(sel_l.begin, sel_l.end, "cover?"); + let fix = Fix::single(sel_l.start_offset(), sel_l.end_offset(), "cover?"); Offense::with_fix( OffenseKind::IncludeVsCoverOnRange, - send.expression_l.begin, + call.location().start_offset(), fix, ) } - None => Offense::new(OffenseKind::IncludeVsCoverOnRange, send.expression_l.begin), + None => Offense::new( + OffenseKind::IncludeVsCoverOnRange, + call.location().start_offset(), + ), }; offenses.push(offense); } /// `.gsub("x", "y")` → `.tr("x", "y")` when both args are single-char strings -fn check_gsub_vs_tr(send: &Send, offenses: &mut Vec) { - if send.method_name != "gsub" || send.args.len() != 2 { +fn check_gsub_vs_tr(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"gsub" { return; } - if is_single_char_string(&send.args[0]) && is_single_char_string(&send.args[1]) { - let offense = match send.selector_l.as_ref() { + let args = call_args(call); + if args.len() != 2 { + return; + } + if is_single_char_string(&args[0]) && is_single_char_string(&args[1]) { + let offense = match call.message_loc() { Some(sel_l) => { - let fix = Fix::single(sel_l.begin, sel_l.end, "tr"); - Offense::with_fix(OffenseKind::GsubVsTr, send.expression_l.begin, fix) + let fix = Fix::single(sel_l.start_offset(), sel_l.end_offset(), "tr"); + Offense::with_fix(OffenseKind::GsubVsTr, call.location().start_offset(), fix) } - None => Offense::new(OffenseKind::GsubVsTr, send.expression_l.begin), + None => Offense::new(OffenseKind::GsubVsTr, call.location().start_offset()), }; offenses.push(offense); } } /// `.sort { |a, b| ... }` → `.sort_by` (only fires when sort has a block) -fn check_sort_vs_sort_by(send: &Send, offenses: &mut Vec) { - if send.method_name == "sort" { +fn check_sort_vs_sort_by(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() == b"sort" { offenses.push(Offense::new( OffenseKind::SortVsSortBy, - send.expression_l.begin, + call.location().start_offset(), )); } } /// `.fetch(k, v)` → `.fetch(k) { v }` -fn check_fetch_with_argument(send: &Send, offenses: &mut Vec) { - if send.method_name == "fetch" - && arg_count_without_block_pass(&send.args) == 2 - && !has_block_pass(&send.args) - { +fn check_fetch_with_argument(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() == b"fetch" && arg_count(call) == 2 && !has_block_pass(call) { offenses.push(Offense::new( OffenseKind::FetchWithArgumentVsBlock, - send.expression_l.begin, + call.location().start_offset(), )); } } /// `.merge!({k: v})` → `h[k] = v` (single pair hash argument) -fn check_hash_merge_bang(send: &Send, offenses: &mut Vec) { - if send.method_name != "merge!" || send.args.len() != 1 { +fn check_hash_merge_bang(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"merge!" { return; } - if first_arg_is_single_pair_hash(&send.args) { + let args = call_args(call); + if args.len() != 1 { + return; + } + if first_arg_is_single_pair_hash(&args) { offenses.push(Offense::new( OffenseKind::HashMergeBangVsHashBrackets, - send.expression_l.begin, + call.location().start_offset(), )); } } /// `.module_eval("def ...")` → `define_method` -fn check_module_eval_send(send: &Send, offenses: &mut Vec) { - if send.method_name != "module_eval" { +fn check_module_eval_call(call: &ruby_prism::CallNode<'_>, offenses: &mut Vec) { + if call.name().as_slice() != b"module_eval" { return; } - if let Some(first_arg) = send.args.first() + let args = call_args(call); + if let Some(first_arg) = args.first() && str_contains_def(first_arg) { offenses.push(Offense::new( OffenseKind::ModuleEval, - send.expression_l.begin, + call.location().start_offset(), )); } } /// `.map { |x| x.foo }` → `.map(&:foo)` -fn check_block_vs_symbol_to_proc(send: &Send, block: &Block, offenses: &mut Vec) { - // Must not be a lambda literal - if matches!(block.call.as_ref(), Node::Lambda(_)) { - return; - } - - // Outer method call must have 0 non-block-pass arguments - if arg_count_without_block_pass(&send.args) != 0 { +fn check_block_vs_symbol_to_proc( + call: &ruby_prism::CallNode<'_>, + block: &ruby_prism::BlockNode<'_>, + offenses: &mut Vec, +) { + // Outer method call must have 0 arguments + if arg_count(call) != 0 { return; } // Block must have exactly 1 argument - let arg_names = block_arg_names(&block.args); + let arg_names = block_arg_names(&block.parameters()); if arg_names.len() != 1 { return; } let block_arg_name = &arg_names[0]; - // Block body must be a single Send node - let body = match block.body.as_deref() { + // Block body must be a single CallNode + let body = match block.body() { Some(node) => node, None => return, }; - let inner_send = match body { - Node::Send(s) => s, - _ => return, + // If body is a StatementsNode with a single statement, unwrap it + let inner_node = if let Some(stmts) = body.as_statements_node() { + let body_nodes: Vec<_> = stmts.body().iter().collect(); + if body_nodes.len() != 1 { + return; + } + body_nodes.into_iter().next().unwrap() + } else { + body + }; + + let inner_call = match inner_node.as_call_node() { + Some(c) => c, + None => return, }; // Inner call must have 0 arguments and no block - if !inner_send.args.is_empty() { + if arg_count(&inner_call) != 0 || inner_call.block().is_some() { return; } // Inner call must have a receiver - let receiver = match inner_send.recv.as_deref() { + let receiver = match inner_call.receiver() { Some(r) => r, None => return, }; // Receiver must not be a primitive - if is_primitive(receiver) { + if is_primitive(&receiver) { return; } - // Receiver must be an Lvar matching the block argument name - if let Node::Lvar(lv) = receiver - && lv.name == *block_arg_name + // Receiver must be a LocalVariableReadNode matching the block argument name + if let Some(lv) = receiver.as_local_variable_read_node() + && String::from_utf8_lossy(lv.name().as_slice()) == *block_arg_name { offenses.push(Offense::new( OffenseKind::BlockVsSymbolToProc, - send.expression_l.begin, + call.location().start_offset(), )); } } @@ -416,52 +456,63 @@ fn check_block_vs_symbol_to_proc(send: &Send, block: &Block, offenses: &mut Vec< #[cfg(test)] mod tests { use super::*; - - use crate::ast_visitor::node_children; + use crate::ast_visitor::for_each_direct_child; + use ruby_prism::Node; fn parse_and_collect(source: &[u8]) -> Vec { - let result = lib_ruby_parser::Parser::new(source.to_vec(), Default::default()).do_parse(); + let result = ruby_prism::parse(source); + let result = Box::leak(Box::new(result)); let mut offenses = Vec::new(); - if let Some(ast) = result.ast { - walk_for_offenses(&ast, &mut offenses); - } + let root = result.node(); + walk_for_offenses(&root, &mut offenses); offenses } - /// Walk AST matching real analyzer behavior: Block's inner Send is NOT - /// visited by scan_send (only scan_block handles it). - fn walk_for_offenses(node: &Node, offenses: &mut Vec) { + /// Walk AST matching real analyzer behavior. + fn walk_for_offenses<'pr>(node: &Node<'pr>, offenses: &mut Vec) { match node { - Node::Send(s) => { - if let Some(Node::Block(recv_block)) = s.recv.as_deref() { - offenses.extend(scan_send_on_block(s, recv_block)); - } - offenses.extend(scan_send(s)); - for child in node_children(node) { - walk_for_offenses(child, offenses); + Node::CallNode { .. } => { + let call = node.as_call_node().unwrap(); + + // Check receiver-is-block-call chains + if let Some(recv) = call.receiver() { + if let Some(recv_call) = recv.as_call_node() { + if let Some(Node::BlockNode { .. }) = recv_call.block() { + offenses.extend(scan_call_on_block_call(&call, &recv_call)); + } + } } - } - Node::Block(b) => { - offenses.extend(scan_block(b)); - if let Node::Send(s) = b.call.as_ref() { - if let Some(recv) = &s.recv { - walk_for_offenses(recv, offenses); + + match call.block() { + Some(Node::BlockNode { .. }) => { + let block = call.block().unwrap().as_block_node().unwrap(); + offenses.extend(scan_call_with_block(&call, &block)); + // Walk receiver and arguments + if let Some(recv) = call.receiver() { + walk_for_offenses(&recv, offenses); + } + if let Some(args) = call.arguments() { + for arg in args.arguments().iter() { + walk_for_offenses(&arg, offenses); + } + } + // Walk block body + if let Some(body) = block.body() { + walk_for_offenses(&body, offenses); + } } - for arg in &s.args { - walk_for_offenses(arg, offenses); + _ => { + offenses.extend(scan_call(&call)); + for_each_direct_child(node, &mut |child| { + walk_for_offenses(child, offenses); + }); } } - if let Some(args) = &b.args { - walk_for_offenses(args, offenses); - } - if let Some(body) = &b.body { - walk_for_offenses(body, offenses); - } } _ => { - for child in node_children(node) { + for_each_direct_child(node, &mut |child| { walk_for_offenses(child, offenses); - } + }); } } } @@ -622,8 +673,6 @@ mod tests { assert!(!o.iter().any(|x| x.kind == OffenseKind::BlockVsSymbolToProc)); } - // --- Additional edge case tests --- - #[test] fn first_not_on_shuffle_no_fire() { let o = parse_and_collect(b"arr.first"); @@ -710,7 +759,6 @@ mod tests { #[test] fn block_no_body_no_symbol_to_proc() { - // Empty block body let o = parse_and_collect(b"arr.map { |x| }"); assert!(!o.iter().any(|x| x.kind == OffenseKind::BlockVsSymbolToProc)); } @@ -771,7 +819,6 @@ mod tests { #[test] fn include_on_parenthesized_range() { - // Single-paren range — parsed as Begin(Irange) let o = parse_and_collect(b"(1..10).include?(5)"); assert!( o.iter() @@ -805,7 +852,6 @@ mod tests { #[test] fn keys_each_with_keys_having_args_no_fire() { - // keys("x") is not Hash#keys — should not fire let o = parse_and_collect(b"h.keys(\"x\").each { |k| k }"); assert!(!o.iter().any(|x| x.kind == OffenseKind::KeysEachVsEachKey)); } diff --git a/src/scanner/method_definition_scanner.rs b/src/scanner/method_definition_scanner.rs index d7254c8..f7c0d09 100644 --- a/src/scanner/method_definition_scanner.rs +++ b/src/scanner/method_definition_scanner.rs @@ -1,13 +1,13 @@ -use lib_ruby_parser::Node; -use lib_ruby_parser::nodes::Def; +use ruby_prism::Node; use crate::ast_helpers::{ body_expressions, def_block_arg_name, def_first_arg_name, def_regular_arg_count, }; +use crate::ast_visitor::for_each_descendant; use crate::offense::{Offense, OffenseKind}; /// Scan a method definition for proc_call, getter, and setter offenses. -pub fn scan(def: &Def) -> Vec { +pub fn scan(def: &ruby_prism::DefNode<'_>) -> Vec { let mut offenses = Vec::new(); check_proc_call_vs_yield(def, &mut offenses); @@ -18,48 +18,62 @@ pub fn scan(def: &Def) -> Vec { } /// `def foo(&block); block.call; end` → use `yield` instead. -fn check_proc_call_vs_yield(def: &Def, offenses: &mut Vec) { +fn check_proc_call_vs_yield(def: &ruby_prism::DefNode<'_>, offenses: &mut Vec) { let block_name = match def_block_arg_name(def) { Some(name) => name, None => return, }; - if body_contains_block_call(&def.body, &block_name) { + let body = def.body(); + if body_contains_block_call(&body, &block_name) { offenses.push(Offense::new( OffenseKind::ProcCallVsYield, - def.keyword_l.begin, + def.def_keyword_loc().start_offset(), )); } } -fn body_contains_block_call(body: &Option>, block_name: &str) -> bool { - match body.as_deref() { +fn body_contains_block_call(body: &Option>, block_name: &str) -> bool { + match body { Some(node) => node_contains_block_call(node, block_name), None => false, } } -fn node_contains_block_call(node: &Node, block_name: &str) -> bool { - if let Node::Send(s) = node - && s.method_name == "call" - && let Some(Node::Lvar(lv)) = s.recv.as_deref() - && lv.name == block_name +fn node_contains_block_call(node: &Node<'_>, block_name: &str) -> bool { + if let Some(call) = node.as_call_node() + && call.name().as_slice() == b"call" + && let Some(recv) = call.receiver() + && let Some(lv) = recv.as_local_variable_read_node() + && String::from_utf8_lossy(lv.name().as_slice()) == block_name { return true; } let mut found = false; - crate::ast_visitor::for_each_child(node, |child| { - if !found && node_contains_block_call(child, block_name) { + for_each_descendant(node, &mut |child| { + if !found && node_is_block_call(child, block_name) { found = true; } }); found } +fn node_is_block_call(node: &Node<'_>, block_name: &str) -> bool { + if let Some(call) = node.as_call_node() + && call.name().as_slice() == b"call" + && let Some(recv) = call.receiver() + && let Some(lv) = recv.as_local_variable_read_node() + { + return String::from_utf8_lossy(lv.name().as_slice()) == block_name; + } + false +} + /// `def name; @name; end` → use `attr_reader :name`. -fn check_getter_vs_attr_reader(def: &Def, offenses: &mut Vec) { +fn check_getter_vs_attr_reader(def: &ruby_prism::DefNode<'_>, offenses: &mut Vec) { + let def_name = String::from_utf8_lossy(def.name().as_slice()).to_string(); // Must not be a setter (name ends with =) - if def.name.ends_with('=') { + if def_name.ends_with('=') { return; } // Must have 0 arguments @@ -67,26 +81,29 @@ fn check_getter_vs_attr_reader(def: &Def, offenses: &mut Vec) { return; } // Body must be a single ivar read matching @ - let exprs = body_expressions(&def.body); + let body = def.body(); + let exprs = body_expressions(&body); if exprs.len() != 1 { return; } - if let Node::Ivar(iv) = exprs[0] { - let expected_ivar = format!("@{}", def.name); - if iv.name == expected_ivar { + if let Some(iv) = exprs[0].as_instance_variable_read_node() { + let ivar_name = String::from_utf8_lossy(iv.name().as_slice()).to_string(); + let expected_ivar = format!("@{}", def_name); + if ivar_name == expected_ivar { offenses.push(Offense::new( OffenseKind::GetterVsAttrReader, - def.keyword_l.begin, + def.def_keyword_loc().start_offset(), )); } } } /// `def name=(value); @name = value; end` → use `attr_writer :name`. -fn check_setter_vs_attr_writer(def: &Def, offenses: &mut Vec) { +fn check_setter_vs_attr_writer(def: &ruby_prism::DefNode<'_>, offenses: &mut Vec) { + let def_name = String::from_utf8_lossy(def.name().as_slice()).to_string(); // Must be a setter - let base_name = match def.name.strip_suffix('=') { - Some(n) => n, + let base_name = match def_name.strip_suffix('=') { + Some(n) => n.to_string(), None => return, }; // Must have exactly 1 regular argument @@ -98,22 +115,24 @@ fn check_setter_vs_attr_writer(def: &Def, offenses: &mut Vec) { None => return, }; // Body must be a single ivar assignment - let exprs = body_expressions(&def.body); + let body = def.body(); + let exprs = body_expressions(&body); if exprs.len() != 1 { return; } - if let Node::Ivasgn(ia) = exprs[0] { + if let Some(ia) = exprs[0].as_instance_variable_write_node() { + let ivar_name = String::from_utf8_lossy(ia.name().as_slice()).to_string(); let expected_ivar = format!("@{}", base_name); - if ia.name != expected_ivar { + if ivar_name != expected_ivar { return; } // The assigned value must be the argument - if let Some(Node::Lvar(lv)) = ia.value.as_deref() - && lv.name == arg_name + if let Some(lv) = ia.value().as_local_variable_read_node() + && String::from_utf8_lossy(lv.name().as_slice()) == arg_name { offenses.push(Offense::new( OffenseKind::SetterVsAttrWriter, - def.keyword_l.begin, + def.def_keyword_loc().start_offset(), )); } } @@ -122,25 +141,23 @@ fn check_setter_vs_attr_writer(def: &Def, offenses: &mut Vec) { #[cfg(test)] mod tests { use super::*; - - use crate::ast_visitor::node_children; + use crate::ast_visitor::for_each_direct_child; fn parse_and_scan(source: &[u8]) -> Vec { - let result = lib_ruby_parser::Parser::new(source.to_vec(), Default::default()).do_parse(); + let result = ruby_prism::parse(source); + let result = Box::leak(Box::new(result)); let mut offenses = Vec::new(); - if let Some(ast) = result.ast { - collect_def_offenses(&ast, &mut offenses); - } + collect_def_offenses(&result.node(), &mut offenses); offenses } - fn collect_def_offenses(node: &Node, offenses: &mut Vec) { - if let Node::Def(d) = node { - offenses.extend(scan(d)); + fn collect_def_offenses<'pr>(node: &Node<'pr>, offenses: &mut Vec) { + if let Some(d) = node.as_def_node() { + offenses.extend(scan(&d)); } - for child in node_children(node) { + for_each_direct_child(node, &mut |child| { collect_def_offenses(child, offenses); - } + }); } #[test] @@ -285,7 +302,6 @@ mod tests { #[test] fn setter_name_method_is_not_getter() { - // name= should not trigger getter check let offenses = parse_and_scan(b"def name=(v); @name = v; end"); assert!( !offenses diff --git a/src/scanner/rescue_scanner.rs b/src/scanner/rescue_scanner.rs index fddc557..06d0596 100644 --- a/src/scanner/rescue_scanner.rs +++ b/src/scanner/rescue_scanner.rs @@ -1,60 +1,64 @@ -use lib_ruby_parser::Node; -use lib_ruby_parser::nodes::RescueBody; +use ruby_prism::Node; use crate::offense::{Offense, OffenseKind}; /// Fires when a rescue clause catches `NoMethodError`. -pub fn scan(node: &RescueBody) -> Vec { +pub fn scan(node: &ruby_prism::RescueNode<'_>) -> Vec { if rescues_no_method_error(node) { vec![Offense::new( OffenseKind::RescueVsRespondTo, - node.keyword_l.begin, + node.keyword_loc().start_offset(), )] } else { vec![] } } -fn rescues_no_method_error(rb: &RescueBody) -> bool { - let exc_list = match rb.exc_list.as_deref() { - Some(node) => node, - None => return false, - }; - - match exc_list { - Node::Array(arr) => arr.elements.iter().any(is_no_method_error_const), - node => is_no_method_error_const(node), +fn rescues_no_method_error(rb: &ruby_prism::RescueNode<'_>) -> bool { + let exceptions: Vec> = rb.exceptions().iter().collect(); + if exceptions.is_empty() { + return false; } + exceptions.iter().any(|exc| is_no_method_error_const(exc)) } -fn is_no_method_error_const(node: &Node) -> bool { - match node { - Node::Const(c) => c.name == "NoMethodError" && c.scope.is_none(), - _ => false, +fn is_no_method_error_const(node: &Node<'_>) -> bool { + if let Some(c) = node.as_constant_read_node() { + c.name().as_slice() == b"NoMethodError" + } else { + false } } #[cfg(test)] mod tests { use super::*; - - use crate::ast_visitor::node_children; + use crate::ast_visitor::for_each_direct_child; fn parse_and_find_rescue_bodies(source: &[u8]) -> Vec { - let result = lib_ruby_parser::Parser::new(source.to_vec(), Default::default()).do_parse(); + let result = ruby_prism::parse(source); + let result = Box::leak(Box::new(result)); let mut offenses = Vec::new(); - if let Some(ast) = result.ast { - collect_rescue_offenses(&ast, &mut offenses); - } + collect_rescue_offenses(&result.node(), &mut offenses); offenses } - fn collect_rescue_offenses(node: &Node, offenses: &mut Vec) { - if let Node::RescueBody(rb) = node { - offenses.extend(scan(rb)); + fn collect_rescue_offenses<'pr>(node: &Node<'pr>, offenses: &mut Vec) { + // For BeginNode, we need to access the rescue clause specially + if let Some(begin) = node.as_begin_node() { + if let Some(rescue) = begin.rescue_clause() { + collect_from_rescue_chain(&rescue, offenses); + } } - for child in node_children(node) { + for_each_direct_child(node, &mut |child| { collect_rescue_offenses(child, offenses); + }); + } + + fn collect_from_rescue_chain(rescue: &ruby_prism::RescueNode<'_>, offenses: &mut Vec) { + offenses.extend(scan(rescue)); + if let Some(subsequent) = rescue.subsequent() { + collect_from_rescue_chain(&subsequent, offenses); } } From e612e820872990ca9eaa11e71d7b8ee3ac796877 Mon Sep 17 00:00:00 2001 From: Zac <579103+7a6163@users.noreply.github.com> Date: Mon, 16 Mar 2026 16:52:39 +0800 Subject: [PATCH 3/4] chore: bump version to 1.3.0 --- CHANGELOG.md | 14 ++++++++++++++ Cargo.lock | 2 +- Cargo.toml | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b1c694c..259d490 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +## [1.3.0] - 2026-03-16 + +### Changed + +- Migrate parser from `lib-ruby-parser` (Ruby 3.1.2) to `ruby-prism` (Ruby 3.3+) + - Support Ruby 3.2, 3.3, 3.4+ syntax + - Native handling of all encodings (ASCII, US-ASCII, etc.) — no custom decoder needed + - Error-tolerant parsing: always produces an AST, even with syntax errors + - Prism is Ruby's official default parser since 3.3 + +### Removed + +- `lib-ruby-parser` dependency replaced by `ruby-prism` + ## [1.2.5] - 2026-03-16 ### Fixed diff --git a/Cargo.lock b/Cargo.lock index ff1ad06..a7e1ef3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -502,7 +502,7 @@ dependencies = [ [[package]] name = "rubyfast" -version = "1.2.5" +version = "1.3.0" dependencies = [ "anyhow", "clap", diff --git a/Cargo.toml b/Cargo.toml index bbb67b0..7047db5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rubyfast" -version = "1.2.5" +version = "1.3.0" edition = "2024" description = "An ultra-fast Ruby performance linter rewritten in Rust — detects 19 common anti-patterns" license = "MIT" From 2fe461408d2043780a83920441a3f21c1f63f83b Mon Sep 17 00:00:00 2001 From: Zac <579103+7a6163@users.noreply.github.com> Date: Mon, 16 Mar 2026 17:12:12 +0800 Subject: [PATCH 4/4] fix: address code review findings from prism migration - Visit DefNode parameters (default values can contain scannable code) - Remove unsafe ptr::read; replace body_expressions with safe body_expression_count/body_first_expression - Remove duplicate visit_rescue_chain_children (identical to visit_rescue_children) - Remove examples/prism_spike.rs (exploratory spike, not needed in release) --- examples/prism_spike.rs | 231 ----------------------- src/ast_helpers.rs | 75 +++----- src/ast_visitor.rs | 29 +-- src/scanner/method_definition_scanner.rs | 21 ++- 4 files changed, 47 insertions(+), 309 deletions(-) delete mode 100644 examples/prism_spike.rs diff --git a/examples/prism_spike.rs b/examples/prism_spike.rs deleted file mode 100644 index 79f3fdb..0000000 --- a/examples/prism_spike.rs +++ /dev/null @@ -1,231 +0,0 @@ -//! Spike: Verify ruby-prism API for the migration from lib-ruby-parser. -//! -//! Run with: cargo run --example prism_spike - -use ruby_prism::*; - -fn main() { - println!("=== 1. Block/Call structure ==="); - spike_block_call(); - - println!("\n=== 2. Send on block (chained call) ==="); - spike_send_on_block(); - - println!("\n=== 3. For loop locations ==="); - spike_for_loop(); - - println!("\n=== 4. Method definition ==="); - spike_def(); - - println!("\n=== 5. Rescue body ==="); - spike_rescue(); - - println!("\n=== 6. Range types ==="); - spike_range(); - - println!("\n=== 7. String / Integer values ==="); - spike_values(); - - println!("\n=== 8. Comments ==="); - spike_comments(); - - println!("\n=== 9. Error tolerance ==="); - spike_errors(); - - println!("\n=== 10. ASCII encoding ==="); - spike_encoding(); - - println!("\n=== 11. Block pass (symbol to proc) ==="); - spike_block_pass(); - - println!("\n=== 12. Child nodes / visitor ==="); - spike_visitor(); - - println!("\n=== 13. ERB template (expected error) ==="); - spike_erb(); - - println!("\n=== 14. Node enum pattern matching ==="); - spike_pattern_match(); -} - -fn spike_block_call() { - // arr.map { |x| x.to_s } - let source = b"arr.map { |x| x.to_s }"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - println!("Root: {:#?}", result.node()); -} - -fn spike_send_on_block() { - // arr.select { |x| x > 1 }.first - let source = b"arr.select { |x| x > 1 }.first"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - println!("Root: {:#?}", result.node()); -} - -fn spike_for_loop() { - let source = b"for x in [1, 2, 3]\n puts x\nend"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - println!("Root: {:#?}", result.node()); -} - -fn spike_def() { - let source = b"def foo(a, &block); block.call; end"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - println!("Root: {:#?}", result.node()); -} - -fn spike_rescue() { - let source = b"begin; x; rescue NoMethodError => e; retry; end"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - println!("Root: {:#?}", result.node()); -} - -fn spike_range() { - let source = b"(1..10).include?(5); (1...10).cover?(5)"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - println!("Root: {:#?}", result.node()); -} - -fn spike_values() { - let source = b"'x'; 42; 1; :sym"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - println!("Root: {:#?}", result.node()); -} - -fn spike_comments() { - let source = b"x = 1 # rubyfast:disable shuffle_first_vs_sample\ny = 2\n"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - for comment in result.comments() { - let loc = comment.location(); - println!( - " Comment at {}..{}: {:?}", - loc.start_offset(), - loc.end_offset(), - std::str::from_utf8(loc.as_slice()).unwrap_or("") - ); - } -} - -fn spike_errors() { - let source = b"def foo; end; def def; end"; - let result = parse(source); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - let error_count: usize = result.errors().count(); - println!("Errors: {}", error_count); - for err in result.errors() { - println!(" Error: {:?}", err.message()); - } - println!("Has AST: true (prism always produces one)"); - println!("Root: {:#?}", result.node()); -} - -fn spike_encoding() { - let source = b"# encoding: us-ascii\nx = 1\n"; - let result = parse(source); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - let error_count: usize = result.errors().count(); - println!("Errors: {}", error_count); - print_errors(&result); - println!("Root: {:#?}", result.node()); -} - -fn spike_block_pass() { - let source = b"arr.map(&:to_s)"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - println!("Root: {:#?}", result.node()); -} - -fn spike_visitor() { - let source = b"arr.select { |x| x > 1 }.first"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - - // Use the Visit trait - struct CallCounter { - count: usize, - } - impl<'pr> Visit<'pr> for CallCounter { - fn visit_call_node(&mut self, node: &CallNode<'pr>) { - println!( - " Found CallNode: name={:?}, has_receiver={}, has_block={}", - std::str::from_utf8(node.name().as_slice()).unwrap_or("?"), - node.receiver().is_some(), - node.block().is_some() - ); - self.count += 1; - // Must call the default visitor to recurse into children - visit_call_node(self, node); - } - } - - let mut counter = CallCounter { count: 0 }; - counter.visit(&result.node()); - println!(" Total CallNodes found: {}", counter.count); -} - -fn spike_erb() { - let source = b"class Foo < ActiveRecord::Migration<%= migration_version %>\nend"; - let result = parse(source); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - let error_count: usize = result.errors().count(); - println!("Errors: {}", error_count); - for err in result.errors() { - println!(" Error: {:?}", err.message()); - } -} - -fn spike_pattern_match() { - let source = b"arr.map { |x| x.to_s }"; - let result = parse(source); - print_errors(&result); - println!("Source: {:?}", std::str::from_utf8(source).unwrap()); - - // Walk the top-level statements - let program = result.node(); - if let Node::ProgramNode { .. } = &program { - let prog = program.as_program_node().unwrap(); - let stmts = prog.statements(); - for node in stmts.body().iter() { - println!(" Top-level node variant:"); - match &node { - Node::CallNode { .. } => { - let call = node.as_call_node().unwrap(); - println!( - " CallNode: name={:?}, has_block={}", - std::str::from_utf8(call.name().as_slice()).unwrap_or("?"), - call.block().is_some() - ); - if let Some(block) = call.block() { - println!(" Block: {:#?}", block); - } - } - other => println!(" {:?}", other), - } - } - } -} - -fn print_errors(result: &ParseResult) { - for err in result.errors() { - println!(" PARSE ERROR: {:?}", err.message()); - } -} diff --git a/src/ast_helpers.rs b/src/ast_helpers.rs index 4f02912..b4a74ad 100644 --- a/src/ast_helpers.rs +++ b/src/ast_helpers.rs @@ -200,58 +200,32 @@ pub fn str_contains_def(node: &Node<'_>) -> bool { false } -/// Get the number of top-level expressions in a body node, and optionally the single expression. -/// Returns (count, Option) — the option is Some only when count == 1. -pub fn body_single_expression<'pr>(body: &Option>) -> (usize, Option>) { +/// Count the number of top-level expressions in a body node. +pub fn body_expression_count(body: &Option>) -> usize { match body { - None => (0, None), + None => 0, Some(node) => { if let Some(stmts) = node.as_statements_node() { - let body_nodes: Vec<_> = stmts.body().iter().collect(); - let count = body_nodes.len(); - if count == 1 { - (1, Some(body_nodes.into_iter().next().unwrap())) - } else { - (count, None) - } + stmts.body().iter().count() } else { - // Single expression body (no StatementsNode wrapper). - // We need to return the node itself. Since prism's Node is just - // a thin wrapper around pointers, we can reconstruct it from the body. - // Re-access from the parent to get an owned Node. - (1, body.as_ref().map(|n| reconstruct_node_from_body(n))) + 1 } } } } -/// Helper: given a reference to a Node, produce a new owned Node with the same data. -/// This works because ruby_prism Node variants are just (parser, pointer, marker) tuples -/// and the data is borrowed from the parse result, not owned. -fn reconstruct_node_from_body<'pr>(node: &Node<'pr>) -> Node<'pr> { - // The body_expressions approach doesn't work because Node isn't Clone. - // Instead, callers should re-call body() to get a fresh owned Node. - // This function exists as a workaround: since all callers already have - // the def/body, they can just re-call .body() to get an owned Node. - // - // For now, we'll use unsafe to transmute since the Node is just pointers. - // Safety: Node<'pr> is a repr(C)-like enum of (parser, pointer, marker) - // and the lifetime is tied to the ParseResult. Copying the pointer data is safe - // as long as the ParseResult outlives the copy. - unsafe { std::ptr::read(node as *const Node<'pr>) } -} - -/// Get expressions from a body node. If it's a StatementsNode, return its body items. -/// Otherwise return a single-element vec. -/// IMPORTANT: Caller must ensure `body` outlives the returned Vec. -pub fn body_expressions<'pr>(body: &Option>) -> Vec> { +/// Get the first expression from a body node (if any). +/// The body must be re-obtained from the parent to produce an owned Node. +pub fn body_first_expression<'pr>(body: &Option>) -> Option> { match body { - None => vec![], + None => None, Some(node) => { if let Some(stmts) = node.as_statements_node() { - stmts.body().iter().collect() + stmts.body().iter().next() } else { - vec![reconstruct_node_from_body(node)] + // Single expression body — caller should re-call parent.body() + // to get an owned Node. We can't clone the reference. + None } } } @@ -520,24 +494,31 @@ mod tests { } #[test] - fn body_expressions_none() { - assert!(body_expressions(&None).is_empty()); + fn body_expression_count_none() { + assert_eq!(body_expression_count(&None), 0); } #[test] - fn body_expressions_single() { + fn body_expression_count_single() { let node = parse_first_stmt(b"def foo; 42; end"); let def = node.as_def_node().unwrap(); - let exprs = body_expressions(&def.body()); - assert_eq!(exprs.len(), 1); + assert_eq!(body_expression_count(&def.body()), 1); } #[test] - fn body_expressions_begin() { + fn body_expression_count_multiple() { let node = parse_first_stmt(b"def foo; 1; 2; 3; end"); let def = node.as_def_node().unwrap(); - let exprs = body_expressions(&def.body()); - assert_eq!(exprs.len(), 3); + assert_eq!(body_expression_count(&def.body()), 3); + } + + #[test] + fn body_first_expression_works() { + let node = parse_first_stmt(b"def foo; 42; end"); + let def = node.as_def_node().unwrap(); + let first = body_first_expression(&def.body()); + assert!(first.is_some()); + assert!(first.unwrap().as_integer_node().is_some()); } #[test] diff --git a/src/ast_visitor.rs b/src/ast_visitor.rs index 7011e94..1b1eaf4 100644 --- a/src/ast_visitor.rs +++ b/src/ast_visitor.rs @@ -59,7 +59,9 @@ pub fn for_each_direct_child<'pr>(node: &Node<'pr>, f: &mut impl FnMut(&Node<'pr } Node::DefNode { .. } => { let n = node.as_def_node().unwrap(); - // Skip parameters - they don't contain scannable code + if let Some(params) = n.parameters() { + f(¶ms.as_node()); + } if let Some(body) = n.body() { f(&body); } @@ -82,7 +84,7 @@ pub fn for_each_direct_child<'pr>(node: &Node<'pr>, f: &mut impl FnMut(&Node<'pr } } if let Some(rescue) = n.rescue_clause() { - visit_rescue_chain_children(&rescue, f); + visit_rescue_children(&rescue, f); } if let Some(else_clause) = n.else_clause() && let Some(stmts) = else_clause.statements() @@ -569,7 +571,7 @@ pub fn for_each_direct_child<'pr>(node: &Node<'pr>, f: &mut impl FnMut(&Node<'pr } } -/// Visit children of a RescueNode. +/// Visit children of a RescueNode and its chain of subsequent clauses. fn visit_rescue_children<'pr>( rescue: &ruby_prism::RescueNode<'pr>, f: &mut impl FnMut(&Node<'pr>), @@ -590,27 +592,6 @@ fn visit_rescue_children<'pr>( } } -/// Visit all descendants through a rescue chain (rescue -> subsequent -> ...). -fn visit_rescue_chain_children<'pr>( - rescue: &ruby_prism::RescueNode<'pr>, - f: &mut impl FnMut(&Node<'pr>), -) { - for exc in rescue.exceptions().iter() { - f(&exc); - } - if let Some(reference) = rescue.reference() { - f(&reference); - } - if let Some(stmts) = rescue.statements() { - for child in stmts.body().iter() { - f(&child); - } - } - if let Some(subsequent) = rescue.subsequent() { - visit_rescue_chain_children(&subsequent, f); - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/scanner/method_definition_scanner.rs b/src/scanner/method_definition_scanner.rs index f7c0d09..e06162c 100644 --- a/src/scanner/method_definition_scanner.rs +++ b/src/scanner/method_definition_scanner.rs @@ -1,7 +1,8 @@ use ruby_prism::Node; use crate::ast_helpers::{ - body_expressions, def_block_arg_name, def_first_arg_name, def_regular_arg_count, + body_expression_count, body_first_expression, def_block_arg_name, def_first_arg_name, + def_regular_arg_count, }; use crate::ast_visitor::for_each_descendant; use crate::offense::{Offense, OffenseKind}; @@ -82,11 +83,14 @@ fn check_getter_vs_attr_reader(def: &ruby_prism::DefNode<'_>, offenses: &mut Vec } // Body must be a single ivar read matching @ let body = def.body(); - let exprs = body_expressions(&body); - if exprs.len() != 1 { + if body_expression_count(&body) != 1 { return; } - if let Some(iv) = exprs[0].as_instance_variable_read_node() { + let single = body_first_expression(&body).or(body); + if let Some(iv) = single + .as_ref() + .and_then(|n| n.as_instance_variable_read_node()) + { let ivar_name = String::from_utf8_lossy(iv.name().as_slice()).to_string(); let expected_ivar = format!("@{}", def_name); if ivar_name == expected_ivar { @@ -116,11 +120,14 @@ fn check_setter_vs_attr_writer(def: &ruby_prism::DefNode<'_>, offenses: &mut Vec }; // Body must be a single ivar assignment let body = def.body(); - let exprs = body_expressions(&body); - if exprs.len() != 1 { + if body_expression_count(&body) != 1 { return; } - if let Some(ia) = exprs[0].as_instance_variable_write_node() { + let single = body_first_expression(&body).or(body); + if let Some(ia) = single + .as_ref() + .and_then(|n| n.as_instance_variable_write_node()) + { let ivar_name = String::from_utf8_lossy(ia.name().as_slice()).to_string(); let expected_ivar = format!("@{}", base_name); if ivar_name != expected_ivar {