diff --git a/qiling/debugger/qdb/frontend.py b/qiling/debugger/qdb/frontend.py index 03f584d8f..cf4a1d087 100644 --- a/qiling/debugger/qdb/frontend.py +++ b/qiling/debugger/qdb/frontend.py @@ -5,16 +5,14 @@ from __future__ import annotations from typing import Optional, Mapping, Iterable, Union - import copy, math, os -from contextlib import contextmanager - -from qiling.const import QL_ARCH import unicorn -from .utils import dump_regs, get_x86_eflags, get_arm_flags, disasm, _parse_int, handle_bnj -from .const import * +from qiling.const import QL_ARCH + +from .utils import disasm, get_x86_eflags, setup_branch_predictor +from .const import color, SIZE_LETTER, FORMAT_LETTER # read data from memory of qiling instance @@ -71,7 +69,7 @@ def extract_count(t): if elem in ql.reg.register_mapping.keys(): items.append(getattr(ql.reg, elem, None)) else: - items.append(_parse_int(elem)) + items.append(read_int(elem)) addr = sum(items) @@ -146,224 +144,364 @@ def _try_read(ql: Qiling, address: int, size: int) -> Optional[bytes]: return (result, err_msg) -# divider printer -@contextmanager -def context_printer(ql: Qiling, field_name: str, ruler: str = "─") -> None: - height, width = get_terminal_size() - bar = (width - len(field_name)) // 2 - 1 - print(ruler * bar, field_name, ruler * bar) - yield - if "DISASM" in field_name: - print(ruler * width) +""" + Context Manager for rendering UI -def context_reg(ql: Qiling, saved_states: Optional[Mapping[str, int]] = None, /, *args, **kwargs) -> None: +""" - # context render for registers - with context_printer(ql, "[ REGISTERS ]"): +COLORS = (color.DARKCYAN, color.BLUE, color.RED, color.YELLOW, color.GREEN, color.PURPLE, color.CYAN, color.WHITE) - _cur_regs = dump_regs(ql) +# decorator function for printing divider +def context_printer(field_name, ruler="─"): + def decorator(context_dumper): + def wrapper(*args, **kwargs): + height, width = get_terminal_size() + bar = (width - len(field_name)) // 2 - 1 + print(ruler * bar, field_name, ruler * bar) + context_dumper(*args, **kwargs) + if "DISASM" in field_name: + print(ruler * width) + return wrapper + return decorator - _colors = (color.DARKCYAN, color.BLUE, color.RED, color.YELLOW, color.GREEN, color.PURPLE, color.CYAN, color.WHITE) - if ql.archtype == QL_ARCH.MIPS: +def setup_ctx_manager(ql: Qiling) -> CtxManager: + return { + QL_ARCH.X86: CtxManager_X86, + QL_ARCH.ARM: CtxManager_ARM, + QL_ARCH.ARM_THUMB: CtxManager_ARM, + QL_ARCH.CORTEX_M: CtxManager_ARM, + QL_ARCH.MIPS: CtxManager_MIPS, + }.get(ql.archtype)(ql) - _cur_regs.update({"fp": _cur_regs.pop("s8")}) - if saved_states is not None: - _saved_states = copy.deepcopy(saved_states) - _saved_states.update({"fp": _saved_states.pop("s8")}) - _diff = [k for k in _cur_regs if _cur_regs[k] != _saved_states[k]] +class CtxManager(object): + def __init__(self, ql): + self.ql = ql + self.predictor = setup_branch_predictor(ql) - else: - _diff = None + def print_asm(self, insn: CsInsn, to_jump: Optional[bool] = None, address: int = None) -> None: - lines = "" - for idx, r in enumerate(_cur_regs, 1): - line = "{}{}: 0x{{:08x}} {}\t".format(_colors[(idx-1) // 4], r, color.END) + opcode = "".join(f"{b:02x}" for b in insn.bytes) + if self.ql.archtype in (QL_ARCH.X86, QL_ARCH.X8664): + trace_line = f"0x{insn.address:08x} │ {opcode:20s} {insn.mnemonic:10} {insn.op_str:35s}" + else: + trace_line = f"0x{insn.address:08x} │ {opcode:10s} {insn.mnemonic:10} {insn.op_str:35s}" - if _diff and r in _diff: - line = f"{color.UNDERLINE}{color.BOLD}{line}" + cursor = " " + if self.ql.reg.arch_pc == insn.address: + cursor = "►" - if idx % 4 == 0 and idx != 32: - line += "\n" + jump_sign = " " + if to_jump: + jump_sign = f"{color.RED}✓{color.END}" - lines += line + print(f"{jump_sign} {cursor} {color.DARKGRAY}{trace_line}{color.END}") - print(lines.format(*_cur_regs.values())) + def dump_regs(self): + return {reg_name: getattr(self.ql.reg, reg_name) for reg_name in self.regs} - elif ql.archtype == QL_ARCH.X86: + def context_reg(self, saved_states): + return NotImplementedError - if saved_states is not None: - _saved_states = copy.deepcopy(saved_states) - _diff = [k for k in _cur_regs if _cur_regs[k] != _saved_states[k]] + @context_printer("[ STACK ]") + def context_stack(self): - else: - _diff = None + for idx in range(10): + addr = self.ql.reg.arch_sp + idx * self.ql.pointersize + if (val := _try_read(self.ql, addr, self.ql.pointersize)[0]): + print(f"$sp+0x{idx*self.ql.pointersize:02x}│ [0x{addr:08x}] —▸ 0x{self.ql.unpack(val):08x}", end="") + + # try to dereference wether it's a pointer + if (buf := _try_read(self.ql, addr, self.ql.pointersize))[0] is not None: + + if (addr := self.ql.unpack(buf[0])): - lines = "" - for idx, r in enumerate(_cur_regs, 1): - if len(r) == 2: - line = "{}{}: 0x{{:08x}} {}\t\t".format(_colors[(idx-1) // 4], r, color.END) - else: - line = "{}{}: 0x{{:08x}} {}\t".format(_colors[(idx-1) // 4], r, color.END) + # try to dereference again + if (buf := _try_read(self.ql, addr, self.ql.pointersize))[0] is not None: + try: + s = self.ql.mem.string(addr) + except: + s = None - if _diff and r in _diff: - line = f"{color.UNDERLINE}{color.BOLD}{line}" + if s and s.isprintable(): + print(f" ◂— {self.ql.mem.string(addr)}", end="") + else: + print(f" ◂— 0x{self.ql.unpack(buf[0]):08x}", end="") + print() - if idx % 4 == 0 and idx != 32: - line += "\n" + @context_printer("[ DISASM ]") + def context_asm(self): + # assembly before current location + past_list = [] + cur_addr = self.ql.reg.arch_pc - lines += line + line = disasm(self.ql, cur_addr-0x10) - print(lines.format(*_cur_regs.values())) - print(color.GREEN, "EFLAGS: [CF: {flags[CF]}, PF: {flags[PF]}, AF: {flags[AF]}, ZF: {flags[ZF]}, SF: {flags[SF]}, OF: {flags[OF]}]".format(flags=get_x86_eflags(ql.reg.ef)), color.END, sep="") + while line: + if line.address == cur_addr: + break - elif ql.archtype in (QL_ARCH.ARM, QL_ARCH.ARM_THUMB, QL_ARCH.CORTEX_M): + addr = line.address + line.size + line = disasm(self.ql, addr) - _cur_regs.update({"sl": _cur_regs.pop("r10")}) - _cur_regs.update({"ip": _cur_regs.pop("r12")}) - _cur_regs.update({"fp": _cur_regs.pop("r11")}) + if not line: + break - regs_in_row = 4 - if ql.archtype == QL_ARCH.CORTEX_M: - regs_in_row = 3 + past_list.append(line) - # for re-order - _cur_regs.update({"xpsr": _cur_regs.pop("xpsr")}) - _cur_regs.update({"control": _cur_regs.pop("control")}) - _cur_regs.update({"primask": _cur_regs.pop("primask")}) - _cur_regs.update({"faultmask": _cur_regs.pop("faultmask")}) - _cur_regs.update({"basepri": _cur_regs.pop("basepri")}) + # print four insns before current location + for line in past_list[:-1]: + self.print_asm(line) - _diff = None - if saved_states is not None: - _saved_states = copy.deepcopy(saved_states) - _saved_states.update({"sl": _saved_states.pop("r10")}) - _saved_states.update({"ip": _saved_states.pop("r12")}) - _saved_states.update({"fp": _saved_states.pop("r11")}) - _diff = [k for k in _cur_regs if _cur_regs[k] != _saved_states[k]] + # assembly for current location - lines = "" - for idx, r in enumerate(_cur_regs, 1): + cur_insn = disasm(self.ql, cur_addr) + prophecy = self.predictor.predict() + self.print_asm(cur_insn, to_jump=prophecy.going) - line = "{}{:}: 0x{{:08x}} {} ".format(_colors[(idx-1) // regs_in_row], r, color.END) + # assembly after current location - if _diff and r in _diff: - line = "{}{}".format(color.UNDERLINE, color.BOLD) + line + forward_insn_size = cur_insn.size + for _ in range(5): + forward_insn = disasm(self.ql, cur_addr+forward_insn_size) + if forward_insn: + self.print_asm(forward_insn) + forward_insn_size += forward_insn.size - if idx % regs_in_row == 0: - line += "\n" - lines += line +class CtxManager_ARM(CtxManager): + def __init__(self, ql): + super().__init__(ql) - print(lines.format(*_cur_regs.values())) - print(color.GREEN, "[{cpsr[mode]} mode], Thumb: {cpsr[thumb]}, FIQ: {cpsr[fiq]}, IRQ: {cpsr[irq]}, NEG: {cpsr[neg]}, ZERO: {cpsr[zero]}, Carry: {cpsr[carry]}, Overflow: {cpsr[overflow]}".format(cpsr=get_arm_flags(ql.reg.cpsr)), color.END, sep="") + self.regs = ( + "r0", "r1", "r2", "r3", + "r4", "r5", "r6", "r7", + "r8", "r9", "r10", "r11", + "r12", "sp", "lr", "pc", + ) - if ql.archtype != QL_ARCH.CORTEX_M: - # context render for Stack, skip this for CORTEX_M - with context_printer(ql, "[ STACK ]", ruler="─"): + @staticmethod + def get_flags(bits: int) -> Mapping[str, int]: - for idx in range(10): - addr = ql.reg.arch_sp + idx * ql.pointersize - val = ql.mem.read(addr, ql.pointersize) - print(f"$sp+0x{idx*ql.pointersize:02x}│ [0x{addr:08x}] —▸ 0x{ql.unpack(val):08x}", end="") + def _get_mode(bits): + return { + 0b10000: "User", + 0b10001: "FIQ", + 0b10010: "IRQ", + 0b10011: "Supervisor", + 0b10110: "Monitor", + 0b10111: "Abort", + 0b11010: "Hypervisor", + 0b11011: "Undefined", + 0b11111: "System", + }.get(bits & 0x00001f) - # try to dereference wether it's a pointer - if (buf := _try_read(ql, addr, ql.pointersize))[0] is not None: + return { + "mode": _get_mode(bits), + "thumb": bits & 0x00000020 != 0, + "fiq": bits & 0x00000040 != 0, + "irq": bits & 0x00000080 != 0, + "neg": bits & 0x80000000 != 0, + "zero": bits & 0x40000000 != 0, + "carry": bits & 0x20000000 != 0, + "overflow": bits & 0x10000000 != 0, + } - if (addr := ql.unpack(buf[0])): + @context_printer("[ REGISTERS ]") + def context_reg(self, saved_reg_dump): + cur_regs = self.dump_regs() - # try to dereference again - if (buf := _try_read(ql, addr, ql.pointersize))[0] is not None: - try: - s = ql.mem.string(addr) - except: - s = None + cur_regs.update({"sl": cur_regs.pop("r10")}) + cur_regs.update({"ip": cur_regs.pop("r12")}) + cur_regs.update({"fp": cur_regs.pop("r11")}) - if s and s.isprintable(): - print(f" ◂— {ql.mem.string(addr)}", end="") - else: - print(f" ◂— 0x{ql.unpack(buf[0]):08x}", end="") - print() + regs_in_row = 4 + diff = None + if saved_reg_dump is not None: + reg_dump = copy.deepcopy(saved_reg_dump) + reg_dump.update({"sl": reg_dump.pop("r10")}) + reg_dump.update({"ip": reg_dump.pop("r12")}) + reg_dump.update({"fp": reg_dump.pop("r11")}) + diff = [k for k in cur_regs if cur_regs[k] != reg_dump[k]] -def print_asm(ql: Qiling, insn: CsInsn, to_jump: Optional[bool] = None, address: int = None) -> None: + lines = "" + for idx, r in enumerate(cur_regs, 1): - opcode = "".join(f"{b:02x}" for b in insn.bytes) - if ql.archtype in (QL_ARCH.X86, QL_ARCH.X8664): - trace_line = f"0x{insn.address:08x} │ {opcode:20s} {insn.mnemonic:10} {insn.op_str:35s}" - else: - trace_line = f"0x{insn.address:08x} │ {opcode:10s} {insn.mnemonic:10} {insn.op_str:35s}" + line = "{}{:}: 0x{{:08x}} {} ".format(COLORS[(idx-1) // regs_in_row], r, color.END) - cursor = " " - if ql.reg.arch_pc == insn.address: - cursor = "►" + if diff and r in diff: + line = f"{color.UNDERLINE}{color.BOLD}{line}" - jump_sign = " " - if to_jump and address != ql.reg.arch_pc+4: - jump_sign = f"{color.RED}✓{color.END}" + if idx % regs_in_row == 0: + line += "\n" - print(f"{jump_sign} {cursor} {color.DARKGRAY}{trace_line}{color.END}") + lines += line + print(lines.format(*cur_regs.values())) + print(color.GREEN, "[{cpsr[mode]} mode], Thumb: {cpsr[thumb]}, FIQ: {cpsr[fiq]}, IRQ: {cpsr[irq]}, NEG: {cpsr[neg]}, ZERO: {cpsr[zero]}, Carry: {cpsr[carry]}, Overflow: {cpsr[overflow]}".format(cpsr=self.get_flags(self.ql.reg.cpsr)), color.END, sep="") -def context_asm(ql: Qiling, address: int) -> None: - with context_printer(ql, field_name="[ DISASM ]"): +class CtxManager_MIPS(CtxManager): + def __init__(self, ql): + super().__init__(ql) - if ql.archtype in (QL_ARCH.X86, QL_ARCH.X8664): - past_list = [] + self.regs = ( + "gp", "at", "v0", "v1", + "a0", "a1", "a2", "a3", + "t0", "t1", "t2", "t3", + "t4", "t5", "t6", "t7", + "t8", "t9", "sp", "s8", + "s0", "s1", "s2", "s3", + "s4", "s5", "s6", "s7", + "ra", "k0", "k1", "pc", + ) - # assembly before current location + @context_printer("[ REGISTERS ]") + def context_reg(self, saved_reg_dump): - line = disasm(ql, address) - acc_size = line.size + cur_regs = self.dump_regs() - while line and len(past_list) != 10: - past_list.append(line) - next_start = address + acc_size - line = disasm(ql, next_start) - acc_size += line.size + cur_regs.update({"fp": cur_regs.pop("s8")}) - # print four insns before current location - for line in past_list[:-1]: - print_asm(ql, line) + diff = None + if saved_reg_dump is not None: + reg_dump = copy.deepcopy(saved_reg_dump) + reg_dump.update({"fp": saved_reg_dump.pop("s8")}) + diff = [k for k in cur_regs if cur_regs[k] != reg_dump[k]] - else: + lines = "" + for idx, r in enumerate(cur_regs, 1): + line = "{}{}: 0x{{:08x}} {}\t".format(COLORS[(idx-1) // 4], r, color.END) + + if diff and r in diff: + line = f"{color.UNDERLINE}{color.BOLD}{line}" + + if idx % 4 == 0 and idx != 32: + line += "\n" + + lines += line + + print(lines.format(*cur_regs.values())) + + +class CtxManager_X86(CtxManager): + def __init__(self, ql): + super().__init__(ql) + + self.regs = ( + "eax", "ebx", "ecx", "edx", + "esp", "ebp", "esi", "edi", + "eip", "ss", "cs", "ds", "es", + "fs", "gs", "ef", + ) + @context_printer("[ REGISTERS ]") + def context_reg(self, saved_reg_dump): + cur_regs = self.dump_regs() + + diff = None + if saved_reg_dump is not None: + reg_dump = copy.deepcopy(saved_reg_dump) + diff = [k for k in cur_regs if cur_regs[k] != saved_reg_dump[k]] + + lines = "" + for idx, r in enumerate(cur_regs, 1): + if len(r) == 2: + line = "{}{}: 0x{{:08x}} {}\t\t".format(COLORS[(idx-1) // 4], r, color.END) + else: + line = "{}{}: 0x{{:08x}} {}\t".format(COLORS[(idx-1) // 4], r, color.END) + + if diff and r in diff: + line = f"{color.UNDERLINE}{color.BOLD}{line}" + + if idx % 4 == 0 and idx != 32: + line += "\n" + + lines += line + + print(lines.format(*cur_regs.values())) + print(color.GREEN, "EFLAGS: [CF: {flags[CF]}, PF: {flags[PF]}, AF: {flags[AF]}, ZF: {flags[ZF]}, SF: {flags[SF]}, OF: {flags[OF]}]".format(flags=get_x86_eflags(self.ql.reg.ef)), color.END, sep="") + + @context_printer("[ DISASM ]") + def context_asm(self): + past_list = [] + cur_addr = self.ql.reg.arch_pc + + cur_insn = disasm(self.ql, cur_addr) + prophecy = self.predictor.predict() + self.print_asm(cur_insn, to_jump=prophecy.going) + + # assembly before current location + + line = disasm(self.ql, cur_addr+cur_insn.size) + acc_size = line.size + cur_insn.size + + while line and len(past_list) != 8: + past_list.append(line) + next_start = cur_addr + acc_size + line = disasm(self.ql, next_start) + acc_size += line.size + + # print four insns before current location + for line in past_list[:-1]: + self.print_asm(line) + + +class CtxManager_CORTEX_M(CtxManager): + def __init__(self, ql): + super().__init__(ql) + + self.regs = ( + "r0", "r1", "r2", "r3", + "r4", "r5", "r6", "r7", + "r8", "r9", "r10", "r11", + "r12", "sp", "lr", "pc", + "xpsr", "control", "primask", "basepri", "faultmask" + ) + + @context_printer("[ REGISTERS ]") + def context_reg(self, saved_reg_dump): + + cur_regs.update({"sl": cur_regs.pop("r10")}) + cur_regs.update({"ip": cur_regs.pop("r12")}) + cur_regs.update({"fp": cur_regs.pop("r11")}) - # assembly before current location + regs_in_row = 3 - past_list = [] + # for re-order + cur_regs.update({"xpsr": cur_regs.pop("xpsr")}) + cur_regs.update({"control": cur_regs.pop("control")}) + cur_regs.update({"primask": cur_regs.pop("primask")}) + cur_regs.update({"faultmask": cur_regs.pop("faultmask")}) + cur_regs.update({"basepri": cur_regs.pop("basepri")}) - line = disasm(ql, address-0x10) + diff = None + if saved_reg_dump is not None: + reg_dump = copy.deepcopy(saved_reg_dump) + reg_dump.update({"sl": reg_dump.pop("r10")}) + reg_dump.update({"ip": reg_dump.pop("r12")}) + reg_dump.update({"fp": reg_dump.pop("r11")}) + diff = [k for k in cur_regs if cur_regs[k] != reg_dump[k]] - while line: - if line.address == address: - break + lines = "" + for idx, r in enumerate(_cur_regs, 1): - addr = line.address + line.size - line = disasm(ql, addr) + line = "{}{:}: 0x{{:08x}} {} ".format(_colors[(idx-1) // regs_in_row], r, color.END) - if not line: - break + if _diff and r in _diff: + line = "{}{}".format(color.UNDERLINE, color.BOLD) + line - past_list.append(line) + if idx % regs_in_row == 0: + line += "\n" - # print four insns before current location - for line in past_list[:-1][:4]: - print_asm(ql, line) + lines += line - # assembly for current location + print(lines.format(cur_regs.values())) + print(color.GREEN, "[{cpsr[mode]} mode], Thumb: {cpsr[thumb]}, FIQ: {cpsr[fiq]}, IRQ: {cpsr[irq]}, NEG: {cpsr[neg]}, ZERO: {cpsr[zero]}, Carry: {cpsr[carry]}, Overflow: {cpsr[overflow]}".format(cpsr=get_arm_flags(self.ql.reg.cpsr)), color.END, sep="") - cur_ins = disasm(ql, address) - to_jump, next_stop = handle_bnj(ql, address) - print_asm(ql, cur_ins, to_jump=to_jump) - # assembly after current location +if __name__ == "__main__": + pass - forward_insn_size = cur_ins.size - for _ in range(5): - forward_insn = disasm(ql, address+forward_insn_size) - if forward_insn: - print_asm(ql, forward_insn) - forward_insn_size += forward_insn.size diff --git a/qiling/debugger/qdb/qdb.py b/qiling/debugger/qdb/qdb.py index 3a0ef8369..d6366254a 100644 --- a/qiling/debugger/qdb/qdb.py +++ b/qiling/debugger/qdb/qdb.py @@ -12,10 +12,10 @@ from qiling.const import QL_ARCH, QL_VERBOSE from qiling.debugger import QlDebugger -from .frontend import context_reg, context_asm, examine_mem -from .utils import _parse_int, handle_bnj, is_thumb, CODE_END, parse_int -from .utils import Breakpoint, TempBreakpoint -from .const import * +from .frontend import examine_mem, setup_ctx_manager +from .utils import is_thumb, parse_int, setup_branch_predictor, disasm +from .utils import Breakpoint, TempBreakpoint, read_inst +from .const import color class QlQdb(cmd.Cmd, QlDebugger): @@ -31,6 +31,9 @@ def __init__(self: QlQdb, ql: Qiling, init_hook: str = "", rr: bool = False) -> if self.rr: self._states_list = [] + self.ctx = setup_ctx_manager(ql) + self.predictor = setup_branch_predictor(ql) + super().__init__() self.dbg_hook(init_hook) @@ -40,6 +43,26 @@ def dbg_hook(self: QlQdb, init_hook: str): # self.ql.loader.entry_point # ld.so # self.ql.loader.elf_entry # .text of binary + def bp_handler(ql, address, size, bp_list): + + if (bp := self.bp_list.get(address, None)): + + if isinstance(bp, TempBreakpoint): + # remove TempBreakpoint once hitted + self.del_breakpoint(bp) + + else: + if bp.hitted: + return + + print(f"{color.CYAN}[+] hit breakpoint at 0x{self.cur_addr:08x}{color.END}") + bp.hitted = True + + ql.stop() + self.do_context() + + self.ql.hook_code(bp_handler, self.bp_list) + if init_hook and self.ql.loader.entry_point != init_hook: self.do_breakpoint(init_hook) @@ -70,27 +93,6 @@ def cur_addr(self: QlQdb, address: int) -> None: self.ql.reg.arch_pc = address - def _bp_handler(self: QlQdb, *args) -> None: - """ - internal function for handling once breakpoint hitted - """ - - if (bp := self.bp_list.get(self.cur_addr, None)): - - if isinstance(bp, TempBreakpoint): - # remove TempBreakpoint once hitted - self.del_breakpoint(bp) - - else: - if bp.hitted: - return - - print(f"{color.CYAN}[+] hit breakpoint at 0x{self.cur_addr:08x}{color.END}") - bp.hitted = True - - self.ql.stop() - self.do_context() - def _save(self: QlQdb, *args) -> None: """ internal function for saving state of qiling instance @@ -191,8 +193,9 @@ def do_context(self: QlQdb, *args) -> None: show context information for current location """ - context_reg(self.ql, self._saved_reg_dump) - context_asm(self.ql, self.cur_addr) + self.ctx.context_reg(self._saved_reg_dump) + self.ctx.context_stack() + self.ctx.context_asm() def do_backward(self: QlQdb, *args) -> None: """ @@ -207,40 +210,60 @@ def do_backward(self: QlQdb, *args) -> None: self._restore() self.do_context() - def do_step(self: QlQdb, *args) -> Optional[bool]: + def update_reg_dump(self: QlQdb) -> None: + """ + internal function for updating registers dump + """ + self._saved_reg_dump = dict(filter(lambda d: isinstance(d[0], str), self.ql.reg.save().items())) + + def do_step_in(self: QlQdb, *args) -> Optional[bool]: """ - execute one instruction at a time + execute one instruction at a time, will enter subroutine """ if self.ql is None: print(f"{color.RED}[!] The program is not being run.{color.END}") else: - # save reg dump for data highlighting changes - self._saved_reg_dump = dict(filter(lambda d: isinstance(d[0], str), self.ql.reg.save().items())) + self.update_reg_dump() if self.rr: self._save() - _, next_stop = handle_bnj(self.ql, self.cur_addr) + prophecy = self.predictor.predict() - if next_stop is CODE_END: + if prophecy.where is True: return True if self.ql.archtype == QL_ARCH.CORTEX_M: self.ql.arch.step() - self.ql.count -= 1 - else: - if self.ql.archtype in (QL_ARCH.X86, QL_ARCH.X8664): - count = 1 - else: - count = 1 if next_stop == self.cur_addr + 4 else 2 - - self._run(count=count) + self._run(count=1) self.do_context() + def do_step_over(self: QlQdb, *args) -> Option[bool]: + """ + execute one instruction at a time, but WON't enter subroutine + """ + + if self.ql is None: + print(f"{color.RED}[!] The program is not being run.{color.END}") + + else: + + prophecy = self.predictor.predict() + self.update_reg_dump() + + if prophecy.going: + cur_insn = disasm(self.ql, self.cur_addr) + self.set_breakpoint(self.cur_addr + cur_insn.size, is_temp=True) + + else: + self.set_breakpoint(prophecy.where, is_temp=True) + + self._run() + def set_breakpoint(self: QlQdb, address: int, is_temp: bool = False) -> None: """ internal function for placing breakpoint @@ -248,19 +271,14 @@ def set_breakpoint(self: QlQdb, address: int, is_temp: bool = False) -> None: bp = TempBreakpoint(address) if is_temp else Breakpoint(address) - if self.ql.archtype != QL_ARCH.CORTEX_M: - # skip hook_address for cortex_m - bp.hook = self.ql.hook_address(self._bp_handler, address) - self.bp_list.update({address: bp}) def del_breakpoint(self: QlQdb, bp: Union[Breakpoint, TempBreakpoint]) -> None: """ - internal function for removing breakpoints + internal function for removing breakpoint """ - if self.bp_list.pop(bp.addr, None): - bp.hook.remove() + self.bp_list.pop(bp.addr, None) def do_start(self: QlQdb, *args) -> None: """ @@ -356,8 +374,10 @@ def do_EOF(self: QlQdb, *args) -> None: if input(f"{color.RED}[!] Are you sure about saying good bye ~ ? [Y/n]{color.END} ").strip() == "Y": self.do_quit() + do_r = do_run - do_s = do_step + do_s = do_step_in + do_n = do_step_over do_q = do_quit do_x = do_examine do_p = do_backward diff --git a/qiling/debugger/qdb/utils.py b/qiling/debugger/qdb/utils.py index bd15ada90..9178a2621 100644 --- a/qiling/debugger/qdb/utils.py +++ b/qiling/debugger/qdb/utils.py @@ -5,123 +5,38 @@ from __future__ import annotations from typing import Callable, Optional, Mapping -from functools import partial - from qiling.const import * -CODE_END = True - - -def dump_regs(ql: Qiling) -> Mapping[str, int]: - - if ql.archtype == QL_ARCH.MIPS: - - _reg_order = ( - "gp", "at", "v0", "v1", - "a0", "a1", "a2", "a3", - "t0", "t1", "t2", "t3", - "t4", "t5", "t6", "t7", - "t8", "t9", "sp", "s8", - "s0", "s1", "s2", "s3", - "s4", "s5", "s6", "s7", - "ra", "k0", "k1", "pc", - ) - - elif ql.archtype in (QL_ARCH.ARM, QL_ARCH.ARM_THUMB): - - _reg_order = ( - "r0", "r1", "r2", "r3", - "r4", "r5", "r6", "r7", - "r8", "r9", "r10", "r11", - "r12", "sp", "lr", "pc", - ) - - elif ql.archtype == QL_ARCH.X86: - - _reg_order = ( - "eax", "ebx", "ecx", "edx", - "esp", "ebp", "esi", "edi", - "eip", "ss", "cs", "ds", "es", - "fs", "gs", "ef", - ) - - elif ql.archtype == QL_ARCH.CORTEX_M: - - _reg_order = ( - "r0", "r1", "r2", "r3", - "r4", "r5", "r6", "r7", - "r8", "r9", "r10", "r11", - "r12", "sp", "lr", "pc", - "xpsr", "control", "primask", "basepri", "faultmask" - ) - - return {reg_name: getattr(ql.reg, reg_name) for reg_name in _reg_order} - - -def get_arm_flags(bits: int) -> Mapping[str, int]: - - def _get_mode(bits): - return { - 0b10000: "User", - 0b10001: "FIQ", - 0b10010: "IRQ", - 0b10011: "Supervisor", - 0b10110: "Monitor", - 0b10111: "Abort", - 0b11010: "Hypervisor", - 0b11011: "Undefined", - 0b11111: "System", - }.get(bits & 0x00001f) - - return { - "mode": _get_mode(bits), - "thumb": bits & 0x00000020 != 0, - "fiq": bits & 0x00000040 != 0, - "irq": bits & 0x00000080 != 0, - "neg": bits & 0x80000000 != 0, - "zero": bits & 0x40000000 != 0, - "carry": bits & 0x20000000 != 0, - "overflow": bits & 0x10000000 != 0, - } - +from collections import namedtuple +import ast, re # parse unsigned integer from string -def _parse_int(s: str) -> int: +def read_int(s: str) -> int: return int(s, 0) -# function dectorator for parse argument as integer +# function dectorator for parsing argument as integer def parse_int(func: Callable) -> Callable: def wrap(qdb, s: str = "") -> int: assert type(s) is str try: - ret = _parse_int(s) + ret = read_int(s) except: ret = None return func(qdb, ret) return wrap + # check wether negative value or not def is_negative(i: int) -> int: return i & (1 << 31) -# convert valu to signed +# signed value convertion def signed_val(val: int) -> int: return (val-1 << 32) if is_negative(val) else val -# handle braches and jumps so we can set berakpoint properly -def handle_bnj(ql: Qiling, cur_addr: str) -> Callable[[Qiling, str], int]: - return { - QL_ARCH.MIPS : handle_bnj_mips, - QL_ARCH.ARM : handle_bnj_arm, - QL_ARCH.ARM_THUMB: handle_bnj_arm, - QL_ARCH.CORTEX_M : handle_bnj_arm, - QL_ARCH.X86 : handle_bnj_x86, - }.get(ql.archtype)(ql, cur_addr) - - def get_cpsr(bits: int) -> (bool, bool, bool, bool): return ( bits & 0x10000000 != 0, # V, overflow flag @@ -147,11 +62,14 @@ def is_thumb(bits: int) -> bool: def disasm(ql: Qiling, address: int, detail: bool = False) -> Optional[int]: + """ + helper function for disassembling + """ md = ql.disassembler md.detail = detail try: - ret = next(md.disasm(_read_inst(ql, address), address)) + ret = next(md.disasm(read_inst(ql, address), address)) except StopIteration: ret = None @@ -159,8 +77,7 @@ def disasm(ql: Qiling, address: int, detail: bool = False) -> Optional[int]: return ret -def _read_inst(ql: Qiling, addr: int) -> int: - +def read_inst(ql: Qiling, addr: int) -> int: result = ql.mem.read(addr, 4) if ql.archtype in (QL_ARCH.ARM, QL_ARCH.ARM_THUMB, QL_ARCH.CORTEX_M): @@ -187,312 +104,459 @@ def _read_inst(ql: Qiling, addr: int) -> int: return result -def handle_bnj_x86(ql: Qilng, cur_addr: str) -> int: - - # FIXME: NO HANDLE BRANCH AND JUMP FOR X86 FOR NOW - - to_jump = False - ret_addr = None - - return (to_jump, ret_addr) +""" + Try to predict certian branch will be taken or not based on current context +""" +def setup_branch_predictor(ql: Qiling) -> BranchPredictor: + return { + QL_ARCH.X86: BranchPredictor_X86, + QL_ARCH.ARM: BranchPredictor_ARM, + QL_ARCH.ARM_THUMB: BranchPredictor_ARM, + QL_ARCH.CORTEX_M: BranchPredictor_CORTEX_M, + QL_ARCH.MIPS: BranchPredictor_MIPS, + }.get(ql.archtype)(ql) -def handle_bnj_arm(ql: Qiling, cur_addr: str) -> int: - - def _read_reg_val(regs, _reg): - return getattr(ql.reg, _reg.replace("ip", "r12").replace("fp", "r11")) - - def regdst_eq_pc(op_str): - return op_str.partition(", ")[0] == "pc" - - read_inst = partial(_read_inst, ql) - read_reg_val = partial(_read_reg_val, ql.reg) - - ARM_INST_SIZE = 4 - ARM_THUMB_INST_SIZE = 2 - - line = disasm(ql, cur_addr) - ret_addr = cur_addr + line.size - - if line.mnemonic == "udf": # indicates program exited - return CODE_END - - jump_table = { - # unconditional branch - "b" : (lambda *_: True), - "bl" : (lambda *_: True), - "bx" : (lambda *_: True), - "blx" : (lambda *_: True), - "b.w" : (lambda *_: True), - - # branch on equal, Z == 1 - "beq" : (lambda V, C, Z, N: Z == 1), - "bxeq" : (lambda V, C, Z, N: Z == 1), - "beq.w": (lambda V, C, Z, N: Z == 1), - - # branch on not equal, Z == 0 - "bne" : (lambda V, C, Z, N: Z == 0), - "bxne" : (lambda V, C, Z, N: Z == 0), - "bne.w": (lambda V, C, Z, N: Z == 0), - - # branch on signed greater than, Z == 0 and N == V - "bgt" : (lambda V, C, Z, N: (Z == 0 and N == V)), - "bgt.w": (lambda V, C, Z, N: (Z == 0 and N == V)), - - # branch on signed less than, N != V - "blt" : (lambda V, C, Z, N: N != V), - - # branch on signed greater than or equal, N == V - "bge" : (lambda V, C, Z, N: N == V), - - # branch on signed less than or queal - "ble" : (lambda V, C, Z, N: Z == 1 or N != V), - - # branch on unsigned higher or same (or carry set), C == 1 - "bhs" : (lambda V, C, Z, N: C == 1), - "bcs" : (lambda V, C, Z, N: C == 1), - - # branch on unsigned lower (or carry clear), C == 0 - "bcc" : (lambda V, C, Z, N: C == 0), - "blo" : (lambda V, C, Z, N: C == 0), - "bxlo" : (lambda V, C, Z, N: C == 0), - "blo.w": (lambda V, C, Z, N: C == 0), - - # branch on negative or minus, N == 1 - "bmi" : (lambda V, C, Z, N: N == 1), - - # branch on positive or plus, N == 0 - "bpl" : (lambda V, C, Z, N: N == 0), - - # branch on signed overflow - "bvs" : (lambda V, C, Z, N: V == 1), - - # branch on no signed overflow - "bvc" : (lambda V, C, Z, N: V == 0), +class Prophecy(object): + def __init__(self): + self.going = False + self.where = None - # branch on unsigned higher - "bhi" : (lambda V, C, Z, N: (Z == 0 and C == 1)), - "bxhi" : (lambda V, C, Z, N: (Z == 0 and C == 1)), - "bhi.w": (lambda V, C, Z, N: (Z == 0 and C == 1)), + def __iter__(self): + return iter((self.going, self.where)) - # branch on unsigned lower - "bls" : (lambda V, C, Z, N: (C == 0 or Z == 1)), - "bls.w": (lambda V, C, Z, N: (C == 0 or Z == 1)), - } +class BranchPredictor(object): + def __init__(self, ql): + self.ql = ql - cb_table = { - # branch on equal to zero - "cbz" : (lambda r: r == 0), + def read_reg(self, reg_name): + return getattr(self.ql.reg, reg_name) - # branch on not equal to zero - "cbnz": (lambda r: r != 0), - } + def predict(self): + return NotImplementedError - to_jump = False - if line.mnemonic in jump_table: - to_jump = jump_table.get(line.mnemonic)(*get_cpsr(ql.reg.cpsr)) +class BranchPredictor_ARM(BranchPredictor): + def __init__(self, ql): + super().__init__(ql) - elif line.mnemonic in cb_table: - to_jump = cb_table.get(line.mnemonic)(read_reg_val(line.op_str.split(", ")[0])) + self.INST_SIZE = 4 + self.THUMB_INST_SIZE = 2 + self.CODE_END = "udf" - if to_jump: - if "#" in line.op_str: - ret_addr = _parse_int(line.op_str.split("#")[-1]) - else: - ret_addr = read_reg_val(line.op_str) + def read_reg(self, reg_name): + reg_name = reg_name.replace("ip", "r12").replace("fp", "r11") + return getattr(self.ql.reg, reg_name) - if regdst_eq_pc(line.op_str): - next_addr = cur_addr + line.size - n2_addr = next_addr + len(read_inst(next_addr)) - ret_addr += len(read_inst(n2_addr)) + len(read_inst(next_addr)) + def regdst_eq_pc(self, op_str): + return op_str.partition(", ")[0] == "pc" - elif line.mnemonic.startswith("it"): - # handle IT block here + def predict(self): + prophecy = Prophecy() + cur_addr = self.ql.reg.arch_pc + line = disasm(self.ql, cur_addr) + prophecy.where = cur_addr + line.size + + if line.mnemonic == self.CODE_END: # indicates program exited + return True + + jump_table = { + # unconditional branch + "b" : (lambda *_: True), + "bl" : (lambda *_: True), + "bx" : (lambda *_: True), + "blx" : (lambda *_: True), + "b.w" : (lambda *_: True), + + # branch on equal, Z == 1 + "beq" : (lambda V, C, Z, N: Z == 1), + "bxeq" : (lambda V, C, Z, N: Z == 1), + "beq.w": (lambda V, C, Z, N: Z == 1), + + # branch on not equal, Z == 0 + "bne" : (lambda V, C, Z, N: Z == 0), + "bxne" : (lambda V, C, Z, N: Z == 0), + "bne.w": (lambda V, C, Z, N: Z == 0), + + # branch on signed greater than, Z == 0 and N == V + "bgt" : (lambda V, C, Z, N: (Z == 0 and N == V)), + "bgt.w": (lambda V, C, Z, N: (Z == 0 and N == V)), + + # branch on signed less than, N != V + "blt" : (lambda V, C, Z, N: N != V), - cond_met = { - "eq": lambda V, C, Z, N: (Z == 1), - "ne": lambda V, C, Z, N: (Z == 0), - "ge": lambda V, C, Z, N: (N == V), - "hs": lambda V, C, Z, N: (C == 1), - "lo": lambda V, C, Z, N: (C == 0), - "mi": lambda V, C, Z, N: (N == 1), - "pl": lambda V, C, Z, N: (N == 0), - "ls": lambda V, C, Z, N: (C == 0 or Z == 1), - "le": lambda V, C, Z, N: (Z == 1 or N != V), - "hi": lambda V, C, Z, N: (Z == 0 and C == 1), - }.get(line.op_str)(*get_cpsr(ql.reg.cpsr)) + # branch on signed greater than or equal, N == V + "bge" : (lambda V, C, Z, N: N == V), - it_block_range = [each_char for each_char in line.mnemonic[1:]] + # branch on signed less than or queal + "ble" : (lambda V, C, Z, N: Z == 1 or N != V), + + # branch on unsigned higher or same (or carry set), C == 1 + "bhs" : (lambda V, C, Z, N: C == 1), + "bcs" : (lambda V, C, Z, N: C == 1), + + # branch on unsigned lower (or carry clear), C == 0 + "bcc" : (lambda V, C, Z, N: C == 0), + "blo" : (lambda V, C, Z, N: C == 0), + "bxlo" : (lambda V, C, Z, N: C == 0), + "blo.w": (lambda V, C, Z, N: C == 0), - next_addr = cur_addr + ARM_THUMB_INST_SIZE - for each in it_block_range: - _inst = read_inst(next_addr) - n2_addr = handle_bnj_arm(ql, next_addr) + # branch on negative or minus, N == 1 + "bmi" : (lambda V, C, Z, N: N == 1), + + # branch on positive or plus, N == 0 + "bpl" : (lambda V, C, Z, N: N == 0), + + # branch on signed overflow + "bvs" : (lambda V, C, Z, N: V == 1), + + # branch on no signed overflow + "bvc" : (lambda V, C, Z, N: V == 0), - if (cond_met and each == "t") or (not cond_met and each == "e"): - if n2_addr != (next_addr+len(_inst)): # branch detected - break + # branch on unsigned higher + "bhi" : (lambda V, C, Z, N: (Z == 0 and C == 1)), + "bxhi" : (lambda V, C, Z, N: (Z == 0 and C == 1)), + "bhi.w": (lambda V, C, Z, N: (Z == 0 and C == 1)), - next_addr += len(_inst) + # branch on unsigned lower + "bls" : (lambda V, C, Z, N: (C == 0 or Z == 1)), + "bls.w": (lambda V, C, Z, N: (C == 0 or Z == 1)), + } - ret_addr = next_addr + cb_table = { + # branch on equal to zero + "cbz" : (lambda r: r == 0), - elif line.mnemonic in ("ldr",): + # branch on not equal to zero + "cbnz": (lambda r: r != 0), + } - if regdst_eq_pc(line.op_str): - _, _, rn_offset = line.op_str.partition(", ") - r, _, imm = rn_offset.strip("[]!").partition(", #") + if line.mnemonic in jump_table: + prophecy.going = jump_table.get(line.mnemonic)(*get_cpsr(self.ql.reg.cpsr)) - if "]" in rn_offset.split(", ")[1]: # pre-indexed immediate - ret_addr = ql.unpack32(ql.mem.read(_parse_int(imm) + read_reg_val(r), ARM_INST_SIZE)) + elif line.mnemonic in cb_table: + prophecy.going = cb_table.get(line.mnemonic)(self.read_reg(line.op_str.split(", ")[0])) - else: # post-indexed immediate - # FIXME: weired behavior, immediate here does not apply - ret_addr = ql.unpack32(ql.mem.read(read_reg_val(r), ARM_INST_SIZE)) + if prophecy.going: + if "#" in line.op_str: + prophecy.where = read_int(line.op_str.split("#")[-1]) + else: + prophecy.where = self.read_reg(line.op_str) - elif line.mnemonic in ("addls", "addne", "add") and regdst_eq_pc(line.op_str): - V, C, Z, N = get_cpsr(ql.reg.cpsr) - r0, r1, r2, *imm = line.op_str.split(", ") + if self.regdst_eq_pc(line.op_str): + next_addr = cur_addr + line.size + n2_addr = next_addr + len(read_inst(next_addr)) + prophecy.where += len(read_inst(n2_addr)) + len(read_inst(next_addr)) - # program counter is awalys 8 bytes ahead when it comes with pc, need to add extra 8 bytes - extra = 8 if 'pc' in (r0, r1, r2) else 0 + elif line.mnemonic.startswith("it"): + # handle IT block here - if imm: - expr = imm[0].split() - # TODO: should support more bit shifting and rotating operation - if expr[0] == "lsl": # logical shift left - n = _parse_int(expr[-1].strip("#")) * 2 + cond_met = { + "eq": lambda V, C, Z, N: (Z == 1), + "ne": lambda V, C, Z, N: (Z == 0), + "ge": lambda V, C, Z, N: (N == V), + "hs": lambda V, C, Z, N: (C == 1), + "lo": lambda V, C, Z, N: (C == 0), + "mi": lambda V, C, Z, N: (N == 1), + "pl": lambda V, C, Z, N: (N == 0), + "ls": lambda V, C, Z, N: (C == 0 or Z == 1), + "le": lambda V, C, Z, N: (Z == 1 or N != V), + "hi": lambda V, C, Z, N: (Z == 0 and C == 1), + }.get(line.op_str)(*get_cpsr(ql.reg.cpsr)) - if line.mnemonic == "addls" and (C == 0 or Z == 1): - ret_addr = extra + read_reg_val(r1) + read_reg_val(r2) * n + it_block_range = [each_char for each_char in line.mnemonic[1:]] - elif line.mnemonic == "add" or (line.mnemonic == "addne" and Z == 0): - ret_addr = extra + read_reg_val(r1) + (read_reg_val(r2) * n if imm else read_reg_val(r2)) + next_addr = cur_addr + self.THUMB_INST_SIZE + for each in it_block_range: + _inst = read_inst(next_addr) + n2_addr = handle_bnj_arm(ql, next_addr) - elif line.mnemonic in ("tbh", "tbb"): + if (cond_met and each == "t") or (not cond_met and each == "e"): + if n2_addr != (next_addr+len(_inst)): # branch detected + break - cur_addr += ARM_INST_SIZE - r0, r1, *imm = line.op_str.strip("[]").split(", ") + next_addr += len(_inst) - if imm: - expr = imm[0].split() - if expr[0] == "lsl": # logical shift left - n = _parse_int(expr[-1].strip("#")) * 2 + prophecy.where = next_addr - if line.mnemonic == "tbh": + elif line.mnemonic in ("ldr",): - r1 = read_reg_val(r1) * n + if self.regdst_eq_pc(line.op_str): + _, _, rn_offset = line.op_str.partition(", ") + r, _, imm = rn_offset.strip("[]!").partition(", #") - elif line.mnemonic == "tbb": + if "]" in rn_offset.split(", ")[1]: # pre-indexed immediate + prophecy.where = ql.unpack32(ql.mem.read(read_int(imm) + self.read_reg(r), self.INST_SIZE)) - r1 = read_reg_val(r1) + else: # post-indexed immediate + # FIXME: weired behavior, immediate here does not apply + prophecy.where = ql.unpack32(ql.mem.read(self.read_reg(r), self.INST_SIZE)) - to_add = int.from_bytes(ql.mem.read(cur_addr+r1, 2 if line.mnemonic == "tbh" else 1), byteorder="little") * n - ret_addr = cur_addr + to_add + elif line.mnemonic in ("addls", "addne", "add") and self.regdst_eq_pc(line.op_str): + V, C, Z, N = get_cpsr(ql.reg.cpsr) + r0, r1, r2, *imm = line.op_str.split(", ") - elif line.mnemonic.startswith("pop") and "pc" in line.op_str: + # program counter is awalys 8 bytes ahead when it comes with pc, need to add extra 8 bytes + extra = 8 if 'pc' in (r0, r1, r2) else 0 - ret_addr = ql.stack_read(line.op_str.strip("{}").split(", ").index("pc") * ARM_INST_SIZE) - if not { # step to next instruction if cond does not meet - "pop" : lambda *_: True, - "pop.w": lambda *_: True, - "popeq": lambda V, C, Z, N: (Z == 1), - "popne": lambda V, C, Z, N: (Z == 0), - "pophi": lambda V, C, Z, N: (C == 1), - "popge": lambda V, C, Z, N: (N == V), - "poplt": lambda V, C, Z, N: (N != V), - }.get(line.mnemonic)(*get_cpsr(ql.reg.cpsr)): + if imm: + expr = imm[0].split() + # TODO: should support more bit shifting and rotating operation + if expr[0] == "lsl": # logical shift left + n = read_int(expr[-1].strip("#")) * 2 + + if line.mnemonic == "addls" and (C == 0 or Z == 1): + prophecy.where = extra + self.read_reg(r1) + self.read_reg(r2) * n + + elif line.mnemonic == "add" or (line.mnemonic == "addne" and Z == 0): + prophecy.where = extra + self.read_reg(r1) + (self.read_reg(r2) * n if imm else self.read_reg(r2)) + + elif line.mnemonic in ("tbh", "tbb"): + + cur_addr += self.INST_SIZE + r0, r1, *imm = line.op_str.strip("[]").split(", ") + + if imm: + expr = imm[0].split() + if expr[0] == "lsl": # logical shift left + n = read_int(expr[-1].strip("#")) * 2 + + if line.mnemonic == "tbh": + + r1 = self.read_reg(r1) * n + + elif line.mnemonic == "tbb": + + r1 = self.read_reg(r1) + + to_add = int.from_bytes(ql.mem.read(cur_addr+r1, 2 if line.mnemonic == "tbh" else 1), byteorder="little") * n + prophecy.where = cur_addr + to_add + + elif line.mnemonic.startswith("pop") and "pc" in line.op_str: + + prophecy.where = ql.stack_read(line.op_str.strip("{}").split(", ").index("pc") * self.INST_SIZE) + if not { # step to next instruction if cond does not meet + "pop" : lambda *_: True, + "pop.w": lambda *_: True, + "popeq": lambda V, C, Z, N: (Z == 1), + "popne": lambda V, C, Z, N: (Z == 0), + "pophi": lambda V, C, Z, N: (C == 1), + "popge": lambda V, C, Z, N: (N == V), + "poplt": lambda V, C, Z, N: (N != V), + }.get(line.mnemonic)(*get_cpsr(ql.reg.cpsr)): + + prophecy.where = cur_addr + self.INST_SIZE + + elif line.mnemonic == "sub" and self.regdst_eq_pc(line.op_str): + _, r, imm = line.op_str.split(", ") + prophecy.where = self.read_reg(r) - read_int(imm.strip("#")) + + elif line.mnemonic == "mov" and self.regdst_eq_pc(line.op_str): + _, r = line.op_str.split(", ") + prophecy.where = self.read_reg(r) + + if prophecy.where & 1: + prophecy.where -= 1 - ret_addr = cur_addr + ARM_INST_SIZE + return prophecy - elif line.mnemonic == "sub" and regdst_eq_pc(line.op_str): - _, r, imm = line.op_str.split(", ") - ret_addr = read_reg_val(r) - _parse_int(imm.strip("#")) +class BranchPredictor_MIPS(BranchPredictor): + def __init__(self, ql): + super().__init__(ql) + self.CODE_END = "break" + self.INST_SIZE = 4 - elif line.mnemonic == "mov" and regdst_eq_pc(line.op_str): - _, r = line.op_str.split(", ") - ret_addr = read_reg_val(r) + def read_reg(self, reg_name): + reg_name = reg_name.strip("$").replace("fp", "s8") + return signed_val(getattr(self.ql.reg, reg_name)) - if ret_addr & 1: - ret_addr -= 1 + def predict(self): + prophecy = Prophecy() + cur_addr = self.ql.reg.arch_pc + line = disasm(self.ql, cur_addr) - return (to_jump, ret_addr) + if line.mnemonic == self.CODE_END: # indicates program extied + return True + prophecy.where = cur_addr + self.INST_SIZE + if line.mnemonic.startswith('j') or line.mnemonic.startswith('b'): -def handle_bnj_mips(ql: Qiling, cur_addr: str) -> int: - MIPS_INST_SIZE = 4 + # make sure at least delay slot executed + prophecy.where += self.INST_SIZE - def _read_reg(regs, _reg): - return signed_val(getattr(regs, _reg.strip('$').replace("fp", "s8"))) + # get registers or memory address from op_str + targets = [ + self.read_reg(each) + if '$' in each else read_int(each) + for each in line.op_str.split(", ") + ] - read_reg_val = partial(_read_reg, ql.reg) + prophecy.going = { + "j" : (lambda _: True), # unconditional jump + "jr" : (lambda _: True), # unconditional jump + "jal" : (lambda _: True), # unconditional jump + "jalr" : (lambda _: True), # unconditional jump + "b" : (lambda _: True), # unconditional branch + "bl" : (lambda _: True), # unconditional branch + "bal" : (lambda _: True), # unconditional branch + "beq" : (lambda r0, r1, _: r0 == r1), # branch on equal + "bne" : (lambda r0, r1, _: r0 != r1), # branch on not equal + "blt" : (lambda r0, r1, _: r0 < r1), # branch on r0 less than r1 + "bgt" : (lambda r0, r1, _: r0 > r1), # branch on r0 greater than r1 + "ble" : (lambda r0, r1, _: r0 <= r1), # brach on r0 less than or equal to r1 + "bge" : (lambda r0, r1, _: r0 >= r1), # branch on r0 greater than or equal to r1 + "beqz" : (lambda r, _: r == 0), # branch on equal to zero + "bnez" : (lambda r, _: r != 0), # branch on not equal to zero + "bgtz" : (lambda r, _: r > 0), # branch on greater than zero + "bltz" : (lambda r, _: r < 0), # branch on less than zero + "bltzal" : (lambda r, _: r < 0), # branch on less than zero and link + "blez" : (lambda r, _: r <= 0), # branch on less than or equal to zero + "bgez" : (lambda r, _: r >= 0), # branch on greater than or equal to zero + "bgezal" : (lambda r, _: r >= 0), # branch on greater than or equal to zero and link + }.get(line.mnemonic)(*targets) + + if prophecy.going: + # target address is always the rightmost one + prophecy.where = targets[-1] + + return prophecy + +class BranchPredictor_X86(BranchPredictor): + def __init__(self, ql): + super().__init__(ql) + + def predict(self): + prophecy = Prophecy() + cur_addr = self.ql.reg.arch_pc + line = disasm(self.ql, cur_addr) + + jump_table = { + # conditional jump + + "jo" : (lambda C, P, A, Z, S, O: O == 1), + "jno" : (lambda C, P, A, Z, S, O: O == 0), + + "js" : (lambda C, P, A, Z, S, O: S == 1), + "jns" : (lambda C, P, A, Z, S, O: S == 0), + + "je" : (lambda C, P, A, Z, S, O: Z == 1), + "jz" : (lambda C, P, A, Z, S, O: Z == 1), + + "jne" : (lambda C, P, A, Z, S, O: Z == 0), + "jnz" : (lambda C, P, A, Z, S, O: Z == 0), + + "jb" : (lambda C, P, A, Z, S, O: C == 1), + "jc" : (lambda C, P, A, Z, S, O: C == 1), + "jnae" : (lambda C, P, A, Z, S, O: C == 1), + + "jnb" : (lambda C, P, A, Z, S, O: C == 0), + "jnc" : (lambda C, P, A, Z, S, O: C == 0), + "jae" : (lambda C, P, A, Z, S, O: C == 0), + + "jbe" : (lambda C, P, A, Z, S, O: C == 1 or Z == 1), + "jna" : (lambda C, P, A, Z, S, O: C == 1 or Z == 1), + + "ja" : (lambda C, P, A, Z, S, O: C == 0 and Z == 0), + "jnbe" : (lambda C, P, A, Z, S, O: C == 0 and Z == 0), + + "jl" : (lambda C, P, A, Z, S, O: S != O), + "jnge" : (lambda C, P, A, Z, S, O: S != O), + + "jge" : (lambda C, P, A, Z, S, O: S == O), + "jnl" : (lambda C, P, A, Z, S, O: S == O), + + "jle" : (lambda C, P, A, Z, S, O: Z == 1 or S != O), + "jng" : (lambda C, P, A, Z, S, O: Z == 1 or S != O), + + "jg" : (lambda C, P, A, Z, S, O: Z == 0 or S == O), + "jnle" : (lambda C, P, A, Z, S, O: Z == 0 or S == O), + + "jp" : (lambda C, P, A, Z, S, O: P == 1), + "jpe" : (lambda C, P, A, Z, S, O: P == 1), + + "jnp" : (lambda C, P, A, Z, S, O: P == 0), + "jpo" : (lambda C, P, A, Z, S, O: P == 0), + + # unconditional jump + + "call" : (lambda *_: True), + "jmp" : (lambda *_: True), + + } + + jump_reg_table = { + "jcxz" : (lambda cx: cx == 0), + "jecxz" : (lambda ecx: ecx == 0), + "jrcxz" : (lambda rcx: rcx == 0), + } + + if line.mnemonic in jump_table: + eflags = get_x86_eflags(self.ql.reg.ef).values() + prophecy.going = jump_table.get(line.mnemonic)(*eflags) + + elif line.mnemonic in jump_reg_table: + prophecy.going = jump_reg_table.get(line.mnemonic)(self.ql.reg.ecx) + + if prophecy.going: + takeaway_list = ["ptr", "dword", "[", "]"] + class AST_checker(ast.NodeVisitor): + def generic_visit(self, node): + if type(node) in (ast.Module, ast.Expr, ast.BinOp, ast.Constant, ast.Add, ast.Mult, ast.Sub): + ast.NodeVisitor.generic_visit(self, node) + else: + raise ParseError("malform or invalid ast node") - line = disasm(ql, cur_addr) + if len(line.op_str.split()) > 1: + new_line = line.op_str.replace(":", "+") + for each in takeaway_list: + new_line = new_line.replace(each, " ") - if line.mnemonic == "break": # indicates program extied - return CODE_END + new_line = " ".join(new_line.split()) + for each_reg in filter(lambda r: len(r) == 3, self.ql.reg.register_mapping.keys()): + if each_reg in new_line: + new_line = re.sub(each_reg, hex(self.read_reg(each_reg)), new_line) + + for each_reg in filter(lambda r: len(r) == 2, self.ql.reg.register_mapping.keys()): + if each_reg in new_line: + new_line = re.sub(each_reg, hex(self.read_reg(each_reg)), new_line) - # default breakpoint address if no jumps and branches here - ret_addr = cur_addr + MIPS_INST_SIZE + checker = AST_checker() + ast_tree = ast.parse(new_line) - to_jump = False - if line.mnemonic.startswith('j') or line.mnemonic.startswith('b'): + checker.visit(ast_tree) - # make sure at least delay slot executed - ret_addr += MIPS_INST_SIZE + prophecy.where = eval(new_line) - # get registers or memory address from op_str - targets = [ - read_reg_val(each) - if '$' in each else _parse_int(each) - for each in line.op_str.split(", ") - ] + elif line.op_str in self.ql.reg.register_mapping: + prophecy.where = getattr(self.ql.reg, line.op_str) - to_jump = { - "j" : (lambda _: True), # uncontitional jump - "jr" : (lambda _: True), # uncontitional jump - "jal" : (lambda _: True), # uncontitional jump - "jalr" : (lambda _: True), # uncontitional jump - "b" : (lambda _: True), # unconditional branch - "bl" : (lambda _: True), # unconditional branch - "bal" : (lambda _: True), # unconditional branch - "beq" : (lambda r0, r1, _: r0 == r1), # branch on equal - "bne" : (lambda r0, r1, _: r0 != r1), # branch on not equal - "blt" : (lambda r0, r1, _: r0 < r1), # branch on r0 less than r1 - "bgt" : (lambda r0, r1, _: r0 > r1), # branch on r0 greater than r1 - "ble" : (lambda r0, r1, _: r0 <= r1), # brach on r0 less than or equal to r1 - "bge" : (lambda r0, r1, _: r0 >= r1), # branch on r0 greater than or equal to r1 - "beqz" : (lambda r, _: r == 0), # branch on equal to zero - "bnez" : (lambda r, _: r != 0), # branch on not equal to zero - "bgtz" : (lambda r, _: r > 0), # branch on greater than zero - "bltz" : (lambda r, _: r < 0), # branch on less than zero - "bltzal" : (lambda r, _: r < 0), # branch on less than zero and link - "blez" : (lambda r, _: r <= 0), # branch on less than or equal to zero - "bgez" : (lambda r, _: r >= 0), # branch on greater than or equal to zero - "bgezal" : (lambda r, _: r >= 0), # branch on greater than or equal to zero and link - }.get(line.mnemonic)(*targets) + else: + prophecy.where = read_int(line.op_str) + else: + prophecy.where = cur_addr + line.size - if to_jump: - # target address is always the rightmost one - ret_addr = targets[-1] + return prophecy - return (to_jump, ret_addr) +class BranchPredictor_CORTEX_M(BranchPredictor_ARM): + def __init__(self, ql): + super().__init__(ql) class Breakpoint(object): """ dummy class for breakpoint """ - def __init__(self, address: int): - self.addr = address + def __init__(self, addr): + self.addr = addr self.hitted = False - self.hook = None class TempBreakpoint(Breakpoint): """ dummy class for temporay breakpoint """ - def __init__(self, address): - super().__init__(address) + def __init__(self, addr): + super().__init__(addr) + +class ParseError(Exception): + pass if __name__ == "__main__":