From d9268aa65cb8e5cb36c4ee8d386da1823516a7e7 Mon Sep 17 00:00:00 2001 From: haykh Date: Mon, 16 Mar 2026 18:02:07 -0400 Subject: [PATCH 01/10] formatting --- .clang-format | 4 ++ src/archetypes/particle_injector.h | 1 - src/engines/grpic.hpp | 4 +- src/engines/srpic/currents.h | 1 - src/engines/srpic/fields_bcs.h | 1 - src/engines/srpic/fieldsolvers.h | 1 - src/engines/srpic/particles_bcs.h | 1 - src/framework/containers/particles_comm.cpp | 1 - src/framework/containers/particles_io.cpp | 3 +- src/framework/domain/metadomain_chckpt.cpp | 3 +- src/framework/domain/metadomain_io.cpp | 1 - src/framework/domain/metadomain_stats.cpp | 28 ++++++------ src/framework/parameters/algorithms.cpp | 3 +- src/framework/parameters/algorithms.h | 4 +- src/framework/parameters/extra.h | 4 +- src/framework/parameters/grid.cpp | 3 +- src/framework/parameters/grid.h | 4 +- src/framework/parameters/output.cpp | 1 + src/framework/parameters/output.h | 4 +- src/framework/simulation.cpp | 1 + src/framework/simulation.h | 3 +- src/framework/tests/comm-mpi.cpp | 4 +- src/framework/tests/comm-nompi.cpp | 4 +- src/global/arch/directions.h | 4 +- src/global/arch/kokkos_aliases.cpp | 36 +++++++-------- src/global/arch/mpi_tags.h | 9 +--- src/global/utils/cargs.cpp | 4 +- src/global/utils/cargs.h | 3 +- src/global/utils/comparators.h | 16 +++---- src/global/utils/param_container.cpp | 26 +++++------ src/global/utils/progressbar.cpp | 7 ++- src/global/utils/progressbar.h | 5 +-- src/global/utils/tools.h | 32 +++++++------- src/kernels/digital_filter.hpp | 49 +++++++++++---------- src/kernels/reduced_stats.hpp | 4 +- src/kernels/tests/deposit.cpp | 2 +- src/metrics/kerr_schild.h | 6 +-- src/metrics/qkerr_schild.h | 24 +++++----- src/output/checkpoint.h | 3 +- src/output/stats.cpp | 7 +-- src/output/stats.h | 4 +- src/output/utils/interpret_prompt.h | 4 +- 42 files changed, 159 insertions(+), 170 deletions(-) diff --git a/.clang-format b/.clang-format index b3bb8d132..6f7fef667 100644 --- a/.clang-format +++ b/.clang-format @@ -108,6 +108,10 @@ IncludeCategories: Priority: 4 - Regex: '^"checkpoint\/.*\.h"' Priority: 4 + - Regex: '^"kernels\/.*\.hpp"' + Priority: 4 + - Regex: '^"kernels\/.*\.h"' + Priority: 4 - Regex: '^"output\/.*\.h"' Priority: 4 - Regex: '^"archetypes\/.*\.h"' diff --git a/src/archetypes/particle_injector.h b/src/archetypes/particle_injector.h index d41996393..e1ca1b7ca 100644 --- a/src/archetypes/particle_injector.h +++ b/src/archetypes/particle_injector.h @@ -26,7 +26,6 @@ #include "framework/domain/domain.h" #include "framework/domain/metadomain.h" - #include "kernels/injectors.hpp" #include diff --git a/src/engines/grpic.hpp b/src/engines/grpic.hpp index d204ad74b..7952153ba 100644 --- a/src/engines/grpic.hpp +++ b/src/engines/grpic.hpp @@ -22,8 +22,6 @@ #include "framework/domain/domain.h" #include "framework/parameters/parameters.h" - -#include "engines/engine.hpp" #include "kernels/ampere_gr.hpp" #include "kernels/aux_fields_gr.hpp" #include "kernels/currents_deposit.hpp" @@ -31,6 +29,8 @@ #include "kernels/faraday_gr.hpp" #include "kernels/fields_bcs.hpp" #include "kernels/particle_pusher_gr.hpp" + +#include "engines/engine.hpp" #include "pgen.hpp" #include diff --git a/src/engines/srpic/currents.h b/src/engines/srpic/currents.h index 479ba0bc6..5733875b4 100644 --- a/src/engines/srpic/currents.h +++ b/src/engines/srpic/currents.h @@ -13,7 +13,6 @@ #include "engines/srpic/utils.h" #include "framework/domain/domain.h" #include "framework/domain/metadomain.h" - #include "kernels/currents_deposit.hpp" #include "kernels/digital_filter.hpp" diff --git a/src/engines/srpic/fields_bcs.h b/src/engines/srpic/fields_bcs.h index 07792002b..cf4c44dc8 100644 --- a/src/engines/srpic/fields_bcs.h +++ b/src/engines/srpic/fields_bcs.h @@ -13,7 +13,6 @@ #include "engines/srpic/utils.h" #include "framework/domain/domain.h" #include "framework/parameters/parameters.h" - #include "kernels/fields_bcs.hpp" namespace ntt { diff --git a/src/engines/srpic/fieldsolvers.h b/src/engines/srpic/fieldsolvers.h index 04bca2744..bc68cee3a 100644 --- a/src/engines/srpic/fieldsolvers.h +++ b/src/engines/srpic/fieldsolvers.h @@ -12,7 +12,6 @@ #include "engines/srpic/utils.h" #include "framework/domain/domain.h" #include "framework/parameters/parameters.h" - #include "kernels/ampere_mink.hpp" #include "kernels/ampere_sr.hpp" #include "kernels/faraday_mink.hpp" diff --git a/src/engines/srpic/particles_bcs.h b/src/engines/srpic/particles_bcs.h index 1f921f033..5d444385f 100644 --- a/src/engines/srpic/particles_bcs.h +++ b/src/engines/srpic/particles_bcs.h @@ -16,7 +16,6 @@ #include "framework/domain/domain.h" #include "framework/domain/metadomain.h" #include "framework/parameters/parameters.h" - #include "kernels/particle_moments.hpp" namespace ntt { diff --git a/src/framework/containers/particles_comm.cpp b/src/framework/containers/particles_comm.cpp index 4d6d67118..1cf17efef 100644 --- a/src/framework/containers/particles_comm.cpp +++ b/src/framework/containers/particles_comm.cpp @@ -10,7 +10,6 @@ #include "utils/log.h" #include "framework/containers/particles.h" - #include "kernels/comm.hpp" #include diff --git a/src/framework/containers/particles_io.cpp b/src/framework/containers/particles_io.cpp index d35ec59f5..9a234f5b9 100644 --- a/src/framework/containers/particles_io.cpp +++ b/src/framework/containers/particles_io.cpp @@ -7,11 +7,10 @@ #include "framework/containers/particles.h" #include "framework/specialization_registry.h" +#include "kernels/prtls_to_phys.hpp" #include "output/utils/readers.h" #include "output/utils/writers.h" -#include "kernels/prtls_to_phys.hpp" - #include #include diff --git a/src/framework/domain/metadomain_chckpt.cpp b/src/framework/domain/metadomain_chckpt.cpp index 693cafc1a..bd6b87cfd 100644 --- a/src/framework/domain/metadomain_chckpt.cpp +++ b/src/framework/domain/metadomain_chckpt.cpp @@ -1,5 +1,3 @@ -#include "output/checkpoint.h" - #include "enums.h" #include "global.h" @@ -10,6 +8,7 @@ #include "framework/domain/metadomain.h" #include "framework/parameters/parameters.h" #include "framework/specialization_registry.h" +#include "output/checkpoint.h" namespace ntt { diff --git a/src/framework/domain/metadomain_io.cpp b/src/framework/domain/metadomain_io.cpp index 1d29c4102..bea98e21e 100644 --- a/src/framework/domain/metadomain_io.cpp +++ b/src/framework/domain/metadomain_io.cpp @@ -11,7 +11,6 @@ #include "framework/domain/metadomain.h" #include "framework/parameters/parameters.h" #include "framework/specialization_registry.h" - #include "kernels/divergences.hpp" #include "kernels/fields_to_phys.hpp" #include "kernels/particle_moments.hpp" diff --git a/src/framework/domain/metadomain_stats.cpp b/src/framework/domain/metadomain_stats.cpp index 811b35111..21176a836 100644 --- a/src/framework/domain/metadomain_stats.cpp +++ b/src/framework/domain/metadomain_stats.cpp @@ -11,7 +11,6 @@ #include "framework/domain/metadomain.h" #include "framework/parameters/parameters.h" #include "framework/specialization_registry.h" - #include "kernels/reduced_stats.hpp" #include @@ -190,8 +189,8 @@ namespace ntt { timestep_t finished_step, simtime_t current_time, simtime_t finished_time, - std::function&)> - CustomStat) -> bool { + std::function&)> CustomStat) + -> bool { if (not(params.template get("output.stats.enable") and g_stats_writer.shouldWrite(finished_step, finished_time))) { return false; @@ -281,17 +280,18 @@ namespace ntt { return true; } -#define METADOMAIN_STATS(S, M, D) \ - template void Metadomain>::InitStatsWriter(const SimulationParams&, \ - bool); \ - template auto Metadomain>::WriteStats( \ - const SimulationParams&, \ - timestep_t, \ - timestep_t, \ - simtime_t, \ - simtime_t, \ - std::function< \ - real_t(const std::string&, timestep_t, simtime_t, const Domain>&)>) -> bool; +#define METADOMAIN_STATS(S, M, D) \ + template void Metadomain>::InitStatsWriter(const SimulationParams&, \ + bool); \ + template auto Metadomain>::WriteStats( \ + const SimulationParams&, \ + timestep_t, \ + timestep_t, \ + simtime_t, \ + simtime_t, \ + std::function< \ + real_t(const std::string&, timestep_t, simtime_t, const Domain>&)>) \ + -> bool; NTT_FOREACH_SPECIALIZATION(METADOMAIN_STATS) diff --git a/src/framework/parameters/algorithms.cpp b/src/framework/parameters/algorithms.cpp index 944c323f8..856d05fc7 100644 --- a/src/framework/parameters/algorithms.cpp +++ b/src/framework/parameters/algorithms.cpp @@ -4,10 +4,11 @@ #include "global.h" #include "utils/numeric.h" -#include #include "framework/parameters/parameters.h" +#include + namespace ntt { namespace params { diff --git a/src/framework/parameters/algorithms.h b/src/framework/parameters/algorithms.h index fb41d731c..a496cb7b6 100644 --- a/src/framework/parameters/algorithms.h +++ b/src/framework/parameters/algorithms.h @@ -14,10 +14,10 @@ #include "global.h" -#include - #include "framework/parameters/parameters.h" +#include + #include #include diff --git a/src/framework/parameters/extra.h b/src/framework/parameters/extra.h index cec6f619b..cf5c0849c 100644 --- a/src/framework/parameters/extra.h +++ b/src/framework/parameters/extra.h @@ -14,10 +14,10 @@ #include "global.h" -#include - #include "framework/parameters/parameters.h" +#include + #include #include diff --git a/src/framework/parameters/grid.cpp b/src/framework/parameters/grid.cpp index 78072680f..9e5d39df5 100644 --- a/src/framework/parameters/grid.cpp +++ b/src/framework/parameters/grid.cpp @@ -6,7 +6,6 @@ #include "utils/error.h" #include "utils/formatting.h" #include "utils/numeric.h" -#include #include "metrics/kerr_schild.h" #include "metrics/kerr_schild_0.h" @@ -17,6 +16,8 @@ #include "framework/parameters/parameters.h" +#include + #include #include #include diff --git a/src/framework/parameters/grid.h b/src/framework/parameters/grid.h index bc73063d7..978b7f329 100644 --- a/src/framework/parameters/grid.h +++ b/src/framework/parameters/grid.h @@ -15,10 +15,10 @@ #include "enums.h" #include "global.h" -#include - #include "framework/parameters/parameters.h" +#include + #include #include #include diff --git a/src/framework/parameters/output.cpp b/src/framework/parameters/output.cpp index 3d06ff054..547800971 100644 --- a/src/framework/parameters/output.cpp +++ b/src/framework/parameters/output.cpp @@ -5,6 +5,7 @@ #include "utils/error.h" #include "utils/log.h" + #include namespace ntt { diff --git a/src/framework/parameters/output.h b/src/framework/parameters/output.h index 08b3377d2..e301fc745 100644 --- a/src/framework/parameters/output.h +++ b/src/framework/parameters/output.h @@ -14,10 +14,10 @@ #include "global.h" -#include - #include "framework/parameters/parameters.h" +#include + #include #include #include diff --git a/src/framework/simulation.cpp b/src/framework/simulation.cpp index db133c0f8..4e0da3eda 100644 --- a/src/framework/simulation.cpp +++ b/src/framework/simulation.cpp @@ -9,6 +9,7 @@ #include "utils/formatting.h" #include "utils/log.h" #include "utils/plog.h" + #include #include diff --git a/src/framework/simulation.h b/src/framework/simulation.h index 745c76db4..7387d62dd 100644 --- a/src/framework/simulation.h +++ b/src/framework/simulation.h @@ -17,11 +17,12 @@ #include "enums.h" #include "utils/error.h" -#include #include "engines/traits.h" #include "framework/parameters/parameters.h" +#include + namespace ntt { class Simulation { diff --git a/src/framework/tests/comm-mpi.cpp b/src/framework/tests/comm-mpi.cpp index 487976f73..8da0fa54b 100644 --- a/src/framework/tests/comm-mpi.cpp +++ b/src/framework/tests/comm-mpi.cpp @@ -1,5 +1,3 @@ -#include "framework/domain/comm_mpi.hpp" - #include "enums.h" #include "global.h" @@ -8,6 +6,8 @@ #include "utils/error.h" #include "utils/numeric.h" +#include "framework/domain/comm_mpi.hpp" + #include #include diff --git a/src/framework/tests/comm-nompi.cpp b/src/framework/tests/comm-nompi.cpp index c7646ef03..059c21cdb 100644 --- a/src/framework/tests/comm-nompi.cpp +++ b/src/framework/tests/comm-nompi.cpp @@ -1,5 +1,3 @@ -#include "framework/domain/comm_nompi.hpp" - #include "enums.h" #include "global.h" @@ -7,6 +5,8 @@ #include "arch/kokkos_aliases.h" #include "utils/numeric.h" +#include "framework/domain/comm_nompi.hpp" + #include #include diff --git a/src/global/arch/directions.h b/src/global/arch/directions.h index 850bc130d..5f8281ed3 100644 --- a/src/global/arch/directions.h +++ b/src/global/arch/directions.h @@ -132,8 +132,8 @@ namespace dir { using dirs_t = std::vector>; template - inline auto operator<<(std::ostream& os, - const direction_t& dir) -> std::ostream& { + inline auto operator<<(std::ostream& os, const direction_t& dir) + -> std::ostream& { for (auto& d : dir) { os << std::setw(2) << std::left; if (d > 0) { diff --git a/src/global/arch/kokkos_aliases.cpp b/src/global/arch/kokkos_aliases.cpp index e81d41280..25397af8e 100644 --- a/src/global/arch/kokkos_aliases.cpp +++ b/src/global/arch/kokkos_aliases.cpp @@ -9,18 +9,18 @@ auto CreateParticleRangePolicy(npart_t p1, npart_t p2) -> range_t { } template <> -auto CreateRangePolicy( - const tuple_t& i1, - const tuple_t& i2) -> range_t { +auto CreateRangePolicy(const tuple_t& i1, + const tuple_t& i2) + -> range_t { index_t i1min = i1[0]; index_t i1max = i2[0]; return Kokkos::RangePolicy(i1min, i1max); } template <> -auto CreateRangePolicy( - const tuple_t& i1, - const tuple_t& i2) -> range_t { +auto CreateRangePolicy(const tuple_t& i1, + const tuple_t& i2) + -> range_t { index_t i1min = i1[0]; index_t i1max = i2[0]; index_t i2min = i1[1]; @@ -31,9 +31,9 @@ auto CreateRangePolicy( } template <> -auto CreateRangePolicy( - const tuple_t& i1, - const tuple_t& i2) -> range_t { +auto CreateRangePolicy(const tuple_t& i1, + const tuple_t& i2) + -> range_t { index_t i1min = i1[0]; index_t i1max = i2[0]; index_t i2min = i1[1]; @@ -46,18 +46,18 @@ auto CreateRangePolicy( } template <> -auto CreateRangePolicyOnHost( - const tuple_t& i1, - const tuple_t& i2) -> range_h_t { +auto CreateRangePolicyOnHost(const tuple_t& i1, + const tuple_t& i2) + -> range_h_t { index_t i1min = i1[0]; index_t i1max = i2[0]; return Kokkos::RangePolicy(i1min, i1max); } template <> -auto CreateRangePolicyOnHost( - const tuple_t& i1, - const tuple_t& i2) -> range_h_t { +auto CreateRangePolicyOnHost(const tuple_t& i1, + const tuple_t& i2) + -> range_h_t { index_t i1min = i1[0]; index_t i1max = i2[0]; index_t i2min = i1[1]; @@ -68,9 +68,9 @@ auto CreateRangePolicyOnHost( } template <> -auto CreateRangePolicyOnHost( - const tuple_t& i1, - const tuple_t& i2) -> range_h_t { +auto CreateRangePolicyOnHost(const tuple_t& i1, + const tuple_t& i2) + -> range_h_t { index_t i1min = i1[0]; index_t i1max = i2[0]; index_t i2min = i1[1]; diff --git a/src/global/arch/mpi_tags.h b/src/global/arch/mpi_tags.h index aaf38a8f4..775670614 100644 --- a/src/global/arch/mpi_tags.h +++ b/src/global/arch/mpi_tags.h @@ -190,13 +190,8 @@ namespace mpi { tag; } - Inline auto SendTag(short tag, - bool im1, - bool ip1, - bool jm1, - bool jp1, - bool km1, - bool kp1) -> short { + Inline auto SendTag(short tag, bool im1, bool ip1, bool jm1, bool jp1, bool km1, bool kp1) + -> short { return ((im1 && jm1 && km1) * (PrtlSendTag::im1_jm1_km1 - 1) + (im1 && jm1 && kp1) * (PrtlSendTag::im1_jm1_kp1 - 1) + (im1 && jp1 && km1) * (PrtlSendTag::im1_jp1_km1 - 1) + diff --git a/src/global/utils/cargs.cpp b/src/global/utils/cargs.cpp index 57b79f33b..8f641214e 100644 --- a/src/global/utils/cargs.cpp +++ b/src/global/utils/cargs.cpp @@ -18,8 +18,8 @@ namespace cargs { _initialized = true; } - auto CommandLineArguments::getArgument(std::string_view key, - std::string_view def) -> std::string_view { + auto CommandLineArguments::getArgument(std::string_view key, std::string_view def) + -> std::string_view { if (!_initialized) { throw std::runtime_error( "# Error: command line arguments have not been parsed."); diff --git a/src/global/utils/cargs.h b/src/global/utils/cargs.h index 530969912..7c02146b7 100644 --- a/src/global/utils/cargs.h +++ b/src/global/utils/cargs.h @@ -25,7 +25,8 @@ namespace cargs { public: void readCommandLineArguments(int argc, char* argv[]); [[nodiscard]] - auto getArgument(std::string_view key, std::string_view def) -> std::string_view; + auto getArgument(std::string_view key, std::string_view def) + -> std::string_view; [[nodiscard]] auto getArgument(std::string_view key) -> std::string_view; auto isSpecified(std::string_view key) -> bool; diff --git a/src/global/utils/comparators.h b/src/global/utils/comparators.h index a12d55e73..d86fe868c 100644 --- a/src/global/utils/comparators.h +++ b/src/global/utils/comparators.h @@ -27,31 +27,31 @@ namespace cmp { template - Inline auto AlmostEqual(T a, - T b, - T eps = Kokkos::Experimental::epsilon::value) -> bool { + Inline auto AlmostEqual(T a, T b, T eps = Kokkos::Experimental::epsilon::value) + -> bool { static_assert(std::is_floating_point_v, "T must be a floating point type"); return (a == b) || (math::abs(a - b) <= math::min(math::abs(a), math::abs(b)) * eps); } template - Inline auto AlmostZero(T a, T eps = Kokkos::Experimental::epsilon::value) -> bool { + Inline auto AlmostZero(T a, T eps = Kokkos::Experimental::epsilon::value) + -> bool { static_assert(std::is_floating_point_v, "T must be a floating point type"); return math::abs(a) <= eps; } template - inline auto AlmostEqual_host(T a, - T b, - T eps = std::numeric_limits::epsilon()) -> bool { + inline auto AlmostEqual_host(T a, T b, T eps = std::numeric_limits::epsilon()) + -> bool { static_assert(std::is_floating_point_v, "T must be a floating point type"); return (a == b) || (std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * eps); } template - inline auto AlmostZero_host(T a, T eps = std::numeric_limits::epsilon()) -> bool { + inline auto AlmostZero_host(T a, T eps = std::numeric_limits::epsilon()) + -> bool { static_assert(std::is_floating_point_v, "T must be a floating point type"); return std::abs(a) <= eps; } diff --git a/src/global/utils/param_container.cpp b/src/global/utils/param_container.cpp index bd8399bc8..6769c6f40 100644 --- a/src/global/utils/param_container.cpp +++ b/src/global/utils/param_container.cpp @@ -31,8 +31,8 @@ namespace prm { } template - auto write(adios2::IO& io, const std::string& name, T var) -> decltype(void(T()), - void()) { + auto write(adios2::IO& io, const std::string& name, T var) + -> decltype(void(T()), void()) { io.DefineAttribute(name, var); } @@ -47,8 +47,8 @@ namespace prm { } template - auto write_pair(adios2::IO& io, const std::string& name, std::pair var) -> - typename std::enable_if::value, void>::type { + auto write_pair(adios2::IO& io, const std::string& name, std::pair var) + -> typename std::enable_if::value, void>::type { std::vector var_str; var_str.push_back(var.first.to_string()); var_str.push_back(var.second.to_string()); @@ -56,9 +56,8 @@ namespace prm { } template - auto write_pair(adios2::IO& io, - const std::string& name, - std::pair var) -> decltype(void(T()), void()) { + auto write_pair(adios2::IO& io, const std::string& name, std::pair var) + -> decltype(void(T()), void()) { std::vector var_vec; var_vec.push_back(var.first); var_vec.push_back(var.second); @@ -76,9 +75,8 @@ namespace prm { } template - auto write_vec(adios2::IO& io, - const std::string& name, - std::vector var) -> decltype(void(T()), void()) { + auto write_vec(adios2::IO& io, const std::string& name, std::vector var) + -> decltype(void(T()), void()) { io.DefineAttribute(name, var.data(), var.size()); } @@ -98,8 +96,8 @@ namespace prm { template auto write_vec_pair(adios2::IO& io, const std::string& name, - std::vector> var) -> decltype(void(T()), - void()) { + std::vector> var) + -> decltype(void(T()), void()) { std::vector var_vec; for (const auto& v : var) { var_vec.push_back(v.first); @@ -125,8 +123,8 @@ namespace prm { template auto write_vec_vec(adios2::IO& io, const std::string& name, - std::vector> var) -> decltype(void(T()), - void()) { + std::vector> var) + -> decltype(void(T()), void()) { std::vector var_vec; for (const auto& vec : var) { for (const auto& v : vec) { diff --git a/src/global/utils/progressbar.cpp b/src/global/utils/progressbar.cpp index eaa8118fc..994927054 100644 --- a/src/global/utils/progressbar.cpp +++ b/src/global/utils/progressbar.cpp @@ -21,11 +21,10 @@ namespace pbar { - auto normalize_duration_fmt( - duration_t t, - const std::string& u) -> std::pair { + auto normalize_duration_fmt(duration_t t, const std::string& u) + -> std::pair { const std::vector> units { - { "µs", 1e0 }, + { "µs", 1e0 }, { "ms", 1e3 }, { "s", 1e6 }, { "min", 6e7 }, diff --git a/src/global/utils/progressbar.h b/src/global/utils/progressbar.h index 588413cb4..982433fa3 100644 --- a/src/global/utils/progressbar.h +++ b/src/global/utils/progressbar.h @@ -76,9 +76,8 @@ namespace pbar { } }; - auto normalize_duration_fmt( - duration_t t, - const std::string& u) -> std::pair; + auto normalize_duration_fmt(duration_t t, const std::string& u) + -> std::pair; auto to_human_readable(duration_t t, const std::string& u) -> std::string; diff --git a/src/global/utils/tools.h b/src/global/utils/tools.h index 8a568ae20..27c7e22c8 100644 --- a/src/global/utils/tools.h +++ b/src/global/utils/tools.h @@ -60,8 +60,8 @@ namespace tools { * @return Tensor product of list */ template - inline auto TensorProduct( - const std::vector>& list) -> std::vector> { + inline auto TensorProduct(const std::vector>& list) + -> std::vector> { std::vector> result = { {} }; for (const auto& sublist : list) { std::vector> temp; @@ -81,8 +81,8 @@ namespace tools { * @param ndomains Number of domains * @param ncells Number of cells */ - inline auto decompose1D(unsigned int ndomains, - ncells_t ncells) -> std::vector { + inline auto decompose1D(unsigned int ndomains, ncells_t ncells) + -> std::vector { auto size = (ncells_t)((double)ncells / (double)ndomains); auto ncells_domain = std::vector(ndomains, size); for (auto i { 0u }; i < ncells - size * ndomains; ++i) { @@ -107,10 +107,8 @@ namespace tools { * @param s1 Proportion of the first dimension * @param s2 Proportion of the second dimension */ - inline auto divideInProportions2D( - unsigned int ntot, - unsigned int s1, - unsigned int s2) -> std::tuple { + inline auto divideInProportions2D(unsigned int ntot, unsigned int s1, unsigned int s2) + -> std::tuple { auto n1 = (unsigned int)(std::sqrt((double)ntot * (double)s1 / (double)s2)); if (n1 == 0) { return { 1, ntot }; @@ -132,11 +130,11 @@ namespace tools { * @param s2 Proportion of the second dimension * @param s3 Proportion of the third dimension */ - inline auto divideInProportions3D( - unsigned int ntot, - unsigned int s1, - unsigned int s2, - unsigned int s3) -> std::tuple { + inline auto divideInProportions3D(unsigned int ntot, + unsigned int s1, + unsigned int s2, + unsigned int s3) + -> std::tuple { auto n1 = (unsigned int)(std::cbrt( (double)ntot * (double)(SQR(s1)) / (double)(s2 * s3))); if (n1 > ntot) { @@ -165,10 +163,10 @@ namespace tools { * * @note If decomposition has -1, it will be calculated automatically */ - inline auto Decompose( - unsigned int ndomains, - const std::vector& ncells, - const std::vector& decomposition) -> std::vector> { + inline auto Decompose(unsigned int ndomains, + const std::vector& ncells, + const std::vector& decomposition) + -> std::vector> { const auto dimension = ncells.size(); raise::ErrorIf(dimension != decomposition.size(), "Decomposition error: dimension != decomposition.size", diff --git a/src/kernels/digital_filter.hpp b/src/kernels/digital_filter.hpp index 8676a3c71..5ac60327d 100644 --- a/src/kernels/digital_filter.hpp +++ b/src/kernels/digital_filter.hpp @@ -19,35 +19,38 @@ #define FILTER2D_IN_I1(ARR, COMP, I, J) \ INV_2*(ARR)((I), (J), (COMP)) + \ - INV_4*((ARR)((I)-1, (J), (COMP)) + (ARR)((I) + 1, (J), (COMP))) + INV_4*((ARR)((I) - 1, (J), (COMP)) + (ARR)((I) + 1, (J), (COMP))) #define FILTER2D_IN_I2(ARR, COMP, I, J) \ INV_2*(ARR)((I), (J), (COMP)) + \ - INV_4*((ARR)((I), (J)-1, (COMP)) + (ARR)((I), (J) + 1, (COMP))) + INV_4*((ARR)((I), (J) - 1, (COMP)) + (ARR)((I), (J) + 1, (COMP))) -#define FILTER3D_IN_I1_I2(ARR, COMP, I, J, K) \ - INV_4*(ARR)(I, J, K, (COMP)) + \ - INV_8*((ARR)((I)-1, (J), (K), (COMP)) + (ARR)((I) + 1, (J), (K), (COMP)) + \ - (ARR)((I), (J)-1, (K), (COMP)) + (ARR)((I), (J) + 1, (K), (COMP))) + \ - INV_16*( \ - (ARR)((I)-1, (J)-1, (K), (COMP)) + (ARR)((I) + 1, (J) + 1, (K), (COMP)) + \ - (ARR)((I)-1, (J) + 1, (K), (COMP)) + (ARR)((I) + 1, (J)-1, (K), (COMP))) +#define FILTER3D_IN_I1_I2(ARR, COMP, I, J, K) \ + INV_4*(ARR)(I, J, K, (COMP)) + \ + INV_8*((ARR)((I) - 1, (J), (K), (COMP)) + (ARR)((I) + 1, (J), (K), (COMP)) + \ + (ARR)((I), (J) - 1, (K), (COMP)) + (ARR)((I), (J) + 1, (K), (COMP))) + \ + INV_16*((ARR)((I) - 1, (J) - 1, (K), (COMP)) + \ + (ARR)((I) + 1, (J) + 1, (K), (COMP)) + \ + (ARR)((I) - 1, (J) + 1, (K), (COMP)) + \ + (ARR)((I) + 1, (J) - 1, (K), (COMP))) -#define FILTER3D_IN_I2_I3(ARR, COMP, I, J, K) \ - INV_4*(ARR)(I, J, K, (COMP)) + \ - INV_8*((ARR)((I), (J)-1, (K), (COMP)) + (ARR)((I), (J) + 1, (K), (COMP)) + \ - (ARR)((I), (J), (K)-1, (COMP)) + (ARR)((I), (J), (K) + 1, (COMP))) + \ - INV_16*( \ - (ARR)((I), (J)-1, (K)-1, (COMP)) + (ARR)((I), (J) + 1, (K) + 1, (COMP)) + \ - (ARR)((I), (J)-1, (K) + 1, (COMP)) + (ARR)((I), (J) + 1, (K)-1, (COMP))) +#define FILTER3D_IN_I2_I3(ARR, COMP, I, J, K) \ + INV_4*(ARR)(I, J, K, (COMP)) + \ + INV_8*((ARR)((I), (J) - 1, (K), (COMP)) + (ARR)((I), (J) + 1, (K), (COMP)) + \ + (ARR)((I), (J), (K) - 1, (COMP)) + (ARR)((I), (J), (K) + 1, (COMP))) + \ + INV_16*((ARR)((I), (J) - 1, (K) - 1, (COMP)) + \ + (ARR)((I), (J) + 1, (K) + 1, (COMP)) + \ + (ARR)((I), (J) - 1, (K) + 1, (COMP)) + \ + (ARR)((I), (J) + 1, (K) - 1, (COMP))) -#define FILTER3D_IN_I1_I3(ARR, COMP, I, J, K) \ - INV_4*(ARR)(I, J, K, (COMP)) + \ - INV_8*((ARR)((I)-1, (J), (K), (COMP)) + (ARR)((I) + 1, (J), (K), (COMP)) + \ - (ARR)((I), (J), (K)-1, (COMP)) + (ARR)((I), (J), (K) + 1, (COMP))) + \ - INV_16*( \ - (ARR)((I)-1, (J), (K)-1, (COMP)) + (ARR)((I) + 1, (J), (K) + 1, (COMP)) + \ - (ARR)((I)-1, (J), (K) + 1, (COMP)) + (ARR)((I) + 1, (J), (K)-1, (COMP))) +#define FILTER3D_IN_I1_I3(ARR, COMP, I, J, K) \ + INV_4*(ARR)(I, J, K, (COMP)) + \ + INV_8*((ARR)((I) - 1, (J), (K), (COMP)) + (ARR)((I) + 1, (J), (K), (COMP)) + \ + (ARR)((I), (J), (K) - 1, (COMP)) + (ARR)((I), (J), (K) + 1, (COMP))) + \ + INV_16*((ARR)((I) - 1, (J), (K) - 1, (COMP)) + \ + (ARR)((I) + 1, (J), (K) + 1, (COMP)) + \ + (ARR)((I) - 1, (J), (K) + 1, (COMP)) + \ + (ARR)((I) + 1, (J), (K) - 1, (COMP))) namespace kernel { using namespace ntt; diff --git a/src/kernels/reduced_stats.hpp b/src/kernels/reduced_stats.hpp index 392d24749..a6824fa28 100644 --- a/src/kernels/reduced_stats.hpp +++ b/src/kernels/reduced_stats.hpp @@ -25,8 +25,8 @@ namespace kernel { template requires metric::traits::HasD && (metric::traits::HasTransform_i || I == 0) && - metric::traits::HasTransform && metric::traits::HasSqrtDetH && - (I <= 3) + metric::traits::HasTransform && + metric::traits::HasSqrtDetH && (I <= 3) class ReducedFields_kernel { static constexpr auto D = M::Dim; diff --git a/src/kernels/tests/deposit.cpp b/src/kernels/tests/deposit.cpp index 75ecf530c..8c7d32eb9 100644 --- a/src/kernels/tests/deposit.cpp +++ b/src/kernels/tests/deposit.cpp @@ -142,7 +142,7 @@ void testDeposit(const std::vector& res, put_value(ux3, uz, 0); put_value(weight, 1.0, 0); put_value(tag, ParticleTag::alive, 0); - + auto J_scat = Kokkos::Experimental::create_scatter_view(J); // clang-format off diff --git a/src/metrics/kerr_schild.h b/src/metrics/kerr_schild.h index 4df88cf4c..9ce53c52e 100644 --- a/src/metrics/kerr_schild.h +++ b/src/metrics/kerr_schild.h @@ -450,15 +450,15 @@ namespace metric { Inline auto polar_area(real_t x1) const -> real_t { if (small_angle) { return dr * (SQR(x1 * dr + x1_min) + SQR(a)) * - math::sqrt(ONE + TWO * (x1 * dr + x1_min) / + math::sqrt(ONE + TWO * (x1 * dr + x1_min) / (SQR(x1 * dr + x1_min) + SQR(a))) * (static_cast(48) - SQR(dtheta)) * SQR(dtheta) / static_cast(384); } else { return dr * (SQR(x1 * dr + x1_min) + SQR(a)) * - math::sqrt(ONE + TWO * (x1 * dr + x1_min) / + math::sqrt(ONE + TWO * (x1 * dr + x1_min) / (SQR(x1 * dr + x1_min) + SQR(a))) * - (ONE - math::cos(HALF * dtheta)); + (ONE - math::cos(HALF * dtheta)); } } diff --git a/src/metrics/qkerr_schild.h b/src/metrics/qkerr_schild.h index 90f366efb..4a4756d98 100644 --- a/src/metrics/qkerr_schild.h +++ b/src/metrics/qkerr_schild.h @@ -90,7 +90,7 @@ namespace metric { , dphi { (x3_max - phi_min) / nx3 } , dchi_inv { ONE / dchi } , deta_inv { ONE / deta } - , dphi_inv { ONE / dphi } + , dphi_inv { ONE / dphi } , small_angle { eta2theta(HALF * deta) < constant::SMALL_ANGLE } { set_dxMin(find_dxMin()); } @@ -509,22 +509,22 @@ namespace metric { */ Inline auto polar_area(real_t x1) const -> real_t { if constexpr (D != Dim::_1D) { - if (small_angle) { + if (small_angle) { const real_t dtheta = eta2theta(HALF * deta); return dchi * math::exp(x1 * dchi + chi_min) * - (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a)) * - math::sqrt( - ONE + TWO * (r0 + math::exp(x1 * dchi + chi_min)) / - (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a))) * - (static_cast(48) - SQR(dtheta)) * SQR(dtheta) / + (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a)) * + math::sqrt( + ONE + TWO * (r0 + math::exp(x1 * dchi + chi_min)) / + (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a))) * + (static_cast(48) - SQR(dtheta)) * SQR(dtheta) / static_cast(384); } else { return dchi * math::exp(x1 * dchi + chi_min) * - (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a)) * - math::sqrt( - ONE + TWO * (r0 + math::exp(x1 * dchi + chi_min)) / - (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a))) * - (ONE - math::cos(eta2theta(HALF * deta))); + (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a)) * + math::sqrt( + ONE + TWO * (r0 + math::exp(x1 * dchi + chi_min)) / + (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a))) * + (ONE - math::cos(eta2theta(HALF * deta))); } } } diff --git a/src/output/checkpoint.h b/src/output/checkpoint.h index 495b97b73..283584136 100644 --- a/src/output/checkpoint.h +++ b/src/output/checkpoint.h @@ -68,7 +68,8 @@ namespace checkpoint { } [[nodiscard]] - auto written() const -> const std::vector>& { + auto written() const + -> const std::vector>& { return m_written; } diff --git a/src/output/stats.cpp b/src/output/stats.cpp index 1d1038d5f..ca270bc91 100644 --- a/src/output/stats.cpp +++ b/src/output/stats.cpp @@ -81,11 +81,8 @@ namespace stats { CallOnce( [this](auto& fname, auto& stat_writers) { std::fstream StatsOut(fname, std::fstream::out | std::fstream::app); - StatsOut << std::setw(io_precision + 8) - << "step" - << "," - << std::setw(io_precision + 8) - << "time" + StatsOut << std::setw(io_precision + 8) << "step" + << "," << std::setw(io_precision + 8) << "time" << ","; for (const auto& stat : stat_writers) { if (stat.is_vector()) { diff --git a/src/output/stats.h b/src/output/stats.h index b687e4107..d3b7e1c82 100644 --- a/src/output/stats.h +++ b/src/output/stats.h @@ -193,9 +193,7 @@ namespace stats { [this](auto&& fname, auto&& value) { std::fstream StatsOut(fname, std::fstream::out | std::fstream::app); StatsOut << std::setw(io_precision + 8) - << std::setprecision(io_precision) - << value - << ","; + << std::setprecision(io_precision) << value << ","; StatsOut.close(); }, m_fname, diff --git a/src/output/utils/interpret_prompt.h b/src/output/utils/interpret_prompt.h index 032482cf8..488d81101 100644 --- a/src/output/utils/interpret_prompt.h +++ b/src/output/utils/interpret_prompt.h @@ -26,8 +26,8 @@ namespace out { auto InterpretSpecies(const std::string&) -> std::vector; - auto InterpretComponents( - const std::vector&) -> std::vector>; + auto InterpretComponents(const std::vector&) + -> std::vector>; } // namespace out From 21e7eee69905927687bf466163f782f1cf547f16 Mon Sep 17 00:00:00 2001 From: haykh Date: Mon, 16 Mar 2026 18:02:54 -0400 Subject: [PATCH 02/10] custom emission policy (pgen-defined) + unit test --- src/archetypes/traits.h | 11 + src/engines/srpic/particle_pusher.h | 103 +++++--- src/framework/containers/particles.cpp | 25 ++ src/framework/containers/particles.h | 7 + src/framework/parameters/particles.cpp | 2 + src/global/arch/traits.h | 39 --- src/global/enums.h | 5 +- src/kernels/emission/compton.hpp | 26 +- src/kernels/emission/synchrotron.hpp | 26 +- src/kernels/emission/traits.h | 75 ++++++ src/kernels/particle_pusher_sr.hpp | 49 +++- src/kernels/tests/CMakeLists.txt | 3 +- src/kernels/tests/custom_emission.cpp | 343 +++++++++++++++++++++++++ 13 files changed, 622 insertions(+), 92 deletions(-) create mode 100644 src/kernels/emission/traits.h create mode 100644 src/kernels/tests/custom_emission.cpp diff --git a/src/archetypes/traits.h b/src/archetypes/traits.h index 522d3f24b..d538e1faf 100644 --- a/src/archetypes/traits.h +++ b/src/archetypes/traits.h @@ -30,6 +30,8 @@ #include "arch/kokkos_aliases.h" +#include "framework/parameters/parameters.h" + namespace arch { namespace traits { @@ -77,6 +79,15 @@ namespace arch { template concept HasInitFlds = requires(const PG& pgen) { pgen.init_flds; }; + template + concept HasEmissionPolicy = requires(const PG& pgen, + simtime_t time, + spidx_t sp, + const D& domain, + const ntt::SimulationParams& params) { + pgen.EmissionPolicy(time, sp, domain, params); + }; + template concept HasInitPrtls = requires(PG& pgen, D& domain) { { pgen.InitPrtls(domain) } -> std::same_as; diff --git a/src/engines/srpic/particle_pusher.h b/src/engines/srpic/particle_pusher.h index 3ba120ef3..067746d59 100644 --- a/src/engines/srpic/particle_pusher.h +++ b/src/engines/srpic/particle_pusher.h @@ -16,7 +16,6 @@ #include "framework/domain/domain.h" #include "framework/domain/grid.h" #include "framework/parameters/parameters.h" - #include "kernels/emission/compton.hpp" #include "kernels/emission/emission.hpp" #include "kernels/emission/synchrotron.hpp" @@ -25,7 +24,7 @@ namespace ntt { namespace srpic { - template + template void CallPusher(Domain& domain, const SimulationParams& params, const kernel::sr::PusherParams& pusher_params, @@ -34,6 +33,7 @@ namespace ntt { const range_t& range, const ndfield_t& EB, const M& metric, + const PG& pgen, const F& external_fields) { if (emission_policy_flag == EmissionType::NONE) { const auto no_emission = kernel::NoEmissionPolicy_t {}; @@ -62,6 +62,7 @@ namespace ntt { HERE); const auto emission_policy = kernel::emission::Synchrotron( emitted_species, + photon_species, pusher_params.mass, pusher_params.charge, pusher_params.radiative_drag_flags, @@ -79,11 +80,14 @@ namespace ntt { metric, external_fields, emission_policy)); - const auto n_inj = emission_policy.number_injected(); + const auto n_inj = emission_policy.numbers_injected(); + raise::ErrorIf(n_inj.size() != 1, + "Synchrotron emission should only inject one species", + HERE); domain.species[photon_species - 1].set_npart( - emitted_species.npart() + n_inj); + emitted_species.npart() + n_inj[0]); domain.species[photon_species - 1].set_counter( - emitted_species.counter() + n_inj); + emitted_species.counter() + n_inj[0]); } else if (emission_policy_flag == EmissionType::COMPTON) { const auto photon_species = params.get( "radiation.emission.compton.photon_species"); @@ -99,6 +103,7 @@ namespace ntt { HERE); const auto emission_policy = kernel::emission::Compton( emitted_species, + photon_species, pusher_params.mass, pusher_params.charge, pusher_params.radiative_drag_flags, @@ -116,11 +121,59 @@ namespace ntt { metric, external_fields, emission_policy)); - const auto n_inj = emission_policy.number_injected(); + const auto n_inj = emission_policy.numbers_injected(); + raise::ErrorIf(n_inj.size() != 1, + "Compton emission should only inject one species", + HERE); domain.species[photon_species - 1].set_npart( - emitted_species.npart() + n_inj); + emitted_species.npart() + n_inj[0]); domain.species[photon_species - 1].set_counter( - emitted_species.counter() + n_inj); + emitted_species.counter() + n_inj[0]); + } else if (emission_policy_flag == EmissionType::CUSTOM) { + if constexpr ( + arch::traits::pgen::HasEmissionPolicy) { + const auto emission_policy = pgen.EmissionPolicy(pusher_params.time, + pusher_params.species_index, + domain, + params); + static_assert( + kernel::traits::emission::IsValid, + "Custom emission policy does not satisfy the required " + "interface"); + Kokkos::parallel_for( + "ParticlePusher", + range, + kernel::sr::Pusher_kernel( + pusher_params, + pusher_arrays, + EB, + metric, + external_fields, + emission_policy)); + const auto emitted_species = emission_policy.emitted_species_indices(); + const auto n_inj = emission_policy.number_injected(); + raise::ErrorIf(emitted_species.size() != n_inj.size(), + "Emission policy emitted_species_indices and " + "numbers_injected must have the same size", + HERE); + for (auto i = 0u; i < emitted_species.size(); ++i) { + const auto sp_idx = emitted_species[i]; + raise::ErrorIf(sp_idx > domain.species.size(), + "Invalid emitted species index from custom " + "emission policy", + HERE); + domain.species[sp_idx - 1].set_npart( + domain.species[sp_idx - 1].npart() + n_inj[i]); + domain.species[sp_idx - 1].set_counter( + domain.species[sp_idx - 1].counter() + n_inj[i]); + } + } else { + raise::Error("Custom emission policy flag is set but problem " + "generator does not define an emission policy", + HERE); + } + } else { + raise::Error("Unrecognized emission policy flag", HERE); } } @@ -184,6 +237,7 @@ namespace ntt { HERE); kernel::sr::PusherParams pusher_params {}; + pusher_params.species_index = species.index(); pusher_params.pusher_flags = species.pusher(); pusher_params.radiative_drag_flags = species.radiative_drag_flags(); pusher_params.mass = species.mass(); @@ -226,26 +280,7 @@ namespace ntt { params.template get("radiation.drag.compton.gamma_rad")); } - kernel::sr::PusherArrays pusher_arrays {}; - pusher_arrays.sp = species.index(); - pusher_arrays.i1 = species.i1; - pusher_arrays.i2 = species.i2; - pusher_arrays.i3 = species.i3; - pusher_arrays.i1_prev = species.i1_prev; - pusher_arrays.i2_prev = species.i2_prev; - pusher_arrays.i3_prev = species.i3_prev; - pusher_arrays.dx1 = species.dx1; - pusher_arrays.dx2 = species.dx2; - pusher_arrays.dx3 = species.dx3; - pusher_arrays.dx1_prev = species.dx1_prev; - pusher_arrays.dx2_prev = species.dx2_prev; - pusher_arrays.dx3_prev = species.dx3_prev; - pusher_arrays.ux1 = species.ux1; - pusher_arrays.ux2 = species.ux2; - pusher_arrays.ux3 = species.ux3; - pusher_arrays.phi = species.phi; - pusher_arrays.weight = species.weight; - pusher_arrays.tag = species.tag; + auto pusher_arrays = species.PusherKernelArrays(); // toggle to indicate whether pgen defines the external force bool has_extfields = false; @@ -261,7 +296,7 @@ namespace ntt { } if (not has_atmosphere and not has_extfields) { - CallPusher( + CallPusher( domain, params, pusher_params, @@ -270,9 +305,10 @@ namespace ntt { species.rangeActiveParticles(), domain.fields.em, domain.mesh.metric, + pgen, kernel::sr::NoField_t {}); } else if (has_atmosphere and not has_extfields) { - CallPusher( + CallPusher( domain, params, pusher_params, @@ -281,10 +317,11 @@ namespace ntt { species.rangeActiveParticles(), domain.fields.em, domain.mesh.metric, + pgen, kernel::sr::NoField_t {}); } else if (not has_atmosphere and has_extfields) { if constexpr (arch::traits::pgen::HasExtFields) { - CallPusher( + CallPusher( domain, params, pusher_params, @@ -293,13 +330,14 @@ namespace ntt { species.rangeActiveParticles(), domain.fields.em, domain.mesh.metric, + pgen, pgen.ext_fields); } else { raise::Error("External fields not implemented", HERE); } } else { // has_atmosphere and has_extforce if constexpr (arch::traits::pgen::HasExtFields) { - CallPusher( + CallPusher( domain, params, pusher_params, @@ -308,6 +346,7 @@ namespace ntt { species.rangeActiveParticles(), domain.fields.em, domain.mesh.metric, + pgen, pgen.ext_fields); } else { raise::Error("External fields not implemented", HERE); diff --git a/src/framework/containers/particles.cpp b/src/framework/containers/particles.cpp index 31b08a40c..7682718a3 100644 --- a/src/framework/containers/particles.cpp +++ b/src/framework/containers/particles.cpp @@ -84,6 +84,31 @@ namespace ntt { } } + template + auto Particles::PusherKernelArrays() -> kernel::sr::PusherArrays { + kernel::sr::PusherArrays pusher_arrays {}; + pusher_arrays.sp = index(); + pusher_arrays.i1 = i1; + pusher_arrays.i2 = i2; + pusher_arrays.i3 = i3; + pusher_arrays.i1_prev = i1_prev; + pusher_arrays.i2_prev = i2_prev; + pusher_arrays.i3_prev = i3_prev; + pusher_arrays.dx1 = dx1; + pusher_arrays.dx2 = dx2; + pusher_arrays.dx3 = dx3; + pusher_arrays.dx1_prev = dx1_prev; + pusher_arrays.dx2_prev = dx2_prev; + pusher_arrays.dx3_prev = dx3_prev; + pusher_arrays.ux1 = ux1; + pusher_arrays.ux2 = ux2; + pusher_arrays.ux3 = ux3; + pusher_arrays.phi = phi; + pusher_arrays.weight = weight; + pusher_arrays.tag = tag; + return pusher_arrays; + } + template struct Particles; template struct Particles; template struct Particles; diff --git a/src/framework/containers/particles.h b/src/framework/containers/particles.h index c7ba5acc5..c3ee1c1b5 100644 --- a/src/framework/containers/particles.h +++ b/src/framework/containers/particles.h @@ -25,6 +25,7 @@ #include "framework/containers/species.h" #include "framework/domain/grid.h" +#include "kernels/particle_pusher_sr.hpp" #include @@ -270,6 +271,12 @@ namespace ntt { */ void SyncHostDevice(); + /** + * @brief Get the arrays required for the particle pusher kernel + * @returns The struct of arrays for the particle pusher kernel + */ + auto PusherKernelArrays() -> kernel::sr::PusherArrays; + #if defined(MPI_ENABLED) /** * @brief Communicate particles across neighboring meshblocks diff --git a/src/framework/parameters/particles.cpp b/src/framework/parameters/particles.cpp index 0b15e20b4..afc795a9a 100644 --- a/src/framework/parameters/particles.cpp +++ b/src/framework/parameters/particles.cpp @@ -86,6 +86,8 @@ namespace ntt { return EmissionType::SYNCHROTRON; } else if (fmt::toLower(emission_policy_str) == "compton") { return EmissionType::COMPTON; + } else if (fmt::toLower(emission_policy_str) == "custom") { + return EmissionType::CUSTOM; } else { raise::Error(fmt::format("Invalid emission_policy value: %s", emission_policy_str.c_str()), diff --git a/src/global/arch/traits.h b/src/global/arch/traits.h index be3b377c5..a2d40adac 100644 --- a/src/global/arch/traits.h +++ b/src/global/arch/traits.h @@ -169,45 +169,6 @@ namespace traits { } // namespace external - namespace emission { - - template - concept HasPayloadType = requires { typename E::Payload; }; - - template - concept HasShouldEmit = HasPayloadType and requires( - const E& e, - const coord_t& x_Cd, - const coord_t& x_Ph, - const vec_t& u_Ph, - const vec_t& ep, - const vec_t& bp, - vec_t& delta_u_Ph, - typename E::Payload& payload) { - { - e.shouldEmit(x_Cd, x_Ph, u_Ph, ep, bp, delta_u_Ph, payload) - } -> std::convertible_to>; - }; - - template - concept HasEmit = HasPayloadType and - requires(const E& e, - const tuple_t& xi_Cd, - const tuple_t& dxi_Cd, - const vec_t& direction, - real_t weight, - real_t phi, - const typename E::Payload& payload) { - { - e.emit(xi_Cd, dxi_Cd, direction, weight, phi, payload) - } -> std::same_as; - }; - - template - concept IsValidEmissionPolicy = HasShouldEmit and HasEmit; - - } // namespace emission - template