# NVT Simulation: Constant-T
import numpy as np
from math import sqrt, nan
from ase import Atoms
from ase.data import atomic_masses
from ase.md.velocitydistribution import  MaxwellBoltzmannDistribution
import torch
import torch_sim as ts
from torch_sim import fire_init, fire_step
from torch_sim.units import MetalUnits
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.integrators.nvt import nvt_nose_hoover_init, nvt_nose_hoover_step, nvt_nose_hoover_invariant
import matplotlib.pyplot as plt


# Run Multiple Simulation: 10 (ideally)
# --------------------------------
# Simulation Constants:
k_b = 1.3806503*10**(-23) # boltzmann's constant
# --------------------------------
# Simulation/TorchSim Adjustable Parameters:
run_n_simulations = 1
n_equilibrium_steps = 1000
simulation_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Simulation Adjustable Parameters:
epsilon_J = 1.6512577588 * 10 ** (-21)   #in J
sigma_m = 3.405 * 10 ** (-10) #in m
epsilon_eV = epsilon_J * 6.241509e18
sigma_A = sigma_m * 1e10
cutoff_distance_m = 3 * sigma_m
cutoff_distance_A = 3 * sigma_A


amu_to_kg = 1.660539e-27
# Simulation Parameters that are based on the paper:
# No. Argon Atoms in the SimulationBox/System
n_atoms = 2592
system_mass_amu = 2592 * atomic_masses[18]
system_mass_kg = system_mass_amu * amu_to_kg
atom_mass_kg = atomic_masses[18] * amu_to_kg
n_slabs = 20 # Partition the SimulationBox in n_slab
temperature_K = 83.122
temperature_eV = temperature_K * MetalUnits.temperature
timestep_reduced = 6.965e-3
timestep_s = timestep_reduced * sqrt(atom_mass_kg * sigma_m ** 2 / epsilon_J)
timestep_ps = timestep_s /1e-12
timestep_fs = timestep_s / 10e-15
simulation_timestep = timestep_ps*MetalUnits.time

for j in range(run_n_simulations):
    # Generating Uniformly Distributed Coordinates to "scatter" the atoms in the SimulationBox
    rng = np.random.default_rng()
    x_lower, x_upper = 0, 34.2498735 #lower and upperbound for x coord
    Yl, Yu = 0, 34.2498735 #lower and upperbound for y coord
    Zl, Zu = 0, 102.749961 #lower and upperbound for z coord

    # Uniformly distributing the atoms:
    system_coords = rng.uniform([x_lower, Yl, Zl], [x_upper, Yu, Zu], size = (n_atoms, 3))

    # Initializing Argon Filled System:
    ar_system = Atoms(numbers=[18] * n_atoms, positions = system_coords, cell = [34.2498735, 34.2498735, 102.749961], pbc=[False, False, True]) # pbc=True?
    # Defining Average Total temperature of 0.694 by defining the momenta of the system using the boltzmann distribution at the temperature.
    MaxwellBoltzmannDistribution(ar_system, temperature_K=1) # temperature_K

    #--------------------------------------------------------------------------------------------
    # Initialise the SimState from arSystem and then initialise NVT ensemble for simulation
    # Create a Lennard-Jones model
    # print(sigma_A, epsilon_eV, cutoff_distance_A)
    lj_model = LennardJonesModel(
        sigma= sigma_A, # Å, 3.405
        epsilon=epsilon_eV,  # eV, 0.01030634
        cutoff=cutoff_distance_A, # Angstrom 10.215
        device=simulation_device,
        # technically this is true by default but mentioning this to make the code easier to read
        compute_forces=True,
        compute_stress=True,
        per_atom_energies=True,
        per_atom_stresses=True,
        use_neighbor_list= True,
    )

    # argon_simstate is TorchSim's SimState
    argon_simstate = ts.state.initialize_state(
        system=ar_system,
        device=simulation_device,
        dtype = torch.float32
    )


    # Probs need to relax the atoms in the SimulationBox? (still doesn't work)
    relaxed_state = ts.runners.optimize(
        system=argon_simstate,
        model=lj_model,
        optimizer= (fire_init, fire_step),
        convergence_fn= ts.runners.generate_energy_convergence_fn(energy_tol=1e-6)
    )
    print("------ Relaxed State ------")
    model_output = lj_model(relaxed_state)
    print(f"Energy: {model_output['energy'].item()}")
    print(relaxed_state.positions)
    print("------ Original State ------")
    model_output = lj_model(argon_simstate)
    print(f"Energy: {model_output['energy'].item()}")
    print(argon_simstate.positions)

    relaxed_state = ts.runners.optimize(
        system=relaxed_state,
        model=lj_model,
        optimizer= (fire_init, fire_step),
        convergence_fn= ts.runners.generate_force_convergence_fn(force_tol=1e-6, include_cell_forces=False)
    )

    print("------ Relaxed State ------")
    model_output = lj_model(relaxed_state)
    print(f"Forces: {model_output['forces']}")
    print(relaxed_state.positions)
    print("------ Original State ------")
    model_output = lj_model(argon_simstate)
    print(f"Forces: {model_output['forces']}")
    print(argon_simstate.positions)

    # Defining TorchSim's NVTNoseHooverState (system_state) based on Simstate (relaxed_state)
    system_state = nvt_nose_hoover_init(
        state = relaxed_state,
        model = lj_model,
        kT = torch.tensor(1 * MetalUnits.temperature).to(simulation_device), #
        dt = torch.tensor(simulation_timestep).to(simulation_device),
        #Using Default tau=100*dt, chain_length=3, chain_steps = 3 and sy_step=3
    )
    sys_temp_k = ts.quantities.calc_temperature(masses=system_state.masses,
                                                momenta=system_state.momenta,
                                                # velocities=system_state.velocities
                                                )
    print(f"Pre-Integrate: System Temperature: {sys_temp_k:.2f} K")
    alt_system_state = system_state

    stablization_phase_temp_k = []
    stablization_phase_energy_eV = []
    # Approach 1: Just running the simulation till equilibrium has been reached
    for i in range(n_equilibrium_steps): #

        system_state = ts.integrate(
            system= system_state,
            model = lj_model,
            integrator=ts.Integrator.nvt_nose_hoover,
            n_steps=1,
            temperature=temperature_K,
            timestep=simulation_timestep)

        if torch.isnan(system_state.momenta).any():
            print(f"sys: NaN in momenta at step {i}")

        sys_temp_k = ts.quantities.calc_temperature(masses=system_state.masses,
                                                    momenta=system_state.momenta,
                                                    # velocities=system_state.velocities
                                                    )
        stablization_phase_temp_k.append(sys_temp_k.cpu().numpy())
        stablization_phase_energy_eV.append(system_state.energy.cpu().numpy())
        print(f"Step {i}: System Temperature: {sys_temp_k:.2f} K")

    plt.plot(range(1, n_equilibrium_steps+1), stablization_phase_temp_k, color="g", label="System Temperature")
    plt.show()
    plt.plot(range(1, n_equilibrium_steps+1), stablization_phase_energy_eV, color="w", label="System Energy")
    plt.show()
    break