diff --git a/src/bub/framework.py b/src/bub/framework.py index 85cbb6ed..005f2221 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -11,6 +11,7 @@ from dotenv import load_dotenv from loguru import logger from republic import AsyncTapeStore, RepublicError, TapeContext +from republic.core.errors import ErrorKind from republic.tape import TapeStore from bub.envelope import content_of, field_of, unpack_batch @@ -146,8 +147,13 @@ async def _run_model( if event.kind == "text": parts.append(str(event.data.get("delta", ""))) elif event.kind == "error": + # Turn "kind" to enum type otherwise the RepublicError's __str__ won't work well + data = { + **event.data, + "kind": ErrorKind(event.data.get("kind", "unknown")), + } await self._hook_runtime.notify_error( - stage="run_model", error=RepublicError(**event.data), message=inbound + stage="run_model", error=RepublicError(**data), message=inbound ) return "".join(parts)