Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/cvec/cvec.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
Expand All @@ -12,6 +13,8 @@
metric_data_points_to_arrow,
)

logger = logging.getLogger(__name__)


class CVec:
"""
Expand Down Expand Up @@ -91,6 +94,7 @@ def _get_headers(self) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self._access_token}",
"Content-Type": "application/json",
"Accept": "application/json",
}

def _make_request(
Expand All @@ -117,7 +121,6 @@ def _make_request(
data=data,
)

# If we get a 401 and we have Supabase tokens, try to refresh and retry
if response.status_code == 401 and self._access_token and self._refresh_token:
try:
self._refresh_supabase_token()
Expand All @@ -135,9 +138,14 @@ def _make_request(
json=json,
data=data,
)
except Exception:
print("Token refresh failed")
# If refresh fails, continue with the original error
except (requests.RequestException, ValueError, KeyError) as e:
logger.warning(
"Token refresh failed, continuing with original request: %s",
e,
exc_info=True,
)
# If refresh fails, continue with the original error response
# which will be raised by raise_for_status() below
pass

response.raise_for_status()
Expand Down Expand Up @@ -410,6 +418,7 @@ def _login_with_supabase(self, email: str, password: str) -> None:

headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"apikey": self._publishable_key,
}

Expand All @@ -434,6 +443,7 @@ def _refresh_supabase_token(self) -> None:

headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"apikey": self._publishable_key,
}

Expand Down
136 changes: 136 additions & 0 deletions tests/test_token_refresh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Tests for token refresh functionality."""

import pytest
import requests
from typing import Any
from unittest.mock import Mock, patch

from cvec import CVec


class TestTokenRefresh:
"""Test cases for automatic token refresh functionality."""

@patch.object(CVec, "_login_with_supabase", return_value=None)
@patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key")
@patch("cvec.cvec.requests.request")
def test_token_refresh_on_401(
self,
mock_request: Any,
mock_fetch_key: Any,
mock_login: Any,
) -> None:
"""Test that token refresh is triggered on 401 Unauthorized."""
client = CVec(
host="https://test.example.com",
api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O",
)

client._access_token = "expired_token"
client._refresh_token = "valid_refresh_token"

# Mock response sequence
mock_response_401 = Mock()
mock_response_401.status_code = 401

mock_response_success = Mock()
mock_response_success.status_code = 200
mock_response_success.headers = {"content-type": "application/json"}
mock_response_success.json.return_value = []

mock_request.side_effect = [
mock_response_401,
mock_response_success,
]

# Mock refresh method
refresh_called: list[bool] = []

def mock_refresh() -> None:
refresh_called.append(True)
client._access_token = "new_token"

with patch.object(client, "_refresh_supabase_token", side_effect=mock_refresh):
# Execute request
result = client.get_metrics()

# Verify refresh was called
assert len(refresh_called) == 1
assert client._access_token == "new_token"
assert result == []

@patch.object(CVec, "_login_with_supabase", return_value=None)
@patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key")
@patch("cvec.cvec.requests.request")
def test_token_refresh_handles_network_errors_gracefully(
self,
mock_request: Any,
mock_fetch_key: Any,
mock_login: Any,
) -> None:
"""Test that network errors during refresh don't crash, returns original error."""
client = CVec(
host="https://test.example.com",
api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O",
)

client._access_token = "expired_token"
client._refresh_token = "valid_refresh_token"

# Mock response: 401 triggers refresh
mock_response_401 = Mock()
mock_response_401.status_code = 401
mock_response_401.raise_for_status.side_effect = requests.HTTPError(
"401 Client Error: Unauthorized"
)

mock_request.return_value = mock_response_401

# Mock refresh to raise network error
def mock_refresh_with_error() -> None:
raise requests.ConnectionError("Network unreachable")

with patch.object(
client, "_refresh_supabase_token", side_effect=mock_refresh_with_error
):
# Should not crash, should raise the original 401 error
with pytest.raises(requests.HTTPError, match="401"):
client.get_metrics()

@patch.object(CVec, "_login_with_supabase", return_value=None)
@patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key")
@patch("cvec.cvec.requests.request")
def test_token_refresh_handles_missing_refresh_token(
self,
mock_request: Any,
mock_fetch_key: Any,
mock_login: Any,
) -> None:
"""Test that missing refresh token is handled gracefully."""
client = CVec(
host="https://test.example.com",
api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O",
)

client._access_token = "expired_token"
client._refresh_token = "valid_refresh_token"

# Mock response: 401 triggers refresh
mock_response_401 = Mock()
mock_response_401.status_code = 401
mock_response_401.raise_for_status.side_effect = requests.HTTPError(
"401 Client Error: Unauthorized"
)

mock_request.return_value = mock_response_401

# Mock refresh to raise ValueError (missing refresh token)
def mock_refresh_with_error() -> None:
raise ValueError("No refresh token available")

with patch.object(
client, "_refresh_supabase_token", side_effect=mock_refresh_with_error
):
# Should not crash, should raise the original 401 error
with pytest.raises(requests.HTTPError, match="401"):
client.get_metrics()