Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
91 changes: 91 additions & 0 deletions libs/tracker/llmstudio_tracker/prompt_manager/crud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from llmstudio_tracker.prompt_manager import models, schemas
from sqlalchemy.orm import Session


def get_prompt_by_name_model_provider(
db: Session, name: str, model: str, provider: str
):
return (
db.query(models.PromptDefault)
.filter(
models.PromptDefault.name == name,
models.PromptDefault.model == model,
models.PromptDefault.provider == provider,
models.PromptDefault.is_active == True,
)
.order_by(models.PromptDefault.version.desc())
.first()
Comment thread
brunoalho99 marked this conversation as resolved.
)


def get_prompt_by_id(db: Session, prompt_id: str):
return (
db.query(models.PromptDefault)
.filter(models.PromptDefault.prompt_id == prompt_id)
.first()
Comment thread
brunoalho99 marked this conversation as resolved.
)


def get_prompt(
db: Session,
prompt_id: str = None,
name: str = None,
model: str = None,
provider: str = None,
):
if prompt_id:
return get_prompt_by_id(db, prompt_id)
else:
return get_prompt_by_name_model_provider(db, name, model, provider)


def add_prompt(db: Session, prompt: schemas.PromptDefault):

prompt_created = models.PromptDefault.create_with_incremental_version(
db,
config=prompt.config,
prompt=prompt.prompt,
is_active=prompt.is_active,
name=prompt.name,
label=prompt.label,
model=prompt.model,
provider=prompt.provider,
)
db.add(prompt_created)
db.commit()
db.refresh(prompt_created)
return prompt_created


def update_prompt(db: Session, prompt: schemas.PromptDefault):
if prompt.prompt_id:
existing_prompt = get_prompt_by_id(db, prompt.prompt_id)
else:
existing_prompt = get_prompt_by_name_model_provider(
db, prompt.name, prompt.model, prompt.provider
)

existing_prompt.config = prompt.config
existing_prompt.prompt = prompt.prompt
existing_prompt.is_active = prompt.is_active
existing_prompt.name = prompt.name
existing_prompt.model = prompt.model
existing_prompt.provider = prompt.provider
existing_prompt.version = prompt.version
existing_prompt.label = prompt.label

db.commit()
db.refresh(existing_prompt)
return existing_prompt


def delete_prompt(db: Session, prompt: schemas.PromptDefault):
if prompt.prompt_id:
existing_prompt = get_prompt_by_id(db, prompt.prompt_id)
else:
existing_prompt = get_prompt_by_name_model_provider(
db, prompt.name, prompt.model, prompt.provider
)

db.delete(existing_prompt)
db.commit()
56 changes: 56 additions & 0 deletions libs/tracker/llmstudio_tracker/prompt_manager/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from fastapi import APIRouter, Depends
from llmstudio_tracker.database import engine, get_db
from llmstudio_tracker.prompt_manager import crud, models, schemas
from sqlalchemy.orm import Session

models.Base.metadata.create_all(bind=engine)


class PromptsRoutes:
def __init__(self, router: APIRouter):
self.router = router
self.define_routes()

def define_routes(self):
Comment thread
brunoalho99 marked this conversation as resolved.
self.router.post(
"/add/prompt",
response_model=schemas.PromptDefault,
)(self.add_prompt)

self.router.get("/get/prompt", response_model=schemas.PromptDefault)(
self.get_prompt
)

self.router.patch("/update/prompt", response_model=schemas.PromptDefault)(
self.update_prompt
)

self.router.delete("/delete/prompt")(self.delete_prompt)

async def add_prompt(
self, prompt: schemas.PromptDefault, db: Session = Depends(get_db)
):
return crud.add_prompt(db=db, prompt=prompt)

async def update_prompt(
self, prompt: schemas.PromptDefault, db: Session = Depends(get_db)
):
return crud.update_prompt(db, prompt)

async def get_prompt(
self,
prompt_info: schemas.PromptInfo,
db: Session = Depends(get_db),
):
return crud.get_prompt(
db,
prompt_id=prompt_info.prompt_id,
name=prompt_info.name,
model=prompt_info.model,
provider=prompt_info.provider,
)

async def delete_prompt(
self, prompt: schemas.PromptDefault, db: Session = Depends(get_db)
):
return crud.delete_prompt(db, prompt)
61 changes: 61 additions & 0 deletions libs/tracker/llmstudio_tracker/prompt_manager/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import json

import requests
from llmstudio_tracker.prompt_manager.schemas import PromptDefault
from llmstudio_tracker.tracker import TrackingConfig


class PromptManager:
def __init__(self, tracking_config: TrackingConfig):
self.tracking_url = tracking_config.url
self._session = requests.Session()

def add_prompt(self, prompt: PromptDefault):
req = self._session.post(
f"{self.tracking_url}/api/tracking/add/prompt",
headers={"accept": "application/json", "Content-Type": "application/json"},
data=prompt.model_dump_json(),
timeout=100,
)
return req

def delete_prompt(self, prompt: PromptDefault):
req = self._session.delete(
f"{self.tracking_url}/api/tracking/delete/prompt",
headers={"accept": "application/json", "Content-Type": "application/json"},
data=prompt.model_dump_json(),
timeout=100,
)
return req

def update_prompt(self, prompt: PromptDefault):
req = self._session.patch(
f"{self.tracking_url}/api/tracking/update/prompt",
headers={"accept": "application/json", "Content-Type": "application/json"},
data=prompt.model_dump_json(),
timeout=100,
)
return req

def get_prompt(
self,
prompt_id: str = None,
name: str = None,
model: str = None,
provider: str = None,
):

data = {
"prompt_id": prompt_id,
"name": name,
"model": model,
"provider": provider,
}

req = self._session.get(
f"{self.tracking_url}/api/tracking/get/prompt",
headers={"accept": "application/json", "Content-Type": "application/json"},
timeout=100,
data=json.dumps(data),
)
return req
102 changes: 102 additions & 0 deletions libs/tracker/llmstudio_tracker/prompt_manager/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import uuid
from datetime import datetime, timezone

from llmstudio_tracker.config import DB_TYPE
from llmstudio_tracker.database import Base
from llmstudio_tracker.db_utils import JSONEncodedDict
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
Integer,
String,
UniqueConstraint,
event,
func,
)
from sqlalchemy.orm import Session


class PromptDefault(Base):
Comment thread
brunoalho99 marked this conversation as resolved.
__tablename__ = "prompts"

if DB_TYPE == "bigquery":
prompt_id = Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
)
config = Column(JSONEncodedDict, nullable=True)
else:
prompt_id = Column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
) # Generate UUID as a string
config = Column(JSON, nullable=True)

prompt = Column(String)
is_active = Column(Boolean, default=False)
name = Column(String, nullable=False)
model = Column(String, nullable=False)
provider = Column(String, nullable=False)
version = Column(Integer, nullable=False)
label = Column(String)
updated_at = Column(
DateTime(timezone=True),
onupdate=lambda: datetime.now(timezone.utc),
default=lambda: datetime.now(timezone.utc),
)
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)

__table_args__ = (
UniqueConstraint(
"name", "provider", "model", "version", name="uq_prompt_version"
),
)

@staticmethod
def get_next_version(session, name, model, provider):
"""
Get the next version number for a combination of name, model, and provider.
"""
max_version = (
session.query(func.max(PromptDefault.version))
.filter_by(name=name, model=model, provider=provider)
.scalar()
)
return (max_version or 0) + 1

@classmethod
def create_with_incremental_version(cls, session, **kwargs):
"""
Create a new PromptDefault entry with an incremental version.
"""
name = kwargs.get("name")
model = kwargs.get("model")
provider = kwargs.get("provider")
if not all([name, model, provider]):
raise ValueError("name, model, and provider must be provided")

kwargs["version"] = cls.get_next_version(session, name, model, provider)

instance = cls(**kwargs)
session.add(instance)
return instance

@event.listens_for(Session, "before_flush")
def ensure_single_active_prompt(session, flush_context, instances):
"""
Ensures only one PromptDefault entry per (name, model, provider) can have is_active=True.
If a new entry is set as is_active=True, deactivate others in the same group.
"""
for instance in session.new.union(session.dirty):
if isinstance(instance, PromptDefault) and instance.is_active:
session.query(PromptDefault).filter(
PromptDefault.name == instance.name,
PromptDefault.model == instance.model,
PromptDefault.provider == instance.provider,
PromptDefault.is_active == True,
PromptDefault.prompt_id != instance.prompt_id,
).update({"is_active": False}, synchronize_session="fetch")
25 changes: 25 additions & 0 deletions libs/tracker/llmstudio_tracker/prompt_manager/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from datetime import datetime
from typing import Dict, Optional

from pydantic import BaseModel


class PromptInfo(BaseModel):
prompt_id: Optional[str] = None
name: Optional[str] = None
model: Optional[str] = None
provider: Optional[str] = None


class PromptDefault(BaseModel):
prompt_id: Optional[str] = None
config: Optional[Dict] = {}
prompt: str
is_active: Optional[bool] = None
name: str
version: Optional[int] = None
label: Optional[str] = "production"
model: str
provider: str
updated_at: Optional[datetime] = None
created_at: Optional[datetime] = None
2 changes: 2 additions & 0 deletions libs/tracker/llmstudio_tracker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi.middleware.cors import CORSMiddleware
from llmstudio_tracker.config import TRACKING_HOST, TRACKING_PORT
from llmstudio_tracker.logs.endpoints import LogsRoutes
from llmstudio_tracker.prompt_manager.endpoints import PromptsRoutes
from llmstudio_tracker.session.endpoints import SessionsRoutes
from llmstudio_tracker.utils import get_current_version

Expand Down Expand Up @@ -42,6 +43,7 @@ def health_check():
tracking_router = APIRouter(prefix=TRACKING_BASE_ENDPOINT)
LogsRoutes(tracking_router)
SessionsRoutes(tracking_router)
PromptsRoutes(tracking_router)

app.include_router(tracking_router)

Expand Down