diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7939d09..bcb0e3a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,80 @@ Changelog ========= +Unreleased +========== + +Added +----- +- **Firmware Payload Capture Tool**: New example script + ``examples/advanced/firmware_payload_capture.py`` for capturing raw MQTT + payloads to detect firmware-introduced protocol changes. Subscribes to all + response and event topics via wildcards, requests the full scheduling + data set (weekly reservations, TOU, device info), and saves everything to a + timestamped JSON file suitable for ``jq``/``diff`` comparison across firmware + versions. + +Fixed +----- +- **Timezone-naive datetime in token expiry checks**: ``AuthTokens.is_expired``, + ``are_aws_credentials_expired``, and ``time_until_expiry`` used + ``datetime.now()`` (naive, local time). During DST transitions or timezone + changes this could cause incorrect expiry detection, leading to premature + re-authentication or use of an actually-expired token. Fixed by using + ``datetime.now(UTC)`` throughout, switching the ``issued_at`` field default + to ``datetime.now(UTC)``, and adding a field validator to normalize any + timezone-naive ``issued_at`` values loaded from old stored token files to UTC + (previously this would raise a ``TypeError`` at comparison time). +- **Duplicate AWS IoT subscribe calls on reconnect**: ``resubscribe_all()`` + called ``connection.subscribe()`` (a network round-trip to AWS IoT) once per + handler per topic. If a topic had N handlers, N identical subscribe requests + were sent on every reconnect. Fixed by making one network call per unique + topic and registering remaining handlers directly into ``_message_handlers``. +- **Anti-Legionella set-period State Preservation**: ``nwp-cli anti-legionella + set-period`` was calling ``enable_anti_legionella()`` in both the enabled and + disabled branches, silently re-enabling the feature when it was off. The + command now informs the user that the period can only be updated while the + feature is enabled and directs them to ``anti-legionella enable``. +- **Subscription State Lost After Failed Resubscription**: ``resubscribe_all()`` + cleared ``_subscriptions`` and ``_message_handlers`` before the re-subscribe + loop. Topics that failed to resubscribe were permanently dropped from internal + state and could not be retried on the next reconnection. Failed topics are now + restored so they are retried automatically. +- **Unit System Detection Returns None on Timeout**: ``_detect_unit_system()`` + declared return type ``UnitSystemType`` but returned ``None`` on + ``TimeoutError``, violating the type contract. Now returns + ``"us_customary"`` consistent with the warning message. +- **Once-Listener Becomes Permanent With Duplicate Callbacks**: ``emit()`` + identified once-listeners via a ``set`` of ``(event, callback)`` tuples. If + the same callback was registered twice with ``once=True``, the set + deduplicated the tuple — after the first emit the second listener lost its + once-status and became permanent. Fixed by checking ``listener.once`` + directly on the ``EventListener`` object. +- **Auth Session Leaked on Client Construction Failure**: In + ``create_navien_clients()``, if ``NavienAPIClient`` or + ``NavienMqttClient`` construction raised after a successful + ``auth_client.__aenter__()``, the auth session and its underlying + ``aiohttp`` session would leak. Client construction is now wrapped in a + ``try/except`` that calls ``auth_client.__aexit__()`` on failure. + Additionally, both ``except BaseException`` blocks have been replaced with + ``except Exception`` (passing real exception info to ``__aexit__``) plus a + separate ``except asyncio.CancelledError`` block that uses + ``asyncio.shield()`` to ensure cleanup completes even when the task is + being cancelled. +- **Hypothesis Tests Broke All Test Collection**: ``test_mqtt_hypothesis.py`` + imported ``hypothesis`` at module level; when it was not installed, pytest + failed to collect every test in the suite. ``hypothesis`` is now mandated + as a ``[testing]`` extra dependency, restoring correct collection behaviour. + +Changed +------- +- **Dependency updates**: Bumped minimum versions to track current releases: + ``aiohttp >= 3.13.5``, ``pydantic >= 2.12.5``, ``click >= 8.3.0``, + ``rich >= 14.3.0``. +- **Dependency: awsiotsdk >= 1.28.2**: Bumped minimum ``awsiotsdk`` version + from ``>=1.27.0`` to ``>=1.28.2`` to track the current patch release. + ``awscrt`` 0.31.3 is pulled in transitively. + Version 7.4.8 (2026-02-17) ========================== diff --git a/examples/advanced/firmware_payload_capture.py b/examples/advanced/firmware_payload_capture.py new file mode 100644 index 0000000..3a1c715 --- /dev/null +++ b/examples/advanced/firmware_payload_capture.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +""" +Firmware Payload Capture Tool. + +Captures raw MQTT payloads for all scheduling-related topics and dumps them +to a timestamped JSON file. Use this to detect changes introduced by firmware +updates by diffing captures taken before and after an update. + +Specifically captures: + - Weekly reservations (rsv/rd) + - Time-of-Use schedule (tou/rd) + - Device info (firmware versions, capabilities) + - Device status (current operating state) + - All other response/event topics (via wildcards) + +Usage: + NAVIEN_EMAIL=your@email.com NAVIEN_PASSWORD=password python3 firmware_payload_capture.py + +Output: + payload_capture_YYYYMMDD_HHMMSS.json — all captured payloads with topics + and timestamps. Sensitive fields + (MAC address, session IDs, client + IDs) are redacted in the output. + +Comparing two captures to find firmware changes: + diff <(jq '.payloads[] | select(.topic | contains("rsv"))' before.json) \\ + <(jq '.payloads[] | select(.topic | contains("rsv"))' after.json) +""" + +import asyncio +import json +import logging +import os +import sys +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from nwp500 import NavienAPIClient, NavienAuthClient, NavienMqttClient +from nwp500.models import DeviceFeature +from nwp500.mqtt.utils import redact, redact_topic +from nwp500.topic_builder import MqttTopicBuilder + +logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) +_logger = logging.getLogger(__name__) + + +class PayloadCapture: + """Captures and records raw MQTT payloads.""" + + def __init__(self) -> None: + self.payloads: list[dict[str, Any]] = [] + + def record(self, topic: str, message: dict[str, Any]) -> None: + entry = { + "timestamp": datetime.now(UTC).isoformat(), + "topic": topic, + "payload": message, + } + self.payloads.append(entry) + print(f" ← {redact_topic(topic)}") + + def save(self, path: Path) -> None: + # Redact sensitive fields (MAC, session IDs, client IDs) before saving + # so the output file is safe to share. Protocol structure and payload + # field values used for firmware analysis are preserved. + redacted_payloads = [ + { + "timestamp": e["timestamp"], + "topic": redact_topic(e["topic"]), + "payload": redact(e["payload"]), + } + for e in self.payloads + ] + data = { + "captured_at": datetime.now(UTC).isoformat(), + "total_payloads": len(self.payloads), + "payloads": redacted_payloads, + } + path.write_text(json.dumps(data, indent=2, default=str)) + print(f"\nSaved {len(self.payloads)} payloads → {path}") + + +async def main() -> None: + email = os.getenv("NAVIEN_EMAIL") + password = os.getenv("NAVIEN_PASSWORD") + + if not email or not password: + print("Error: set NAVIEN_EMAIL and NAVIEN_PASSWORD environment variables") + sys.exit(1) + + capture = PayloadCapture() + + async with NavienAuthClient(email, password) as auth_client: + api_client = NavienAPIClient(auth_client=auth_client) + device = await api_client.get_first_device() + if not device: + print("No devices found for this account") + return + + device_type = str(device.device_info.device_type) + mac = device.device_info.mac_address + print(f"Device: {device.device_info.device_name} [{device_type}]") + + mqtt_client = NavienMqttClient(auth_client) + await mqtt_client.connect() + + client_id = mqtt_client.client_id + + # --- Wildcard subscriptions to catch everything --- + + # All response messages back to this client + res_wildcard = MqttTopicBuilder.response_topic(device_type, client_id, "#") + # All event messages pushed by the device + evt_wildcard = MqttTopicBuilder.event_topic(device_type, mac, "#") + + print( + f"\nSubscribing to:\n {redact_topic(res_wildcard)}\n" + f" {redact_topic(evt_wildcard)}\n" + ) + print("Captured topics:") + + await mqtt_client.subscribe(res_wildcard, capture.record) + await mqtt_client.subscribe(evt_wildcard, capture.record) + + # --- Step 1: fetch device info (needed for firmware version + serial) --- + device_info_event: asyncio.Event = asyncio.Event() + device_feature: DeviceFeature | None = None + + def on_feature(feature: DeviceFeature) -> None: + nonlocal device_feature + device_feature = feature + device_info_event.set() + + await mqtt_client.subscribe_device_feature(device, on_feature) + await mqtt_client.control.request_device_info(device) + await asyncio.wait_for(device_info_event.wait(), timeout=30.0) + + if device_feature: + print( + f"\nFirmware: controller={device_feature.controller_sw_version} " + f"panel={device_feature.panel_sw_version} " + f"wifi={device_feature.wifi_sw_version}" + ) + + # --- Step 2: request device status --- + await mqtt_client.control.request_device_status(device) + await asyncio.sleep(3) + + # --- Step 3: request reservation (weekly) schedule --- + print("\nRequesting weekly reservation schedule...") + await mqtt_client.control.request_reservations(device) + await asyncio.sleep(5) + + # --- Step 4: request TOU schedule (requires controller serial number) --- + if device_feature and device_feature.program_reservation_use: + serial = device_feature.controller_serial_number + if serial: + print("Requesting TOU schedule...") + try: + await mqtt_client.control.request_tou_settings(device, serial) + await asyncio.sleep(5) + except Exception as exc: + print(f" TOU request failed: {exc}") + + # --- Step 5: wait a bit more to catch any late-arriving messages --- + print("\nWaiting for any remaining messages...") + await asyncio.sleep(5) + + await mqtt_client.disconnect() + + # --- Save results --- + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") + output_path = Path(f"payload_capture_{timestamp}.json") + capture.save(output_path) + + # Print a summary grouped by topic + print("\n--- Summary by topic ---") + by_topic: dict[str, int] = {} + for entry in capture.payloads: + by_topic[entry["topic"]] = by_topic.get(entry["topic"], 0) + 1 + for topic, count in sorted(by_topic.items()): + print(f" {count:2d}x {redact_topic(topic)}") + + if device_feature: + print( + f"\nFirmware captured: controller_sw_version=" + f"{device_feature.controller_sw_version}" + ) + print( + "Compare this file against a capture from a different firmware version " + "to detect scheduling changes.\n" + "Useful diff command:\n" + " diff <(jq '.payloads[] | select(.topic | contains(\"rsv\"))' " + f"before.json) \\\n" + " <(jq '.payloads[] | select(.topic | contains(\"rsv\"))' " + f"{output_path})" + ) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nCancelled by user") + except TimeoutError: + print("\nError: timed out waiting for device response. Is the device online?") + sys.exit(1) diff --git a/setup.cfg b/setup.cfg index 89dee02..7145b1f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,9 +51,9 @@ python_requires = >=3.13 # new major versions. This works if the required packages follow Semantic Versioning. # For more information, check out https://semver.org/. install_requires = - aiohttp>=3.8.0 - awsiotsdk>=1.27.0 - pydantic>=2.0.0 + aiohttp>=3.13.5 + awsiotsdk>=1.28.2 + pydantic>=2.12.5 [options.packages.find] @@ -68,8 +68,8 @@ exclude = # CLI - command line interface with optional rich formatting cli = - click>=8.0.0 - rich>=13.0.0 + click>=8.3.0 + rich>=14.3.0 # Add here test requirements (semicolon/line-separated) testing = diff --git a/src/nwp500/auth.py b/src/nwp500/auth.py index 5a2f228..85dd61b 100644 --- a/src/nwp500/auth.py +++ b/src/nwp500/auth.py @@ -15,11 +15,18 @@ import json import logging -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from typing import Any, Literal, Self, cast import aiohttp -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + field_validator, + model_validator, +) from pydantic.alias_generators import to_camel from . import __version__ @@ -79,11 +86,22 @@ class AuthTokens(NavienBaseModel): authorization_expires_in: int | None = None # Calculated fields - issued_at: datetime = Field(default_factory=datetime.now) + issued_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) _expires_at: datetime = PrivateAttr() _aws_expires_at: datetime | None = PrivateAttr(default=None) + @field_validator("issued_at", mode="before") + @classmethod + def _normalize_issued_at_tz(cls, v: Any) -> Any: + """Assume UTC for timezone-naive datetimes. + + Handles old stored tokens that may not have timezone info. + """ + if isinstance(v, datetime) and v.tzinfo is None: + return v.replace(tzinfo=UTC) + return v + @model_validator(mode="before") @classmethod def handle_empty_aliases(cls, data: Any) -> Any: @@ -159,7 +177,7 @@ def expires_at(self) -> datetime: def is_expired(self) -> bool: """Check if the access token has expired (cached calculation).""" # Consider expired if within 5 minutes of expiration - return datetime.now() >= (self._expires_at - timedelta(minutes=5)) + return datetime.now(UTC) >= (self._expires_at - timedelta(minutes=5)) @property def are_aws_credentials_expired(self) -> bool: @@ -178,7 +196,9 @@ def are_aws_credentials_expired(self) -> bool: # This handles cases where authorization_expires_in wasn't provided return False # Consider expired if within 5 minutes of expiration - return datetime.now() >= (self._aws_expires_at - timedelta(minutes=5)) + return datetime.now(UTC) >= ( + self._aws_expires_at - timedelta(minutes=5) + ) @property def time_until_expiry(self) -> timedelta: @@ -186,7 +206,7 @@ def time_until_expiry(self) -> timedelta: Uses cached expiration time for efficiency. """ - return self._expires_at - datetime.now() + return self._expires_at - datetime.now(UTC) @property def bearer_token(self) -> str: diff --git a/src/nwp500/cli/__main__.py b/src/nwp500/cli/__main__.py index dd5749a..76a06ff 100644 --- a/src/nwp500/cli/__main__.py +++ b/src/nwp500/cli/__main__.py @@ -66,7 +66,7 @@ def _on_status(status: DeviceStatus) -> None: _logger.warning( "Timed out detecting unit system, defaulting to us_customary" ) - return None + return "us_customary" def async_command(f: Any) -> Any: diff --git a/src/nwp500/cli/handlers.py b/src/nwp500/cli/handlers.py index 9c0ffce..c10c49c 100644 --- a/src/nwp500/cli/handlers.py +++ b/src/nwp500/cli/handlers.py @@ -452,14 +452,15 @@ def _on_status(status: DeviceStatus) -> None: # Get current enabled state use = getattr(status, "anti_legionella_use", None) - # If enabled, keep it enabled; otherwise, enable it - # (period only, no disable-state for set operation) if use: await mqtt.control.enable_anti_legionella(device, period_days) + print(f"Anti-Legionella period set to {period_days} day(s)") else: - await mqtt.control.enable_anti_legionella(device, period_days) - - print(f"✓ Anti-Legionella period set to {period_days} day(s)") + print( + "Anti-Legionella is currently disabled. " + "Enable it first to set the period, or use " + "'anti-legionella enable' with the desired period." + ) except (RangeValidationError, ValidationError) as e: _logger.error(f"Failed to set Anti-Legionella period: {e}") except DeviceError as e: diff --git a/src/nwp500/events.py b/src/nwp500/events.py index 69b64b2..fe2da2c 100644 --- a/src/nwp500/events.py +++ b/src/nwp500/events.py @@ -252,8 +252,8 @@ async def emit(self, event: str, *args: Any, **kwargs: Any) -> int: called_count += 1 - # Check if this is a once listener using O(1) set lookup - if (event, listener.callback) in self._once_callbacks: + # Check if this is a once listener + if listener.once: listeners_to_remove.append(listener) self._once_callbacks.discard((event, listener.callback)) diff --git a/src/nwp500/factory.py b/src/nwp500/factory.py index 9edf6f3..378b2bd 100644 --- a/src/nwp500/factory.py +++ b/src/nwp500/factory.py @@ -18,6 +18,8 @@ ... devices = await api.list_devices() """ +import asyncio + from .api_client import NavienAPIClient from .auth import NavienAuthClient from .mqtt import NavienMqttClient @@ -75,13 +77,20 @@ async def create_navien_clients( # Authenticate and enter context manager try: await auth_client.__aenter__() - except BaseException: - # Ensure session is cleaned up if authentication fails - await auth_client.__aexit__(None, None, None) + except asyncio.CancelledError: + # Shield cleanup from further cancellation + await asyncio.shield(auth_client.__aexit__(None, None, None)) + raise + except Exception as exc: + await auth_client.__aexit__(type(exc), exc, exc.__traceback__) raise # Create API and MQTT clients that share the session - api_client = NavienAPIClient(auth_client=auth_client) - mqtt_client = NavienMqttClient(auth_client=auth_client) + try: + api_client = NavienAPIClient(auth_client=auth_client) + mqtt_client = NavienMqttClient(auth_client=auth_client) + except Exception as exc: + await auth_client.__aexit__(type(exc), exc, exc.__traceback__) + raise return auth_client, api_client, mqtt_client diff --git a/src/nwp500/mqtt/subscriptions.py b/src/nwp500/mqtt/subscriptions.py index fb50bfd..577bc26 100644 --- a/src/nwp500/mqtt/subscriptions.py +++ b/src/nwp500/mqtt/subscriptions.py @@ -304,26 +304,42 @@ async def resubscribe_all(self) -> None: self._subscriptions.clear() self._message_handlers.clear() - # Re-establish each subscription + # Re-establish each subscription — one network call per topic, + # regardless of how many handlers are registered for it. failed_subscriptions: set[str] = set() for topic, qos in subscriptions_to_restore: handlers = handlers_to_restore.get(topic, []) - for handler in handlers: - try: - await self.subscribe(topic, handler, qos) - except (AwsCrtError, RuntimeError) as e: - _logger.error( - f"Failed to re-subscribe to " - f"'{redact_topic(topic)}': {e}" - ) - # Mark topic as failed and skip remaining handlers - # since they will fail for the same reason - failed_subscriptions.add(topic) - break # Exit handler loop, move to next topic + if not handlers: + continue + try: + # One network subscribe for the first handler + await self.subscribe(topic, handlers[0], qos) + except (AwsCrtError, RuntimeError) as e: + _logger.error( + f"Failed to re-subscribe to '{redact_topic(topic)}': {e}" + ) + failed_subscriptions.add(topic) + continue + + # Register remaining handlers without extra network calls + for handler in handlers[1:]: + if handler not in self._message_handlers[topic]: + self._message_handlers[topic].append(handler) if failed_subscriptions: + # Restore failed subscriptions to internal state so they can be + # retried on the next reconnection cycle. + qos_map = dict(subscriptions_to_restore) + for topic in failed_subscriptions: + self._subscriptions[topic] = qos_map.get( + topic, mqtt.QoS.AT_LEAST_ONCE + ) + self._message_handlers[topic] = handlers_to_restore.get( + topic, [] + ) _logger.warning( - f"Failed to restore {len(failed_subscriptions)} subscription(s)" + f"Failed to restore {len(failed_subscriptions)} " + "subscription(s); will retry on next reconnection" ) else: _logger.info("All subscriptions re-established successfully") diff --git a/tests/test_auth.py b/tests/test_auth.py index c2a4daf..bf4b58c 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,6 +1,6 @@ """Tests for authentication functionality.""" -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from unittest.mock import AsyncMock, MagicMock, patch import aiohttp @@ -122,7 +122,7 @@ def test_auth_tokens_creation(): def test_auth_tokens_expires_at_calculation(): """Test AuthTokens expires_at property.""" - now = datetime.now() + now = datetime.now(UTC) tokens = AuthTokens( id_token="test", access_token="test", @@ -149,7 +149,7 @@ def test_auth_tokens_is_expired_false(): def test_auth_tokens_is_expired_true(): """Test AuthTokens.is_expired when token is expired.""" - old_time = datetime.now() - timedelta(seconds=7200) # 2 hours ago + old_time = datetime.now(UTC) - timedelta(seconds=7200) # 2 hours ago tokens = AuthTokens( id_token="test", access_token="test", @@ -164,7 +164,7 @@ def test_auth_tokens_is_expired_true(): def test_auth_tokens_is_expired_near_expiry(): """Test AuthTokens.is_expired within 5-minute buffer.""" # Token expires in 3 minutes - should be considered expired - near_expiry = datetime.now() - timedelta(seconds=3420) # 57 minutes ago + near_expiry = datetime.now(UTC) - timedelta(seconds=3420) # 57 minutes ago tokens = AuthTokens( id_token="test", access_token="test", @@ -194,7 +194,7 @@ def test_auth_tokens_aws_credentials_expired_false(): def test_auth_tokens_aws_credentials_expired_true(): """Test are_aws_credentials_expired when AWS credentials are expired.""" - old_time = datetime.now() - timedelta(seconds=7200) # 2 hours ago + old_time = datetime.now(UTC) - timedelta(seconds=7200) # 2 hours ago tokens = AuthTokens( id_token="test", access_token="test", @@ -583,7 +583,7 @@ async def test_ensure_valid_token_aws_credentials_expired(): ) # Create tokens with expired AWS credentials but valid JWT - old_time = datetime.now() - timedelta(seconds=3900) # 65 minutes ago + old_time = datetime.now(UTC) - timedelta(seconds=3900) # 65 minutes ago tokens = AuthTokens( id_token="test", access_token="test", @@ -660,7 +660,7 @@ async def test_ensure_valid_token_jwt_expired(): ) # Create tokens with expired JWT - old_time = datetime.now() - timedelta(seconds=3900) # 65 minutes ago + old_time = datetime.now(UTC) - timedelta(seconds=3900) # 65 minutes ago tokens = AuthTokens( id_token="test", access_token="test", @@ -792,7 +792,7 @@ async def test_context_manager(): def test_aws_credentials_preservation_in_token_refresh(): """Test that AWS credentials are preserved during token refresh.""" - old_time = datetime.now() - timedelta(seconds=1800) # 30 minutes ago + old_time = datetime.now(UTC) - timedelta(seconds=1800) # 30 minutes ago old_tokens = AuthTokens( id_token="old_id", @@ -842,7 +842,7 @@ def test_aws_credentials_preservation_in_token_refresh(): # Test token restoration functionality def test_auth_tokens_to_dict(): """Test AuthTokens.to_dict serialization.""" - issued_at = datetime.now() + issued_at = datetime.now(UTC) tokens = AuthTokens( id_token="test_id", access_token="test_access", @@ -865,12 +865,13 @@ def test_auth_tokens_to_dict(): assert result["secret_key"] == "test_secret" assert result["session_token"] == "test_session" assert result["authorization_expires_in"] == 1800 - assert result["issued_at"] == issued_at.isoformat() + expected_issued_at = issued_at.strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z" + assert result["issued_at"] == expected_issued_at def test_auth_tokens_from_dict_with_issued_at(): """Test AuthTokens.from_dict with issued_at timestamp.""" - issued_at = datetime.now() - timedelta(seconds=1800) + issued_at = datetime.now(UTC) - timedelta(seconds=1800) data = { "id_token": "test_id", "access_token": "test_access", @@ -899,7 +900,7 @@ def test_auth_tokens_from_dict_with_issued_at(): def test_auth_tokens_serialization_roundtrip(): """Test that tokens can be serialized and deserialized without data loss.""" - issued_at = datetime.now() - timedelta(seconds=1800) + issued_at = datetime.now(UTC) - timedelta(seconds=1800) original = AuthTokens( id_token="test_id", access_token="test_access", @@ -1021,7 +1022,7 @@ async def test_context_manager_with_valid_stored_tokens(): @pytest.mark.asyncio async def test_context_manager_with_expired_jwt_stored_tokens(): """Test async context manager with expired JWT refreshes tokens.""" - old_time = datetime.now() - timedelta(seconds=3900) # 65 minutes ago + old_time = datetime.now(UTC) - timedelta(seconds=3900) # 65 minutes ago stored_tokens = AuthTokens( id_token="stored_id", access_token="stored_access", @@ -1055,7 +1056,7 @@ async def test_context_manager_with_expired_jwt_stored_tokens(): @pytest.mark.asyncio async def test_context_manager_with_expired_aws_credentials(): """Test async context manager re-authenticates on AWS creds expiry.""" - old_time = datetime.now() - timedelta(seconds=3900) # 65 minutes ago + old_time = datetime.now(UTC) - timedelta(seconds=3900) # 65 minutes ago stored_tokens = AuthTokens( id_token="stored_id", access_token="stored_access", diff --git a/tests/test_mqtt_client_init.py b/tests/test_mqtt_client_init.py index 724e88b..2ecc21f 100644 --- a/tests/test_mqtt_client_init.py +++ b/tests/test_mqtt_client_init.py @@ -1,6 +1,6 @@ """Tests for MQTT client initialization and token validation.""" -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta import pytest @@ -39,7 +39,7 @@ def auth_client_with_valid_tokens(): def auth_client_with_expired_jwt(): """Create an auth client with expired JWT token.""" auth_client = NavienAuthClient("test@example.com", "password") - old_time = datetime.now() - timedelta(seconds=7200) + old_time = datetime.now(UTC) - timedelta(seconds=7200) expired_tokens = AuthTokens( id_token="test_id", access_token="test_access", @@ -62,7 +62,7 @@ def auth_client_with_expired_jwt(): def auth_client_with_expired_aws_credentials(): """Create an auth client with expired AWS credentials.""" auth_client = NavienAuthClient("test@example.com", "password") - old_time = datetime.now() - timedelta(seconds=7200) + old_time = datetime.now(UTC) - timedelta(seconds=7200) expired_tokens = AuthTokens( id_token="test_id", access_token="test_access", @@ -221,7 +221,7 @@ def test_has_valid_tokens_expired_jwt_only(self): auth_client = NavienAuthClient("test@example.com", "password") # Create tokens with expired JWT but valid AWS credentials - old_time = datetime.now() - timedelta(seconds=7200) + old_time = datetime.now(UTC) - timedelta(seconds=7200) expired_jwt_tokens = AuthTokens( id_token="test_id", access_token="test_access", @@ -248,7 +248,7 @@ def test_has_valid_tokens_expired_aws_credentials_only(self): auth_client = NavienAuthClient("test@example.com", "password") # Create tokens with valid JWT but expired AWS credentials - old_time = datetime.now() - timedelta(seconds=7200) + old_time = datetime.now(UTC) - timedelta(seconds=7200) expired_aws_tokens = AuthTokens( id_token="test_id", access_token="test_access", @@ -275,7 +275,7 @@ def test_has_valid_tokens_both_expired(self): auth_client = NavienAuthClient("test@example.com", "password") # Create tokens with both JWT and AWS credentials expired - old_time = datetime.now() - timedelta(seconds=7200) + old_time = datetime.now(UTC) - timedelta(seconds=7200) both_expired_tokens = AuthTokens( id_token="test_id", access_token="test_access", @@ -351,7 +351,7 @@ def test_has_valid_tokens_jwt_near_expiry_buffer(self): auth_client = NavienAuthClient("test@example.com", "password") # Token expires in 3 minutes (within 5-minute buffer) - near_expiry = datetime.now() - timedelta(seconds=3420) + near_expiry = datetime.now(UTC) - timedelta(seconds=3420) near_expiry_tokens = AuthTokens( id_token="test_id", access_token="test_access", @@ -376,7 +376,7 @@ def test_has_valid_tokens_aws_near_expiry_buffer(self): auth_client = NavienAuthClient("test@example.com", "password") # AWS creds expire in 3 minutes (within 5-minute buffer) - near_expiry = datetime.now() - timedelta(seconds=3420) + near_expiry = datetime.now(UTC) - timedelta(seconds=3420) near_expiry_tokens = AuthTokens( id_token="test_id", access_token="test_access", @@ -482,7 +482,7 @@ def test_expired_jwt_near_expiry_buffer(self): """ auth_client = NavienAuthClient("test@example.com", "password") # Token expires in 3 minutes - should be considered expired - near_expiry = datetime.now() - timedelta(seconds=3420) + near_expiry = datetime.now(UTC) - timedelta(seconds=3420) tokens = AuthTokens( id_token="test_id", access_token="test_access", @@ -763,7 +763,7 @@ async def test_recover_connection_with_expired_auth_client( # Manually expire tokens to simulate them expiring after creation mqtt_client._auth_client._auth_response.tokens._expires_at = ( - datetime.now() - timedelta(minutes=10) + datetime.now(UTC) - timedelta(minutes=10) ) # Mock ensure_valid_token to refresh the tokens