Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import shutil
import subprocess
import tomllib
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any
from typing import cast

import pytest
import yaml
Expand Down Expand Up @@ -66,19 +67,19 @@ def get_default_command_list(test_dir: Path) -> list[str]:
]


def load_copier_answers(project_dir: Path) -> dict[str, Any]:
def load_copier_answers(project_dir: Path) -> dict[str, object]:
"""Load ``.copier-answers.yml`` from a generated project."""
answers_path = project_dir / ".copier-answers.yml"
assert answers_path.is_file(), f"Missing {answers_path}"
raw = yaml.safe_load(answers_path.read_text(encoding="utf-8"))
assert isinstance(raw, dict)
return raw
raw = cast(object, yaml.safe_load(answers_path.read_text(encoding="utf-8")))
raw_map = require_mapping(raw, name="copier_answers")
return dict(raw_map)


def git_commit_all(project_dir: Path, message: str) -> None:
"""Create a commit containing all tracked and new files (initial project snapshot)."""
run_command(["git", "add", "-A"], cwd=project_dir)
run_command(
_ = run_command(["git", "add", "-A"], cwd=project_dir)
_ = run_command(
[
"git",
"-c",
Expand All @@ -105,18 +106,32 @@ def copy_with_data(
if skip_tasks:
cmd.append("--skip-tasks")
for key, value in data.items():
if isinstance(value, bool):
rendered = "true" if value else "false"
else:
rendered = str(value)
rendered = ("true" if value else "false") if isinstance(value, bool) else str(value)
cmd.extend(["--data", f"{key}={rendered}"])
run_command(cmd)
_ = run_command(cmd)


def load_pyproject(project_dir: Path) -> dict[str, Any]:
def load_pyproject(project_dir: Path) -> dict[str, object]:
"""Parse ``pyproject.toml`` from a generated project."""
with (project_dir / "pyproject.toml").open("rb") as handle:
return tomllib.load(handle)
raw = cast(object, tomllib.load(handle))
assert isinstance(raw, dict)
return cast(dict[str, object], raw)


def require_mapping(value: object, *, name: str) -> Mapping[str, object]:
if not isinstance(value, Mapping):
raise AssertionError(f"{name} must be a mapping, got {type(value).__name__}")
value_map = cast(Mapping[object, object], value)
if not all(isinstance(key, str) for key in value_map):
raise AssertionError(f"{name} must have string keys")
return cast(Mapping[str, object], value_map)


def require_sequence(value: object, *, name: str) -> Sequence[object]:
if not isinstance(value, Sequence) or isinstance(value, (str, bytes, bytearray)):
raise AssertionError(f"{name} must be a sequence, got {type(value).__name__}")
return value


def test_skip_if_exists_preserves_readme_on_update() -> None:
Expand Down Expand Up @@ -202,7 +217,8 @@ def test_generate_defaults_only_cli(tmp_path: Path) -> None:
assert answers.get("package_name") == "my_library"
assert answers.get("project_slug") == "my-library"
pyproject = load_pyproject(test_dir)
assert pyproject["project"]["name"] == "my_library"
project = require_mapping(pyproject.get("project"), name="pyproject.project")
assert project["name"] == "my_library"


def test_codecov_token_not_stored_in_answers_file(tmp_path: Path) -> None:
Expand Down Expand Up @@ -400,25 +416,25 @@ def test_copier_update_exits_zero_after_copy_and_commit(tmp_path: Path) -> None:
template_repo = tmp_path / "template_repo"
test_dir = tmp_path / "update_clean"

shutil.copytree(
_ = shutil.copytree(
Path("."),
template_repo,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns(
".git", ".venv", "__pycache__", "*.pyc", ".ruff_cache", ".pytest_cache"
),
)
run_command(["git", "init"], cwd=template_repo)
run_command(["git", "config", "user.email", "test@example.com"], cwd=template_repo)
run_command(["git", "config", "user.name", "Template Test"], cwd=template_repo)
run_command(["git", "add", "-A"], cwd=template_repo)
run_command(
_ = run_command(["git", "init"], cwd=template_repo)
_ = run_command(["git", "config", "user.email", "test@example.com"], cwd=template_repo)
_ = run_command(["git", "config", "user.name", "Template Test"], cwd=template_repo)
_ = run_command(["git", "add", "-A"], cwd=template_repo)
_ = run_command(
["git", "commit", "--no-verify", "-m", "chore: init template repo"],
cwd=template_repo,
)

vcs_src = f"git+file://{template_repo}"
run_command(
_ = run_command(
[
"copier",
"copy",
Expand All @@ -434,7 +450,7 @@ def test_copier_update_exits_zero_after_copy_and_commit(tmp_path: Path) -> None:
]
)

run_command(["git", "init"], cwd=test_dir)
_ = run_command(["git", "init"], cwd=test_dir)
git_commit_all(test_dir, "chore: initial generated project")

result = run_command(
Expand Down Expand Up @@ -491,16 +507,17 @@ def test_pyproject_and_tree_match_explicit_copy_data(tmp_path: Path) -> None:
)

pyproject = load_pyproject(test_dir)
proj = pyproject["project"]
proj = require_mapping(pyproject.get("project"), name="pyproject.project")
assert proj["name"] == "ocean_buoy"
assert proj["description"] == "Marine sensor ingestion"
assert proj["requires-python"] == ">=3.13"
assert proj["license"] == {"text": "Apache-2.0"}
authors = proj["authors"]
assert len(authors) == 1
assert authors[0] == {"name": "Harbor Lab", "email": "dev@harbor.lab"}
authors_seq = require_sequence(proj["authors"], name="pyproject.project.authors")
assert len(authors_seq) == 1
assert authors_seq[0] == {"name": "Harbor Lab", "email": "dev@harbor.lab"}

deps: list[str] = proj["dependencies"]
deps_seq = require_sequence(proj["dependencies"], name="pyproject.project.dependencies")
deps = [cast(str, d) for d in deps_seq]
assert not any("pandas" in d for d in deps)
assert not any("numpy" in d for d in deps)

Expand All @@ -510,6 +527,6 @@ def test_pyproject_and_tree_match_explicit_copy_data(tmp_path: Path) -> None:

readme = (test_dir / "README.md").read_text(encoding="utf-8")
assert "Ocean Buoy" in readme
urls = proj["urls"]
urls = require_mapping(proj["urls"], name="pyproject.project.urls")
assert urls["Homepage"] == "https://github.com/harbor-lab/ocean-buoy"
assert urls["Repository"] == "https://github.com/harbor-lab/ocean-buoy"