Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions alignai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from alignai.ingestion.v1alpha.event_pb2 import Event, EventProperties
from alignai.logger import get_logger
from alignai.utils import datetime_to_timestamp
from alignai.utils import CustomProperties, datetime_to_timestamp, serialize_custom_properties
from alignai.worker import Worker


Expand Down Expand Up @@ -61,20 +61,30 @@ def __init__(
self.worker.setup(self.buffer_storage)
self.worker.start()

def open_session(self, session_id: str, user_id: str, assistant_id: str = DEFAULT_ASSISTANT_ID) -> None:
def open_session(
self,
session_id: str,
user_id: str,
assistant_id: str = DEFAULT_ASSISTANT_ID,
custom_properties: CustomProperties | None = None,
) -> None:
"""Record the initiation of a session.

Args:
session_id (str): Session ID.
user_id (str): User ID associated with the session.
assistant_id (str, optional): Assistant ID. Defaults to "DEFAULT".
custom_properties (dict[str, str] | None, optional): Custom properties associated with the session. Defaults to None.
"""
session_properties_args = {"session_id": session_id, "user_id": user_id, "assistant_id": assistant_id}
open_session_event = Event(
id=uuid.uuid4().hex,
type=EventTypes.SESSION_OPEN,
create_time=datetime_to_timestamp(pendulum.now()),
properties=EventProperties(session_properties=EventProperties.SessionProperties(**session_properties_args)),
properties=EventProperties(
session_properties=EventProperties.SessionProperties(**session_properties_args),
custom_properties=serialize_custom_properties(custom_properties) if custom_properties else None,
),
project_id=self.project_id,
)
self._collect(open_session_event)
Expand Down Expand Up @@ -102,6 +112,7 @@ def identify_user(
ip: str | None = None,
country_code: str | None = None,
create_time: datetime | None = None,
custom_properties: CustomProperties | None = None,
) -> None:
"""Record a user.

Expand All @@ -112,6 +123,7 @@ def identify_user(
ip (str | None, optional): User IPv4 address. Provide either ip or country code for user location. If both are given, country code overrides ip. Defaults to None.
country_code (str | None, optional): User country code in ISO Alpha-2. Provide either ip or country code for user location. If both are given, country code overrides ip. Defaults to None.
create_time (datetime | None, optional): User creation time. Defaults to None.
custom_properties (dict[str, str] | None, optional): Custom properties associated with the user. Defaults to None.
""" # noqa: E501
user_properties_args = {"user_id": user_id}
if display_name is not None:
Expand All @@ -129,19 +141,30 @@ def identify_user(
id=uuid.uuid4().hex,
type=EventTypes.USER_RECOGNIZE,
create_time=datetime_to_timestamp(pendulum.now()),
properties=EventProperties(user_properties=EventProperties.UserProperties(**user_properties_args)),
properties=EventProperties(
user_properties=EventProperties.UserProperties(**user_properties_args),
custom_properties=serialize_custom_properties(custom_properties) if custom_properties else None,
),
project_id=self.project_id,
)
self._collect(identify_user_event)

def create_message(self, session_id: str, message_index: int, role: str, content: str) -> None:
def create_message(
self,
session_id: str,
message_index: int,
role: str,
content: str,
custom_properties: CustomProperties | None = None,
) -> None:
"""Record an individual message within a session.

Args:
session_id (str): Session ID associated with the message.
message_index (int): Message index used to sort messages in a chronological order within a session. Must be a positive integer.
role (str): alignai.constants.ROLE_USER or alignai.constants.ROLE_ASSISTANT.
content (str): Content of the message.
custom_properties (dict[str, str] | None, optional): Custom properties associated with the message. Defaults to None.
""" # noqa: E501
if message_index <= 0:
self.logger.error(f"Invalid message index '{message_index}': Message index must be a positive integer")
Expand All @@ -158,11 +181,14 @@ def create_message(self, session_id: str, message_index: int, role: str, content
message_properties=EventProperties.MessageProperties(
session_id=session_id,
message_index_hint=message_index,
message_role=EventProperties.MessageProperties.Role.ROLE_ASSISTANT
if role == ROLE_ASSISTANT
else EventProperties.MessageProperties.Role.ROLE_USER,
message_role=(
EventProperties.MessageProperties.Role.ROLE_ASSISTANT
if role == ROLE_ASSISTANT
else EventProperties.MessageProperties.Role.ROLE_USER
),
message_content=content,
)
),
custom_properties=serialize_custom_properties(custom_properties) if custom_properties else None,
),
project_id=self.project_id,
)
Expand Down
Loading