diff --git a/interpreter/binary/decode.ml b/interpreter/binary/decode.ml index fa7eb7f9..da646e0d 100644 --- a/interpreter/binary/decode.ml +++ b/interpreter/binary/decode.ml @@ -1097,16 +1097,16 @@ let module_ s = customs -let decode_custom m custom = +let decode_custom m bs custom = let open Source in let Custom.{name; content; place} = custom.it in match Custom.handler name, Custom.handler (Utf8.decode "custom") with | Some (module Handler), _ -> - let fmt = Handler.decode m custom in + let fmt = Handler.decode m bs custom in let module S = struct module Handler = Handler let it = fmt end in [(module S : Custom.Section)] | None, Some (module Handler') -> - let fmt = Handler'.decode m custom in + let fmt = Handler'.decode m bs custom in let module S = struct module Handler = Handler' let it = fmt end in [(module S : Custom.Section)] | None, None -> @@ -1121,6 +1121,6 @@ let decode_with_custom name bs = let open Source in let m', cs = m_cs.it in let m = m' @@ m_cs.at in - m, List.flatten (List.map (decode_custom m) cs) + m, List.flatten (List.map (decode_custom m bs) cs) let decode name bs = fst (decode_with_custom name bs) diff --git a/interpreter/binary/encode.ml b/interpreter/binary/encode.ml index 753db1d9..6b0b1922 100644 --- a/interpreter/binary/encode.ml +++ b/interpreter/binary/encode.ml @@ -920,15 +920,17 @@ struct end -let encode_custom m (module S : Custom.Section) = +let encode_custom m bs (module S : Custom.Section) = let open Source in - let c = S.Handler.encode m S.it in + let c = S.Handler.encode m bs S.it in Custom.{c.it with place = S.Handler.place S.it} @@ c.at +let encode m = + let module E = E (struct let stream = stream () end) in + E.module_ m []; to_string E.s + let encode_with_custom (m, secs) = + let bs = encode m in let module E = E (struct let stream = stream () end) in - let cs = List.map (encode_custom m) secs in + let cs = List.map (encode_custom m bs) secs in E.module_ m cs; to_string E.s - -let encode m = - encode_with_custom (m, []) diff --git a/interpreter/custom/custom.ml b/interpreter/custom/custom.ml index 800c42a7..17ed651f 100644 --- a/interpreter/custom/custom.ml +++ b/interpreter/custom/custom.ml @@ -51,10 +51,10 @@ sig type format = format' Source.phrase val name : Ast.name val place : format -> place - val decode : Ast.module_ -> custom -> format (* raise Code *) - val encode : Ast.module_ -> format -> custom - val parse : Ast.module_ -> Annot.annot list -> format list (* raise Syntax *) - val arrange : Ast.module_ -> format -> Sexpr.sexpr + val decode : Ast.module_ -> string -> custom -> format (* raise Code *) + val encode : Ast.module_ -> string -> format -> custom + val parse : Ast.module_ -> string -> Annot.annot list -> format list (* raise Syntax *) + val arrange : Ast.module_ -> Sexpr.sexpr -> format -> Sexpr.sexpr val check : Ast.module_ -> format -> unit (* raise Invalid *) end diff --git a/interpreter/custom/handler_branch_hint.ml b/interpreter/custom/handler_branch_hint.ml new file mode 100644 index 00000000..7746584f --- /dev/null +++ b/interpreter/custom/handler_branch_hint.ml @@ -0,0 +1,425 @@ +(* Handler for "metadata.code.branch_hint" section and @metadata.code.branch_hint annotations *) + +open Custom +open Annot +open Source +open Ast + +module IdxMap = Map.Make(Int32) + +type kind = Likely | Unlikely +type hint = hint' Source.phrase +and hint' = kind * int +type hints = hint list +type func_hints = hints IdxMap.t + +type format = format' Source.phrase +and format' = +{ + func_hints : func_hints; +} + +let empty = {func_hints = IdxMap.empty } + +let name = Utf8.decode "metadata.code.branch_hint" + +let place _fmt = Before Code + + +let is_contained r1 r2 = r1.left >= r2.left && r1.right <= r2.right +let is_left r1 r2 = (r1.right.line < r2.left.line) || + (r1.right.line == r2.left.line && + r1.right.column <= r2.left.column) +let starts_before r1 r2 = (r1.left.line < r2.left.line) || + (r1.left.line == r2.left.line && + r1.left.column < r2.left.column) + +let get_func m fidx = + let nimp = List.length m.it.imports in + let fn = (Int32.to_int fidx) - nimp in + List.nth m.it.funcs fn + +let flatten_instr_locs is = + let rec flatten is = match is with + | [] -> [] + | i::rest -> + let group = match i.it with + | Block (_, inner) -> [i] @ (flatten inner) + | Loop (_, inner) -> [i] @ (flatten inner) + | If (_, inner1, inner2) -> [i] @ (flatten inner1) @ (flatten inner2) + | _ -> [i] in + group @ (flatten rest) in + let flat = flatten is in + let indexed = List.mapi (fun idx i -> (idx, i)) flat in + let sorter (_,i1) (_,i2) = + if starts_before i1.at i2.at then -1 else 1 in + let sorted = List.sort sorter indexed in + sorted + +let get_inst_idx locs at = + let finder (i, l) = is_left at l.at in + let o = List.find_opt finder locs in + match o with + | Some (i, _) -> Some i + | None -> None + +let get_nth_inst locs idx = + match List.find_opt (fun (i_idx, i) -> i_idx == idx) locs with + | Some((_, i)) -> Some(i) + | None -> None + +(* Decoding *) + +(* TODO: make Decode module reusable instead of duplicating code *) + +type stream = {bytes : string; pos : int ref} + +exception EOS + +let stream bs = {bytes = bs; pos = ref 0} + +let len s = String.length s.bytes +let pos s = !(s.pos) + +let check n s = if pos s + n > len s then raise EOS +let skip n s = if n < 0 then raise EOS else check n s; s.pos := !(s.pos) + n + +let read s = Char.code (s.bytes.[!(s.pos)]) +let get s = check 1 s; let b = read s in skip 1 s; b + +let position file pos = Source.{file; line = -1; column = pos} +let region file left right = Source.{left = position file left; right = position file right} + +let decode_error pos msg = raise (Custom.Code (region "@metadata.code.branch_hint section" pos pos, msg)) +let require b pos msg = if not b then decode_error pos msg + +let decode_byte s = + get s + +let rec decode_uN n s = + require (n > 0) (pos s) "integer representation too long"; + let b = decode_byte s in + require (n >= 7 || b land 0x7f < 1 lsl n) (pos s - 1) "integer too large"; + let x = Int32.of_int (b land 0x7f) in + if b land 0x80 = 0 then x else + Int32.(logor x (shift_left (decode_uN (n - 7) s) 7)) + +let decode_u32 = decode_uN 32 + +let decode_vec f s = + let n = decode_u32 s in + let n = Int32.to_int n in + let rec it i s = + if i = 0 then + [] + else + [f s] @ it (i - 1) s + in + it n s + +let decode_hint locs foff s = + let off = decode_u32 s in + let one = decode_u32 s in + require (one = 1l) (pos s) "@metadata.code.branch_hint section: missing reserved byte"; + let k = decode_byte s in + let hint = match k with + | 0x00 -> Unlikely + | 0x01 -> Likely + | _ -> decode_error (pos s) "@metadata.code.branch_hint section: invalid hint value" in + let abs_off = Int32.to_int (Int32.add off foff) in + let at = region "" abs_off abs_off in + let idx = match get_inst_idx locs at with + | Some i -> i + | None -> decode_error (pos s) "@metadata.code.branch_hint section: invalid offset" in + (hint, idx) @@ at + +let decode_func_hints locs foff = + decode_vec (decode_hint locs foff) + + +let decode_func m s = + let fidx = decode_u32 s in + let f = get_func m fidx in + let foff = Int32.of_int f.at.left.column in + let locs = flatten_instr_locs f.it.body in + let hs = decode_func_hints locs foff s in + (fidx, List.rev hs) + +let decode_funcs m s = + let fs = decode_vec (decode_func m) s in + IdxMap.add_seq (List.to_seq fs) IdxMap.empty + +let decode m _ custom = + let s = stream custom.it.content in + try + { func_hints = decode_funcs m s } @@ custom.at + with EOS -> decode_error (pos s) "unexpected end of name section" + + +(* Encoding *) + +(* TODO: make Encode module reusable *) + +let encode_byte buf b = + Buffer.add_char buf (Char.chr b) + +let rec encode_u32 buf i = + let b = Int32.(to_int (logand i 0x7fl)) in + if 0l <= i && i < 128l then encode_byte buf b + else ( + encode_byte buf (b lor 0x80); + encode_u32 buf (Int32.shift_right_logical i 7) + ) + +let encode_size buf n = + encode_u32 buf (Int32.of_int n) + +let encode_vec buf f v = + encode_size buf (List.length v); + let rec it v = match v with + | [] -> () + | e::es -> f buf e; it es in + it v + +let encode_hint locs foff buf h = + let kind, idx = h.it in + let i = match get_nth_inst locs idx with + | Some(i) -> i + | None -> assert false in + let off = i.at.left.column - foff in + encode_size buf off; + encode_u32 buf 1l; + let b = match kind with + | Unlikely -> 0l + | Likely -> 1l in + encode_u32 buf b + +let encode_func_hints buf locs foff = + encode_vec buf (encode_hint locs foff) + +let encode_func m buf t = + let fidx, hs = t in + encode_u32 buf fidx; + let f = get_func m fidx in + let foff = f.at.left.column in + let locs = flatten_instr_locs f.it.body in + encode_func_hints buf locs foff hs + +let encode_funcs buf m fhs = + encode_vec buf (encode_func m) (List.of_seq (IdxMap.to_seq fhs)) + +let encode m bs sec = + let {func_hints} = sec.it in + let m2 = Decode.decode "" bs in + let buf = Buffer.create 200 in + encode_funcs buf m2 func_hints; + let content = Buffer.contents buf in + {name = Utf8.decode "metadata.code.branch_hint"; content; place = Before Code} @@ sec.at + + +(* Parsing *) + +open Ast + +let parse_error at msg = raise (Custom.Syntax (at, msg)) + +let merge_func_hints = IdxMap.merge (fun key x y -> + match x, y with + | Some a, None -> Some a + | None, Some b -> Some b + | Some a, Some [{at=_; it=(_, idx)} as b] -> + let (_, last_idx) = (List.hd (List.rev a)).it in + if last_idx >= idx then + parse_error b.at "@metadata.code.branch_hint annotation: duplicate annotation" + else + Some (a @ [b]) + | Some _, Some _ -> + assert false + | None, None -> None ) + +let merge s1 s2 = + { + func_hints = merge_func_hints s1.it.func_hints s2.it.func_hints + } @@ {left = s1.at.left; right = s2.at.right} + +let find_func_idx m annot = + let idx = Lib.List.index_where (fun f -> is_contained annot.at f.at ) m.it.funcs in + match idx with + | Some i -> Int32.of_int (i + List.length m.it.imports) + | None -> parse_error annot.at "@metadata.code.branch_hint annotation: not in a function" + +let print_list ls = + let print_one (i, l) = + Printf.printf "[%d] " i; + Print.instr Out_channel.stdout 80 l; + in + Printf.printf "-------------------\n"; + List.iter print_one ls + +let rec parse m _bs annots = + let annots' = List.rev annots in + let ms = List.map (parse_annot m) annots' in + match ms with + | [] -> [] + | m::ms' -> [List.fold_left merge (empty @@ m.at) ms] + +and parse_annot m annot = + let {name = n; items} = annot.it in + assert (n = name); + let payload a = match a.it with + | String s -> s + | _ -> parse_error a.at "@metadata.code.branch_hint annotation: unexpected token" in + let fold_payload bs a = bs ^ (payload a) in + let p = List.fold_left fold_payload "" items in + let at_last = (List.hd (List.rev items)).at in + let fidx = find_func_idx m annot in + let f = get_func m fidx in + let locs = flatten_instr_locs f.it.body in + let hint = match p with + | "\x00" -> Unlikely + | "\x01" -> Likely + | _ -> parse_error annot.at "@metadata.code.branch_hint annotation: invalid hint value" in + let at = Source.{left = annot.at.left; right = at_last.right} in + let hidx = match get_inst_idx locs at with + | Some i -> i + | None -> parse_error annot.at "@metadata.code.branch_hint annotation: invalid placement" in + let e = { func_hints = IdxMap.add fidx [(hint, hidx) @@ at] IdxMap.empty } in + e @@ at + +(* Arranging *) + +let hint_to_string = function + | Likely -> "\"\\01\"" + | Unlikely -> "\"\\00\"" + +let collect_one f hat = + let (h, hidx) = hat.it in + (Int32.to_int f, hidx, Sexpr.Node ("@metadata.code.branch_hint ", [Sexpr.Atom (hint_to_string h)])) + +let collect_func (f, hs) = + List.map (collect_one f) hs + +let collect_funcs (fhs) = + List.concat (List.map collect_func fhs) + +let rec get_instrs n1 n2 = + match n2 with + | [] -> (n1, []) + | (Sexpr.Atom s)::rest -> (n1 @ [Sexpr.Atom s], rest) + | (Sexpr.Node (h, els))::rest -> + if ( String.starts_with ~prefix:"type " h + || String.equal "local" h + || String.starts_with ~prefix:"result" h ) then + get_instrs (n1 @ [Sexpr.Node(h, els)]) rest + else + (n1, n2) + + +let get_annot annots fidx idx h = + match !annots with + | [] -> [] + | (a_fidx, a_hidx, a_node)::rest -> + if a_fidx = fidx && a_hidx = idx then + begin + annots := rest; + [a_node] + end + else + [] + + +let rec apply_instrs annots fidx curi is = + match is with + | [] -> [] + | i::rest -> + let idx = !curi in + curi := idx+1; + let newn = match i with + | Sexpr.Node (h, ns) -> + let annot = get_annot annots fidx idx h in + if ( String.starts_with ~prefix:"block" h + || String.starts_with ~prefix:"loop" h ) then + let pre, inner = get_instrs [] ns in + annot @ [Sexpr.Node(h, pre @ apply_instrs annots fidx curi inner)] + else if String.starts_with ~prefix:"if" h then + match ns with + | [Sexpr.Node(hif, nif); Sexpr.Node(helse, nelse)] -> + let newif = apply_instrs annots fidx curi nif in + let newelse = apply_instrs annots fidx curi nelse in + annot @ [Sexpr.Node(h, [Sexpr.Node(hif, newif); Sexpr.Node(helse, newelse)])] + | [Sexpr.Node("result",res); Sexpr.Node(hif, nif); Sexpr.Node(helse, nelse)] -> + let newif = apply_instrs annots fidx curi nif in + let newelse = apply_instrs annots fidx curi nelse in + annot @ [Sexpr.Node(h, [Sexpr.Node("result",res); Sexpr.Node(hif, newif); Sexpr.Node(helse, newelse)])] + | _ -> assert false + else + annot @ [Sexpr.Node(h, ns)] + | Sexpr.Atom s -> [Sexpr.Atom s] in + newn @ apply_instrs annots fidx curi rest + + +let apply_func nodes annots fidx = + let curi = ref 0 in + let pre, instrs = get_instrs [] nodes in + let new_instrs = apply_instrs annots fidx curi instrs in + pre @ new_instrs + +let apply_secs annots curf node = + match node with + | Sexpr.Atom a -> Sexpr.Atom a + | Sexpr.Node (head, rest) -> + if String.starts_with ~prefix:"func" head then + begin + let ret = apply_func rest annots !curf in + curf := !curf + 1; + Sexpr.Node (head, ret) + end + else + begin + Sexpr.Node (head, rest) + end + +let apply mnode annots curf = + match mnode with + | Sexpr.Atom a -> Sexpr.Atom a + | Sexpr.Node (h, secs) -> Sexpr.Node(h, List.map (apply_secs annots curf) secs) + +let arrange m mnode fmt = + let annots = ref (collect_funcs (List.of_seq (IdxMap.to_seq fmt.it.func_hints))) in + let curf = ref 0 in + let ret = apply mnode annots curf in + ret + + + +(* Checking *) + +let check_error at msg = raise (Custom.Invalid (at, msg)) + +let check_one locs prev_hidx h = + let kind, idx = h.it in + match get_nth_inst locs idx with + | None -> assert false + | Some i -> + (match i.it with + | If _ | BrIf _ -> + if !prev_hidx >= idx then + check_error h.at "@metadata.code.branch_hint annotation: invalid order" + else + begin + prev_hidx := idx; + () + end + | _ -> check_error h.at "@metadata.code.branch_hint annotation: invalid target") + +let check_fun m fidx hs = + let f = get_func m fidx in + let locs = flatten_instr_locs f.it.body in + let prev_hidx = ref 0 in + List.iter (check_one locs prev_hidx) hs + + +let check (m : module_) (fmt : format) = + IdxMap.iter (check_fun m) fmt.it.func_hints; + () + diff --git a/interpreter/custom/handler_custom.ml b/interpreter/custom/handler_custom.ml index b268d441..3a1d8054 100644 --- a/interpreter/custom/handler_custom.ml +++ b/interpreter/custom/handler_custom.ml @@ -21,7 +21,7 @@ let decode_content m custom = let module S = struct module Handler = Handler - let it = Handler.decode m custom + let it = Handler.decode m "" custom end in Some (module S : Custom.Section) | None -> @@ -31,18 +31,18 @@ let decode_content m custom = else None -let decode m custom = +let decode m _bs custom = ignore (decode_content m custom); custom -let encode _m custom = custom +let encode _m _bs custom = custom (* Parsing *) let parse_error at msg = raise (Custom.Syntax (at, msg)) -let rec parse m annots = List.map (parse_annot m) annots +let rec parse m _bs annots = List.map (parse_annot m) annots and parse_annot m annot = let {name = n; items} = annot.it in @@ -128,11 +128,14 @@ and parse_end = function open Sexpr -let rec arrange _m custom = +let rec arrange _m mnode custom = let {name; content; place} = custom.it in - Node ("@custom " ^ Arrange.name name, + let node = Node ("@custom " ^ Arrange.name name, arrange_place place :: Arrange.break_bytes content - ) + ) in + match mnode with + | Sexpr.Atom _ -> assert false + | Node (name, secs) -> Node (name, secs @ [node]) and arrange_place = function | Before sec -> Node ("before", [Atom (arrange_sec sec)]) diff --git a/interpreter/custom/handler_name.ml b/interpreter/custom/handler_name.ml index 9c42d7ff..7c28708b 100644 --- a/interpreter/custom/handler_name.ml +++ b/interpreter/custom/handler_name.ml @@ -130,7 +130,7 @@ let decode_subsec id f default s = require (pos s = pos' + n) (pos s) "name subsection size mismatch"; ss -let decode _m custom = +let decode _m _bs custom = let s = stream custom.it.content in try let module_ = decode_subsec 0x00 decode_module None s in @@ -215,7 +215,7 @@ let encode_locals buf name_map_map = encode_subsec_end buf subsec end -let encode _m sec = +let encode _m _bs sec = let {module_; funcs; locals} = sec.it in let buf = Buffer.create 200 in encode_module buf module_; @@ -264,14 +264,16 @@ let merge s1 s2 = let is_contained r1 r2 = r1.left >= r2.left && r1.right <= r2.right let is_left r1 r2 = r1.right <= r2.left -let locate_func x name at (f : func) = +let locate_func bs x name at (f : func) = if is_left at f.it.ftype.at then {empty with funcs = IdxMap.singleton x name} - else - (* TODO *) + else if f.it.body = [] || is_left at (List.hd f.it.body).at then + (* TODO re-parse the function params and locals from bs *) parse_error at "@name annotation: local names not yet supported" + else + parse_error at "@name annotation: misplaced annotation" -let locate_module name at (m : module_) = +let locate_module bs name at (m : module_) = if not (is_contained at m.at) then parse_error at "misplaced @name annotation"; let {types; globals; tables; memories; funcs; start; @@ -293,22 +295,22 @@ let locate_module name at (m : module_) = | at1::_ when is_left at at1 -> {empty with module_ = Some name} | _ -> match Lib.List.index_where (fun f -> is_contained at f.at) funcs with - | Some x -> locate_func (Int32.of_int x) name at (List.nth funcs x) + | Some x -> locate_func bs (Int32.of_int x) name at (List.nth funcs x) | None -> parse_error at "misplaced @name annotation" -let rec parse m annots = - let ms = List.map (parse_annot m) annots in +let rec parse m bs annots = + let ms = List.map (parse_annot m bs) annots in match ms with | [] -> [] | m::ms' -> [List.fold_left merge (empty @@ m.at) ms] -and parse_annot m annot = +and parse_annot m bs annot = let {name = n; items} = annot.it in assert (n = name); let name, items' = parse_name annot.at items in parse_end items'; - locate_module name annot.at m @@ annot.at + locate_module bs name annot.at m @@ annot.at and parse_name at = function | {it = String s; at} :: items -> @@ -326,9 +328,9 @@ and parse_end = function (* Printing *) -let arrange m fmt = +let arrange m bs fmt = (* Print as generic custom section *) - Handler_custom.arrange m (encode m fmt) + Handler_custom.arrange m bs (encode m "" fmt) (* Checking *) diff --git a/interpreter/main/main.ml b/interpreter/main/main.ml index 19533113..fe98bb64 100644 --- a/interpreter/main/main.ml +++ b/interpreter/main/main.ml @@ -4,6 +4,7 @@ let version = "1.1" let all_handlers = [ (module Handler_custom : Custom.Handler); (module Handler_name : Custom.Handler); + (module Handler_branch_hint : Custom.Handler); ] let configure custom_handlers = diff --git a/interpreter/text/annot.ml b/interpreter/text/annot.ml index 8f1900c4..779086e6 100644 --- a/interpreter/text/annot.ml +++ b/interpreter/text/annot.ml @@ -22,8 +22,14 @@ module NameMap = Map.Make(struct type t = Ast.name let compare = compare end) type map = annot list NameMap.t let current : map ref = ref NameMap.empty +let current_source : Buffer.t = Buffer.create 512 -let clear () = current := NameMap.empty +let reset () = + current := NameMap.empty; + Buffer.clear current_source + +let get_source () = + Buffer.contents current_source let record annot = let old = Lib.Option.get (NameMap.find_opt annot.it.name !current) [] in diff --git a/interpreter/text/arrange.ml b/interpreter/text/arrange.ml index 72239170..2db264a9 100644 --- a/interpreter/text/arrange.ml +++ b/interpreter/text/arrange.ml @@ -619,25 +619,11 @@ let global off i g = let {gtype; ginit} = g.it in Node ("global $" ^ nat (off + i), global_type gtype :: list instr ginit.it) - -(* Custom section *) - -let custom_section m place (module S : Custom.Section) = - if Custom.(compare_place (S.Handler.place S.it) place) <= +1 then - Some (S.Handler.arrange m S.it) - else - None +let custom m mnode (module S : Custom.Section) = + S.Handler.arrange m mnode S.it (* Module *) -let rec iterate f xs = - match xs with - | [] -> [], [] - | x::xs' -> - match f x with - | Some y -> let ys', xs'' = iterate f xs' in y::ys', xs'' - | None -> [], xs - let var_opt = function | None -> "" | Some x -> " " ^ x.it @@ -648,30 +634,19 @@ let module_with_var_opt x_opt (m, cs) = let mx = ref 0 in let gx = ref 0 in let imports = list (import fx tx mx gx) m.it.imports in - let open Custom in - Node ("module" ^ var_opt x_opt, - let secs, cs = iterate (custom_section m (Before Type)) cs in secs @ + let ret = Node ("module" ^ var_opt x_opt, listi typedef m.it.types @ - let secs, cs = iterate (custom_section m (Before Import)) cs in secs @ imports @ - let secs, cs = iterate (custom_section m (Before Table)) cs in secs @ listi (table !tx) m.it.tables @ - let secs, cs = iterate (custom_section m (Before Memory)) cs in secs @ listi (memory !mx) m.it.memories @ - let secs, cs = iterate (custom_section m (Before Global)) cs in secs @ listi (global !gx) m.it.globals @ - let secs, cs = iterate (custom_section m (Before Export)) cs in secs @ list export m.it.exports @ - let secs, cs = iterate (custom_section m (Before Start)) cs in secs @ opt start m.it.start @ - let secs, cs = iterate (custom_section m (Before Elem)) cs in secs @ listi elem m.it.elems @ - let secs, cs = iterate (custom_section m (Before Code)) cs in secs @ listi (func_with_index !fx) m.it.funcs @ - let secs, cs = iterate (custom_section m (Before Data)) cs in secs @ - listi data m.it.datas @ - let secs, cs = iterate (custom_section m (After Data)) cs in secs - ) + listi data m.it.datas + ) in + List.fold_left (custom m) ret cs let binary_module_with_var_opt x_opt bs = diff --git a/interpreter/text/parse.ml b/interpreter/text/parse.ml index 5dbbf443..666df799 100644 --- a/interpreter/text/parse.ml +++ b/interpreter/text/parse.ml @@ -5,10 +5,27 @@ type 'a start = exception Syntax = Script.Syntax + +let wrap_lexbuf lexbuf = + let open Lexing in + let inner_refill = lexbuf.refill_buff in + let refill_buff lexbuf = + let oldlen = lexbuf.lex_buffer_len - lexbuf.lex_start_pos in + inner_refill lexbuf; + let newlen = lexbuf.lex_buffer_len - lexbuf.lex_start_pos in + let start = lexbuf.lex_start_pos + oldlen in + let n = newlen - oldlen in + Buffer.add_subbytes Annot.current_source lexbuf.lex_buffer start n + in + let n = lexbuf.lex_buffer_len - lexbuf.lex_start_pos in + Buffer.add_subbytes Annot.current_source lexbuf.lex_buffer lexbuf.lex_start_pos n; + {lexbuf with refill_buff} + let parse' name lexbuf start = + Annot.reset (); + let lexbuf = wrap_lexbuf lexbuf in lexbuf.Lexing.lex_curr_p <- {lexbuf.Lexing.lex_curr_p with Lexing.pos_fname = name}; - Annot.clear (); try let result = start Lexer.token lexbuf in let annots = Annot.get_all () in diff --git a/interpreter/text/parser.mly b/interpreter/text/parser.mly index 416219f3..7e119b80 100644 --- a/interpreter/text/parser.mly +++ b/interpreter/text/parser.mly @@ -206,12 +206,13 @@ let inline_type_explicit (c : context) x ft at = (* Custom annotations *) let parse_annots (m : module_) : Custom.section list = + let bs = Annot.get_source () in let annots = Annot.get m.at in let secs = Annot.NameMap.fold (fun name anns secs -> match Custom.handler name with | Some (module Handler) -> - let secs' = Handler.parse m anns in + let secs' = Handler.parse m bs anns in List.map (fun fmt -> let module S = struct module Handler = Handler let it = fmt end in (module S : Custom.Section) diff --git a/test/custom/metadata.code.branch_hint/branch_hint.wast b/test/custom/metadata.code.branch_hint/branch_hint.wast new file mode 100644 index 00000000..e50a1860 --- /dev/null +++ b/test/custom/metadata.code.branch_hint/branch_hint.wast @@ -0,0 +1,99 @@ +(module + (type (;0;) (func (param i32))) + (memory (;0;) 1 1) + (func $dummy) + (func $test1 (type 0) + (local i32) + local.get 1 + local.get 0 + i32.eq + (@metadata.code.branch_hint "\00" ) if + return + end + return + ) + (func $test2 (type 0) + (local i32) + local.get 1 + local.get 0 + i32.eq + (@metadata.code.branch_hint "\01" ) if + return + end + return + ) + (func (export "nested") (param i32 i32) (result i32) + (@metadata.code.branch_hint "\00") + (if (result i32) (local.get 0) + (then + (if (local.get 1) (then (call $dummy) (block) (nop))) + (if (local.get 1) (then) (else (call $dummy) (block) (nop))) + (@metadata.code.branch_hint "\01") + (if (result i32) (local.get 1) + (then (call $dummy) (i32.const 9)) + (else (call $dummy) (i32.const 10)) + ) + ) + (else + (if (local.get 1) (then (call $dummy) (block) (nop))) + (@metadata.code.branch_hint "\00") + (if (local.get 1) (then) (else (call $dummy) (block) (nop))) + (if (result i32) (local.get 1) + (then (call $dummy) (i32.const 10)) + (else (call $dummy) (i32.const 11)) + ) + ) + ) + ) +) + +(assert_malformed + (module quote + "(func $test2 (type 0)" + " (local i32)" + " local.get 1" + " local.get 0" + " i32.eq" + " (@metadata.code.branch_hint \"\\01\" )" + " (@metadata.code.branch_hint \"\\01\" )" + " if" + " return" + " end" + " return" + ")" + ) + "@metadata.code.branch_hint annotation: duplicate annotation" +) +(assert_malformed + (module quote + "(module" + " (@metadata.code.branch_hint \"\\01\" )" + " (type (;0;) (func (param i32)))" + " (memory (;0;) 1 1)" + " (func $test (type 0)" + " (local i32)" + " local.get 1" + " local.get 0" + " i32.eq" + " return" + " )" + ")" + ) + "@metadata.code.branch_hint annotation: not in a function" +) + +(assert_invalid + (module + (type (;0;) (func (param i32))) + (memory (;0;) 1 1) + (func $test (type 0) + (local i32) + local.get 1 + local.get 0 + (@metadata.code.branch_hint "\01" ) + i32.eq + return + ) + ) + "@metadata.code.branch_hint annotation: invalid target" +)