-
Notifications
You must be signed in to change notification settings - Fork 41
Feature/prompt management #200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| from llmstudio_tracker.prompt_manager import models, schemas | ||
| from sqlalchemy.orm import Session | ||
|
|
||
|
|
||
| def get_prompt_by_name_model_provider( | ||
| db: Session, name: str, model: str, provider: str | ||
| ): | ||
| return ( | ||
| db.query(models.PromptDefault) | ||
| .filter( | ||
| models.PromptDefault.name == name, | ||
| models.PromptDefault.model == model, | ||
| models.PromptDefault.provider == provider, | ||
| models.PromptDefault.is_active == True, | ||
| ) | ||
| .order_by(models.PromptDefault.version.desc()) | ||
| .first() | ||
| ) | ||
|
|
||
|
|
||
| def get_prompt_by_id(db: Session, prompt_id: str): | ||
| return ( | ||
| db.query(models.PromptDefault) | ||
| .filter(models.PromptDefault.prompt_id == prompt_id) | ||
| .first() | ||
|
brunoalho99 marked this conversation as resolved.
|
||
| ) | ||
|
|
||
|
|
||
| def get_prompt( | ||
| db: Session, | ||
| prompt_id: str = None, | ||
| name: str = None, | ||
| model: str = None, | ||
| provider: str = None, | ||
| ): | ||
| if prompt_id: | ||
| return get_prompt_by_id(db, prompt_id) | ||
| else: | ||
| return get_prompt_by_name_model_provider(db, name, model, provider) | ||
|
|
||
|
|
||
| def add_prompt(db: Session, prompt: schemas.PromptDefault): | ||
|
|
||
| prompt_created = models.PromptDefault.create_with_incremental_version( | ||
| db, | ||
| config=prompt.config, | ||
| prompt=prompt.prompt, | ||
| is_active=prompt.is_active, | ||
| name=prompt.name, | ||
| label=prompt.label, | ||
| model=prompt.model, | ||
| provider=prompt.provider, | ||
| ) | ||
| db.add(prompt_created) | ||
| db.commit() | ||
| db.refresh(prompt_created) | ||
| return prompt_created | ||
|
|
||
|
|
||
| def update_prompt(db: Session, prompt: schemas.PromptDefault): | ||
| if prompt.prompt_id: | ||
| existing_prompt = get_prompt_by_id(db, prompt.prompt_id) | ||
| else: | ||
| existing_prompt = get_prompt_by_name_model_provider( | ||
| db, prompt.name, prompt.model, prompt.provider | ||
| ) | ||
|
|
||
| existing_prompt.config = prompt.config | ||
| existing_prompt.prompt = prompt.prompt | ||
| existing_prompt.is_active = prompt.is_active | ||
| existing_prompt.name = prompt.name | ||
| existing_prompt.model = prompt.model | ||
| existing_prompt.provider = prompt.provider | ||
| existing_prompt.version = prompt.version | ||
| existing_prompt.label = prompt.label | ||
|
|
||
| db.commit() | ||
| db.refresh(existing_prompt) | ||
| return existing_prompt | ||
|
|
||
|
|
||
| def delete_prompt(db: Session, prompt: schemas.PromptDefault): | ||
| if prompt.prompt_id: | ||
| existing_prompt = get_prompt_by_id(db, prompt.prompt_id) | ||
| else: | ||
| existing_prompt = get_prompt_by_name_model_provider( | ||
| db, prompt.name, prompt.model, prompt.provider | ||
| ) | ||
|
|
||
| db.delete(existing_prompt) | ||
| db.commit() | ||
56 changes: 56 additions & 0 deletions
56
libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| from fastapi import APIRouter, Depends | ||
| from llmstudio_tracker.database import engine, get_db | ||
| from llmstudio_tracker.prompt_manager import crud, models, schemas | ||
| from sqlalchemy.orm import Session | ||
|
|
||
| models.Base.metadata.create_all(bind=engine) | ||
|
|
||
|
|
||
| class PromptsRoutes: | ||
| def __init__(self, router: APIRouter): | ||
| self.router = router | ||
| self.define_routes() | ||
|
|
||
| def define_routes(self): | ||
|
brunoalho99 marked this conversation as resolved.
|
||
| self.router.post( | ||
| "/add/prompt", | ||
| response_model=schemas.PromptDefault, | ||
| )(self.add_prompt) | ||
|
|
||
| self.router.get("/get/prompt", response_model=schemas.PromptDefault)( | ||
| self.get_prompt | ||
| ) | ||
|
|
||
| self.router.patch("/update/prompt", response_model=schemas.PromptDefault)( | ||
| self.update_prompt | ||
| ) | ||
|
|
||
| self.router.delete("/delete/prompt")(self.delete_prompt) | ||
|
|
||
| async def add_prompt( | ||
| self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) | ||
| ): | ||
| return crud.add_prompt(db=db, prompt=prompt) | ||
|
|
||
| async def update_prompt( | ||
| self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) | ||
| ): | ||
| return crud.update_prompt(db, prompt) | ||
|
|
||
| async def get_prompt( | ||
| self, | ||
| prompt_info: schemas.PromptInfo, | ||
| db: Session = Depends(get_db), | ||
| ): | ||
| return crud.get_prompt( | ||
| db, | ||
| prompt_id=prompt_info.prompt_id, | ||
| name=prompt_info.name, | ||
| model=prompt_info.model, | ||
| provider=prompt_info.provider, | ||
| ) | ||
|
|
||
| async def delete_prompt( | ||
| self, prompt: schemas.PromptDefault, db: Session = Depends(get_db) | ||
| ): | ||
| return crud.delete_prompt(db, prompt) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| import json | ||
|
|
||
| import requests | ||
| from llmstudio_tracker.prompt_manager.schemas import PromptDefault | ||
| from llmstudio_tracker.tracker import TrackingConfig | ||
|
|
||
|
|
||
| class PromptManager: | ||
| def __init__(self, tracking_config: TrackingConfig): | ||
| self.tracking_url = tracking_config.url | ||
| self._session = requests.Session() | ||
|
|
||
| def add_prompt(self, prompt: PromptDefault): | ||
| req = self._session.post( | ||
| f"{self.tracking_url}/api/tracking/add/prompt", | ||
| headers={"accept": "application/json", "Content-Type": "application/json"}, | ||
| data=prompt.model_dump_json(), | ||
| timeout=100, | ||
| ) | ||
| return req | ||
|
|
||
| def delete_prompt(self, prompt: PromptDefault): | ||
| req = self._session.delete( | ||
| f"{self.tracking_url}/api/tracking/delete/prompt", | ||
| headers={"accept": "application/json", "Content-Type": "application/json"}, | ||
| data=prompt.model_dump_json(), | ||
| timeout=100, | ||
| ) | ||
| return req | ||
|
|
||
| def update_prompt(self, prompt: PromptDefault): | ||
| req = self._session.patch( | ||
| f"{self.tracking_url}/api/tracking/update/prompt", | ||
| headers={"accept": "application/json", "Content-Type": "application/json"}, | ||
| data=prompt.model_dump_json(), | ||
| timeout=100, | ||
| ) | ||
| return req | ||
|
|
||
| def get_prompt( | ||
| self, | ||
| prompt_id: str = None, | ||
| name: str = None, | ||
| model: str = None, | ||
| provider: str = None, | ||
| ): | ||
|
|
||
| data = { | ||
| "prompt_id": prompt_id, | ||
| "name": name, | ||
| "model": model, | ||
| "provider": provider, | ||
| } | ||
|
|
||
| req = self._session.get( | ||
| f"{self.tracking_url}/api/tracking/get/prompt", | ||
| headers={"accept": "application/json", "Content-Type": "application/json"}, | ||
| timeout=100, | ||
| data=json.dumps(data), | ||
| ) | ||
| return req |
102 changes: 102 additions & 0 deletions
102
libs/tracker/llmstudio_tracker/prompt_manager/models.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| import uuid | ||
| from datetime import datetime, timezone | ||
|
|
||
| from llmstudio_tracker.config import DB_TYPE | ||
| from llmstudio_tracker.database import Base | ||
| from llmstudio_tracker.db_utils import JSONEncodedDict | ||
| from sqlalchemy import ( | ||
| JSON, | ||
| Boolean, | ||
| Column, | ||
| DateTime, | ||
| Integer, | ||
| String, | ||
| UniqueConstraint, | ||
| event, | ||
| func, | ||
| ) | ||
| from sqlalchemy.orm import Session | ||
|
|
||
|
|
||
| class PromptDefault(Base): | ||
|
brunoalho99 marked this conversation as resolved.
|
||
| __tablename__ = "prompts" | ||
|
|
||
| if DB_TYPE == "bigquery": | ||
| prompt_id = Column( | ||
| String, | ||
| primary_key=True, | ||
| default=lambda: str(uuid.uuid4()), | ||
| ) | ||
| config = Column(JSONEncodedDict, nullable=True) | ||
| else: | ||
| prompt_id = Column( | ||
| String, primary_key=True, default=lambda: str(uuid.uuid4()) | ||
| ) # Generate UUID as a string | ||
| config = Column(JSON, nullable=True) | ||
|
|
||
| prompt = Column(String) | ||
| is_active = Column(Boolean, default=False) | ||
| name = Column(String, nullable=False) | ||
| model = Column(String, nullable=False) | ||
| provider = Column(String, nullable=False) | ||
| version = Column(Integer, nullable=False) | ||
| label = Column(String) | ||
| updated_at = Column( | ||
| DateTime(timezone=True), | ||
| onupdate=lambda: datetime.now(timezone.utc), | ||
| default=lambda: datetime.now(timezone.utc), | ||
| ) | ||
| created_at = Column( | ||
| DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) | ||
| ) | ||
|
|
||
| __table_args__ = ( | ||
| UniqueConstraint( | ||
| "name", "provider", "model", "version", name="uq_prompt_version" | ||
| ), | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def get_next_version(session, name, model, provider): | ||
| """ | ||
| Get the next version number for a combination of name, model, and provider. | ||
| """ | ||
| max_version = ( | ||
| session.query(func.max(PromptDefault.version)) | ||
| .filter_by(name=name, model=model, provider=provider) | ||
| .scalar() | ||
| ) | ||
| return (max_version or 0) + 1 | ||
|
|
||
| @classmethod | ||
| def create_with_incremental_version(cls, session, **kwargs): | ||
| """ | ||
| Create a new PromptDefault entry with an incremental version. | ||
| """ | ||
| name = kwargs.get("name") | ||
| model = kwargs.get("model") | ||
| provider = kwargs.get("provider") | ||
| if not all([name, model, provider]): | ||
| raise ValueError("name, model, and provider must be provided") | ||
|
|
||
| kwargs["version"] = cls.get_next_version(session, name, model, provider) | ||
|
|
||
| instance = cls(**kwargs) | ||
| session.add(instance) | ||
| return instance | ||
|
|
||
| @event.listens_for(Session, "before_flush") | ||
| def ensure_single_active_prompt(session, flush_context, instances): | ||
| """ | ||
| Ensures only one PromptDefault entry per (name, model, provider) can have is_active=True. | ||
| If a new entry is set as is_active=True, deactivate others in the same group. | ||
| """ | ||
| for instance in session.new.union(session.dirty): | ||
| if isinstance(instance, PromptDefault) and instance.is_active: | ||
| session.query(PromptDefault).filter( | ||
| PromptDefault.name == instance.name, | ||
| PromptDefault.model == instance.model, | ||
| PromptDefault.provider == instance.provider, | ||
| PromptDefault.is_active == True, | ||
| PromptDefault.prompt_id != instance.prompt_id, | ||
| ).update({"is_active": False}, synchronize_session="fetch") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| from datetime import datetime | ||
| from typing import Dict, Optional | ||
|
|
||
| from pydantic import BaseModel | ||
|
|
||
|
|
||
| class PromptInfo(BaseModel): | ||
| prompt_id: Optional[str] = None | ||
| name: Optional[str] = None | ||
| model: Optional[str] = None | ||
| provider: Optional[str] = None | ||
|
|
||
|
|
||
| class PromptDefault(BaseModel): | ||
| prompt_id: Optional[str] = None | ||
| config: Optional[Dict] = {} | ||
| prompt: str | ||
| is_active: Optional[bool] = None | ||
| name: str | ||
| version: Optional[int] = None | ||
| label: Optional[str] = "production" | ||
| model: str | ||
| provider: str | ||
| updated_at: Optional[datetime] = None | ||
| created_at: Optional[datetime] = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.