Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions src/apm_cli/registry/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,18 @@ def get_server_by_name(self, name: str) -> Optional[Dict[str, Any]]:
Optional[Dict[str, Any]]: Server metadata dictionary or None if not found.

Raises:
requests.RequestException: If the request fails.
requests.RequestException: If the registry API request fails.
"""
# Use search API to find by name - more efficient than listing all servers
try:
search_results = self.search_servers(name)
# Look for an exact match in search results
for server in search_results:
if server.get("name") == name:
search_results = self.search_servers(name)

# Look for an exact match in search results
for server in search_results:
if server.get("name") == name:
try:
return self.get_server_info(server["id"])

except Exception:
pass
except ValueError:
continue

return None

Expand All @@ -173,35 +172,37 @@ def find_server_by_reference(self, reference: str) -> Optional[Dict[str, Any]]:
Optional[Dict[str, Any]]: Server metadata dictionary or None if not found.

Raises:
requests.RequestException: If the request fails.
requests.RequestException: If the registry API request fails.
"""
# Strategy 1: Try as server ID first (direct lookup)
try:
# Check if it looks like a UUID (contains hyphens and is 36 chars)
if len(reference) == 36 and reference.count('-') == 4:
return self.get_server_info(reference)
except (ValueError, Exception):
except ValueError:
pass

# Strategy 2: Use search API to find by name
# search_servers now handles extracting repository names internally
try:
search_results = self.search_servers(reference)
# Pass 1: exact full-name match (prevents slug collisions)
for server in search_results:
server_name = server.get("name", "")
if server_name == reference:
search_results = self.search_servers(reference)

# Pass 1: exact full-name match (prevents slug collisions)
for server in search_results:
server_name = server.get("name", "")
if server_name == reference:
try:
return self.get_server_info(server["id"])

# Pass 2: fuzzy slug match (only when reference has no namespace)
for server in search_results:
server_name = server.get("name", "")
if self._is_server_match(reference, server_name):
except ValueError:
continue

# Pass 2: fuzzy slug match (only when reference has no namespace)
for server in search_results:
server_name = server.get("name", "")
if self._is_server_match(reference, server_name):
try:
return self.get_server_info(server["id"])

except Exception:
pass
except ValueError:
continue

# If not found by ID or exact name, server is not in registry
return None
Expand Down
18 changes: 16 additions & 2 deletions src/apm_cli/registry/operations.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""MCP server operations and installation logic."""

import logging
import os
from typing import List, Dict, Set, Optional, Tuple
from pathlib import Path

import requests

from .client import SimpleRegistryClient

logger = logging.getLogger(__name__)


class MCPServerOperations:
"""Handles MCP server operations like conflict detection and installation status."""
Expand Down Expand Up @@ -133,6 +138,8 @@ def validate_servers_exist(self, server_references: List[str]) -> Tuple[List[str
"""Validate that all servers exist in the registry before attempting installation.

This implements fail-fast validation similar to npm's behavior.
Network errors are treated as transient — the server is assumed valid
so a flaky registry API does not block installation.

Args:
server_references: List of MCP server references to validate
Expand All @@ -150,8 +157,15 @@ def validate_servers_exist(self, server_references: List[str]) -> Tuple[List[str
valid_servers.append(server_ref)
else:
invalid_servers.append(server_ref)
except Exception:
invalid_servers.append(server_ref)
except requests.RequestException:
# Network/transient error — assume server exists and let
# downstream installation attempt the actual resolution.
logger.debug(
"Registry lookup failed for %s, assuming valid (transient error)",
server_ref,
exc_info=True,
)
valid_servers.append(server_ref)

return valid_servers, invalid_servers

Expand Down
27 changes: 21 additions & 6 deletions tests/unit/test_registry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest
import os
from unittest import mock
import requests
from apm_cli.registry.client import SimpleRegistryClient
from apm_cli.utils import github_host

Expand Down Expand Up @@ -279,7 +280,7 @@ def test_find_server_by_reference_name_not_found(self, mock_search_servers):
@mock.patch('apm_cli.registry.client.SimpleRegistryClient.get_server_info')
@mock.patch('apm_cli.registry.client.SimpleRegistryClient.search_servers')
def test_find_server_by_reference_name_match_get_server_info_fails(self, mock_search_servers, mock_get_server_info):
"""Test finding a server by name when get_server_info fails."""
"""Test finding a server by name when get_server_info raises ValueError (stale ID)."""
# Mock search_servers
mock_search_servers.return_value = [
{
Expand All @@ -288,16 +289,30 @@ def test_find_server_by_reference_name_match_get_server_info_fails(self, mock_se
}
]

# Mock get_server_info to fail
mock_get_server_info.side_effect = Exception("Network error")
# Mock get_server_info to fail with ValueError (server not found by ID)
mock_get_server_info.side_effect = ValueError("Server not found")

# Call the method
# Should return None when get_server_info fails with ValueError
result = self.client.find_server_by_reference("test-server")

# Should return None when get_server_info fails
self.assertIsNone(result)
mock_search_servers.assert_called_once_with("test-server")
mock_get_server_info.assert_called_once_with("123e4567-e89b-12d3-a456-426614174000")

@mock.patch('apm_cli.registry.client.SimpleRegistryClient.get_server_info')
@mock.patch('apm_cli.registry.client.SimpleRegistryClient.search_servers')
def test_find_server_by_reference_name_match_network_error_propagates(self, mock_search_servers, mock_get_server_info):
"""Test that network errors in get_server_info propagate to the caller."""
mock_search_servers.return_value = [
{
"id": "123e4567-e89b-12d3-a456-426614174000",
"name": "test-server"
}
]

mock_get_server_info.side_effect = requests.ConnectionError("Network error")

with self.assertRaises(requests.ConnectionError):
self.client.find_server_by_reference("test-server")

@mock.patch('apm_cli.registry.client.SimpleRegistryClient.search_servers')
def test_find_server_by_reference_invalid_format(self, mock_search_servers):
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/test_registry_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,60 @@ def test_get_latest_version(self, mock_get_package_info):
self.integration.get_latest_version("test-package")


class TestMCPServerOperationsValidation(unittest.TestCase):
"""Tests for MCPServerOperations.validate_servers_exist resilience."""

def _make_ops(self):
"""Create an MCPServerOperations with a mocked registry client."""
from apm_cli.registry.operations import MCPServerOperations
ops = MCPServerOperations.__new__(MCPServerOperations)
ops.registry_client = mock.MagicMock()
return ops

def test_valid_server(self):
"""Server found in registry → valid."""
ops = self._make_ops()
ops.registry_client.find_server_by_reference.return_value = {"id": "abc", "name": "srv"}

valid, invalid = ops.validate_servers_exist(["io.github.test/srv"])
self.assertEqual(valid, ["io.github.test/srv"])
self.assertEqual(invalid, [])

def test_missing_server(self):
"""Server not in registry (None) → invalid."""
ops = self._make_ops()
ops.registry_client.find_server_by_reference.return_value = None

valid, invalid = ops.validate_servers_exist(["io.github.test/no-such"])
self.assertEqual(valid, [])
self.assertEqual(invalid, ["io.github.test/no-such"])

def test_network_error_assumes_valid(self):
"""Transient network error → assume server valid (not invalid)."""
ops = self._make_ops()
ops.registry_client.find_server_by_reference.side_effect = requests.ConnectionError("flaky")

valid, invalid = ops.validate_servers_exist(["io.github.test/flaky-srv"])
self.assertEqual(valid, ["io.github.test/flaky-srv"])
self.assertEqual(invalid, [])

def test_mixed_results(self):
"""Mix of found, missing, and errored servers."""
ops = self._make_ops()

def side_effect(ref):
if ref == "found":
return {"id": "1", "name": "found"}
if ref == "missing":
return None
raise requests.Timeout("timeout")

ops.registry_client.find_server_by_reference.side_effect = side_effect

valid, invalid = ops.validate_servers_exist(["found", "missing", "flaky"])
self.assertEqual(sorted(valid), ["flaky", "found"])
self.assertEqual(invalid, ["missing"])


if __name__ == "__main__":
unittest.main()
Loading