diff --git a/ndc_analyser/src/analyser.rs b/ndc_analyser/src/analyser.rs index 2dae6a01..60a5d380 100644 --- a/ndc_analyser/src/analyser.rs +++ b/ndc_analyser/src/analyser.rs @@ -1,11 +1,14 @@ use std::collections::HashMap; use std::fmt::Debug; -use crate::scope::ScopeTree; -use itertools::Itertools; +use crate::scope::{ScopeTree, TypeBinding}; +use itertools::{Itertools, izip}; use ndc_core::{StaticType, TypeSignature}; use ndc_lexer::Span; -use ndc_parser::{Binding, Expression, ExpressionLocation, ForBody, ForIteration, Lvalue, NodeId}; +use ndc_parser::{ + Binding, Expression, ExpressionLocation, ForBody, ForIteration, FunctionParameter, Lvalue, + NodeId, +}; /// Side table holding semantic information keyed by AST node identity. /// Keeps tooling-specific data (like per-expression types) out of the AST. @@ -13,6 +16,9 @@ use ndc_parser::{Binding, Expression, ExpressionLocation, ForBody, ForIteration, pub struct AnalysisResult { /// Maps each expression node to its inferred result type. pub expr_types: HashMap, + /// Inferred return types for functions without explicit annotations. + /// Keyed by the FunctionDeclaration's `NodeId`. + pub inferred_return_types: HashMap, /// Errors accumulated during analysis. Non-empty when the analyser /// encountered problems but was able to continue with fallback types. pub errors: Vec, @@ -98,7 +104,9 @@ impl Analyser { fn analyse_inner( &mut self, ExpressionLocation { - expression, span, .. + expression, + span, + id, }: &mut ExpressionLocation, ) -> Result { match expression { @@ -142,25 +150,41 @@ impl Analyser { Ok(StaticType::Bool) } Expression::Grouping(expr) => self.analyse(expr), - Expression::VariableDeclaration { l_value, value } => { - let typ = self.analyse_or_any(value); - self.resolve_lvalue_declarative(l_value, typ, *span); + Expression::VariableDeclaration { + l_value, + annotated_type, + value, + } => { + let found_type = self.analyse_or_any(value); + + self.resolve_lvalue_declarative( + l_value, + annotated_type.to_owned(), + found_type.clone(), + *span, + ); Ok(StaticType::unit()) } Expression::Assignment { l_value, r_value } => { let old_type = self.resolve_lvalue_or_any(l_value, *span); let new_type = self.analyse_or_any(r_value); - // Widen the binding's type to the LUB so subsequent uses - // see the broader type. if let Lvalue::Identifier { resolved: Some(target), .. } = l_value { let widened = old_type.lub(&new_type); - if widened != old_type { - self.scope_tree.update_binding_type(*target, widened); + if widened != old_type + && let Err(annotated_type) = + self.scope_tree.update_binding_type(*target, widened) + && !new_type.is_subtype(&annotated_type) + { + self.emit(AnalysisError::mismatched_types( + &new_type, + &annotated_type, + *span, + )); } } @@ -190,42 +214,112 @@ impl Analyser { )); } + // Determine the result type of the operation + let result_type = match resolved_operation { + Binding::Resolved(res) => { + if let StaticType::Function { return_type, .. } = + self.scope_tree.get_type(*res) + { + Some(return_type.as_ref().clone()) + } else { + None + } + } + _ => None, + }; + + if let Some(result_type) = result_type { + match l_value { + // Direct variable: widen or reject if annotated + Lvalue::Identifier { + resolved: Some(target), + .. + } => { + let widened = arg_types[0].lub(&result_type); + if widened != arg_types[0] + && let Err(annotated_type) = + self.scope_tree.update_binding_type(*target, widened) + && !result_type.is_subtype(&annotated_type) + { + self.emit(AnalysisError::mismatched_types( + &result_type, + &annotated_type, + *span, + )); + } + } + // Index into a container: widen the container's type + Lvalue::Index { value, .. } => { + if let Expression::Identifier { + resolved: Binding::Resolved(target), + .. + } = &value.expression + { + let container_type = self.scope_tree.get_type(*target).clone(); + if let Some(elem_type) = container_type.index_element_type() { + let widened_elem = elem_type.lub(&result_type); + if widened_elem != elem_type { + let new_container = + container_type.with_element_type(widened_elem); + let _ = self + .scope_tree + .update_binding_type(*target, new_container); + } + } + } + } + _ => {} + } + } + Ok(StaticType::unit()) } Expression::FunctionDeclaration { name, resolved_name, - type_signature, + parameters, body, return_type: return_type_slot, captures, .. } => { + let type_signature = FunctionParameter::to_type_signature(parameters); + // Pre-register the function before analysing its body so recursive calls can // resolve the name. The return type is unknown at this point so we use Any. - let pre_slot = if let Some(name) = name { - let arity = type_signature.types().map(|t| t.len()); - if self.scope_tree.has_function_in_current_scope(name, arity) { - self.emit(AnalysisError::function_redefinition(name, arity, *span)); - // Skip re-registering but still analyse the body below. - None + let pre_slot = + if let Some(name) = name { + let arity = type_signature.types().map(|t| t.len()); + if self.scope_tree.has_function_in_current_scope(name, arity) { + self.emit(AnalysisError::function_redefinition(name, arity, *span)); + // Skip re-registering but still analyse the body below. + None + } else { + let placeholder = StaticType::Function { + parameters: type_signature.types(), + return_type: Box::new( + return_type_slot.clone().unwrap_or(StaticType::Any), + ), + }; + Some(self.scope_tree.create_local_binding( + name.clone(), + TypeBinding::Inferred(placeholder), + )) + } } else { - let placeholder = StaticType::Function { - parameters: type_signature.types(), - return_type: Box::new(StaticType::Any), - }; - Some( - self.scope_tree - .create_local_binding(name.clone(), placeholder), - ) - } - } else { - None - }; + None + }; self.scope_tree.new_function_scope(); self.return_type_stack.push(None); - let param_types = self.resolve_parameters_declarative(type_signature, *span); + let param_types = self.resolve_parameters_declarative(&type_signature, *span); + + // Fill inferred_type on parameter Lvalues for LSP hints. + for (p, typ) in parameters.iter_mut().zip(¶m_types) { + if let Lvalue::Identifier { inferred_type, .. } = &mut p.lvalue { + *inferred_type = Some(typ.clone()); + } + } let implicit_return = self.analyse_or_any(body); let explicit_return = self.return_type_stack.pop().unwrap(); @@ -233,23 +327,37 @@ impl Analyser { self.scope_tree.destroy_scope(); // Combine explicit `return` types with the block's implicit return type. - let return_type = match explicit_return { + let inferred_return = match explicit_return { Some(ret) => ret.lub(&implicit_return), None => implicit_return, }; - *return_type_slot = Some(return_type); + + // If there is an annotated return type, validate it; + // otherwise record the inferred type in the side table. + if let Some(annotated) = return_type_slot { + if !inferred_return.is_subtype(annotated) { + self.emit(AnalysisError::mismatched_types( + &inferred_return, + annotated, + *span, + )); + } + } else { + self.result + .inferred_return_types + .insert(*id, inferred_return.clone()); + } + + let effective_return = return_type_slot.clone().unwrap_or(inferred_return); let function_type = StaticType::Function { parameters: Some(param_types.clone()), - return_type: Box::new( - return_type_slot - .clone() - .expect("must have a value at this point"), - ), + return_type: Box::new(effective_return), }; if let Some(slot) = pre_slot { - self.scope_tree + let _ = self + .scope_tree .update_binding_type(slot, function_type.clone()); *resolved_name = Some(slot); } @@ -401,10 +509,22 @@ impl Analyser { } Binding::Resolved(res) => self.scope_tree.get_type(*res).clone(), - Binding::Dynamic(_) => StaticType::Function { - parameters: None, - return_type: Box::new(StaticType::Any), - }, + Binding::Dynamic(candidates) => { + let return_type = candidates + .iter() + .map(|c| self.scope_tree.get_type(*c).clone()) + .filter_map(|t| match t { + StaticType::Function { return_type, .. } => Some(*return_type), + _ => None, + }) + .reduce(|a, b| a.lub(&b)) + .unwrap_or(StaticType::Any); + + StaticType::Function { + parameters: None, + return_type: Box::new(return_type), + } + } }; *resolved = binding; @@ -429,13 +549,14 @@ impl Analyser { self.scope_tree.new_iteration_scope(); - self.resolve_lvalue_declarative( - l_value, - sequence_type - .sequence_element_type() - .unwrap_or(StaticType::Any), - span, - ); + let found_type = sequence_type + .sequence_element_type() + .unwrap_or(StaticType::Any); + + // TOOD: get this from the AST when the parser adds it + let expected_type = None; + + self.resolve_lvalue_declarative(l_value, expected_type, found_type, span); do_destroy = true; } ForIteration::Guard(expr) => { @@ -585,33 +706,62 @@ impl Analyser { let mut seen_names: Vec<&str> = Vec::new(); for param in parameters { - types.push(StaticType::Any); + let has_annotation = param.type_name != StaticType::Any; + let binding = if has_annotation { + TypeBinding::Annotated(param.type_name.clone()) + } else { + TypeBinding::Inferred(StaticType::Any) + }; + + types.push(param.type_name.clone()); if seen_names.contains(¶m.name.as_str()) { self.emit(AnalysisError::parameter_redefined(¶m.name, span)); - // Skip duplicate but continue checking remaining params. continue; } seen_names.push(¶m.name); self.scope_tree - .create_local_binding(param.name.clone(), StaticType::Any); + .create_local_binding(param.name.clone(), binding); } types } - fn resolve_lvalue_declarative(&mut self, lvalue: &mut Lvalue, typ: StaticType, span: Span) { + fn resolve_lvalue_declarative( + &mut self, + lvalue: &mut Lvalue, + expected_type: Option, + found_type: StaticType, + span: Span, + ) { match lvalue { Lvalue::Identifier { identifier, resolved, inferred_type, - .. + span, } => { + // If there is a type annotation and the given type is not a subtype of the annotated type we emit an error + if let Some(expected_type) = &expected_type + && !found_type.is_subtype(expected_type) + { + self.emit(AnalysisError::mismatched_types( + &found_type, + expected_type, + *span, + )); + } + + let type_binding = match expected_type { + Some(annotated) => TypeBinding::Annotated(annotated), + None => TypeBinding::Inferred(found_type), + }; + *resolved = Some( self.scope_tree - .create_local_binding(identifier.clone(), typ.clone()), + .create_local_binding(identifier.clone(), type_binding.clone()), ); - *inferred_type = Some(typ); + + *inferred_type = Some(type_binding.typ().clone()) } Lvalue::Index { index, value, .. } => { self.analyse_or_any(index); @@ -623,25 +773,40 @@ impl Analyser { // can happen when a variable is declared with one type (e.g. ()) // and later reassigned to a tuple of a different arity — the // analyser doesn't track reassignment types. + let is_annotated = expected_type.is_some(); + let resolved_type = expected_type.unwrap_or(found_type.clone()); + let sub_types: Box> = - if let StaticType::Tuple(elems) = &typ { + if let StaticType::Tuple(elems) = &resolved_type { if elems.len() != seq.len() { Box::new(std::iter::repeat(&StaticType::Any)) } else { Box::new(elems.iter()) } - } else if let Some(iter) = typ.unpack() { + } else if let Some(iter) = resolved_type.unpack() { iter } else { - self.emit(AnalysisError::unable_to_unpack_type(&typ, span)); + self.emit(AnalysisError::unable_to_unpack_type(&resolved_type, span)); return; }; - for (sub_lvalue, sub_lvalue_type) in seq.iter_mut().zip(sub_types) { + let found_types = found_type + .unpack() + .unwrap_or_else(|| Box::new(std::iter::repeat(&StaticType::Any))); + + for (sub_lvalue, sub_type, found_type) in + izip!(seq.iter_mut(), sub_types, found_types) + { + let sub_expected = if is_annotated { + Some(sub_type.clone()) + } else { + None + }; self.resolve_lvalue_declarative( sub_lvalue, - sub_lvalue_type.clone(), - /* todo: figure out how to narrow this span */ span, + sub_expected, + found_type.clone(), + span, ); } } @@ -678,6 +843,14 @@ impl AnalysisError { pub fn span(&self) -> Span { self.span } + + fn mismatched_types(found: &StaticType, expected: &StaticType, span: Span) -> Self { + Self { + text: format!("mismatched types: found {found} but expected {expected}"), + span, + } + } + fn function_redefinition(name: &str, arity: Option, span: Span) -> Self { let arity_desc = match arity { Some(n) => format!("{n} parameter{}", if n == 1 { "" } else { "s" }), diff --git a/ndc_analyser/src/scope.rs b/ndc_analyser/src/scope.rs index 69107bad..bb6f5a9d 100644 --- a/ndc_analyser/src/scope.rs +++ b/ndc_analyser/src/scope.rs @@ -2,13 +2,37 @@ use ndc_core::StaticType; use ndc_parser::{Binding, CaptureSource, ResolvedVar}; use std::fmt::{Debug, Formatter}; +#[derive(Debug, Clone)] +pub(crate) enum TypeBinding { + Inferred(StaticType), + Annotated(StaticType), +} + +impl TypeBinding { + pub fn typ(&self) -> &StaticType { + match self { + Self::Inferred(t) | Self::Annotated(t) => t, + } + } + + pub fn is_annotated(&self) -> bool { + matches!(self, Self::Annotated(_)) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ScopeBinding { + pub name: String, + pub binding: TypeBinding, +} + #[derive(Debug, Clone)] pub(crate) struct Scope { parent_idx: Option, creates_environment: bool, // Only true for function scopes and for-loop iterations base_offset: usize, function_scope_idx: usize, - identifiers: Vec<(String, StaticType)>, + identifiers: Vec, upvalues: Vec<(String, CaptureSource)>, } @@ -56,7 +80,7 @@ impl Scope { pub(crate) fn find_slot_by_name(&self, find_ident: &str) -> Option { self.identifiers .iter() - .rposition(|(ident, _)| ident == find_ident) + .rposition(|b| b.name == find_ident) .map(|idx| idx + self.base_offset) } @@ -68,8 +92,8 @@ impl Scope { self.identifiers .iter() .enumerate() - .filter_map(|(slot, (ident, typ))| { - if ident == find_ident && typ.could_be_callable() { + .filter_map(|(slot, b)| { + if b.name == find_ident && b.binding.typ().could_be_callable() { Some(slot + self.base_offset) } else { None @@ -82,13 +106,13 @@ impl Scope { self.identifiers.iter() .enumerate() .rev() - .filter_map(|(slot, (ident, typ))| { - if ident != find_ident { + .filter_map(|(slot, b)| { + if b.name != find_ident { return None; } // If the thing is not a function we're not interested - let StaticType::Function { parameters, .. } = typ else { + let StaticType::Function { parameters, .. } = b.binding.typ() else { return None; }; @@ -107,17 +131,17 @@ impl Scope { fn find_function(&self, find_ident: &str, find_types: &[StaticType]) -> Option { self.identifiers .iter() - .rposition(|(ident, typ)| ident == find_ident && typ.is_fn_and_matches(find_types)) + .rposition(|b| b.name == find_ident && b.binding.typ().is_fn_and_matches(find_types)) .map(|idx| idx + self.base_offset) } /// Check if this scope already contains a function with the given name and arity. fn has_function_with_arity(&self, name: &str, arity: Option) -> bool { - self.identifiers.iter().any(|(ident, typ)| { - if ident != name { + self.identifiers.iter().any(|b| { + if b.name != name { return false; } - match typ { + match b.binding.typ() { StaticType::Function { parameters: Some(params), .. @@ -133,9 +157,11 @@ impl Scope { }) } - fn allocate(&mut self, name: String, typ: StaticType) -> usize { - self.identifiers.push((name, typ)); - // Slot is just the length of the list minus one + fn allocate(&mut self, name: String, type_binding: TypeBinding) -> usize { + self.identifiers.push(ScopeBinding { + name, + binding: type_binding, + }); self.base_offset + self.identifiers.len() - 1 } @@ -184,7 +210,13 @@ impl ScopeTree { /// user-level shadowing. pub fn from_global_scope(global_scope_map: Vec<(String, StaticType)>) -> Self { let mut global_scope = Scope::new_function_scope(None, 0); - global_scope.identifiers = global_scope_map; + global_scope.identifiers = global_scope_map + .into_iter() + .map(|(name, typ)| ScopeBinding { + name, + binding: TypeBinding::Inferred(typ), + }) + .collect(); Self { current_scope_idx: 0, @@ -217,7 +249,7 @@ impl ScopeTree { } } } - ResolvedVar::Global { slot } => &self.global_scope.identifiers[slot].1, + ResolvedVar::Global { slot } => self.global_scope.identifiers[slot].binding.typ(), } } @@ -233,7 +265,7 @@ impl ScopeTree { loop { let scope = &self.scopes[scope_idx]; if slot >= scope.base_offset && slot < scope.base_offset + scope.identifiers.len() { - return &scope.identifiers[slot - scope.base_offset].1; + return scope.identifiers[slot - scope.base_offset].binding.typ(); } scope_idx = scope .parent_idx @@ -407,9 +439,13 @@ impl ScopeTree { Binding::None } - pub(crate) fn create_local_binding(&mut self, ident: String, typ: StaticType) -> ResolvedVar { + pub(crate) fn create_local_binding( + &mut self, + ident: String, + binding: TypeBinding, + ) -> ResolvedVar { ResolvedVar::Local { - slot: self.scopes[self.current_scope_idx].allocate(ident, typ), + slot: self.scopes[self.current_scope_idx].allocate(ident, binding), } } @@ -427,15 +463,31 @@ impl ScopeTree { /// Uses `"\x00"` as a sentinel name that can never collide with user identifiers /// since the lexer never produces null bytes. pub(crate) fn reserve_anonymous_slot(&mut self) -> usize { - self.scopes[self.current_scope_idx].allocate("\x00".to_string(), StaticType::Any) + self.scopes[self.current_scope_idx] + .allocate("\x00".to_string(), TypeBinding::Inferred(StaticType::Any)) + } + + /// Try to update a binding's type. Returns `Err` with the annotated type + /// if the binding has an explicit type annotation and cannot be widened. + pub(crate) fn update_binding_type( + &mut self, + var: ResolvedVar, + new_type: StaticType, + ) -> Result<(), StaticType> { + let binding = self.get_binding_mut(var); + if binding.is_annotated() { + return Err(binding.typ().clone()); + } + *binding = TypeBinding::Inferred(new_type); + Ok(()) } - pub(crate) fn update_binding_type(&mut self, var: ResolvedVar, new_type: StaticType) { + fn get_binding_mut(&mut self, var: ResolvedVar) -> &mut TypeBinding { match var { ResolvedVar::Local { slot } => { let scope_idx = self.find_scope_owning_slot(self.current_scope_idx, slot); let base = self.scopes[scope_idx].base_offset; - self.scopes[scope_idx].identifiers[slot - base].1 = new_type; + &mut self.scopes[scope_idx].identifiers[slot - base].binding } ResolvedVar::Upvalue { slot } => { let mut scope_idx = self.scopes[self.current_scope_idx].function_scope_idx; @@ -451,8 +503,7 @@ impl ScopeTree { .expect("expected parent scope"); let owning = self.find_scope_owning_slot(parent, local_slot); let base = self.scopes[owning].base_offset; - self.scopes[owning].identifiers[local_slot - base].1 = new_type; - return; + return &mut self.scopes[owning].identifiers[local_slot - base].binding; } CaptureSource::Upvalue(uv_slot) => { scope_idx = self.get_parent_function_scope_idx(scope_idx); @@ -462,7 +513,7 @@ impl ScopeTree { } } ResolvedVar::Global { .. } => { - panic!("update_binding_type called with a global binding") + panic!("get_binding_mut called with a global binding") } } } @@ -584,7 +635,7 @@ mod tests { #[test] fn single_local_in_function_scope() { let mut tree = empty_scope_tree(); - let var = tree.create_local_binding("x".into(), StaticType::Int); + let var = tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(var, ResolvedVar::Local { slot: 0 }); assert_eq!( tree.get_binding_any("x"), @@ -595,9 +646,9 @@ mod tests { #[test] fn multiple_locals_get_ascending_slots() { let mut tree = empty_scope_tree(); - let x = tree.create_local_binding("x".into(), StaticType::Int); - let y = tree.create_local_binding("y".into(), StaticType::Int); - let z = tree.create_local_binding("z".into(), StaticType::Int); + let x = tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); + let y = tree.create_local_binding("y".into(), TypeBinding::Inferred(StaticType::Int)); + let z = tree.create_local_binding("z".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(x, ResolvedVar::Local { slot: 0 }); assert_eq!(y, ResolvedVar::Local { slot: 1 }); assert_eq!(z, ResolvedVar::Local { slot: 2 }); @@ -606,11 +657,11 @@ mod tests { #[test] fn block_scope_continues_flat_numbering() { let mut tree = empty_scope_tree(); - let x = tree.create_local_binding("x".into(), StaticType::Int); + let x = tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(x, ResolvedVar::Local { slot: 0 }); tree.new_block_scope(); - let y = tree.create_local_binding("y".into(), StaticType::Int); + let y = tree.create_local_binding("y".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(y, ResolvedVar::Local { slot: 1 }); assert_eq!( @@ -622,21 +673,21 @@ mod tests { #[test] fn nested_block_scopes_continue_numbering() { let mut tree = empty_scope_tree(); - tree.create_local_binding("a".into(), StaticType::Int); + tree.create_local_binding("a".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_block_scope(); - let b = tree.create_local_binding("b".into(), StaticType::Int); + let b = tree.create_local_binding("b".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(b, ResolvedVar::Local { slot: 1 }); tree.new_block_scope(); - let c = tree.create_local_binding("c".into(), StaticType::Int); + let c = tree.create_local_binding("c".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(c, ResolvedVar::Local { slot: 2 }); } #[test] fn block_scope_does_not_create_upvalue() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_block_scope(); assert_eq!( @@ -648,10 +699,10 @@ mod tests { #[test] fn function_scope_resets_slots_and_captures_as_upvalue() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); - let y = tree.create_local_binding("y".into(), StaticType::Int); + let y = tree.create_local_binding("y".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(y, ResolvedVar::Local { slot: 0 }); assert_eq!( @@ -663,10 +714,10 @@ mod tests { #[test] fn iteration_scope_continues_numbering_and_is_transparent() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_iteration_scope(); - let i = tree.create_local_binding("i".into(), StaticType::Int); + let i = tree.create_local_binding("i".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(i, ResolvedVar::Local { slot: 1 }); assert_eq!( @@ -694,21 +745,21 @@ mod tests { #[test] fn slot_reuse_after_scope_destroy() { let mut tree = empty_scope_tree(); - tree.create_local_binding("a".into(), StaticType::Int); + tree.create_local_binding("a".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_block_scope(); - tree.create_local_binding("b".into(), StaticType::Int); + tree.create_local_binding("b".into(), TypeBinding::Inferred(StaticType::Int)); tree.destroy_scope(); - let c = tree.create_local_binding("c".into(), StaticType::Int); + let c = tree.create_local_binding("c".into(), TypeBinding::Inferred(StaticType::Int)); assert_eq!(c, ResolvedVar::Local { slot: 1 }); } #[test] fn get_type_returns_correct_type() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); - tree.create_local_binding("y".into(), StaticType::String); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); + tree.create_local_binding("y".into(), TypeBinding::Inferred(StaticType::String)); assert_eq!( tree.get_type(ResolvedVar::Local { slot: 0 }), @@ -727,7 +778,7 @@ mod tests { #[test] fn upvalue_hoisting_across_two_function_scopes() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); // outer tree.new_function_scope(); // inner @@ -754,8 +805,8 @@ mod tests { #[test] fn multiple_upvalues_get_distinct_indices() { let mut tree = empty_scope_tree(); - tree.create_local_binding("a".into(), StaticType::Int); - tree.create_local_binding("b".into(), StaticType::String); + tree.create_local_binding("a".into(), TypeBinding::Inferred(StaticType::Int)); + tree.create_local_binding("b".into(), TypeBinding::Inferred(StaticType::String)); tree.new_function_scope(); @@ -770,7 +821,7 @@ mod tests { #[test] fn duplicate_upvalue_resolution_reuses_index() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); @@ -789,7 +840,7 @@ mod tests { #[test] fn get_type_follows_upvalue_chain() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); // outer tree.new_function_scope(); // inner @@ -805,7 +856,7 @@ mod tests { #[test] fn sibling_closure_finds_existing_upvalue() { let mut tree = empty_scope_tree(); - tree.create_local_binding("x".into(), StaticType::Int); + tree.create_local_binding("x".into(), TypeBinding::Inferred(StaticType::Int)); tree.new_function_scope(); // middle diff --git a/ndc_bin/src/main.rs b/ndc_bin/src/main.rs index c249f6fd..f0b3edc1 100644 --- a/ndc_bin/src/main.rs +++ b/ndc_bin/src/main.rs @@ -125,7 +125,10 @@ impl TryFrom for Action { Command::Run { file: Some(file), options, - } => Self::RunFile { path: file, options }, + } => Self::RunFile { + path: file, + options, + }, Command::Run { file: None, .. } => Self::StartRepl, Command::Lsp { stdio: _ } => Self::RunLsp, Command::Disassemble { file } => Self::DisassembleFile(file), @@ -153,10 +156,7 @@ fn main() -> anyhow::Result<()> { let action: Action = cli.command.unwrap_or_default().try_into()?; match action { - Action::RunFile { - path, - options, - } => { + Action::RunFile { path, options } => { let filename = path .file_name() .and_then(|name| name.to_str()) diff --git a/ndc_core/src/static_type.rs b/ndc_core/src/static_type.rs index aff10d29..3e13ed78 100644 --- a/ndc_core/src/static_type.rs +++ b/ndc_core/src/static_type.rs @@ -49,6 +49,17 @@ impl TypeSignature { } } + pub fn from_annotated_bindings(bindings: Vec<(String, Option)>) -> Self { + Self::Exact( + bindings + .into_iter() + .map(|(name, annotation)| { + Parameter::new(name, annotation.unwrap_or(StaticType::Any)) + }) + .collect(), + ) + } + pub fn types(&self) -> Option> { match self { Self::Variadic => None, @@ -108,7 +119,127 @@ pub enum StaticType { Deque(Box), } +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct StaticTypeConstructionError { + message: String, + help_text: String, +} + +impl StaticTypeConstructionError { + fn new, H: Into>(message: M, help_text: H) -> Self { + Self { + message: message.into(), + help_text: help_text.into(), + } + } + + pub fn help_text(&self) -> &str { + &self.help_text + } +} + +impl fmt::Display for StaticTypeConstructionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.message.fmt(f) + } +} + impl StaticType { + pub fn from_name_and_args( + name: &str, + args: Vec, + ) -> Result { + match name { + "Any" => Self::require_no_args(name, &args).map(|_| Self::Any), + "Never" => Self::require_no_args(name, &args).map(|_| Self::Never), + "Bool" => Self::require_no_args(name, &args).map(|_| Self::Bool), + "Number" => Self::require_no_args(name, &args).map(|_| Self::Number), + "Float" => Self::require_no_args(name, &args).map(|_| Self::Float), + "Int" => Self::require_no_args(name, &args).map(|_| Self::Int), + "Rational" => Self::require_no_args(name, &args).map(|_| Self::Rational), + "Complex" => Self::require_no_args(name, &args).map(|_| Self::Complex), + "String" => Self::require_no_args(name, &args).map(|_| Self::String), + "Option" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::Option(Box::new(elem))) + } + "Sequence" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::Sequence(Box::new(elem))) + } + "List" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::List(Box::new(elem))) + } + "Iterator" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::Iterator(Box::new(elem))) + } + "MinHeap" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::MinHeap(Box::new(elem))) + } + "MaxHeap" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::MaxHeap(Box::new(elem))) + } + "Deque" => { + Self::require_exactly_one_arg(name, args).map(|elem| Self::Deque(Box::new(elem))) + } + "Tuple" => Self::require_at_least_one_arg(name, args).map(Self::Tuple), + "Map" => { + let [key, value] = Self::require_exactly_n_args::<2>(name, args)?; + Ok(Self::Map { + key: Box::new(key), + value: Box::new(value), + }) + } + _ => Err(StaticTypeConstructionError::new( + format!("unknown type `{name}`"), + "Use a valid type name in this annotation.", + )), + } + } + + fn require_no_args(name: &str, args: &[Self]) -> Result<(), StaticTypeConstructionError> { + if args.is_empty() { + Ok(()) + } else { + Err(StaticTypeConstructionError::new( + format!("type `{name}` does not take generic arguments"), + format!("Remove the generic arguments from `{name}`."), + )) + } + } + + fn require_exactly_one_arg( + name: &str, + args: Vec, + ) -> Result { + let [arg] = Self::require_exactly_n_args(name, args)?; + Ok(arg) + } + + fn require_exactly_n_args( + name: &str, + args: Vec, + ) -> Result<[Self; N], StaticTypeConstructionError> { + args.try_into().map_err(|_err: Vec| { + StaticTypeConstructionError::new( + format!("type `{name}` expects exactly {N} generic arguments"), + format!("Use `{name}<...>` with {N} type arguments."), + ) + }) + } + + fn require_at_least_one_arg( + name: &str, + args: Vec, + ) -> Result, StaticTypeConstructionError> { + if args.is_empty() { + Err(StaticTypeConstructionError::new( + format!("type `{name}` requires generic arguments"), + format!("Add generic arguments like `{name}<...>`."), + )) + } else { + Ok(args) + } + } + /// Checks if `self` is a subtype of `other`. /// /// A type S is a subtype of T (S <: T) if a value of type S can be safely @@ -469,6 +600,21 @@ impl StaticType { !self.is_subtype(other) && !other.is_subtype(self) } + /// Returns a new type with the element type replaced. For container types + /// like `List`, this returns `List`. Returns `None` if the + /// type does not have a replaceable element type. + pub fn with_element_type(&self, new_elem: Self) -> Self { + match self { + Self::List(_) => Self::List(Box::new(new_elem)), + Self::Sequence(_) => Self::Sequence(Box::new(new_elem)), + Self::Iterator(_) => Self::Iterator(Box::new(new_elem)), + Self::MinHeap(_) => Self::MinHeap(Box::new(new_elem)), + Self::MaxHeap(_) => Self::MaxHeap(Box::new(new_elem)), + Self::Deque(_) => Self::Deque(Box::new(new_elem)), + _ => self.clone(), + } + } + pub fn index_element_type(&self) -> Option { if let Self::Map { value, .. } = self { return Some(value.as_ref().clone()); diff --git a/ndc_lsp/src/features/inlay_hints.rs b/ndc_lsp/src/features/inlay_hints.rs index 9883883e..339e8c94 100644 --- a/ndc_lsp/src/features/inlay_hints.rs +++ b/ndc_lsp/src/features/inlay_hints.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use ndc_core::StaticType; use ndc_interpreter::AnalysisResult; use ndc_lexer::Span; -use ndc_parser::ExpressionLocation; +use ndc_parser::{ExpressionLocation, NodeId}; use tower_lsp::lsp_types::{InlayHint, InlayHintKind, InlayHintLabel}; use crate::util::position_from_offset; @@ -54,35 +54,124 @@ impl AstVisitor for HintCollector<'_> { } } - fn on_declaration(&mut self, identifier: &str, inferred_type: Option<&StaticType>, span: Span) { + fn on_declaration( + &mut self, + identifier: &str, + inferred_type: Option<&StaticType>, + has_annotation: bool, + span: Span, + ) { if let Some(typ) = inferred_type { - self.hints.push(InlayHint { - position: position_from_offset(self.text, span.end()), - label: InlayHintLabel::String(format!(": {typ}")), - kind: Some(InlayHintKind::TYPE), - text_edits: None, - tooltip: None, - padding_left: None, - padding_right: Some(true), - data: None, - }); + if !has_annotation { + self.hints.push(InlayHint { + position: position_from_offset(self.text, span.end()), + label: InlayHintLabel::String(format!(": {typ}")), + kind: Some(InlayHintKind::TYPE), + text_edits: None, + tooltip: None, + padding_left: None, + padding_right: Some(true), + data: None, + }); + } self.variable_types .insert(identifier.to_string(), typ.clone()); } } - fn on_function_declaration(&mut self, return_type: Option<&StaticType>, parameters_span: Span) { - if let Some(rt) = return_type { - self.hints.push(InlayHint { - position: position_from_offset(self.text, parameters_span.end()), - label: InlayHintLabel::String(format!(" -> {rt}")), - kind: Some(InlayHintKind::TYPE), - text_edits: None, - tooltip: None, - padding_left: None, - padding_right: None, - data: None, - }); + fn on_function_declaration( + &mut self, + return_type: Option<&StaticType>, + parameters_span: Span, + node_id: NodeId, + ) { + // return_type is Some only when explicitly annotated by the user — skip the hint. + // Inferred return types are stored in the side table. + if return_type.is_none() { + if let Some(rt) = self.analysis_result.inferred_return_types.get(&node_id) { + self.hints.push(InlayHint { + position: position_from_offset(self.text, parameters_span.end()), + label: InlayHintLabel::String(format!(" -> {rt}")), + kind: Some(InlayHintKind::TYPE), + text_edits: None, + tooltip: None, + padding_left: None, + padding_right: None, + data: None, + }); + } } } } + +#[cfg(test)] +mod tests { + use super::*; + use ndc_interpreter::Interpreter; + + fn collect_hints(source: &str) -> AnalysisInfo { + let mut interpreter = Interpreter::capturing(); + interpreter.configure(ndc_stdlib::register); + let (expressions, analysis_result) = interpreter + .analyse_str(source) + .expect("analysis should succeed"); + collect(&expressions, &analysis_result, source) + } + + #[test] + fn inferred_let_binding_gets_type_inlay() { + let info = collect_hints("let value = 1;"); + assert!( + info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == ": Int") + ) + ); + } + + #[test] + fn annotated_let_binding_skips_type_inlay() { + let info = collect_hints("let value: Int = 1;"); + assert!( + !info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == ": Int") + ) + ); + assert_eq!(info.variable_types.get("value"), Some(&StaticType::Int)); + } + + #[test] + fn annotated_return_type_skips_inlay() { + let info = collect_hints("fn foo(x: Int) -> Int { x + 1 }"); + assert!(!info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label.contains("->")) + )); + } + + #[test] + fn inferred_return_type_gets_inlay() { + let info = collect_hints("fn foo() { 42 }"); + assert!(info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == " -> Int") + )); + } + + #[test] + fn annotated_param_skips_inlay() { + let info = collect_hints("fn foo(x: Int) { x }"); + assert!( + !info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == ": Int") + ) + ); + } + + #[test] + fn unannotated_param_gets_inlay() { + let info = collect_hints("fn foo(x) { x }"); + assert!( + info.hints.iter().any( + |hint| matches!(&hint.label, InlayHintLabel::String(label) if label == ": Any") + ) + ); + } +} diff --git a/ndc_lsp/src/visitor.rs b/ndc_lsp/src/visitor.rs index f22b934a..417ff525 100644 --- a/ndc_lsp/src/visitor.rs +++ b/ndc_lsp/src/visitor.rs @@ -1,6 +1,6 @@ use ndc_core::StaticType; use ndc_lexer::Span; -use ndc_parser::{Expression, ExpressionLocation, ForBody, ForIteration, Lvalue}; +use ndc_parser::{Expression, ExpressionLocation, ForBody, ForIteration, Lvalue, NodeId}; /// Trait for visiting interesting nodes during an AST walk. /// @@ -13,6 +13,7 @@ pub trait AstVisitor { &mut self, _identifier: &str, _inferred_type: Option<&StaticType>, + _has_annotation: bool, _span: Span, ) { } @@ -25,6 +26,7 @@ pub trait AstVisitor { &mut self, _return_type: Option<&StaticType>, _parameters_span: Span, + _node_id: NodeId, ) { } } @@ -39,17 +41,25 @@ pub fn walk_ast(visitor: &mut impl AstVisitor, expressions: &[ExpressionLocation fn walk_expression(visitor: &mut impl AstVisitor, expr: &ExpressionLocation) { visitor.on_expression(expr); match &expr.expression { - Expression::VariableDeclaration { l_value, value } => { - walk_lvalue(visitor, l_value); + Expression::VariableDeclaration { + l_value, + annotated_type, + value, + } => { + walk_lvalue(visitor, l_value, annotated_type.is_some()); walk_expression(visitor, value); } Expression::FunctionDeclaration { return_type, + parameters, parameters_span, body, .. } => { - visitor.on_function_declaration(return_type.as_ref(), *parameters_span); + for p in parameters { + walk_lvalue(visitor, &p.lvalue, p.annotation.is_some()); + } + visitor.on_function_declaration(return_type.as_ref(), *parameters_span, expr.id); walk_expression(visitor, body); } Expression::Statement(inner) | Expression::Grouping(inner) => { @@ -82,7 +92,7 @@ fn walk_expression(visitor: &mut impl AstVisitor, expr: &ExpressionLocation) { for iteration in iterations { match iteration { ForIteration::Iteration { l_value, sequence } => { - walk_lvalue(visitor, l_value); + walk_lvalue(visitor, l_value, false); walk_expression(visitor, sequence); } ForIteration::Guard(expr) => walk_expression(visitor, expr), @@ -113,7 +123,7 @@ fn walk_expression(visitor: &mut impl AstVisitor, expr: &ExpressionLocation) { } } -fn walk_lvalue(visitor: &mut impl AstVisitor, lvalue: &Lvalue) { +fn walk_lvalue(visitor: &mut impl AstVisitor, lvalue: &Lvalue, has_annotation: bool) { match lvalue { Lvalue::Identifier { identifier, @@ -121,11 +131,11 @@ fn walk_lvalue(visitor: &mut impl AstVisitor, lvalue: &Lvalue) { span, .. } => { - visitor.on_declaration(identifier, inferred_type.as_ref(), *span); + visitor.on_declaration(identifier, inferred_type.as_ref(), has_annotation, *span); } Lvalue::Sequence(lvalues) => { for lv in lvalues { - walk_lvalue(visitor, lv); + walk_lvalue(visitor, lv, has_annotation); } } Lvalue::Index { .. } => {} diff --git a/ndc_parser/src/expression.rs b/ndc_parser/src/expression.rs index f8c0a803..a9677a78 100644 --- a/ndc_parser/src/expression.rs +++ b/ndc_parser/src/expression.rs @@ -76,6 +76,7 @@ pub enum Expression { Grouping(Box), VariableDeclaration { l_value: Lvalue, + annotated_type: Option, value: Box, }, Assignment { @@ -92,7 +93,7 @@ pub enum Expression { FunctionDeclaration { name: Option, resolved_name: Option, - type_signature: TypeSignature, + parameters: Vec, parameters_span: Span, body: Box, return_type: Option, @@ -170,6 +171,29 @@ pub enum ForBody { }, } +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct FunctionParameter { + pub lvalue: Lvalue, + pub annotation: Option, + pub span: Span, +} + +impl FunctionParameter { + pub fn to_type_signature(params: &[Self]) -> TypeSignature { + TypeSignature::from_annotated_bindings( + params + .iter() + .map(|p| { + let Lvalue::Identifier { identifier, .. } = &p.lvalue else { + panic!("expected identifier in parameter list: {:?}", p.lvalue); + }; + (identifier.clone(), p.annotation.clone()) + }) + .collect(), + ) + } +} + #[derive(Debug, Eq, PartialEq, Clone)] pub enum Lvalue { // Example: `let foo = ...` diff --git a/ndc_parser/src/lib.rs b/ndc_parser/src/lib.rs index a5227e63..cae9582f 100644 --- a/ndc_parser/src/lib.rs +++ b/ndc_parser/src/lib.rs @@ -3,8 +3,8 @@ mod operator; mod parser; pub use expression::{ - Binding, CaptureSource, Expression, ExpressionLocation, ForBody, ForIteration, Lvalue, NodeId, - ResolvedVar, + Binding, CaptureSource, Expression, ExpressionLocation, ForBody, ForIteration, + FunctionParameter, Lvalue, NodeId, ResolvedVar, }; pub use operator::{BinaryOperator, LogicalOperator, UnaryOperator}; pub use parser::Error; diff --git a/ndc_parser/src/parser.rs b/ndc_parser/src/parser.rs index ef70156c..077b6828 100644 --- a/ndc_parser/src/parser.rs +++ b/ndc_parser/src/parser.rs @@ -1,7 +1,9 @@ use std::fmt::Write; use crate::expression::Expression; -use crate::expression::{Binding, ExpressionLocation, ForBody, ForIteration, Lvalue, NodeId}; +use crate::expression::{ + Binding, ExpressionLocation, ForBody, ForIteration, FunctionParameter, Lvalue, NodeId, +}; use crate::operator::{BinaryOperator, LogicalOperator, UnaryOperator}; use ndc_core::{Parameter, StaticType, TypeSignature}; use ndc_lexer::{Span, Token, TokenLocation}; @@ -309,16 +311,7 @@ impl Parser { .require_current_token_matches(&Token::Let) .expect("guaranteed to match by caller"); - let maybe_lvalue = self.tuple_expression(Self::single_expression, false)?; - let lvalue_span = maybe_lvalue.span; - - let Ok(lvalue) = Lvalue::try_from(maybe_lvalue) else { - return Err(Error::with_help( - "Invalid assignment target".to_string(), - lvalue_span, - "Assignment target is not a valid lvalue. Only a few expressions can be assigned a value. Check that the left-hand side of the assignment is a valid target.".to_string(), - )); - }; + let (lvalue, annotated_type) = self.named_binding()?; self.require_current_token_matches(&Token::EqualsSign)?; @@ -326,6 +319,7 @@ impl Parser { let end = expression.span; let declaration = Expression::VariableDeclaration { l_value: lvalue, + annotated_type, value: Box::new(expression), }; @@ -431,25 +425,54 @@ impl Parser { } } + fn delimited_comma_separated( + &mut self, + open: &Token, + close: &Token, + parse_item: fn(&mut Self) -> Result, + allow_empty: bool, + ) -> Result<(Vec, Span), Error> { + let open_span = self.require_current_token_matches(open)?.span; + + if let Some(close_token) = self.consume_token_if(std::slice::from_ref(close)) { + if allow_empty { + return Ok((Vec::new(), open_span.merge(close_token.span))); + } + + return Err(Error::with_help( + format!("expected an item before '{close}'"), + close_token.span, + "This delimited list cannot be empty.".to_string(), + )); + } + + let mut items = vec![parse_item(self)?]; + + while self.consume_token_if(&[Token::Comma]).is_some() { + if self.match_token(std::slice::from_ref(close)).is_some() { + break; + } + + items.push(parse_item(self)?); + } + + let close_span = self.require_current_token_matches(close)?.span; + Ok((items, open_span.merge(close_span))) + } + /// Parses a delimited tuple (enclosed in parentheses) that can be empty fn delimited_tuple( &mut self, next: fn(&mut Self) -> Result, ) -> Result { - let start = self.require_current_token_matches(&Token::LeftParentheses)?; - if let Some(end) = self.consume_token_if(&[Token::RightParentheses]) { - Ok(Expression::Tuple { values: vec![] }.to_location(start.span.merge(end.span))) - } else { - let mut tuple_expression = self.tuple_expression(next, true)?; - let right_paren_span = self - .require_current_token_matches(&Token::RightParentheses)? - .span; - - // Include the right paretheses in the span - tuple_expression.span = tuple_expression.span.merge(right_paren_span); + let (values, span) = self.delimited_comma_separated( + &Token::LeftParentheses, + &Token::RightParentheses, + next, + true, + )?; - Ok(tuple_expression) - } + Ok(Expression::Tuple { values }.to_location(span)) } fn single_expression(&mut self) -> Result { @@ -1158,7 +1181,22 @@ impl Parser { } }; - let argument_list = self.delimited_tuple(Self::single_expression)?; + // let argument_list = self.delimited_tuple(Self::single_expression)?; + + let (argument_list, parameters_span) = self.delimited_comma_separated( + &Token::LeftParentheses, + &Token::RightParentheses, + Self::named_parameter, + true, + )?; + + // Optional return type annotation: `-> Type` + let annotated_return_type = if self.peek_current_token() == Some(&Token::RightArrow) { + self.advance(); + Some(self.static_type()?) + } else { + None + }; // Next we either expect a body block `{ ... }` or a fat arrow followed by a single expression `=> ...` @@ -1175,20 +1213,17 @@ impl Parser { "Expected that the argument list is followed by either a body `{}` or a fat arrow `=>`".to_string(), )) } - None => return Err(Error::end_of_input(argument_list.span)), + None => return Err(Error::end_of_input(parameters_span)), }; - let parameters_span = argument_list.span; let span = fn_token.span.merge(body.span); Ok(ExpressionLocation { expression: Expression::FunctionDeclaration { name: identifier, - type_signature: argument_list - .try_into() - .expect("INTERNAL ERROR: type of argument list is incorrect"), + parameters: argument_list, parameters_span, body: Box::new(body), - return_type: None, // At some point in the future we could use type declarations here to insert the type (return type inference is cringe anyway) + return_type: annotated_return_type, resolved_name: None, captures: vec![], pure: is_pure, @@ -1296,6 +1331,171 @@ impl Parser { }; Ok(Expression::Map { values, default }.to_location(map_open_span.merge(map_close_span))) } + + pub fn static_type(&mut self) -> Result { + let Some(TokenLocation { token, span }) = self.peek_current_token_location() else { + return Err(Error::end_of_input( + self.tokens.last().expect("last token exists").span, + )); + }; + + match token { + Token::Identifier(_) => self.named_or_generic_type(), + Token::LeftParentheses => self.tuple_type(), + _ => Err(Error::with_help( + format!("expected a type annotation, found `{token}`"), + *span, + "Use a valid type name or tuple type annotation in this position.".to_string(), + )), + } + } + + pub fn named_or_generic_type(&mut self) -> Result { + let Ok(TokenLocation { + token: Token::Identifier(ident), + span, + }) = self.require_current_token() + else { + unreachable!("this should have been checked"); + }; + + let generic_args = if self.peek_current_token() == Some(&Token::Less) { + self.delimited_type_params()? + } else { + Vec::new() + }; + + StaticType::from_name_and_args(ident.as_str(), generic_args) + .map_err(|err| Error::with_help(err.to_string(), span, err.help_text().to_string())) + } + + /// Parses `` type parameter lists, handling the `>>` / `>=` / `>>=` + /// ambiguity that arises with nested generics like `List>`. + fn delimited_type_params(&mut self) -> Result, Error> { + self.require_current_token_matches(&Token::Less)?; + + let mut items = vec![self.static_type()?]; + + while self.consume_token_if(&[Token::Comma]).is_some() { + if self.peek_current_token() == Some(&Token::Greater) { + break; + } + items.push(self.static_type()?); + } + + self.consume_closing_angle_bracket()?; + Ok(items) + } + + /// Consumes a closing `>` for a generic type parameter list. If the current + /// token is `>>`, `>=`, or `>>=`, it is split so that the leading `>` is + /// consumed and the remainder is left as the current token. + fn consume_closing_angle_bracket(&mut self) -> Result { + if let Some(token) = self.consume_token_if(&[Token::Greater]) { + return Ok(token.span); + } + + let Some(loc) = self.peek_current_token_location() else { + return Err(Error::end_of_input( + self.tokens.last().expect("last token exists").span, + )); + }; + + let greater_span = Span::new(loc.span.source_id(), loc.span.offset(), 1); + let rest_span = Span::new( + loc.span.source_id(), + loc.span.offset() + 1, + loc.span.end() - loc.span.offset() - 1, + ); + + let remainder = match &loc.token { + // >> becomes > + Token::GreaterGreater => Token::Greater, + // >= becomes = + Token::GreaterEquals => Token::EqualsSign, + // >>= (OpAssign(>>)) becomes >= + Token::OpAssign(inner) if inner.token == Token::GreaterGreater => Token::GreaterEquals, + _ => { + let loc = loc.clone(); + return Err(Error::text( + format!("Expected token '>' but got '{}' instead", loc.token), + loc.span, + )); + } + }; + + self.tokens[self.current] = TokenLocation { + token: remainder, + span: rest_span, + }; + + Ok(greater_span) + } + + pub fn tuple_type(&mut self) -> Result { + let (types, _span) = self.delimited_comma_separated( + &Token::LeftParentheses, + &Token::RightParentheses, + Self::static_type, + true, + )?; + Ok(StaticType::Tuple(types)) + } + + fn named_parameter(&mut self) -> Result { + let maybe_lvalue = self.single_expression()?; + let lvalue_span = maybe_lvalue.span; + + let Ok(lvalue) = Lvalue::try_from(maybe_lvalue) else { + return Err(Error::with_help( + "Expected parameter name".to_string(), + lvalue_span, + "Function parameters must be identifiers, optionally followed by a type annotation (e.g. `x` or `x: Int`).".to_string(), + )); + }; + + let annotation = if self.peek_current_token() == Some(&Token::Colon) { + self.advance(); + Some(self.static_type()?) + } else { + None + }; + + let span = if annotation.is_some() { + lvalue_span.merge(self.tokens[self.current - 1].span) + } else { + lvalue_span + }; + + Ok(FunctionParameter { + lvalue, + annotation, + span, + }) + } + + pub fn named_binding(&mut self) -> Result<(Lvalue, Option), Error> { + let maybe_lvalue = self.tuple_expression(Self::single_expression, false)?; + let lvalue_span = maybe_lvalue.span; + + let Ok(lvalue) = Lvalue::try_from(maybe_lvalue) else { + return Err(Error::with_help( + "Invalid assignment target".to_string(), + lvalue_span, + "Assignment target is not a valid lvalue. Only a few expressions can be assigned a value. Check that the left-hand side of the assignment is a valid target.".to_string(), + )); + }; + + let annotated_type = if self.peek_current_token() == Some(&Token::Colon) { + self.advance(); + Some(self.static_type()?) + } else { + None + }; + + Ok((lvalue, annotated_type)) + } + fn peek_range_end(&self) -> bool { matches!( self.peek_current_token(), @@ -1320,9 +1520,12 @@ pub struct Error { impl Error { #[must_use] - pub fn text(text: String, span: Span) -> Self { + pub fn text(text: S, span: Span) -> Self + where + S: Into, + { Self { - text, + text: text.into(), span, help_text: None, } diff --git a/ndc_stdlib/src/math.rs b/ndc_stdlib/src/math.rs index b2377dc0..5732c824 100644 --- a/ndc_stdlib/src/math.rs +++ b/ndc_stdlib/src/math.rs @@ -172,6 +172,7 @@ mod inner { pub mod f64 { use super::{Number, ToPrimitive, f64}; use ndc_core::StaticType; + use ndc_core::int::Int; use ndc_core::num::BinaryOperatorError; use ndc_vm::error::VmError; use ndc_vm::value::{NativeFunc, NativeFunction, Value}; @@ -241,6 +242,103 @@ pub mod f64 { "Returns the Euclidean remainder of dividing two numbers. The result is always non-negative." ); + // Int-specific overloads: fast path on i64, fall back to Number on overflow/BigInt. + macro_rules! implement_binary_operator_on_int { + ($operator:literal, $checked_method:ident, $fallback:expr, $docs:literal) => { + env.declare_global_fn(Rc::new(NativeFunction { + name: $operator.to_string(), + documentation: Some($docs.to_string()), + static_type: StaticType::Function { + parameters: Some(vec![StaticType::Int, StaticType::Int]), + return_type: Box::new(StaticType::Int), + }, + func: NativeFunc::Simple(Box::new(|args| match args { + [Value::Int(l), Value::Int(r)] => { + if let Some(result) = l.$checked_method(*r) { + Ok(Value::Int(result)) + } else { + let l = Int::Int64(*l); + let r = Int::Int64(*r); + Ok(Value::from_int($fallback(l, r))) + } + } + [left, right] => { + let l = left.to_int().ok_or_else(|| { + VmError::native(format!("expected int, got {}", left.static_type())) + })?; + let r = right.to_int().ok_or_else(|| { + VmError::native(format!( + "expected int, got {}", + right.static_type() + )) + })?; + Ok(Value::from_int($fallback(l, r))) + } + _ => Err(VmError::native(format!( + "expected 2 arguments, got {}", + args.len() + ))), + })), + })); + }; + } + + implement_binary_operator_on_int!( + "+", + checked_add, + std::ops::Add::add, + "Adds two integers." + ); + implement_binary_operator_on_int!( + "-", + checked_sub, + std::ops::Sub::sub, + "Subtracts two integers." + ); + implement_binary_operator_on_int!( + "*", + checked_mul, + std::ops::Mul::mul, + "Multiplies two integers." + ); + implement_binary_operator_on_int!( + "%", + checked_rem, + std::ops::Rem::rem, + "Returns the remainder of dividing two integers." + ); + + // Float-specific overloads: operate directly on f64. + macro_rules! implement_binary_operator_on_float { + ($operator:literal, $op:expr, $docs:literal) => { + env.declare_global_fn(Rc::new(NativeFunction { + name: $operator.to_string(), + documentation: Some($docs.to_string()), + static_type: StaticType::Function { + parameters: Some(vec![StaticType::Float, StaticType::Float]), + return_type: Box::new(StaticType::Float), + }, + func: NativeFunc::Simple(Box::new(|args| match args { + [Value::Float(l), Value::Float(r)] => Ok(Value::Float($op(*l, *r))), + _ => Err(VmError::native(format!( + "expected 2 float arguments, got {}", + args.len() + ))), + })), + })); + }; + } + + implement_binary_operator_on_float!("+", std::ops::Add::add, "Adds two floats."); + implement_binary_operator_on_float!("-", std::ops::Sub::sub, "Subtracts two floats."); + implement_binary_operator_on_float!("*", std::ops::Mul::mul, "Multiplies two floats."); + implement_binary_operator_on_float!("/", std::ops::Div::div, "Divides two floats."); + implement_binary_operator_on_float!( + "%", + std::ops::Rem::rem, + "Returns the remainder of dividing two floats." + ); + env.declare_global_fn(Rc::new(NativeFunction { name: "-".to_string(), documentation: Some("Negates a number.".to_string()), diff --git a/ndc_vm/src/compiler.rs b/ndc_vm/src/compiler.rs index 1ce86153..cdbe8b94 100644 --- a/ndc_vm/src/compiler.rs +++ b/ndc_vm/src/compiler.rs @@ -4,8 +4,8 @@ use crate::{Object, Value}; use ndc_core::{StaticType, TypeSignature}; use ndc_lexer::Span; use ndc_parser::{ - Binding, CaptureSource, Expression, ExpressionLocation, ForBody, ForIteration, LogicalOperator, - Lvalue, ResolvedVar, + Binding, CaptureSource, Expression, ExpressionLocation, ForBody, ForIteration, + FunctionParameter, LogicalOperator, Lvalue, ResolvedVar, }; use std::rc::Rc; @@ -152,7 +152,7 @@ impl Compiler { } } } - Expression::VariableDeclaration { value, l_value } => { + Expression::VariableDeclaration { value, l_value, .. } => { self.compile_expr(*value)?; self.compile_declare_lvalue(l_value, span)?; } @@ -299,12 +299,13 @@ impl Compiler { name, resolved_name, body, - type_signature, + parameters, return_type, captures, pure, .. } => { + let type_signature = FunctionParameter::to_type_signature(¶meters); self.compile_function_decl( name, resolved_name, diff --git a/tests/programs/004_basic/046_annotated_let_binding.ndc b/tests/programs/004_basic/046_annotated_let_binding.ndc new file mode 100644 index 00000000..6b88f8df --- /dev/null +++ b/tests/programs/004_basic/046_annotated_let_binding.ndc @@ -0,0 +1,24 @@ +// This test asserts that supported annotated let bindings are valid syntax. +let any_value: Any = 3; +while false { + let never_value: Never = break; +} +let bool_value: Bool = true; +let int_value: Int = 3; +let float_value: Float = 3.0; +let rational_value: Number = 3 / 4; +let complex_value: Number = 1 + 2i; +let number_value: Number = 3; +let string_value: String = "hello"; + +let option_value: Option = Some(3); +let sequence_value: Sequence = [1, 2, 3]; +let list_value: List = [1, 2, 3]; +let iterator_value: Iterator = 1..10; +let min_heap_value: MinHeap = MinHeap(); +let max_heap_value: MaxHeap = MaxHeap(); +let deque_value: Deque = Deque(); +let map_value: Map = %{"a": 1, "b": 2}; +let tuple_named_value: Tuple = (1, "hello"); +let tuple_shorthand_value: (Int, String) = (1, "hello"); +let tuple_empty_value: () = (); diff --git a/tests/programs/004_basic/047_annotated_let_type_mismatch.ndc b/tests/programs/004_basic/047_annotated_let_type_mismatch.ndc new file mode 100644 index 00000000..5359df52 --- /dev/null +++ b/tests/programs/004_basic/047_annotated_let_type_mismatch.ndc @@ -0,0 +1,2 @@ +// expect-error: mismatched types: found String but expected Int +let x: Int = "hello"; diff --git a/tests/programs/004_basic/048_annotated_let_type_mismatch_bool.ndc b/tests/programs/004_basic/048_annotated_let_type_mismatch_bool.ndc new file mode 100644 index 00000000..fd05d342 --- /dev/null +++ b/tests/programs/004_basic/048_annotated_let_type_mismatch_bool.ndc @@ -0,0 +1,2 @@ +// expect-error: mismatched types: found Bool but expected String +let x: String = true; diff --git a/tests/programs/004_basic/049_annotated_let_type_mismatch_list.ndc b/tests/programs/004_basic/049_annotated_let_type_mismatch_list.ndc new file mode 100644 index 00000000..7f292235 --- /dev/null +++ b/tests/programs/004_basic/049_annotated_let_type_mismatch_list.ndc @@ -0,0 +1,2 @@ +// expect-error: mismatched types: found List but expected List +let x: List = [1, 2, 3]; diff --git a/tests/programs/004_basic/050_annotated_let_type_mismatch_tuple.ndc b/tests/programs/004_basic/050_annotated_let_type_mismatch_tuple.ndc new file mode 100644 index 00000000..06c3dd48 --- /dev/null +++ b/tests/programs/004_basic/050_annotated_let_type_mismatch_tuple.ndc @@ -0,0 +1,2 @@ +// expect-error: mismatched types: found Tuple but expected Tuple +let x: (String, String) = (1, 2); diff --git a/tests/programs/004_basic/051_annotated_let_subtype_accepted.ndc b/tests/programs/004_basic/051_annotated_let_subtype_accepted.ndc new file mode 100644 index 00000000..7a231d3a --- /dev/null +++ b/tests/programs/004_basic/051_annotated_let_subtype_accepted.ndc @@ -0,0 +1,3 @@ +// expect-output: 42 +let x: Number = 42; +print(x); diff --git a/tests/programs/004_basic/052_annotated_let_rejects_supertype.ndc b/tests/programs/004_basic/052_annotated_let_rejects_supertype.ndc new file mode 100644 index 00000000..b9ee794f --- /dev/null +++ b/tests/programs/004_basic/052_annotated_let_rejects_supertype.ndc @@ -0,0 +1,3 @@ +// expect-error: mismatched types: found Number but expected Int +let x: Number = 3; +let y: Int = x; diff --git a/tests/programs/004_basic/053_nested_generics.ndc b/tests/programs/004_basic/053_nested_generics.ndc new file mode 100644 index 00000000..c47bf566 --- /dev/null +++ b/tests/programs/004_basic/053_nested_generics.ndc @@ -0,0 +1,19 @@ +// expect-output: [[1]] +// expect-output: [2] +// expect-output: [[3]] + +// This test ensures the parser correctly splits compound `>` tokens +// when closing nested generic type parameters. + +// >> is split into > > +let xs: List> = [[1]]; + +// >= is split into > = (no space before `=`) +let ys: List= [2]; + +// >>= is split into > >= then > = (no space before `=` with nested generics) +let zs: List>= [[3]]; + +print(xs); +print(ys); +print(zs); diff --git a/tests/programs/004_basic/054_annotated_let_reassignment_rejected.ndc b/tests/programs/004_basic/054_annotated_let_reassignment_rejected.ndc new file mode 100644 index 00000000..78dda3c0 --- /dev/null +++ b/tests/programs/004_basic/054_annotated_let_reassignment_rejected.ndc @@ -0,0 +1,3 @@ +let x: Int = 5; +x = "test"; +// expect-error: mismatched types diff --git a/tests/programs/004_basic/055_annotated_let_op_assign_rejected.ndc b/tests/programs/004_basic/055_annotated_let_op_assign_rejected.ndc new file mode 100644 index 00000000..88cbc43e --- /dev/null +++ b/tests/programs/004_basic/055_annotated_let_op_assign_rejected.ndc @@ -0,0 +1,3 @@ +let x: Int = 3; +x /= 4; +// expect-error: mismatched types diff --git a/tests/programs/005_functions/037_return_type_annotation.ndc b/tests/programs/005_functions/037_return_type_annotation.ndc new file mode 100644 index 00000000..c6374d61 --- /dev/null +++ b/tests/programs/005_functions/037_return_type_annotation.ndc @@ -0,0 +1,6 @@ +fn greet(name: String) -> String => "hello " <> name; +assert_eq(greet("world"), "hello world"); + +fn identity(x: Int) -> Int => x; +assert_eq(identity(5), 5); +// expect-output: diff --git a/tests/programs/005_functions/038_return_type_annotation_mismatch.ndc b/tests/programs/005_functions/038_return_type_annotation_mismatch.ndc new file mode 100644 index 00000000..90ace995 --- /dev/null +++ b/tests/programs/005_functions/038_return_type_annotation_mismatch.ndc @@ -0,0 +1,2 @@ +fn bad() -> Int { "hello" } +// expect-error: mismatched types diff --git a/tests/programs/900_bugs/bug0017_function_parser_crash.ndc b/tests/programs/900_bugs/bug0017_function_parser_crash.ndc new file mode 100644 index 00000000..4d5ebada --- /dev/null +++ b/tests/programs/900_bugs/bug0017_function_parser_crash.ndc @@ -0,0 +1,2 @@ +// expect-error: Expected parameter name +fn x(1 + 1) { }