Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions src/py_code_mode/artifacts/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def _index_key(self) -> str:
"""Build index hash key."""
return f"{self._prefix}{self.INDEX_SUFFIX}"

def _decode_entry(self, entry_json: str | bytes) -> dict[str, Any]:
"""Decode a stored index entry."""
if isinstance(entry_json, bytes):
entry_json = entry_json.decode()
return cast(dict[str, Any], json.loads(entry_json))

def _drop_index_entries(self, names: list[str]) -> None:
"""Remove stale metadata entries from the Redis hash."""
if names:
self._redis.hdel(self._index_key(), *names)

def _payload_exists(self, name: str) -> bool:
"""Check whether the tracked payload key still exists."""
return bool(self._redis.exists(self._data_key(name)))

def save(
self,
name: str,
Expand Down Expand Up @@ -107,21 +122,20 @@ def load(self, name: str) -> Any:
Raises:
ArtifactNotFoundError: If artifact doesn't exist.
"""
entry_json = cast(str | bytes | None, self._redis.hget(self._index_key(), name))
if entry_json is None:
raise ArtifactNotFoundError(name)

entry = self._decode_entry(entry_json)
data_key = self._data_key(name)
content = cast(str | bytes | None, self._redis.get(data_key))

if content is None:
self._drop_index_entries([name])
raise ArtifactNotFoundError(name)

# Check metadata for data type
data_type = None
try:
entry_json = cast(str | bytes | None, self._redis.hget(self._index_key(), name))
if entry_json and isinstance(entry_json, str | bytes):
entry = json.loads(entry_json)
data_type = entry.get("metadata", {}).get("_data_type")
except (json.JSONDecodeError, TypeError):
pass # Use fallback logic below
data_type = entry.get("metadata", {}).get("_data_type")

# Load based on stored type
if data_type == "bytes":
Expand Down Expand Up @@ -150,7 +164,11 @@ def get(self, name: str) -> Artifact | None:
if entry_json is None:
return None

entry = json.loads(entry_json)
if not self._payload_exists(name):
self._drop_index_entries([name])
return None

entry = self._decode_entry(entry_json)
return Artifact(
name=name,
path=self._data_key(name),
Expand All @@ -170,12 +188,14 @@ def list(self) -> list[Artifact]:
return []

artifacts = []
stale_names: list[str] = []
for name, entry_json in index_data.items():
if isinstance(name, bytes):
name = name.decode()
if isinstance(entry_json, bytes):
entry_json = entry_json.decode()
entry = json.loads(entry_json)
if not self._payload_exists(name):
stale_names.append(name)
continue
entry = self._decode_entry(entry_json)
artifacts.append(
Artifact(
name=name,
Expand All @@ -185,6 +205,7 @@ def list(self) -> list[Artifact]:
created_at=datetime.fromisoformat(entry["created_at"]),
)
)
self._drop_index_entries(stale_names)
return artifacts

def exists(self, name: str) -> bool:
Expand All @@ -194,9 +215,14 @@ def exists(self, name: str) -> bool:
name: Artifact name.

Returns:
True if artifact exists in index.
True if artifact is tracked in metadata and its payload still exists.
"""
return bool(self._redis.hexists(self._index_key(), name))
if not self._redis.hexists(self._index_key(), name):
return False
if not self._payload_exists(name):
self._drop_index_entries([name])
return False
return True

def delete(self, name: str) -> None:
"""Delete artifact and its index entry.
Expand Down
56 changes: 32 additions & 24 deletions src/py_code_mode/deps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def _normalize_package_name(package: str) -> str:
return normalized_name + rest


def _package_base_name(package: str) -> str:
"""Return the normalized package name without extras or version specifiers."""
normalized = _normalize_package_name(package)
return re.split(r"[\[<>=!~]", normalized, maxsplit=1)[0]


def _compute_hash(packages: list[str]) -> str:
"""Compute SHA256 hash of sorted package list.

Expand Down Expand Up @@ -187,12 +193,12 @@ def add(self, package: str) -> None:
normalized = _normalize_package_name(package)

# Check for duplicates with different formatting
base_name = re.split(r"[\[<>=!~]", normalized)[0]
base_name = _package_base_name(normalized)

# Remove existing entries with same base name
to_remove = []
for existing in self._packages:
existing_base = re.split(r"[\[<>=!~]", existing)[0]
existing_base = _package_base_name(existing)
if existing_base == base_name:
to_remove.append(existing)

Expand All @@ -210,13 +216,12 @@ def remove(self, package: str) -> bool:
if not package or not package.strip():
return False

normalized = _normalize_package_name(package)
base_name = re.split(r"[\[<>=!~]", normalized)[0]
base_name = _package_base_name(package)

# Find and remove any package with matching base name
to_remove = []
for existing in self._packages:
existing_base = re.split(r"[\[<>=!~]", existing)[0]
existing_base = _package_base_name(existing)
if existing_base == base_name:
to_remove.append(existing)

Expand Down Expand Up @@ -247,11 +252,10 @@ def exists(self, package: str) -> bool:
if not package or not package.strip():
return False

normalized = _normalize_package_name(package)
base_name = re.split(r"[\[<>=!~]", normalized)[0]
base_name = _package_base_name(package)

for existing in self._packages:
existing_base = re.split(r"[\[<>=!~]", existing)[0]
existing_base = _package_base_name(existing)
if existing_base == base_name:
return True
return False
Expand All @@ -270,10 +274,11 @@ def __init__(self, base_path: Path) -> None:
self._deps_dir = self._base_path / "deps"
self._requirements_file = self._deps_dir / "requirements.txt"
self._packages: set[str] = set()
self._load()
self._refresh()

def _load(self) -> None:
"""Load packages from requirements.txt if it exists."""
def _refresh(self) -> None:
"""Reload packages from requirements.txt so external edits stay visible."""
self._packages = set()
if not self._requirements_file.exists():
return

Expand All @@ -294,21 +299,23 @@ def _save(self) -> None:

def list(self) -> list[str]:
"""Return list of all packages."""
self._refresh()
return list(self._packages)

def add(self, package: str) -> None:
"""Add a package to the store."""
self._refresh()
_validate_package_name(package)
normalized = _normalize_package_name(package)

# Check for duplicates with different formatting (e.g., my_package vs my-package)
# Extract base name without version specifiers for comparison
base_name = re.split(r"[\[<>=!~]", normalized)[0]
base_name = _package_base_name(normalized)

# Remove existing entries with same base name
to_remove = []
for existing in self._packages:
existing_base = re.split(r"[\[<>=!~]", existing)[0]
existing_base = _package_base_name(existing)
if existing_base == base_name:
to_remove.append(existing)

Expand All @@ -324,16 +331,16 @@ def remove(self, package: str) -> bool:
Matches by base package name, so remove("requests") will remove
"requests>=2.0" if present.
"""
self._refresh()
if not package or not package.strip():
return False

normalized = _normalize_package_name(package)
base_name = re.split(r"[\[<>=!~]", normalized)[0]
base_name = _package_base_name(package)

# Find and remove any package with matching base name
to_remove = []
for existing in self._packages:
existing_base = re.split(r"[\[<>=!~]", existing)[0]
existing_base = _package_base_name(existing)
if existing_base == base_name:
to_remove.append(existing)

Expand All @@ -347,11 +354,13 @@ def remove(self, package: str) -> bool:

def clear(self) -> None:
"""Remove all packages from the store."""
self._refresh()
self._packages.clear()
self._save()

def hash(self) -> str:
"""Compute hash of current package list."""
self._refresh()
return _compute_hash(list(self._packages))

def exists(self, package: str) -> bool:
Expand All @@ -363,14 +372,14 @@ def exists(self, package: str) -> bool:
Returns:
True if package (by base name) exists, False otherwise.
"""
self._refresh()
if not package or not package.strip():
return False

normalized = _normalize_package_name(package)
base_name = re.split(r"[\[<>=!~]", normalized)[0]
base_name = _package_base_name(package)

for existing in self._packages:
existing_base = re.split(r"[\[<>=!~]", existing)[0]
existing_base = _package_base_name(existing)
if existing_base == base_name:
return True
return False
Expand Down Expand Up @@ -410,12 +419,12 @@ def add(self, package: str) -> None:
normalized = _normalize_package_name(package)

# Check for duplicates with different formatting
base_name = re.split(r"[\[<>=!~]", normalized)[0]
base_name = _package_base_name(normalized)

# Get all members and check for conflicts
existing = self.list()
for pkg in existing:
existing_base = re.split(r"[\[<>=!~]", pkg)[0]
existing_base = _package_base_name(pkg)
if existing_base == base_name:
self._redis.srem(self._key, pkg)

Expand All @@ -430,14 +439,13 @@ def remove(self, package: str) -> bool:
if not package or not package.strip():
return False

normalized = _normalize_package_name(package)
base_name = re.split(r"[\[<>=!~]", normalized)[0]
base_name = _package_base_name(package)

# Find and remove any package with matching base name
existing = self.list()
removed = False
for pkg in existing:
existing_base = re.split(r"[\[<>=!~]", pkg)[0]
existing_base = _package_base_name(pkg)
if existing_base == base_name:
self._redis.srem(self._key, pkg)
removed = True
Expand Down
4 changes: 2 additions & 2 deletions src/py_code_mode/execution/container/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,13 @@ async def api_add_dep(self, package: str) -> dict[str, Any]:
return response.json()

async def api_remove_dep(self, package: str) -> dict[str, Any]:
"""Remove a package from configuration.
"""Remove a package from configuration and uninstall it.

Args:
package: Package specification to remove.

Returns:
Dict with removal status.
Dict with keys: removed, not_found, failed, removed_from_config.

Raises:
RuntimeError: If removal is disabled.
Expand Down
4 changes: 2 additions & 2 deletions src/py_code_mode/execution/container/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,15 +754,15 @@ async def add_dep(self, package: str) -> dict[str, Any]:
return await self._client.api_add_dep(package)

async def remove_dep(self, package: str) -> dict[str, Any]:
"""Remove a package from configuration.
"""Remove a package from configuration and uninstall it.

This respects allow_runtime_deps configuration on the server.

Args:
package: Package specification to remove.

Returns:
Dict with removal status.
Dict with keys: removed, not_found, failed, removed_from_config.

Raises:
RuntimeError: If container is not started or runtime deps are disabled.
Expand Down
Loading
Loading