|
8 | 8 | from republic import AsyncStreamEvents, TapeContext |
9 | 9 | from republic.tape import TapeStore |
10 | 10 |
|
| 11 | +from bub import inquirer as bub_inquirer |
11 | 12 | from bub.builtin.agent import Agent |
12 | 13 | from bub.builtin.context import default_tape_context |
| 14 | +from bub.builtin.settings import DEFAULT_MODEL |
13 | 15 | from bub.channels.base import Channel |
14 | 16 | from bub.channels.message import ChannelMessage, MediaItem |
15 | 17 | from bub.envelope import content_of, field_of |
|
18 | 20 | from bub.types import Envelope, MessageHandler, State |
19 | 21 |
|
20 | 22 | AGENTS_FILE_NAME = "AGENTS.md" |
| 23 | +MODEL_PROVIDER_CHOICES: tuple[str, ...] = ( |
| 24 | + "openrouter", |
| 25 | + "openai", |
| 26 | + "anthropic", |
| 27 | + "gemini", |
| 28 | + "azure", |
| 29 | + "bedrock", |
| 30 | + "ollama", |
| 31 | + "groq", |
| 32 | + "mistral", |
| 33 | + "deepseek", |
| 34 | +) |
| 35 | +API_FORMAT_CHOICES: tuple[str, ...] = ("completion", "responses", "messages") |
21 | 36 | DEFAULT_SYSTEM_PROMPT = """\ |
22 | 37 | <general_instruct> |
23 | 38 | Call tools or skills to finish the task. |
@@ -55,6 +70,37 @@ def _get_agent(self) -> Agent: |
55 | 70 | self._agent = Agent(self.framework) |
56 | 71 | return self._agent |
57 | 72 |
|
| 73 | + @staticmethod |
| 74 | + async def _discard_message(_: ChannelMessage) -> None: |
| 75 | + return |
| 76 | + |
| 77 | + @staticmethod |
| 78 | + def _split_model_identifier(model: str) -> tuple[str, str]: |
| 79 | + provider, separator, model_name = model.partition(":") |
| 80 | + if separator and provider and model_name: |
| 81 | + return provider.strip(), model_name.strip() |
| 82 | + default_provider, _, default_model_name = DEFAULT_MODEL.partition(":") |
| 83 | + fallback_model_name = model.strip() or default_model_name |
| 84 | + return default_provider, fallback_model_name |
| 85 | + |
| 86 | + @staticmethod |
| 87 | + def _provider_choices(current_provider: str) -> list[str]: |
| 88 | + choices = list(MODEL_PROVIDER_CHOICES) |
| 89 | + if current_provider and current_provider not in choices: |
| 90 | + choices.append(current_provider) |
| 91 | + choices.append("custom") |
| 92 | + return choices |
| 93 | + |
| 94 | + def _channel_choices(self) -> list[str]: |
| 95 | + return [c for c in self.framework.get_channels(self._discard_message) if c != "cli"] |
| 96 | + |
| 97 | + @staticmethod |
| 98 | + def _default_enabled_channels(current_value: object, available_channels: list[str]) -> list[str]: |
| 99 | + if isinstance(current_value, str) and current_value.strip() and current_value.strip().lower() != "all": |
| 100 | + selected = [name.strip() for name in current_value.split(",") if name.strip() in available_channels] |
| 101 | + return selected |
| 102 | + return available_channels |
| 103 | + |
58 | 104 | @hookimpl |
59 | 105 | def resolve_session(self, message: ChannelMessage) -> str: |
60 | 106 | session_id = field_of(message, "session_id") |
@@ -124,13 +170,69 @@ def register_cli_commands(self, app: typer.Typer) -> None: |
124 | 170 |
|
125 | 171 | app.command("run")(cli.run) |
126 | 172 | app.command("chat")(cli.chat) |
| 173 | + app.command("onboard")(cli.onboard) |
127 | 174 | app.add_typer(cli.login_app) |
128 | 175 | app.command("hooks", hidden=True)(cli.list_hooks) |
129 | 176 | app.command("gateway")(cli.gateway) |
130 | 177 | app.command("install")(cli.install) |
131 | 178 | app.command("uninstall")(cli.uninstall) |
132 | 179 | app.command("update")(cli.update) |
133 | 180 |
|
| 181 | + @hookimpl |
| 182 | + def onboard_config(self, current_config: dict[str, object]) -> dict[str, object] | None: |
| 183 | + current_model = current_config.get("model") |
| 184 | + model_default = str(current_model) if isinstance(current_model, str) and current_model else DEFAULT_MODEL |
| 185 | + provider_default, model_name_default = self._split_model_identifier(model_default) |
| 186 | + |
| 187 | + provider = bub_inquirer.ask_fuzzy( |
| 188 | + "LLM provider", |
| 189 | + choices=self._provider_choices(provider_default), |
| 190 | + default=provider_default, |
| 191 | + ) |
| 192 | + if provider == "custom": |
| 193 | + provider = bub_inquirer.ask_text("Custom provider", default=provider_default) or provider_default |
| 194 | + |
| 195 | + model_name = bub_inquirer.ask_text("LLM model", default=model_name_default) |
| 196 | + if not model_name: |
| 197 | + model_name = model_name_default |
| 198 | + model = f"{provider}:{model_name}" |
| 199 | + |
| 200 | + api_key = bub_inquirer.ask_secret("API key (optional)") |
| 201 | + |
| 202 | + current_api_base = current_config.get("api_base") |
| 203 | + api_base_default = str(current_api_base) if isinstance(current_api_base, str) else "" |
| 204 | + api_base = bub_inquirer.ask_text("API base (optional)", default=api_base_default) |
| 205 | + |
| 206 | + current_api_format = current_config.get("api_format") |
| 207 | + api_format_default = ( |
| 208 | + str(current_api_format) |
| 209 | + if isinstance(current_api_format, str) and current_api_format in API_FORMAT_CHOICES |
| 210 | + else API_FORMAT_CHOICES[0] |
| 211 | + ) |
| 212 | + api_format = bub_inquirer.ask_select("API format", choices=list(API_FORMAT_CHOICES), default=api_format_default) |
| 213 | + |
| 214 | + available_channels = self._channel_choices() |
| 215 | + default_channels = self._default_enabled_channels(current_config.get("enabled_channels"), available_channels) |
| 216 | + enabled_channels = bub_inquirer.ask_checkbox( |
| 217 | + "Channels", |
| 218 | + choices=available_channels, |
| 219 | + enabled=default_channels, |
| 220 | + validate=lambda values: True if values else "Select at least one channel.", |
| 221 | + ) |
| 222 | + |
| 223 | + stream_output = bub_inquirer.ask_confirm("Stream output", default=bool(current_config.get("stream_output"))) |
| 224 | + config: dict[str, object] = { |
| 225 | + "model": model, |
| 226 | + "api_format": api_format, |
| 227 | + "enabled_channels": ",".join(enabled_channels), |
| 228 | + "stream_output": stream_output, |
| 229 | + } |
| 230 | + if api_key: |
| 231 | + config["api_key"] = api_key |
| 232 | + if api_base: |
| 233 | + config["api_base"] = api_base |
| 234 | + return config |
| 235 | + |
134 | 236 | def _read_agents_file(self, state: State) -> str: |
135 | 237 | workspace = state.get("_runtime_workspace", str(Path.cwd())) |
136 | 238 | prompt_path = Path(workspace) / AGENTS_FILE_NAME |
|
0 commit comments