diff --git a/src/KOKKOS/pair_metatomic_kokkos.cpp b/src/KOKKOS/pair_metatomic_kokkos.cpp index 5420b9f66ed..79d01f109b4 100644 --- a/src/KOKKOS/pair_metatomic_kokkos.cpp +++ b/src/KOKKOS/pair_metatomic_kokkos.cpp @@ -70,11 +70,12 @@ void PairMetatomicKokkos::init_style() { this->type_mapping_kk = Kokkos::View("type_mapping_kk", atom->ntypes + 1); Kokkos::deep_copy(this->type_mapping_kk, type_mapping_kk_host); + using NCMode = PairMetatomicData::NonConservativeMode; auto options = MetatomicSystemOptions{ this->type_mapping_kk.data(), mta_data->max_cutoff, mta_data->check_consistency, - !(mta_data->non_conservative), + mta_data->non_conservative != NCMode::ON, // autograd needed for OFF/FORCES/STRESS }; // override the system adaptor with the kokkos version @@ -112,6 +113,7 @@ void PairMetatomicKokkos::pick_device(torch::Device& device, const c template void PairMetatomicKokkos::store_forces(const at::Tensor& forces_tensor) { + using NCMode = PairMetatomicData::NonConservativeMode; assert(forces_tensor.scalar_type() == torch::kFloat64); auto forces = forces_tensor.contiguous(); @@ -131,8 +133,8 @@ void PairMetatomicKokkos::store_forces(const at::Tensor& forces_tens } ); - // in non-conservative mode we do not need to update forces on ghost atoms - if (!mta_data->non_conservative) { + // ghost atom forces only exist when forces come from autograd + if (mta_data->non_conservative == NCMode::OFF || mta_data->non_conservative == NCMode::STRESS) { auto system_adaptor_kk = dynamic_cast*>(this->system_adaptor.get()); assert(system_adaptor_kk != nullptr); auto mta_to_lmp_kk = UnmanagedView( diff --git a/src/ML-METATOMIC/metatomic_types.h b/src/ML-METATOMIC/metatomic_types.h index 9565e76a35a..f8ba8aefae2 100644 --- a/src/ML-METATOMIC/metatomic_types.h +++ b/src/ML-METATOMIC/metatomic_types.h @@ -69,8 +69,9 @@ struct PairMetatomicData: public CommonMetatomicData { metatomic_torch::ModelOutput nc_forces_output; metatomic_torch::ModelOutput nc_stress_output; - // whether non-conservative forces and stresses should be used - bool non_conservative = false; + // which non-conservative outputs to use + enum class NonConservativeMode { OFF, ON, FORCES, STRESS }; + NonConservativeMode non_conservative = NonConservativeMode::OFF; // energy key for the model std::string energy_key; diff --git a/src/ML-METATOMIC/pair_metatomic.cpp b/src/ML-METATOMIC/pair_metatomic.cpp index bd067026e86..947590500ad 100644 --- a/src/ML-METATOMIC/pair_metatomic.cpp +++ b/src/ML-METATOMIC/pair_metatomic.cpp @@ -152,13 +152,17 @@ void PairMetatomic::settings(int argc, char ** argv) { i += 1; } else if (strcmp(argv[i], "non_conservative") == 0) { if (i == argc - 1) { - error->one(FLERR, "expected after 'non_conservative' in pair_style metatomic, got nothing"); + error->one(FLERR, "expected after 'non_conservative' in pair_style metatomic, got nothing"); } else if (strcmp(argv[i + 1], "on") == 0) { - mta_data->non_conservative = true; + mta_data->non_conservative = PairMetatomicData::NonConservativeMode::ON; } else if (strcmp(argv[i + 1], "off") == 0) { - mta_data->non_conservative = false; + mta_data->non_conservative = PairMetatomicData::NonConservativeMode::OFF; + } else if (strcmp(argv[i + 1], "forces") == 0) { + mta_data->non_conservative = PairMetatomicData::NonConservativeMode::FORCES; + } else if (strcmp(argv[i + 1], "stress") == 0) { + mta_data->non_conservative = PairMetatomicData::NonConservativeMode::STRESS; } else { - error->one(FLERR, "expected after 'non_conservative' in pair_style metatomic, got '{}'", argv[i + 1]); + error->one(FLERR, "expected after 'non_conservative' in pair_style metatomic, got '{}'", argv[i + 1]); } i += 1; @@ -267,29 +271,68 @@ void PairMetatomic::settings(int argc, char ** argv) { } // Handle non-conservative variants - if (mta_data->non_conservative) { - // Error if *both* nc-force and nc-stress were provided by user AND one is Null - bool user_set_forces = (variant_nc_forces != nullptr); - bool user_set_stress = (variant_nc_stress != nullptr); - - if (user_set_forces && user_set_stress) { - - bool forces_none = !normalize_variant(variant_nc_forces).has_value(); - bool stress_none = !normalize_variant(variant_nc_stress).has_value(); + using NCMode = PairMetatomicData::NonConservativeMode; + const auto nc_mode = mta_data->non_conservative; + + bool user_set_forces = (variant_nc_forces != nullptr); + bool user_set_stress = (variant_nc_stress != nullptr); + + // Warn if the user set an explicit variant for an output that the chosen + // mode does not use. + if (user_set_forces && nc_mode != NCMode::ON && nc_mode != NCMode::FORCES) { + error->warning(FLERR, + "'variant/non_conservative_forces' was set but the current 'non_conservative' mode " + "does not use non-conservative forces; the variant will be ignored." + ); + } + if (user_set_stress && nc_mode != NCMode::ON && nc_mode != NCMode::STRESS) { + error->warning(FLERR, + "'variant/non_conservative_stress' was set but the current 'non_conservative' mode " + "does not use non-conservative stress; the variant will be ignored." + ); + } - if (forces_none != stress_none) { - error->one(FLERR, - "if both 'variant/non_conservative_stress' and " - "'variant/non_conservative_forces' are present, they " - "must either both be 'off' or both not 'off'"); - } + // Error if *both* nc-force and nc-stress were provided by user AND one is Null + if (nc_mode == NCMode::ON && user_set_forces && user_set_stress) { + bool forces_none = !normalize_variant(variant_nc_forces).has_value(); + bool stress_none = !normalize_variant(variant_nc_stress).has_value(); + if (forces_none != stress_none) { + error->one(FLERR, + "if both 'variant/non_conservative_stress' and " + "'variant/non_conservative_forces' are present with 'non_conservative on', " + "they must either both be 'off' or both not 'off'"); } + } + bool do_nc_forces = (nc_mode == NCMode::ON || nc_mode == NCMode::FORCES); + if (do_nc_forces) { try { mta_data->nc_forces_key = pick_output("non_conservative_forces", outputs, v_nc_forces); + } catch (std::exception& e) { + error->one(FLERR, + "{}\nFailed to select 'non_conservative_forces' output. " + "If the model does not support non-conservative forces, use " + "'non_conservative stress' or 'non_conservative off'. " + "If the model provides multiple variants, select one with " + "'variant/non_conservative_forces '.", + e.what() + ); + } + } + + bool do_nc_stress = (nc_mode == NCMode::ON || nc_mode == NCMode::STRESS); + if (do_nc_stress) { + try { mta_data->nc_stress_key = pick_output("non_conservative_stress", outputs, v_nc_stress); } catch (std::exception& e) { - error->one(FLERR, e.what()); + error->one(FLERR, + "{}\nFailed to select 'non_conservative_stress' output. " + "If the model does not support non-conservative stress, use " + "'non_conservative forces' or 'non_conservative off'. " + "If the model provides multiple variants, select one with " + "'variant/non_conservative_stress '.", + e.what() + ); } } @@ -331,7 +374,7 @@ void PairMetatomic::settings(int argc, char ** argv) { } } - if (mta_data->non_conservative) { + if (do_nc_forces) { auto nc_forces = outputs.find(mta_data->nc_forces_key); if (nc_forces == outputs.end()) { error->one(FLERR, @@ -348,12 +391,13 @@ void PairMetatomic::settings(int argc, char ** argv) { mta_data->nc_forces_key, model_path ); } - mta_data->nc_forces_output = torch::make_intrusive(); mta_data->nc_forces_output->set_quantity("force"); mta_data->nc_forces_output->set_unit(this->energy_unit + "/" + this->length_unit); mta_data->nc_forces_output->per_atom = true; + } + if (do_nc_stress) { auto nc_stress = outputs.find(mta_data->nc_stress_key); if (nc_stress != outputs.end()) { mta_data->nc_stress_output = torch::make_intrusive(); @@ -505,6 +549,9 @@ void PairMetatomic::coeff(int argc, char ** argv) { // called when the run starts void PairMetatomic::init_style() { + using NCMode = PairMetatomicData::NonConservativeMode; + const auto nc_mode = mta_data->non_conservative; + // Require newton pair on since we need to communicate forces accumulated on // ghost atoms to neighboring domains. These forces contributions come from // gradient of a local descriptor w.r.t. domain ghosts (periodic images @@ -549,7 +596,7 @@ void PairMetatomic::init_style() { this->type_mapping, mta_data->max_cutoff, mta_data->check_consistency, - !(mta_data->non_conservative), + nc_mode != NCMode::ON, // autograd needed for OFF/FORCES/STRESS }; this->system_adaptor = std::make_unique(lmp, options); @@ -576,6 +623,9 @@ void PairMetatomic::init_list(int id, NeighList *ptr) { } void PairMetatomic::compute(int eflag, int vflag) { + using NCMode = PairMetatomicData::NonConservativeMode; + const auto nc_mode = mta_data->non_conservative; + if (std::getenv("LAMMPS_METATOMIC_PROFILE") != nullptr) { MetatomicTimer::enable(true); } else { @@ -589,8 +639,15 @@ void PairMetatomic::compute(int eflag, int vflag) { mta_data->evaluation_options->outputs.clear(); // we need an energy output if the energy was explicitly requested (through // `eflag_either`), or when running in standard/conservative mode, because - // we'll get the forces as the gradient of the energy through autodiff. - if (eflag_either || !mta_data->non_conservative) { + // we'll get the forces and stress as the gradient of the energy through autodiff. + auto need_energy_for_autograd = (nc_mode == NCMode::OFF + || nc_mode == NCMode::STRESS + || (nc_mode == NCMode::FORCES && vflag_global)); + + auto do_nc_forces = nc_mode == NCMode::ON || nc_mode == NCMode::FORCES; + auto do_nc_stress = nc_mode == NCMode::ON || nc_mode == NCMode::STRESS; + + if (eflag_either || need_energy_for_autograd) { if (eflag_atom) { if (!mta_data->is_energy_output_per_atom) { error->one(FLERR, @@ -609,18 +666,11 @@ void PairMetatomic::compute(int eflag, int vflag) { mta_data->evaluation_options->outputs.insert(mta_data->energy_uq_key, mta_data->uncertainty_output); } - if (mta_data->non_conservative) { + if (do_nc_forces) { mta_data->evaluation_options->outputs.insert(mta_data->nc_forces_key, mta_data->nc_forces_output); - if (vflag_global) { - if (mta_data->nc_stress_output == nullptr) { - error->one(FLERR, - "the model at '{}' does not have a '{}' output, " - "we can not run non_conservative simulations that require computing the stress/virial", - mta_data->model_path, mta_data->nc_stress_key - ); - } - mta_data->evaluation_options->outputs.insert(mta_data->nc_stress_key, mta_data->nc_stress_output); - } + } + if (vflag_global && do_nc_stress) { + mta_data->evaluation_options->outputs.insert(mta_data->nc_stress_key, mta_data->nc_stress_output); } auto dtype = torch::kFloat64; @@ -635,7 +685,7 @@ void PairMetatomic::compute(int eflag, int vflag) { // transform from LAMMPS to metatomic System auto system = this->system_adaptor->system_from_lmp( mta_list, - static_cast(vflag_global), + vflag_global && !do_nc_stress, dtype, mta_data->device ); @@ -712,7 +762,7 @@ void PairMetatomic::compute(int eflag, int vflag) { // get the energy if we need to compute the energy, or if we are using it to // get the forces/virial with autograd - if (eflag_either || !mta_data->non_conservative) { + if (eflag_either || need_energy_for_autograd) { auto energy = results.at(mta_data->energy_key).toCustomClass(); auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); energy_tensor = energy_block->values(); @@ -722,30 +772,37 @@ void PairMetatomic::compute(int eflag, int vflag) { torch::Tensor forces_tensor; torch::Tensor virial_tensor; - if (mta_data->non_conservative) { + // get non-conservative forces + if (do_nc_forces) { auto forces = results.at(mta_data->nc_forces_key).toCustomClass(); auto forces_block = metatensor_torch::TensorMapHolder::block_by_id(forces, 0); forces_tensor = forces_block->values().squeeze(-1); forces_tensor = forces_tensor.to(torch::kCPU).to(torch::kFloat64); + } - if (vflag_global) { - auto stress = results.at(mta_data->nc_stress_key).toCustomClass(); - auto stress_block = metatensor_torch::TensorMapHolder::block_by_id(stress, 0); - auto stress_tensor = stress_block->values().squeeze(0).squeeze(-1); - virial_tensor = - stress_tensor * compute_volume(domain); - virial_tensor = virial_tensor.to(torch::kCPU).to(torch::kFloat64); - } - } else { - // compute forces/virial on device with backward propagation - // reset gradients to zero before calling backward + // get non-conservative stress + if (vflag_global && do_nc_stress) { + auto stress = results.at(mta_data->nc_stress_key).toCustomClass(); + auto stress_block = metatensor_torch::TensorMapHolder::block_by_id(stress, 0); + auto stress_tensor = stress_block->values().squeeze(0).squeeze(-1); + virial_tensor = - stress_tensor * compute_volume(domain); + virial_tensor = virial_tensor.to(torch::kCPU).to(torch::kFloat64); + } + + // compute conservative quantities through autograd + if (need_energy_for_autograd) { this->system_adaptor->positions.mutable_grad() = torch::Tensor(); this->system_adaptor->strain.mutable_grad() = torch::Tensor(); auto _ = MetatomicTimer("running Model::backward"); energy_tensor.backward(-torch::ones_like(energy_tensor)); - forces_tensor = this->system_adaptor->positions.grad(); - virial_tensor = this->system_adaptor->strain.grad(); + if (!do_nc_forces) { + forces_tensor = this->system_adaptor->positions.grad(); + } + if (vflag_global && !do_nc_stress) { + virial_tensor = this->system_adaptor->strain.grad(); + } } { @@ -802,7 +859,7 @@ void PairMetatomic::compute(int eflag, int vflag) { assert(!vflag_fdotr); - if (vflag_global) { + if (vflag_global && virial_tensor.defined()) { auto virial_cpu = virial_tensor.to(torch::kCPU); assert(virial_cpu.is_cpu() && virial_cpu.scalar_type() == torch::kFloat64); @@ -833,8 +890,10 @@ void PairMetatomic::store_forces(const at::Tensor& forces_tensor) { atom->f[i][2] += this->scale * forces[i][2]; } - // in non-conservative mode we do not need to update forces on ghost atoms - if (!mta_data->non_conservative) { + // ghost atom forces only exist when forces come from autograd + using NCMode = PairMetatomicData::NonConservativeMode; + const auto nc_mode = mta_data->non_conservative; + if (nc_mode == NCMode::OFF || nc_mode == NCMode::STRESS) { const auto& mta_to_lmp = this->system_adaptor->mta_to_lmp; for (int i=atom->nlocal; if[mta_to_lmp[i]][0] += this->scale * forces[i][0];