diff --git a/.github/workflows/amorphouspy_api.yml b/.github/workflows/amorphouspy_api.yml index ab87b17e..ceaebab9 100644 --- a/.github/workflows/amorphouspy_api.yml +++ b/.github/workflows/amorphouspy_api.yml @@ -36,8 +36,21 @@ jobs: shell: bash -l {0} working-directory: amorphouspy_api run: | - amorphouspy_INTEGRATION=1 uvicorn amorphouspy_api.app:app --port 8002 & - pytest -m integration -s --durations=0 --cov=src/amorphouspy_api --cov-report=xml --cov-report=term --cov-append + cat > test.sh << 'EOF' + #!/bin/bash + uvicorn amorphouspy_api.app:app --port 8002 & + pytest -m integration -s \ + --durations=0 \ + --cov=src/amorphouspy_api \ + --cov-report=xml \ + --cov-report=term \ + --cov-append + EOF + chmod +x test.sh + flux start ./test.sh + env: + AMORPHOUSPY_INTEGRATION: "1" + EXECUTOR_TYPE: "flux" - name: Pytest coverage comment uses: MishaKav/pytest-coverage-comment@main diff --git a/amorphouspy/src/amorphouspy/analysis/bond_angle_distribution.py b/amorphouspy/src/amorphouspy/analysis/bond_angle_distribution.py index a6db409e..3b6aeb79 100644 --- a/amorphouspy/src/amorphouspy/analysis/bond_angle_distribution.py +++ b/amorphouspy/src/amorphouspy/analysis/bond_angle_distribution.py @@ -49,7 +49,7 @@ def compute_angles( >>> bins, hist = compute_angles(structure, center_type=1, neighbor_type=2, cutoff=3.0) """ - ids, types, coords, box_size = get_properties_for_structure_analysis(structure) + _ids, types, coords, box_size = get_properties_for_structure_analysis(structure) neighbors = get_neighbors( coords, diff --git a/amorphouspy/src/amorphouspy/analysis/cavities.py b/amorphouspy/src/amorphouspy/analysis/cavities.py index 31eae7ac..1ef6cc26 100644 --- a/amorphouspy/src/amorphouspy/analysis/cavities.py +++ b/amorphouspy/src/amorphouspy/analysis/cavities.py @@ -56,7 +56,7 @@ def compute_cavities( """ # Extract properties using the provided helper - ids, types, coords, box_size = get_properties_for_structure_analysis(structure) + _ids, types, coords, box_size = get_properties_for_structure_analysis(structure) type_dict = type_to_dict(types) # Use a context manager to ensure the temporary file is cleaned up diff --git a/amorphouspy/src/amorphouspy/analysis/cte.py b/amorphouspy/src/amorphouspy/analysis/cte.py index 860a4e00..39d16fc9 100644 --- a/amorphouspy/src/amorphouspy/analysis/cte.py +++ b/amorphouspy/src/amorphouspy/analysis/cte.py @@ -126,7 +126,7 @@ def cte_from_volume_temperature_data( volume = np.array(volume)[sorted_indices] # fit and calculate CTE - slope, intercept = np.polyfit(temperature, volume, 1) + slope, _intercept = np.polyfit(temperature, volume, 1) CTE = slope / volume[0] return float(CTE) diff --git a/amorphouspy/src/amorphouspy/analysis/radial_distribution_functions.py b/amorphouspy/src/amorphouspy/analysis/radial_distribution_functions.py index 78e00550..a66da3ab 100644 --- a/amorphouspy/src/amorphouspy/analysis/radial_distribution_functions.py +++ b/amorphouspy/src/amorphouspy/analysis/radial_distribution_functions.py @@ -116,7 +116,7 @@ def compute_rdf( >>> r, rdfs, cn = compute_rdf(structure, r_max=10.0, n_bins=500) """ - ids, types, coords, box_size = get_properties_for_structure_analysis(structure) + _ids, types, coords, box_size = get_properties_for_structure_analysis(structure) # Input validation and type conversion coords = np.asarray(coords, dtype=np.float64) types = np.asarray(types, dtype=np.int64) diff --git a/amorphouspy/src/amorphouspy/analysis/rings.py b/amorphouspy/src/amorphouspy/analysis/rings.py index c0f8b5b9..6e3adfd7 100644 --- a/amorphouspy/src/amorphouspy/analysis/rings.py +++ b/amorphouspy/src/amorphouspy/analysis/rings.py @@ -68,7 +68,7 @@ def compute_guttmann_rings( ... ) """ - ids, types, coords, box_size = get_properties_for_structure_analysis(structure) + _ids, types, coords, box_size = get_properties_for_structure_analysis(structure) type_dict = type_to_dict(types) with tempfile.NamedTemporaryFile("w+", suffix=".xyz", delete=True) as tmp: write_xyz(filename=tmp.name, coords=coords, types=types, box_size=box_size, type_dict=type_dict) diff --git a/amorphouspy/src/amorphouspy/structure.py b/amorphouspy/src/amorphouspy/structure.py index 207f013a..1e4c533f 100644 --- a/amorphouspy/src/amorphouspy/structure.py +++ b/amorphouspy/src/amorphouspy/structure.py @@ -172,7 +172,7 @@ def _integer_fu_from_total(Nfu_target: int, mol_frac: dict[str, float]) -> dict[ n = {ox: int(np.floor(w[ox])) for ox in x} rem = Nfu_target - sum(n.values()) if rem > 0: - order = sorted(x.keys(), key=lambda k: (w[k] - n[k]), reverse=True) + order = sorted(x.keys(), key=lambda k: w[k] - n[k], reverse=True) for i in range(rem): n[order[i % len(order)]] += 1 return n diff --git a/amorphouspy/src/amorphouspy/workflows/structural_analysis.py b/amorphouspy/src/amorphouspy/workflows/structural_analysis.py index 24bfeb95..90f659b8 100644 --- a/amorphouspy/src/amorphouspy/workflows/structural_analysis.py +++ b/amorphouspy/src/amorphouspy/workflows/structural_analysis.py @@ -203,7 +203,7 @@ def analyze_structure(atoms: Atoms) -> StructureData: # noqa: C901, PLR0912, PL total_mass_g = atoms.get_masses().sum() / avogadro_number # Convert amu to g density = total_mass_g / volume_cm3 - type_map, network_formers, modifiers, oxygen_present = _classify_elements(unique_z) + type_map, network_formers, modifiers, _oxygen_present = _classify_elements(unique_z) former_types = [z for z, sym in type_map.items() if sym in network_formers] modifier_types = [z for z, sym in type_map.items() if sym in modifiers] O_type = [z for z, sym in type_map.items() if sym == "O"] diff --git a/amorphouspy/src/amorphouspy/workflows/viscosity.py b/amorphouspy/src/amorphouspy/workflows/viscosity.py index 9e0dc538..2a00ae80 100644 --- a/amorphouspy/src/amorphouspy/workflows/viscosity.py +++ b/amorphouspy/src/amorphouspy/workflows/viscosity.py @@ -209,7 +209,7 @@ def viscosity_simulation( ) # Stage 2: Production simulation for viscosity at T - structure_final, parsed_output = _run_lammps_md( + _structure_final, parsed_output = _run_lammps_md( structure=structure1, potential=potential, tmp_working_directory=tmp_working_directory, diff --git a/amorphouspy/src/tests/test_structure.py b/amorphouspy/src/tests/test_structure.py index 52855b16..89bfa229 100644 --- a/amorphouspy/src/tests/test_structure.py +++ b/amorphouspy/src/tests/test_structure.py @@ -66,7 +66,7 @@ def test_structure_atom_counts_molar() -> None: assert atom_counts[elem] == expected, f"{elem} atoms should be {expected} for {n_molecules} mode." # Test with target_atoms - atoms, atom_counts = ps.create_random_atoms( + _atoms, atom_counts = ps.create_random_atoms( composition=composition, n_molecules=None, target_atoms=target_atoms, @@ -110,7 +110,7 @@ def test_structure_atom_counts_weight() -> None: assert atom_counts[elem] == expected, f"{elem} atoms should be {expected} for {n_molecules} mode." # Test with target_atoms - atoms, atom_counts = ps.create_random_atoms( + _atoms, atom_counts = ps.create_random_atoms( composition=weight_composition, n_molecules=None, target_atoms=target_atoms, diff --git a/amorphouspy_api/README.md b/amorphouspy_api/README.md index 4c1e57a1..a9cb0a74 100644 --- a/amorphouspy_api/README.md +++ b/amorphouspy_api/README.md @@ -10,11 +10,11 @@ This FastAPI-based service provides a Model Context Protocol (MCP) interface for ``` ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ -│ FastAPI App │ ── │ SQLite Cache │ ── │ Worker Process │ +│ FastAPI App │ ── │ SQLite Cache │ ── │ executorlib │ │ │ │ │ │ │ -│ • Request hash │ │ • Task metadata │ │ • amorphouspy │ -│ • Cache lookup │ │ • Results │ │ • LAMMPS sims │ -│ • Task creation │ │ • Hash index │ │ • File cleanup │ +│ • Request hash │ │ • Task metadata │ │ • Local exec │ +│ • Cache lookup │ │ • Results │ │ • SLURM cluster │ +│ • Task creation │ │ • Hash index │ │ • Job caching │ └─────────────────┘ └─────────────────┘ └─────────────────┘ ``` @@ -32,17 +32,28 @@ This FastAPI-based service provides a Model Context Protocol (MCP) interface for - Tracks task states: `processing` → `complete`/`error` - Survives server restarts and process crashes -#### 3. **Async Processing with Process Isolation** -- Uses `ProcessPoolExecutor` to run simulations in separate processes -- Avoids blocking the FastAPI event loop -- Proper signal handling for subprocess management -- Automatic temporary file cleanup using `tempfile.TemporaryDirectory()` +#### 3. **Job Execution with executorlib** +- Supports local execution (`SingleNodeExecutor`) or SLURM cluster (`SlurmClusterExecutor`) +- Executor type configured via environment variables +- Built-in job caching at the executor level +- Re-submitting same job returns cached result or running future #### 4. **Model Context Protocol (MCP) Integration** - Exposes simulation capabilities as MCP tools - Compatible with Claude, VS Code, and other MCP clients - Server-Sent Events (SSE) endpoint at `/mcp` +## Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `EXECUTOR_TYPE` | Executor backend: `local` or `slurm` | `local` | +| `EXECUTOR_CORES` | Number of CPU cores per worker | `4` | +| `SLURM_PARTITION` | SLURM partition name (slurm only) | - | +| `SLURM_TIME` | SLURM job time limit (slurm only) | - | +| `AMORPHOUSPY_PROJECTS` | Directory for project/cache files | `./projects` | +| `API_BASE_URL` | Base URL for visualization links | - | + ## Installation diff --git a/amorphouspy_api/conftest.py b/amorphouspy_api/conftest.py new file mode 100644 index 00000000..80f81107 --- /dev/null +++ b/amorphouspy_api/conftest.py @@ -0,0 +1,22 @@ +"""Shared test fixtures for amorphouspy_api tests.""" + +from pathlib import Path + +import pytest + +from amorphouspy_api.database import close_task_store, init_task_store + + +@pytest.fixture(autouse=True) +def _fresh_task_store(tmp_path: Path) -> None: + """Provide a fresh temporary task store for every test. + + This ensures tests are isolated from each other and from any + persistent database left over from previous runs. + """ + # Re-initialise the singleton so every call to get_task_store() + # (in routers, visualization, tests, …) returns the fresh instance. + db_path = tmp_path / "test_tasks.db" + init_task_store(db_path) + yield + close_task_store() diff --git a/amorphouspy_api/pyproject.toml b/amorphouspy_api/pyproject.toml index 8c7cee6f..8329462e 100644 --- a/amorphouspy_api/pyproject.toml +++ b/amorphouspy_api/pyproject.toml @@ -32,7 +32,7 @@ markers = [ addopts = ["-m", "not integration"] filterwarnings = [ "ignore::DeprecationWarning:defusedxml.*", - "ignore:.*__get_pydantic_core_schema__.*", + "ignore::pydantic.PydanticDeprecatedSince211", "ignore:.*multi-threaded.*fork.*", ] diff --git a/amorphouspy_api/src/amorphouspy_api/app.py b/amorphouspy_api/src/amorphouspy_api/app.py index e1b87efa..5aafef46 100644 --- a/amorphouspy_api/src/amorphouspy_api/app.py +++ b/amorphouspy_api/src/amorphouspy_api/app.py @@ -1,160 +1,56 @@ """amorphouspy Simulation API. -This module provides a FastAPI server for managing long-running glass simulation tasks. -It supports meltquench simulations for multi-component oxide glasses using the PMMCS -interatomic potential from Pedone et al. - -Supported simulation types: - - Meltquench simulations: Complete heating/cooling cycles for glass formation - -Supported elements (PMMCS potential): - Ag, Al, Ba, Be, Ca, Co, Cr, Cu, Er, Fe, Fe3, Gd, Ge, K, Li, Mg, Mn, Na, Nd, Ni, O, P, Sc, Si, Sn, Sr, Ti, Zn, Zr - -Example usage: - 1. Start meltquench: POST /submit_meltquench -> returns task_id - 2. Check status: GET /check/{task_id} -> returns current status or results +FastAPI application that manages long-running glass simulation tasks. +Routers handle the individual simulation types (meltquench, etc.). """ -import asyncio -import concurrent.futures -import hashlib import logging -import os -from importlib.metadata import version +from contextlib import asynccontextmanager from pathlib import Path -from uuid import uuid4 -import cloudpickle -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi_mcp import FastApiMCP -from .database import get_task_store, init_task_store -from .models import MeltquenchRequest, MeltquenchResult +from .config import DB_PATH, PROJECTS_FOLDER +from .database import close_task_store, init_task_store +from .routers.meltquench import router as meltquench_router from .visualization import router as visualization_router -from .worker import meltquench_worker -# Configure logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler(), logging.FileHandler("glass_api.log")], -) +# Configure logging - use stream handler by default, file handler only if not in test logger = logging.getLogger(__name__) +if not logger.handlers: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) + logger.addHandler(handler) + logger.setLevel(logging.INFO) -# Get amorphouspy version for project directory naming -try: - amorphouspy_version = version("amorphouspy") - logger.info("Using amorphouspy version: %s", amorphouspy_version) -except Exception: - amorphouspy_version = "unknown" - logger.warning("Could not determine amorphouspy version, using 'unknown'") - -# Setup shared project directory -PROJECTS_FOLDER = Path(__file__).resolve().parent.parent.parent / "projects" - -# Check for AMORPHOUSPY_PROJECTS environment variable -if "AMORPHOUSPY_PROJECTS" in os.environ: - PROJECTS_FOLDER = Path(os.environ["AMORPHOUSPY_PROJECTS"]) - logger.info("Using project directory from AMORPHOUSPY_PROJECTS: %s", PROJECTS_FOLDER) -else: - logger.info("Using default project directory: %s", PROJECTS_FOLDER) - -MELTQUENCH_PROJECT_DIR = PROJECTS_FOLDER / f"amorphouspy_{amorphouspy_version}" / "meltquench" - - -# Configure API base URL for visualization links -API_BASE_URL = os.environ.get("API_BASE_URL", "") -if API_BASE_URL: - logger.info("Using API base URL for visualization links: %s", API_BASE_URL) -else: - logger.info("No API base URL configured, using relative paths") +logger.info("Using project directory: %s", PROJECTS_FOLDER) # Ensure the projects directory exists PROJECTS_FOLDER.mkdir(parents=True, exist_ok=True) -logger.info("Ensured projects directory exists: %s", PROJECTS_FOLDER) - -# Initialize persistent task store -DB_PATH = PROJECTS_FOLDER / "tasks.db" -logger.info("Task store database path: %s", DB_PATH) -logger.info( - "Directory exists: %s, Directory writable: %s", - PROJECTS_FOLDER.exists(), - os.access(PROJECTS_FOLDER, os.W_OK) if PROJECTS_FOLDER.exists() else "N/A", -) -init_task_store(DB_PATH) -_task_store = get_task_store() - - -def get_meltquench_hash(request: MeltquenchRequest) -> str: - """Compute hash for a meltquench request to enable caching. - Args: - request: The meltquench request object to hash. - Returns: - First 16 characters of the SHA256 hash of the request parameters. - """ - # Create sorted component-value pairs for consistent hashing - comp_value_pairs = sorted(zip(request.components, request.values, strict=True)) +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan - startup and shutdown.""" + # Startup: Initialize persistent task store + logger.info("Task store database path: %s", DB_PATH) + init_task_store(DB_PATH) + yield + # Shutdown: Close database connections + logger.info("Closing task store database connection") + close_task_store() - hash_params = { - "composition": comp_value_pairs, - "unit": request.unit, - "heating_rate": request.heating_rate, - "cooling_rate": request.cooling_rate, - "n_print": request.n_print, - "n_atoms": request.n_atoms, - } - # Use cloudpickle for consistent serialization, then hash with sha256 - binary_data = cloudpickle.dumps(hash_params) - return hashlib.sha256(binary_data).hexdigest()[:16] # First 16 chars for brevity - - -def get_visualization_url(task_id: str) -> str: - """Construct the full visualization URL for a given task ID. - - Args: - task_id: The unique identifier for the task. - - Returns: - The full URL or relative path to the visualization page. - """ - relative_path = f"/visualize/meltquench/{task_id}" - if API_BASE_URL: - # Remove trailing slash from base URL if present, then combine - base_url = API_BASE_URL.rstrip("/") - return f"{base_url}{relative_path}" - return relative_path - - -async def _meltquench_worker(task_id: str, request: MeltquenchRequest) -> None: - """Async wrapper for meltquench simulation that runs the synchronous worker in a process executor. - - Args: - task_id: Unique identifier for the task - request: Validated meltquench parameters - """ - loop = asyncio.get_event_loop() - - # Convert request to dict for serialization across processes - request_dict = request.model_dump() - - # Run the synchronous worker in a process executor - with concurrent.futures.ProcessPoolExecutor() as executor: - await loop.run_in_executor( - executor, meltquench_worker, task_id, request_dict, DB_PATH, str(MELTQUENCH_PROJECT_DIR) - ) - - -# Create FastAPI app +# Create FastAPI app with lifespan manager app = FastAPI( title="amorphouspy Simulation API", description="API for managing long-running glass simulation tasks using amorphouspy", version="0.1.0", + lifespan=lifespan, ) # Enable CORS for all origins (customize as needed) @@ -170,132 +66,16 @@ async def _meltquench_worker(task_id: str, request: MeltquenchRequest) -> None: static_dir = Path(__file__).parent / "static" app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") -# Include visualization router +# Include routers +app.include_router(meltquench_router, tags=["meltquench"]) app.include_router(visualization_router, tags=["visualization"]) -@app.post("/cache/meltquench", tags=["tool"]) -async def check_cached_result(request: MeltquenchRequest) -> MeltquenchResult | None: - """Check if a result for the given meltquench request is already available in cache. - - Args: - request: The meltquench request to check. - - Returns: - The cached result if found, otherwise None. - - Raises: - HTTPException: If an error occurs during the check. - """ - try: - request_hash = get_meltquench_hash(request) - logger.info("Checking for cached result with hash: %s", request_hash) - - # Use database's efficient hash-based lookup - cached_result = _task_store.find_cached_result(request_hash) - - if cached_result: - logger.info("Found cached result") - # Return just the result, not the task_id (for API compatibility) - return cached_result[1] - - logger.info("No cached result found") - return None - - except Exception: - logger.exception("Error checking cached result") - raise HTTPException(status_code=500, detail="Internal server error") from None - - -@app.post("/submit/meltquench", tags=["tool"]) -async def submit_meltquench(request: MeltquenchRequest) -> dict: - """Start a new meltquench simulation task. - - Note: Results can be visualized at /visualize/meltquench/{task_id} - - Args: - request: The meltquench request parameters. - - Returns: - A dictionary containing the task ID, status, and visualization URL. - - Raises: - HTTPException: If the task cannot be started. - """ - try: - # Check if we already have a cached result - request_hash = get_meltquench_hash(request) - cached_result = _task_store.find_cached_result(request_hash) - - if cached_result: - cached_task_id, cached_meltquench_result = cached_result - logger.info("Returning cached result from task %s instead of starting new task", cached_task_id) - return { - "task_id": cached_task_id, - "status": "completed_from_cache", - "visualization_url": get_visualization_url(cached_task_id), - "result": cached_meltquench_result.model_dump(), - } - - task_id = str(uuid4()) - logger.info("Creating new meltquench task with ID: %s, hash: %s", task_id, request_hash) - - # Store task in database - _task_store.set( - task_id, - { - "state": "processing", - "status": "Initializing", - "request_hash": request_hash, - "request_data": request.model_dump(), # Store original request for reference - }, - ) - - # Always run as background task using process executor - task = asyncio.create_task(_meltquench_worker(task_id, request)) - # Store task reference to prevent garbage collection - task.add_done_callback(lambda _: None) - - return {"task_id": task_id, "status": "started", "visualization_url": get_visualization_url(task_id)} - except Exception: - logger.exception("Error starting meltquench task") - raise HTTPException(status_code=500, detail="Internal server error") from None - - -@app.get("/check/{task_id}", tags=["tool"]) -async def check(task_id: str) -> dict: - """Check the current status of a simulation task by its ID. - - Note: When ready, visualize results at /visualize/meltquench/{task_id} - - Args: - task_id: The ID of the task to check. - - Returns: - A dictionary containing the task status, result (if available), and visualization URL. - - Raises: - HTTPException: If the task is not found. - """ - meta = _task_store.get(task_id) - if not meta: - raise HTTPException(status_code=404, detail="Task not found") - - return { - "task_id": task_id, - "state": meta["state"], - "status": meta.get("status", "processing"), - "visualization_url": get_visualization_url(task_id), - "error": meta.get("error"), - "result": meta.get("result"), - } - - mcp = FastApiMCP(app, include_tags=["tool"]) mcp.mount_http(mount_path="/mcp") @app.get("/") -async def root() -> RedirectResponse: +def root() -> RedirectResponse: """Root endpoint redirects to API documentation.""" return RedirectResponse(url="/docs") diff --git a/amorphouspy_api/src/amorphouspy_api/config.py b/amorphouspy_api/src/amorphouspy_api/config.py new file mode 100644 index 00000000..694f75ee --- /dev/null +++ b/amorphouspy_api/src/amorphouspy_api/config.py @@ -0,0 +1,40 @@ +"""Shared configuration for the amorphouspy API.""" + +import logging +import os +from importlib.metadata import version +from pathlib import Path + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# amorphouspy version (used in project directory naming) +# --------------------------------------------------------------------------- + +try: + amorphouspy_version = version("amorphouspy") + logger.info("Using amorphouspy version: %s", amorphouspy_version) +except Exception: + amorphouspy_version = "unknown" + logger.warning("Could not determine amorphouspy version, using 'unknown'") + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- + +PROJECTS_FOLDER = Path( + os.environ.get( + "AMORPHOUSPY_PROJECTS", + str(Path(__file__).resolve().parent.parent.parent / "projects"), + ), +) + +MELTQUENCH_PROJECT_DIR = PROJECTS_FOLDER / f"amorphouspy_{amorphouspy_version}" / "meltquench" + +DB_PATH = PROJECTS_FOLDER / "tasks.db" + +# --------------------------------------------------------------------------- +# API base URL for visualization links (e.g. behind a reverse proxy) +# --------------------------------------------------------------------------- + +API_BASE_URL = os.environ.get("API_BASE_URL", "") diff --git a/amorphouspy_api/src/amorphouspy_api/database.py b/amorphouspy_api/src/amorphouspy_api/database.py index 18045701..a41c2625 100644 --- a/amorphouspy_api/src/amorphouspy_api/database.py +++ b/amorphouspy_api/src/amorphouspy_api/database.py @@ -12,6 +12,7 @@ from sqlalchemy import JSON, Column, DateTime, Index, String, Text, create_engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker +from sqlalchemy.pool import NullPool from .models import MeltquenchResult, serialize_atoms @@ -46,7 +47,11 @@ class Task(Base): # Timestamps created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) - updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) + updated_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + ) # Index for efficient cache lookups __table_args__ = (Index("ix_request_hash_state", "request_hash", "state"),) @@ -72,10 +77,11 @@ def __init__(self, db_path: Path | None = None) -> None: self.db_url = f"sqlite:///{db_path}" # Create engine with SQLite-specific settings + # Use NullPool to disable connection pooling - ensures connections are properly closed self.engine = create_engine( self.db_url, echo=False, # Set to True for SQL debugging - pool_pre_ping=True, # Verify connections before use + poolclass=NullPool, # Disable connection pooling for better cleanup connect_args={ "check_same_thread": False, # Allow use from multiple threads "timeout": 30, # 30 second timeout for busy database @@ -99,6 +105,12 @@ def _create_tables(self) -> None: logger.exception("Error creating database tables") raise + def close(self) -> None: + """Close the database engine and dispose of all connections.""" + if self.engine: + self.engine.dispose() + logger.info("Closed task store database connection") + def get_session(self) -> Session: """Get a new database session.""" return self.SessionLocal() @@ -179,12 +191,20 @@ def find_cached_result(self, request_hash: str) -> tuple[str, MeltquenchResult] with self.get_session() as session: task = ( session.query(Task) - .filter(Task.request_hash == request_hash, Task.state == "complete", Task.result_data.isnot(None)) + .filter( + Task.request_hash == request_hash, + Task.state == "complete", + Task.result_data.isnot(None), + ) .first() ) if task and task.result_data: - logger.info("Found cached result for hash %s in task %s", request_hash, task.task_id) + logger.info( + "Found cached result for hash %s in task %s", + request_hash, + task.task_id, + ) return (task.task_id, MeltquenchResult(**task.result_data)) return None @@ -208,7 +228,10 @@ def cleanup_old_tasks(self, days: int = 30) -> int: with self.get_session() as session: deleted_count = ( session.query(Task) - .filter(Task.state.in_(["complete", "error"]), Task.updated_at < cutoff_date) + .filter( + Task.state.in_(["complete", "error"]), + Task.updated_at < cutoff_date, + ) .delete() ) @@ -237,6 +260,9 @@ def _task_to_dict(self, task: Task) -> dict[str, Any]: if task.error_message: task_dict["error"] = task.error_message + if task.request_data: + task_dict["request_data"] = task.request_data + return task_dict def _update_task_from_dict(self, task: Task, task_data: dict[str, Any]) -> None: @@ -251,15 +277,19 @@ def _update_task_from_dict(self, task: Task, task_data: dict[str, Any]) -> None: task.request_hash = task_data["request_hash"] if "result" in task_data: - # Handle ASE Atoms serialization in final_structure - result_data = task_data["result"].copy() - if "final_structure" in result_data: - from ase import Atoms - - if isinstance(result_data["final_structure"], Atoms): - # Serialize ASE Atoms to JSON string for storage - result_data["final_structure"] = serialize_atoms(result_data["final_structure"]) - task.result_data = result_data + result = task_data["result"] + if result is not None: + # Handle ASE Atoms serialization in final_structure + result_data = result.copy() + if "final_structure" in result_data: + from ase import Atoms + + if isinstance(result_data["final_structure"], Atoms): + # Serialize ASE Atoms to JSON string for storage + result_data["final_structure"] = serialize_atoms(result_data["final_structure"]) + task.result_data = result_data + else: + task.result_data = None if "error" in task_data: task.error_message = task_data["error"] @@ -297,3 +327,11 @@ def init_task_store(db_path: Path | None = None) -> TaskStore: global _task_store_instance _task_store_instance = TaskStore(db_path) return _task_store_instance + + +def close_task_store() -> None: + """Close and reset the global task store instance.""" + global _task_store_instance + if _task_store_instance is not None: + _task_store_instance.close() + _task_store_instance = None diff --git a/amorphouspy_api/src/amorphouspy_api/jobs.py b/amorphouspy_api/src/amorphouspy_api/jobs.py new file mode 100644 index 00000000..357e0418 --- /dev/null +++ b/amorphouspy_api/src/amorphouspy_api/jobs.py @@ -0,0 +1,107 @@ +"""Job submission utilities for amorphouspy API. + +This module provides utilities for selecting and configuring executorlib executors +(TestClusterExecutor for local or SlurmClusterExecutor for SLURM). + +Both executors use wait=False to allow non-blocking exit from the context manager, +enabling the API to check job status without blocking. + +Configure via environment variables: + EXECUTOR_TYPE: "local" (default) or "slurm" + EXECUTOR_CORES: Number of cores per worker (default: 4) + LAMMPS_CORES: Number of cores for LAMMPS simulations (default: EXECUTOR_CORES or 4) + SLURM_PARTITION: SLURM partition name (optional, slurm only) + SLURM_TIME: SLURM time limit (optional, slurm only) +""" + +import logging +import os +from pathlib import Path +from typing import Any + +import executorlib +from executorlib.api import TestClusterExecutor + +logger = logging.getLogger(__name__) + + +def get_executor_class() -> type: + """Get the appropriate executor class based on environment. + + Note: the executor classes behave differently with respect to cache and `wait`ing: + - Only the SlurmClusterExecutor and the FluxClusterExecutor support cache and `wait`ing as expected + - SingleNodeExecutor: uses socket-based communication, so cache is created only once results are computed + and calling `get_future_from_cache` earlier results in `FileNotFoundError` + - TestClusterExecutor: uses Python's `subprocess` module which does not provide task dependency management. + When chaining futures, the next future is thus submitted only once the previous one is completed + + Returns: + BaseExecutor subclass based on environment. + """ + executor_type = os.environ.get("EXECUTOR_TYPE", "local").lower() + + executor_classes = { + "slurm": executorlib.SlurmClusterExecutor, + "flux": executorlib.FluxClusterExecutor, + "single": executorlib.SingleNodeExecutor, + "test": TestClusterExecutor, + } + + if executor_type not in executor_classes: + msg = f"Unknown EXECUTOR_TYPE '{executor_type}'. Valid options are: {list(executor_classes.keys())}" + raise ValueError(msg) + + return executor_classes[executor_type] + + +def get_executor_config() -> dict[str, Any]: + """Build executor configuration from environment variables. + + Returns: + Dictionary of executor configuration options. + """ + config: dict[str, Any] = {} + cores = os.environ.get("EXECUTOR_CORES") + if cores: + config["cores_per_worker"] = int(cores) + + # SLURM-specific config + if os.environ.get("EXECUTOR_TYPE", "local").lower() == "slurm": + if os.environ.get("SLURM_PARTITION"): + config["partition"] = os.environ["SLURM_PARTITION"] + if os.environ.get("SLURM_TIME"): + config["time"] = os.environ["SLURM_TIME"] + + return config + + +def get_lammps_resource_dict() -> dict[str, Any]: + """Get resource dictionary for LAMMPS simulations. + + Returns: + Dictionary with LAMMPS-specific resource settings. + """ + cores = int(os.environ.get("LAMMPS_CORES", os.environ.get("EXECUTOR_CORES", "4"))) + return {"cores": cores} + + +def get_executor(cache_directory: Path) -> executorlib.BaseExecutor: + """Create a fresh executor instance. + + Args: + cache_directory: Directory for executor disk cache. + + Returns: + The executor instance. + """ + # Create new executor each time to properly detect cached results + executor_class = get_executor_class() + executor_config = get_executor_config() + + logger.info( + "Creating executor: %s with cache_directory=%s", + executor_class.__name__, + cache_directory, + ) + + return executor_class(cache_directory=cache_directory, **executor_config) diff --git a/amorphouspy_api/src/amorphouspy_api/models.py b/amorphouspy_api/src/amorphouspy_api/models.py index b2d25371..2586158c 100644 --- a/amorphouspy_api/src/amorphouspy_api/models.py +++ b/amorphouspy_api/src/amorphouspy_api/models.py @@ -4,13 +4,21 @@ including meltquench simulations and other glass modeling workflows. """ +from enum import StrEnum from io import StringIO from typing import Annotated, Literal from amorphouspy.workflows.structural_analysis import StructureData from ase import Atoms from ase.io import read, write -from pydantic import BaseModel, Field, PlainSerializer, PlainValidator, ValidationInfo, field_validator +from pydantic import ( + BaseModel, + Field, + PlainSerializer, + PlainValidator, + ValidationInfo, + field_validator, +) # Constants for composition validation PERCENTAGE_THRESHOLD = 1.1 @@ -84,7 +92,25 @@ def validate_atoms(v: Atoms | dict | str | None) -> Atoms | None: # Export the serialization functions for use in other modules -__all__ = ["AtomsType", "MeltquenchRequest", "MeltquenchResult", "serialize_atoms", "validate_atoms"] +__all__ = [ + "AtomsType", + "MeltquenchRequest", + "MeltquenchResult", + "TaskResponse", + "TaskStatus", + "serialize_atoms", + "validate_atoms", +] + + +class TaskStatus(StrEnum): + """Status of a simulation task.""" + + STARTED = "started" + RUNNING = "running" + COMPLETED = "completed" + COMPLETED_FROM_CACHE = "completed_from_cache" + ERROR = "error" class MeltquenchRequest(BaseModel): @@ -107,9 +133,13 @@ class MeltquenchRequest(BaseModel): heating_rate: int = Field(default=int(1e14), description="Heating rate in K/s (default: 100K/ps)") cooling_rate: int = Field(default=int(1e12), description="Cooling rate in K/s (default: 1K/ps)") n_print: int = Field(default=1000, description="Print interval for simulation output (default: 1000)") - n_atoms: int = Field(default=5000, description="Target number of atoms for the generated structure (default: 5000)") + n_atoms: int = Field( + default=5000, + description="Target number of atoms for the generated structure (default: 5000)", + ) potential_type: Literal["shik", "bjp", "pmmcs"] = Field( - default="pmmcs", description="Type of interatomic potential to use (default: 'pmmcs')" + default="pmmcs", + description="Type of interatomic potential to use (default: 'pmmcs')", ) @field_validator("values") @@ -151,3 +181,16 @@ class MeltquenchResult(BaseModel): mean_temperature: float = Field(..., description="Mean temperature during final phase (K)") simulation_steps: int = Field(..., description="Total simulation steps completed") structural_analysis: StructureData | dict = Field(..., description="Structural analysis results") + + +class TaskResponse(BaseModel): + """Response model for task submission and status check endpoints. + + Provides a consistent response format for both /submit and /check endpoints. + """ + + task_id: str = Field(..., description="Unique identifier for the task") + status: TaskStatus = Field(..., description="Current status of the task") + visualization_url: str = Field(..., description="URL to visualize results when complete") + result: MeltquenchResult | None = Field(default=None, description="Simulation result if completed") + error: str | None = Field(default=None, description="Error message if failed") diff --git a/amorphouspy_api/src/amorphouspy_api/routers/__init__.py b/amorphouspy_api/src/amorphouspy_api/routers/__init__.py new file mode 100644 index 00000000..78a88058 --- /dev/null +++ b/amorphouspy_api/src/amorphouspy_api/routers/__init__.py @@ -0,0 +1 @@ +"""FastAPI routers for the amorphouspy API.""" diff --git a/amorphouspy_api/src/amorphouspy_api/routers/meltquench.py b/amorphouspy_api/src/amorphouspy_api/routers/meltquench.py new file mode 100644 index 00000000..4a52dd53 --- /dev/null +++ b/amorphouspy_api/src/amorphouspy_api/routers/meltquench.py @@ -0,0 +1,339 @@ +"""Meltquench simulation router. + +Endpoints for submitting, checking, and caching meltquench simulations. +""" + +import hashlib +import logging +from uuid import uuid4 + +import cloudpickle +from executorlib import get_future_from_cache +from fastapi import APIRouter, HTTPException + +from amorphouspy_api.config import API_BASE_URL, MELTQUENCH_PROJECT_DIR +from amorphouspy_api.database import get_task_store +from amorphouspy_api.jobs import get_executor, get_lammps_resource_dict +from amorphouspy_api.models import ( + MeltquenchRequest, + MeltquenchResult, + TaskResponse, + TaskStatus, +) +from amorphouspy_api.workflows import run_meltquench_workflow + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +router = APIRouter() + + +def get_meltquench_hash(request: MeltquenchRequest) -> str: + """Compute hash for a meltquench request to enable caching. + + Args: + request: The meltquench request object to hash. + + Returns: + First 16 characters of the SHA256 hash of the request parameters. + """ + comp_value_pairs = sorted(zip(request.components, request.values, strict=True)) + + hash_params = { + "composition": comp_value_pairs, + "unit": request.unit, + "heating_rate": request.heating_rate, + "cooling_rate": request.cooling_rate, + "n_print": request.n_print, + "n_atoms": request.n_atoms, + } + + binary_data = cloudpickle.dumps(hash_params) + return hashlib.sha256(binary_data).hexdigest()[:16] + + +def get_visualization_url(task_id: str) -> str: + """Construct the full visualization URL for a given task ID. + + Args: + task_id: The unique identifier for the task. + + Returns: + The full URL or relative path to the visualization page. + """ + relative_path = f"/visualize/meltquench/{task_id}" + if API_BASE_URL: + base_url = API_BASE_URL.rstrip("/") + return f"{base_url}{relative_path}" + return relative_path + + +def resolve_future(future, task_id: str) -> dict: + """Extract state, result, and error from a resolved or pending future. + + Args: + future: A concurrent.futures.Future-like object. + task_id: The task identifier (used for logging). + + Returns: + A dict with 'state' and optionally 'result' or 'error' keys. + """ + if not future.done(): + return {"state": "running"} + + exc = future.exception() + if exc is not None: + error_msg = str(exc) + logger.error("Task %s failed: %s", task_id, error_msg) + return {"state": "error", "error": error_msg} + + serialized = MeltquenchResult(**future.result()).model_dump() + return {"state": "complete", "result": serialized} + + +def build_task_response( + task_id: str, + job_status: dict, + *, + from_cache: bool = False, +) -> TaskResponse: + """Build a TaskResponse from job status. + + Args: + task_id: The task identifier. + job_status: Dictionary with 'state', 'result', and 'error' keys. + from_cache: Whether this result was retrieved from cache. + + Returns: + A TaskResponse model instance. + """ + state = job_status["state"] + + if state == "complete": + status = TaskStatus.COMPLETED_FROM_CACHE if from_cache else TaskStatus.COMPLETED + result = MeltquenchResult(**job_status["result"]) if job_status.get("result") else None + elif state == "error": + status = TaskStatus.ERROR + result = None + else: # running + status = TaskStatus.RUNNING + result = None + + return TaskResponse( + task_id=task_id, + status=status, + visualization_url=get_visualization_url(task_id), + result=result, + error=job_status.get("error"), + ) + + +def submit_to_executor( + request_data: dict, + task_id: str, + request_hash: str, + *, + cache_key: str | None = None, +) -> dict: + """Submit a meltquench job to the executor and resolve its status. + + The executor's disk cache (``MELTQUENCH_PROJECT_DIR``) means that a + previously-completed job will have ``done() == True`` immediately. + + Args: + request_data: Dictionary with the meltquench request parameters. + task_id: The unique task identifier. + request_hash: Hash of the request for caching. + cache_key: Optional explicit cache key for the final workflow step, + enabling later retrieval via ``get_future_from_cache``. + + Returns: + A job-status dict with 'state', 'result', and 'error' keys. + """ + exe = get_executor(cache_directory=MELTQUENCH_PROJECT_DIR) + lammps_resource_dict = get_lammps_resource_dict() + future = run_meltquench_workflow( + executor=exe, + components=request_data["components"], + values=request_data["values"], + n_atoms=request_data["n_atoms"], + potential_type=request_data["potential_type"], + heating_rate=request_data["heating_rate"], + cooling_rate=request_data["cooling_rate"], + n_print=request_data["n_print"], + lammps_resource_dict=lammps_resource_dict, + cache_key=cache_key, + ) + + # Resolve the future while the executor is still active + task_store = get_task_store() + + meta = { + "request_hash": request_hash, + "request_data": request_data, + **resolve_future(future, task_id), + } + + task_store.set(task_id, meta) + exe.shutdown(wait=False, cancel_futures=False) + + # Note: after shutdown of executor, do not touch the future anymore + # E.g. the FluxClusterExecutor will cancel the Future object (while not cancelling the underlying job) + # See https://github.com/pyiron/executorlib/issues/921#issuecomment-3919953044 + + return meta + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@router.post("/cache/meltquench", tags=["tool"]) +def check_cached_result(request: MeltquenchRequest) -> MeltquenchResult | None: + """Check if a result for the given meltquench request is already available in cache. + + Args: + request: The meltquench request to check. + + Returns: + The cached result if found, otherwise None. + + Raises: + HTTPException: If an error occurs during the check. + """ + try: + task_store = get_task_store() + request_hash = get_meltquench_hash(request) + logger.info("Checking for cached result with hash: %s", request_hash) + + cached_result = task_store.find_cached_result(request_hash) + + if cached_result: + logger.info("Found cached result") + return cached_result[1] + + logger.info("No cached result found") + return None + + except Exception: + logger.exception("Error checking cached result") + raise HTTPException(status_code=500, detail="Internal server error") from None + + +@router.post("/submit/meltquench", tags=["tool"]) +def submit_meltquench(request: MeltquenchRequest) -> TaskResponse: + """Start a new meltquench simulation task. + + Submit a melt-quench simulation for multi-component oxide glasses. + The calculation uses the PMMCS interatomic potential (Pedone et al.) + and runs a complete heating / cooling cycle for glass formation. + + Supported elements (PMMCS potential): + Ag, Al, Ba, Be, Ca, Co, Cr, Cu, Er, Fe, Fe3, Gd, Ge, K, Li, + Mg, Mn, Na, Nd, Ni, O, P, Sc, Si, Sn, Sr, Ti, Zn, Zr + + If the job with identical parameters has already been submitted, + it will return the cached result or current status. + + Note: Results can be visualized at /visualize/meltquench/{task_id} + + Args: + request: The meltquench request parameters. + + Returns: + TaskResponse with task ID, status, and result if available. + + Raises: + HTTPException: If the task cannot be started. + """ + try: + task_store = get_task_store() + request_hash = get_meltquench_hash(request) + request_data = request.model_dump() + + # Check if we already have a cached result in our database + cached_result = task_store.find_cached_result(request_hash) + if cached_result: + cached_task_id, cached_meltquench_result = cached_result + logger.info("Returning cached result from task %s", cached_task_id) + return build_task_response( + cached_task_id, + {"state": "complete", "result": cached_meltquench_result.model_dump()}, + from_cache=True, + ) + + task_id = str(uuid4()) + logger.info("Submitting meltquench task with ID: %s, hash: %s", task_id, request_hash) + status = submit_to_executor(request_data, task_id, request_hash, cache_key=request_hash) + return build_task_response(task_id, status) + + except HTTPException: + raise + except Exception: + logger.exception("Error submitting meltquench task") + raise HTTPException(status_code=500, detail="Internal server error") from None + + +@router.get("/check/{task_id}", tags=["tool"]) +def check(task_id: str) -> TaskResponse: + """Check the current status of a simulation task by its ID. + + Uses ``get_future_from_cache()`` to recreate the future from the + executor's disk cache, avoiding re-submission of the entire workflow. + + Note: When ready, visualize results at /visualize/meltquench/{task_id} + + Args: + task_id: The ID of the task to check. + + Returns: + TaskResponse with current status, result (if available), and visualization URL. + + Raises: + HTTPException: If the task is not found. + """ + task_store = get_task_store() + meta = task_store.get(task_id) + if not meta: + raise HTTPException(status_code=404, detail="Task not found") + logger.info("check %s: state=%s", task_id, meta["state"]) + + if meta["state"] != "running": + return build_task_response(task_id, meta) + + request_hash = meta.get("request_hash", "") + request_data = meta.get("request_data", {}) + + if not request_hash: + raise HTTPException(status_code=500, detail="Task is missing request hash") + + # Recreate the future from the executor's disk cache instead of + # re-submitting the entire workflow. See + # https://github.com/pyiron/executorlib/pull/915 + try: + future = get_future_from_cache( + cache_directory=str(MELTQUENCH_PROJECT_DIR), + cache_key=request_hash, + ) + + status = { + "request_hash": request_hash, + "request_data": request_data, + **resolve_future(future, task_id), + } + + task_store.set(task_id, status) + except FileNotFoundError: + # Cache files not yet written - job is still starting up + logger.info("Cache files not yet available for task %s", task_id) + status = {"state": "running", "request_hash": request_hash, "request_data": request_data} + except Exception as exc: + logger.exception("Failed to check task %s", task_id) + error_msg = str(exc) + status = {"state": "error", "error": error_msg, "request_hash": request_hash, "request_data": request_data} + task_store.set(task_id, status) + return build_task_response(task_id, status) diff --git a/amorphouspy_api/src/amorphouspy_api/worker.py b/amorphouspy_api/src/amorphouspy_api/worker.py deleted file mode 100644 index 5e1d8fd0..00000000 --- a/amorphouspy_api/src/amorphouspy_api/worker.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Worker module for amorphouspy simulations. - -This module contains the actual simulation logic that runs in separate processes, -isolated from the FastAPI server code to avoid unnecessary imports and potential -conflicts with signal handling. -""" - -import logging -from typing import Any - -from .models import MeltquenchRequest - - -def setup_worker_logging(task_id: str) -> logging.Logger: - """Set up logging for worker process. - - Args: - task_id: The unique identifier for the task. - - Returns: - Configured logger instance describing the worker process. - """ - logger = logging.getLogger(f"worker.{task_id}") - if not logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter(f"%(asctime)s - WORKER-{task_id} - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - return logger - - -def meltquench_worker(task_id: str, request_dict: dict[str, Any], db_path: str, shared_project_dir: str) -> None: - """Run synchronous meltquench simulation. - - This runs in a separate process to avoid blocking the event loop. - - Args: - task_id: Unique identifier for the task. - request_dict: Serialized meltquench parameters. - db_path: Path to SQLite database for task store. - shared_project_dir: Path to the shared project directory. - """ - from pathlib import Path - - from .database import TaskStore - - logger = setup_worker_logging(task_id) - logger.info(f"Starting meltquench simulation for task {task_id}") - - # Create task store instance for this worker process - task_store = TaskStore(Path(db_path)) - - # Reconstruct the request object from the dict - request = MeltquenchRequest(**request_dict) - logger.info(f"Request parameters: {request.model_dump()}") - - try: - # Import amorphouspy modules (import here to avoid startup dependencies) - import numpy as np - from amorphouspy import ( - generate_potential, - get_ase_structure, - get_structure_dict, - melt_quench_simulation, - ) - from amorphouspy.workflows.structural_analysis import analyze_structure - from executorlib import SingleNodeExecutor - - # Create composition string from request - comp_parts = [] - for component, value in zip(request.components, request.values, strict=False): - # Convert to fractions if percentages were provided - fraction = value / 100.0 if sum(request.values) > 1.1 else value - comp_parts.append(f"{fraction}{component}") - - composition = "-".join(comp_parts) - logger.info(f"Task {task_id}: Generated composition string: {composition}") - - # Update task status - current_task = task_store.get(task_id) or {"state": "processing"} - current_task["status"] = "Creating structure" - task_store.set(task_id, current_task) - logger.info(f"Task {task_id}: Creating structure") - - # Use the shared project directory passed from the main process - project_path = Path(shared_project_dir) - logger.info(f"Task {task_id}: Using shared project directory: {project_path}") - - # Create executor for caching workflow results - with SingleNodeExecutor(cache_directory=project_path) as exe: - atoms_dict = exe.submit( - get_structure_dict, - composition=composition, - # n_molecules=5000, # Default number of molecules - target_atoms=request.n_atoms, - ).result() - logger.info(f"Task {task_id}: Structure dictionary created with {len(atoms_dict['atoms'])} atoms") - - structure_future = exe.submit( - get_ase_structure, - atoms_dict=atoms_dict, - ) - logger.info(f"Task {task_id}: ASE structure created") - - potential_future = exe.submit( - generate_potential, - atoms_dict=atoms_dict, - potential_type=request.potential_type, - ) - logger.info(f"Task {task_id}: Potential generated") - - # Update task status - current_task = task_store.get(task_id) or {"state": "processing"} - current_task["status"] = "Running meltquench simulation" - task_store.set(task_id, current_task) - logger.info(f"Task {task_id}: Starting meltquench simulation") - - # Use simulation parameters from the request - logger.info( - f"Task {task_id}: Using heating_rate={request.heating_rate}, cooling_rate={request.cooling_rate}, n_print={request.n_print}" - ) - - # Run meltquench simulation - logger.info(f"Task {task_id}: Executing simulation workflow") - result = exe.submit( - melt_quench_simulation, - structure=structure_future, - potential=potential_future, - n_print=request.n_print, - # tmp_working_directory=str(tmp_dir_base), # note: if provided needs to be static - or prevents caching at executor level - heating_rate=request.heating_rate, - cooling_rate=request.cooling_rate, - langevin=False, - server_kwargs={}, - ).result() - logger.info(f"Task {task_id}: Simulation completed successfully") - - # Update task status for structural analysis - current_task = task_store.get(task_id) or {"state": "processing"} - current_task["status"] = "Running structural analysis" - task_store.set(task_id, current_task) - logger.info(f"Task {task_id}: Starting structural analysis") - - # Perform structural analysis on the final structure (includes density calculation) - final_structure = result["structure"] - logger.info(f"Task {task_id}: Analyzing structure with {len(final_structure)} atoms") - - # Run structural analysis - structural_data = exe.submit( - analyze_structure, - atoms=final_structure, - ).result() - logger.info(f"Task {task_id}: Structural analysis completed successfully") - - # Debug: Check what fields are present in the structural_data object - logger.info(f"Task {task_id}: StructureData type: {type(structural_data)}") - if hasattr(structural_data, "model_fields"): - logger.info(f"Task {task_id}: StructureData model fields: {list(structural_data.model_fields.keys())}") - if hasattr(structural_data, "__dict__"): - logger.info(f"Task {task_id}: StructureData attributes: {list(structural_data.__dict__.keys())}") - - # Use the structural data directly (it's now a Pydantic model with proper serialization) - structural_summary = structural_data.model_dump() if hasattr(structural_data, "model_dump") else structural_data - logger.info(f"Task {task_id}: Structural analysis data prepared") - logger.info( - f"Task {task_id}: Structural summary keys: {list(structural_summary.keys()) if isinstance(structural_summary, dict) else 'Not a dict'}" - ) - - # Store results including structural analysis - current_task = task_store.get(task_id) or {} - current_task.update( - { - "state": "complete", - "status": "Completed", - "result": { - "composition": composition, - "final_structure": result["structure"], # Store ASE Atoms object directly - "mean_temperature": float(np.mean(result["result"]["temperature"])), - "simulation_steps": len(result["result"]["steps"]), - "structural_analysis": structural_summary, - }, - } - ) - task_store.set(task_id, current_task) - - logger.info(f"Task {task_id}: Results stored, simulation complete") - - except Exception as exc: - logger.error(f"Task {task_id}: Simulation failed with error: {exc!s}", exc_info=True) - current_task = task_store.get(task_id) or {} - current_task.update({"state": "error", "status": "Failed", "error": str(exc)}) - task_store.set(task_id, current_task) diff --git a/amorphouspy_api/src/amorphouspy_api/workflows/__init__.py b/amorphouspy_api/src/amorphouspy_api/workflows/__init__.py new file mode 100644 index 00000000..d90c3918 --- /dev/null +++ b/amorphouspy_api/src/amorphouspy_api/workflows/__init__.py @@ -0,0 +1,5 @@ +"""Workflow functions for amorphouspy API.""" + +from .meltquench import run_meltquench_workflow + +__all__ = ["run_meltquench_workflow"] diff --git a/amorphouspy_api/src/amorphouspy_api/workflows/meltquench.py b/amorphouspy_api/src/amorphouspy_api/workflows/meltquench.py new file mode 100644 index 00000000..465bd83e --- /dev/null +++ b/amorphouspy_api/src/amorphouspy_api/workflows/meltquench.py @@ -0,0 +1,128 @@ +"""Meltquench workflow for glass simulation. + +This module contains the meltquench workflow function that uses executorlib +to submit different parts of the workflow with appropriate resources. + +The workflow is structured as: +1. Structure generation and potential setup (lightweight, no special resources) +2. LAMMPS melt-quench simulation (compute-intensive, uses LAMMPS_CORES) +3. Structural analysis (post-processing, no special resources) +""" + +import logging +from concurrent.futures import Future +from typing import TYPE_CHECKING, Any + +import numpy as np +from amorphouspy import ( + generate_potential, + get_ase_structure, + get_structure_dict, + melt_quench_simulation, +) +from amorphouspy.workflows.structural_analysis import analyze_structure + +if TYPE_CHECKING: + from executorlib.executor.base import BaseExecutor + +logger = logging.getLogger(__name__) + + +def run_meltquench_workflow( + executor: "BaseExecutor", + components: list[str], + values: list[float], + n_atoms: int, + potential_type: str, + heating_rate: float, + cooling_rate: float, + n_print: int, + lammps_resource_dict: dict[str, Any] | None = None, + cache_key: str | None = None, +) -> Future[dict[str, Any]]: + """Submit the complete meltquench workflow to the executor. + + This function submits multiple jobs to the executor with proper dependency + tracking. Different parts of the workflow can use different resources: + - Structure/potential generation: lightweight, default resources + - LAMMPS simulation: compute-intensive, uses lammps_resource_dict + + Args: + executor: The executorlib executor to submit jobs to. + components: List of oxide components (e.g., ["SiO2", "Na2O", "B2O3"]). + values: List of corresponding values (fractions or percentages). + n_atoms: Target number of atoms in the simulation. + potential_type: Type of interatomic potential to use. + heating_rate: Heating rate in K/ps. + cooling_rate: Cooling rate in K/ps. + n_print: Number of steps between output prints. + lammps_resource_dict: Resource dict for LAMMPS (e.g., {"cores": 4}). + cache_key: Optional explicit cache key for the final workflow step. + When set, the result can later be retrieved via + ``get_future_from_cache(cache_directory, cache_key)``. + + Returns: + Future that will resolve to the final result dictionary. + """ + if lammps_resource_dict is None: + lammps_resource_dict = {} + + # Build composition string from components and values + comp_parts = [] + for component, value in zip(components, values, strict=False): + fraction = value / 100.0 if sum(values) > 1.1 else value + comp_parts.append(f"{fraction}{component}") + composition = "-".join(comp_parts) + logger.info("Submitting meltquench workflow for composition: %s", composition) + + # Step 1-3: Submit structure and potential generation (lightweight) + atoms_dict_future = executor.submit(get_structure_dict, composition=composition, target_atoms=n_atoms) + structure_future = executor.submit(get_ase_structure, atoms_dict=atoms_dict_future) + potential_future = executor.submit(generate_potential, atoms_dict=atoms_dict_future, potential_type=potential_type) + + # Step 4: Submit LAMMPS melt-quench simulation (compute-intensive) + meltquench_future = executor.submit( + melt_quench_simulation, + structure=structure_future, + potential=potential_future, + n_print=n_print, + heating_rate=heating_rate, + cooling_rate=cooling_rate, + langevin=False, + server_kwargs=lammps_resource_dict, + ) + + # Step 5: Submit structural analysis and result assembly + final_resource_dict = {} + if cache_key is not None: + final_resource_dict["cache_key"] = cache_key + return executor.submit( + _assemble_results, + composition=composition, + meltquench_result=meltquench_future, + resource_dict=final_resource_dict if final_resource_dict else {}, + ) + + +def _assemble_results(composition: str, meltquench_result: dict[str, Any]) -> dict[str, Any]: + """Perform structural analysis and assemble final results. + + Args: + composition: Composition string. + meltquench_result: Result from melt_quench_simulation. + + Returns: + Final result dictionary with structural analysis. + """ + final_structure = meltquench_result["structure"] + structural_data = analyze_structure(atoms=final_structure) + + structural_summary = structural_data.model_dump() if hasattr(structural_data, "model_dump") else structural_data + + return { + "composition": composition, + "final_structure": final_structure, + "mean_temperature": float(np.mean(meltquench_result["result"]["temperature"])), + "simulation_steps": len(meltquench_result["result"]["steps"]), + "structural_analysis": structural_summary, + } diff --git a/amorphouspy_api/src/tests/test_database.py b/amorphouspy_api/src/tests/test_database.py index b980ba58..49a6a5f7 100644 --- a/amorphouspy_api/src/tests/test_database.py +++ b/amorphouspy_api/src/tests/test_database.py @@ -1,6 +1,7 @@ """Test database functionality for the task store.""" import tempfile +import threading from pathlib import Path from amorphouspy_api.database import TaskStore @@ -14,7 +15,11 @@ def test_task_store_basic_operations() -> None: store = TaskStore(db_path) # Test set and get - task_data = {"state": "processing", "status": "Starting", "request_hash": "abc123def456"} + task_data = { + "state": "processing", + "status": "Starting", + "request_hash": "abc123def456", + } store.set("test_task_1", task_data) retrieved = store.get("test_task_1") @@ -24,6 +29,8 @@ def test_task_store_basic_operations() -> None: assert retrieved["status"] == "Starting" assert retrieved["request_hash"] == "abc123def456" + store.close() + def test_task_store_cached_result_lookup() -> None: """Test efficient cached result lookup by hash.""" @@ -48,7 +55,11 @@ def test_task_store_cached_result_lookup() -> None: "structural_analysis": { "density": 2.5, "coordination": {"oxygen": {}, "formers": {}, "modifiers": {}}, - "network": {"connectivity": 3.0, "Qn_distribution": {}, "Qn_distribution_partial": {}}, + "network": { + "connectivity": 3.0, + "Qn_distribution": {}, + "Qn_distribution_partial": {}, + }, "distributions": {"bond_angles": {}, "rings": {}}, "rdfs": {"r": [], "rdfs": {}, "cumulative_coordination": {}}, "elements": {"formers": ["Si"], "modifiers": ["Na"], "cutoffs": {}}, @@ -82,6 +93,8 @@ def test_task_store_cached_result_lookup() -> None: no_result = store.find_cached_result("nonexistent_hash") assert no_result is None + store.close() + def test_task_store_items() -> None: """Test getting all tasks.""" @@ -101,6 +114,8 @@ def test_task_store_items() -> None: assert "task1" in task_ids assert "task2" in task_ids + store.close() + def test_task_store_persistence() -> None: """Test that data persists across TaskStore instances.""" @@ -109,7 +124,10 @@ def test_task_store_persistence() -> None: # Create store and add data store1 = TaskStore(db_path) - store1.set("persistent_task", {"state": "complete", "status": "Done", "request_hash": "persistent_hash"}) + store1.set( + "persistent_task", + {"state": "complete", "status": "Done", "request_hash": "persistent_hash"}, + ) # Create new store instance with same database store2 = TaskStore(db_path) @@ -119,3 +137,113 @@ def test_task_store_persistence() -> None: assert retrieved["state"] == "complete" assert retrieved["status"] == "Done" assert retrieved["request_hash"] == "persistent_hash" + + store1.close() + store2.close() + + +def test_task_store_concurrent_writes() -> None: + """Test that multiple threads can write to the task store simultaneously.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_tasks.db" + store = TaskStore(db_path) + + errors: list[Exception] = [] + n_threads = 10 + + def write_task(i: int) -> None: + try: + store.set( + f"thread-task-{i}", + {"state": "processing", "request_hash": f"hash-{i}"}, + ) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=write_task, args=(i,)) for i in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [], f"Concurrent writes failed: {errors}" + + # Verify all tasks were written + items = store.items() + assert len(items) == n_threads + + store.close() + + +def test_task_store_concurrent_cache_lookup() -> None: + """Test that find_cached_result works correctly from multiple threads. + + This simulates the pattern where FastAPI runs sync endpoints in a + threadpool — multiple /check or /cache requests hitting the DB at once. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_tasks.db" + store = TaskStore(db_path) + + mock_structure = { + "numbers": [14, 8, 8], + "positions": [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + "cell": [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]], + "pbc": [True, True, True], + } + + store.set( + "cached-task", + { + "state": "complete", + "request_hash": "shared-hash", + "result": { + "composition": "SiO2", + "final_structure": mock_structure, + "mean_temperature": 300.0, + "simulation_steps": 100, + "structural_analysis": { + "density": 2.2, + "coordination": {"oxygen": {}, "formers": {}, "modifiers": {}}, + "network": { + "connectivity": 4.0, + "Qn_distribution": {}, + "Qn_distribution_partial": {}, + }, + "distributions": {"bond_angles": {}, "rings": {}}, + "rdfs": {"r": [], "rdfs": {}, "cumulative_coordination": {}}, + "elements": {"formers": ["Si"], "modifiers": [], "cutoffs": {}}, + }, + }, + }, + ) + + errors: list[Exception] = [] + results: list[tuple | None] = [] + lock = threading.Lock() + n_threads = 10 + + def lookup() -> None: + try: + result = store.find_cached_result("shared-hash") + with lock: + results.append(result) + except Exception as e: + with lock: + errors.append(e) + + threads = [threading.Thread(target=lookup) for _ in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [], f"Concurrent cache lookups failed: {errors}" + assert len(results) == n_threads + for r in results: + assert r is not None + task_id, mq_result = r + assert task_id == "cached-task" + assert mq_result.composition == "SiO2" + + store.close() diff --git a/amorphouspy_api/src/tests/test_hash_caching.py b/amorphouspy_api/src/tests/test_hash_caching.py index 62fa284c..0958330e 100644 --- a/amorphouspy_api/src/tests/test_hash_caching.py +++ b/amorphouspy_api/src/tests/test_hash_caching.py @@ -6,8 +6,8 @@ 3. The caching logic can be imported and executed without errors """ -from amorphouspy_api.app import get_meltquench_hash from amorphouspy_api.models import MeltquenchRequest +from amorphouspy_api.routers.meltquench import get_meltquench_hash def test_hash_consistency() -> None: diff --git a/amorphouspy_api/src/tests/test_meltquench.py b/amorphouspy_api/src/tests/test_meltquench.py index a963fdb3..ef9669a4 100644 --- a/amorphouspy_api/src/tests/test_meltquench.py +++ b/amorphouspy_api/src/tests/test_meltquench.py @@ -1,80 +1,29 @@ -"""Unit tests for meltquench API functionality.""" +"""Unit tests for meltquench API functionality. + +Tests insert tasks directly into the task store rather than mocking the executor, +except for tests that specifically exercise the /submit endpoint. +""" import time +from collections.abc import Generator +from contextlib import contextmanager +from types import SimpleNamespace from typing import Any from unittest.mock import MagicMock, patch -import pytest from fastapi.testclient import TestClient from amorphouspy_api.app import app +from amorphouspy_api.database import get_task_store from amorphouspy_api.models import MeltquenchRequest +from amorphouspy_api.routers.meltquench import get_meltquench_hash client = TestClient(app) -@pytest.fixture(autouse=True) -def _patch_worker(monkeypatch) -> None: - """Replace background worker with a no-op that writes a completed result. - - This keeps tests fully in-process and avoids spawning real child processes. - """ - from amorphouspy_api import app as app_module - - async def fake_worker(task_id: str, request: MeltquenchRequest) -> None: - from amorphouspy_api.database import get_task_store - - ts = get_task_store() - ts.set( - task_id, - { - "state": "complete", - "status": "Completed", - "result": { - "composition": "0.6SiO2-0.25CaO-0.15Al2O3", - "final_structure": create_mock_structure_dict(), - "mean_temperature": 302.3333333333, - "simulation_steps": 3, - "structural_analysis": create_mock_structural_analysis_data(), - }, - }, - ) - - monkeypatch.setattr(app_module, "_meltquench_worker", fake_worker) - - -class MockAtoms: - """Mock ASE Atoms-like object that can be serialized.""" - - def __init__(self, atoms_dict: dict[str, Any]) -> None: - """Initialize mock atoms with dictionary data.""" - self._dict = atoms_dict - - def get_masses(self) -> object: - """Return a mock that has a sum method.""" - - class MockMasses: - def sum(self) -> int: - return 1000 # mock mass - - return MockMasses() - - def __str__(self) -> str: - """Return string representation of mock atoms.""" - return "Mock ASE structure with 100 atoms" - - def __getstate__(self) -> dict[str, Any]: - """Return a fully serializable dictionary - avoid any ASE objects.""" - return { - "numbers": self._dict["numbers"], - "positions": self._dict["positions"], - "cell": self._dict["cell"], # Keep as nested list, not Cell object - "pbc": self._dict["pbc"], - } - - def __setstate__(self, state: dict[str, Any]) -> None: - """Restore state from serialized dictionary.""" - self._dict = state +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- def create_mock_structure_dict() -> dict[str, Any]: @@ -92,71 +41,75 @@ def create_mock_structural_analysis_data() -> dict[str, Any]: return { "density": 2.5, "coordination": {"oxygen": {}, "formers": {}, "modifiers": {}}, - "network": {"Qn_distribution": {}, "Qn_distribution_partial": {}, "connectivity": 0.0}, + "network": { + "Qn_distribution": {}, + "Qn_distribution_partial": {}, + "connectivity": 0.0, + }, "distributions": {"bond_angles": {}, "rings": {}}, "rdfs": {"r": [], "rdfs": {}, "cumulative_coordination": {}}, "elements": {"formers": [], "modifiers": [], "cutoffs": {}}, } -def create_mock_result_data() -> dict[str, Any]: - """Create mock simulation result data.""" +def create_mock_result( + composition: str = "0.6SiO2-0.25CaO-0.15Al2O3", +) -> dict[str, Any]: + """Create a complete mock meltquench result.""" return { - "structure": create_mock_structure_dict(), - "result": { - "volume": [1000, 1000, 1000], # cm³ - "temperature": [300, 305, 302], # K - "steps": [1, 2, 3], - }, + "composition": composition, + "final_structure": create_mock_structure_dict(), + "mean_temperature": 302.3333333333, + "simulation_steps": 3, + "structural_analysis": create_mock_structural_analysis_data(), } -def setup_common_mocks( - mock_project: MagicMock, - mock_get_structure_dict: MagicMock, - mock_get_ase_structure: MagicMock, - mock_generate_potential: MagicMock, - mock_melt_quench_simulation: MagicMock, - mock_analyze_structure: MagicMock, +def insert_completed_task( + task_id: str, + *, + request_hash: str = "test-hash", + composition: str = "0.6SiO2-0.25CaO-0.15Al2O3", + request_data: dict[str, Any] | None = None, ) -> None: - """Set up common mock objects for meltquench tests.""" - # Mock the simulation components - mock_atoms_dict = {"atoms": [{"element": "Si", "position": [0, 0, 0]}] * 100} - mock_get_structure_dict.return_value.pull.return_value = mock_atoms_dict - - # Create mock structure - mock_structure_dict = create_mock_structure_dict() - mock_structure = MockAtoms(mock_structure_dict) - mock_get_ase_structure.return_value = mock_structure - - # Mock potential - mock_potential = "mock_potential_content" - mock_generate_potential.return_value = mock_potential - - # Mock structural analysis - mock_analyze_structure.return_value.pull.return_value = create_mock_structural_analysis_data() - - # Mock simulation result - mock_melt_quench_simulation.return_value.pull.return_value = create_mock_result_data() - - -def wait_for_task_completion(task_id: str, max_wait: float = 10.0) -> dict[str, Any]: - """Wait for a task to complete and return the final check data.""" - waited = 0.0 - while waited < max_wait: - check_response = client.get(f"/check/{task_id}") - assert check_response.status_code == 200 - check_data = check_response.json() - - if check_data["state"] == "complete": - return check_data - if check_data["state"] == "error": - pytest.fail(f"Simulation failed: {check_data.get('error')}") + """Insert a completed task into the task store.""" + get_task_store().set( + task_id, + { + "state": "complete", + "request_hash": request_hash, + "request_data": request_data, + "result": create_mock_result(composition), + }, + ) - time.sleep(0.5) - waited += 0.5 - pytest.fail(f"Task {task_id} did not complete within {max_wait} seconds") +def insert_running_task( + task_id: str, + *, + request_hash: str = "test-hash-running", + request_data: dict[str, Any] | None = None, +) -> None: + """Insert a running task into the task store.""" + if request_data is None: + request_data = { + "components": ["SiO2"], + "values": [100.0], + "unit": "wt", + "n_atoms": 3, + "potential_type": "pmmcs", + "heating_rate": 1e12, + "cooling_rate": 1e12, + "n_print": 100, + } + get_task_store().set( + task_id, + { + "state": "running", + "request_hash": request_hash, + "request_data": request_data, + }, + ) def validate_result_structure(result: dict[str, Any]) -> None: @@ -167,9 +120,7 @@ def validate_result_structure(result: dict[str, Any]) -> None: assert "structural_analysis" in result assert "simulation_steps" in result - # Validate numerical values assert isinstance(result["mean_temperature"], float) - # Handle both dict and StructureData object cases if isinstance(result["structural_analysis"], dict): assert isinstance(result["structural_analysis"]["density"], float) else: @@ -177,42 +128,102 @@ def validate_result_structure(result: dict[str, Any]) -> None: assert isinstance(result["simulation_steps"], int) -def test_submit_meltquench_and_check() -> None: - """Test the complete meltquench workflow without real background processes.""" - # Submit meltquench task - payload = {"components": ["SiO2", "CaO", "Al2O3"], "values": [60.0, 25.0, 15.0], "unit": "wt"} - response = client.post("/submit/meltquench", json=payload) +# --------------------------------------------------------------------------- +# /submit/meltquench tests +# --------------------------------------------------------------------------- + + +@contextmanager +def _mock_executor_context() -> Generator[SimpleNamespace, None, None]: + """Context manager that patches get_executor and run_meltquench_workflow.""" + mock_future = MagicMock() + mock_future.result.return_value = create_mock_result() + mock_future.done.return_value = True + mock_future.exception.return_value = None + + with ( + patch("amorphouspy_api.routers.meltquench.get_executor") as mock_get_exe, + patch( + "amorphouspy_api.routers.meltquench.run_meltquench_workflow", + return_value=mock_future, + ) as mock_workflow, + ): + mock_get_exe.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_get_exe.return_value.__exit__ = MagicMock(return_value=False) + yield SimpleNamespace(mock_workflow=mock_workflow, mock_future=mock_future) + + +def test_submit_meltquench_new_task() -> None: + """Test submitting a new task runs the executor and returns completed.""" + with _mock_executor_context(): + payload = { + "components": ["SiO2", "CaO", "Al2O3"], + "values": [60.0, 25.0, 15.0], + "unit": "wt", + } + response = client.post("/submit/meltquench", json=payload) + assert response.status_code == 200 data = response.json() assert "task_id" in data - assert "status" in data - - # Handle cached results - if data["status"] == "completed_from_cache": - assert "result" in data - validate_result_structure(data["result"]) - return + assert data["status"] == "completed" + assert data["result"] is not None + validate_result_structure(data["result"]) + + # Verify task was stored as complete + stored = get_task_store().get(data["task_id"]) + assert stored is not None + assert stored["state"] == "complete" + + +def test_submit_meltquench_returns_cached() -> None: + """Test that submitting a duplicate request returns the cached result.""" + # Pre-insert a completed task with a known hash + request = MeltquenchRequest( + components=["SiO2", "BaO"], + values=[80.0, 20.0], + unit="wt", + ) + request_hash = get_meltquench_hash(request) + insert_completed_task("cached-task-1", request_hash=request_hash, composition="0.8SiO2-0.2BaO") - # Wait for completion and validate - assert data["status"] == "started" - check_data = wait_for_task_completion(data["task_id"]) + # Submit with the same parameters — should return cached + response = client.post("/submit/meltquench", json=request.model_dump()) - assert check_data["task_id"] == data["task_id"] - assert check_data["state"] == "complete" - assert check_data["result"] is not None + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed_from_cache" + assert data["task_id"] == "cached-task-1" - # Validate the result structure - validate_result_structure(check_data["result"]) - # Validate composition format - assert check_data["result"]["composition"] == "0.6SiO2-0.25CaO-0.15Al2O3" +def test_submit_meltquench_stores_request_data() -> None: + """Test that submitting a new task stores request_data.""" + with _mock_executor_context(): + payload = { + "components": ["SiO2", "ZnO"], + "values": [90.0, 10.0], + "unit": "wt", + } + response = client.post("/submit/meltquench", json=payload) + assert response.status_code == 200 + stored = get_task_store().get(response.json()["task_id"]) + assert stored is not None + assert stored["request_data"]["components"] == ["SiO2", "ZnO"] + assert stored["request_data"]["values"] == [90.0, 10.0] + + +def test_submit_meltquench_executor_error_returns_500() -> None: + """Test that an executor error returns HTTP 500.""" + with patch("amorphouspy_api.routers.meltquench.get_executor", side_effect=RuntimeError): + payload = { + "components": ["SiO2", "TiO2"], + "values": [95.0, 5.0], + "unit": "wt", + } + response = client.post("/submit/meltquench", json=payload) -def test_check_nonexistent_task() -> None: - """Test checking a task that doesn't exist.""" - response = client.get("/check/nonexistent-task-id") - assert response.status_code == 404 - assert "Task not found" in response.json()["detail"] + assert response.status_code == 500 def test_invalid_payload() -> None: @@ -223,169 +234,148 @@ def test_invalid_payload() -> None: "unit": "wt", } response = client.post("/submit/meltquench", json=payload) - assert response.status_code == 422 # Validation error + assert response.status_code == 422 -def test_root_redirect() -> None: - """Test that root redirects to docs.""" - # FastAPI TestClient follows redirects by default, so we need to check differently - # We can verify that accessing "/" eventually serves docs content - response = client.get("/") +# --------------------------------------------------------------------------- +# /check/{task_id} tests +# --------------------------------------------------------------------------- + + +def test_check_completed_task() -> None: + """Test that checking a completed task returns the stored result.""" + insert_completed_task("check-complete-1", request_hash="hash-check-1") + + response = client.get("/check/check-complete-1") assert response.status_code == 200 - # The response should contain swagger/docs content when redirected - assert "swagger" in response.text.lower() or "openapi" in response.text.lower() + data = response.json() + assert data["status"] == "completed" + assert data["result"] is not None + validate_result_structure(data["result"]) -def validate_cached_result(data: dict[str, Any] | None) -> None: - """Validate cached result structure if it exists.""" - if data is not None: - assert "composition" in data - assert "structural_analysis" in data - # Handle both dict and StructureData object cases - if isinstance(data["structural_analysis"], dict): - assert "density" in data["structural_analysis"] - else: - assert hasattr(data["structural_analysis"], "density") - assert "final_structure" in data - assert "mean_temperature" in data - assert "simulation_steps" in data +def test_check_errored_task() -> None: + """Test that checking an errored task returns the error.""" + get_task_store().set( + "check-error-1", + { + "state": "error", + "request_hash": "hash-check-error-1", + "error": "LAMMPS crashed", + }, + ) + response = client.get("/check/check-error-1") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "error" + assert data["error"] == "LAMMPS crashed" -def test_check_cached_result_found() -> None: - """Test checking for cached results with a specific composition.""" - payload = { - "components": ["SiO2", "K2O"], # Different from other tests - "values": [85.0, 15.0], - "unit": "wt", - } - response = client.post("/cache/meltquench", json=payload) +def test_check_nonexistent_task() -> None: + """Test checking a task that doesn't exist.""" + response = client.get("/check/nonexistent-task-id") + assert response.status_code == 404 + assert "Task not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# /cache/meltquench tests +# --------------------------------------------------------------------------- + + +def test_cache_hit() -> None: + """Test cache endpoint returns a result when one exists.""" + request = MeltquenchRequest( + components=["SiO2", "K2O"], + values=[85.0, 15.0], + unit="wt", + ) + request_hash = get_meltquench_hash(request) + insert_completed_task("cache-hit-1", request_hash=request_hash, composition="0.85SiO2-0.15K2O") + + response = client.post("/cache/meltquench", json=request.model_dump()) assert response.status_code == 200 - validate_cached_result(response.json()) + data = response.json() + assert data is not None + assert data["composition"] == "0.85SiO2-0.15K2O" -def test_check_cached_result_not_found() -> None: - """Test checking for cached results with another unique composition.""" +def test_cache_miss() -> None: + """Test cache endpoint returns null when no result exists.""" payload = { - "components": ["SiO2", "Li2O"], # Different from other tests + "components": ["SiO2", "Li2O"], "values": [90.0, 10.0], "unit": "wt", } - response = client.post("/cache/meltquench", json=payload) assert response.status_code == 200 - validate_cached_result(response.json()) + assert response.json() is None -def test_caching_behavior() -> None: - """Test that caching actually works by submitting and then checking cache.""" - unique_payload = { - "components": ["SiO2", "MgO"], - "values": [70.0, 30.0], - "unit": "wt", - "heating_rate": int(1e15), # Fast for testing - "cooling_rate": int(1e15), - "n_print": 100, - } - - # Check cache first - cache_response = client.post("/cache/meltquench", json=unique_payload) - assert cache_response.status_code == 200 - - # Submit the simulation (will be mocked by autouse fixture) - submit_response = client.post("/submit/meltquench", json=unique_payload) - assert submit_response.status_code == 200 - submit_data = submit_response.json() - - # Should either start a new task or return cached result - assert "task_id" in submit_data - assert "status" in submit_data - assert submit_data["status"] in ["started", "completed_from_cache"] +# --------------------------------------------------------------------------- +# Visualization tests +# --------------------------------------------------------------------------- @patch("amorphouspy.workflows.structural_analysis.plot_analysis_results_plotly") def test_visualization_endpoint(mock_plot_analysis_results_plotly: MagicMock) -> None: - """Test the visualization endpoint with mocked plot generation.""" - # Create a mock figure for the plot + """Test the visualization endpoint returns HTML for a completed task.""" mock_fig = MagicMock() - mock_fig.to_dict.return_value = {"data": [], "layout": {}} # Mock Plotly figure dict + mock_fig.to_dict.return_value = {"data": [], "layout": {}} mock_plot_analysis_results_plotly.return_value = mock_fig - # Submit task with unique payload to avoid caching - unique_suffix = str(int(time.time() * 1000)) # millisecond timestamp - payload = { - "components": ["SiO2", "Na2O"], - "values": [75.0, 25.0], - "unit": "wt", - "heating_rate": int(unique_suffix[-6:]), # Use last 6 digits - } - - submit_response = client.post("/submit/meltquench", json=payload) - assert submit_response.status_code == 200 - submit_data = submit_response.json() - task_id = submit_data["task_id"] - - # Overwrite task result directly to tailor the visualization data - from amorphouspy_api.database import get_task_store - + task_id = f"viz-task-{int(time.time() * 1000)}" get_task_store().set( task_id, { "state": "complete", - "status": "Completed", + "request_hash": f"viz-hash-{task_id}", "result": { - "composition": "0.75SiO2-0.25Na2O", - "final_structure": create_mock_structure_dict(), - "mean_temperature": 300.0, - "simulation_steps": 3, - "structural_analysis": {**create_mock_structural_analysis_data(), "density": 2.65}, + **create_mock_result("0.75SiO2-0.25Na2O"), + "structural_analysis": { + **create_mock_structural_analysis_data(), + "density": 2.65, + }, }, }, ) - # Test the visualization endpoint - viz_response = client.get(f"/visualize/meltquench/{task_id}") - assert viz_response.status_code == 200 - - # Check that we get HTML content - assert viz_response.headers["content-type"] == "text/html; charset=utf-8" - html_content = viz_response.text + response = client.get(f"/visualize/meltquench/{task_id}") + assert response.status_code == 200 + assert response.headers["content-type"] == "text/html; charset=utf-8" - # Verify HTML contains expected elements + html_content = response.text assert "Melt-Quench Simulation Results" in html_content assert task_id in html_content assert "plotlyData" in html_content or "plotly-div" in html_content - - # Verify the plot function was called mock_plot_analysis_results_plotly.assert_called_once() -def test_visualization_endpoint_task_not_found() -> None: +def test_visualization_task_not_found() -> None: """Test visualization endpoint with non-existent task.""" response = client.get("/visualize/meltquench/nonexistent-task") assert response.status_code == 404 assert "Task not found" in response.json()["detail"] -def test_visualization_endpoint_incomplete_task() -> None: - """Test visualization endpoint with incomplete task.""" - # Create a task manually in the database with 'running' state - from amorphouspy_api.app import get_meltquench_hash - from amorphouspy_api.database import get_task_store - from amorphouspy_api.models import MeltquenchRequest +def test_visualization_incomplete_task() -> None: + """Test visualization endpoint with an incomplete task.""" + task_id = "viz-incomplete-task" + insert_running_task(task_id, request_hash="viz-incomplete-hash") - task_store = get_task_store() - fake_task_id = "test-incomplete-task-123" + response = client.get(f"/visualize/meltquench/{task_id}") + assert response.status_code == 400 + assert "not completed yet" in response.json()["detail"] - # Create a proper request to generate hash - request_data = {"components": ["SiO2"], "values": [100.0], "unit": "wt"} - request = MeltquenchRequest(**request_data) - request_hash = get_meltquench_hash(request) - # Add incomplete task to database - task_store.set(fake_task_id, {"state": "running", "request_data": request_data, "request_hash": request_hash}) +# --------------------------------------------------------------------------- +# General tests +# --------------------------------------------------------------------------- - # Try to visualize incomplete task - viz_response = client.get(f"/visualize/meltquench/{fake_task_id}") - assert viz_response.status_code == 400 - assert "not completed yet" in viz_response.json()["detail"] + +def test_root_redirect() -> None: + """Test that root redirects to docs.""" + response = client.get("/") + assert response.status_code == 200 + assert "swagger" in response.text.lower() or "openapi" in response.text.lower() diff --git a/amorphouspy_api/src/tests/test_meltquench_integration.py b/amorphouspy_api/src/tests/test_meltquench_integration.py index 1650c2cf..a822b7d3 100644 --- a/amorphouspy_api/src/tests/test_meltquench_integration.py +++ b/amorphouspy_api/src/tests/test_meltquench_integration.py @@ -1,6 +1,7 @@ """Integration tests for meltquench API with live server.""" import logging +import os import time import pytest @@ -29,15 +30,20 @@ def is_api_server_running(url: str) -> bool: @pytest.mark.integration def test_meltquench_api_integration() -> None: """Full integration test for the meltquench API using a running server. - Requires: API server running in main thread with amorphouspy_INTEGRATION=1 + Requires: API server running in main thread with AMORPHOUSPY_INTEGRATION=1 Example: - amorphouspy_INTEGRATION=1 uvicorn amorphouspy_api.src.amorphouspy_api.app:app --port 8002 + AMORPHOUSPY_INTEGRATION=1 uvicorn amorphouspy_api.src.amorphouspy_api.app:app --port 8002 pytest -m integration. """ API_URL = "http://127.0.0.1:8002" root_url = f"{API_URL}/" logger.info("Checking API server status...") if not is_api_server_running(root_url): + if os.environ.get("AMORPHOUSPY_INTEGRATION"): + pytest.fail( + "API server not running at http://127.0.0.1:8002/ " + "but AMORPHOUSPY_INTEGRATION is set — the server should have started" + ) pytest.skip("API server not running at http://127.0.0.1:8002/") # Use faster rates for integration testing @@ -71,20 +77,22 @@ def test_meltquench_api_integration() -> None: r = requests.get(f"{API_URL}/check/{task_id}", timeout=30) r.raise_for_status() check_data = r.json() - state = check_data["state"] - logger.info("Polling: state=%s", state) - if state == "complete": + status = check_data["status"] + logger.info("Polling: status=%s", status) + if status == "completed": logger.info("Result: %s", check_data["result"]) result = check_data["result"] break - if state == "error": + if status == "error": logger.error("Meltquench task errored: %s", check_data.get("error")) pytest.fail(f"Meltquench task errored: {check_data.get('error')}") if time.time() - start > timeout: logger.error( - "Timeout: Meltquench task did not complete within %s seconds. Last state: %s", timeout, state + "Timeout: Meltquench task did not complete within %s seconds. Last status: %s", + timeout, + status, ) - pytest.fail(f"Meltquench task did not complete within {timeout} seconds. Last state: {state}") + pytest.fail(f"Meltquench task did not complete within {timeout} seconds. Last status: {status}") time.sleep(poll_interval) assert result is not None @@ -137,4 +145,7 @@ def test_meltquench_api_integration() -> None: logger.info("✓ Temperature: %.1f K", temp) logger.info("✓ Density: %.2f g/cm³", density) logger.info("✓ Steps: %s", steps) - logger.info("✓ Structural analysis: %s", {k: v for k, v in structural_analysis.items() if k != "error"}) + logger.info( + "✓ Structural analysis: %s", + {k: v for k, v in structural_analysis.items() if k != "error"}, + ) diff --git a/environment.yml b/environment.yml index ed75f161..2bff412f 100644 --- a/environment.yml +++ b/environment.yml @@ -1,23 +1,25 @@ name: amorphouspy channels: -- conda-forge + - conda-forge dependencies: -- python =3.13 -- ase >=3.25.0 -- cryptography =45.0.7 -- executorlib =1.7.4 -- hatchling -- jupyter -- lammps =2024.08.29=*_openmpi_* -- networkx ~=3.4 -- pandas =2.3.3 -- numpy =2.3.3 -- pygraphviz =1.14 -- lammpsparser =0.0.1 -- pymatgen =2025.10.07 -- scipy =1.16.2 -- sqlalchemy -- numba -- uvicorn -- fastapi-mcp =0.4.0 -- sovapy =0.8.3 + - python =3.13 + - ase >=3.25.0 + - cryptography =45.0.7 + - executorlib >=1.8.2 + - flux-core >=0.81.0 + - pysqa >=0.3.4 + - hatchling + - jupyter + - lammps =2024.08.29=*_openmpi_* + - networkx ~=3.4 + - pandas =2.3.3 + - numpy =2.3.3 + - pygraphviz =1.14 + - lammpsparser =0.0.1 + - pymatgen =2025.10.07 + - scipy =1.16.2 + - sqlalchemy + - numba + - uvicorn + - fastapi-mcp =0.4.0 + - sovapy =0.8.3