Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions echo/server/dembrane/reply_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import re
from typing import AsyncGenerator
from logging import getLogger

import litellm
from litellm import acompletion
from pydantic import BaseModel

from litellm.utils import token_counter

from dembrane.config import (
MEDIUM_LITELLM_MODEL,
MEDIUM_LITELLM_API_KEY,
MEDIUM_LITELLM_API_BASE,
MEDIUM_LITELLM_API_VERSION,
)
from dembrane.prompts import render_prompt
from dembrane.directus import directus
from dembrane.anthropic import count_tokens_anthropic

logger = getLogger("reply_utils")

# Constants for token limits and conversation sizing
GET_REPLY_TOKEN_LIMIT = 40000
GET_REPLY_TARGET_TOKENS_PER_CONV = 2000
GET_REPLY_TOKEN_LIMIT = 80000
GET_REPLY_TARGET_TOKENS_PER_CONV = 4000
GET_REPLY_TAG_BUFFER_MAX_SIZE = 100
GET_REPLY_TAG_BUFFER_TRIM_SIZE = 20

Expand Down Expand Up @@ -219,7 +226,7 @@ async def generate_reply_for_conversation(

# Check tokens for this conversation
formatted_conv = format_conversation(c)
tokens = count_tokens_anthropic(formatted_conv)
tokens = token_counter(text=formatted_conv, model=MEDIUM_LITELLM_MODEL)

candidate_conversations.append((formatted_conv, tokens))
else:
Expand All @@ -239,7 +246,7 @@ async def generate_reply_for_conversation(

# First check tokens for this conversation
formatted_conv = format_conversation(c)
tokens = count_tokens_anthropic(formatted_conv)
tokens = token_counter(text=formatted_conv, model=MEDIUM_LITELLM_MODEL)

# If conversation is too large, truncate it
if tokens > target_tokens_per_conv:
Expand All @@ -248,7 +255,7 @@ async def generate_reply_for_conversation(
truncated_transcript = c.transcript[: int(len(c.transcript) * truncation_ratio)]
c.transcript = truncated_transcript + "\n[Truncated for brevity...]"
formatted_conv = format_conversation(c)
tokens = count_tokens_anthropic(formatted_conv)
tokens = token_counter(text=formatted_conv, model=MEDIUM_LITELLM_MODEL)

candidate_conversations.append((formatted_conv, tokens))

Expand Down Expand Up @@ -331,11 +338,13 @@ async def generate_reply_for_conversation(
in_response_section = False

# Stream the response
response = await litellm.acompletion(
model="anthropic/claude-3-5-sonnet-20240620",
response = await acompletion(
model=MEDIUM_LITELLM_MODEL,
api_key=MEDIUM_LITELLM_API_KEY,
api_version=MEDIUM_LITELLM_API_VERSION,
api_base=MEDIUM_LITELLM_API_BASE,
messages=[
{"role": "user", "content": prompt},
{"role": "assistant", "content": ""},
],
stream=True,
)
Expand Down Expand Up @@ -403,14 +412,16 @@ async def generate_reply_for_conversation(
tag_buffer = tag_buffer[-GET_REPLY_TAG_BUFFER_TRIM_SIZE:]

try:
response_content = ""
if "<response>" in accumulated_response and "</response>" in accumulated_response:
start_idx = accumulated_response.find("<response>") + len("<response>")
end_idx = accumulated_response.find("</response>")
if start_idx < end_idx:
response_content = accumulated_response[start_idx:end_idx].strip()
else:
response_content = accumulated_response.strip()
response_content = accumulated_response

# Remove everything between <detailed_analysis> and </detailed_analysis> tags
response_content = re.sub(r'<detailed_analysis>.*?</detailed_analysis>', '', response_content, flags=re.DOTALL)

# Replace <response> and </response> tags with empty strings
response_content = response_content.replace("<response>", "").replace("</response>", "")

# Strip whitespace
response_content = response_content.strip()

directus.create_item(
"conversation_reply",
Expand Down