Skip to content
Merged

TCO #10

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/block.rsc
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ f(10, 5, 3)

let c = 5 + { 5 + 10 * 2 }

c
inspect(a)
inspect(b)
5 changes: 5 additions & 0 deletions examples/tailrec.rsc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
let sum = fn(n, acc) => {
if n == 0
then acc
else sum(n - 1, acc + n)
}
144 changes: 82 additions & 62 deletions lib/eval.ml
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,23 @@ let rec eval_let lhs rhs = fun state ->
let new_state = (bind lhs evaled) new_state in
(Unit, new_state)

and eval_lambda_call call =
and unwrap_thunk thunk state = match thunk with
| Thunk {thunk_fn = thunk_fn; thunk_args = thunk_args} ->
let inner_state = (bind thunk_fn.lambda_args thunk_args) state in
let (new_thunk, _) = (eval_expr ~tc:true thunk_fn.lambda_expr) inner_state in
unwrap_thunk new_thunk inner_state
| value -> value, state

and eval_lambda_call ?tc:(tail_call=false) call =
fun (state: state) -> match Map.find state call.callee with
| Some(Lambda (lambda_val) as l) -> begin
| Some(Lambda lambda_val) -> begin
let (evaled, _) = (eval_expr call.call_args) state in
let inner_state = (bind lambda_val.lambda_args evaled) state in
let inner_state = Map.set inner_state ~key:call.callee ~data:l in
let (result, _) = (eval_expr lambda_val.lambda_expr) inner_state in
(result, state)
let thunk = Thunk {thunk_fn = lambda_val; thunk_args = evaled} in
if tail_call
then (thunk, state)
else
let res, _ = unwrap_thunk thunk state in
(res, state)
end
| None -> begin
match call.callee with
Expand All @@ -98,66 +107,77 @@ and eval_lambda_call call =
end
| _ -> assert false

and eval_if_expr if_expr = fun state ->
and eval_if_expr ?tc:(tail_call=false) if_expr = fun state ->
match (eval_expr if_expr.cond) state with
| Boolean true, state -> (eval_expr if_expr.then_expr) state
| Boolean true, state -> (eval_expr ~tc:tail_call if_expr.then_expr) state
| Boolean false, state ->
(eval_expr if_expr.else_expr) state
(eval_expr ~tc:tail_call if_expr.else_expr) state
| _ -> assert false

and eval_block_expr ls state =
and eval_block_expr ?tc:(tail_call=false) ls state =
let (res, _) =
List.fold_left ~init:(Unit, state) ~f:(fun (_, state) e -> (eval_expr e) state) ls
let len = List.length ls in
match List.split_n ls (len - 1) with
| exprs, [last_expr] ->
let state =
List.fold_left
~init:state
~f:(fun state e -> let _, s = (eval_expr e) state in s)
exprs
in
(eval_expr ~tc:tail_call last_expr) state
| _ -> assert false
in (res, state)

and eval_expr: expr -> state -> value * state = fun expr ->
(* printf "Evaluating: %s\n" (string_of_expr expr); *)
match expr with
| Atomic n -> fun s -> n, s
| Ident n -> fun state -> begin
match Map.find state n with
| Some v -> v, state
| None ->
printf "Error: variable not found: %s\n" n;
assert false
end
| Binary ({op = Add; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_add lhs rhs, s
| Binary ({op = Sub; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_sub lhs rhs, s
| Binary ({op = Mul; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_mul lhs rhs, s
| Binary ({op = Div; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_div lhs rhs, s
| Binary ({op = EQ; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_eq lhs rhs, s
| Binary ({op = LT; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_lt lhs rhs, s
| Binary ({op = GT; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_gt lhs rhs, s
| Let l -> fun s -> (eval_let l.assignee l.assigned_expr) s
| TupleExpr ls -> fun s ->
let (eval_ls, state) =
List.fold_left
~init:([], s)
~f:(fun (acc, s) e -> let (ev, s) = eval_expr e s in (ev::acc, s))
ls
in
Tuple (List.rev eval_ls), state
| LambdaCall l -> fun s -> (eval_lambda_call l) s
| IfExpr i -> fun s -> (eval_if_expr i) s
| BlockExpr ls -> fun s -> eval_block_expr ls s
and eval_expr: expr -> ?tc:bool -> state -> value * state =
fun expr ?tc:(tail_call=false) ->
(* printf "Evaluating: %s\n" (string_of_expr expr); *)
match expr with
| Atomic n -> fun s -> n, s
| Ident n -> fun state -> begin
match Map.find state n with
| Some v -> v, state
| None ->
printf "Error: variable not found: %s\n" n;
assert false
end
| Binary ({op = Add; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_add lhs rhs, s
| Binary ({op = Sub; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_sub lhs rhs, s
| Binary ({op = Mul; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_mul lhs rhs, s
| Binary ({op = Div; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_div lhs rhs, s
| Binary ({op = EQ; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_eq lhs rhs, s
| Binary ({op = LT; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_lt lhs rhs, s
| Binary ({op = GT; _} as e) -> fun s ->
let (lhs, s) = (eval_expr e.lhs) s in
let (rhs, s) = (eval_expr e.rhs) s in
val_gt lhs rhs, s
| Let l -> fun s -> (eval_let l.assignee l.assigned_expr) s
| TupleExpr ls -> fun s ->
let (eval_ls, state) =
List.fold_left
~init:([], s)
~f:(fun (acc, s) e -> let (ev, s) = eval_expr e s in (ev::acc, s))
ls
in
Tuple (List.rev eval_ls), state
| LambdaCall l -> fun s -> (eval_lambda_call ~tc:tail_call l) s
| IfExpr i -> fun s -> (eval_if_expr ~tc:tail_call i) s
| BlockExpr ls -> fun s -> eval_block_expr ~tc:tail_call ls s
4 changes: 2 additions & 2 deletions lib/types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ type pattern =

type lambda = {lambda_expr: expr; lambda_args: pattern}
and lambda_call = {callee: string; call_args: expr}

and if_expr = {cond: expr; then_expr: expr; else_expr: expr}

and value =
Expand All @@ -25,7 +24,7 @@ and value =
| Tuple of value list
| Unit
| Lambda of lambda

| Thunk of {thunk_fn: lambda; thunk_args: value}
and expr =
| Atomic of value
| Ident of string
Expand All @@ -42,6 +41,7 @@ let rec string_of_val = function
| Tuple ls -> "(" ^ String.concat ~sep:", " (List.map ~f:string_of_val ls) ^ ")"
| Unit -> "()"
| Lambda _ -> "Lambda"
| Thunk _ -> "Thunk"

let rec string_of_expr = function
| Atomic v -> string_of_val v
Expand Down
2 changes: 1 addition & 1 deletion test/dune
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
(tests
(names fib tuple block comments)
(names fib tuple block comments tailrec)
(libraries base stdio rustscript))
14 changes: 14 additions & 0 deletions test/tailrec.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
open Base
open Stdio

open Rustscript.Run
open Util

let () =
let state =
Map.empty (module String) |> run_file (test_file "tailrec.rsc") in

(* Evaluating this stack overflows when tail recursion isn't optimized*)
assert_equal_expressions "sum(300000, 0)" "45000150000" state;

printf "Passed\n"