diff --git a/examples/euler5.rsc b/examples/euler5.rsc new file mode 100644 index 0000000..3963887 --- /dev/null +++ b/examples/euler5.rsc @@ -0,0 +1,18 @@ +let gcd = fn(a, b) => match (a, b) + | (a, 0) -> a + | (a, b) -> gcd(b, a % b) + +let lcm = fn(a, b) => (a * b) / gcd(a, b) + +let range = { + let helper = fn (l, r, acc) => + if l == r then acc else helper(l, r - 1, (r, acc)) + + fn (l, r) => helper(l - 1, r, ()) +} + +let foldl = fn (acc, f, ls) => match ls + | () -> acc + | (x, xs) -> foldl(f(acc, x), f, xs) + +let euler5 = foldl(1, lcm, range(1, 20)) diff --git a/examples/map.rsc b/examples/map.rsc new file mode 100644 index 0000000..2c2e7d5 --- /dev/null +++ b/examples/map.rsc @@ -0,0 +1,7 @@ +let m = %{ + 1: 2, + 3: 4, + (5, 6): (7, 8) +} + +let %{1: x, 3: y, (5, 6): z, 467: a} = m diff --git a/lib/eval.ml b/lib/eval.ml index 80d7313..5b671f4 100644 --- a/lib/eval.ml +++ b/lib/eval.ml @@ -28,15 +28,30 @@ let rec bind lhs rhs = let s = (bind (ListPat (FullPat head_pat_ls)) (ValList head_ls)) s in let s = (bind tail_pat (ValList tail_ls)) s in s + | MapPat kv_pairs, Dictionary rhs -> fun s -> + let fetched_pairs = kv_pairs + |> List.map ~f:(fun (k, v) -> let ev_k, _ = (eval_expr k) s in ev_k, v) + |> List.map ~f:(fun (k, v) -> dict_get rhs k, v) + in + let fold_step state (k, v) = (bind v k) state in + List.fold_left ~init:s ~f:fold_step fetched_pairs | WildcardPat, _ -> fun state -> state | _ -> assert false -let rec list_equal_len lhs rhs = match lhs, rhs with +and dict_get dict key = + (* Can probably be replaced by Base.Option functions *) + match Map.find dict (hash_value key) with + | Some found_values -> + let res = List.Assoc.find found_values ~equal:val_eq_bool key in + Option.value ~default:(Tuple []) res + | _ -> Tuple [] + +and list_equal_len lhs rhs = match lhs, rhs with | [], [] -> true | [], _ | _, [] -> false | _::xs, _::ys -> list_equal_len xs ys -let rec pattern_matches pat value = +and pattern_matches pat value = match pat, value with | WildcardPat, _ -> true | SinglePat _, _ -> true @@ -54,7 +69,7 @@ let rec pattern_matches pat value = head_matches && tail_matches | _ -> false -let rec eval_op op lhs rhs = fun s -> +and eval_op op lhs rhs = fun s -> let (lhs, s) = (eval_expr lhs) s in let (rhs, s) = (eval_expr rhs) s in op lhs rhs, s @@ -100,7 +115,7 @@ and eval_lambda_call ?tc:(tail_call=false) call = | None -> begin match call.callee with | "inspect" -> - let (result, _) = (eval_expr call.call_args) state in begin + let (result, state) = (eval_expr ~tc:tail_call call.call_args) state in begin match result with | Tuple [v] -> printf "%s\n" (string_of_val v); @@ -109,6 +124,21 @@ and eval_lambda_call ?tc:(tail_call=false) call = printf "Expected only one argument to inspect"; assert false end + | "get" -> + let (args, state) = (eval_expr call.call_args) state in begin + match args with + | Tuple [Dictionary m; key] -> begin + match Map.find m (hash_value key) with + | Some found_values -> + let res = List.Assoc.find found_values ~equal:val_eq_bool key in + let v = Option.value ~default:(Tuple []) res in + v, state + | None -> (Tuple [], state) + end + | _ -> + printf "get requires two arguments, a list, and a value"; + assert false + end | _ -> printf "Error: function not found: %s\n" call.callee; assert false @@ -175,6 +205,22 @@ and eval_match_expr ?tc:(tail_call=false) match_val match_arms state = printf "No patterns matched in match expression\n"; assert false +and eval_map_expr ?tc:(tail_call=false) map_pairs state = + let fold_fn = fun (map_acc, state) (key_expr, val_expr) -> + let key_val, state = (eval_expr ~tc:tail_call key_expr) state in + let data_val, state = (eval_expr ~tc:tail_call val_expr) state in + let key_hash = hash_value key_val in + let new_data = match Map.find map_acc key_hash with + | Some assoc_list -> (key_val, data_val)::assoc_list + | None -> [(key_val, data_val)] + in + (Map.set map_acc ~key:key_hash ~data:new_data, state) + in + let start_map = Map.empty (module Int) in + let (val_map, state) = + List.fold_left ~init:(start_map, state) ~f:fold_fn map_pairs + in (Dictionary val_map, state) + and eval_list_expr ?tc:(_tail_call=false) ls tail = fun s -> let eval_expr_list ~init = List.fold_left @@ -227,4 +273,5 @@ and eval_expr: expr -> ?tc:bool -> state -> value * state = | IfExpr i -> fun s -> (eval_if_expr ~tc:tail_call i) s | BlockExpr ls -> fun s -> eval_block_expr ~tc:tail_call ls s | MatchExpr m -> fun s -> eval_match_expr ~tc:tail_call m.match_val m.match_arms s + | MapExpr ls -> fun s -> eval_map_expr ~tc:tail_call ls s | ListExpr (ls, tail) -> eval_list_expr ls tail diff --git a/lib/operators.ml b/lib/operators.ml index d242359..92a6225 100644 --- a/lib/operators.ml +++ b/lib/operators.ml @@ -39,6 +39,8 @@ let rec val_eq lhs rhs = match lhs, rhs with Boolean false | _ -> Boolean false +let val_eq_bool l r = val_is_true (val_eq l r) + let val_neq lhs rhs = Boolean (not (val_is_true (val_eq lhs rhs))) let val_lt lhs rhs = match lhs, rhs with diff --git a/lib/parser.ml b/lib/parser.ml index 81223ef..e9c3a29 100644 --- a/lib/parser.ml +++ b/lib/parser.ml @@ -16,6 +16,7 @@ let binary_op_bp = function let prefix_op_bp = 13 let rec complete_expr lhs ls min_bp = match ls with + | Percent::xs -> complete_expr lhs ((Operator Mod)::xs) min_bp | (Operator op)::xs -> let (l_bp, r_bp) = binary_op_bp op in @@ -112,6 +113,33 @@ and parse_pat ls = match ls with | Some tail_pat -> HeadTailPat (pat_list, tail_pat) in ListPat parsed_list_pat, rest + | Percent::LBrace::xs -> + let parse_pair toks = + let key, rest = parse toks 0 in + match rest with + | Colon::more -> + let val_pat, more = parse_pat more in + (key, val_pat), more + | _ -> + printf "Expected a colon\n"; + assert false + in + let rec aux toks acc = match toks with + | RBrace::rest -> acc, rest + | Comma::rest -> + let pair, more = parse_pair rest in + aux more (pair::acc) + | _ -> assert false + in begin match xs with + | RBrace::rest -> MapPat [], rest + | _ -> + let first_pair, rest = parse_pair xs in + let pair_ls, more = aux rest [first_pair] in + MapPat (List.rev pair_ls), more + end + | Percent::_ -> + printf "Expected LBrace\n"; + assert false | (Ident s)::xs -> (SinglePat s, xs) | (Number f)::xs -> (NumberPat f, xs) | Underscore::xs -> (WildcardPat, xs) @@ -201,6 +229,41 @@ and parse_block_expr ls = aux rest (next_expr::acc) in aux ls [] +and parse_map = function + | LBrace::rest -> + let rest = skip_newlines rest in + let parse_key_val ls = + let key_expr, xs = parse ls 0 in + match xs with + | Colon::xs -> + let xs = skip_newlines xs in + let (val_expr, more) = parse xs 0 in + (key_expr, val_expr, more) + | _ -> + printf "Expected comma"; + assert false + in + let rec aux ls acc = match ls with + | RBrace::more -> (acc, more) + | Comma::xs -> + let xs = skip_newlines xs in + let (key_expr, val_expr, rest) = parse_key_val xs in + let rest = skip_newlines rest in + aux rest ((key_expr, val_expr)::acc) + | _ -> assert false + in begin match rest with + | RBrace::xs -> + (MapExpr [], xs) + | _ -> + let k0, v0, rest = parse_key_val rest in + let res, more = aux rest [(k0, v0)] in + (MapExpr (List.rev res), more) + end + + | ls -> + printf "Expected LBrace, got %s\n" (string_of_toks ls); + assert false + and parse_match_expr ls = let (match_val, rest) = parse ls 0 in let rest = skip_newlines rest in @@ -223,8 +286,11 @@ and parse_match_expr ls = match rest with | Newline::xs -> parse_match_arms xs ((arm_pat, arm_expr, cond)::acc) + | Pipe::_ -> + printf "Must break line after each match arm\n"; + assert false | _ -> - printf "Must break line after each match arm"; + printf "Error parsing expression in match arm\n"; assert false end | _ -> @@ -244,25 +310,30 @@ and parse_match_expr ls = and parse: token list -> int -> expr * (token list) = fun s min_bp -> let s = skip_newlines s in match s with - | LBrace::xs -> parse_block_expr xs - | (Ident _)::LParen::_ -> + | LBrace::xs -> + let (block, xs) = parse_block_expr xs in + complete_expr block xs min_bp + | Percent::xs -> + let (map, xs) = parse_map xs in + complete_expr map xs min_bp + | (Ident _)::LParen::_ -> let (call, xs) = parse_lambda_call s in complete_expr call xs min_bp - | LParen::_ -> expr_bp s 0 - | LBracket::_ -> expr_bp s 0 - | (Operator _)::_ -> expr_bp s 0 - | (True|False|Number _| Ident _)::_ -> expr_bp s min_bp - | Let::xs -> parse_let xs - | Fn::_ -> + | LParen::_ -> expr_bp s 0 + | LBracket::_ -> expr_bp s 0 + | (Operator _)::_ -> expr_bp s 0 + | (True|False|Number _| Ident _)::_ -> expr_bp s min_bp + | Let::xs -> parse_let xs + | Fn::_ -> let (lambda_parsed, xs) = parse_lambda s in complete_expr lambda_parsed xs min_bp - | If::_ -> + | If::_ -> let (if_parsed, xs) = parse_if_expr s in complete_expr if_parsed xs min_bp - | Match::xs -> + | Match::xs -> let (match_parsed, xs) = parse_match_expr xs in complete_expr match_parsed xs min_bp - | _ -> + | _ -> printf "Expected expression, got (%s)\n" (string_of_toks s); assert false diff --git a/lib/scanner.ml b/lib/scanner.ml index 03253da..6269378 100644 --- a/lib/scanner.ml +++ b/lib/scanner.ml @@ -28,6 +28,8 @@ type token = | Comma | Pipe | Underscore + | Colon + | Percent let is_numeric d = Base.Char.is_digit d || phys_equal d '.' let is_identic c = Base.Char.is_alphanum c || phys_equal c '_' @@ -72,7 +74,7 @@ and scan_ls = function | '&'::'&'::xs -> Operator And :: scan_ls xs | '='::'='::xs -> Operator EQ :: scan_ls xs | '!'::'='::xs -> Operator NEQ :: scan_ls xs - | '%'::xs -> Operator Mod :: scan_ls xs + | '%'::xs -> Percent :: scan_ls xs | '^'::xs -> Operator Head :: scan_ls xs | '$'::xs -> Operator Tail :: scan_ls xs | '!'::xs -> Operator Not :: scan_ls xs @@ -89,6 +91,7 @@ and scan_ls = function | '|'::xs -> Pipe :: scan_ls xs | 'T'::xs -> True :: scan_ls xs | 'F'::xs -> False :: scan_ls xs + | ':'::xs -> Colon :: scan_ls xs | d::_ as ls when Char.is_digit d -> scan_digit ls | i::_ as ls when Char.is_alpha i -> scan_ident ls | ls -> @@ -134,6 +137,8 @@ let string_of_tok = function | Match -> "Match" | MatchArrow -> "MatchArrow" | Underscore -> "Underscore" + | Colon -> "Colon" + | Percent -> "Percent" let string_of_toks ls = String.concat ~sep:" " (List.map ~f:string_of_tok ls) let print_toks ls = ls |> string_of_toks |> printf "%s\n" diff --git a/lib/types.ml b/lib/types.ml index 88ca08d..3d2ee6f 100644 --- a/lib/types.ml +++ b/lib/types.ml @@ -1,5 +1,6 @@ open Base open Printf +open Stdio type operator = | Add @@ -24,12 +25,14 @@ type value = | ValList of value list | Lambda of lambda | Thunk of {thunk_fn: lambda; thunk_args: value; thunk_fn_name: string} + | Dictionary of (int, (value * value) list, Int.comparator_witness) Map.t and pattern = | SinglePat of string | NumberPat of float | TuplePat of pattern list | ListPat of list_pattern + | MapPat of (expr * pattern) list | WildcardPat and list_pattern = @@ -54,6 +57,7 @@ and expr = | TupleExpr of expr list | BlockExpr of expr list | MatchExpr of {match_val: expr; match_arms: (pattern * expr * expr option) list} + | MapExpr of (expr * expr) list | ListExpr of (expr list) * (expr option) let rec string_of_val = function @@ -63,6 +67,7 @@ let rec string_of_val = function | ValList ls -> "[" ^ String.concat ~sep:", " (List.map ~f:string_of_val ls) ^ "]" | Lambda _ -> "Lambda" | Thunk _ -> "Thunk" + | Dictionary _ -> "Map" let rec string_of_expr = function | Atomic v -> string_of_val v @@ -80,6 +85,7 @@ let rec string_of_expr = function | IfExpr _ -> "IfExpr" | BlockExpr ls -> sprintf "{\n\t%s\n}" (String.concat ~sep:"\n\t" (List.map ~f:string_of_expr ls)) | MatchExpr _ -> "MatchExpr" + | MapExpr _ -> "Map" and string_of_list_pat = function | FullPat ls -> "[" ^ (String.concat ~sep:", " (List.map ~f:string_of_pat ls)) ^ "]" @@ -88,6 +94,15 @@ and string_of_list_pat = function and string_of_pat = function | SinglePat s -> s | ListPat lp -> (string_of_list_pat lp) + | MapPat _ -> "MapPat" | NumberPat f -> Float.to_string f | TuplePat ls -> sprintf "(%s)" (String.concat ~sep:", " (List.map ~f:string_of_pat ls)) | WildcardPat -> "_" + +let rec hash_value = function + | Number f -> Hashtbl.hash (0, f) + | Boolean b -> Hashtbl.hash (1, b) + | Tuple ls -> Hashtbl.hash (List.map ~f:hash_value ls) + | _ -> + printf "Tried to hash an unhashable type"; + assert false diff --git a/test/dune b/test/dune index b27de3b..0e1b572 100644 --- a/test/dune +++ b/test/dune @@ -1,3 +1,3 @@ (tests - (names fib tuple block comments tailrec euler match_expr) + (names fib tuple block comments tailrec euler match_expr map) (libraries base stdio rustscript)) diff --git a/test/euler.ml b/test/euler.ml index 4637fe9..297234a 100644 --- a/test/euler.ml +++ b/test/euler.ml @@ -21,4 +21,8 @@ let () = Map.empty (module String) |> run_file (test_file "euler3.rsc") in assert_equal_expressions "euler3" "6857" state; + let state = + Map.empty (module String) |> run_file (test_file "euler5.rsc") in + assert_equal_expressions "euler5" "232792560" state; + printf "Passed\n" diff --git a/test/map.ml b/test/map.ml new file mode 100644 index 0000000..c2938ba --- /dev/null +++ b/test/map.ml @@ -0,0 +1,21 @@ +open Base +open Stdio + +open Rustscript.Run +open Util + +let () = + let state = + Map.empty (module String) |> run_file (test_file "map.rsc") in + + assert_equal_expressions "get(m, 1)" "2" state; + assert_equal_expressions "get(m, 3)" "4" state; + assert_equal_expressions "get(m, (5, 6))" "(7, 8)" state; + assert_equal_expressions "get(m, 467))" "()" state; + + assert_equal_expressions "get(m, 1)" "x" state; + assert_equal_expressions "get(m, 3)" "y" state; + assert_equal_expressions "get(m, (5, 6))" "z" state; + assert_equal_expressions "get(m, 467))" "a" state; + + printf "Passed\n"