diff --git a/src/lmcode/agent/_display.py b/src/lmcode/agent/_display.py index ee7a616..caac11b 100644 --- a/src/lmcode/agent/_display.py +++ b/src/lmcode/agent/_display.py @@ -34,11 +34,13 @@ from rich.console import Console from rich.console import Group as RenderGroup from rich.markup import escape as _escape +from rich.padding import Padding from rich.panel import Panel as _Panel from rich.rule import Rule +from rich.style import Style from rich.syntax import Syntax from rich.table import Table -from rich.text import Text +from rich.text import Span, Text from lmcode.ui.colors import ( ACCENT, @@ -48,6 +50,7 @@ SUCCESS, TEXT_MUTED, TEXT_SECONDARY, + WARNING, ) # --------------------------------------------------------------------------- @@ -272,7 +275,7 @@ def _print_tool_call(name: str, args: dict[str, Any]) -> None: def _render_diff_sidebyside( - old_lines: list[str], new_lines: list[str], max_rows: int = 50 + old_lines: list[str], new_lines: list[str], filename: str = "", max_rows: int = 50 ) -> tuple[Table, int, int]: """Build a side-by-side diff table and return ``(table, n_added, n_removed)``. @@ -280,10 +283,8 @@ def _render_diff_sidebyside( the same palette used by Claude Code's diff view. Equal lines receive a subtle violet-tinted neutral background to keep the panel cohesive. """ - _EQ_BG = "#1c1a2e" # unchanged — violet-tinted neutral - _DEL_FG = "#f38ba8" # Catppuccin Mocha rose + _EQ_BG = None # transparent for unchanged lines _DEL_BG = "#4a221d" # Codex dark-TC del bg — warm maroon - _ADD_FG = "#a6e3a1" # Catppuccin Mocha green _ADD_BG = "#1e3a2a" # Codex dark-TC add bg — deep forest green _SEP = Text("│", style=f"dim {ACCENT}") @@ -292,47 +293,111 @@ def _render_diff_sidebyside( table.add_column(width=1, no_wrap=True) # separator table.add_column(ratio=1, no_wrap=True, overflow="fold") + lexer_name = Syntax.guess_lexer(filename, code="".join(old_lines)) if filename else "text" + + old_text_obj = Syntax( + "".join(old_lines), lexer_name, theme="one-dark", background_color="default" + ).highlight("".join(old_lines)) + new_text_obj = Syntax( + "".join(new_lines), lexer_name, theme="one-dark", background_color="default" + ).highlight("".join(new_lines)) + + old_hlt_lines = old_text_obj.split("\n") + new_hlt_lines = new_text_obj.split("\n") + + def _style_line(line_text: Text, bg_color: str | None, is_empty: bool = False) -> Any: + if is_empty: + return ( + Text("") + if not bg_color + else Padding(Text(""), (0, 0), expand=True, style=f"on {bg_color}") + ) + + new_spans = [] + for span in line_text.spans: + if isinstance(span.style, str): + new_spans.append(Span(span.start, span.end, span.style)) + else: + new_spans.append( + Span( + span.start, + span.end, + Style( + color=span.style.color, bold=span.style.bold, italic=span.style.italic + ), + ) + ) + + line_text = line_text.copy() + line_text.spans = new_spans + line_text.style = ( + Style(color=line_text.style.color) if isinstance(line_text.style, Style) else Style() + ) + + if not bg_color: + return line_text + return Padding(line_text, (0, 0), expand=True, style=f"on {bg_color}") + added = removed = rows = 0 matcher = difflib.SequenceMatcher(None, old_lines, new_lines, autojunk=False) - def _row(left: Text, right: Text) -> None: + def _row(left: Any, right: Any) -> None: table.add_row(left, _SEP, right) for op, i1, i2, j1, j2 in matcher.get_opcodes(): if rows >= max_rows: break if op == "equal": - for old, new in zip(old_lines[i1:i2], new_lines[j1:j2], strict=False): - _row( - Text(old.rstrip("\n"), style=f"#abb2bf on {_EQ_BG}"), - Text(new.rstrip("\n"), style=f"#abb2bf on {_EQ_BG}"), + for i in range(i2 - i1): + old_len = len(old_hlt_lines) + new_len = len(new_hlt_lines) + left = ( + _style_line(old_hlt_lines[i1 + i], _EQ_BG) + if i1 + i < old_len + else _style_line(Text(""), _EQ_BG, is_empty=True) ) + right = ( + _style_line(new_hlt_lines[j1 + i], _EQ_BG) + if j1 + i < new_len + else _style_line(Text(""), _EQ_BG, is_empty=True) + ) + _row(left, right) rows += 1 elif op == "replace": old_chunk = old_lines[i1:i2] new_chunk = new_lines[j1:j2] for i in range(max(len(old_chunk), len(new_chunk))): - _row( - Text( - old_chunk[i].rstrip("\n") if i < len(old_chunk) else "", - style=f"{_DEL_FG} on {_DEL_BG}", - ), - Text( - new_chunk[i].rstrip("\n") if i < len(new_chunk) else "", - style=f"{_ADD_FG} on {_ADD_BG}", - ), - ) + if i < len(old_chunk) and i1 + i < len(old_hlt_lines): + left = _style_line(old_hlt_lines[i1 + i], _DEL_BG) + else: + left = _style_line(Text(""), _DEL_BG, is_empty=True) + + if i < len(new_chunk) and j1 + i < len(new_hlt_lines): + right = _style_line(new_hlt_lines[j1 + i], _ADD_BG) + else: + right = _style_line(Text(""), _ADD_BG, is_empty=True) + _row(left, right) rows += 1 removed += i2 - i1 added += j2 - j1 elif op == "delete": - for line in old_lines[i1:i2]: - _row(Text(line.rstrip("\n"), style=f"{_DEL_FG} on {_DEL_BG}"), Text("")) + for i in range(i2 - i1): + if i1 + i < len(old_hlt_lines): + left = _style_line(old_hlt_lines[i1 + i], _DEL_BG) + else: + left = _style_line(Text(""), _DEL_BG, is_empty=True) + right = _style_line(Text(""), _DEL_BG, is_empty=True) + _row(left, right) rows += 1 removed += i2 - i1 elif op == "insert": - for line in new_lines[j1:j2]: - _row(Text(""), Text(line.rstrip("\n"), style=f"{_ADD_FG} on {_ADD_BG}")) + for i in range(j2 - j1): + left = _style_line(Text(""), _ADD_BG, is_empty=True) + if j1 + i < len(new_hlt_lines): + right = _style_line(new_hlt_lines[j1 + i], _ADD_BG) + else: + right = _style_line(Text(""), _ADD_BG, is_empty=True) + _row(left, right) rows += 1 added += j2 - j1 @@ -392,7 +457,9 @@ def _print_tool_result( else: old_ls = old_content.splitlines(keepends=True) new_ls = new_content.splitlines(keepends=True) - diff_table, n_added, n_removed = _render_diff_sidebyside(old_ls, new_ls) + diff_table, n_added, n_removed = _render_diff_sidebyside( + old_ls, new_ls, filename=path + ) if n_added == 0 and n_removed == 0: console.print( f" [{SUCCESS}]✓ write_file[/] [{TEXT_MUTED}]{short} (no changes)[/]" @@ -547,3 +614,68 @@ def _print_lmstudio_closed() -> None: "_render_diff_sidebyside", "_rewrite_as_history", ] + + +def _print_tool_preview( + name: str, + args: dict[str, Any], + old_content: str | None = None, +) -> None: + """Print a preview of a tool execution (before it happens) in ask mode.""" + if name == "write_file": + path = args.get("path", "") + new_content = args.get("content", "") + if path and new_content: + short = pathlib.Path(path).name + ext = pathlib.Path(path).suffix.lstrip(".") + if old_content is None: + lines = new_content.splitlines() + n = min(len(lines), _WRITE_PREVIEW_LINES) + preview = "\n".join(lines[:n]) + more = f"\n… ({len(lines) - n} more lines)" if len(lines) > n else "" + body: Any = Syntax( + preview + more, ext or "text", theme="one-dark", line_numbers=True + ) + title = f"[{TEXT_MUTED}]{short}[/] [{WARNING}]new file (preview)[/]" + else: + old_ls = old_content.splitlines(keepends=True) + new_ls = new_content.splitlines(keepends=True) + diff_table, n_added, n_removed = _render_diff_sidebyside( + old_ls, new_ls, filename=path + ) + if n_added == 0 and n_removed == 0: + console.print( + f" [{WARNING}]? write_file[/] [{TEXT_MUTED}]{short} " + "(no changes - preview)[/]" + ) + return + parts = [] + if n_added: + parts.append(f"[{SUCCESS}]+{n_added}[/]") + if n_removed: + parts.append(f"[{ERROR}]-{n_removed}[/]") + title = f"[{TEXT_MUTED}]{short}[/] {' '.join(parts)} [{WARNING}](preview)[/]" + body = diff_table + console.print( + _Panel(body, title=title, border_style=WARNING, box=box.ROUNDED, padding=(0, 1)) + ) + return + + if name == "run_shell": + cmd = args.get("command", "") + if cmd: + in_grid = Table.grid(padding=(0, 1)) + in_grid.add_column(style=f"bold {WARNING}", width=3, no_wrap=True) + in_grid.add_column(style=TEXT_SECONDARY) + in_grid.add_row("IN", _escape(cmd)) + title = f"[{WARNING}]run_shell (preview)[/]" + console.print( + _Panel( + in_grid, + title=title, + border_style=WARNING, + box=box.ROUNDED, + padding=(0, 1), + ) + ) + return diff --git a/src/lmcode/agent/core.py b/src/lmcode/agent/core.py index a12ef43..f968ee8 100644 --- a/src/lmcode/agent/core.py +++ b/src/lmcode/agent/core.py @@ -41,6 +41,7 @@ _print_log_event, _print_startup_tip, _print_tool_call, + _print_tool_preview, _print_tool_result, _rewrite_as_history, console, @@ -52,6 +53,7 @@ from lmcode.lms_bridge import load_model, stream_model_log, unload_model from lmcode.tools import filesystem # noqa: F401 — ensures @register decorators run from lmcode.tools.registry import get_all +from lmcode.ui._interactive_prompt import display_interactive_approval from lmcode.ui.colors import ( ACCENT, ACCENT_BRIGHT, @@ -298,6 +300,7 @@ def __init__(self, model_id: str = "auto") -> None: self._max_file_bytes: int = get_settings().agent.max_file_bytes self._show_tips: bool = get_settings().ui.show_tips self._show_stats: bool = get_settings().ui.show_stats + self._always_allowed_tools: set[str] = set() self._inference_config: dict[str, Any] = {} # passed as config= to model.act() # ------------------------------------------------------------------ @@ -896,6 +899,53 @@ async def _reveal_markdown(self, text: str) -> None: # _run_turn — single agent iteration # ------------------------------------------------------------------ + def _wrap_tool(self, fn: Any) -> Any: + _params = list(inspect.signature(fn).parameters.keys()) + + @functools.wraps(fn) + def _wrapper(*args: Any, **kwargs: Any) -> str: + merged = {_params[i]: v for i, v in enumerate(args)} + merged.update(kwargs) + name = fn.__name__ + + old_content: str | None = None + if name == "write_file": + try: + fp = pathlib.Path(merged.get("path", "")) + old_content = fp.read_text(encoding="utf-8") if fp.exists() else None + except Exception: + pass + + is_dangerous = name in ("write_file", "run_shell") + + if is_dangerous and self._mode == "ask" and name not in self._always_allowed_tools: + _print_tool_preview(name, merged, old_content=old_content) + path_or_cmd = merged.get("path") or merged.get("command") or "" + ans = display_interactive_approval(name, str(path_or_cmd)) + if ans is None: + return "error: Tool execution cancelled by user." + elif ans == "no": + return "error: Tool execution denied by user." + elif ans == "always": + self._always_allowed_tools.add(name) + elif ans not in ("yes", "always"): + return ( + f"error: Tool execution denied. " + f"User provided this instruction instead: {ans}" + ) + + if self._verbose: + _print_tool_call(name, merged) + + result: str = fn(*args, **kwargs) + + if self._verbose: + _print_tool_result(name, str(result), merged, old_content=old_content) + + return result + + return _wrapper + async def _run_turn(self, model: Any, user_input: str, live: Any = None) -> tuple[str, str]: """Send one user message, run the tool loop, return ``(response, stats_line)``. @@ -955,7 +1005,7 @@ def _on_fragment(fragment: Any, _round_index: int) -> None: """Count generated tokens for the spinner label.""" tok_count[0] += 1 - tools = [_wrap_tool_verbose(t) for t in self._tools] if self._verbose else self._tools + tools = [self._wrap_tool(t) for t in self._tools] stop_evt = asyncio.Event() shuffled_tips = random.sample(_TIPS, len(_TIPS)) if self._show_tips else [] diff --git a/src/lmcode/ui/_interactive_prompt.py b/src/lmcode/ui/_interactive_prompt.py new file mode 100644 index 0000000..7973596 --- /dev/null +++ b/src/lmcode/ui/_interactive_prompt.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any + +from prompt_toolkit import Application +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.layout.containers import HSplit, Window +from prompt_toolkit.layout.controls import FormattedTextControl +from prompt_toolkit.layout.layout import Layout +from prompt_toolkit.widgets import TextArea + + +def display_interactive_approval(tool_name: str, path_or_cmd: str) -> str | None: + """Show an inline interactive approval menu for tools in 'ask' mode. + Returns: + "yes": user approved + "no": user denied + "always": user approved and wants to auto-allow this tool + "": user typed a redirect instruction + None: user pressed Ctrl+C + """ + options = [ + ("yes", "Yes"), + ("no", "No"), + ("always", "Yes — and allow this tool automatically from now on"), + ] + selected_index = 0 + + text_area = TextArea( + prompt="[Text input box] ...or tell lmcode what to do instead: ", multiline=False + ) + + def get_radio_text() -> list[tuple[str, str]]: + result: list[tuple[str, str]] = [] + result.append(("", f"Allow this change? ({tool_name})\n")) + for i, (_val, label) in enumerate(options): + if i == selected_index: + result.append(("class:selected", f"❯ {label}\n")) + else: + result.append(("", f" {label}\n")) + return result + + radio_window = Window(content=FormattedTextControl(get_radio_text), dont_extend_height=True) + + root_container = HSplit([radio_window, text_area]) + + layout = Layout(root_container, focused_element=text_area) + + kb = KeyBindings() + + @kb.add("up") + def _up(event: Any) -> None: + nonlocal selected_index + selected_index = max(0, selected_index - 1) + + @kb.add("down") + def _down(event: Any) -> None: + nonlocal selected_index + selected_index = min(len(options) - 1, selected_index + 1) + + @kb.add("enter") + def _enter(event: Any) -> None: + if text_area.text.strip(): + event.app.exit(result=text_area.text) + else: + event.app.exit(result=options[selected_index][0]) + + # Keyboard interrupt handler + @kb.add("c-c") + def _ctrl_c(event: Any) -> None: + event.app.exit(result=None) + + app: Application[str | None] = Application( + layout=layout, + key_bindings=kb, + full_screen=False, + ) + + return app.run()