Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 229 additions & 24 deletions ably/realtime/connectionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import asyncio
import logging
from collections import deque
from datetime import datetime
from queue import Queue
from typing import TYPE_CHECKING

import httpx
Expand All @@ -24,6 +24,88 @@
log = logging.getLogger(__name__)


class PendingMessage:
"""Represents a message awaiting acknowledgment from the server"""

def __init__(self, message: dict):
self.message = message
self.future: asyncio.Future | None = None
action = message.get('action')

# Messages that require acknowledgment: MESSAGE, PRESENCE, ANNOTATION, OBJECT
self.ack_required = action in (
ProtocolMessageAction.MESSAGE,
ProtocolMessageAction.PRESENCE,
ProtocolMessageAction.ANNOTATION,
ProtocolMessageAction.OBJECT,
)

if self.ack_required:
self.future = asyncio.Future()


class PendingMessageQueue:
"""Queue for tracking messages awaiting acknowledgment"""

def __init__(self):
self.messages: list[PendingMessage] = []

def push(self, pending_message: PendingMessage) -> None:
"""Add a message to the queue"""
self.messages.append(pending_message)

def count(self) -> int:
"""Return the number of pending messages"""
return len(self.messages)

def complete_messages(self, serial: int, count: int, err: AblyException | None = None) -> None:
"""Complete messages based on serial and count from ACK/NACK

Args:
serial: The msgSerial of the first message being acknowledged
count: The number of messages being acknowledged
err: Error from NACK, or None for successful ACK
"""
log.debug(f'MessageQueue.complete_messages(): serial={serial}, count={count}, err={err}')

if not self.messages:
log.warning('MessageQueue.complete_messages(): called on empty queue')
return

first = self.messages[0]
if first:
start_serial = first.message.get('msgSerial')
if start_serial is None:
log.warning('MessageQueue.complete_messages(): first message has no msgSerial')
return

end_serial = serial + count

if end_serial > start_serial:
# Remove and complete the acknowledged messages
num_to_complete = min(end_serial - start_serial, len(self.messages))
completed_messages = self.messages[:num_to_complete]
self.messages = self.messages[num_to_complete:]

for msg in completed_messages:
if msg.future and not msg.future.done():
if err:
msg.future.set_exception(err)
else:
msg.future.set_result(None)

def complete_all_messages(self, err: AblyException) -> None:
"""Complete all pending messages with an error"""
while self.messages:
msg = self.messages.pop(0)
if msg.future and not msg.future.done():
msg.future.set_exception(err)

def clear(self) -> None:
"""Clear all messages from the queue"""
self.messages.clear()


class ConnectionManager(EventEmitter):
def __init__(self, realtime: AblyRealtime, initial_state):
self.options = realtime.options
Expand All @@ -41,8 +123,10 @@ def __init__(self, realtime: AblyRealtime, initial_state):
self.connect_base_task: asyncio.Task | None = None
self.disconnect_transport_task: asyncio.Task | None = None
self.__fallback_hosts: list[str] = self.options.get_fallback_realtime_hosts()
self.queued_messages: Queue = Queue()
self.queued_messages: deque[PendingMessage] = deque()
self.__error_reason: AblyException | None = None
self.msg_serial: int = 0
self.pending_message_queue: PendingMessageQueue = PendingMessageQueue()
super().__init__()

def enact_state_change(self, state: ConnectionState, reason: AblyException | None = None) -> None:
Expand Down Expand Up @@ -88,37 +172,109 @@ async def close_impl(self) -> None:
self.notify_state(ConnectionState.CLOSED)

async def send_protocol_message(self, protocol_message: dict) -> None:
if self.state in (
ConnectionState.DISCONNECTED,
ConnectionState.CONNECTING,
):
self.queued_messages.put(protocol_message)
return

if self.state == ConnectionState.CONNECTED:
if self.transport:
await self.transport.send(protocol_message)
else:
log.exception(
"ConnectionManager.send_protocol_message(): can not send message with no active transport"
"""Send a protocol message and optionally track it for acknowledgment

Args:
protocol_message: protocol message dict (new message)
Returns:
None
"""
if self.state not in (ConnectionState.DISCONNECTED, ConnectionState.CONNECTING, ConnectionState.CONNECTED):
raise AblyException(f"ConnectionManager.send_protocol_message(): called in {self.state}", 500, 50000)

pending_message = PendingMessage(protocol_message)

# Assign msgSerial to messages that need acknowledgment
if pending_message.ack_required:
# New message - assign fresh serial
protocol_message['msgSerial'] = self.msg_serial
self.pending_message_queue.push(pending_message)
self.msg_serial += 1

if self.state in (ConnectionState.DISCONNECTED, ConnectionState.CONNECTING):
self.queued_messages.appendleft(pending_message)
if pending_message.ack_required:
await pending_message.future
return None

return await self._send_protocol_message_on_connected_state(pending_message)

async def _send_protocol_message_on_connected_state(self, pending_message: PendingMessage) -> None:
if self.state == ConnectionState.CONNECTED and self.transport:
# Add to pending queue before sending (for messages being resent from queue)
if pending_message.ack_required and pending_message not in self.pending_message_queue.messages:
self.pending_message_queue.push(pending_message)
await self.transport.send(pending_message.message)
else:
log.exception(
"ConnectionManager.send_protocol_message(): can not send message with no active transport"
)
if pending_message.future:
pending_message.future.set_exception(
AblyException("No active transport", 500, 50000)
)
if pending_message.ack_required:
await pending_message.future
return None

def send_queued_messages(self) -> None:
log.info(f'ConnectionManager.send_queued_messages(): sending {len(self.queued_messages)} message(s)')
while len(self.queued_messages) > 0:
pending_message = self.queued_messages.pop()
asyncio.create_task(self._send_protocol_message_on_connected_state(pending_message))

def requeue_pending_messages(self) -> None:
"""RTN19a: Requeue messages awaiting ACK/NACK when transport disconnects

These messages will be resent when connection becomes CONNECTED again.
RTN19a2: msgSerial is preserved for resume, reset for new connection.
"""
pending_count = self.pending_message_queue.count()
if pending_count == 0:
return

raise AblyException(f"ConnectionManager.send_protocol_message(): called in {self.state}", 500, 50000)
log.info(
f'ConnectionManager.requeue_pending_messages(): '
f'requeuing {pending_count} pending message(s) for resend'
)

def send_queued_messages(self) -> None:
log.info(f'ConnectionManager.send_queued_messages(): sending {self.queued_messages.qsize()} message(s)')
while not self.queued_messages.empty():
asyncio.create_task(self.send_protocol_message(self.queued_messages.get()))
# Get all pending messages and add them back to the queue
# They'll be sent again when we reconnect
pending_messages = list(self.pending_message_queue.messages)

# Add back to front of queue (FIFO but priority over new messages)
# Store the entire PendingMessage object to preserve Future
for pending_msg in reversed(pending_messages):
# PendingMessage object retains its Future, msgSerial
self.queued_messages.append(pending_msg)

# Clear the message queue since we're requeueing them all
# When they're resent, the existing Future will be resolved
self.pending_message_queue.clear()

def fail_queued_messages(self, err) -> None:
log.info(
f"ConnectionManager.fail_queued_messages(): discarding {self.queued_messages.qsize()} messages;" +
f"ConnectionManager.fail_queued_messages(): discarding {len(self.queued_messages)} messages;" +
f" reason = {err}"
)
while not self.queued_messages.empty():
msg = self.queued_messages.get()
log.exception(f"ConnectionManager.fail_queued_messages(): Failed to send protocol message: {msg}")
error = err or AblyException("Connection failed", 80000, 500)
while len(self.queued_messages) > 0:
pending_msg = self.queued_messages.pop()
log.exception(
f"ConnectionManager.fail_queued_messages(): Failed to send protocol message: "
f"{pending_msg.message}"
)
# Fail the Future if it exists
if pending_msg.future and not pending_msg.future.done():
pending_msg.future.set_exception(error)

# Also fail all pending messages awaiting acknowledgment
if self.pending_message_queue.count() > 0:
count = self.pending_message_queue.count()
log.info(
f"ConnectionManager.fail_queued_messages(): failing {count} pending messages"
)
self.pending_message_queue.complete_all_messages(error)

async def ping(self) -> float:
if self.__ping_future:
Expand Down Expand Up @@ -149,6 +305,16 @@ def on_connected(self, connection_details: ConnectionDetails, connection_id: str
reason: AblyException | None = None) -> None:
self.__fail_state = ConnectionState.DISCONNECTED

# RTN19a2: Reset msgSerial if connectionId changed (new connection)
prev_connection_id = self.connection_id
connection_id_changed = prev_connection_id is not None and prev_connection_id != connection_id

if connection_id_changed:
log.info('ConnectionManager.on_connected(): New connectionId; resetting msgSerial')
self.msg_serial = 0
# Note: In JS they call resetSendAttempted() here, but we don't need it
# because we fail all pending messages on disconnect per RTN7e

self.__connection_details = connection_details
self.connection_id = connection_id

Expand Down Expand Up @@ -244,7 +410,36 @@ def on_heartbeat(self, id: str | None) -> None:
self.__ping_future.set_result(None)
self.__ping_future = None

def on_ack(self, serial: int, count: int) -> None:
"""Handle ACK protocol message from server

Args:
serial: The msgSerial of the first message being acknowledged
count: The number of messages being acknowledged
"""
log.debug(f'ConnectionManager.on_ack(): serial={serial}, count={count}')
self.pending_message_queue.complete_messages(serial, count)

def on_nack(self, serial: int, count: int, err: AblyException | None) -> None:
"""Handle NACK protocol message from server

Args:
serial: The msgSerial of the first message being rejected
count: The number of messages being rejected
err: Error information from the server
"""
if not err:
err = AblyException('Unable to send message; channel not responding', 50001, 500)

log.error(f'ConnectionManager.on_nack(): serial={serial}, count={count}, err={err}')
self.pending_message_queue.complete_messages(serial, count, err)

def deactivate_transport(self, reason: AblyException | None = None):
# RTN19a: Before disconnecting, requeue any pending messages
# so they'll be resent on reconnection
if self.transport:
log.info('ConnectionManager.deactivate_transport(): requeuing pending messages')
self.requeue_pending_messages()
self.transport = None
self.notify_state(ConnectionState.DISCONNECTED, reason)

Expand Down Expand Up @@ -383,8 +578,16 @@ def notify_state(self, state: ConnectionState, reason: AblyException | None = No
ConnectionState.SUSPENDED,
ConnectionState.FAILED,
):
# RTN7e: Fail pending messages on SUSPENDED, CLOSED, FAILED
self.fail_queued_messages(reason)
self.ably.channels._propagate_connection_interruption(state, reason)
elif state == ConnectionState.DISCONNECTED and not self.options.queue_messages:
# RTN7d: If queueMessages is false, fail pending messages on DISCONNECTED
log.info(
'ConnectionManager.notify_state(): queueMessages is false; '
'failing pending messages on DISCONNECTED'
)
self.fail_queued_messages(reason)

def start_transition_timer(self, state: ConnectionState, fail_state: ConnectionState | None = None) -> None:
log.debug(f'ConnectionManager.start_transition_timer(): transition state = {state}')
Expand Down Expand Up @@ -466,6 +669,8 @@ def cancel_retry_timer(self) -> None:
def disconnect_transport(self) -> None:
log.info('ConnectionManager.disconnect_transport()')
if self.transport:
# RTN19a: Requeue pending messages before disposing transport
self.requeue_pending_messages()
self.disconnect_transport_task = asyncio.create_task(self.transport.dispose())

async def on_auth_updated(self, token_details: TokenDetails):
Expand Down
Loading