diff --git a/ml-proto/host/lexer.mll b/ml-proto/host/lexer.mll index 7f3b325ce9..bee9f1f5d1 100644 --- a/ml-proto/host/lexer.mll +++ b/ml-proto/host/lexer.mll @@ -135,6 +135,7 @@ rule token = parse | "break" { BREAK } | "case" { CASE } | "fallthrough" { FALLTHROUGH } + | "br_if" { BR_IF } | "call" { CALL } | "call_import" { CALL_IMPORT } | "call_indirect" { CALL_INDIRECT } @@ -159,6 +160,7 @@ rule token = parse | (nxx as t)".switch" { SWITCH (value_type t) } | (nxx as t)".const" { CONST (value_type t) } + | (nxx as t)".br_switch" { BR_SWITCH (value_type t) } | (ixx as t)".clz" { UNARY (intop t Int32Op.Clz Int64Op.Clz) } | (ixx as t)".ctz" { UNARY (intop t Int32Op.Ctz Int64Op.Ctz) } diff --git a/ml-proto/host/parser.mly b/ml-proto/host/parser.mly index e8f4f239ce..523bca9c2e 100644 --- a/ml-proto/host/parser.mly +++ b/ml-proto/host/parser.mly @@ -95,6 +95,7 @@ let anon_label c = {c with labels = VarMap.map ((+) 1) c.labels} %token INT FLOAT TEXT VAR TYPE LPAR RPAR %token NOP BLOCK IF LOOP LABEL BREAK SWITCH CASE FALLTHROUGH +%token BR_IF BR_SWITCH %token CALL CALL_IMPORT CALL_INDIRECT RETURN %token GET_LOCAL SET_LOCAL LOAD STORE %token CONST UNARY BINARY COMPARE CONVERT @@ -110,6 +111,7 @@ let anon_label c = {c with labels = VarMap.map ((+) 1) c.labels} %token TYPE %token CONST %token SWITCH +%token BR_SWITCH %token UNARY %token BINARY %token COMPARE @@ -185,6 +187,12 @@ expr1 : { let at1 = ati 1 in fun c -> let c', l = $2 c in let cs, e = $4 c' in switch (l, $1 @@ at1, $3 c', List.map (fun a -> a $1) cs, e) } + | BR_IF var expr expr_opt { fun c -> br_if ($2 c label, $3 c, $4 c) } + | BR_SWITCH expr var br_switch_arms expr_opt + { let at1 = ati 1 in + let t = $1 in + fun c -> br_switch (t @@ at1, $2 c, $3 c label, + List.map (fun (s, a) -> (literal s t, a c label)) $4, $5 c) } | CALL var expr_list { fun c -> call ($2 c func, $3 c) } | CALL_IMPORT var expr_list { fun c -> call_import ($2 c import, $3 c) } | CALL_INDIRECT var expr expr_list @@ -232,6 +240,11 @@ cases : | case cases { fun c -> let x, y = $2 c in $1 c :: x, y } ; +br_switch_arms : + | /* empty */ { [] } + | INT var br_switch_arms { let at = at () in ($1 @@ at, $2) :: $3 } +; + /* Functions */ diff --git a/ml-proto/spec/sugar.ml b/ml-proto/spec/sugar.ml index 134bb4dd1b..d499b91ff7 100644 --- a/ml-proto/spec/sugar.ml +++ b/ml-proto/spec/sugar.ml @@ -43,6 +43,17 @@ let return (x, eo) = let switch (l, t, e1, cs, e2) = labeling l (Switch (t, e1, cs, e2)) +let br_if (x, e1, e2) = + if_ (e1, break (x, e2) @@ Source.no_region, Some (nop @@ Source.no_region)) + +let br_switch (t, e1, x, xs, e2) = + switch (Unlabelled @@ Source.no_region, t, e1, + List.map (fun (i, x) -> {value = i; + expr = break (x, e2) @@ Source.no_region; + fallthru = true} @@ Source.no_region) + xs, + (break (x, e2) @@ Source.no_region)) + let call (x, es) = Call (x, es) diff --git a/ml-proto/spec/sugar.mli b/ml-proto/spec/sugar.mli index 2cb3cdf3ca..08334daebb 100644 --- a/ml-proto/spec/sugar.mli +++ b/ml-proto/spec/sugar.mli @@ -11,6 +11,8 @@ val label : expr -> expr' val break : var * expr option -> expr' val return : var * expr option -> expr' val switch : labeling * value_type * expr * case list * expr -> expr' +val br_if : var * expr * expr option -> expr' +val br_switch : value_type * expr * var * (Values.value Source.phrase * var) list * expr option -> expr' val call : var * expr list -> expr' val call_import : var * expr list -> expr' val call_indirect : var * expr * expr list -> expr' diff --git a/ml-proto/test/br_if.wast b/ml-proto/test/br_if.wast new file mode 100644 index 0000000000..a12fb74ff8 --- /dev/null +++ b/ml-proto/test/br_if.wast @@ -0,0 +1,29 @@ +;; Similar to fac.wasm, but using br_if instead of if. + +(module + (func (param i64) (result i64) + (block $else + (br_if $else (i64.ne (get_local 0) (i64.const 0))) + (return (i64.const 1)) + ) + (i64.mul (get_local 0) (call 0 (i64.sub (get_local 0) (i64.const 1)))) + ) + + (func (param i64) (result i64) + (local i64 i64) + (set_local 1 (get_local 0)) + (set_local 2 (i64.const 1)) + (loop $done $loop + (br_if $done (i64.eq (get_local 1) (i64.const 0))) + (set_local 2 (i64.mul (get_local 1) (get_local 2))) + (set_local 1 (i64.sub (get_local 1) (i64.const 1))) + (break $loop)) + (return (get_local 2)) + ) + + (export "fac-rec" 0) + (export "fac-iter" 1) +) + +(assert_return (invoke "fac-rec" (i64.const 25)) (i64.const 7034535277573963776)) +(assert_return (invoke "fac-iter" (i64.const 25)) (i64.const 7034535277573963776)) diff --git a/ml-proto/test/br_switch.wast b/ml-proto/test/br_switch.wast new file mode 100644 index 0000000000..bee795f33a --- /dev/null +++ b/ml-proto/test/br_switch.wast @@ -0,0 +1,67 @@ +(module + ;; Statement br_switch + (func $stmt (param $i i32) (result i32) + (local $j i32) + (set_local $j (i32.const 100)) + (block $end + (block $default + (block $case6 + (block $case5 + (block $case4 + (block $case3 + (block $case2 + (block $case1 + (block $case0 + (i32.br_switch (get_local $i) + $default 0 $case0 1 $case1 2 $case2 3 $case3 4 $case4 5 $case5 6 $case6)) + (return (get_local $i))))) + (set_local $j (i32.sub (i32.const 0) (get_local $i))) (break $end)) + (break $end)) + (set_local $j (i32.const 101)) (break $end)) + (set_local $j (i32.const 101))) + (set_local $j (i32.const 102))) + (return (get_local $j)) + ) + + ;; Expression br_switch + (func $expr (param $i i64) (result i64) + (local $j i64) + (set_local $j (i64.const 100)) + (return + (block $exit + (block $default + (block $case6 + (block $case3 + (block $case2 + (block $case1 + (block $case0 + (i64.br_switch (get_local $i) + $default 0 $case0 1 $case1 2 $case2 3 $case3 6 $case6)) + (return (get_local $i))))) + (break $exit (i64.sub (i64.const 0) (get_local $i)))) + (set_local $j (i64.const 101))) + (get_local $j)) + ) + ) + + (export "stmt" $stmt) + (export "expr" $expr) +) + +(assert_return (invoke "stmt" (i32.const 0)) (i32.const 0)) +(assert_return (invoke "stmt" (i32.const 1)) (i32.const -1)) +(assert_return (invoke "stmt" (i32.const 2)) (i32.const -2)) +(assert_return (invoke "stmt" (i32.const 3)) (i32.const -3)) +(assert_return (invoke "stmt" (i32.const 4)) (i32.const 100)) +(assert_return (invoke "stmt" (i32.const 5)) (i32.const 101)) +(assert_return (invoke "stmt" (i32.const 6)) (i32.const 102)) +(assert_return (invoke "stmt" (i32.const 7)) (i32.const 102)) +(assert_return (invoke "stmt" (i32.const -10)) (i32.const 102)) + +(assert_return (invoke "expr" (i64.const 0)) (i64.const 0)) +(assert_return (invoke "expr" (i64.const 1)) (i64.const -1)) +(assert_return (invoke "expr" (i64.const 2)) (i64.const -2)) +(assert_return (invoke "expr" (i64.const 3)) (i64.const -3)) +(assert_return (invoke "expr" (i64.const 6)) (i64.const 101)) +(assert_return (invoke "expr" (i64.const 7)) (i64.const 100)) +(assert_return (invoke "expr" (i64.const -10)) (i64.const 100))