From 74cb3b8644e3156762549afc9ff503c2050c3c32 Mon Sep 17 00:00:00 2001 From: brunoalho Date: Tue, 21 Jan 2025 17:29:34 +0000 Subject: [PATCH 1/3] [feat] prompt management --- .../prompt_management/__init__.py | 0 .../prompt_management/crud.py | 74 +++++++++++++++++++ .../prompt_management/endpoints.py | 55 ++++++++++++++ .../prompt_management/manager.py | 50 +++++++++++++ .../prompt_management/models.py | 47 ++++++++++++ .../prompt_management/schemas.py | 16 ++++ 6 files changed, 242 insertions(+) create mode 100644 libs/tracker/llmstudio_tracker/prompt_management/__init__.py create mode 100644 libs/tracker/llmstudio_tracker/prompt_management/crud.py create mode 100644 libs/tracker/llmstudio_tracker/prompt_management/endpoints.py create mode 100644 libs/tracker/llmstudio_tracker/prompt_management/manager.py create mode 100644 libs/tracker/llmstudio_tracker/prompt_management/models.py create mode 100644 libs/tracker/llmstudio_tracker/prompt_management/schemas.py diff --git a/libs/tracker/llmstudio_tracker/prompt_management/__init__.py b/libs/tracker/llmstudio_tracker/prompt_management/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/tracker/llmstudio_tracker/prompt_management/crud.py b/libs/tracker/llmstudio_tracker/prompt_management/crud.py new file mode 100644 index 00000000..f0401272 --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_management/crud.py @@ -0,0 +1,74 @@ +from llmstudio_tracker.prompt_management import models, schemas +from sqlalchemy.orm import Session + + +def get_project_by_name(db: Session, name: str): + return db.query(models.Project).filter(models.Project.name == name).first() + + +def get_prompt_by_name_and_label( + db: Session, name: str, label: str = "production", skip: int = 0, limit: int = 100 +): + return ( + db.query(models.PromptDefault) + .filter(models.PromptDefault.name == name, models.PromptDefault.label == label) + .order_by(models.PromptDefault.created_at.asc()) + .offset(skip) + .limit(limit) + .first() + ) + + +def get_prompt_by_id(db: Session, prompt_id: int): + return ( + db.query(models.PromptDefault) + .filter(models.PromptDefault.prompt_id == prompt_id) + .first() + ) + + +def get_prompt(db: Session, prompt_id: int = None, name: str = None, label: str = None): + if prompt_id: + return get_prompt_by_id(db, prompt_id) + else: + return get_prompt_by_name_and_label(db, name, label) + + +def add_prompt(db: Session, prompt: schemas.PromptDefault): + db_session = models.PromptDefault(**prompt.dict()) + db.add(db_session) + db.commit() + db.refresh(db_session) + return db_session + + +def update_prompt(db: Session, prompt: schemas.PromptDefault): + if prompt.prompt_id: + existing_prompt = get_prompt_by_id(db, prompt.prompt_id) + else: + existing_prompt = get_prompt_by_name_and_label(db, prompt.name, prompt.label) + + existing_prompt.config = prompt.config + existing_prompt.prompt = prompt.prompt + existing_prompt.is_active = prompt.is_active + existing_prompt.name = prompt.name + existing_prompt.version = prompt.version + existing_prompt.label = prompt.label + + db.commit() + db.refresh(existing_prompt) + return existing_prompt + + +def delete_prompt(db: Session, prompt: schemas.PromptDefault): + db_prompt = ( + db.query(schemas.PromptDefault) + .filter( + models.PromptDefault.name == prompt.name, + models.PromptDefault.label == prompt.label, + ) + .one() + ) + + db.delete(db_prompt) + db.commit() diff --git a/libs/tracker/llmstudio_tracker/prompt_management/endpoints.py b/libs/tracker/llmstudio_tracker/prompt_management/endpoints.py new file mode 100644 index 00000000..eba670b1 --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_management/endpoints.py @@ -0,0 +1,55 @@ +from typing import List + +from fastapi import APIRouter, Depends +from llmstudio_tracker.database import engine, get_db +from llmstudio_tracker.prompt_management import crud, models, schemas +from sqlalchemy.orm import Session + +models.Base.metadata.create_all(bind=engine) + + +class PromptsRoutes: + def __init__(self, router: APIRouter): + self.router = router + self.define_routes() + + def define_routes(self): + # Add session + self.router.post( + "/prompt", + response_model=schemas.PromptDefault, + )(self.add_prompt) + + self.router.get("/prompt", response_model=List[schemas.PromptDefault])( + self.get_prompt + ) + + self.router.patch("/prompt", response_model=schemas.PromptDefault)( + self.update_prompt + ) + + self.router.delete("/prompt")(self.delete_prompt) + + async def add_prompt( + self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) + ): + return crud.add_prompt(db=db, prompt=prompt) + + async def update_prompt( + self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) + ): + return crud.update_prompt(db, prompt) + + async def get_prompt( + self, + prompt_id: int = None, + name: str = None, + label: str = None, + db: Session = Depends(get_db), + ): + return crud.get_prompt(db, prompt_id=prompt_id, name=name, label=label) + + async def delete_prompt( + self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) + ): + return crud.delete_prompt(db, prompt) diff --git a/libs/tracker/llmstudio_tracker/prompt_management/manager.py b/libs/tracker/llmstudio_tracker/prompt_management/manager.py new file mode 100644 index 00000000..d4bd2af5 --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_management/manager.py @@ -0,0 +1,50 @@ +import json + +import requests +from llmstudio_tracker.prompt_management.schemas import PromptDefault +from llmstudio_tracker.tracker import TrackingConfig + + +class PromptManagement: + def __init__(self, tracking_config: TrackingConfig): + self.tracking_url = tracking_config.url + self._session = requests.Session() + + def add_prompt(self, prompt: PromptDefault): + req = self._session.post( + f"{self.tracking_url}/api/tracking/prompt", + headers={"accept": "application/json", "Content-Type": "application/json"}, + data=prompt.model_dump(), + timeout=100, + ) + return req + + def delete_prompt(self, prompt: PromptDefault): + req = self._session.delete( + f"{self.tracking_url}/api/tracking/prompt", + headers={"accept": "application/json", "Content-Type": "application/json"}, + data=prompt.model_dump(), + timeout=100, + ) + return req + + def update_prompt(self, prompt: PromptDefault): + req = self._session.get( + f"{self.tracking_url}/api/tracking/prompt", + headers={"accept": "application/json", "Content-Type": "application/json"}, + data=prompt.model_dump(), + timeout=100, + ) + return req + + def get_prompt(self, prompt_id: str = None, name: str = None, label=None): + + data = {"prompt_id": prompt_id, "name": name, "label": label} + + req = self._session.get( + f"{self.tracking_url}/api/tracking/prompt", + headers={"accept": "application/json", "Content-Type": "application/json"}, + timeout=100, + data=json.dumps(data), + ) + return req diff --git a/libs/tracker/llmstudio_tracker/prompt_management/models.py b/libs/tracker/llmstudio_tracker/prompt_management/models.py new file mode 100644 index 00000000..398f6c7e --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_management/models.py @@ -0,0 +1,47 @@ +from datetime import datetime, timezone + +from llmstudio_tracker.config import DB_TYPE +from llmstudio_tracker.database import Base +from llmstudio_tracker.db_utils import JSONEncodedDict +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + Integer, + String, + UniqueConstraint, +) + + +class PromptDefault(Base): + __tablename__ = "prompts" + + if DB_TYPE == "bigquery": + prompt_id = Column( + Integer, + primary_key=True, + default=lambda: int( + datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S%f")[:-1] + ), + ) + config = Column(JSONEncodedDict) + else: + prompt_id = Column(Integer, primary_key=True, index=True) + config = Column(JSON) + + prompt = Column(String) + is_active = Column(Boolean) + name = Column(String) + version = Column(Integer) + label = Column(String) + updated_at = Column( + DateTime(timezone=True), + onupdate=lambda: datetime.now(timezone.utc), + default=lambda: datetime.now(timezone.utc), + ) + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + __table_args__ = (UniqueConstraint("name", "label", name="uq_name_label"),) diff --git a/libs/tracker/llmstudio_tracker/prompt_management/schemas.py b/libs/tracker/llmstudio_tracker/prompt_management/schemas.py new file mode 100644 index 00000000..05c7fc30 --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_management/schemas.py @@ -0,0 +1,16 @@ +from datetime import datetime +from typing import Dict + +from pydantic import BaseModel + + +class PromptDefault(BaseModel): + prompt_id: str + config: Dict + prompt: str + is_active: bool + name: str + version: int + label: str + updated_at: datetime + created_at: datetime From 8c5638548b85346af78834a3e62a35f2f856c00a Mon Sep 17 00:00:00 2001 From: brunoalho Date: Wed, 22 Jan 2025 16:57:05 +0000 Subject: [PATCH 2/3] [feat] testing --- .../prompt_management/crud.py | 74 -------------- .../prompt_management/models.py | 47 --------- .../prompt_management/schemas.py | 16 --- .../__init__.py | 0 .../llmstudio_tracker/prompt_manager/crud.py | 97 +++++++++++++++++++ .../endpoints.py | 19 ++-- .../manager.py | 30 ++++-- .../prompt_manager/models.py | 84 ++++++++++++++++ .../prompt_manager/schemas.py | 25 +++++ libs/tracker/llmstudio_tracker/server.py | 2 + 10 files changed, 239 insertions(+), 155 deletions(-) delete mode 100644 libs/tracker/llmstudio_tracker/prompt_management/crud.py delete mode 100644 libs/tracker/llmstudio_tracker/prompt_management/models.py delete mode 100644 libs/tracker/llmstudio_tracker/prompt_management/schemas.py rename libs/tracker/llmstudio_tracker/{prompt_management => prompt_manager}/__init__.py (100%) create mode 100644 libs/tracker/llmstudio_tracker/prompt_manager/crud.py rename libs/tracker/llmstudio_tracker/{prompt_management => prompt_manager}/endpoints.py (75%) rename libs/tracker/llmstudio_tracker/{prompt_management => prompt_manager}/manager.py (69%) create mode 100644 libs/tracker/llmstudio_tracker/prompt_manager/models.py create mode 100644 libs/tracker/llmstudio_tracker/prompt_manager/schemas.py diff --git a/libs/tracker/llmstudio_tracker/prompt_management/crud.py b/libs/tracker/llmstudio_tracker/prompt_management/crud.py deleted file mode 100644 index f0401272..00000000 --- a/libs/tracker/llmstudio_tracker/prompt_management/crud.py +++ /dev/null @@ -1,74 +0,0 @@ -from llmstudio_tracker.prompt_management import models, schemas -from sqlalchemy.orm import Session - - -def get_project_by_name(db: Session, name: str): - return db.query(models.Project).filter(models.Project.name == name).first() - - -def get_prompt_by_name_and_label( - db: Session, name: str, label: str = "production", skip: int = 0, limit: int = 100 -): - return ( - db.query(models.PromptDefault) - .filter(models.PromptDefault.name == name, models.PromptDefault.label == label) - .order_by(models.PromptDefault.created_at.asc()) - .offset(skip) - .limit(limit) - .first() - ) - - -def get_prompt_by_id(db: Session, prompt_id: int): - return ( - db.query(models.PromptDefault) - .filter(models.PromptDefault.prompt_id == prompt_id) - .first() - ) - - -def get_prompt(db: Session, prompt_id: int = None, name: str = None, label: str = None): - if prompt_id: - return get_prompt_by_id(db, prompt_id) - else: - return get_prompt_by_name_and_label(db, name, label) - - -def add_prompt(db: Session, prompt: schemas.PromptDefault): - db_session = models.PromptDefault(**prompt.dict()) - db.add(db_session) - db.commit() - db.refresh(db_session) - return db_session - - -def update_prompt(db: Session, prompt: schemas.PromptDefault): - if prompt.prompt_id: - existing_prompt = get_prompt_by_id(db, prompt.prompt_id) - else: - existing_prompt = get_prompt_by_name_and_label(db, prompt.name, prompt.label) - - existing_prompt.config = prompt.config - existing_prompt.prompt = prompt.prompt - existing_prompt.is_active = prompt.is_active - existing_prompt.name = prompt.name - existing_prompt.version = prompt.version - existing_prompt.label = prompt.label - - db.commit() - db.refresh(existing_prompt) - return existing_prompt - - -def delete_prompt(db: Session, prompt: schemas.PromptDefault): - db_prompt = ( - db.query(schemas.PromptDefault) - .filter( - models.PromptDefault.name == prompt.name, - models.PromptDefault.label == prompt.label, - ) - .one() - ) - - db.delete(db_prompt) - db.commit() diff --git a/libs/tracker/llmstudio_tracker/prompt_management/models.py b/libs/tracker/llmstudio_tracker/prompt_management/models.py deleted file mode 100644 index 398f6c7e..00000000 --- a/libs/tracker/llmstudio_tracker/prompt_management/models.py +++ /dev/null @@ -1,47 +0,0 @@ -from datetime import datetime, timezone - -from llmstudio_tracker.config import DB_TYPE -from llmstudio_tracker.database import Base -from llmstudio_tracker.db_utils import JSONEncodedDict -from sqlalchemy import ( - JSON, - Boolean, - Column, - DateTime, - Integer, - String, - UniqueConstraint, -) - - -class PromptDefault(Base): - __tablename__ = "prompts" - - if DB_TYPE == "bigquery": - prompt_id = Column( - Integer, - primary_key=True, - default=lambda: int( - datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S%f")[:-1] - ), - ) - config = Column(JSONEncodedDict) - else: - prompt_id = Column(Integer, primary_key=True, index=True) - config = Column(JSON) - - prompt = Column(String) - is_active = Column(Boolean) - name = Column(String) - version = Column(Integer) - label = Column(String) - updated_at = Column( - DateTime(timezone=True), - onupdate=lambda: datetime.now(timezone.utc), - default=lambda: datetime.now(timezone.utc), - ) - created_at = Column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ) - - __table_args__ = (UniqueConstraint("name", "label", name="uq_name_label"),) diff --git a/libs/tracker/llmstudio_tracker/prompt_management/schemas.py b/libs/tracker/llmstudio_tracker/prompt_management/schemas.py deleted file mode 100644 index 05c7fc30..00000000 --- a/libs/tracker/llmstudio_tracker/prompt_management/schemas.py +++ /dev/null @@ -1,16 +0,0 @@ -from datetime import datetime -from typing import Dict - -from pydantic import BaseModel - - -class PromptDefault(BaseModel): - prompt_id: str - config: Dict - prompt: str - is_active: bool - name: str - version: int - label: str - updated_at: datetime - created_at: datetime diff --git a/libs/tracker/llmstudio_tracker/prompt_management/__init__.py b/libs/tracker/llmstudio_tracker/prompt_manager/__init__.py similarity index 100% rename from libs/tracker/llmstudio_tracker/prompt_management/__init__.py rename to libs/tracker/llmstudio_tracker/prompt_manager/__init__.py diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/crud.py b/libs/tracker/llmstudio_tracker/prompt_manager/crud.py new file mode 100644 index 00000000..4cf7b951 --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_manager/crud.py @@ -0,0 +1,97 @@ +from llmstudio_tracker.prompt_manager import models, schemas +from sqlalchemy.orm import Session + + +def get_project_by_name(db: Session, name: str): + return db.query(models.Project).filter(models.Project.name == name).first() + + +def get_prompt_by_name_model_provider( + db: Session, name: str, model: str, provider: str, skip: int = 0, limit: int = 100 +): + return ( + db.query(models.PromptDefault) + .filter( + models.PromptDefault.name == name, + models.PromptDefault.model == model, + models.PromptDefault.provider == provider, + models.PromptDefault.is_active == True, + ) + .order_by(models.PromptDefault.version.desc()) + .offset(skip) + .limit(limit) + .first() + ) + + +def get_prompt_by_id(db: Session, prompt_id: str): + return ( + db.query(models.PromptDefault) + .filter(models.PromptDefault.prompt_id == prompt_id) + .first() + ) + + +def get_prompt( + db: Session, + prompt_id: str = None, + name: str = None, + model: str = None, + provider: str = None, +): + if prompt_id: + return get_prompt_by_id(db, prompt_id) + else: + return get_prompt_by_name_model_provider(db, name, model, provider) + + +def add_prompt(db: Session, prompt: schemas.PromptDefault): + + prompt_created = models.PromptDefault.create_with_incremental_version( + db, + config=prompt.config, + prompt=prompt.prompt, + is_active=prompt.is_active, + name=prompt.name, + label=prompt.label, + model=prompt.model, + provider=prompt.provider, + ) + db.add(prompt_created) + db.commit() + db.refresh(prompt_created) + return prompt_created + + +def update_prompt(db: Session, prompt: schemas.PromptDefault): + if prompt.prompt_id: + existing_prompt = get_prompt_by_id(db, prompt.prompt_id) + else: + existing_prompt = get_prompt_by_name_model_provider( + db, prompt.name, prompt.model, prompt.provider + ) + + existing_prompt.config = prompt.config + existing_prompt.prompt = prompt.prompt + existing_prompt.is_active = prompt.is_active + existing_prompt.name = prompt.name + existing_prompt.model = prompt.model + existing_prompt.provider = prompt.provider + existing_prompt.version = prompt.version + existing_prompt.label = prompt.label + + db.commit() + db.refresh(existing_prompt) + return existing_prompt + + +def delete_prompt(db: Session, prompt: schemas.PromptDefault): + if prompt.prompt_id: + existing_prompt = get_prompt_by_id(db, prompt.prompt_id) + else: + existing_prompt = get_prompt_by_name_model_provider( + db, prompt.name, prompt.model, prompt.provider + ) + + db.delete(existing_prompt) + db.commit() diff --git a/libs/tracker/llmstudio_tracker/prompt_management/endpoints.py b/libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py similarity index 75% rename from libs/tracker/llmstudio_tracker/prompt_management/endpoints.py rename to libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py index eba670b1..4f9b8af9 100644 --- a/libs/tracker/llmstudio_tracker/prompt_management/endpoints.py +++ b/libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py @@ -1,8 +1,6 @@ -from typing import List - from fastapi import APIRouter, Depends from llmstudio_tracker.database import engine, get_db -from llmstudio_tracker.prompt_management import crud, models, schemas +from llmstudio_tracker.prompt_manager import crud, models, schemas from sqlalchemy.orm import Session models.Base.metadata.create_all(bind=engine) @@ -14,13 +12,12 @@ def __init__(self, router: APIRouter): self.define_routes() def define_routes(self): - # Add session self.router.post( "/prompt", response_model=schemas.PromptDefault, )(self.add_prompt) - self.router.get("/prompt", response_model=List[schemas.PromptDefault])( + self.router.get("/prompt", response_model=schemas.PromptDefault)( self.get_prompt ) @@ -42,12 +39,16 @@ async def update_prompt( async def get_prompt( self, - prompt_id: int = None, - name: str = None, - label: str = None, + prompt_info: schemas.PromptInfo, db: Session = Depends(get_db), ): - return crud.get_prompt(db, prompt_id=prompt_id, name=name, label=label) + return crud.get_prompt( + db, + prompt_id=prompt_info.prompt_id, + name=prompt_info.name, + model=prompt_info.model, + provider=prompt_info.provider, + ) async def delete_prompt( self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) diff --git a/libs/tracker/llmstudio_tracker/prompt_management/manager.py b/libs/tracker/llmstudio_tracker/prompt_manager/manager.py similarity index 69% rename from libs/tracker/llmstudio_tracker/prompt_management/manager.py rename to libs/tracker/llmstudio_tracker/prompt_manager/manager.py index d4bd2af5..0e130707 100644 --- a/libs/tracker/llmstudio_tracker/prompt_management/manager.py +++ b/libs/tracker/llmstudio_tracker/prompt_manager/manager.py @@ -1,11 +1,11 @@ import json import requests -from llmstudio_tracker.prompt_management.schemas import PromptDefault +from llmstudio_tracker.prompt_manager.schemas import PromptDefault from llmstudio_tracker.tracker import TrackingConfig -class PromptManagement: +class PromptManager: def __init__(self, tracking_config: TrackingConfig): self.tracking_url = tracking_config.url self._session = requests.Session() @@ -14,32 +14,44 @@ def add_prompt(self, prompt: PromptDefault): req = self._session.post( f"{self.tracking_url}/api/tracking/prompt", headers={"accept": "application/json", "Content-Type": "application/json"}, - data=prompt.model_dump(), + data=prompt.model_dump_json(), timeout=100, ) + print(req) return req def delete_prompt(self, prompt: PromptDefault): req = self._session.delete( f"{self.tracking_url}/api/tracking/prompt", headers={"accept": "application/json", "Content-Type": "application/json"}, - data=prompt.model_dump(), + data=prompt.model_dump_json(), timeout=100, ) return req def update_prompt(self, prompt: PromptDefault): - req = self._session.get( + req = self._session.patch( f"{self.tracking_url}/api/tracking/prompt", headers={"accept": "application/json", "Content-Type": "application/json"}, - data=prompt.model_dump(), + data=prompt.model_dump_json(), timeout=100, ) return req - def get_prompt(self, prompt_id: str = None, name: str = None, label=None): - - data = {"prompt_id": prompt_id, "name": name, "label": label} + def get_prompt( + self, + prompt_id: str = None, + name: str = None, + model: str = None, + provider: str = None, + ): + + data = { + "prompt_id": prompt_id, + "name": name, + "model": model, + "provider": provider, + } req = self._session.get( f"{self.tracking_url}/api/tracking/prompt", diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/models.py b/libs/tracker/llmstudio_tracker/prompt_manager/models.py new file mode 100644 index 00000000..0a2b4900 --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_manager/models.py @@ -0,0 +1,84 @@ +import uuid +from datetime import datetime, timezone + +from llmstudio_tracker.config import DB_TYPE +from llmstudio_tracker.database import Base +from llmstudio_tracker.db_utils import JSONEncodedDict +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + Integer, + String, + UniqueConstraint, + func, +) + + +class PromptDefault(Base): + __tablename__ = "prompts" + + if DB_TYPE == "bigquery": + prompt_id = Column( + String, + primary_key=True, + default=lambda: str(uuid.uuid4()), + ) + config = Column(JSONEncodedDict, nullable=True) + else: + prompt_id = Column( + String, primary_key=True, default=lambda: str(uuid.uuid4()) + ) # Generate UUID as a string + config = Column(JSON, nullable=True) + + prompt = Column(String) + is_active = Column(Boolean, default=True) + name = Column(String) + model = Column(String) + provider = Column(String) + version = Column(Integer) + label = Column(String) + updated_at = Column( + DateTime(timezone=True), + onupdate=lambda: datetime.now(timezone.utc), + default=lambda: datetime.now(timezone.utc), + ) + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + __table_args__ = ( + UniqueConstraint("name", "provider", "model", "version", name="uq_name_label"), + ) + + @staticmethod + def get_next_version(session, name, model, provider): + """ + Get the next version number for a combination of name, model, and provider. + """ + max_version = ( + session.query(func.max(PromptDefault.version)) + .filter_by(name=name, model=model, provider=provider) + .scalar() + ) + return (max_version or 0) + 1 + + @classmethod + def create_with_incremental_version(cls, session, **kwargs): + """ + Create a new PromptDefault entry with an incremental version. + """ + name = kwargs.get("name") + model = kwargs.get("model") + provider = kwargs.get("provider") + if not all([name, model, provider]): + raise ValueError("name, model, and provider must be provided") + + # Determine the next version + kwargs["version"] = cls.get_next_version(session, name, model, provider) + + # Create and add the new instance + instance = cls(**kwargs) + session.add(instance) + return instance diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/schemas.py b/libs/tracker/llmstudio_tracker/prompt_manager/schemas.py new file mode 100644 index 00000000..3209f8ee --- /dev/null +++ b/libs/tracker/llmstudio_tracker/prompt_manager/schemas.py @@ -0,0 +1,25 @@ +from datetime import datetime +from typing import Dict, Optional + +from pydantic import BaseModel + + +class PromptInfo(BaseModel): + prompt_id: Optional[str] = None + name: Optional[str] = None + model: Optional[str] = None + provider: Optional[str] = None + + +class PromptDefault(BaseModel): + prompt_id: Optional[str] = None + config: Optional[Dict] = {} + prompt: str + is_active: Optional[bool] = None + name: str + version: Optional[int] = None + label: Optional[str] = "production" + model: str + provider: str + updated_at: Optional[datetime] = None + created_at: Optional[datetime] = None diff --git a/libs/tracker/llmstudio_tracker/server.py b/libs/tracker/llmstudio_tracker/server.py index ed22c94c..8caff406 100644 --- a/libs/tracker/llmstudio_tracker/server.py +++ b/libs/tracker/llmstudio_tracker/server.py @@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware from llmstudio_tracker.config import TRACKING_HOST, TRACKING_PORT from llmstudio_tracker.logs.endpoints import LogsRoutes +from llmstudio_tracker.prompt_manager.endpoints import PromptsRoutes from llmstudio_tracker.session.endpoints import SessionsRoutes from llmstudio_tracker.utils import get_current_version @@ -42,6 +43,7 @@ def health_check(): tracking_router = APIRouter(prefix=TRACKING_BASE_ENDPOINT) LogsRoutes(tracking_router) SessionsRoutes(tracking_router) + PromptsRoutes(tracking_router) app.include_router(tracking_router) From 22def94d644a858b43baeb0a0fb502939bc3d966 Mon Sep 17 00:00:00 2001 From: brunoalho Date: Thu, 23 Jan 2025 11:33:36 +0000 Subject: [PATCH 3/3] [feat] only one active prompt --- .../llmstudio_tracker/prompt_manager/crud.py | 8 +---- .../prompt_manager/endpoints.py | 8 ++--- .../prompt_manager/manager.py | 9 +++-- .../prompt_manager/models.py | 34 ++++++++++++++----- 4 files changed, 35 insertions(+), 24 deletions(-) diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/crud.py b/libs/tracker/llmstudio_tracker/prompt_manager/crud.py index 4cf7b951..f4341592 100644 --- a/libs/tracker/llmstudio_tracker/prompt_manager/crud.py +++ b/libs/tracker/llmstudio_tracker/prompt_manager/crud.py @@ -2,12 +2,8 @@ from sqlalchemy.orm import Session -def get_project_by_name(db: Session, name: str): - return db.query(models.Project).filter(models.Project.name == name).first() - - def get_prompt_by_name_model_provider( - db: Session, name: str, model: str, provider: str, skip: int = 0, limit: int = 100 + db: Session, name: str, model: str, provider: str ): return ( db.query(models.PromptDefault) @@ -18,8 +14,6 @@ def get_prompt_by_name_model_provider( models.PromptDefault.is_active == True, ) .order_by(models.PromptDefault.version.desc()) - .offset(skip) - .limit(limit) .first() ) diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py b/libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py index 4f9b8af9..bcffc07b 100644 --- a/libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py +++ b/libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py @@ -13,19 +13,19 @@ def __init__(self, router: APIRouter): def define_routes(self): self.router.post( - "/prompt", + "/add/prompt", response_model=schemas.PromptDefault, )(self.add_prompt) - self.router.get("/prompt", response_model=schemas.PromptDefault)( + self.router.get("/get/prompt", response_model=schemas.PromptDefault)( self.get_prompt ) - self.router.patch("/prompt", response_model=schemas.PromptDefault)( + self.router.patch("/update/prompt", response_model=schemas.PromptDefault)( self.update_prompt ) - self.router.delete("/prompt")(self.delete_prompt) + self.router.delete("/delete/prompt")(self.delete_prompt) async def add_prompt( self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/manager.py b/libs/tracker/llmstudio_tracker/prompt_manager/manager.py index 0e130707..7f47ec09 100644 --- a/libs/tracker/llmstudio_tracker/prompt_manager/manager.py +++ b/libs/tracker/llmstudio_tracker/prompt_manager/manager.py @@ -12,17 +12,16 @@ def __init__(self, tracking_config: TrackingConfig): def add_prompt(self, prompt: PromptDefault): req = self._session.post( - f"{self.tracking_url}/api/tracking/prompt", + f"{self.tracking_url}/api/tracking/add/prompt", headers={"accept": "application/json", "Content-Type": "application/json"}, data=prompt.model_dump_json(), timeout=100, ) - print(req) return req def delete_prompt(self, prompt: PromptDefault): req = self._session.delete( - f"{self.tracking_url}/api/tracking/prompt", + f"{self.tracking_url}/api/tracking/delete/prompt", headers={"accept": "application/json", "Content-Type": "application/json"}, data=prompt.model_dump_json(), timeout=100, @@ -31,7 +30,7 @@ def delete_prompt(self, prompt: PromptDefault): def update_prompt(self, prompt: PromptDefault): req = self._session.patch( - f"{self.tracking_url}/api/tracking/prompt", + f"{self.tracking_url}/api/tracking/update/prompt", headers={"accept": "application/json", "Content-Type": "application/json"}, data=prompt.model_dump_json(), timeout=100, @@ -54,7 +53,7 @@ def get_prompt( } req = self._session.get( - f"{self.tracking_url}/api/tracking/prompt", + f"{self.tracking_url}/api/tracking/get/prompt", headers={"accept": "application/json", "Content-Type": "application/json"}, timeout=100, data=json.dumps(data), diff --git a/libs/tracker/llmstudio_tracker/prompt_manager/models.py b/libs/tracker/llmstudio_tracker/prompt_manager/models.py index 0a2b4900..2d96e35b 100644 --- a/libs/tracker/llmstudio_tracker/prompt_manager/models.py +++ b/libs/tracker/llmstudio_tracker/prompt_manager/models.py @@ -12,8 +12,10 @@ Integer, String, UniqueConstraint, + event, func, ) +from sqlalchemy.orm import Session class PromptDefault(Base): @@ -33,11 +35,11 @@ class PromptDefault(Base): config = Column(JSON, nullable=True) prompt = Column(String) - is_active = Column(Boolean, default=True) - name = Column(String) - model = Column(String) - provider = Column(String) - version = Column(Integer) + is_active = Column(Boolean, default=False) + name = Column(String, nullable=False) + model = Column(String, nullable=False) + provider = Column(String, nullable=False) + version = Column(Integer, nullable=False) label = Column(String) updated_at = Column( DateTime(timezone=True), @@ -49,7 +51,9 @@ class PromptDefault(Base): ) __table_args__ = ( - UniqueConstraint("name", "provider", "model", "version", name="uq_name_label"), + UniqueConstraint( + "name", "provider", "model", "version", name="uq_prompt_version" + ), ) @staticmethod @@ -75,10 +79,24 @@ def create_with_incremental_version(cls, session, **kwargs): if not all([name, model, provider]): raise ValueError("name, model, and provider must be provided") - # Determine the next version kwargs["version"] = cls.get_next_version(session, name, model, provider) - # Create and add the new instance instance = cls(**kwargs) session.add(instance) return instance + + @event.listens_for(Session, "before_flush") + def ensure_single_active_prompt(session, flush_context, instances): + """ + Ensures only one PromptDefault entry per (name, model, provider) can have is_active=True. + If a new entry is set as is_active=True, deactivate others in the same group. + """ + for instance in session.new.union(session.dirty): + if isinstance(instance, PromptDefault) and instance.is_active: + session.query(PromptDefault).filter( + PromptDefault.name == instance.name, + PromptDefault.model == instance.model, + PromptDefault.provider == instance.provider, + PromptDefault.is_active == True, + PromptDefault.prompt_id != instance.prompt_id, + ).update({"is_active": False}, synchronize_session="fetch")