Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions examples/scripts/1_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@

from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.telemetry import configure_logging, get_logger


configure_logging(log_file="1_introduction.log")
log = get_logger(name="1_introduction")

# Set up the device and data type
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
Expand All @@ -28,9 +32,9 @@
# ============================================================================
# SECTION 1: Lennard-Jones Model - Simple Classical Potential
# ============================================================================
print("\n" + "=" * 70)
print("SECTION 1: Lennard-Jones Model")
print("=" * 70)
log.info("=" * 70)
log.info("SECTION 1: Lennard-Jones Model")
log.info("=" * 70)

# Create face-centered cubic (FCC) Argon
# 5.26 Å is a typical lattice constant for Ar
Expand Down Expand Up @@ -103,19 +107,19 @@
results = lj_model(state)

# Print the results
print(f"Energy: {results['energy']}")
print(f"Forces shape: {results['forces'].shape}")
print(f"Stress shape: {results['stress'].shape}")
print(f"Per-atom energies shape: {results['energies'].shape}")
print(f"Per-atom stresses shape: {results['stresses'].shape}")
log.info(f"Energy: {results['energy']}")
log.info(f"Forces shape: {results['forces'].shape}")
log.info(f"Stress shape: {results['stress'].shape}")
log.info(f"Per-atom energies shape: {results['energies'].shape}")
log.info(f"Per-atom stresses shape: {results['stresses'].shape}")


# ============================================================================
# SECTION 2: MACE Model - Machine Learning Potential (Batched)
# ============================================================================
print("\n" + "=" * 70)
print("SECTION 2: MACE Model with Batched Input")
print("=" * 70)
log.info("=" * 70)
log.info("SECTION 2: MACE Model with Batched Input")
log.info("=" * 70)

# Load the raw model from the downloaded model
loaded_model = mace_mp(
Expand Down Expand Up @@ -164,10 +168,10 @@
)

# You can see their shapes are as expected
print(f"Positions: {positions.shape}")
print(f"Cell: {cell.shape}")
print(f"Atomic numbers: {atomic_numbers.shape}")
print(f"System indices: {system_idx.shape}")
log.info(f"Positions: {positions.shape}")
log.info(f"Cell: {cell.shape}")
log.info(f"Atomic numbers: {atomic_numbers.shape}")
log.info(f"System indices: {system_idx.shape}")

# Masses for Silicon (28.085 amu)
masses_si = torch.full((positions.shape[0],), 28.085, device=device, dtype=dtype)
Expand All @@ -185,13 +189,13 @@
)

# The energy has shape (n_systems,) as the structures in a batch
print(f"Energy shape: {results['energy'].shape}")
log.info(f"Energy shape: {results['energy'].shape}")

# The forces have shape (n_atoms, 3) same as positions
print(f"Forces shape: {results['forces'].shape}")
log.info(f"Forces shape: {results['forces'].shape}")

# The stress has shape (n_systems, 3, 3) same as cell
print(f"Stress shape: {results['stress'].shape}")
log.info(f"Stress shape: {results['stress'].shape}")

# Check if the energy, forces, and stress are the same for the Si system across batches
# Each system has 64 atoms (2x2x2 supercell of 8-atom Si diamond)
Expand All @@ -204,10 +208,10 @@
)
stress_diff = torch.max(torch.abs(results["stress"][0] - results["stress"][1]))

print(f"\nMax energy difference: {energy_diff}")
print(f"Max forces difference: {forces_diff}")
print(f"Max stress difference: {stress_diff}")
log.info(f"Max energy difference: {energy_diff}")
log.info(f"Max forces difference: {forces_diff}")
log.info(f"Max stress difference: {stress_diff}")

print("\n" + "=" * 70)
print("Introduction examples completed!")
print("=" * 70)
log.info("=" * 70)
log.info("Introduction examples completed!")
log.info("=" * 70)
Loading
Loading