diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/__init__.py b/libs/tracker/llmstudio_tracker/prompt_manager/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/crud.py b/libs/tracker/llmstudio_tracker/prompt_manager/crud.py new file mode 100644 index 00000000..f4341592 --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_manager/crud.py @@ -0,0 +1,91 @@ +from llmstudio_tracker.prompt_manager import models, schemas +from sqlalchemy.orm import Session + + +def get_prompt_by_name_model_provider( + db: Session, name: str, model: str, provider: str +): + return ( + db.query(models.PromptDefault) + .filter( + models.PromptDefault.name == name, + models.PromptDefault.model == model, + models.PromptDefault.provider == provider, + models.PromptDefault.is_active == True, + ) + .order_by(models.PromptDefault.version.desc()) + .first() + ) + + +def get_prompt_by_id(db: Session, prompt_id: str): + return ( + db.query(models.PromptDefault) + .filter(models.PromptDefault.prompt_id == prompt_id) + .first() + ) + + +def get_prompt( + db: Session, + prompt_id: str = None, + name: str = None, + model: str = None, + provider: str = None, +): + if prompt_id: + return get_prompt_by_id(db, prompt_id) + else: + return get_prompt_by_name_model_provider(db, name, model, provider) + + +def add_prompt(db: Session, prompt: schemas.PromptDefault): + + prompt_created = models.PromptDefault.create_with_incremental_version( + db, + config=prompt.config, + prompt=prompt.prompt, + is_active=prompt.is_active, + name=prompt.name, + label=prompt.label, + model=prompt.model, + provider=prompt.provider, + ) + db.add(prompt_created) + db.commit() + db.refresh(prompt_created) + return prompt_created + + +def update_prompt(db: Session, prompt: schemas.PromptDefault): + if prompt.prompt_id: + existing_prompt = get_prompt_by_id(db, prompt.prompt_id) + else: + existing_prompt = get_prompt_by_name_model_provider( + db, prompt.name, prompt.model, prompt.provider + ) + + existing_prompt.config = prompt.config + existing_prompt.prompt = prompt.prompt + existing_prompt.is_active = prompt.is_active + existing_prompt.name = prompt.name + existing_prompt.model = prompt.model + existing_prompt.provider = prompt.provider + existing_prompt.version = prompt.version + existing_prompt.label = prompt.label + + db.commit() + db.refresh(existing_prompt) + return existing_prompt + + +def delete_prompt(db: Session, prompt: schemas.PromptDefault): + if prompt.prompt_id: + existing_prompt = get_prompt_by_id(db, prompt.prompt_id) + else: + existing_prompt = get_prompt_by_name_model_provider( + db, prompt.name, prompt.model, prompt.provider + ) + + db.delete(existing_prompt) + db.commit() diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py b/libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py new file mode 100644 index 00000000..bcffc07b --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py @@ -0,0 +1,56 @@ +from fastapi import APIRouter, Depends +from llmstudio_tracker.database import engine, get_db +from llmstudio_tracker.prompt_manager import crud, models, schemas +from sqlalchemy.orm import Session + +models.Base.metadata.create_all(bind=engine) + + +class PromptsRoutes: + def __init__(self, router: APIRouter): + self.router = router + self.define_routes() + + def define_routes(self): + self.router.post( + "/add/prompt", + response_model=schemas.PromptDefault, + )(self.add_prompt) + + self.router.get("/get/prompt", response_model=schemas.PromptDefault)( + self.get_prompt + ) + + self.router.patch("/update/prompt", response_model=schemas.PromptDefault)( + self.update_prompt + ) + + self.router.delete("/delete/prompt")(self.delete_prompt) + + async def add_prompt( + self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) + ): + return crud.add_prompt(db=db, prompt=prompt) + + async def update_prompt( + self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) + ): + return crud.update_prompt(db, prompt) + + async def get_prompt( + self, + prompt_info: schemas.PromptInfo, + db: Session = Depends(get_db), + ): + return crud.get_prompt( + db, + prompt_id=prompt_info.prompt_id, + name=prompt_info.name, + model=prompt_info.model, + provider=prompt_info.provider, + ) + + async def delete_prompt( + self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) + ): + return crud.delete_prompt(db, prompt) diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/manager.py b/libs/tracker/llmstudio_tracker/prompt_manager/manager.py new file mode 100644 index 00000000..7f47ec09 --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_manager/manager.py @@ -0,0 +1,61 @@ +import json + +import requests +from llmstudio_tracker.prompt_manager.schemas import PromptDefault +from llmstudio_tracker.tracker import TrackingConfig + + +class PromptManager: + def __init__(self, tracking_config: TrackingConfig): + self.tracking_url = tracking_config.url + self._session = requests.Session() + + def add_prompt(self, prompt: PromptDefault): + req = self._session.post( + f"{self.tracking_url}/api/tracking/add/prompt", + headers={"accept": "application/json", "Content-Type": "application/json"}, + data=prompt.model_dump_json(), + timeout=100, + ) + return req + + def delete_prompt(self, prompt: PromptDefault): + req = self._session.delete( + f"{self.tracking_url}/api/tracking/delete/prompt", + headers={"accept": "application/json", "Content-Type": "application/json"}, + data=prompt.model_dump_json(), + timeout=100, + ) + return req + + def update_prompt(self, prompt: PromptDefault): + req = self._session.patch( + f"{self.tracking_url}/api/tracking/update/prompt", + headers={"accept": "application/json", "Content-Type": "application/json"}, + data=prompt.model_dump_json(), + timeout=100, + ) + return req + + def get_prompt( + self, + prompt_id: str = None, + name: str = None, + model: str = None, + provider: str = None, + ): + + data = { + "prompt_id": prompt_id, + "name": name, + "model": model, + "provider": provider, + } + + req = self._session.get( + f"{self.tracking_url}/api/tracking/get/prompt", + headers={"accept": "application/json", "Content-Type": "application/json"}, + timeout=100, + data=json.dumps(data), + ) + return req diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/models.py b/libs/tracker/llmstudio_tracker/prompt_manager/models.py new file mode 100644 index 00000000..2d96e35b --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_manager/models.py @@ -0,0 +1,102 @@ +import uuid +from datetime import datetime, timezone + +from llmstudio_tracker.config import DB_TYPE +from llmstudio_tracker.database import Base +from llmstudio_tracker.db_utils import JSONEncodedDict +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + Integer, + String, + UniqueConstraint, + event, + func, +) +from sqlalchemy.orm import Session + + +class PromptDefault(Base): + __tablename__ = "prompts" + + if DB_TYPE == "bigquery": + prompt_id = Column( + String, + primary_key=True, + default=lambda: str(uuid.uuid4()), + ) + config = Column(JSONEncodedDict, nullable=True) + else: + prompt_id = Column( + String, primary_key=True, default=lambda: str(uuid.uuid4()) + ) # Generate UUID as a string + config = Column(JSON, nullable=True) + + prompt = Column(String) + is_active = Column(Boolean, default=False) + name = Column(String, nullable=False) + model = Column(String, nullable=False) + provider = Column(String, nullable=False) + version = Column(Integer, nullable=False) + label = Column(String) + updated_at = Column( + DateTime(timezone=True), + onupdate=lambda: datetime.now(timezone.utc), + default=lambda: datetime.now(timezone.utc), + ) + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + __table_args__ = ( + UniqueConstraint( + "name", "provider", "model", "version", name="uq_prompt_version" + ), + ) + + @staticmethod + def get_next_version(session, name, model, provider): + """ + Get the next version number for a combination of name, model, and provider. + """ + max_version = ( + session.query(func.max(PromptDefault.version)) + .filter_by(name=name, model=model, provider=provider) + .scalar() + ) + return (max_version or 0) + 1 + + @classmethod + def create_with_incremental_version(cls, session, **kwargs): + """ + Create a new PromptDefault entry with an incremental version. + """ + name = kwargs.get("name") + model = kwargs.get("model") + provider = kwargs.get("provider") + if not all([name, model, provider]): + raise ValueError("name, model, and provider must be provided") + + kwargs["version"] = cls.get_next_version(session, name, model, provider) + + instance = cls(**kwargs) + session.add(instance) + return instance + + @event.listens_for(Session, "before_flush") + def ensure_single_active_prompt(session, flush_context, instances): + """ + Ensures only one PromptDefault entry per (name, model, provider) can have is_active=True. + If a new entry is set as is_active=True, deactivate others in the same group. + """ + for instance in session.new.union(session.dirty): + if isinstance(instance, PromptDefault) and instance.is_active: + session.query(PromptDefault).filter( + PromptDefault.name == instance.name, + PromptDefault.model == instance.model, + PromptDefault.provider == instance.provider, + PromptDefault.is_active == True, + PromptDefault.prompt_id != instance.prompt_id, + ).update({"is_active": False}, synchronize_session="fetch") diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/schemas.py b/libs/tracker/llmstudio_tracker/prompt_manager/schemas.py new file mode 100644 index 00000000..3209f8ee --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_manager/schemas.py @@ -0,0 +1,25 @@ +from datetime import datetime +from typing import Dict, Optional + +from pydantic import BaseModel + + +class PromptInfo(BaseModel): + prompt_id: Optional[str] = None + name: Optional[str] = None + model: Optional[str] = None + provider: Optional[str] = None + + +class PromptDefault(BaseModel): + prompt_id: Optional[str] = None + config: Optional[Dict] = {} + prompt: str + is_active: Optional[bool] = None + name: str + version: Optional[int] = None + label: Optional[str] = "production" + model: str + provider: str + updated_at: Optional[datetime] = None + created_at: Optional[datetime] = None diff --git a/libs/tracker/llmstudio_tracker/server.py b/libs/tracker/llmstudio_tracker/server.py index ed22c94c..8caff406 100644 --- a/libs/tracker/llmstudio_tracker/server.py +++ b/libs/tracker/llmstudio_tracker/server.py @@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware from llmstudio_tracker.config import TRACKING_HOST, TRACKING_PORT from llmstudio_tracker.logs.endpoints import LogsRoutes +from llmstudio_tracker.prompt_manager.endpoints import PromptsRoutes from llmstudio_tracker.session.endpoints import SessionsRoutes from llmstudio_tracker.utils import get_current_version @@ -42,6 +43,7 @@ def health_check(): tracking_router = APIRouter(prefix=TRACKING_BASE_ENDPOINT) LogsRoutes(tracking_router) SessionsRoutes(tracking_router) + PromptsRoutes(tracking_router) app.include_router(tracking_router)