diff --git a/src/apm_cli/registry/client.py b/src/apm_cli/registry/client.py index 024000a7..8f174183 100644 --- a/src/apm_cli/registry/client.py +++ b/src/apm_cli/registry/client.py @@ -141,19 +141,18 @@ def get_server_by_name(self, name: str) -> Optional[Dict[str, Any]]: Optional[Dict[str, Any]]: Server metadata dictionary or None if not found. Raises: - requests.RequestException: If the request fails. + requests.RequestException: If the registry API request fails. """ # Use search API to find by name - more efficient than listing all servers - try: - search_results = self.search_servers(name) - - # Look for an exact match in search results - for server in search_results: - if server.get("name") == name: + search_results = self.search_servers(name) + + # Look for an exact match in search results + for server in search_results: + if server.get("name") == name: + try: return self.get_server_info(server["id"]) - - except Exception: - pass + except ValueError: + continue return None @@ -173,35 +172,37 @@ def find_server_by_reference(self, reference: str) -> Optional[Dict[str, Any]]: Optional[Dict[str, Any]]: Server metadata dictionary or None if not found. Raises: - requests.RequestException: If the request fails. + requests.RequestException: If the registry API request fails. """ # Strategy 1: Try as server ID first (direct lookup) try: # Check if it looks like a UUID (contains hyphens and is 36 chars) if len(reference) == 36 and reference.count('-') == 4: return self.get_server_info(reference) - except (ValueError, Exception): + except ValueError: pass # Strategy 2: Use search API to find by name # search_servers now handles extracting repository names internally - try: - search_results = self.search_servers(reference) - - # Pass 1: exact full-name match (prevents slug collisions) - for server in search_results: - server_name = server.get("name", "") - if server_name == reference: + search_results = self.search_servers(reference) + + # Pass 1: exact full-name match (prevents slug collisions) + for server in search_results: + server_name = server.get("name", "") + if server_name == reference: + try: return self.get_server_info(server["id"]) - - # Pass 2: fuzzy slug match (only when reference has no namespace) - for server in search_results: - server_name = server.get("name", "") - if self._is_server_match(reference, server_name): + except ValueError: + continue + + # Pass 2: fuzzy slug match (only when reference has no namespace) + for server in search_results: + server_name = server.get("name", "") + if self._is_server_match(reference, server_name): + try: return self.get_server_info(server["id"]) - - except Exception: - pass + except ValueError: + continue # If not found by ID or exact name, server is not in registry return None diff --git a/src/apm_cli/registry/operations.py b/src/apm_cli/registry/operations.py index c7109597..5fc9fa2d 100644 --- a/src/apm_cli/registry/operations.py +++ b/src/apm_cli/registry/operations.py @@ -1,11 +1,16 @@ """MCP server operations and installation logic.""" +import logging import os from typing import List, Dict, Set, Optional, Tuple from pathlib import Path +import requests + from .client import SimpleRegistryClient +logger = logging.getLogger(__name__) + class MCPServerOperations: """Handles MCP server operations like conflict detection and installation status.""" @@ -133,6 +138,8 @@ def validate_servers_exist(self, server_references: List[str]) -> Tuple[List[str """Validate that all servers exist in the registry before attempting installation. This implements fail-fast validation similar to npm's behavior. + Network errors are treated as transient — the server is assumed valid + so a flaky registry API does not block installation. Args: server_references: List of MCP server references to validate @@ -150,8 +157,15 @@ def validate_servers_exist(self, server_references: List[str]) -> Tuple[List[str valid_servers.append(server_ref) else: invalid_servers.append(server_ref) - except Exception: - invalid_servers.append(server_ref) + except requests.RequestException: + # Network/transient error — assume server exists and let + # downstream installation attempt the actual resolution. + logger.debug( + "Registry lookup failed for %s, assuming valid (transient error)", + server_ref, + exc_info=True, + ) + valid_servers.append(server_ref) return valid_servers, invalid_servers diff --git a/tests/unit/test_registry_client.py b/tests/unit/test_registry_client.py index 8730c463..83a6a9da 100644 --- a/tests/unit/test_registry_client.py +++ b/tests/unit/test_registry_client.py @@ -3,6 +3,7 @@ import unittest import os from unittest import mock +import requests from apm_cli.registry.client import SimpleRegistryClient from apm_cli.utils import github_host @@ -279,7 +280,7 @@ def test_find_server_by_reference_name_not_found(self, mock_search_servers): @mock.patch('apm_cli.registry.client.SimpleRegistryClient.get_server_info') @mock.patch('apm_cli.registry.client.SimpleRegistryClient.search_servers') def test_find_server_by_reference_name_match_get_server_info_fails(self, mock_search_servers, mock_get_server_info): - """Test finding a server by name when get_server_info fails.""" + """Test finding a server by name when get_server_info raises ValueError (stale ID).""" # Mock search_servers mock_search_servers.return_value = [ { @@ -288,16 +289,30 @@ def test_find_server_by_reference_name_match_get_server_info_fails(self, mock_se } ] - # Mock get_server_info to fail - mock_get_server_info.side_effect = Exception("Network error") + # Mock get_server_info to fail with ValueError (server not found by ID) + mock_get_server_info.side_effect = ValueError("Server not found") - # Call the method + # Should return None when get_server_info fails with ValueError result = self.client.find_server_by_reference("test-server") - # Should return None when get_server_info fails self.assertIsNone(result) mock_search_servers.assert_called_once_with("test-server") - mock_get_server_info.assert_called_once_with("123e4567-e89b-12d3-a456-426614174000") + + @mock.patch('apm_cli.registry.client.SimpleRegistryClient.get_server_info') + @mock.patch('apm_cli.registry.client.SimpleRegistryClient.search_servers') + def test_find_server_by_reference_name_match_network_error_propagates(self, mock_search_servers, mock_get_server_info): + """Test that network errors in get_server_info propagate to the caller.""" + mock_search_servers.return_value = [ + { + "id": "123e4567-e89b-12d3-a456-426614174000", + "name": "test-server" + } + ] + + mock_get_server_info.side_effect = requests.ConnectionError("Network error") + + with self.assertRaises(requests.ConnectionError): + self.client.find_server_by_reference("test-server") @mock.patch('apm_cli.registry.client.SimpleRegistryClient.search_servers') def test_find_server_by_reference_invalid_format(self, mock_search_servers): diff --git a/tests/unit/test_registry_integration.py b/tests/unit/test_registry_integration.py index fa18fd3b..06e03ddd 100644 --- a/tests/unit/test_registry_integration.py +++ b/tests/unit/test_registry_integration.py @@ -174,5 +174,60 @@ def test_get_latest_version(self, mock_get_package_info): self.integration.get_latest_version("test-package") +class TestMCPServerOperationsValidation(unittest.TestCase): + """Tests for MCPServerOperations.validate_servers_exist resilience.""" + + def _make_ops(self): + """Create an MCPServerOperations with a mocked registry client.""" + from apm_cli.registry.operations import MCPServerOperations + ops = MCPServerOperations.__new__(MCPServerOperations) + ops.registry_client = mock.MagicMock() + return ops + + def test_valid_server(self): + """Server found in registry → valid.""" + ops = self._make_ops() + ops.registry_client.find_server_by_reference.return_value = {"id": "abc", "name": "srv"} + + valid, invalid = ops.validate_servers_exist(["io.github.test/srv"]) + self.assertEqual(valid, ["io.github.test/srv"]) + self.assertEqual(invalid, []) + + def test_missing_server(self): + """Server not in registry (None) → invalid.""" + ops = self._make_ops() + ops.registry_client.find_server_by_reference.return_value = None + + valid, invalid = ops.validate_servers_exist(["io.github.test/no-such"]) + self.assertEqual(valid, []) + self.assertEqual(invalid, ["io.github.test/no-such"]) + + def test_network_error_assumes_valid(self): + """Transient network error → assume server valid (not invalid).""" + ops = self._make_ops() + ops.registry_client.find_server_by_reference.side_effect = requests.ConnectionError("flaky") + + valid, invalid = ops.validate_servers_exist(["io.github.test/flaky-srv"]) + self.assertEqual(valid, ["io.github.test/flaky-srv"]) + self.assertEqual(invalid, []) + + def test_mixed_results(self): + """Mix of found, missing, and errored servers.""" + ops = self._make_ops() + + def side_effect(ref): + if ref == "found": + return {"id": "1", "name": "found"} + if ref == "missing": + return None + raise requests.Timeout("timeout") + + ops.registry_client.find_server_by_reference.side_effect = side_effect + + valid, invalid = ops.validate_servers_exist(["found", "missing", "flaky"]) + self.assertEqual(sorted(valid), ["flaky", "found"]) + self.assertEqual(invalid, ["missing"]) + + if __name__ == "__main__": unittest.main() \ No newline at end of file