diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 2ccd60083d..bf83b0409b 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -15,6 +15,7 @@ import copy from datetime import datetime +from datetime import timezone import json import logging from typing import Any @@ -144,6 +145,21 @@ class StorageSession(Base): def __repr__(self): return f"" + @property + def _dialect_name(self) -> Optional[str]: + session = inspect(self).session + return session.bind.dialect.name if session else None + + @property + def update_timestamp_tz(self) -> datetime: + """Returns the time zone aware update timestamp.""" + if self._dialect_name == "sqlite": + # SQLite does not support timezone. SQLAlchemy returns a naive datetime + # object without timezone information. We need to convert it to UTC + # manually. + return self.update_time.replace(tzinfo=timezone.utc).timestamp() + return self.update_time.timestamp() + class StorageEvent(Base): """Represents an event stored in the database.""" @@ -412,7 +428,7 @@ async def create_session( user_id=str(storage_session.user_id), id=str(storage_session.id), state=merged_state, - last_update_time=storage_session.update_time.timestamp(), + last_update_time=storage_session.update_timestamp_tz, ) return session @@ -473,7 +489,7 @@ async def get_session( user_id=user_id, id=session_id, state=merged_state, - last_update_time=storage_session.update_time.timestamp(), + last_update_time=storage_session.update_timestamp_tz, ) session.events = [e.to_event() for e in reversed(storage_events)] return session @@ -496,7 +512,7 @@ async def list_sessions( user_id=user_id, id=storage_session.id, state={}, - last_update_time=storage_session.update_time.timestamp(), + last_update_time=storage_session.update_timestamp_tz, ) sessions.append(session) return ListSessionsResponse(sessions=sessions) @@ -529,13 +545,13 @@ async def append_event(self, session: Session, event: Event) -> Event: StorageSession, (session.app_name, session.user_id, session.id) ) - if storage_session.update_time.timestamp() > session.last_update_time: + if storage_session.update_timestamp_tz > session.last_update_time: raise ValueError( "The last_update_time provided in the session object" f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is" " earlier than the update_time in the storage_session" - f" {storage_session.update_time:'%Y-%m-%d %H:%M:%S'}. Please check" - " if it is a stale session." + f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}." + " Please check if it is a stale session." ) # Fetch states from storage @@ -577,7 +593,7 @@ async def append_event(self, session: Session, event: Event) -> Event: session_factory.refresh(storage_session) # Update timestamp with commit time - session.last_update_time = storage_session.update_time.timestamp() + session.last_update_time = storage_session.update_timestamp_tz # Also update the in-memory session await super().append_event(session=session, event=event) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index ec93caafbb..d8344194f4 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime +from datetime import timezone import enum from google.adk.events import Event @@ -66,10 +68,17 @@ async def test_create_get_session(service_type): assert session.id assert session.state == state assert ( - await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - == session + session.last_update_time + <= datetime.now().astimezone(timezone.utc).timestamp() + ) + + got_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert got_session == session + assert ( + got_session.last_update_time + <= datetime.now().astimezone(timezone.utc).timestamp() ) session_id = session.id