From af1fe06037710584b5088a963ce614587627f064 Mon Sep 17 00:00:00 2001 From: Mikail Khan Date: Thu, 21 Oct 2021 00:10:33 -0400 Subject: [PATCH] TCO --- examples/block.rsc | 3 +- examples/tailrec.rsc | 5 ++ lib/eval.ml | 144 ++++++++++++++++++++++++------------------- lib/types.ml | 4 +- test/dune | 2 +- test/tailrec.ml | 14 +++++ 6 files changed, 106 insertions(+), 66 deletions(-) create mode 100644 examples/tailrec.rsc create mode 100644 test/tailrec.ml diff --git a/examples/block.rsc b/examples/block.rsc index 3c8853c..75cc81f 100644 --- a/examples/block.rsc +++ b/examples/block.rsc @@ -13,4 +13,5 @@ f(10, 5, 3) let c = 5 + { 5 + 10 * 2 } -c +inspect(a) +inspect(b) diff --git a/examples/tailrec.rsc b/examples/tailrec.rsc new file mode 100644 index 0000000..33a2428 --- /dev/null +++ b/examples/tailrec.rsc @@ -0,0 +1,5 @@ +let sum = fn(n, acc) => { + if n == 0 + then acc + else sum(n - 1, acc + n) +} diff --git a/lib/eval.ml b/lib/eval.ml index ab94a7a..6f52769 100644 --- a/lib/eval.ml +++ b/lib/eval.ml @@ -71,14 +71,23 @@ let rec eval_let lhs rhs = fun state -> let new_state = (bind lhs evaled) new_state in (Unit, new_state) -and eval_lambda_call call = +and unwrap_thunk thunk state = match thunk with + | Thunk {thunk_fn = thunk_fn; thunk_args = thunk_args} -> + let inner_state = (bind thunk_fn.lambda_args thunk_args) state in + let (new_thunk, _) = (eval_expr ~tc:true thunk_fn.lambda_expr) inner_state in + unwrap_thunk new_thunk inner_state + | value -> value, state + +and eval_lambda_call ?tc:(tail_call=false) call = fun (state: state) -> match Map.find state call.callee with - | Some(Lambda (lambda_val) as l) -> begin + | Some(Lambda lambda_val) -> begin let (evaled, _) = (eval_expr call.call_args) state in - let inner_state = (bind lambda_val.lambda_args evaled) state in - let inner_state = Map.set inner_state ~key:call.callee ~data:l in - let (result, _) = (eval_expr lambda_val.lambda_expr) inner_state in - (result, state) + let thunk = Thunk {thunk_fn = lambda_val; thunk_args = evaled} in + if tail_call + then (thunk, state) + else + let res, _ = unwrap_thunk thunk state in + (res, state) end | None -> begin match call.callee with @@ -98,66 +107,77 @@ and eval_lambda_call call = end | _ -> assert false -and eval_if_expr if_expr = fun state -> +and eval_if_expr ?tc:(tail_call=false) if_expr = fun state -> match (eval_expr if_expr.cond) state with - | Boolean true, state -> (eval_expr if_expr.then_expr) state + | Boolean true, state -> (eval_expr ~tc:tail_call if_expr.then_expr) state | Boolean false, state -> - (eval_expr if_expr.else_expr) state + (eval_expr ~tc:tail_call if_expr.else_expr) state | _ -> assert false -and eval_block_expr ls state = +and eval_block_expr ?tc:(tail_call=false) ls state = let (res, _) = - List.fold_left ~init:(Unit, state) ~f:(fun (_, state) e -> (eval_expr e) state) ls + let len = List.length ls in + match List.split_n ls (len - 1) with + | exprs, [last_expr] -> + let state = + List.fold_left + ~init:state + ~f:(fun state e -> let _, s = (eval_expr e) state in s) + exprs + in + (eval_expr ~tc:tail_call last_expr) state + | _ -> assert false in (res, state) -and eval_expr: expr -> state -> value * state = fun expr -> - (* printf "Evaluating: %s\n" (string_of_expr expr); *) - match expr with - | Atomic n -> fun s -> n, s - | Ident n -> fun state -> begin - match Map.find state n with - | Some v -> v, state - | None -> - printf "Error: variable not found: %s\n" n; - assert false - end - | Binary ({op = Add; _} as e) -> fun s -> - let (lhs, s) = (eval_expr e.lhs) s in - let (rhs, s) = (eval_expr e.rhs) s in - val_add lhs rhs, s - | Binary ({op = Sub; _} as e) -> fun s -> - let (lhs, s) = (eval_expr e.lhs) s in - let (rhs, s) = (eval_expr e.rhs) s in - val_sub lhs rhs, s - | Binary ({op = Mul; _} as e) -> fun s -> - let (lhs, s) = (eval_expr e.lhs) s in - let (rhs, s) = (eval_expr e.rhs) s in - val_mul lhs rhs, s - | Binary ({op = Div; _} as e) -> fun s -> - let (lhs, s) = (eval_expr e.lhs) s in - let (rhs, s) = (eval_expr e.rhs) s in - val_div lhs rhs, s - | Binary ({op = EQ; _} as e) -> fun s -> - let (lhs, s) = (eval_expr e.lhs) s in - let (rhs, s) = (eval_expr e.rhs) s in - val_eq lhs rhs, s - | Binary ({op = LT; _} as e) -> fun s -> - let (lhs, s) = (eval_expr e.lhs) s in - let (rhs, s) = (eval_expr e.rhs) s in - val_lt lhs rhs, s - | Binary ({op = GT; _} as e) -> fun s -> - let (lhs, s) = (eval_expr e.lhs) s in - let (rhs, s) = (eval_expr e.rhs) s in - val_gt lhs rhs, s - | Let l -> fun s -> (eval_let l.assignee l.assigned_expr) s - | TupleExpr ls -> fun s -> - let (eval_ls, state) = - List.fold_left - ~init:([], s) - ~f:(fun (acc, s) e -> let (ev, s) = eval_expr e s in (ev::acc, s)) - ls - in - Tuple (List.rev eval_ls), state - | LambdaCall l -> fun s -> (eval_lambda_call l) s - | IfExpr i -> fun s -> (eval_if_expr i) s - | BlockExpr ls -> fun s -> eval_block_expr ls s +and eval_expr: expr -> ?tc:bool -> state -> value * state = + fun expr ?tc:(tail_call=false) -> + (* printf "Evaluating: %s\n" (string_of_expr expr); *) + match expr with + | Atomic n -> fun s -> n, s + | Ident n -> fun state -> begin + match Map.find state n with + | Some v -> v, state + | None -> + printf "Error: variable not found: %s\n" n; + assert false + end + | Binary ({op = Add; _} as e) -> fun s -> + let (lhs, s) = (eval_expr e.lhs) s in + let (rhs, s) = (eval_expr e.rhs) s in + val_add lhs rhs, s + | Binary ({op = Sub; _} as e) -> fun s -> + let (lhs, s) = (eval_expr e.lhs) s in + let (rhs, s) = (eval_expr e.rhs) s in + val_sub lhs rhs, s + | Binary ({op = Mul; _} as e) -> fun s -> + let (lhs, s) = (eval_expr e.lhs) s in + let (rhs, s) = (eval_expr e.rhs) s in + val_mul lhs rhs, s + | Binary ({op = Div; _} as e) -> fun s -> + let (lhs, s) = (eval_expr e.lhs) s in + let (rhs, s) = (eval_expr e.rhs) s in + val_div lhs rhs, s + | Binary ({op = EQ; _} as e) -> fun s -> + let (lhs, s) = (eval_expr e.lhs) s in + let (rhs, s) = (eval_expr e.rhs) s in + val_eq lhs rhs, s + | Binary ({op = LT; _} as e) -> fun s -> + let (lhs, s) = (eval_expr e.lhs) s in + let (rhs, s) = (eval_expr e.rhs) s in + val_lt lhs rhs, s + | Binary ({op = GT; _} as e) -> fun s -> + let (lhs, s) = (eval_expr e.lhs) s in + let (rhs, s) = (eval_expr e.rhs) s in + val_gt lhs rhs, s + | Let l -> fun s -> (eval_let l.assignee l.assigned_expr) s + | TupleExpr ls -> fun s -> + let (eval_ls, state) = + List.fold_left + ~init:([], s) + ~f:(fun (acc, s) e -> let (ev, s) = eval_expr e s in (ev::acc, s)) + ls + in + Tuple (List.rev eval_ls), state + | LambdaCall l -> fun s -> (eval_lambda_call ~tc:tail_call l) s + | IfExpr i -> fun s -> (eval_if_expr ~tc:tail_call i) s + | BlockExpr ls -> fun s -> eval_block_expr ~tc:tail_call ls s diff --git a/lib/types.ml b/lib/types.ml index aeb6feb..02dcdc3 100644 --- a/lib/types.ml +++ b/lib/types.ml @@ -16,7 +16,6 @@ type pattern = type lambda = {lambda_expr: expr; lambda_args: pattern} and lambda_call = {callee: string; call_args: expr} - and if_expr = {cond: expr; then_expr: expr; else_expr: expr} and value = @@ -25,7 +24,7 @@ and value = | Tuple of value list | Unit | Lambda of lambda - + | Thunk of {thunk_fn: lambda; thunk_args: value} and expr = | Atomic of value | Ident of string @@ -42,6 +41,7 @@ let rec string_of_val = function | Tuple ls -> "(" ^ String.concat ~sep:", " (List.map ~f:string_of_val ls) ^ ")" | Unit -> "()" | Lambda _ -> "Lambda" + | Thunk _ -> "Thunk" let rec string_of_expr = function | Atomic v -> string_of_val v diff --git a/test/dune b/test/dune index 6357c74..22a4420 100644 --- a/test/dune +++ b/test/dune @@ -1,3 +1,3 @@ (tests - (names fib tuple block comments) + (names fib tuple block comments tailrec) (libraries base stdio rustscript)) diff --git a/test/tailrec.ml b/test/tailrec.ml new file mode 100644 index 0000000..80cf7f2 --- /dev/null +++ b/test/tailrec.ml @@ -0,0 +1,14 @@ +open Base +open Stdio + +open Rustscript.Run +open Util + +let () = + let state = + Map.empty (module String) |> run_file (test_file "tailrec.rsc") in + + (* Evaluating this stack overflows when tail recursion isn't optimized*) + assert_equal_expressions "sum(300000, 0)" "45000150000" state; + + printf "Passed\n"