diff --git a/ml-proto/host/parser.mly b/ml-proto/host/parser.mly index 9d92550205..24875067b8 100644 --- a/ml-proto/host/parser.mly +++ b/ml-proto/host/parser.mly @@ -233,7 +233,8 @@ expr1 : | LOOP labeling1 labeling1 expr_list { fun c -> let c' = $2 c in let c'' = $3 c' in Loop ($4 c'') } | BR var expr_opt { fun c -> Br ($2 c label, $3 c) } - | BR_IF expr var expr_opt { fun c -> Br_if ($2 c, $3 c label, $4 c) } + | BR_IF var expr { fun c -> Br_if ($2 c label, None, $3 c) } + | BR_IF var expr expr { fun c -> Br_if ($2 c label, Some ($3 c), $4 c) } | RETURN expr_opt { let at1 = ati 1 in fun c -> Return (label c ("return" @@ at1) @@ at1, $2 c) } diff --git a/ml-proto/spec/ast.ml b/ml-proto/spec/ast.ml index 086837736d..4154b859e3 100644 --- a/ml-proto/spec/ast.ml +++ b/ml-proto/spec/ast.ml @@ -19,7 +19,7 @@ and expr' = | Block of expr list | Loop of expr list | Br of var * expr option - | Br_if of expr * var * expr option + | Br_if of var * expr option * expr | Return of var * expr option | If of expr * expr | If_else of expr * expr * expr diff --git a/ml-proto/spec/check.ml b/ml-proto/spec/check.ml index fdd19b6830..42d9c4f7f1 100644 --- a/ml-proto/spec/check.ml +++ b/ml-proto/spec/check.ml @@ -133,6 +133,11 @@ let rec check_expr c et e = | Break (x, eo) -> check_expr_opt c (label c x) eo e.at + | Br_if (x, eo, e) -> + check_expr_opt c (label c x) eo e.at; + check_expr c (Some Int32Type) e; + check_type None et e.at + | If (e1, e2, e3) -> check_expr c (Some Int32Type) e1; check_expr c et e2; diff --git a/ml-proto/spec/desugar.ml b/ml-proto/spec/desugar.ml index bad75cc63b..ebd467528e 100644 --- a/ml-proto/spec/desugar.ml +++ b/ml-proto/spec/desugar.ml @@ -17,6 +17,9 @@ and shift' n = function | Break (x, eo) -> let x' = if x.it < n then x else (x.it + 1) @@ x.at in Break (x', Lib.Option.map (shift n) eo) + | Br_if (x, eo, e) -> + let x' = if x.it < n then x else (x.it + 1) @@ x.at in + Br_if (x', Lib.Option.map (shift n) eo, shift n e) | If (e1, e2, e3) -> If (shift n e1, shift n e2, shift n e3) | Switch (e, xs, x, es) -> Switch (shift n e, xs, x, List.map (shift n) es) | Call (x, es) -> Call (x, List.map (shift n) es) @@ -53,8 +56,7 @@ and expr' at = function | Ast.Block es -> Block (List.map expr es) | Ast.Loop es -> Block [Loop (seq es) @@ at] | Ast.Br (x, eo) -> Break (x, Lib.Option.map expr eo) - | Ast.Br_if (e, x, eo) -> - If (expr e, Break (x, Lib.Option.map expr eo) @@ at, opt eo) + | Ast.Br_if (x, eo, e) -> Br_if (x, Lib.Option.map expr eo, expr e) | Ast.Return (x, eo) -> Break (x, Lib.Option.map expr eo) | Ast.If (e1, e2) -> If (expr e1, expr e2, Nop @@ Source.after e2.at) | Ast.If_else (e1, e2, e3) -> If (expr e1, expr e2, expr e3) diff --git a/ml-proto/spec/eval.ml b/ml-proto/spec/eval.ml index 601fbf0b00..403b0859e7 100644 --- a/ml-proto/spec/eval.ml +++ b/ml-proto/spec/eval.ml @@ -154,6 +154,11 @@ let rec eval_expr (c : config) (e : expr) = | Break (x, eo) -> raise (label c x (eval_expr_opt c eo)) + | Br_if (x, eo, e) -> + let v = eval_expr_opt c eo in + let i = int32 (eval_expr c e) e.at in + if i <> 0l then raise (label c x v) else None + | If (e1, e2, e3) -> let i = int32 (eval_expr c e1) e1.at in eval_expr c (if i <> 0l then e2 else e3) diff --git a/ml-proto/spec/kernel.ml b/ml-proto/spec/kernel.ml index 36d70484d4..054ed7666e 100644 --- a/ml-proto/spec/kernel.ml +++ b/ml-proto/spec/kernel.ml @@ -82,6 +82,7 @@ and expr' = | Block of expr list (* execute in sequence *) | Loop of expr (* loop header *) | Break of var * expr option (* break to n-th surrounding label *) + | Br_if of var * expr option * expr (* conditional break *) | If of expr * expr * expr (* conditional *) | Switch of expr * var list * var * expr list (* table switch *) | Call of var * expr list (* call function *) diff --git a/ml-proto/test/fac.wast b/ml-proto/test/fac.wast index 4d19074d12..40459ca95e 100644 --- a/ml-proto/test/fac.wast +++ b/ml-proto/test/fac.wast @@ -62,11 +62,11 @@ (local i64) (set_local 1 (i64.const 1)) (block - (br_if (i64.lt_s (get_local 0) (i64.const 2)) 0) + (br_if 0 (i64.lt_s (get_local 0) (i64.const 2))) (loop (set_local 1 (i64.mul (get_local 1) (get_local 0))) (set_local 0 (i64.add (get_local 0) (i64.const -1))) - (br_if (i64.gt_s (get_local 0) (i64.const 1)) 0) + (br_if 0 (i64.gt_s (get_local 0) (i64.const 1))) ) ) (get_local 1) diff --git a/ml-proto/test/labels.wast b/ml-proto/test/labels.wast index 539f4ef3a3..8c51f054c6 100644 --- a/ml-proto/test/labels.wast +++ b/ml-proto/test/labels.wast @@ -90,23 +90,46 @@ ) ) - (func $br_if (result i32) + (func $br_if0 (result i32) (local $i i32) (set_local $i (i32.const 0)) (block $outer (block $inner - (br_if (i32.const 0) $inner) + (br_if $inner (i32.const 0)) (set_local $i (i32.or (get_local $i) (i32.const 0x1))) - (br_if (i32.const 1) $inner) + (br_if $inner (i32.const 1)) (set_local $i (i32.or (get_local $i) (i32.const 0x2))) ) - (br_if (i32.const 0) $outer (set_local $i (i32.or (get_local $i) (i32.const 0x4)))) + (br_if $outer (set_local $i (i32.or (get_local $i) (i32.const 0x4))) (i32.const 0)) (set_local $i (i32.or (get_local $i) (i32.const 0x8))) - (br_if (i32.const 1) $outer (set_local $i (i32.or (get_local $i) (i32.const 0x10)))) + (br_if $outer (set_local $i (i32.or (get_local $i) (i32.const 0x10))) (i32.const 1)) (set_local $i (i32.or (get_local $i) (i32.const 0x20))) ) ) + (func $br_if1 (result i32) + (block $l0 + (br_if $l0 (block $l1 (br $l1 (i32.const 1))) (i32.const 1)) + (i32.const 1))) + + (func $br_if2 (result i32) + (block $l0 + (if (i32.const 1) + (br $l0 + (block $l1 + (br $l1 (i32.const 1))))) + (i32.const 1))) + + (func $br_if3 (result i32) + (local $i1 i32) + (i32.add (block $l0 + (br_if $l0 + (set_local $i1 (i32.const 1)) + (set_local $i1 (i32.const 2))) + (i32.const 0)) + (i32.const 0)) + (get_local $i1)) + (func $misc1 (result i32) (block $l1 (i32.xor (br $l1 (i32.const 1)) (i32.const 2))) ) @@ -123,7 +146,10 @@ (export "loop5" $loop5) (export "switch" $switch) (export "return" $return) - (export "br_if" $br_if) + (export "br_if0" $br_if0) + (export "br_if1" $br_if1) + (export "br_if2" $br_if2) + (export "br_if3" $br_if3) (export "misc1" $misc1) (export "misc2" $misc2) ) @@ -143,11 +169,29 @@ (assert_return (invoke "return" (i32.const 0)) (i32.const 0)) (assert_return (invoke "return" (i32.const 1)) (i32.const 2)) (assert_return (invoke "return" (i32.const 2)) (i32.const 2)) -(assert_return (invoke "br_if") (i32.const 0x1d)) +(assert_return (invoke "br_if0") (i32.const 0x1d)) +(assert_return (invoke "br_if1") (i32.const 1)) +(assert_return (invoke "br_if2") (i32.const 1)) +(assert_return (invoke "br_if3") (i32.const 2)) (assert_return (invoke "misc1") (i32.const 1)) (assert_return (invoke "misc2") (i32.const 1)) (assert_invalid (module (func (loop $l (br $l (i32.const 0))))) "arity mismatch") -(assert_invalid (module (func (block $l (f32.neg (br_if (i32.const 1) $l)) (nop)))) "type mismatch") -(assert_invalid (module (func (result f32) (block $l (br_if (i32.const 1) $l (f32.const 0))))) "type mismatch") +(assert_invalid (module (func (block $l (f32.neg (br_if $l (i32.const 1))) (nop)))) "type mismatch") +(assert_invalid (module (func (result f32) (block $l (br_if $l (f32.const 0) (i32.const 1))))) "type mismatch") +(assert_invalid (module (func (result i32) (block $l (br_if $l (f32.const 0) (i32.const 1))))) "type mismatch") +(assert_invalid (module (func (block $l (f32.neg (br_if $l (f32.const 0) (i32.const 1)))))) "arity mismatch") +(assert_invalid (module (func (param i32) (result i32) (block $l (f32.neg (br_if $l (f32.const 0) (get_local 0)))))) "type mismatch") +(assert_invalid (module (func (param i32) (result f32) + (block $l (f32.neg (block $i (br_if $l (f32.const 3) (get_local 0))))))) + "type mismatch") +(assert_invalid (module (func (block $l0 (br_if $l0 (nop) (i32.const 1))))) + "arity mismatch") +(assert_invalid (module (func (result i32) + (block $l0 + (if_else (i32.const 1) + (br $l0 (block $l1 (br $l1 (i32.const 1)))) + (block (block $l1 (br $l1 (i32.const 1))) (nop)) + ) + (i32.const 1)))) "arity mismatch")