diff --git a/bin/rustscript_cli.ml b/bin/rustscript_cli.ml index 11458d2..b919d31 100644 --- a/bin/rustscript_cli.ml +++ b/bin/rustscript_cli.ml @@ -9,10 +9,10 @@ let rec repl state = | Some "\n" -> () | None -> () | Some line -> - match Rustscript.Run.eval state line with + match Rustscript.Run.eval {static_atoms = []} state line with | (Tuple [], new_state) -> repl new_state | (evaled, new_state) -> - printf "%s\n" (Rustscript.Types.string_of_val evaled); + printf "%s\n" (Rustscript.Types.string_of_val {static_atoms = []} evaled); Out_channel.flush stdout; repl new_state diff --git a/editor/rustscript.vim b/editor/rustscript.vim index 7e3a7b4..0cace2e 100644 --- a/editor/rustscript.vim +++ b/editor/rustscript.vim @@ -49,6 +49,9 @@ syntax match rscBool "T" syntax match rscBool "F" highlight link rscBool Boolean +syntax match rscAtom "\v:\[A-za-z][A-za-z0-9_]+" +highlight link rscAtom Constant + syntax match rscIdentifier "\v[A-Za-z@!?][A-Za-z0-9@!?]*" syntax match rscIdentifier "\v_" highlight link rscIdentifier Identifier diff --git a/examples/atom.rsc b/examples/atom.rsc new file mode 100644 index 0000000..cb815b2 --- /dev/null +++ b/examples/atom.rsc @@ -0,0 +1,9 @@ +let x = :a1 +let y = :b2 + +let m = %{:a: 1, :b: %{:a: 3, :b: 4, :c: 5}} +let %{:a: i, :b: m2} = m + +let %{:a: z, :b: x, :c: y} = m2 + +inspect((i, z, x, y)) diff --git a/examples/mergesort.rsc b/examples/mergesort.rsc index 77cee5a..25042be 100644 --- a/examples/mergesort.rsc +++ b/examples/mergesort.rsc @@ -15,4 +15,4 @@ let sort = fn(ls) => { loop([[x] for x in ls]) } -inspect(sort([5, 4, 12, 17, 6, 7, 4, 3, 2, 8, 9])) +# inspect(sort([5, 4, 12, 17, 6, 7, 4, 3, 2, 8, 9])) diff --git a/lib/dune b/lib/dune index b05e4e6..16543c8 100644 --- a/lib/dune +++ b/lib/dune @@ -1,7 +1,7 @@ (library (public_name rustscript) (libraries base stdio) - (modules run types parser scanner eval operators)) + (modules run types parser scanner eval operators preprocess)) (env (release diff --git a/lib/eval.ml b/lib/eval.ml index 52deecc..79580f1 100644 --- a/lib/eval.ml +++ b/lib/eval.ml @@ -3,7 +3,11 @@ open Stdio open Base open Operators -let rec bind lhs rhs = +(* Throughout, static state is abbreviated as ss *) + +let rec bind lhs rhs ss = + let bind lhs rhs = bind lhs rhs ss in + let pattern_matches lhs rhs = pattern_matches lhs rhs ss in (* printf "Binding %s to %s\n" (string_of_pat lhs) (string_of_val rhs); *) match lhs, rhs with | SinglePat s, _ -> fun state -> @@ -25,7 +29,7 @@ let rec bind lhs rhs = printf "\n"; printf "Tried to bind %s of len %d to %s of len %d\n" (string_of_pat lhs) (List.length lhs_ls) - (string_of_val rhs) (List.length rhs_ls); + (string_of_val ss rhs) (List.length rhs_ls); assert false end | (ListPat (HeadTailPat (head_pat_ls, tail_pat))), ValList rhs_ls -> fun s -> @@ -35,19 +39,19 @@ let rec bind lhs rhs = s | MapPat kv_pairs, Dictionary rhs -> fun s -> let fetched_pairs = kv_pairs - |> List.map ~f:(fun (k, v) -> let ev_k, _ = (eval_expr k) s in ev_k, v) - |> List.map ~f:(fun (k, v) -> dict_get rhs k, v) + |> List.map ~f:(fun (k, v) -> let ev_k, _ = (eval_expr k ss) s in ev_k, v) + |> List.map ~f:(fun (k, v) -> dict_get rhs k ss, v) in let fold_step state (k, v) = (bind v k) state in List.fold_left ~init:s ~f:fold_step fetched_pairs | WildcardPat, _ -> fun state -> state | _ -> assert false -and dict_get dict key = +and dict_get dict key ss = (* Can probably be replaced by Base.Option functions *) match Map.find dict (hash_value key) with | Some found_values -> - let res = List.Assoc.find found_values ~equal:val_eq_bool key in + let res = List.Assoc.find found_values ~equal:(fun a b -> val_eq_bool a b ss) key in Option.value ~default:(Tuple []) res | _ -> Tuple [] @@ -56,7 +60,9 @@ and list_equal_len lhs rhs = match lhs, rhs with | [], _ | _, [] -> false | _::xs, _::ys -> list_equal_len xs ys -and pattern_matches pat value state = +and pattern_matches pat value ss state = + let pattern_matches pat value = pattern_matches pat value ss in + let eval_expr expr ?tc:(tc=false) = eval_expr expr ss ~tc:tc in match pat, value with | WildcardPat, _ -> true | SinglePat _, _ -> true @@ -77,15 +83,15 @@ and pattern_matches pat value state = | (MapPat kv_pairs, Dictionary rhs) -> let fetched_pairs = kv_pairs |> List.map ~f:(fun (k, v) -> let ev_k, _ = (eval_expr k) state in ev_k, v) - |> List.map ~f:(fun (k, v) -> dict_get rhs k, v) + |> List.map ~f:(fun (k, v) -> dict_get rhs k ss, v) in List.for_all ~f:(fun (k, v) -> pattern_matches v k state) fetched_pairs | _ -> false -and inspect_builtin (args, state) = +and inspect_builtin (args, state) ss = match args with | Tuple [v] -> - printf "%s\n" (string_of_val v); + printf "%s\n" (string_of_val ss v); (v, state) | _ -> printf "Expected only one argument to inspect"; @@ -102,12 +108,12 @@ and range_builtin (args, state) = printf "Expected three integer arguments to range_step"; assert false -and fold_builtin (args, state) = +and fold_builtin (args, state) ss = match args with | Tuple [init; Lambda fn; ValList ls] -> let call_fn = fun args -> let lambda_call = Thunk {thunk_fn = fn; thunk_args= args; thunk_fn_name = ""} in - let res, _ = unwrap_thunk lambda_call state in + let res, _ = unwrap_thunk lambda_call state ss in res in let fold_result = @@ -121,14 +127,14 @@ and fold_builtin (args, state) = printf "Expected (init, fn, ls) as arguments to fold\n"; assert false -and eval_op op lhs rhs = fun s -> - let (lhs, s) = (eval_expr lhs) s in - let (rhs, s) = (eval_expr rhs) s in - op lhs rhs, s +and eval_op op lhs rhs ss = fun s -> + let (lhs, s) = (eval_expr lhs ss) s in + let (rhs, s) = (eval_expr rhs ss) s in + op lhs rhs ss, s -and eval_prefix_op op rhs = fun s -> - let (rhs, s) = (eval_expr rhs) s in - op rhs, s +and eval_prefix_op op rhs ss = fun s -> + let (rhs, s) = (eval_expr rhs ss) s in + op rhs ss, s and eval_ident name = fun state -> match Map.find state name with @@ -137,45 +143,45 @@ and eval_ident name = fun state -> printf "Error: variable not found: %s\n" name; assert false -and eval_let lhs rhs = fun state -> - let (evaled, new_state) = (eval_expr rhs) state in - let new_state = (bind lhs evaled) new_state in +and eval_let lhs rhs ss = fun state -> + let (evaled, new_state) = (eval_expr rhs ss) state in + let new_state = (bind lhs evaled ss) new_state in (Tuple [], new_state) and eval_lambda_def e args = fun s -> (Lambda {lambda_expr = e; lambda_args = args; enclosed_state = s}), s -and unwrap_thunk thunk state = match thunk with +and unwrap_thunk thunk state ss = match thunk with | Thunk {thunk_fn = thunk_fn; thunk_args = thunk_args; thunk_fn_name = thunk_fn_name} -> - let inner_state = (bind thunk_fn.lambda_args thunk_args) thunk_fn.enclosed_state in + let inner_state = (bind thunk_fn.lambda_args thunk_args ss) thunk_fn.enclosed_state in let inner_state = Map.set inner_state ~key:thunk_fn_name ~data:(Lambda thunk_fn) in - let (new_thunk, _) = (eval_expr ~tc:true thunk_fn.lambda_expr) inner_state in - unwrap_thunk new_thunk state + let (new_thunk, _) = (eval_expr ~tc:true thunk_fn.lambda_expr ss) inner_state in + unwrap_thunk new_thunk state ss | value -> value, state -and eval_lambda_call ?tc:(tail_call=false) call = +and eval_lambda_call ?tc:(tail_call=false) call ss = fun (state: state) -> match Map.find state call.callee with | Some(Lambda lambda_val) -> begin - let (evaled, _) = (eval_expr call.call_args) state in + let (evaled, _) = (eval_expr call.call_args ss) state in let thunk = Thunk {thunk_fn = lambda_val; thunk_args = evaled; thunk_fn_name = call.callee} in if tail_call then (thunk, state) else - let res, _ = unwrap_thunk thunk state in + let res, _ = unwrap_thunk thunk state ss in (res, state) end | None -> begin match call.callee with - | "inspect" -> inspect_builtin ((eval_expr call.call_args) state) - | "range_step" -> range_builtin ((eval_expr call.call_args) state) - | "fold" -> fold_builtin ((eval_expr call.call_args) state) + | "inspect" -> inspect_builtin ((eval_expr call.call_args ss) state) ss + | "range_step" -> range_builtin ((eval_expr call.call_args ss) state) + | "fold" -> fold_builtin ((eval_expr call.call_args ss) state) ss | "get" -> - let (args, state) = (eval_expr call.call_args) state in begin + let (args, state) = (eval_expr call.call_args ss) state in begin match args with | Tuple [Dictionary m; key] -> begin match Map.find m (hash_value key) with | Some found_values -> - let res = List.Assoc.find found_values ~equal:val_eq_bool key in + let res = List.Assoc.find found_values ~equal:(fun a b -> val_eq_bool a b ss) key in let v = Option.value ~default:(Tuple []) res in v, state | None -> (Tuple [], state) @@ -190,23 +196,23 @@ and eval_lambda_call ?tc:(tail_call=false) call = end | _ -> assert false -and eval_tuple_expr ls state = +and eval_tuple_expr ls ss state = let (eval_ls, state) = List.fold_left ~init:([], state) - ~f:(fun (acc, s) e -> let (ev, s) = eval_expr e s in (ev::acc, s)) + ~f:(fun (acc, s) e -> let (ev, s) = (eval_expr e ss) s in (ev::acc, s)) ls in Tuple (List.rev eval_ls), state -and eval_if_expr ?tc:(tail_call=false) if_expr = fun state -> - match (eval_expr if_expr.cond) state with - | Boolean true, state -> (eval_expr ~tc:tail_call if_expr.then_expr) state +and eval_if_expr ?tc:(tail_call=false) if_expr ss = fun state -> + match (eval_expr if_expr.cond ss) state with + | Boolean true, state -> (eval_expr ~tc:tail_call if_expr.then_expr ss) state | Boolean false, state -> - (eval_expr ~tc:tail_call if_expr.else_expr) state + (eval_expr ~tc:tail_call if_expr.else_expr ss) state | _ -> assert false -and eval_block_expr ?tc:(tail_call=false) ls state = +and eval_block_expr ?tc:(tail_call=false) ls ss state = let (res, _) = let len = List.length ls in match List.split_n ls (len - 1) with @@ -214,23 +220,25 @@ and eval_block_expr ?tc:(tail_call=false) ls state = let block_state = List.fold_left ~init:state - ~f:(fun line_state e -> let _, s = (eval_expr e) line_state in s) + ~f:(fun line_state e -> let _, s = (eval_expr e ss) line_state in s) exprs in - (eval_expr ~tc:tail_call last_expr) block_state + (eval_expr ~tc:tail_call last_expr ss) block_state | _ -> assert false in (res, state) -and eval_match_expr ?tc:(tail_call=false) match_val match_arms state = - let (match_val, state) = (eval_expr match_val) state in +and eval_match_expr ?tc:(tail_call=false) match_val match_arms ss state = + let (match_val, state) = (eval_expr match_val ss) state in + let eval_expr expr ?tc:(tc=false) = eval_expr expr ss ~tc:tc in + let bind lhs rhs = bind lhs rhs ss in let result_state_opt = List.find_map ~f:( fun (pat, arm_expr, cond) -> - if pattern_matches pat match_val state then + if pattern_matches pat match_val ss state then match cond with | Some cond -> let inner_state = (bind pat match_val) state in let cond_eval, inner_state = (eval_expr cond) inner_state in - if val_is_true cond_eval then + if val_is_true cond_eval ss then let (result, _) = (eval_expr ~tc:tail_call arm_expr) inner_state in Some (result, state) else @@ -250,10 +258,10 @@ and eval_match_expr ?tc:(tail_call=false) match_val match_arms state = printf "No patterns matched in match expression\n"; assert false -and eval_map_expr ?tc:(tail_call=false) map_pairs tail_map state = +and eval_map_expr ?tc:(tail_call=false) map_pairs tail_map ss state = let fold_fn = fun (map_acc, state) (key_expr, val_expr) -> - let key_val, state = (eval_expr ~tc:tail_call key_expr) state in - let data_val, state = (eval_expr ~tc:tail_call val_expr) state in + let key_val, state = (eval_expr ~tc:tail_call key_expr ss) state in + let data_val, state = (eval_expr ~tc:tail_call val_expr ss) state in let key_hash = hash_value key_val in let new_data = match Map.find map_acc key_hash with | Some assoc_list -> (key_val, data_val)::assoc_list @@ -263,7 +271,7 @@ and eval_map_expr ?tc:(tail_call=false) map_pairs tail_map state = in let tail_map, state = match tail_map with | Some e -> - let m, state = (eval_expr e) state in + let m, state = (eval_expr e ss) state in Some m, state | None -> None, state in @@ -278,14 +286,14 @@ and eval_map_expr ?tc:(tail_call=false) map_pairs tail_map state = List.fold_left ~init:(start_map, state) ~f:fold_fn map_pairs in (Dictionary val_map, state) -and eval_list_expr ?tc:(_tail_call=false) ls tail = fun s -> +and eval_list_expr ?tc:(_tail_call=false) ls tail ss = fun s -> let eval_expr_list ~init = List.fold_left ~init:init - ~f:(fun (acc, s) e -> let (ev, s) = eval_expr e s in (ev::acc, s)) + ~f:(fun (acc, s) e -> let (ev, s) = (eval_expr e ss) s in (ev::acc, s)) in let eval_prepend ls tail = - let (tail_eval, s) = (eval_expr tail) s in + let (tail_eval, s) = (eval_expr tail ss) s in match tail_eval with | ValList tail_ls -> let (eval_ls, state) = eval_expr_list ~init:(tail_ls, s) (List.rev ls) in @@ -300,8 +308,10 @@ and eval_list_expr ?tc:(_tail_call=false) ls tail = fun s -> let (eval_ls, state) = eval_expr_list ~init:([], s) ls in ValList (List.rev eval_ls), state -and eval_expr: expr -> ?tc:bool -> state -> value * state = - fun expr ?tc:(tail_call=false) -> +and eval_expr: expr -> static_state -> ?tc:bool -> state -> value * state = + fun expr ss ?tc:(tail_call=false) -> + let eval_prefix_op op e = eval_prefix_op op e ss in + let eval_op op lhs rhs = eval_op op lhs rhs ss in (* printf "Evaluating: %s\n" (string_of_expr expr); *) match expr with | Atomic v -> fun s -> v, s @@ -326,11 +336,14 @@ and eval_expr: expr -> ?tc:bool -> state -> value * state = | Binary ({op = Mod; _} as e) -> eval_op val_mod e.lhs e.rhs | Binary ({op = _op; _}) -> assert false (* Invalid binary op *) | LambdaDef d -> eval_lambda_def d.lambda_def_expr d.lambda_def_args - | Let l -> fun s -> (eval_let l.assignee l.assigned_expr) s - | TupleExpr ls -> fun s -> eval_tuple_expr ls s - | LambdaCall l -> fun s -> (eval_lambda_call ~tc:tail_call l) s - | IfExpr i -> fun s -> (eval_if_expr ~tc:tail_call i) s - | BlockExpr ls -> fun s -> eval_block_expr ~tc:tail_call ls s - | MatchExpr m -> fun s -> eval_match_expr ~tc:tail_call m.match_val m.match_arms s - | MapExpr (ls, tail) -> fun s -> eval_map_expr ~tc:tail_call ls tail s - | ListExpr (ls, tail) -> eval_list_expr ls tail + | Let l -> fun s -> (eval_let l.assignee l.assigned_expr ss) s + | TupleExpr ls -> fun s -> (eval_tuple_expr ls ss) s + | LambdaCall l -> fun s -> (eval_lambda_call ~tc:tail_call l ss) s + | IfExpr i -> fun s -> (eval_if_expr ~tc:tail_call i ss) s + | BlockExpr ls -> fun s -> (eval_block_expr ~tc:tail_call ls ss) s + | MatchExpr m -> fun s -> (eval_match_expr ~tc:tail_call m.match_val m.match_arms ss) s + | MapExpr (ls, tail) -> fun s -> (eval_map_expr ~tc:tail_call ls tail ss) s + | ListExpr (ls, tail) -> eval_list_expr ls tail ss + | UnresolvedAtom n -> + printf "Found unresolved atom %s\n" n; + assert false diff --git a/lib/operators.ml b/lib/operators.ml index bcc70be..c3fae27 100644 --- a/lib/operators.ml +++ b/lib/operators.ml @@ -2,91 +2,92 @@ open Types open Stdio open Base -let val_add lhs rhs = match lhs, rhs with +let val_add lhs rhs ss = match lhs, rhs with | Number lhs, Number rhs -> Number (lhs +. rhs) | ValList lhs, ValList rhs -> ValList (lhs @ rhs) | _ -> - printf "Invalid Add: lhs = %s, rhs = %s\n" (string_of_val lhs) (string_of_val rhs); + printf "Invalid Add: lhs = %s, rhs = %s\n" (string_of_val ss lhs) (string_of_val ss rhs); assert false -let val_sub lhs rhs = match lhs, rhs with +let val_sub lhs rhs _ss = match lhs, rhs with | Number lhs, Number rhs -> Number (lhs -. rhs) | _ -> assert false -let val_mul lhs rhs = match lhs, rhs with +let val_mul lhs rhs ss = match lhs, rhs with | Number lhs, Number rhs -> Number (lhs *. rhs) | _ -> - printf "Invalid Mul: lhs = %s, rhs = %s\n" (string_of_val lhs) (string_of_val rhs); + printf "Invalid Mul: lhs = %s, rhs = %s\n" (string_of_val ss lhs) (string_of_val ss rhs); assert false -let val_div lhs rhs = match lhs, rhs with +let val_div lhs rhs _ss = match lhs, rhs with | Number lhs, Number rhs -> Number (lhs /. rhs) | _ -> assert false -let val_is_true = function +let val_is_true v _ss = match v with | Boolean true -> true | _ -> false -let rec val_eq lhs rhs = match lhs, rhs with +let rec val_eq lhs rhs ss = match lhs, rhs with | Number lhs, Number rhs -> Boolean (Float.equal lhs rhs) | Boolean lhs, Boolean rhs -> Boolean (Bool.equal lhs rhs) | (Tuple lhs, Tuple rhs)|(ValList lhs, ValList rhs) -> begin match List.zip lhs rhs with | Ok zipped -> - let res = List.for_all zipped ~f:(fun (a, b) -> val_is_true (val_eq a b)) + let res = List.for_all zipped ~f:(fun (a, b) -> val_is_true (val_eq a b ss) ss) in Boolean res | _ -> Boolean false end + | Atom lhs, Atom rhs -> Boolean (Int.equal lhs rhs) | _ -> Boolean false -let val_eq_bool l r = val_is_true (val_eq l r) +let val_eq_bool l r ss = val_is_true (val_eq l r ss) ss -let val_neq lhs rhs = Boolean (not (val_is_true (val_eq lhs rhs))) +let val_neq lhs rhs ss = Boolean (not (val_is_true (val_eq lhs rhs ss) ss)) -let val_lt lhs rhs = match lhs, rhs with +let val_lt lhs rhs _ss = match lhs, rhs with | Number lhs, Number rhs -> Boolean (Float.compare lhs rhs < 0) | _ -> assert false -let val_gt lhs rhs = match lhs, rhs with +let val_gt lhs rhs _ss = match lhs, rhs with | Number lhs, Number rhs -> Boolean (Float.compare lhs rhs > 0) | _ -> assert false -let val_leq lhs rhs = match lhs, rhs with +let val_leq lhs rhs _ss = match lhs, rhs with | Number lhs, Number rhs -> Boolean (Float.compare lhs rhs <= 0) | _ -> assert false -let val_geq lhs rhs = match lhs, rhs with +let val_geq lhs rhs _ss = match lhs, rhs with | Number lhs, Number rhs -> Boolean (Float.compare lhs rhs >= 0) | _ -> assert false -let val_and lhs rhs = match lhs, rhs with +let val_and lhs rhs _ss = match lhs, rhs with | Boolean lhs, Boolean rhs -> Boolean (lhs && rhs) | _ -> assert false -let val_or lhs rhs = match lhs, rhs with +let val_or lhs rhs _ss = match lhs, rhs with | Boolean lhs, Boolean rhs -> Boolean (lhs || rhs) | _ -> assert false -let val_mod lhs rhs = match lhs, rhs with +let val_mod lhs rhs _ss = match lhs, rhs with | Number lhs, Number rhs -> Number (Float.mod_float lhs rhs) | _ -> assert false -let val_negate rhs = match rhs with +let val_negate rhs _ss = match rhs with | Number rhs -> Number (~-.rhs) | _ -> assert false -let val_negate_bool rhs = match rhs with +let val_negate_bool rhs _ss = match rhs with | Boolean rhs -> Boolean (not rhs) | _ -> assert false -let val_list_head rhs = match rhs with +let val_list_head rhs ss = match rhs with | ValList (head::_) -> head | _ -> - printf "Invalid Head: rhs = %s\n" (string_of_val rhs); + printf "Invalid Head: rhs = %s\n" (string_of_val ss rhs); assert false -let val_list_tail rhs = match rhs with +let val_list_tail rhs ss = match rhs with | ValList (_::tail) -> ValList tail | _ -> - printf "Invalid Tail: rhs = %s\n" (string_of_val rhs); + printf "Invalid Tail: rhs = %s\n" (string_of_val ss rhs); assert false diff --git a/lib/parser.ml b/lib/parser.ml index f8276d5..a7a3733 100644 --- a/lib/parser.ml +++ b/lib/parser.ml @@ -400,6 +400,7 @@ and parse: token list -> int -> expr * (token list) = fun s min_bp -> | Percent::xs -> let (map, xs) = parse_map xs in complete_expr map xs min_bp + | Colon::(Ident n)::xs -> complete_expr (UnresolvedAtom n) xs min_bp | (Ident _)::LParen::_ -> let (call, xs) = parse_lambda_call s in complete_expr call xs min_bp diff --git a/lib/preprocess.ml b/lib/preprocess.ml new file mode 100644 index 0000000..764f825 --- /dev/null +++ b/lib/preprocess.ml @@ -0,0 +1,83 @@ +open Base +open Types + +let rec find_atoms: expr -> (string * int) list -> (string * int) list = + fun expr atoms -> match expr with + | Binary b -> atoms |> find_atoms b.lhs |> find_atoms b.rhs + | Prefix p -> atoms |> find_atoms p.rhs + | Let l -> atoms |> find_atoms l.assigned_expr + | LambdaDef d -> atoms |> find_atoms d.lambda_def_expr + | LambdaCall c -> atoms |> find_atoms c.call_args + | IfExpr i -> atoms |> find_atoms i.cond |> find_atoms i.then_expr |> find_atoms i.else_expr + | TupleExpr ls -> List.fold_left ~init:atoms ~f:(fun atoms e -> atoms |> find_atoms e) ls + | BlockExpr ls -> List.fold_left ~init:atoms ~f:(fun atoms e -> atoms |> find_atoms e) ls + | MatchExpr m -> atoms |> find_atoms m.match_val + | MapExpr (m, _) -> List.fold_left ~init:atoms ~f:(fun atoms (a, b) -> atoms |> find_atoms a |> find_atoms b) m + | ListExpr (ls, _) -> List.fold_left ~init:atoms ~f:(fun atoms e -> atoms |> find_atoms e) ls + | UnresolvedAtom s -> begin match List.Assoc.find atoms ~equal:String.equal s with + | Some _ -> atoms + | None -> (s, List.length atoms)::atoms + end + | _ -> atoms + +let rec resolve_pat_atoms ss p = + let resolve = resolve_pat_atoms ss in + let resolve_expr = resolve_atoms ss in + match p with + | SinglePat _ | NumberPat _ | AtomPat _ | WildcardPat -> p + | TuplePat ls -> TuplePat (List.map ~f:resolve ls) + | ListPat (FullPat ls) -> ListPat (FullPat (List.map ~f:resolve ls)) + | ListPat (HeadTailPat (ls, p)) -> ListPat (HeadTailPat (List.map ~f:resolve ls, resolve p)) + | MapPat pairs -> MapPat (List.map ~f:(fun (k, v) -> resolve_expr k, resolve v) pairs) + | OrPat (l, r) -> OrPat (resolve l, resolve r) + | AsPat (p, n) -> AsPat (resolve p, n) + | UnresolvedAtomPat s -> AtomPat (List.Assoc.find_exn ss.static_atoms ~equal:String.equal s) + +and resolve_atoms ss e = + let resolve = resolve_atoms ss in + match e with + | Atomic _ | Ident _ -> e + | Binary b -> + let lhs = resolve b.lhs in + let rhs = resolve b.rhs in + Binary { lhs; rhs; op = b.op } + | Prefix p -> + let rhs = resolve p.rhs in + Prefix { rhs; op = p.op } + | Let l -> + let assignee = resolve_pat_atoms ss l.assignee in + let assigned_expr = resolve l.assigned_expr in + Let { assignee; assigned_expr } + | LambdaDef d -> + let lambda_def_expr = resolve d.lambda_def_expr in + let lambda_def_args = resolve_pat_atoms ss d.lambda_def_args in + LambdaDef { lambda_def_expr; lambda_def_args } + | LambdaCall c -> + let call_args = resolve c.call_args in + LambdaCall { call_args; callee = c.callee } + | IfExpr i -> + let cond = resolve i.cond in + let then_expr = resolve i.then_expr in + let else_expr = resolve i.else_expr in + IfExpr { cond; then_expr; else_expr } + | TupleExpr ls -> + TupleExpr (List.map ~f:resolve ls) + | BlockExpr ls -> + BlockExpr (List.map ~f:resolve ls) + | MatchExpr m -> + let match_val = resolve m.match_val in + let match_arms = + List.map + ~f:(fun (p, a, b) -> (resolve_pat_atoms ss p, resolve a, Option.map ~f:resolve b)) + m.match_arms + in + MatchExpr { match_val; match_arms } + | MapExpr (pairs, tail) -> + let pairs = List.map ~f:(fun (a, b) -> (resolve a, resolve b)) pairs in + let tail = Option.map ~f:resolve tail in + MapExpr (pairs, tail) + | ListExpr (ls, tail) -> + let ls = List.map ~f:resolve ls in + let tail = Option.map ~f:resolve tail in + ListExpr (ls, tail) + | UnresolvedAtom s -> Atomic (Atom (List.Assoc.find_exn ss.static_atoms ~equal:String.equal s)) diff --git a/lib/run.ml b/lib/run.ml index 00e6f88..83ea26e 100644 --- a/lib/run.ml +++ b/lib/run.ml @@ -3,16 +3,17 @@ open Stdio open Types open Scanner -let eval state s = +let eval ss state s = let (parsed, _remaining) = Parser.parse_str s in - let eval_closure = Eval.eval_expr parsed in + let eval_closure = Eval.eval_expr parsed ss in eval_closure state -let run_line state line = - match eval state line with +(* TODO: Make it support atoms *) +let run_line ss state line = + match eval ss state line with | (Tuple [], new_state) -> new_state | (evaled, new_state) -> - printf "%s\n" (string_of_val evaled); + printf "%s\n" (string_of_val ss evaled); Out_channel.flush Stdio.stdout; new_state @@ -67,7 +68,8 @@ let enumerate_rev_rsc = let enumerate_rsc = "let enumerate = fn(ls) => reverse(enumerate_rev(ls))" let load_stdlib state = - let run_line_swap line state = run_line state line in + let ss = { static_atoms = [] } in + let run_line_swap line state = run_line ss state line in state |> run_line_swap "let sum = fn(ls) => fold(0, fn(a, b) => a + b, ls)" |> run_line_swap reverse_rsc @@ -82,15 +84,24 @@ let load_stdlib state = |> run_line_swap enumerate_rev_rsc |> run_line_swap enumerate_rsc -let default_state = Map.empty(module String) |> load_stdlib +let default_state: state = Map.empty (module String) |> load_stdlib -let run_file filename state = +let run_file filename state = let in_stream = In_channel.create filename in let in_string = In_channel.input_all in_stream in let tokens = in_string |> Scanner.scan |> skip_newlines in - let rec aux (parsed, remaining) state = - let remaining = skip_newlines remaining in - match (Eval.eval_expr parsed state), remaining with - | (_, new_state), [] -> new_state - | (_, new_state), remaining -> aux (Parser.parse remaining 0) new_state - in aux (Parser.parse tokens 0) state + let expr_ls = + let rec aux remaining acc = match (skip_newlines remaining) with + | [] -> acc + | remaining -> + let (parsed, remaining) = Parser.parse remaining 0 in + aux remaining (parsed::acc) + in + let (parsed, remaining) = Parser.parse tokens 0 in + List.rev (aux remaining [parsed]) + in + let block = BlockExpr expr_ls in + let ss = { static_atoms = Preprocess.find_atoms block [] } in + let expr_ls = List.map ~f:(Preprocess.resolve_atoms ss) expr_ls in + let fold_step = fun state e -> let _, s = (Eval.eval_expr e ss) state in s in + ss, List.fold_left ~init:state ~f:fold_step expr_ls diff --git a/lib/types.ml b/lib/types.ml index 579ed4d..7da0a0e 100644 --- a/lib/types.ml +++ b/lib/types.ml @@ -28,10 +28,13 @@ type value = | Lambda of lambda | Thunk of {thunk_fn: lambda; thunk_args: value; thunk_fn_name: string} | Dictionary of (int, (value * value) list, Int.comparator_witness) Map.t + | Atom of int and pattern = | SinglePat of string | NumberPat of float + | UnresolvedAtomPat of string + | AtomPat of int | TuplePat of pattern list | ListPat of list_pattern | MapPat of (expr * pattern) list @@ -43,6 +46,7 @@ and list_pattern = | FullPat of pattern list | HeadTailPat of (pattern list) * pattern +and static_state = { static_atoms: (string * int) list } and state = (string, value, String.comparator_witness) Map.t and lambda = {lambda_expr: expr; lambda_args: pattern; enclosed_state: state} @@ -57,14 +61,17 @@ and expr = | Let of {assignee: pattern; assigned_expr: expr} | LambdaDef of {lambda_def_expr: expr; lambda_def_args: pattern} | LambdaCall of lambda_call - | IfExpr of if_expr + | IfExpr of if_expr | TupleExpr of expr list | BlockExpr of expr list | MatchExpr of {match_val: expr; match_arms: (pattern * expr * expr option) list} | MapExpr of ((expr * expr) list) * (expr option) | ListExpr of (expr list) * (expr option) + | UnresolvedAtom of string -let rec string_of_val = function +let rec string_of_val ss v = + let string_of_val = string_of_val ss in + match v with | Number n -> Float.to_string n | Boolean b -> Bool.to_string b | Tuple ls -> "(" ^ String.concat ~sep:", " (List.map ~f:string_of_val ls) ^ ")" @@ -72,8 +79,14 @@ let rec string_of_val = function | Lambda _ -> "Lambda" | Thunk _ -> "Thunk" | Dictionary _ -> "Map" + | Atom n -> + let reverse_map = List.Assoc.inverse ss.static_atoms in + sprintf ":%s" (List.Assoc.find_exn reverse_map ~equal:Int.equal n) -let rec string_of_expr = function +let rec string_of_expr ss e = + let string_of_expr = string_of_expr ss in + let string_of_val = string_of_val ss in + match e with | Atomic v -> string_of_val v | Ident s -> s | Prefix (_ as p) -> sprintf "{rhs: %s}" (string_of_expr p.rhs) @@ -90,6 +103,7 @@ let rec string_of_expr = function | BlockExpr ls -> sprintf "{\n\t%s\n}" (String.concat ~sep:"\n\t" (List.map ~f:string_of_expr ls)) | MatchExpr _ -> "MatchExpr" | MapExpr _ -> "Map" + | UnresolvedAtom _ -> "UnresolvedAtom" and string_of_list_pat = function | FullPat ls -> "[" ^ (String.concat ~sep:", " (List.map ~f:string_of_pat ls)) ^ "]" @@ -104,11 +118,14 @@ and string_of_pat = function | WildcardPat -> "_" | OrPat _ -> "OrPat" | AsPat _ -> "AsPat" + | UnresolvedAtomPat _ -> "UnresolvedAtomPat" + | AtomPat _ -> "AtomPat" let rec hash_value = function | Number f -> Hashtbl.hash (0, f) | Boolean b -> Hashtbl.hash (1, b) - | Tuple ls -> Hashtbl.hash (List.map ~f:hash_value ls) + | Tuple ls -> Hashtbl.hash (2, List.map ~f:hash_value ls) + | Atom i -> Hashtbl.hash (3, i) | _ -> printf "Tried to hash an unhashable type"; assert false diff --git a/test/block.ml b/test/block.ml index a3cd38e..71dfe79 100644 --- a/test/block.ml +++ b/test/block.ml @@ -6,15 +6,15 @@ open Rustscript.Run open Util let () = - let state = + let ss, state = Map.empty (module String) |> run_file (test_file "block.rsc") in - assert_equal_expressions "a + b" "20" state; - assert_equal_expressions "f(10, 5, 3)" "28" state; - assert_equal_expressions "c" (Float.to_string (5. +. (5. +. 10. *. 2.))) state; + assert_equal_expressions "a + b" "20" ss state; + assert_equal_expressions "f(10, 5, 3)" "28" ss state; + assert_equal_expressions "c" (Float.to_string (5. +. (5. +. 10. *. 2.))) ss state; - let state = + let ss, state = Map.empty (module String) |> run_file (test_file "fmap_tuple.rsc") in let input = "(5, (10, (20, (30, (1, ())))))" in let output = "(10, (20, (40, (60, (2, ())))))" in - assert_equal_expressions (sprintf "fmap(f, %s)" input) output state; + assert_equal_expressions (sprintf "fmap(f, %s)" input) output ss state; printf "Passed\n" diff --git a/test/comments.ml b/test/comments.ml index 40c67db..e276975 100644 --- a/test/comments.ml +++ b/test/comments.ml @@ -5,18 +5,18 @@ open Rustscript.Run open Util let () = - let state = + let ss, state = Map.empty (module String) |> run_file (test_file "block.rsc") in - assert_equal_expressions "a + b" "20" state; - assert_equal_expressions "f(10, 5, 3)" "28" state; + assert_equal_expressions "a + b" "20" ss state; + assert_equal_expressions "f(10, 5, 3)" "28" ss state; - let state = + let ss, state = Map.empty (module String) |> run_file (test_file "comment.rsc") in let input = "a" in let output = "5" in - assert_equal_expressions input output state; + assert_equal_expressions input output ss state; let input = "b" in let output = "(5, 10, 15)" in - assert_equal_expressions input output state; + assert_equal_expressions input output ss state; printf "Passed\n" diff --git a/test/euler.ml b/test/euler.ml index 1f85f20..4afe7ea 100644 --- a/test/euler.ml +++ b/test/euler.ml @@ -5,36 +5,36 @@ open Rustscript.Run open Util let () = - let state = + let ss, state = default_state |> run_file (test_file "euler1.rsc") in - assert_equal_expressions "euler1" "233168" state; + assert_equal_expressions "euler1" "233168" ss state; - let state = + let ss, state = default_state |> run_file (test_file "euler1_no_listcomp.rsc") in - assert_equal_expressions "sum(filter_rev(predicate, range(1, 1000)))" "233168" state; + assert_equal_expressions "sum(filter_rev(predicate, range(1, 1000)))" "233168" ss state; - let state = + let ss, state = default_state |> run_file (test_file "euler1_tup.rsc") in - assert_equal_expressions "sum(filter(predicate, range(1, 1000)))" "233168" state; + assert_equal_expressions "sum(filter(predicate, range(1, 1000)))" "233168" ss state; - let state = + let ss, state = default_state |> run_file (test_file "euler1_tup.rsc") in - assert_equal_expressions "sum(filter(predicate, range(1, 1000)))" "233168" state; + assert_equal_expressions "sum(filter(predicate, range(1, 1000)))" "233168" ss state; - let state = + let ss, state = default_state |> run_file (test_file "euler2.rsc") in - assert_equal_expressions "euler2" "4613732" state; + assert_equal_expressions "euler2" "4613732" ss state; - let state = + let ss, state = default_state |> run_file (test_file "euler3.rsc") in - assert_equal_expressions "euler3" "6857" state; + assert_equal_expressions "euler3" "6857" ss state; - let state = + let ss, state = default_state |> run_file (test_file "euler5.rsc") in - assert_equal_expressions "euler5" "232792560" state; + assert_equal_expressions "euler5" "232792560" ss state; - let state = + let ss, state = default_state |> run_file (test_file "euler6.rsc") in - assert_equal_expressions "euler6" "25164150" state; + assert_equal_expressions "euler6" "25164150" ss state; printf "Passed\n" diff --git a/test/fib.ml b/test/fib.ml index 26bb875..52d30f5 100644 --- a/test/fib.ml +++ b/test/fib.ml @@ -5,7 +5,8 @@ open Rustscript.Run open Util let () = + let ss = { Rustscript.Types.static_atoms = [] } in let state = Map.empty (module String) in - let (_, state) = eval state "let fib = fn(n) => if n < 1 then 1 else fib(n - 1) + fib(n - 2)" in - assert_equal_expressions "fib(10)" "144" state; + let (_, state) = eval ss state "let fib = fn(n) => if n < 1 then 1 else fib(n - 1) + fib(n - 2)" in + assert_equal_expressions "fib(10)" "144" ss state; printf "Passed\n" diff --git a/test/map.ml b/test/map.ml index c2938ba..e29ab1d 100644 --- a/test/map.ml +++ b/test/map.ml @@ -5,17 +5,17 @@ open Rustscript.Run open Util let () = - let state = + let ss, state = Map.empty (module String) |> run_file (test_file "map.rsc") in - assert_equal_expressions "get(m, 1)" "2" state; - assert_equal_expressions "get(m, 3)" "4" state; - assert_equal_expressions "get(m, (5, 6))" "(7, 8)" state; - assert_equal_expressions "get(m, 467))" "()" state; + assert_equal_expressions "get(m, 1)" "2" ss state; + assert_equal_expressions "get(m, 3)" "4" ss state; + assert_equal_expressions "get(m, (5, 6))" "(7, 8)" ss state; + assert_equal_expressions "get(m, 467))" "()" ss state; - assert_equal_expressions "get(m, 1)" "x" state; - assert_equal_expressions "get(m, 3)" "y" state; - assert_equal_expressions "get(m, (5, 6))" "z" state; - assert_equal_expressions "get(m, 467))" "a" state; + assert_equal_expressions "get(m, 1)" "x" ss state; + assert_equal_expressions "get(m, 3)" "y" ss state; + assert_equal_expressions "get(m, (5, 6))" "z" ss state; + assert_equal_expressions "get(m, 467))" "a" ss state; printf "Passed\n" diff --git a/test/match_expr.ml b/test/match_expr.ml index 552c686..d5877fb 100644 --- a/test/match_expr.ml +++ b/test/match_expr.ml @@ -5,8 +5,8 @@ open Rustscript.Run open Util let () = - let state = + let ss, state = Map.empty (module String) |> run_file (test_file "match_expr.rsc") in - assert_equal_expressions "fib(20)" "10946" state; + assert_equal_expressions "fib(20)" "10946" ss state; printf "Passed\n" diff --git a/test/run_len_encode.ml b/test/run_len_encode.ml index f7a4442..7ff0a02 100644 --- a/test/run_len_encode.ml +++ b/test/run_len_encode.ml @@ -5,12 +5,13 @@ open Rustscript.Run open Util let () = - let state = + let ss, state = default_state |> run_file (test_file "run_len_encode.rsc") in assert_equal_expressions "run_len_encode(test_ls)" "[(1., 2.), (2., 1.), (3., 1.), (4., 3.), (5., 1.), (6., 1.), (1., 1.), (2., 2.)]" + ss state; printf "Passed\n" diff --git a/test/sort.ml b/test/sort.ml index c651dd5..b87442a 100644 --- a/test/sort.ml +++ b/test/sort.ml @@ -5,12 +5,12 @@ open Rustscript.Run open Util let () = - let state = + let ss, state = default_state |> run_file (test_file "quicksort.rsc") in - assert_equal_expressions "sort([5, 3, 9, 10, 4, 7, 6])" "[3, 4, 5, 6, 7, 9, 10]" state; + assert_equal_expressions "sort([5, 3, 9, 10, 4, 7, 6])" "[3, 4, 5, 6, 7, 9, 10]" ss state; - let state = + let ss, state = default_state |> run_file (test_file "mergesort.rsc") in - assert_equal_expressions "sort([5, 3, 9, 10, 4, 7, 6])" "[3, 4, 5, 6, 7, 9, 10]" state; + assert_equal_expressions "sort([5, 3, 9, 10, 4, 7, 6])" "[3, 4, 5, 6, 7, 9, 10]" ss state; printf "Passed\n" diff --git a/test/tailrec.ml b/test/tailrec.ml index a8075c3..a85466e 100644 --- a/test/tailrec.ml +++ b/test/tailrec.ml @@ -5,13 +5,13 @@ open Rustscript.Run open Util let () = - let state = + let ss, state = Map.empty (module String) |> run_file (test_file "tailrec.rsc") in (* Evaluating this stack overflows when tail recursion isn't optimized*) - assert_equal_expressions "sum(300000, 0)" "45000150000" state; + assert_equal_expressions "sum(300000, 0)" "45000150000" ss state; - let state = + let ss, state = Map.empty (module String) |> run_file (test_file "fib_tc.rsc") in - assert_equal_expressions "fib(30)" "2178309" state; + assert_equal_expressions "fib(30)" "2178309" ss state; printf "Passed\n" diff --git a/test/tuple.ml b/test/tuple.ml index d296944..1337b3e 100644 --- a/test/tuple.ml +++ b/test/tuple.ml @@ -6,8 +6,9 @@ open Util let () = let state = Map.empty (module String) in - let (_, state) = eval state "let x = (1, 2, (3, 4, 5), (6, 7), 8)" in - let (_, state) = eval state "let (a, b, c, d, e) = x" in - let (_, state) = eval state "let (f, g, h) = c" in - assert_equal_expressions "(a, b, e, f, g, h)" "(1, 2, 8, 3, 4, 5)" state; + let ss = { Rustscript.Types.static_atoms = [] } in + let (_, state) = eval ss state "let x = (1, 2, (3, 4, 5), (6, 7), 8)" in + let (_, state) = eval ss state "let (a, b, c, d, e) = x" in + let (_, state) = eval ss state "let (f, g, h) = c" in + assert_equal_expressions "(a, b, e, f, g, h)" "(1, 2, 8, 3, 4, 5)" ss state; printf "Passed\n" diff --git a/test/two_sum.ml b/test/two_sum.ml index bdb758a..f726378 100644 --- a/test/two_sum.ml +++ b/test/two_sum.ml @@ -5,11 +5,11 @@ open Rustscript.Run open Util let () = - let state = + let ss, state = default_state |> run_file (test_file "two_sum.rsc") in - assert_equal_expressions "two_sum([1,9,13,20,47], 10)" "(0, 1)" state; - assert_equal_expressions "two_sum([3,2,4,1,9], 12)" "(0, 4)" state; - assert_equal_expressions "two_sum([], 10)" "()" state; + assert_equal_expressions "two_sum([1,9,13,20,47], 10)" "(0, 1)" ss state; + assert_equal_expressions "two_sum([3,2,4,1,9], 12)" "(0, 4)" ss state; + assert_equal_expressions "two_sum([], 10)" "()" ss state; printf "Passed\n" diff --git a/test/util.ml b/test/util.ml index c17890f..a3cf98c 100644 --- a/test/util.ml +++ b/test/util.ml @@ -5,13 +5,13 @@ open Rustscript.Types let test_file filename = Printf.sprintf "../../../examples/%s" filename -let assert_equal_expressions lhs rhs state = - let (lhs_res, _) = eval state lhs in - let (rhs_res, _) = eval state rhs in - match (Rustscript.Operators.val_eq lhs_res rhs_res) with +let assert_equal_expressions lhs rhs ss state = + let (lhs_res, _) = eval ss state lhs in + let (rhs_res, _) = eval ss state rhs in + match (Rustscript.Operators.val_eq lhs_res rhs_res ss) with | Boolean true -> assert true | _ -> - printf "Expected LHS: %s\n" (string_of_val lhs_res); - printf "Got RHS: %s\n" (string_of_val rhs_res); + printf "Expected LHS: %s\n" (string_of_val ss lhs_res); + printf "Got RHS: %s\n" (string_of_val ss rhs_res); printf "\n"; assert false