Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import copy
from datetime import datetime
from datetime import timezone
import json
import logging
from typing import Any
Expand Down Expand Up @@ -144,6 +145,21 @@ class StorageSession(Base):
def __repr__(self):
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"

@property
def _dialect_name(self) -> Optional[str]:
session = inspect(self).session
return session.bind.dialect.name if session else None

@property
def update_timestamp_tz(self) -> datetime:
"""Returns the time zone aware update timestamp."""
if self._dialect_name == "sqlite":
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
# object without timezone information. We need to convert it to UTC
# manually.
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
return self.update_time.timestamp()


class StorageEvent(Base):
"""Represents an event stored in the database."""
Expand Down Expand Up @@ -412,7 +428,7 @@ async def create_session(
user_id=str(storage_session.user_id),
id=str(storage_session.id),
state=merged_state,
last_update_time=storage_session.update_time.timestamp(),
last_update_time=storage_session.update_timestamp_tz,
)
return session

Expand Down Expand Up @@ -473,7 +489,7 @@ async def get_session(
user_id=user_id,
id=session_id,
state=merged_state,
last_update_time=storage_session.update_time.timestamp(),
last_update_time=storage_session.update_timestamp_tz,
)
session.events = [e.to_event() for e in reversed(storage_events)]
return session
Expand All @@ -496,7 +512,7 @@ async def list_sessions(
user_id=user_id,
id=storage_session.id,
state={},
last_update_time=storage_session.update_time.timestamp(),
last_update_time=storage_session.update_timestamp_tz,
)
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
Expand Down Expand Up @@ -529,13 +545,13 @@ async def append_event(self, session: Session, event: Event) -> Event:
StorageSession, (session.app_name, session.user_id, session.id)
)

if storage_session.update_time.timestamp() > session.last_update_time:
if storage_session.update_timestamp_tz > session.last_update_time:
raise ValueError(
"The last_update_time provided in the session object"
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
" earlier than the update_time in the storage_session"
f" {storage_session.update_time:'%Y-%m-%d %H:%M:%S'}. Please check"
" if it is a stale session."
f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}."
" Please check if it is a stale session."
)

# Fetch states from storage
Expand Down Expand Up @@ -577,7 +593,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
session_factory.refresh(storage_session)

# Update timestamp with commit time
session.last_update_time = storage_session.update_time.timestamp()
session.last_update_time = storage_session.update_timestamp_tz

# Also update the in-memory session
await super().append_event(session=session, event=event)
Expand Down
17 changes: 13 additions & 4 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from datetime import timezone
import enum

from google.adk.events import Event
Expand Down Expand Up @@ -66,10 +68,17 @@ async def test_create_get_session(service_type):
assert session.id
assert session.state == state
assert (
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
session.last_update_time
<= datetime.now().astimezone(timezone.utc).timestamp()
)

got_session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert got_session == session
assert (
got_session.last_update_time
<= datetime.now().astimezone(timezone.utc).timestamp()
)

session_id = session.id
Expand Down