Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ Repository = "https://github.com/ogx-ai/ogx-client-python"
[project.optional-dependencies]
aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"]

[project.scripts]
ogx-client = "ogx_client.lib.cli.ogx_client:main"

[tool.uv]
managed = true
required-version = ">=0.9"
Expand Down
2 changes: 1 addition & 1 deletion src/ogx_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __init__(
"""Construct an async Agent backed by the responses + conversations APIs.

Args:
client: An async OpenAI-compatible client (e.g., openai.AsyncOpenAI() or AsyncLlamaStackClient).
client: An async OpenAI-compatible client (e.g., openai.AsyncOpenAI() or AsyncOgxClient).
The client must support the responses and conversations APIs.
"""
self.client = client
Expand Down
14 changes: 7 additions & 7 deletions src/ogx_client/lib/cli/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import click
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from llama_stack_client.lib.cli.constants import LLAMA_STACK_CLIENT_CONFIG_DIR, get_config_file_path
from ogx_client.lib.cli.constants import OGX_CLIENT_CONFIG_DIR, get_config_file_path


def get_config():
Expand All @@ -20,18 +20,18 @@ def get_config():

@click.command()
@click.help_option("-h", "--help")
@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="")
@click.option("--api-key", type=str, help="Llama Stack distribution API key", default="")
@click.option("--endpoint", type=str, help="OGX server endpoint", default="")
@click.option("--api-key", type=str, help="OGX server API key", default="")
def configure(endpoint: str | None, api_key: str | None):
"""Configure Llama Stack Client CLI."""
os.makedirs(LLAMA_STACK_CLIENT_CONFIG_DIR, exist_ok=True)
"""Configure OGX Client CLI."""
os.makedirs(OGX_CLIENT_CONFIG_DIR, exist_ok=True)
config_path = get_config_file_path()

if endpoint != "":
final_endpoint = endpoint
else:
final_endpoint = prompt(
"> Enter the endpoint of the Llama Stack distribution server: ",
"> Enter the endpoint of the OGX server: ",
validator=Validator.from_callable(
lambda x: len(x) > 0 and (parsed := urlparse(x)).scheme and parsed.netloc,
error_message="Endpoint cannot be empty and must be a valid URL, please enter a valid endpoint",
Expand Down Expand Up @@ -60,4 +60,4 @@ def configure(endpoint: str | None, api_key: str | None):
)
)

print(f"Done! You can now use the Llama Stack Client CLI with endpoint {final_endpoint}") # noqa: T201
print(f"Done! You can now use the OGX Client CLI with endpoint {final_endpoint}") # noqa: T201
4 changes: 2 additions & 2 deletions src/ogx_client/lib/cli/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from pathlib import Path

LLAMA_STACK_CLIENT_CONFIG_DIR = Path(os.path.expanduser("~/.llama/client"))
OGX_CLIENT_CONFIG_DIR = Path(os.path.expanduser("~/.ogx/client"))


def get_config_file_path():
return LLAMA_STACK_CLIENT_CONFIG_DIR / "config.yaml"
return OGX_CLIENT_CONFIG_DIR / "config.yaml"
6 changes: 2 additions & 4 deletions src/ogx_client/lib/cli/eval/run_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@click.option(
"--dataset-id",
required=False,
help="Pre-registered dataset_id to score (from llama-stack-client datasets list)",
help="Pre-registered dataset_id to score (from ogx-client datasets list)",
)
@click.option(
"--dataset-path",
Expand Down Expand Up @@ -75,9 +75,7 @@ def run_scoring(
if dataset_id is not None:
dataset = client.datasets.retrieve(dataset_id=dataset_id)
if not dataset:
click.BadParameter(
f"Dataset {dataset_id} not found. Please register using llama-stack-client datasets register"
)
click.BadParameter(f"Dataset {dataset_id} not found. Please register using ogx-client datasets register")

# TODO: this will eventually be replaced with jobs polling from server vis score_bath
# For now, get all datasets rows via datasets API
Expand Down
2 changes: 1 addition & 1 deletion src/ogx_client/lib/cli/inspect/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@click.pass_context
@handle_client_errors("inspect version")
def inspect_version(ctx):
"""Show Llama Stack version on distribution endpoint"""
"""Show OGX server version on distribution endpoint"""
client = ctx.obj["client"]
console = Console()
version_response = client.inspect.version()
Expand Down
2 changes: 1 addition & 1 deletion src/ogx_client/lib/cli/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def models():
"""Manage GenAI models."""


@click.command(name="list", help="Show available llama models at distribution endpoint")
@click.command(name="list", help="Show available models at distribution endpoint")
@click.help_option("-h", "--help")
@click.pass_context
@handle_client_errors("list models")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import json
from importlib.metadata import version

import yaml
Expand All @@ -24,16 +25,15 @@

@click.group()
@click.help_option("-h", "--help")
@click.version_option(version=version("ogx-client"), prog_name="llama-stack-client")
@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="")
@click.option("--api-key", type=str, help="Llama Stack distribution API key", default="")
@click.version_option(version=version("ogx_client"), prog_name="ogx-client")
@click.option("--endpoint", type=str, help="OGX server endpoint", default="")
@click.option("--api-key", type=str, help="OGX server API key", default="")
@click.option("--config", type=str, help="Path to config file", default=None)
@click.pass_context
def llama_stack_client(ctx, endpoint: str, api_key: str, config: str | None):
"""Welcome to the llama-stack-client CLI - a command-line interface for interacting with Llama Stack"""
def ogx_client(ctx, endpoint: str, api_key: str, config: str | None):
"""Welcome to the ogx-client CLI - a command-line interface for interacting with an OGX server"""
ctx.ensure_object(dict)

# If no config provided, check default location
if config and endpoint:
raise ValueError("Cannot use both config and endpoint")

Expand All @@ -55,40 +55,45 @@ def llama_stack_client(ctx, endpoint: str, api_key: str, config: str | None):
if endpoint == "":
endpoint = "http://localhost:8321"

default_headers = {}
default_headers: dict[str, str] = {}
if api_key != "":
default_headers = {
"Authorization": f"Bearer {api_key}",
}
default_headers["Authorization"] = f"Bearer {api_key}"

client = OgxClient(
base_url=endpoint,
provider_data={
provider_data = {
k: v
for k, v in {
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
},
}.items()
if v
}
if provider_data:
default_headers["X-OGX-Provider-Data"] = json.dumps(provider_data)

client = OgxClient(
base_url=endpoint,
api_key=api_key or None,
default_headers=default_headers,
)
ctx.obj = {"client": client}


# Register all subcommands
llama_stack_client.add_command(models, "models")
llama_stack_client.add_command(vector_stores, "vector_stores")
llama_stack_client.add_command(shields, "shields")
llama_stack_client.add_command(eval_tasks, "eval_tasks")
llama_stack_client.add_command(providers, "providers")
llama_stack_client.add_command(datasets, "datasets")
llama_stack_client.add_command(configure, "configure")
llama_stack_client.add_command(scoring_functions, "scoring_functions")
llama_stack_client.add_command(eval, "eval")
llama_stack_client.add_command(inference, "inference")
llama_stack_client.add_command(inspect, "inspect")
ogx_client.add_command(models, "models")
ogx_client.add_command(vector_stores, "vector_stores")
ogx_client.add_command(shields, "shields")
ogx_client.add_command(eval_tasks, "eval_tasks")
ogx_client.add_command(providers, "providers")
ogx_client.add_command(datasets, "datasets")
ogx_client.add_command(configure, "configure")
ogx_client.add_command(scoring_functions, "scoring_functions")
ogx_client.add_command(eval, "eval")
ogx_client.add_command(inference, "inference")
ogx_client.add_command(inspect, "inspect")


def main():
llama_stack_client()
ogx_client()


if __name__ == "__main__":
Expand Down
Loading