diff --git a/README.md b/README.md
index 5fe96c3e..3454e1f5 100644
--- a/README.md
+++ b/README.md
@@ -39,7 +39,7 @@ As much as we can, we follow the [CleanRL](https://github.com/vwxyzjn/cleanrl) p
[pixel_cnn_pp_2d](engiopt/pixel_cnn_pp_2d) | Inverse Design | 2D | ✅ | PixelCNN++ Autoregressive Model
## Dashboards
-The integration with WandB allows us to access live dashboards of our runs (on the cluster or not). We also upload the trained models there. You can access some of our runs at https://wandb.ai/engibench/engiopt.
+The integration with WandB allows us to access live dashboards of our runs (on the cluster or not). New checkpoint packages are stored on the Hugging Face Hub by default, while WandB keeps experiment tracking, metadata, and links back to the canonical checkpoint location. Historical WandB model artifacts remain supported for backward compatibility. You can access some of our runs at https://wandb.ai/engibench/engiopt.
@@ -66,6 +66,11 @@ First, if you want to use weights and biases, you need to set the `WANDB_API_KEY
wandb login
```
+If you want to save or load checkpoints from Hugging Face Hub, make sure your environment is authenticated there as well:
+```
+huggingface-cli login
+```
+
### Inverse design
Usually, we provide two scripts per algorithm: one to train the model, and one to evaluate it.
@@ -77,6 +82,20 @@ python engiopt/cgan_cnn_2d/cgan_cnn_2d.py --problem-id "beams2d" --track --wandb
This will run a CGAN 2D using CNN model on the beams2d problem. `--track` will track the run on wandb, `--wandb-entity None` will use the default wandb entity, `--save-model` will save the model, `--n-epochs 200` will run for 200 epochs, and `--seed 1` will set the random seed.
+By default, `--save-model` now stores a self-contained checkpoint package on the Hugging Face Hub. The default backend is:
+```
+--checkpoint-backend hf
+```
+You can still force legacy or hybrid behavior when needed:
+```
+--checkpoint-backend wandb
+--checkpoint-backend both
+--checkpoint-backend none
+```
+
+All HF-backed checkpoint packages contain the model files together with `run_config.json` and `metadata.json`, so evaluation does not depend on live WandB run config state.
+When W&B tracking is active, the HF package metadata also records the originating W&B run identity, and the W&B run summary records the HF repo, the seed-based convenience path, the exact uploaded HF revision, and an immutable run-specific HF package path.
+
For reproducible debugging runs, you can additionally enable strict deterministic mode:
```
python engiopt/cgan_cnn_2d/cgan_cnn_2d.py --problem-id "beams2d" --seed 1 --strict-determinism
@@ -95,7 +114,30 @@ Then you can restore a trained model and evaluate it:
```
python engiopt/cgan_cnn_2d/evaluate_cgan_cnn_2d.py --problem-id "beams2d" --wandb-entity None --seed 1 --n-samples 10
```
-This will generate 10 designs from the trained model and run some [metrics](https://github.com/IDEALLab/EngiOpt/blob/main/engiopt/metrics.py) on them. This is what we used to generate the results in the paper. This by default will pull the model from wandb. It is possible to restore a model from a local file but is not currently supported.
+This will generate 10 designs from the trained model and run some [metrics](https://github.com/IDEALLab/EngiOpt/blob/main/engiopt/metrics.py) on them. This is what we used to generate the results in the paper.
+
+Evaluation now defaults to:
+```
+--model-source auto
+```
+In `auto` mode, EngiOpt tries to resolve checkpoints in this order:
+1. Hugging Face package for the model family, problem, and seed
+2. Legacy WandB model artifact
+3. Explicit local checkpoint package directory if you pass `--local-model-dir`
+
+For new HF-backed runs, EngiOpt maintains both:
+- a seed-based convenience path such as `beams2d/seed_1`
+- an immutable run-specific path such as `beams2d/seed_1/run_`
+
+You can force legacy WandB loading for historical runs:
+```
+python engiopt/cgan_cnn_2d/evaluate_cgan_cnn_2d.py --problem-id "beams2d" --seed 1 --model-source wandb
+```
+
+You can also point evaluation at a local package directory:
+```
+python engiopt/cgan_cnn_2d/evaluate_cgan_cnn_2d.py --problem-id "beams2d" --seed 1 --model-source local --local-model-dir /path/to/package
+```
### Surrogate model
@@ -107,6 +149,13 @@ The current surrogate model comprises several steps:
See this [notebook](https://github.com/IDEALLab/EngiOpt/blob/main/engiopt/surrogate_model/case_study_pe_notebook.ipynb) for an example.
+Surrogate-model optimization paths now use the same checkpoint abstraction. For example, the power-electronics optimizer can consume:
+* legacy WandB artifact refs
+* HF package refs such as `hf://IDEALLab/engiopt-mlp-tabular-only/power_electronics/DcGain/seed_42`
+* local checkpoint package directories
+
+For migration guidance on moving historical checkpoint subsets from WandB to the IDEALLab HF organization later, see [docs/checkpoint_migration_playbook.md](docs/checkpoint_migration_playbook.md).
+
diff --git a/docs/checkpoint_migration_playbook.md b/docs/checkpoint_migration_playbook.md
new file mode 100644
index 00000000..243d2a6d
--- /dev/null
+++ b/docs/checkpoint_migration_playbook.md
@@ -0,0 +1,96 @@
+# Checkpoint Migration Playbook
+
+This playbook describes how to migrate historical EngiOpt checkpoints from Weights & Biases (W&B) to the Hugging Face Hub (HF) without breaking backward compatibility.
+
+This is guidance only for phase 1 of the HF checkpoint backend rollout.
+
+What this phase does not do:
+- delete historical W&B artifacts
+- mutate historical W&B artifacts in place
+- assume every historical run should be migrated
+
+## Goals
+
+1. Stop creating new long-lived checkpoint pressure in W&B by making HF the default backend for new saved models.
+2. Keep historical W&B checkpoints loadable while HF-backed packages roll out.
+3. Create a safe path to reclaim W&B storage later, after validation.
+
+## Recommended Scope
+
+Start with the official or release subset of checkpoints first, not the full historical backlog.
+
+Good early candidates:
+- checkpoints referenced in papers, benchmarks, or public notebooks
+- checkpoints linked from README examples
+- checkpoints used by downstream evaluation scripts or case studies
+
+## Migration Package Format
+
+Each migrated HF checkpoint package should be self-contained and include:
+- the original model weight file or files
+- `run_config.json`
+- `metadata.json`
+
+`metadata.json` should record at least:
+- `problem_id`
+- `algo`
+- `seed`
+- original W&B project, run id, and artifact names
+- HF repo id and package path
+- primary file list
+
+## Recommended Procedure
+
+1. Inventory the target artifacts.
+ Create a manifest with artifact name, aliases, run id, problem, algorithm, seed, expected files, and whether the artifact is part of the official/release subset.
+
+2. Freeze the migration manifest before uploads.
+ Use the manifest as the source of truth so the migration is reproducible and reviewable.
+
+3. Download the original W&B artifacts.
+ For each target artifact, download the stored files and extract the associated W&B run config needed to rebuild the model outside W&B.
+
+4. Build the HF package.
+ Upload the weights together with `run_config.json` and `metadata.json` into the per-model-family HF repo under the deterministic package path:
+ `problem_id[/extra_parts]/seed_`
+
+5. Record the mapping.
+ For every migrated checkpoint, store a durable mapping from the W&B artifact alias to the HF repo id, revision, and package path.
+
+6. Validate restores before cleanup.
+ Test a representative sample end-to-end with EngiOpt’s evaluation or surrogate-model loading path and confirm the HF package reproduces the expected model restore behavior.
+
+7. Add pointers back into W&B metadata.
+ Once validated, update run metadata, summaries, or notes so the WandB run points to the canonical HF checkpoint location.
+
+8. Keep an overlap period.
+ Retain both HF and W&B copies long enough to verify that downstream users and scripts are not relying on the old blob storage path.
+
+9. Only then define deletion policy.
+ Deletion or retention changes should happen in a separate maintenance pass, using the manifest as the authoritative record.
+
+## Validation Checklist
+
+Before any cleanup is considered for a migrated checkpoint:
+
+1. The HF package contains all required weight files.
+2. `run_config.json` is present and sufficient to rebuild the model.
+3. `metadata.json` is present and points back to the original W&B lineage.
+4. A real EngiOpt load path has succeeded against the HF package.
+5. The corresponding W&B run contains the HF pointer or mapping information.
+
+## Cleanup Guidance for Later
+
+When the project is ready to reclaim W&B storage:
+
+1. Start with official checkpoints only.
+2. Delete in small batches, not all at once.
+3. Confirm the manifest entry is complete before each deletion.
+4. Re-run a small restore audit after each batch.
+5. Keep at least one validated overlap window where both copies coexist.
+
+## Notes
+
+- W&B remains part of the lineage story even after HF becomes the canonical checkpoint host.
+- HF should be treated as the long-lived storage backend for public or durable checkpoints.
+- Backward compatibility matters more than immediate cleanup.
diff --git a/engiopt/cgan_1d/cgan_1d.py b/engiopt/cgan_1d/cgan_1d.py
index 211996e1..d8394a6e 100644
--- a/engiopt/cgan_1d/cgan_1d.py
+++ b/engiopt/cgan_1d/cgan_1d.py
@@ -22,6 +22,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -64,6 +66,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -398,13 +408,21 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]:
th.save(ckpt_gen, "generator.pth")
th.save(ckpt_disc, "discriminator.pth")
- if args.track:
- artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model")
- artifact_gen.add_file("generator.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model")
- artifact_disc.add_file("discriminator.pth")
-
- wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"generator.pth": "generator.pth", "discriminator.pth": "discriminator.pth"},
+ run_config=vars(args),
+ primary_files=["generator.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_generator": "generator.pth",
+ f"{args.problem_id}_{args.algo}_discriminator": "discriminator.pth",
+ },
+ )
wandb.finish()
diff --git a/engiopt/cgan_1d/evaluate_cgan_1d.py b/engiopt/cgan_1d/evaluate_cgan_1d.py
index a8beb031..9d74cd4f 100644
--- a/engiopt/cgan_1d/evaluate_cgan_1d.py
+++ b/engiopt/cgan_1d/evaluate_cgan_1d.py
@@ -15,8 +15,9 @@
from engiopt import metrics
from engiopt.cgan_1d.cgan_1d import Generator
from engiopt.cgan_1d.cgan_1d import prepare_data
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
-import wandb
@dataclasses.dataclass
@@ -31,6 +32,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 10
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -74,30 +83,28 @@ class Args:
)
### Set Up Generator ###
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_cgan_1d_generator:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_cgan_1d_generator:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="cgan_1d",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["generator.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"generator.pth": f"{args.problem_id}_cgan_1d_generator"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- artifact_dir = artifact.download()
- ckpt_path = os.path.join(artifact_dir, "generator.pth")
+ ckpt_path = resolved.files["generator.pth"]
ckpt = th.load(ckpt_path, map_location=device)
_, conds_normalizer, design_normalizer = prepare_data(problem, device)
model = Generator(
- latent_dim=run.config["latent_dim"],
+ latent_dim=run_config["latent_dim"],
n_conds=len(problem.conditions_keys),
design_shape=design_shape,
design_normalizer=design_normalizer,
@@ -107,7 +114,7 @@ def __init__(self):
model.eval()
# Sample noise and generate designs
- z = th.randn((args.n_samples, run.config["latent_dim"]), device=device)
+ z = th.randn((args.n_samples, run_config["latent_dim"]), device=device)
gen_designs = model(z, conditions_tensor)
gen_designs_np = gen_designs.detach().cpu().numpy()
print(gen_designs_np.shape)
diff --git a/engiopt/cgan_2d/cgan_2d.py b/engiopt/cgan_2d/cgan_2d.py
index fb68207b..79ee4e53 100644
--- a/engiopt/cgan_2d/cgan_2d.py
+++ b/engiopt/cgan_2d/cgan_2d.py
@@ -19,6 +19,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -40,6 +42,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -317,13 +327,21 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]:
th.save(ckpt_gen, "generator.pth")
th.save(ckpt_disc, "discriminator.pth")
- if args.track:
- artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model")
- artifact_gen.add_file("generator.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model")
- artifact_disc.add_file("discriminator.pth")
-
- wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"generator.pth": "generator.pth", "discriminator.pth": "discriminator.pth"},
+ run_config=vars(args),
+ primary_files=["generator.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_generator": "generator.pth",
+ f"{args.problem_id}_{args.algo}_discriminator": "discriminator.pth",
+ },
+ )
wandb.finish()
diff --git a/engiopt/cgan_2d/evaluate_cgan_2d.py b/engiopt/cgan_2d/evaluate_cgan_2d.py
index c0c93d4d..aa6d7cba 100644
--- a/engiopt/cgan_2d/evaluate_cgan_2d.py
+++ b/engiopt/cgan_2d/evaluate_cgan_2d.py
@@ -13,8 +13,9 @@
from engiopt import metrics
from engiopt.cgan_2d.cgan_2d import Generator
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
-import wandb
@dataclasses.dataclass
@@ -29,6 +30,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name (if any)."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -66,28 +75,26 @@ class Args:
)
### Set Up Generator ###
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_cgan_2d_generator:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_cgan_2d_generator:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="cgan_2d",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["generator.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"generator.pth": f"{args.problem_id}_cgan_2d_generator"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- artifact_dir = artifact.download()
- ckpt_path = os.path.join(artifact_dir, "generator.pth")
+ ckpt_path = resolved.files["generator.pth"]
ckpt = th.load(ckpt_path, map_location=device)
model = Generator(
- latent_dim=run.config["latent_dim"],
+ latent_dim=run_config["latent_dim"],
n_conds=len(problem.conditions_keys),
design_shape=problem.design_space.shape,
).to(device)
@@ -95,7 +102,7 @@ def __init__(self):
model.eval()
# Sample noise and generate designs
- z = th.randn((args.n_samples, run.config["latent_dim"]), device=device)
+ z = th.randn((args.n_samples, run_config["latent_dim"]), device=device)
gen_designs = model(z, conditions_tensor)
gen_designs_np = gen_designs.detach().cpu().numpy()
gen_designs_np = np.clip(gen_designs_np, 1e-3, 1.0)
diff --git a/engiopt/cgan_bezier/cgan_bezier.py b/engiopt/cgan_bezier/cgan_bezier.py
index 053ad9f4..2e419686 100644
--- a/engiopt/cgan_bezier/cgan_bezier.py
+++ b/engiopt/cgan_bezier/cgan_bezier.py
@@ -18,11 +18,13 @@
from torch import nn
import torch.nn.functional as f
import tyro
-import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
+import wandb
if TYPE_CHECKING:
from collections.abc import Callable
@@ -46,6 +48,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 6
"""Random seed."""
@@ -640,14 +650,25 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
th.save(ckpt_gen, "bezier_generator.pth")
th.save(ckpt_disc, "bezier_discriminator.pth")
- if args.track:
- artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model")
- artifact_gen.add_file("bezier_generator.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model")
- artifact_disc.add_file("bezier_discriminator.pth")
-
- wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={
+ "bezier_generator.pth": "bezier_generator.pth",
+ "bezier_discriminator.pth": "bezier_discriminator.pth",
+ },
+ run_config=vars(args),
+ primary_files=["bezier_generator.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_generator": "bezier_generator.pth",
+ f"{args.problem_id}_{args.algo}_discriminator": "bezier_discriminator.pth",
+ },
+ )
if args.track:
wandb.finish()
diff --git a/engiopt/cgan_cnn_2d/cgan_cnn_2d.py b/engiopt/cgan_cnn_2d/cgan_cnn_2d.py
index 8da8c704..ef084124 100644
--- a/engiopt/cgan_cnn_2d/cgan_cnn_2d.py
+++ b/engiopt/cgan_cnn_2d/cgan_cnn_2d.py
@@ -18,6 +18,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -39,6 +41,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -413,13 +423,24 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]:
th.save(ckpt_gen, "generator.pth")
th.save(ckpt_disc, "discriminator.pth")
- if args.track:
- artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model")
- artifact_gen.add_file("generator.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model")
- artifact_disc.add_file("discriminator.pth")
-
- wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={
+ "generator.pth": "generator.pth",
+ "discriminator.pth": "discriminator.pth",
+ },
+ run_config=vars(args),
+ primary_files=["generator.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_generator": "generator.pth",
+ f"{args.problem_id}_{args.algo}_discriminator": "discriminator.pth",
+ },
+ )
wandb.finish()
diff --git a/engiopt/cgan_cnn_2d/evaluate_cgan_cnn_2d.py b/engiopt/cgan_cnn_2d/evaluate_cgan_cnn_2d.py
index 5e63e1e2..f9dfd5a3 100644
--- a/engiopt/cgan_cnn_2d/evaluate_cgan_cnn_2d.py
+++ b/engiopt/cgan_cnn_2d/evaluate_cgan_cnn_2d.py
@@ -13,8 +13,9 @@
from engiopt import metrics
from engiopt.cgan_cnn_2d.cgan_cnn_2d import Generator
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
-import wandb
@dataclasses.dataclass
@@ -29,6 +30,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -66,35 +75,34 @@ class Args:
### Set Up Generator ###
- # Restores the pytorch model from wandb
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_cgan_cnn_2d_generator:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_cgan_cnn_2d_generator:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
- artifact_dir = artifact.download()
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="cgan_cnn_2d",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["generator.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"generator.pth": f"{args.problem_id}_cgan_cnn_2d_generator"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- ckpt_path = os.path.join(artifact_dir, "generator.pth")
+ ckpt_path = resolved.files["generator.pth"]
ckpt = th.load(ckpt_path, map_location=th.device(device))
model = Generator(
- latent_dim=run.config["latent_dim"], n_conds=len(problem.conditions_keys), design_shape=problem.design_space.shape
+ latent_dim=run_config["latent_dim"],
+ n_conds=len(problem.conditions_keys),
+ design_shape=problem.design_space.shape,
)
model.load_state_dict(ckpt["generator"])
model.eval() # Set to evaluation mode
model.to(device)
# Sample noise as generator input
- z = th.randn((args.n_samples, run.config["latent_dim"], 1, 1), device=device, dtype=th.float)
+ z = th.randn((args.n_samples, run_config["latent_dim"], 1, 1), device=device, dtype=th.float)
# Generate a batch of designs
gen_designs = model(z, conditions_tensor)
diff --git a/engiopt/cgan_cnn_3d/cgan_cnn_3d.py b/engiopt/cgan_cnn_3d/cgan_cnn_3d.py
index 9f94df9d..d47ff232 100644
--- a/engiopt/cgan_cnn_3d/cgan_cnn_3d.py
+++ b/engiopt/cgan_cnn_3d/cgan_cnn_3d.py
@@ -22,6 +22,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.metrics import dpp_diversity
from engiopt.metrics import mmd
from engiopt.reproducibility import enable_strict_determinism
@@ -47,6 +49,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -681,15 +691,22 @@ def sample_3d_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]:
th.save(ckpt_gen, "generator_3d.pth")
th.save(ckpt_disc, "discriminator_3d.pth")
-
- if args.track:
- artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator_3d", type="model")
- artifact_gen.add_file("generator_3d.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator_3d", type="model")
- artifact_disc.add_file("discriminator_3d.pth")
-
- wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"generator_3d.pth": "generator_3d.pth", "discriminator_3d.pth": "discriminator_3d.pth"},
+ run_config=vars(args),
+ primary_files=["generator_3d.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_generator_3d": "generator_3d.pth",
+ f"{args.problem_id}_{args.algo}_discriminator_3d": "discriminator_3d.pth",
+ },
+ )
print("3D models saved successfully!")
diff --git a/engiopt/cgan_cnn_3d/evaluate_cgan_cnn_3d.py b/engiopt/cgan_cnn_3d/evaluate_cgan_cnn_3d.py
index 13602960..b3248ce5 100644
--- a/engiopt/cgan_cnn_3d/evaluate_cgan_cnn_3d.py
+++ b/engiopt/cgan_cnn_3d/evaluate_cgan_cnn_3d.py
@@ -13,8 +13,9 @@
from engiopt import metrics
from engiopt.cgan_cnn_3d.cgan_cnn_3d import Generator3D
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
-import wandb
@dataclasses.dataclass
@@ -29,6 +30,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -67,38 +76,35 @@ class Args:
### Set Up Generator ###
- # Restores the pytorch model from wandb
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_cgan_cnn_3d_generator_3d:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_cgan_cnn_3d_generator_3d:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
- artifact_dir = artifact.download()
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="cgan_cnn_3d",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["generator_3d.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"generator_3d.pth": f"{args.problem_id}_cgan_cnn_3d_generator_3d"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- ckpt_path = os.path.join(artifact_dir, "generator_3d.pth")
+ ckpt_path = resolved.files["generator_3d.pth"]
ckpt = th.load(ckpt_path, map_location=th.device(device))
# Safer debug output
for key in ckpt:
print("Checkpoint key:", key)
model = Generator3D(
- latent_dim=run.config["latent_dim"], n_conds=len(problem.conditions), design_shape=problem.design_space.shape
+ latent_dim=run_config["latent_dim"], n_conds=len(problem.conditions), design_shape=problem.design_space.shape
)
model.load_state_dict(ckpt["generator"])
model.eval() # Set to evaluation mode
model.to(device)
# Sample noise as generator input
- z = th.randn((args.n_samples, run.config["latent_dim"], 1, 1, 1), device=device, dtype=th.float)
+ z = th.randn((args.n_samples, run_config["latent_dim"], 1, 1, 1), device=device, dtype=th.float)
# Generate a batch of designs
gen_designs = model(z, conditions_tensor)
diff --git a/engiopt/cgan_vae/cgan_vae.py b/engiopt/cgan_vae/cgan_vae.py
index 70fabda2..946cd2b7 100644
--- a/engiopt/cgan_vae/cgan_vae.py
+++ b/engiopt/cgan_vae/cgan_vae.py
@@ -22,6 +22,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.metrics import dpp_diversity
from engiopt.metrics import mmd
from engiopt.reproducibility import enable_strict_determinism
@@ -47,6 +49,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -808,10 +818,19 @@ def sample_3d_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]:
"multiview_3d_vaegan.pth",
)
- if args.track:
- artifact = wandb.Artifact(f"{args.problem_id}_{args.algo}_models", type="model")
- artifact.add_file("multiview_3d_vaegan.pth")
- wandb.log_artifact(artifact, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"multiview_3d_vaegan.pth": "multiview_3d_vaegan.pth"},
+ run_config=vars(args),
+ primary_files=["multiview_3d_vaegan.pth"],
+ wandb_artifacts={f"{args.problem_id}_{args.algo}_models": "multiview_3d_vaegan.pth"},
+ )
print("3D vae models saved successfully!")
diff --git a/engiopt/cgan_vae/evaluate_cgan_vae.py b/engiopt/cgan_vae/evaluate_cgan_vae.py
index fff54a88..9367de46 100644
--- a/engiopt/cgan_vae/evaluate_cgan_vae.py
+++ b/engiopt/cgan_vae/evaluate_cgan_vae.py
@@ -13,8 +13,9 @@
from engiopt import metrics
from engiopt.cgan_vae.cgan_vae import Generator3D
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
-import wandb
@dataclasses.dataclass
@@ -29,6 +30,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -67,38 +76,35 @@ class Args:
### Set Up Generator ###
- # Restores the pytorch model from wandb
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_cgan_vae_models:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_cgan_vae_models:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
- artifact_dir = artifact.download()
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="cgan_vae",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["multiview_3d_vaegan.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"multiview_3d_vaegan.pth": f"{args.problem_id}_cgan_vae_models"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- ckpt_path = os.path.join(artifact_dir, "multiview_3d_vaegan.pth")
+ ckpt_path = resolved.files["multiview_3d_vaegan.pth"]
ckpt = th.load(ckpt_path, map_location=th.device(device))
# Safer debug output
for key in ckpt:
print("Checkpoint key:", key)
model = Generator3D(
- latent_dim=run.config["latent_dim"], n_conds=len(problem.conditions_keys), design_shape=problem.design_space.shape
+ latent_dim=run_config["latent_dim"], n_conds=len(problem.conditions_keys), design_shape=problem.design_space.shape
)
model.load_state_dict(ckpt["generator"])
model.eval() # Set to evaluation mode
model.to(device)
# Sample noise as generator input
- z = th.randn((args.n_samples, run.config["latent_dim"], 1, 1, 1), device=device, dtype=th.float)
+ z = th.randn((args.n_samples, run_config["latent_dim"], 1, 1, 1), device=device, dtype=th.float)
# Generate a batch of designs
gen_designs = model(z, conditions_tensor)
diff --git a/engiopt/checkpoint_store.py b/engiopt/checkpoint_store.py
new file mode 100644
index 00000000..d0aee320
--- /dev/null
+++ b/engiopt/checkpoint_store.py
@@ -0,0 +1,524 @@
+"""Shared checkpoint save/load helpers for EngiOpt model artifacts."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+import json
+import os
+from pathlib import Path
+import shutil
+import tempfile
+from typing import Any, Literal
+
+from huggingface_hub import HfApi
+from huggingface_hub import snapshot_download
+
+import wandb
+
+CheckpointBackend = Literal["hf", "wandb", "both", "none"]
+ModelSource = Literal["auto", "hf", "wandb", "local"]
+
+
+@dataclass(frozen=True)
+class ResolvedCheckpoint:
+ """Resolved checkpoint package with local files and serialized config."""
+
+ source: Literal["hf", "wandb", "local"]
+ root_dir: str
+ files: dict[str, str]
+ run_config: dict[str, Any]
+ metadata: dict[str, Any]
+
+
+def build_hf_repo_id(hf_entity: str, hf_repo_prefix: str, algo: str) -> str:
+ """Return the canonical HF repo id for an EngiOpt model family."""
+ repo_suffix = algo.replace("_", "-")
+ return f"{hf_entity}/{hf_repo_prefix}-{repo_suffix}"
+
+
+def build_hf_package_path(problem_id: str, seed: int, extra_parts: list[str] | None = None) -> str:
+ """Return the canonical package path inside an HF repo."""
+ parts = [_sanitize_path_component(problem_id)]
+ if extra_parts:
+ parts.extend(_sanitize_path_component(part) for part in extra_parts)
+ parts.append(f"seed_{seed}")
+ return "/".join(parts)
+
+
+def build_hf_run_package_path(package_path: str, wandb_run_id: str | None) -> str | None:
+ """Return an immutable run-specific package path when a W&B run id is available."""
+ if not wandb_run_id:
+ return None
+ return f"{package_path}/run_{_sanitize_path_component(wandb_run_id)}"
+
+
+def save_checkpoint_package( # noqa: PLR0913
+ *,
+ checkpoint_backend: CheckpointBackend,
+ hf_entity: str,
+ hf_repo_prefix: str,
+ hf_private: bool,
+ problem_id: str,
+ algo: str,
+ seed: int,
+ checkpoint_files: dict[str, str],
+ run_config: dict[str, Any],
+ metadata: dict[str, Any] | None = None,
+ primary_files: list[str] | None = None,
+ extra_path_parts: list[str] | None = None,
+ wandb_artifacts: dict[str, str] | None = None,
+) -> dict[str, Any]:
+ """Save a checkpoint package to the configured remote backends."""
+ info: dict[str, Any] = {
+ "checkpoint_backend": checkpoint_backend,
+ "hf_repo_id": None,
+ "hf_package_path": None,
+ "hf_run_package_path": None,
+ "hf_revision": None,
+ "hf_run_revision": None,
+ }
+ metadata_payload = _build_metadata(
+ problem_id=problem_id,
+ algo=algo,
+ seed=seed,
+ checkpoint_backend=checkpoint_backend,
+ checkpoint_files=checkpoint_files,
+ primary_files=primary_files,
+ metadata=metadata,
+ )
+ metadata_payload.update(_build_wandb_run_metadata())
+
+ if checkpoint_backend in {"hf", "both"}:
+ repo_id = build_hf_repo_id(hf_entity, hf_repo_prefix, algo)
+ package_path = build_hf_package_path(problem_id, seed, extra_path_parts)
+ info["hf_repo_id"] = repo_id
+ info["hf_package_path"] = package_path
+ metadata_payload["hf_repo_id"] = repo_id
+ metadata_payload["hf_package_path"] = package_path
+
+ run_package_path = build_hf_run_package_path(package_path, metadata_payload.get("wandb_run_id"))
+ info["hf_run_package_path"] = run_package_path
+ if run_package_path is not None:
+ metadata_payload["hf_run_package_path"] = run_package_path
+ run_revision = _upload_package_to_hf(
+ repo_id=repo_id,
+ hf_private=hf_private,
+ package_path=run_package_path,
+ checkpoint_files=checkpoint_files,
+ run_config=run_config,
+ metadata=metadata_payload,
+ algo=algo,
+ )
+ info["hf_run_revision"] = run_revision
+ metadata_payload["hf_run_revision"] = run_revision
+
+ revision = _upload_package_to_hf(
+ repo_id=repo_id,
+ hf_private=hf_private,
+ package_path=package_path,
+ checkpoint_files=checkpoint_files,
+ run_config=run_config,
+ metadata=metadata_payload,
+ algo=algo,
+ )
+ info["hf_revision"] = revision
+ metadata_payload["hf_revision"] = revision
+
+ if checkpoint_backend in {"wandb", "both"} and wandb_artifacts and wandb.run is not None:
+ for artifact_name, file_path in wandb_artifacts.items():
+ artifact = wandb.Artifact(artifact_name, type="model", metadata=metadata_payload)
+ artifact.add_file(file_path)
+ wandb.log_artifact(artifact, aliases=[f"seed_{seed}"])
+
+ if wandb.run is not None:
+ _log_checkpoint_summary_to_wandb(metadata_payload, info)
+
+ return info
+
+
+def resolve_named_checkpoint( # noqa: PLR0913
+ *,
+ model_source: ModelSource,
+ problem_id: str,
+ algo: str,
+ seed: int,
+ hf_entity: str,
+ hf_repo_prefix: str,
+ required_files: list[str],
+ wandb_project: str,
+ wandb_entity: str | None,
+ wandb_artifact_names: dict[str, str],
+ wandb_config_artifact_name: str | None = None,
+ local_model_dir: str | None = None,
+ extra_path_parts: list[str] | None = None,
+) -> ResolvedCheckpoint:
+ """Resolve a checkpoint package by the standard EngiOpt problem/algo/seed naming."""
+ errors: list[str] = []
+ if model_source in {"auto", "hf"}:
+ try:
+ return _resolve_hf_package(
+ repo_id=build_hf_repo_id(hf_entity, hf_repo_prefix, algo),
+ package_path=build_hf_package_path(problem_id, seed, extra_path_parts),
+ required_files=required_files,
+ )
+ except Exception as exc:
+ if model_source == "hf":
+ raise
+ errors.append(f"hf: {exc}")
+
+ if model_source in {"auto", "wandb"}:
+ try:
+ return _resolve_wandb_package(
+ _required_files=required_files,
+ artifact_names=wandb_artifact_names,
+ wandb_project=wandb_project,
+ wandb_entity=wandb_entity,
+ seed=seed,
+ config_artifact_name=wandb_config_artifact_name,
+ )
+ except Exception as exc:
+ if model_source == "wandb":
+ raise
+ errors.append(f"wandb: {exc}")
+
+ if local_model_dir is not None and model_source in {"auto", "local"}:
+ return _resolve_local_package(local_model_dir, required_files)
+
+ attempted = ", ".join(errors) if errors else "no backends attempted"
+ raise FileNotFoundError(f"Unable to resolve checkpoint for {algo}/{problem_id}/seed_{seed}: {attempted}")
+
+
+def resolve_checkpoint_reference(
+ *,
+ model_source: ModelSource,
+ model_ref: str,
+ required_files: list[str] | None = None,
+ active_wandb_run: wandb.sdk.wandb_run.Run | None = None,
+) -> ResolvedCheckpoint:
+ """Resolve a checkpoint package from an explicit HF/W&B/local reference."""
+ inferred_source = model_source
+ normalized_ref = model_ref
+ if model_source == "auto":
+ if model_ref.startswith("hf://"):
+ inferred_source = "hf"
+ elif model_ref.startswith("wandb://"):
+ inferred_source = "wandb"
+ elif os.path.isdir(model_ref):
+ inferred_source = "local"
+ else:
+ inferred_source = "wandb"
+
+ if inferred_source == "hf":
+ repo_id, package_path = _parse_hf_reference(model_ref)
+ return _resolve_hf_package(repo_id=repo_id, package_path=package_path, required_files=required_files or [])
+ if inferred_source == "wandb":
+ artifact_path = model_ref.removeprefix("wandb://")
+ return _resolve_wandb_reference(
+ artifact_path=artifact_path,
+ required_files=required_files or [],
+ active_wandb_run=active_wandb_run,
+ )
+ if inferred_source == "local":
+ normalized_ref = model_ref.removeprefix("file://")
+ return _resolve_local_package(normalized_ref, required_files or [])
+
+ raise ValueError(f"Unsupported model source: {model_source}")
+
+
+def _upload_package_to_hf( # noqa: PLR0913
+ *,
+ repo_id: str,
+ hf_private: bool,
+ package_path: str,
+ checkpoint_files: dict[str, str],
+ run_config: dict[str, Any],
+ metadata: dict[str, Any],
+ algo: str,
+) -> str | None:
+ api = HfApi()
+ api.create_repo(repo_id=repo_id, repo_type="model", private=hf_private, exist_ok=True)
+ _ensure_hf_repo_readme(api, repo_id, algo)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ stage_dir = Path(tmpdir)
+ for package_name, file_path in checkpoint_files.items():
+ shutil.copy2(file_path, stage_dir / package_name)
+ _write_json(stage_dir / "run_config.json", run_config)
+ _write_json(stage_dir / "metadata.json", metadata)
+
+ commit_info = api.upload_folder(
+ repo_id=repo_id,
+ repo_type="model",
+ folder_path=tmpdir,
+ path_in_repo=package_path,
+ commit_message=f"Upload checkpoint for {algo} {package_path}",
+ )
+ return getattr(commit_info, "oid", None)
+
+
+def _resolve_hf_package(*, repo_id: str, package_path: str, required_files: list[str]) -> ResolvedCheckpoint:
+ repo_snapshot = snapshot_download(
+ repo_id=repo_id,
+ repo_type="model",
+ allow_patterns=[f"{package_path}/*"],
+ )
+ root_dir = os.path.join(repo_snapshot, package_path)
+ if not os.path.isdir(root_dir):
+ raise FileNotFoundError(f"HF package path not found: {repo_id}/{package_path}")
+ return _load_package_from_directory(root_dir=root_dir, required_files=required_files, source="hf")
+
+
+def _resolve_local_package(local_model_dir: str, required_files: list[str]) -> ResolvedCheckpoint:
+ if not os.path.isdir(local_model_dir):
+ raise FileNotFoundError(f"Local checkpoint directory not found: {local_model_dir}")
+ return _load_package_from_directory(root_dir=local_model_dir, required_files=required_files, source="local")
+
+
+def _resolve_wandb_package(
+ *,
+ _required_files: list[str],
+ artifact_names: dict[str, str],
+ wandb_project: str,
+ wandb_entity: str | None,
+ seed: int,
+ config_artifact_name: str | None,
+) -> ResolvedCheckpoint:
+ api = wandb.Api()
+ files: dict[str, str] = {}
+ for file_name, artifact_name in artifact_names.items():
+ artifact_path = _build_wandb_artifact_path(
+ artifact_name=artifact_name,
+ wandb_project=wandb_project,
+ wandb_entity=wandb_entity,
+ seed=seed,
+ )
+ artifact = api.artifact(artifact_path, type="model")
+ artifact_dir = artifact.download()
+ files[file_name] = os.path.join(artifact_dir, file_name)
+
+ config_artifact = config_artifact_name or next(iter(artifact_names.values()))
+ config_artifact_path = _build_wandb_artifact_path(
+ artifact_name=config_artifact,
+ wandb_project=wandb_project,
+ wandb_entity=wandb_entity,
+ seed=seed,
+ )
+ artifact = api.artifact(config_artifact_path, type="model")
+ run = artifact.logged_by()
+ if run is None or not hasattr(run, "config"):
+ raise ValueError("Failed to retrieve W&B run config from artifact")
+
+ return ResolvedCheckpoint(
+ source="wandb",
+ root_dir=os.path.dirname(next(iter(files.values()))),
+ files=files,
+ run_config=dict(run.config),
+ metadata={
+ "source": "wandb",
+ "artifact_paths": {
+ file_name: _build_wandb_artifact_path(
+ artifact_name=artifact_name,
+ wandb_project=wandb_project,
+ wandb_entity=wandb_entity,
+ seed=seed,
+ )
+ for file_name, artifact_name in artifact_names.items()
+ },
+ },
+ )
+
+
+def _resolve_wandb_reference(
+ *,
+ artifact_path: str,
+ required_files: list[str],
+ active_wandb_run: wandb.sdk.wandb_run.Run | None,
+) -> ResolvedCheckpoint:
+ artifact = (
+ active_wandb_run.use_artifact(artifact_path, type="model")
+ if active_wandb_run is not None
+ else wandb.Api().artifact(artifact_path, type="model")
+ )
+ artifact_dir = artifact.download()
+ files = _discover_reference_files(artifact_dir, required_files)
+ run = artifact.logged_by()
+ if run is None or not hasattr(run, "config"):
+ raise ValueError("Failed to retrieve W&B run config from artifact reference")
+ return ResolvedCheckpoint(
+ source="wandb",
+ root_dir=artifact_dir,
+ files=files,
+ run_config=dict(run.config),
+ metadata={"source": "wandb", "artifact_path": artifact_path},
+ )
+
+
+def _load_package_from_directory(
+ *, root_dir: str, required_files: list[str], source: Literal["hf", "local"]
+) -> ResolvedCheckpoint:
+ run_config_path = os.path.join(root_dir, "run_config.json")
+ metadata_path = os.path.join(root_dir, "metadata.json")
+ if not os.path.exists(run_config_path):
+ raise FileNotFoundError(f"Missing run_config.json in {root_dir}")
+ run_config = _read_json(run_config_path)
+ metadata = _read_json(metadata_path) if os.path.exists(metadata_path) else {}
+ package_files = required_files or _discover_package_files(root_dir, metadata)
+ files = {file_name: os.path.join(root_dir, file_name) for file_name in package_files}
+ for file_name, file_path in files.items():
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"Missing checkpoint file {file_name} in {root_dir}")
+ return ResolvedCheckpoint(source=source, root_dir=root_dir, files=files, run_config=run_config, metadata=metadata)
+
+
+def _discover_reference_files(root_dir: str, required_files: list[str]) -> dict[str, str]:
+ if required_files:
+ files = {file_name: os.path.join(root_dir, file_name) for file_name in required_files}
+ for file_name, file_path in files.items():
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"Missing checkpoint file {file_name} in {root_dir}")
+ return files
+
+ discovered = [entry for entry in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, entry))]
+ if len(discovered) != 1:
+ raise ValueError(f"Expected exactly one file in {root_dir}, found {discovered}")
+ file_name = discovered[0]
+ return {file_name: os.path.join(root_dir, file_name)}
+
+
+def _discover_package_files(root_dir: str, metadata: dict[str, Any]) -> list[str]:
+ primary_files = metadata.get("primary_files")
+ if isinstance(primary_files, list) and primary_files:
+ return [str(file_name) for file_name in primary_files]
+
+ excluded_names = {"metadata.json", "run_config.json"}
+ discovered = [
+ entry
+ for entry in os.listdir(root_dir)
+ if os.path.isfile(os.path.join(root_dir, entry)) and entry not in excluded_names
+ ]
+ if not discovered:
+ raise FileNotFoundError(f"No checkpoint files found in {root_dir}")
+ return sorted(discovered)
+
+
+def _parse_hf_reference(model_ref: str) -> tuple[str, str]:
+ normalized = model_ref.removeprefix("hf://")
+ first_sep = normalized.find("/")
+ second_sep = normalized.find("/", first_sep + 1)
+ if first_sep == -1 or second_sep == -1:
+ raise ValueError(f"HF model references must look like hf:////, got {model_ref}")
+ repo_id = normalized[:second_sep]
+ package_path = normalized[second_sep + 1 :]
+ return repo_id, package_path
+
+
+def _build_wandb_artifact_path(
+ *,
+ artifact_name: str,
+ wandb_project: str,
+ wandb_entity: str | None,
+ seed: int,
+) -> str:
+ project_path = f"{wandb_entity}/{wandb_project}" if wandb_entity is not None else wandb_project
+ return f"{project_path}/{artifact_name}:seed_{seed}"
+
+
+def _build_metadata( # noqa: PLR0913
+ *,
+ problem_id: str,
+ algo: str,
+ seed: int,
+ checkpoint_backend: CheckpointBackend,
+ checkpoint_files: dict[str, str],
+ primary_files: list[str] | None,
+ metadata: dict[str, Any] | None,
+) -> dict[str, Any]:
+ payload = dict(metadata or {})
+ payload.update(
+ {
+ "problem_id": problem_id,
+ "algo": algo,
+ "seed": seed,
+ "checkpoint_backend": checkpoint_backend,
+ "checkpoint_files": sorted(checkpoint_files),
+ "primary_files": primary_files or sorted(checkpoint_files),
+ }
+ )
+ return payload
+
+
+def _build_wandb_run_metadata() -> dict[str, Any]:
+ """Return W&B run identity fields for checkpoint metadata when tracking is active."""
+ if wandb.run is None:
+ return {}
+
+ run = wandb.run
+ entity = getattr(run, "entity", None)
+ project = getattr(run, "project", None)
+ run_id = getattr(run, "id", None)
+ run_url = None
+ if entity and project and run_id:
+ run_url = f"https://wandb.ai/{entity}/{project}/runs/{run_id}"
+
+ return {
+ "wandb_entity": entity,
+ "wandb_project": project,
+ "wandb_run_id": run_id,
+ "wandb_run_url": run_url,
+ }
+
+
+def _ensure_hf_repo_readme(api: HfApi, repo_id: str, algo: str) -> None:
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model")
+ if "README.md" in files:
+ return
+ with tempfile.NamedTemporaryFile("w", delete=False, suffix=".md") as tmpfile:
+ tmpfile.write(_repo_readme_text(algo))
+ tmp_path = tmpfile.name
+ try:
+ api.upload_file(
+ repo_id=repo_id,
+ repo_type="model",
+ path_or_fileobj=tmp_path,
+ path_in_repo="README.md",
+ commit_message=f"Add README for {algo} checkpoint repo",
+ )
+ finally:
+ os.unlink(tmp_path)
+
+
+def _repo_readme_text(algo: str) -> str:
+ return (
+ f"# EngiOpt {algo}\n\n"
+ "This repository stores EngiOpt checkpoint packages for one model family.\n\n"
+ "Each checkpoint package contains model weight files together with `run_config.json` "
+ "and `metadata.json` so evaluation can run without depending on W&B run config state.\n"
+ )
+
+
+def _log_checkpoint_summary_to_wandb(metadata: dict[str, Any], info: dict[str, Any]) -> None:
+ if wandb.run is None:
+ return
+ wandb.summary["checkpoint_backend"] = info["checkpoint_backend"]
+ if info["hf_repo_id"] is not None:
+ wandb.summary["hf_repo_id"] = info["hf_repo_id"]
+ wandb.summary["hf_package_path"] = info["hf_package_path"]
+ wandb.summary["hf_revision"] = info["hf_revision"]
+ if info["hf_run_package_path"] is not None:
+ wandb.summary["hf_run_package_path"] = info["hf_run_package_path"]
+ wandb.summary["hf_run_revision"] = info["hf_run_revision"]
+ wandb.summary["checkpoint_primary_files"] = metadata["primary_files"]
+
+
+def _write_json(path: Path, payload: dict[str, Any]) -> None:
+ with path.open("w", encoding="utf-8") as handle:
+ json.dump(payload, handle, indent=2, sort_keys=True, default=str)
+
+
+def _read_json(path: str) -> dict[str, Any]:
+ with open(path, encoding="utf-8") as handle:
+ return json.load(handle)
+
+
+def _sanitize_path_component(value: str) -> str:
+ return value.replace(os.sep, "_").replace(" ", "_")
diff --git a/engiopt/diffusion_1d/diffusion_1d.py b/engiopt/diffusion_1d/diffusion_1d.py
index 7c6d1a40..971d85cc 100644
--- a/engiopt/diffusion_1d/diffusion_1d.py
+++ b/engiopt/diffusion_1d/diffusion_1d.py
@@ -22,6 +22,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -100,6 +102,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -274,10 +284,18 @@ class Args:
}
th.save(ckpt, "model.pth")
- if args.track:
- artifact = wandb.Artifact(f"{args.problem_id}_{args.algo}_model", type="model")
- artifact.add_file("model.pth")
-
- wandb.log_artifact(artifact, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"model.pth": "model.pth"},
+ run_config=vars(args),
+ primary_files=["model.pth"],
+ wandb_artifacts={f"{args.problem_id}_{args.algo}_model": "model.pth"},
+ )
wandb.finish()
diff --git a/engiopt/diffusion_1d/evaluate_diffusion_1d.py b/engiopt/diffusion_1d/evaluate_diffusion_1d.py
index bfc961ef..eb503866 100644
--- a/engiopt/diffusion_1d/evaluate_diffusion_1d.py
+++ b/engiopt/diffusion_1d/evaluate_diffusion_1d.py
@@ -15,9 +15,10 @@
import tyro
from engiopt import metrics
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.diffusion_1d.diffusion_1d import prepare_data
-import wandb
@dataclasses.dataclass
@@ -32,6 +33,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 10
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -83,31 +92,29 @@ class Args:
)
### Load Diffusion Model ###
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_diffusion_1d_model:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_diffusion_1d_model:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="diffusion_1d",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["model.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"model.pth": f"{args.problem_id}_diffusion_1d_model"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- artifact_dir = artifact.download()
- ckpt_path = os.path.join(artifact_dir, "model.pth")
+ ckpt_path = resolved.files["model.pth"]
ckpt = th.load(ckpt_path, map_location=device)
_, design_normalizer = prepare_data(problem, padding_size, device)
model = Unet1D(
- dim=run.config["unet_dim"], # Used for the sinusoidal positional embeddings
- channels=run.config["n_channels"], # Number of channels in the input
+ dim=run_config["unet_dim"], # Used for the sinusoidal positional embeddings
+ channels=run_config["n_channels"], # Number of channels in the input
).to(device)
diffusion = GaussianDiffusion1D(
diff --git a/engiopt/diffusion_2d_cond/diffusion_2d_cond.py b/engiopt/diffusion_2d_cond/diffusion_2d_cond.py
index 5f071d7e..a3cbb41d 100644
--- a/engiopt/diffusion_2d_cond/diffusion_2d_cond.py
+++ b/engiopt/diffusion_2d_cond/diffusion_2d_cond.py
@@ -18,6 +18,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -42,6 +44,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -442,10 +452,18 @@ def sample_designs(model: UNet2DConditionModel, n_designs: int = 25) -> tuple[th
}
th.save(ckpt_model, "model.pth")
- if args.track:
- artifact_model = wandb.Artifact(f"{args.problem_id}_{args.algo}_model", type="model")
- artifact_model.add_file("model.pth")
-
- wandb.log_artifact(artifact_model, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"model.pth": "model.pth"},
+ run_config=vars(args),
+ primary_files=["model.pth"],
+ wandb_artifacts={f"{args.problem_id}_{args.algo}_model": "model.pth"},
+ )
wandb.finish()
diff --git a/engiopt/diffusion_2d_cond/evaluate_diffusion_2d_cond.py b/engiopt/diffusion_2d_cond/evaluate_diffusion_2d_cond.py
index 8403908b..9bd585f8 100644
--- a/engiopt/diffusion_2d_cond/evaluate_diffusion_2d_cond.py
+++ b/engiopt/diffusion_2d_cond/evaluate_diffusion_2d_cond.py
@@ -13,10 +13,11 @@
import tyro
from engiopt import metrics
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.diffusion_2d_cond.diffusion_2d_cond import beta_schedule
from engiopt.diffusion_2d_cond.diffusion_2d_cond import DiffusionSampler
-import wandb
@dataclasses.dataclass
@@ -31,6 +32,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -70,24 +79,22 @@ class Args:
conditions_tensor = conditions_tensor.unsqueeze(1)
### Set Up Diffusion Model ###
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_diffusion_2d_cond_model:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_diffusion_2d_cond_model:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="diffusion_2d_cond",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["model.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"model.pth": f"{args.problem_id}_diffusion_2d_cond_model"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- artifact_dir = artifact.download()
- ckpt_path = os.path.join(artifact_dir, "model.pth")
+ ckpt_path = resolved.files["model.pth"]
ckpt = th.load(ckpt_path, map_location=device)
# Build UNet
@@ -99,7 +106,7 @@ def __init__(self):
block_out_channels=(32, 64, 128, 256),
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
- layers_per_block=run.config["layers_per_block"],
+ layers_per_block=run_config["layers_per_block"],
transformer_layers_per_block=1,
encoder_hid_dim=len(problem.conditions_keys),
only_cross_attention=True,
@@ -107,18 +114,18 @@ def __init__(self):
# Noise schedule
options = {
- "cosine": run.config["noise_schedule"] == "cosine",
- "exp_biasing": run.config["noise_schedule"] == "exp",
+ "cosine": run_config["noise_schedule"] == "cosine",
+ "exp_biasing": run_config["noise_schedule"] == "exp",
"exp_bias_factor": 1,
}
betas = beta_schedule(
- t=run.config["num_timesteps"],
+ t=run_config["num_timesteps"],
start=1e-4,
end=0.02,
scale=1.0,
options=options,
)
- ddm_sampler = DiffusionSampler(run.config["num_timesteps"], betas)
+ ddm_sampler = DiffusionSampler(run_config["num_timesteps"], betas)
model.load_state_dict(ckpt["model"])
model.eval()
@@ -126,8 +133,8 @@ def __init__(self):
# Generate and reshape
design_shape: tuple = problem.design_space.shape
gen_designs = th.randn((args.n_samples, 1, *design_shape), device=device)
- assert run.config["num_timesteps"] is not None
- for i in reversed(range(run.config["num_timesteps"])):
+ assert run_config["num_timesteps"] is not None
+ for i in reversed(range(run_config["num_timesteps"])):
t = th.full((args.n_samples,), i, device=device, dtype=th.long)
gen_designs = ddm_sampler.sample_timestep(model, gen_designs, t, conditions_tensor)
diff --git a/engiopt/gan_1d/evaluate_gan_1d.py b/engiopt/gan_1d/evaluate_gan_1d.py
index 1df4d887..9adb0392 100644
--- a/engiopt/gan_1d/evaluate_gan_1d.py
+++ b/engiopt/gan_1d/evaluate_gan_1d.py
@@ -13,10 +13,11 @@
import tyro
from engiopt import metrics
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.gan_1d.gan_1d import Generator
from engiopt.gan_1d.gan_1d import prepare_data
-import wandb
@dataclasses.dataclass
@@ -31,6 +32,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -74,36 +83,34 @@ class Args:
)
### Set Up Generator ###
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_gan_1d_generator:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_gan_1d_generator:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="gan_1d",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["generator.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"generator.pth": f"{args.problem_id}_gan_1d_generator"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- artifact_dir = artifact.download()
- ckpt_path = os.path.join(artifact_dir, "generator.pth")
+ ckpt_path = resolved.files["generator.pth"]
ckpt = th.load(ckpt_path, map_location=device)
_, design_normalizer = prepare_data(problem, device)
model = Generator(
- latent_dim=run.config["latent_dim"], design_shape=design_shape, design_normalizer=design_normalizer
+ latent_dim=run_config["latent_dim"], design_shape=design_shape, design_normalizer=design_normalizer
).to(device)
model.load_state_dict(ckpt["generator"])
model.eval()
# Sample noise and generate designs
- z = th.randn((args.n_samples, run.config["latent_dim"]), device=device)
+ z = th.randn((args.n_samples, run_config["latent_dim"]), device=device)
gen_designs = model(z)
gen_designs_np = gen_designs.detach().cpu().numpy()
diff --git a/engiopt/gan_1d/gan_1d.py b/engiopt/gan_1d/gan_1d.py
index c36067b4..8c8da8bc 100644
--- a/engiopt/gan_1d/gan_1d.py
+++ b/engiopt/gan_1d/gan_1d.py
@@ -21,6 +21,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -46,6 +48,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -339,13 +349,21 @@ def prepare_data(problem: Problem, device: th.device) -> tuple[th.utils.data.Ten
th.save(ckpt_gen, "generator.pth")
th.save(ckpt_disc, "discriminator.pth")
- if args.track:
- artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model")
- artifact_gen.add_file("generator.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model")
- artifact_disc.add_file("discriminator.pth")
-
- wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"generator.pth": "generator.pth", "discriminator.pth": "discriminator.pth"},
+ run_config=vars(args),
+ primary_files=["generator.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_generator": "generator.pth",
+ f"{args.problem_id}_{args.algo}_discriminator": "discriminator.pth",
+ },
+ )
wandb.finish()
diff --git a/engiopt/gan_2d/evaluate_gan_2d.py b/engiopt/gan_2d/evaluate_gan_2d.py
index be3e9429..78fb8331 100644
--- a/engiopt/gan_2d/evaluate_gan_2d.py
+++ b/engiopt/gan_2d/evaluate_gan_2d.py
@@ -12,9 +12,10 @@
import tyro
from engiopt import metrics
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.gan_2d.gan_2d import Generator
-import wandb
@dataclasses.dataclass
@@ -29,6 +30,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -66,35 +75,33 @@ class Args:
)
### Set Up Generator ###
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_gan_2d_generator:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_gan_2d_generator:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="gan_2d",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["generator.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"generator.pth": f"{args.problem_id}_gan_2d_generator"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- artifact_dir = artifact.download()
- ckpt_path = os.path.join(artifact_dir, "generator.pth")
+ ckpt_path = resolved.files["generator.pth"]
ckpt = th.load(ckpt_path, map_location=device)
model = Generator(
- latent_dim=run.config["latent_dim"],
+ latent_dim=run_config["latent_dim"],
design_shape=problem.design_space.shape,
).to(device)
model.load_state_dict(ckpt["generator"])
model.eval()
# Sample noise and generate designs
- z = th.randn((args.n_samples, run.config["latent_dim"]), device=device)
+ z = th.randn((args.n_samples, run_config["latent_dim"]), device=device)
gen_designs = model(z)
gen_designs_np = gen_designs.detach().cpu().numpy()
gen_designs_np = np.clip(gen_designs_np, 1e-3, 1.0)
diff --git a/engiopt/gan_2d/gan_2d.py b/engiopt/gan_2d/gan_2d.py
index 8cfd31b0..4a78cc13 100644
--- a/engiopt/gan_2d/gan_2d.py
+++ b/engiopt/gan_2d/gan_2d.py
@@ -18,6 +18,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -39,6 +41,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -277,13 +287,21 @@ def sample_designs(n_designs: int) -> th.Tensor:
th.save(ckpt_gen, "generator.pth")
th.save(ckpt_disc, "discriminator.pth")
- if args.track:
- artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model")
- artifact_gen.add_file("generator.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model")
- artifact_disc.add_file("discriminator.pth")
-
- wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"generator.pth": "generator.pth", "discriminator.pth": "discriminator.pth"},
+ run_config=vars(args),
+ primary_files=["generator.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_generator": "generator.pth",
+ f"{args.problem_id}_{args.algo}_discriminator": "discriminator.pth",
+ },
+ )
wandb.finish()
diff --git a/engiopt/gan_bezier/evaluate_gan_bezier.py b/engiopt/gan_bezier/evaluate_gan_bezier.py
index 46efa540..0af5cbf4 100644
--- a/engiopt/gan_bezier/evaluate_gan_bezier.py
+++ b/engiopt/gan_bezier/evaluate_gan_bezier.py
@@ -13,11 +13,12 @@
import tyro
from engiopt import metrics
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.gan_bezier.gan_bezier import Generator
from engiopt.gan_bezier.gan_bezier import prepare_data
from engiopt.transforms import flatten_dict_factory
-import wandb
if TYPE_CHECKING:
from gymnasium import spaces
@@ -37,6 +38,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -76,32 +85,30 @@ class Args:
)
### Set Up Generator ###
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_gan_bezier_generator:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_gan_bezier_generator:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="gan_bezier",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["bezier_generator.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"bezier_generator.pth": f"{args.problem_id}_gan_bezier_generator"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- artifact_dir = artifact.download()
- ckpt_path = os.path.join(artifact_dir, "bezier_generator.pth")
+ ckpt_path = resolved.files["bezier_generator.pth"]
ckpt = th.load(ckpt_path, map_location=device)
_, design_scalars_normalizer, _ = prepare_data(problem, args.n_samples, device)
model = Generator(
- latent_dim=run.config["latent_dim"],
- noise_dim=run.config["noise_dim"],
- n_control_points=run.config["bezier_control_pts"],
+ latent_dim=run_config["latent_dim"],
+ noise_dim=run_config["noise_dim"],
+ n_control_points=run_config["bezier_control_pts"],
n_data_points=coords_space.shape[1],
design_scalars_normalizer=design_scalars_normalizer,
eps=_EPS,
@@ -112,8 +119,8 @@ def __init__(self):
# Sample noise and generate designs
bounds = (0.0, 1.0) # Bounds for angle of attack
- c = (bounds[1] - bounds[0]) * th.rand(args.n_samples, run.config["latent_dim"], device=device) + bounds[0]
- z = 0.5 * th.randn(args.n_samples, run.config["noise_dim"], device=device)
+ c = (bounds[1] - bounds[0]) * th.rand(args.n_samples, run_config["latent_dim"], device=device) + bounds[0]
+ z = 0.5 * th.randn(args.n_samples, run_config["noise_dim"], device=device)
gen_designs, _, _, _, _, alphas = model(c, z)
gen_designs_np = gen_designs.detach().cpu().numpy()
diff --git a/engiopt/gan_bezier/gan_bezier.py b/engiopt/gan_bezier/gan_bezier.py
index 85a45f35..7e6a7f64 100644
--- a/engiopt/gan_bezier/gan_bezier.py
+++ b/engiopt/gan_bezier/gan_bezier.py
@@ -18,11 +18,13 @@
from torch import nn
import torch.nn.functional as f
import tyro
-import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
+import wandb
if TYPE_CHECKING:
from collections.abc import Callable
@@ -46,6 +48,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -619,14 +629,25 @@ def prepare_data(problem, batch_size, device, seed=None):
th.save(ckpt_gen, "bezier_generator.pth")
th.save(ckpt_disc, "bezier_discriminator.pth")
- if args.track:
- artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model")
- artifact_gen.add_file("bezier_generator.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model")
- artifact_disc.add_file("bezier_discriminator.pth")
-
- wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={
+ "bezier_generator.pth": "bezier_generator.pth",
+ "bezier_discriminator.pth": "bezier_discriminator.pth",
+ },
+ run_config=vars(args),
+ primary_files=["bezier_generator.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_generator": "bezier_generator.pth",
+ f"{args.problem_id}_{args.algo}_discriminator": "bezier_discriminator.pth",
+ },
+ )
if args.track:
wandb.finish()
diff --git a/engiopt/gan_cnn_2d/evaluate_gan_cnn_2d.py b/engiopt/gan_cnn_2d/evaluate_gan_cnn_2d.py
index bb3d1e2f..c524750e 100644
--- a/engiopt/gan_cnn_2d/evaluate_gan_cnn_2d.py
+++ b/engiopt/gan_cnn_2d/evaluate_gan_cnn_2d.py
@@ -12,9 +12,10 @@
import tyro
from engiopt import metrics
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.gan_cnn_2d.gan_cnn_2d import Generator
-import wandb
@dataclasses.dataclass
@@ -29,6 +30,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name (if any)."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -66,33 +75,30 @@ class Args:
### Set Up Generator ###
- # Restores the pytorch model from wandb
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_gan_cnn_2d_generator:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_gan_cnn_2d_generator:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
- artifact_dir = artifact.download()
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="gan_cnn_2d",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["generator.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"generator.pth": f"{args.problem_id}_gan_cnn_2d_generator"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- ckpt_path = os.path.join(artifact_dir, "generator.pth")
+ ckpt_path = resolved.files["generator.pth"]
ckpt = th.load(ckpt_path, map_location=th.device(device))
- model = Generator(latent_dim=run.config["latent_dim"], design_shape=problem.design_space.shape)
+ model = Generator(latent_dim=run_config["latent_dim"], design_shape=problem.design_space.shape)
model.load_state_dict(ckpt["generator"])
model.eval() # Set to evaluation mode
model.to(device)
# Sample noise as generator input
- z = th.randn((args.n_samples, run.config["latent_dim"], 1, 1), device=device, dtype=th.float)
+ z = th.randn((args.n_samples, run_config["latent_dim"], 1, 1), device=device, dtype=th.float)
# Generate a batch of designs
gen_designs = model(z)
diff --git a/engiopt/gan_cnn_2d/gan_cnn_2d.py b/engiopt/gan_cnn_2d/gan_cnn_2d.py
index 8f46f1b3..94140d48 100644
--- a/engiopt/gan_cnn_2d/gan_cnn_2d.py
+++ b/engiopt/gan_cnn_2d/gan_cnn_2d.py
@@ -18,6 +18,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -39,6 +41,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -332,13 +342,21 @@ def sample_designs(n_designs: int) -> th.Tensor:
th.save(ckpt_gen, "generator.pth")
th.save(ckpt_disc, "discriminator.pth")
- if args.track:
- artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model")
- artifact_gen.add_file("generator.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model")
- artifact_disc.add_file("discriminator.pth")
-
- wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"generator.pth": "generator.pth", "discriminator.pth": "discriminator.pth"},
+ run_config=vars(args),
+ primary_files=["generator.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_generator": "generator.pth",
+ f"{args.problem_id}_{args.algo}_discriminator": "discriminator.pth",
+ },
+ )
wandb.finish()
diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py
index 224340ca..ab1be33c 100644
--- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py
+++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py
@@ -10,9 +10,10 @@
import pandas as pd
import torch as th
import tyro
-import wandb
from engiopt import metrics
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.pixel_cnn_pp_2d.pixel_cnn_pp_2d import PixelCNNpp
from engiopt.pixel_cnn_pp_2d.pixel_cnn_pp_2d import sample_from_discretized_mix_logistic
@@ -32,6 +33,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -73,32 +82,30 @@ class Args:
design_shape = (problem.design_space.shape[0], problem.design_space.shape[1])
# Set up PixelCNN++ model
- if args.wandb_entity is not None:
- artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_pixel_cnn_pp_2d_model:seed_{seed}"
- else:
- artifact_path = f"{args.wandb_project}/{args.problem_id}_pixel_cnn_pp_2d_model:seed_{seed}"
-
- api = wandb.Api()
- artifact = api.artifact(artifact_path, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact.logged_by()
- if run is None or not hasattr(run, "config"):
- raise RunRetrievalError
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="pixel_cnn_pp_2d",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["model.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"model.pth": f"{args.problem_id}_pixel_cnn_pp_2d_model"},
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
- artifact_dir = artifact.download()
- ckpt_path = os.path.join(artifact_dir, "model.pth")
+ ckpt_path = resolved.files["model.pth"]
ckpt = th.load(ckpt_path, map_location=device)
model = PixelCNNpp(
- nr_resnet=run.config["nr_resnet"],
- nr_filters=run.config["nr_filters"],
- nr_logistic_mix=run.config["nr_logistic_mix"],
- resnet_nonlinearity=run.config["resnet_nonlinearity"],
- dropout_p=run.config["dropout_p"],
+ nr_resnet=run_config["nr_resnet"],
+ nr_filters=run_config["nr_filters"],
+ nr_logistic_mix=run_config["nr_logistic_mix"],
+ resnet_nonlinearity=run_config["resnet_nonlinearity"],
+ dropout_p=run_config["dropout_p"],
input_channels=1,
nr_conditions=conditions_tensor.shape[1],
)
@@ -122,7 +129,7 @@ def __init__(self):
for i in range(design_shape[0]):
for j in range(design_shape[1]):
out = model(data, batch_conds)
- out_sample = sample_from_discretized_mix_logistic(out, run.config["nr_logistic_mix"])
+ out_sample = sample_from_discretized_mix_logistic(out, run_config["nr_logistic_mix"])
data[:, :, i, j] = out_sample.data[:, :, i, j]
# move completed batch to CPU to free GPU memory and store
diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py
index c2a257ff..ce3b9b44 100644
--- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py
+++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py
@@ -25,6 +25,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -46,6 +48,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved model weights."""
+ hf_entity: str = "IDEALLab"
+ """HF org/user where checkpoints are stored."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used for model-family repositories."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -891,10 +901,18 @@ def sample_designs( # noqa: PLR0913
}
th.save(ckpt_model, "model.pth")
- if args.track:
- artifact_model = wandb.Artifact(f"{args.problem_id}_{args.algo}_model", type="model")
- artifact_model.add_file("model.pth")
-
- wandb.log_artifact(artifact_model, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"model.pth": "model.pth"},
+ run_config=vars(args),
+ primary_files=["model.pth"],
+ wandb_artifacts={f"{args.problem_id}_{args.algo}_model": "model.pth"},
+ )
wandb.finish()
diff --git a/engiopt/surrogate_model/mlp_tabular_only.py b/engiopt/surrogate_model/mlp_tabular_only.py
index 135da9ad..3fe0a859 100644
--- a/engiopt/surrogate_model/mlp_tabular_only.py
+++ b/engiopt/surrogate_model/mlp_tabular_only.py
@@ -24,6 +24,8 @@
from engiopt.args_utils import parse_list_from_single_item_list
from engiopt.args_utils import parse_list_from_string
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -94,6 +96,10 @@ class Args:
track: bool = True
wandb_project: str = "engiopt"
wandb_entity: str | None = None
+ checkpoint_backend: CheckpointBackend = "hf"
+ hf_entity: str = "IDEALLab"
+ hf_repo_prefix: str = "engiopt"
+ hf_private: bool = False
seed: int = 42
strict_determinism: bool = False
n_ensembles: int = 1
@@ -350,11 +356,23 @@ def main(args: Args) -> float: # noqa: PLR0915
if args.save_model:
pipeline.save(pipeline_filename, device=device)
print(f"[INFO] Saved pipeline to {pipeline_filename}")
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={os.path.basename(pipeline_filename): pipeline_filename},
+ run_config=vars(args),
+ metadata={"target_col": args.target_col},
+ primary_files=[os.path.basename(pipeline_filename)],
+ extra_path_parts=[args.target_col],
+ wandb_artifacts={f"{run_name}_model": pipeline_filename},
+ )
if args.track:
- artifact = wandb.Artifact(f"{run_name}_model", type="model")
- artifact.add_file(pipeline_filename)
- wandb.log_artifact(artifact, aliases=[f"seed_{args.seed}"])
- print("[INFO] Uploaded model artifact to W&B.")
+ print("[INFO] Uploaded model artifact to configured checkpoint backend(s).")
# Evaluate on test set if requested
if args.test_model:
diff --git a/engiopt/surrogate_model/run_pe_optimization.py b/engiopt/surrogate_model/run_pe_optimization.py
index 82905f55..a7fdbb35 100644
--- a/engiopt/surrogate_model/run_pe_optimization.py
+++ b/engiopt/surrogate_model/run_pe_optimization.py
@@ -34,11 +34,13 @@
from pymoo.optimize import minimize
from pymoo.termination import get_termination
import tyro
+import wandb
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_checkpoint_reference
from engiopt.surrogate_model.model_pipeline import ModelPipeline
from engiopt.surrogate_model.pymoo_pe_problem import PymooPowerElecProblem
from engiopt.surrogate_model.training_utils import get_device
-import wandb
if TYPE_CHECKING:
from pymoo.core.algorithm import Algorithm
@@ -52,13 +54,15 @@
class Args:
# Surrogate pipelines
model_gain_path: str
- """Path to the W&B artifact for the gain model, e.g. "engibench/engiopt/power_electronics__mlp_tabular_only__DcGain__50__1746000531_model:latest"."""
+ """Model ref for the gain surrogate (W&B artifact, HF package ref, or local package directory)."""
model_ripple_path: str
- """Path to the W&B artifact for the ripple model, e.g. "engibench/engiopt/power_electronics__mlp_tabular_only__Voltage_Ripple__50__1746001046_model:latest"."""
+ """Model ref for the ripple surrogate (W&B artifact, HF package ref, or local package directory)."""
# Optimisation hyperparameters
seed: int
"""Random seed for the optimization, must match the seed used to train the models."""
+ model_source: ModelSource = "auto"
+ """Where to load the surrogate pipelines from."""
pop_size: int = 500
n_gen: int = 100
@@ -171,21 +175,30 @@ def save_front(res: Result, output_dir: str) -> tuple[str, str, str, str, str]:
return evals_csv, designs_csv, pareto_csv, evals_txt, designs_txt
-def load_model_from_wandb(artifact_path: str, run) -> ModelPipeline:
- """Load a model pipeline from a W&B artifact.
+def load_model_from_reference(
+ model_ref: str,
+ *,
+ model_source: ModelSource,
+ active_wandb_run: wandb.sdk.wandb_run.Run | None,
+) -> ModelPipeline:
+ """Load a model pipeline from a W&B artifact, HF package, or local directory.
Args:
- artifact_path: Path to the W&B artifact.
- run: Active W&B run to use for downloading the artifact.
+ model_ref: Reference to the stored model package or artifact.
+ model_source: Checkpoint backend to use when resolving the reference.
+ active_wandb_run: Optional active W&B run for artifact access.
Returns:
Loaded model pipeline.
"""
- artifact = run.use_artifact(artifact_path, type="model")
- artifact_dir = artifact.download()
- # Find the .pkl file in the directory (assuming exactly one)
- model_file = next(f for f in os.listdir(artifact_dir) if f.endswith(".pkl"))
- return ModelPipeline.load(os.path.join(artifact_dir, model_file))
+ resolved = resolve_checkpoint_reference(
+ model_source=model_source,
+ model_ref=model_ref,
+ required_files=[],
+ active_wandb_run=active_wandb_run,
+ )
+ model_file = next(file_path for file_path in resolved.files.values() if file_path.endswith(".pkl"))
+ return ModelPipeline.load(model_file)
# ---------------------------------------------------------------------------
@@ -208,11 +221,17 @@ def main(args: Args) -> None:
wandb.define_metric("generation")
wandb.define_metric("*", step_metric="generation")
- # load models from weights and biases
- assert wandb.run is not None, f"W&B run not found for run_name={run_name} in {args.wandb_entity}/{args.wandb_project}"
- # Load both models using the helper function
- pipeline_g = load_model_from_wandb(args.model_gain_path, wandb.run)
- pipeline_r = load_model_from_wandb(args.model_ripple_path, wandb.run)
+ active_wandb_run = wandb.run if args.track else None
+ pipeline_g = load_model_from_reference(
+ args.model_gain_path,
+ model_source=args.model_source,
+ active_wandb_run=active_wandb_run,
+ )
+ pipeline_r = load_model_from_reference(
+ args.model_ripple_path,
+ model_source=args.model_source,
+ active_wandb_run=active_wandb_run,
+ )
problem = PymooPowerElecProblem(
pipeline_r=pipeline_r,
diff --git a/engiopt/vqgan/evaluate_vqgan.py b/engiopt/vqgan/evaluate_vqgan.py
index bc2d3edf..51c7b358 100644
--- a/engiopt/vqgan/evaluate_vqgan.py
+++ b/engiopt/vqgan/evaluate_vqgan.py
@@ -10,9 +10,10 @@
import pandas as pd
import torch as th
import tyro
-import wandb
from engiopt import metrics
+from engiopt.checkpoint_store import ModelSource
+from engiopt.checkpoint_store import resolve_named_checkpoint
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.transforms import drop_constant
from engiopt.transforms import normalize
@@ -33,6 +34,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ model_source: ModelSource = "auto"
+ """Where to load the checkpoint package from."""
+ hf_entity: str = "IDEALLab"
+ """HF organization or user for checkpoint storage."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used to build per-family model repos."""
+ local_model_dir: str | None = None
+ """Optional local checkpoint package directory."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
@@ -62,57 +71,67 @@ class Args:
### Set Up Transformer ###
- # Restores the pytorch model from wandb
- if args.wandb_entity is not None:
- artifact_path_cvqgan = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}"
- artifact_path_vqgan = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}"
- artifact_path_transformer = (
- f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}"
+ resolved = resolve_named_checkpoint(
+ model_source=args.model_source,
+ problem_id=args.problem_id,
+ algo="vqgan",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["vqgan.pth", "transformer.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={
+ "vqgan.pth": f"{args.problem_id}_vqgan_vqgan",
+ "transformer.pth": f"{args.problem_id}_vqgan_transformer",
+ },
+ wandb_config_artifact_name=f"{args.problem_id}_vqgan_transformer",
+ local_model_dir=args.local_model_dir,
+ )
+ run_config = resolved.run_config
+
+ ckpt_path_cvqgan = os.path.join(resolved.root_dir, "cvqgan.pth")
+ if run_config["conditional"] and not os.path.exists(ckpt_path_cvqgan):
+ cvqgan_resolved = resolve_named_checkpoint(
+ model_source="wandb" if args.model_source == "wandb" else "auto",
+ problem_id=args.problem_id,
+ algo="vqgan",
+ seed=seed,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ required_files=["cvqgan.pth"],
+ wandb_project=args.wandb_project,
+ wandb_entity=args.wandb_entity,
+ wandb_artifact_names={"cvqgan.pth": f"{args.problem_id}_vqgan_cvqgan"},
+ wandb_config_artifact_name=f"{args.problem_id}_vqgan_transformer",
+ local_model_dir=args.local_model_dir,
)
- else:
- artifact_path_cvqgan = f"{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}"
- artifact_path_vqgan = f"{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}"
- artifact_path_transformer = f"{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}"
-
- api = wandb.Api()
- artifact_cvqgan = api.artifact(artifact_path_cvqgan, type="model")
- artifact_vqgan = api.artifact(artifact_path_vqgan, type="model")
- artifact_transformer = api.artifact(artifact_path_transformer, type="model")
-
- class RunRetrievalError(ValueError):
- def __init__(self):
- super().__init__("Failed to retrieve the run")
-
- run = artifact_transformer.logged_by()
- if run is None:
- raise RunRetrievalError
- run = api.run(f"{run.entity}/{run.project}/{run.id}")
-
- artifact_dir_cvqgan = artifact_cvqgan.download()
- artifact_dir_vqgan = artifact_vqgan.download()
- artifact_dir_transformer = artifact_transformer.download()
-
- ckpt_path_cvqgan = os.path.join(artifact_dir_cvqgan, "cvqgan.pth")
- ckpt_path_vqgan = os.path.join(artifact_dir_vqgan, "vqgan.pth")
- ckpt_path_transformer = os.path.join(artifact_dir_transformer, "transformer.pth")
- ckpt_cvqgan = th.load(ckpt_path_cvqgan, map_location=th.device(device), weights_only=False)
+ ckpt_path_cvqgan = cvqgan_resolved.files["cvqgan.pth"]
+
+ ckpt_path_vqgan = resolved.files["vqgan.pth"]
+ ckpt_path_transformer = resolved.files["transformer.pth"]
+ ckpt_cvqgan = None
+ if os.path.exists(ckpt_path_cvqgan):
+ ckpt_cvqgan = th.load(ckpt_path_cvqgan, map_location=th.device(device), weights_only=False)
+ elif run_config["conditional"]:
+ raise FileNotFoundError("Conditional VQGAN evaluation requires cvqgan.pth, but no checkpoint was found.")
ckpt_vqgan = th.load(ckpt_path_vqgan, map_location=th.device(device), weights_only=False)
ckpt_transformer = th.load(ckpt_path_transformer, map_location=th.device(device), weights_only=False)
vqgan = VQGAN(
device=device,
is_c=False,
- encoder_channels=run.config["encoder_channels"],
- encoder_start_resolution=run.config["image_size"],
- encoder_attn_resolutions=run.config["encoder_attn_resolutions"],
- encoder_num_res_blocks=run.config["encoder_num_res_blocks"],
- decoder_channels=run.config["decoder_channels"],
- decoder_start_resolution=run.config["latent_size"],
- decoder_attn_resolutions=run.config["decoder_attn_resolutions"],
- decoder_num_res_blocks=run.config["decoder_num_res_blocks"],
- image_channels=run.config["image_channels"],
- latent_dim=run.config["latent_dim"],
- num_codebook_vectors=run.config["num_codebook_vectors"],
+ encoder_channels=run_config["encoder_channels"],
+ encoder_start_resolution=run_config["image_size"],
+ encoder_attn_resolutions=run_config["encoder_attn_resolutions"],
+ encoder_num_res_blocks=run_config["encoder_num_res_blocks"],
+ decoder_channels=run_config["decoder_channels"],
+ decoder_start_resolution=run_config["latent_size"],
+ decoder_attn_resolutions=run_config["decoder_attn_resolutions"],
+ decoder_num_res_blocks=run_config["decoder_num_res_blocks"],
+ image_channels=run_config["image_channels"],
+ latent_dim=run_config["latent_dim"],
+ num_codebook_vectors=run_config["num_codebook_vectors"],
)
vqgan.load_state_dict(ckpt_vqgan["vqgan"])
vqgan.eval() # Set to evaluation mode
@@ -121,35 +140,36 @@ def __init__(self):
cvqgan = VQGAN(
device=device,
is_c=True,
- cond_feature_map_dim=run.config["cond_feature_map_dim"],
- cond_dim=run.config["cond_dim"],
- cond_hidden_dim=run.config["cond_hidden_dim"],
- cond_latent_dim=run.config["cond_latent_dim"],
- cond_codebook_vectors=run.config["cond_codebook_vectors"],
+ cond_feature_map_dim=run_config["cond_feature_map_dim"],
+ cond_dim=run_config["cond_dim"],
+ cond_hidden_dim=run_config["cond_hidden_dim"],
+ cond_latent_dim=run_config["cond_latent_dim"],
+ cond_codebook_vectors=run_config["cond_codebook_vectors"],
)
- cvqgan.load_state_dict(ckpt_cvqgan["cvqgan"])
+ if ckpt_cvqgan is not None:
+ cvqgan.load_state_dict(ckpt_cvqgan["cvqgan"])
cvqgan.eval() # Set to evaluation mode
cvqgan.to(device)
model = VQGANTransformer(
- conditional=run.config["conditional"],
+ conditional=run_config["conditional"],
vqgan=vqgan,
cvqgan=cvqgan,
- image_size=run.config["image_size"],
- decoder_channels=run.config["decoder_channels"],
- cond_feature_map_dim=run.config["cond_feature_map_dim"],
- num_codebook_vectors=run.config["num_codebook_vectors"],
- n_layer=run.config["n_layer"],
- n_head=run.config["n_head"],
- n_embd=run.config["n_embd"],
- dropout=run.config["dropout"],
+ image_size=run_config["image_size"],
+ decoder_channels=run_config["decoder_channels"],
+ cond_feature_map_dim=run_config["cond_feature_map_dim"],
+ num_codebook_vectors=run_config["num_codebook_vectors"],
+ n_layer=run_config["n_layer"],
+ n_head=run_config["n_head"],
+ n_embd=run_config["n_embd"],
+ dropout=run_config["dropout"],
)
model.load_state_dict(ckpt_transformer["transformer"])
model.eval() # Set to evaluation mode
model.to(device)
### Set up testing conditions ###
- _, sampled_conditions, sampled_designs_np, _ = sample_conditions(
+ _, sampled_conditions, sampled_designs_np, _sampled_designs_tensor = sample_conditions(
problem=problem, n_samples=args.n_samples, device=device, seed=seed
)
@@ -158,25 +178,25 @@ def __init__(self):
conditions = sampled_conditions_new.column_names
# Drop constant condition columns if enabled
- if run.config["drop_constant_conditions"]:
+ if run_config["drop_constant_conditions"]:
sampled_conditions_new, conditions = drop_constant(sampled_conditions_new, sampled_conditions_new.column_names)
# Normalize condition columns if enabled
- if run.config["normalize_conditions"]:
+ if run_config["normalize_conditions"]:
sampled_conditions_new, mean, std = normalize(sampled_conditions_new, conditions)
# Convert to tensor
conditions_tensor = th.stack([th.as_tensor(sampled_conditions_new[c][:]).float() for c in conditions], dim=1).to(device)
# Set the start-of-sequence tokens for the transformer using the CVQGAN to discretize the conditions if enabled
- if run.config["conditional"]:
+ if run_config["conditional"]:
c = model.encode_to_z(x=conditions_tensor, is_c=True)[1]
else:
c = th.ones(args.n_samples, 1, dtype=th.int64, device=device) * model.sos_token
# Generate a batch of designs
latent_designs = model.sample(
- x=th.empty(args.n_samples, 0, dtype=th.int64, device=device), c=c, steps=(run.config["latent_size"] ** 2)
+ x=th.empty(args.n_samples, 0, dtype=th.int64, device=device), c=c, steps=(run_config["latent_size"] ** 2)
)
gen_designs = resize_to(
data=model.z_to_image(latent_designs), h=problem.design_space.shape[0], w=problem.design_space.shape[1]
diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py
index fade92d9..36052647 100644
--- a/engiopt/vqgan/vqgan.py
+++ b/engiopt/vqgan/vqgan.py
@@ -32,6 +32,8 @@
import tyro
import wandb
+from engiopt.checkpoint_store import CheckpointBackend
+from engiopt.checkpoint_store import save_checkpoint_package
from engiopt.reproducibility import enable_strict_determinism
from engiopt.reproducibility import make_dataloader_generator
from engiopt.reproducibility import seed_training
@@ -68,6 +70,14 @@ class Args:
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
+ checkpoint_backend: CheckpointBackend = "hf"
+ """Checkpoint backend for saved models."""
+ hf_entity: str = "IDEALLab"
+ """HF organization or user for checkpoint storage."""
+ hf_repo_prefix: str = "engiopt"
+ """HF repo prefix used to build per-family model repos."""
+ hf_private: bool = False
+ """Whether newly created HF repos should be private."""
seed: int = 1
"""Random seed."""
@@ -978,9 +988,20 @@ def sample_designs_transformer(n_designs: int) -> tuple[th.Tensor, th.Tensor]:
}
th.save(ckpt_cvq, "cvqgan.pth")
- artifact_cvq = wandb.Artifact(f"{args.problem_id}_{args.algo}_cvqgan", type="model")
- artifact_cvq.add_file("cvqgan.pth")
- wandb.log_artifact(artifact_cvq, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={"cvqgan.pth": "cvqgan.pth"},
+ run_config=vars(args),
+ metadata={"stage": "cvqgan"},
+ primary_files=["cvqgan.pth"],
+ wandb_artifacts={f"{args.problem_id}_{args.algo}_cvqgan": "cvqgan.pth"},
+ )
# Freeze CVQGAN for later use in Stage 2 Transformer
for p in cvqgan.parameters():
@@ -1084,13 +1105,26 @@ def sample_designs_transformer(n_designs: int) -> tuple[th.Tensor, th.Tensor]:
th.save(ckpt_vq, "vqgan.pth")
th.save(ckpt_disc, "discriminator.pth")
- artifact_vq = wandb.Artifact(f"{args.problem_id}_{args.algo}_vqgan", type="model")
- artifact_vq.add_file("vqgan.pth")
- artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model")
- artifact_disc.add_file("discriminator.pth")
-
- wandb.log_artifact(artifact_vq, aliases=[f"seed_{args.seed}"])
- wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files={
+ "vqgan.pth": "vqgan.pth",
+ "discriminator.pth": "discriminator.pth",
+ },
+ run_config=vars(args),
+ metadata={"stage": "vqgan"},
+ primary_files=["vqgan.pth"],
+ wandb_artifacts={
+ f"{args.problem_id}_{args.algo}_vqgan": "vqgan.pth",
+ f"{args.problem_id}_{args.algo}_discriminator": "discriminator.pth",
+ },
+ )
# Freeze VQGAN for later use in Stage 2 Transformer
for p in vqgan.parameters():
@@ -1211,8 +1245,27 @@ def sample_designs_transformer(n_designs: int) -> tuple[th.Tensor, th.Tensor]:
}
th.save(ckpt_tr, "transformer.pth")
- artifact_tr = wandb.Artifact(f"{args.problem_id}_{args.algo}_transformer", type="model")
- artifact_tr.add_file("transformer.pth")
- wandb.log_artifact(artifact_tr, aliases=[f"seed_{args.seed}"])
+ checkpoint_files = {"transformer.pth": "transformer.pth"}
+ if os.path.exists("vqgan.pth"):
+ checkpoint_files["vqgan.pth"] = "vqgan.pth"
+ if os.path.exists("discriminator.pth"):
+ checkpoint_files["discriminator.pth"] = "discriminator.pth"
+ if os.path.exists("cvqgan.pth"):
+ checkpoint_files["cvqgan.pth"] = "cvqgan.pth"
+
+ save_checkpoint_package(
+ checkpoint_backend=args.checkpoint_backend,
+ hf_entity=args.hf_entity,
+ hf_repo_prefix=args.hf_repo_prefix,
+ hf_private=args.hf_private,
+ problem_id=args.problem_id,
+ algo=args.algo,
+ seed=args.seed,
+ checkpoint_files=checkpoint_files,
+ run_config=vars(args),
+ metadata={"stage": "transformer"},
+ primary_files=["transformer.pth"],
+ wandb_artifacts={f"{args.problem_id}_{args.algo}_transformer": "transformer.pth"},
+ )
wandb.finish()
diff --git a/pyproject.toml b/pyproject.toml
index 36881c4a..d2412ffc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,6 +40,7 @@ dependencies = [
"datasets >=4.0.0",
"einops >=0.8.0",
"requests >=2.31.0",
+ "huggingface-hub >= 0.30.0",
]
dynamic = ["version"]
@@ -294,6 +295,8 @@ module = [
"plotly.*",
"einops",
"einops.*",
+ "huggingface_hub",
+ "huggingface_hub.*",
"torchvision.*",
"requests",
]