diff --git a/echo/server/dembrane/chat_utils.py b/echo/server/dembrane/chat_utils.py index 0c73bd8b..f317eb0a 100644 --- a/echo/server/dembrane/chat_utils.py +++ b/echo/server/dembrane/chat_utils.py @@ -141,7 +141,7 @@ async def create_system_messages_for_chat( if conversation.created_at else None, "duration": conversation.duration, - "transcript": get_conversation_transcript( + "transcript": await get_conversation_transcript( conversation.id, # fake auth to get this fn call DirectusSession(user_id="none", is_admin=True), @@ -427,7 +427,7 @@ async def _process_single_batch( else: # Use transcript as fallback try: - transcript = get_conversation_transcript( + transcript = await get_conversation_transcript( conv.id, DirectusSession(user_id="none", is_admin=True), ) diff --git a/echo/server/dembrane/report_utils.py b/echo/server/dembrane/report_utils.py index 0e163368..bfe23cdb 100644 --- a/echo/server/dembrane/report_utils.py +++ b/echo/server/dembrane/report_utils.py @@ -111,7 +111,7 @@ async def get_report_content_for_project(project_id: str, language: str) -> str: if conversation["id"] not in conversation_data_dict: continue - transcript = get_conversation_transcript( + transcript = await get_conversation_transcript( conversation["id"], DirectusSession(user_id="none", is_admin=True), ) diff --git a/echo/server/tests/api/test_conversation.py b/echo/server/tests/api/test_conversation.py index 26ae1a59..a997da0b 100644 --- a/echo/server/tests/api/test_conversation.py +++ b/echo/server/tests/api/test_conversation.py @@ -1,5 +1,7 @@ import logging +import pytest + from tests.common import ( create_project, delete_project, @@ -15,7 +17,8 @@ logger = logging.getLogger("dembrane.tests.api.test_conversation") -def test_get_conversation_transcript(): +@pytest.mark.asyncio +async def test_get_conversation_transcript(): project = create_project("test", "en") conversation = create_conversation(project["id"], "test") chunks = [ @@ -23,7 +26,7 @@ def test_get_conversation_transcript(): create_conversation_chunk(conversation["id"], "check456"), ] - transcript = get_conversation_transcript( + transcript = await get_conversation_transcript( conversation["id"], auth=DirectusSession(user_id="none", is_admin=True) ) assert transcript == "check123\ncheck456" @@ -35,11 +38,12 @@ def test_get_conversation_transcript(): delete_project(project["id"]) -def test_summarize_conversation(): +@pytest.mark.asyncio +async def test_summarize_conversation(): project = create_project("test", "en") conversation = create_conversation(project["id"], "test") - response = summarize_conversation( + response = await summarize_conversation( conversation["id"], auth=DirectusSession(user_id="none", is_admin=True) ) @@ -48,7 +52,7 @@ def test_summarize_conversation(): chunk = create_conversation_chunk(conversation["id"], "Hello, how are you?") - response = summarize_conversation( + response = await summarize_conversation( conversation["id"], auth=DirectusSession(user_id="none", is_admin=True) )