diff --git a/.clang-format b/.clang-format index ed150af9b..b3bb8d132 100644 --- a/.clang-format +++ b/.clang-format @@ -106,6 +106,8 @@ IncludeCategories: Priority: 4 - Regex: '^"engines\/.*\.h"' Priority: 4 + - Regex: '^"checkpoint\/.*\.h"' + Priority: 4 - Regex: '^"output\/.*\.h"' Priority: 4 - Regex: '^"archetypes\/.*\.h"' diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 04bc34050..a5c546328 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -1,26 +1,44 @@ name: Unit tests on: - pull_request: - branches: - - '**rc' - - 'master' + push: jobs: + check-commit: + runs-on: ubuntu-latest + outputs: + run_tests: ${{ steps.check_message.outputs.run_tests }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Check commit message + id: check_message + run: | + if git log -1 --pretty=%B | grep -q "RUNTEST"; then + echo "run_tests=true" >> "$GITHUB_OUTPUT" + else + echo "run_tests=false" >> "$GITHUB_OUTPUT" + fi tests: + needs: check-commit + if: needs.check-commit.outputs.run_tests == 'true' strategy: fail-fast: false matrix: - device: [amd-gpu, nvidia-gpu] + device: [cpu, amd-gpu, nvidia-gpu] precision: [double, single] - exclude: + mpi: [serial, parallel] + exclude: # my AMD GPU doesn't support fp64 atomics : ( - device: amd-gpu precision: double - # my AMD GPUs doesn't support fp64 atomics : ( + - device: amd-gpu + mpi: parallel + - device: nvidia-gpu + mpi: parallel runs-on: [self-hosted, "${{ matrix.device }}"] steps: - name: Checkout - uses: actions/checkout@v3.3.0 + uses: actions/checkout@v4 - name: Configure run: | if [ "${{ matrix.device }}" = "nvidia-gpu" ]; then @@ -34,6 +52,8 @@ jobs: fi elif [ "${{ matrix.device }}" = "amd-gpu" ]; then FLAGS="-D Kokkos_ENABLE_HIP=ON -D Kokkos_ARCH_AMD_GFX1100=ON" + elif [ "${{ matrix.mpi }}" = "parallel" ]; then + FLAGS="-D mpi=ON" fi cmake -B build -D TESTS=ON -D output=ON -D precision=${{ matrix.precision }} $FLAGS - name: Compile diff --git a/.gitignore b/.gitignore index 20bfe33a3..9a167b9d5 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,10 @@ venv/ # CMake testing files Testing/ +tags +.clangd .schema.json *_old/ action-token +*.vim +ignore-* diff --git a/.gitmodules b/.gitmodules index bb2c39c87..835fbe5b8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,13 +1,12 @@ -[submodule "extern/toml11"] - path = extern/toml11 - url = https://github.com/ToruNiina/toml11.git [submodule "extern/plog"] path = extern/plog url = https://github.com/SergiusTheBest/plog.git [submodule "extern/adios2"] path = extern/adios2 url = https://github.com/ornladios/ADIOS2.git - branch = master [submodule "extern/Kokkos"] path = extern/Kokkos url = https://github.com/kokkos/kokkos.git +[submodule "extern/entity-pgens"] + path = extern/entity-pgens + url = https://github.com/entity-toolkit/entity-pgens.git diff --git a/.taplo.toml b/.taplo.toml new file mode 100644 index 000000000..423a47594 --- /dev/null +++ b/.taplo.toml @@ -0,0 +1,6 @@ +[formatting] + align_entries = true + indent_tables = true + indent_entries = true + trailing_newline = true + align_comments = true diff --git a/.vscode/.taplo.toml b/.vscode/.taplo.toml index 0bfa6bec9..c24ab0926 100644 --- a/.vscode/.taplo.toml +++ b/.vscode/.taplo.toml @@ -1,4 +1,4 @@ -include = ["input.example.toml", "setups/**/*.toml"] +include = ["input.example.toml", "pgens/**/*.toml"] exclude = [".taplo.toml"] [formatting] diff --git a/CMakeLists.txt b/CMakeLists.txt index 2152ebcf2..06a7690d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,3 +1,5 @@ +# cmake-lint: disable=C0103,C0111,E1120,R0913,R0915 + cmake_minimum_required(VERSION 3.16) cmake_policy(SET CMP0110 NEW) @@ -5,15 +7,16 @@ set(PROJECT_NAME entity) project( ${PROJECT_NAME} - VERSION 1.1.1 + VERSION 1.2.0 LANGUAGES CXX C) add_compile_options("-D ENTITY_VERSION=\"${PROJECT_VERSION}\"") -execute_process(COMMAND - bash -c "git diff --quiet src/ && echo $(git rev-parse HEAD) || echo $(git rev-parse HEAD)-mod" +set(hash_cmd "git diff --quiet src/ && echo $(git rev-parse HEAD) ") +string(APPEND hash_cmd "|| echo $(git rev-parse HEAD)-mod") +execute_process( + COMMAND bash -c ${hash_cmd} WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" OUTPUT_VARIABLE GIT_HASH - ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE -) + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) message(STATUS "Git hash: ${GIT_HASH}") add_compile_options("-D ENTITY_GIT_HASH=\"${GIT_HASH}\"") @@ -25,100 +28,115 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/defaults.cmake) # defaults set(DEBUG - ${default_debug} - CACHE BOOL "Debug mode") + ${default_debug} + CACHE BOOL "Debug mode") set(precision - ${default_precision} - CACHE STRING "Precision") + ${default_precision} + CACHE STRING "Precision") set(pgen - ${default_pgen} - CACHE STRING "Problem generator") + ${default_pgen} + CACHE STRING "Problem generator") set(gui - ${default_gui} - CACHE BOOL "Use GUI [nttiny]") + ${default_gui} + CACHE BOOL "Use GUI [nttiny]") set(output - ${default_output} - CACHE BOOL "Enable output") + ${default_output} + CACHE BOOL "Enable output") set(mpi - ${default_mpi} - CACHE BOOL "Use MPI") + ${default_mpi} + CACHE BOOL "Use MPI") + +set(gpu_aware_mpi + ${default_gpu_aware_mpi} + CACHE BOOL "Enable GPU-aware MPI") # -------------------------- Compilation settings -------------------------- # set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) if(${DEBUG} STREQUAL "OFF") set(CMAKE_BUILD_TYPE - Release - CACHE STRING "CMake build type") + Release + CACHE STRING "CMake build type") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG") else() set(CMAKE_BUILD_TYPE - Debug - CACHE STRING "CMake build type") + Debug + CACHE STRING "CMake build type") set(CMAKE_CXX_FLAGS - "${CMAKE_CXX_FLAGS} -DDEBUG -Wall -Wextra -Wno-unknown-pragmas") + "${CMAKE_CXX_FLAGS} -DDEBUG -Wall -Wextra -Wno-unknown-pragmas") endif() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-local-typedefs") - # options set(precisions - "single" "double" - CACHE STRING "Precisions") + "single" "double" + CACHE STRING "Precisions") include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/config.cmake) # ------------------------- Third-Party Tests ------------------------------ # set(BUILD_TESTING - OFF - CACHE BOOL "Build tests") + OFF + CACHE BOOL "Build tests") # ------------------------ Third-party dependencies ------------------------ # -include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/kokkosConfig.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/dependencies.cmake) -find_or_fetch_dependency(Kokkos FALSE) -find_or_fetch_dependency(plog TRUE) -find_or_fetch_dependency(toml11 TRUE) +find_or_fetch_dependency(Kokkos FALSE QUIET) +find_or_fetch_dependency(plog TRUE QUIET) set(DEPENDENCIES Kokkos::kokkos) include_directories(${plog_SRC}/include) -include_directories(${toml11_SRC}) # -------------------------------- Main code ------------------------------- # set_precision(${precision}) +if("${Kokkos_DEVICES}" MATCHES "CUDA") + add_compile_options("-D CUDA_ENABLED") + set(DEVICE_ENABLED ON) + add_compile_options("-D DEVICE_ENABLED") +elseif("${Kokkos_DEVICES}" MATCHES "HIP") + add_compile_options("-D HIP_ENABLED") + set(DEVICE_ENABLED ON) + add_compile_options("-D DEVICE_ENABLED") +elseif("${Kokkos_DEVICES}" MATCHES "SYCL") + add_compile_options("-D SYCL_ENABLED") + set(DEVICE_ENABLED ON) + add_compile_options("-D DEVICE_ENABLED") +else() + set(DEVICE_ENABLED OFF) +endif() + +if(("${Kokkos_DEVICES}" MATCHES "CUDA") + OR ("${Kokkos_DEVICES}" MATCHES "HIP") + OR ("${Kokkos_DEVICES}" MATCHES "SYCL")) + set(DEVICE_ENABLED ON) +else() + set(DEVICE_ENABLED OFF) +endif() # MPI if(${mpi}) - include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/MPIConfig.cmake) + find_or_fetch_dependency(MPI FALSE REQUIRED) + include_directories(${MPI_CXX_INCLUDE_PATH}) + add_compile_options("-D MPI_ENABLED") set(DEPENDENCIES ${DEPENDENCIES} MPI::MPI_CXX) + if(${DEVICE_ENABLED}) + if(${gpu_aware_mpi}) + add_compile_options("-D GPU_AWARE_MPI") + endif() + else() + set(gpu_aware_mpi + OFF + CACHE BOOL "Use explicit copy when using MPI + GPU") + endif() endif() # Output if(${output}) - include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/adios2Config.cmake) - find_or_fetch_dependency(adios2 FALSE) - if (NOT DEFINED ENV{HDF5_ROOT}) - set(USE_CUSTOM_HDF5 OFF) - if (DEFINED ENV{CONDA_PREFIX}) - execute_process(COMMAND bash -c "conda list | grep \"hdf5\" -q" - RESULT_VARIABLE HDF5_INSTALLED) - if (HDF5_INSTALLED EQUAL 0) - set(HDF5_ROOT $ENV{CONDA_PREFIX}) - else() - set(USE_CUSTOM_HDF5 ON) - endif() - else() - set(USE_CUSTOM_HDF5 ON) - endif() - if (USE_CUSTOM_HDF5) - message(FATAL_ERROR "HDF5_ROOT is not set. Please set it to the root of the HDF5 installation") - endif() - endif() - find_package(HDF5 REQUIRED) - + find_or_fetch_dependency(adios2 FALSE QUIET) + add_compile_options("-D OUTPUT_ENABLED") if(${mpi}) set(DEPENDENCIES ${DEPENDENCIES} adios2::cxx11_mpi) else() @@ -131,14 +149,18 @@ link_libraries(${DEPENDENCIES}) if(TESTS) # ---------------------------------- Tests --------------------------------- # include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/tests.cmake) +elseif(BENCHMARK) + # ------------------------------ Benchmark --------------------------------- # + include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/benchmark.cmake) else() # ----------------------------------- GUI ---------------------------------- # if(${gui}) - find_or_fetch_dependency(nttiny FALSE) + find_or_fetch_dependency(nttiny FALSE QUIET) endif() # ------------------------------- Main source ------------------------------ # set_problem_generator(${pgen}) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src src) - include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/report.cmake) endif() + +include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/report.cmake) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..cfd678063 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,29 @@ +This code of conduct outlines shared principles and expectations for all participants in our open-source project. It's here to keep things open, respectful, and simple for everyone. Anyone involved with using and developing the code are expected to follow these guidelines starting from the release version `1.2.0`. + +# Open-source Principles +- This project is fully open-source and does not belong to any individual or institution. +- It is developed and maintained by our team of incredibly talented people, whose goal is to make it accessible to everyone with no expectation of credits. +- Anyone is free to use, copy, modify, or distribute the code under the project's open license. +- If you contribute something to the repository, it becomes part of the project and thus will also be regarded as open-source. + +# Contributions and Credit +- The only attribution we strongly encourage is a citation of either the code repository, or the corresponding method papers (coming soon). +- All contributions are made voluntarily, and there is no expectation of recognition of isolated individuals. +- There's no built-in expectation of credit or authorship for modules or changes pushed to the repository. Anyone is free to use any part of the code with no attribution to the author of any specific module or algorithm. +- The code is there for everyone to use, and its only goal is to enable the community to produce exciting science! + +# Roles in the Project +- There are three informal roles: + - users: anyone using the code; + - contributors: users who have write access to the repository; + - maintainers: contributors who also take care of organizational, administrative, and tech-support duties. +- Role changes (user -> contributor or vice versa) are easy and open, and can be done by asking the maintainers. +- Contributors are also expected to follow any shared guidelines -- whether discussed informally or written down -- around code-development things, such as committing, merging, and creating pull requests. +- Maintainers have the most strict obligations of keeping the repository clean, managing pull-requests, issuing releases, documenting all the features and additions, writing unit tests, helping other users and contributors with any problems they encounter, etc. +- To reiterate, regardless of the extent of involvement and help, there is absolutely no expectation of recognition or attribution on maintainers' end! + +# Community and Participation +- Everyone is welcome in the community. +- Joining meetings on Zoom, using the Slack workspace, giving feedback, taking part in decision making and planning, or following development doesn't require any special status -- it's open to all. + +Finally, while we cannot enforce it, we strongly encourage any projects that build on this code to be open source too. If you build upon this project, we welcome transparency and openness in spirit. You may contribute to the code as much or as little as you like; all effort is appreciated, none is required. diff --git a/LICENSE b/LICENSE index 554820782..639d2b41f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,19 +1,8 @@ BSD 3-Clause License -Copyright (c) 2024, Princeton University (PU), and -Princeton Plasma Physics Lab (PPPL). +Copyright (c) 2021-present, Entity development team. All rights reserved. -Core developers (alphabetical order): -* Alexander Chernoglazov -* Benjamin Crinquand -* Alisa Galishnikova -* Hayk Hakobyan -* Jens Mahlmann -* Sasha Philippov -* Arno Vanthieghem -* Muni Zhou - Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/README.md b/README.md index d6f4597f5..d1afa2487 100644 --- a/README.md +++ b/README.md @@ -2,34 +2,42 @@ tl;dr: One particle-in-cell code to rule them all. -`Entity` is an open-source coordinate-agnostic particle-in-cell (PIC) code written in C++17 specifically targeted to study plasma physics in relativistic astrophysical systems. The main algorithms of the code are written in covariant form, allowing to easily implement arbitrary grid geometries. The code is highly modular, and is written in the architecture-agnostic way using the [`Kokkos`](https://kokkos.org/kokkos-core-wiki/) performance portability library, allowing the code to efficiently use device parallelization on CPU and GPU architectures of different types. The multi-node parallelization is implemented using the `MPI` library, and the data output is done via the [`ADIOS2`](https://github.com/ornladios/ADIOS2) library which supports multiple output formats, including `HDF5` and `BP5`. +`Entity` is a community-driven open-source coordinate-agnostic general-relativistic (GR) particle-in-cell (PIC) code written in C++17 specifically targeted to study plasma physics in relativistic astrophysical systems. The main algorithms of the code are written in covariant form, allowing to easily implement arbitrary grid geometries. The code is highly modular, and is written in the architecture-agnostic way using the [`Kokkos`](https://kokkos.org/kokkos-core-wiki/) performance portability library, allowing the code to efficiently use device parallelization on CPU and GPU architectures of different types. The multi-node parallelization is implemented using the `MPI` library, and the data output is done via the [`ADIOS2`](https://github.com/ornladios/ADIOS2) library which supports multiple output formats, including `HDF5` and `BP5`. `Entity` is part of the `Entity toolkit` framework, which also includes a Python library for fast and efficient data analysis and visualization of the simulation data: [`nt2py`](https://pypi.org/project/nt2py/). -Our [detailed documentation](https://entity-toolkit.github.io/) includes everything you need to know to get started with using and/or contributing to the `Entity toolkit`. If you find bugs or issues, please feel free to add a GitHub issue or submit a pull request. Users with significant contributions to the code will be added to the list of developers, and assigned an emoji of their choice (important). +Our [detailed documentation](https://entity-toolkit.github.io/) includes everything you need to know to get started with using and/or contributing to the `Entity toolkit`. If you find bugs or issues, please feel free to add a GitHub issue or submit a pull request. Users with significant contributions to the code will be added to the list of developers, and assigned an emoji of their choice (important!). [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) -## Core developers (alphabetical) +## Join the community -πŸ‘€ __Yangyang Cai__ {[@StaticObserver](https://github.com/StaticObserver): GRPIC} +Everyone is welcome to join our small yet steadily growing community of code users and developers; regardless of how much you are planning to contribute -- we always welcome fresh ideas and feedback. We hold weekly Zoom meetings on Mondays at 12pm NY time, and have a dedicated Slack channel where you can be easily added by [emailing](mailto:haykh.astro@gmail.com) [one of the maintainers](mailto:genegorbs@gmail.com). If you prefer to just join our Zoom meetings without the Slack involvement -- that's totally fine, just email, and we can send you the Zoom link. -πŸ’β€β™‚οΈ __Alexander Chernoglazov__ {[@SChernoglazov](https://github.com/SChernoglazov): PIC} +Another way of contacting us is via GitHub issues and/or pull requests. Make sure to check out our [F.A.Q.](https://entity-toolkit.github.io/wiki/content/1-getting-started/9-faq/), as it might help you answer your question. -🍡 __Benjamin Crinquand__ {[@bcrinquand](https://github.com/bcrinquand): GRPIC, cubed-sphere} +> Keep in mind, you are free to use the code in any capacity, and there is absolutely no requirement on our end of including any of the developers in your project/proposal (as highlighted in our Code of Conduct). When contributing, also keep in mind that the code you upload to the repository automatically becomes public and open-source, and the same standards will be applied to it as to the rest of the code. -πŸ§‹ __Alisa Galishnikova__ {[@alisagk](https://github.com/alisagk): GRPIC} +## Contributors (alphabetical) -β˜• __Hayk Hakobyan__ {[@haykh](https://github.com/haykh): framework, PIC, GRPIC, cubed-sphere} - -πŸ₯” __Jens Mahlmann__ {[@jmahlmann](https://github.com/jmahlmann): framework, MPI, cubed-sphere} - -🐬 __Sasha Philippov__ {[@sashaph](https://github.com/sashaph): all-around} - -🀷 __Arno Vanthieghem__ {[@vanthieg](https://github.com/vanthieg): framework, PIC} - -😺 __Muni Zhou__ {[@munizhou](https://github.com/munizhou): PIC} +* :guitar: Ludwig BΓΆss {[@LudwigBoess](https://github.com/LudwigBoess)} +* :eyes: Yangyang Cai {[@StaticObserver](https://github.com/StaticObserver)} +* :tipping_hand_person: Alexander Chernoglazov {[@SChernoglazov](https://github.com/SChernoglazov)} +* :tea: Benjamin Crinquand {[@bcrinquand](https://github.com/bcrinquand)} +* :bubble_tea: Alisa Galishnikova {[@alisagk](https://github.com/alisagk)} +* :steam_locomotive: Evgeny Gorbunov {[@Alcauchy](https://github.com/Alcauchy)} +* :coffee: Hayk Hakobyan {[@haykh](https://github.com/haykh)} +* :potato: Jens Mahlmann {[@jmahlmann](https://github.com/jmahlmann)} +* :dolphin: Sasha Philippov {[@sashaph](https://github.com/sashaph)} +* :radio: Siddhant Solanki {[@sidruns30](https://github.com/sidruns30)} +* :shrug: Arno Vanthieghem {[@vanthieg](https://github.com/vanthieg)} +* :cat: Muni Zhou {[@munizhou](https://github.com/munizhou)} ## Branch policy -Master branch contains the latest stable version of the code which has already been released. Development on the core is done on branches starting with `dev/`, while fixing bugs is done in branches that start with `bug/`. User-specific modifications (i.e., new problem generators plus perhaps minor corrections in the core) are done on branches starting with `pgen/`. Before merging to the master branch, all the branches must first be merged to the latest release-candidate branch, which ends with `rc`, via a pull request. After which, when all the release goals are met, the `rc` branch is merged to the master and released as a new stable version. Stale branches will be archived with a tag starting with `archive/` (can still be accessed via the "Tags" tab) and removed. +- `master` branch contains the latest stable version of the code which has already been released. +- Development on the core is done on branches starting with `dev/`. +- Bug-fixes are being pushed to branches starting with `bug/`. +- All `bug/` and `dev/` branches must have an open pull-request describing in detail its purpose. +- Before merging to the master branch, all the branches must first be merged to the latest release-candidate branch, which ends with `rc`, via a pull request. This can either be a major release: `1.X.0rc`, or a patch release `1.X.Yrc`. +- Stale branches will be archived with a tag starting with `archive/` (can still be accessed via the "Tags" tab) and removed. diff --git a/TASKLIST.md b/TASKLIST.md deleted file mode 100644 index 82f44d0ea..000000000 --- a/TASKLIST.md +++ /dev/null @@ -1,35 +0,0 @@ -# v0.8 - -- [x] thick layer boundary for the monopole -- [x] test with filters -- [x] add diagnostics for nans in fields and particles -- [x] add gravitationally bound atmosphere -- [x] rewrite UniformInjector with global random pool -- [x] add particle deletion routine -- [x] make more user-friendly and understandable boundary conditions -- [x] refine output -- [x] add different moments (momX, momY, momZ, meanGamma) -- [x] add charge -- [x] add per species densities - -# v0.9 - -- [x] add current deposit/filtering for GR -- [x] add moments for GR -- [x] add Maxwellian for GR - -# v1.0.0 - -- [x] particle output -- [x] BUG in MPI particles/currents - -# v1.1.0 - -- [ ] custom boundary conditions for particles and fields -- [ ] transfer GR from v0.9 - -### Performance improvements to try - -- [ ] removing temporary variables in interpolation -- [ ] passing by value vs const ref in metric -- [ ] return physical coords one-by-one instead of by passing full vector diff --git a/benchmark/benchmark.cpp b/benchmark/benchmark.cpp new file mode 100644 index 000000000..98306c92b --- /dev/null +++ b/benchmark/benchmark.cpp @@ -0,0 +1,17 @@ +#include "global.h" + +#include +#include + +auto main(int argc, char* argv[]) -> int { + ntt::GlobalInitialize(argc, argv); + try { + // ... + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + GlobalFinalize(); + return 1; + } + GlobalFinalize(); + return 0; +} diff --git a/cmake/MPIConfig.cmake b/cmake/MPIConfig.cmake deleted file mode 100644 index b426641ec..000000000 --- a/cmake/MPIConfig.cmake +++ /dev/null @@ -1,3 +0,0 @@ -find_package(MPI REQUIRED) -include_directories(${MPI_CXX_INCLUDE_PATH}) -add_compile_options("-D MPI_ENABLED") \ No newline at end of file diff --git a/cmake/adios2Config.cmake b/cmake/adios2Config.cmake index 16c0c30c7..a4ce46179 100644 --- a/cmake/adios2Config.cmake +++ b/cmake/adios2Config.cmake @@ -1,15 +1,29 @@ # ----------------------------- Adios2 settings ---------------------------- # -set(ADIOS2_BUILD_EXAMPLES OFF CACHE BOOL "Build ADIOS2 examples") +set(ADIOS2_BUILD_EXAMPLES + OFF + CACHE BOOL "Build ADIOS2 examples") # Language support -set(ADIOS2_USE_Python OFF CACHE BOOL "Use Python for ADIOS2") -set(ADIOS2_USE_Fortran OFF CACHE BOOL "Use Fortran for ADIOS2") +set(ADIOS2_USE_Python + OFF + CACHE BOOL "Use Python for ADIOS2") +set(ADIOS2_USE_Fortran + OFF + CACHE BOOL "Use Fortran for ADIOS2") # Format/compression support -set(ADIOS2_USE_ZeroMQ OFF CACHE BOOL "Use ZeroMQ for ADIOS2") +set(ADIOS2_USE_HDF5 + ON + CACHE BOOL "Use HDF5 for ADIOS2") -set(ADIOS2_USE_MPI ${mpi} CACHE BOOL "Use MPI for ADIOS2") +set(ADIOS2_USE_MPI + ${mpi} + CACHE BOOL "Use MPI for ADIOS2") -set(ADIOS2_USE_CUDA OFF CACHE BOOL "Use CUDA for ADIOS2") +set(ADIOS2_USE_ZeroMQ + OFF + CACHE BOOL "Use ZeroMQ for ADIOS2") -add_compile_options("-D OUTPUT_ENABLED") +set(ADIOS2_USE_CUDA + OFF + CACHE BOOL "Use CUDA for ADIOS2") diff --git a/cmake/benchmark.cmake b/cmake/benchmark.cmake new file mode 100644 index 000000000..39b075716 --- /dev/null +++ b/cmake/benchmark.cmake @@ -0,0 +1,26 @@ +# cmake-lint: disable=C0103 + +set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src) + +add_subdirectory(${SRC_DIR}/global ${CMAKE_CURRENT_BINARY_DIR}/global) +add_subdirectory(${SRC_DIR}/metrics ${CMAKE_CURRENT_BINARY_DIR}/metrics) +add_subdirectory(${SRC_DIR}/kernels ${CMAKE_CURRENT_BINARY_DIR}/kernels) +add_subdirectory(${SRC_DIR}/archetypes ${CMAKE_CURRENT_BINARY_DIR}/archetypes) +add_subdirectory(${SRC_DIR}/framework ${CMAKE_CURRENT_BINARY_DIR}/framework) + +if(${output}) + add_subdirectory(${SRC_DIR}/output ${CMAKE_CURRENT_BINARY_DIR}/output) + add_subdirectory(${SRC_DIR}/checkpoint ${CMAKE_CURRENT_BINARY_DIR}/checkpoint) +endif() + +set(exec benchmark.xc) +set(src ${CMAKE_CURRENT_SOURCE_DIR}/benchmark/benchmark.cpp) + +add_executable(${exec} ${src}) + +set(libs ntt_global ntt_metrics ntt_kernels ntt_archetypes ntt_framework) +if(${output}) + list(APPEND libs ntt_output ntt_checkpoint) +endif() +add_dependencies(${exec} ${libs}) +target_link_libraries(${exec} PRIVATE ${libs} stdc++fs) diff --git a/cmake/config.cmake b/cmake/config.cmake index fa18a87eb..97ed658e3 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -1,9 +1,14 @@ +# cmake-lint: disable=C0103 + # -------------------------------- Precision ------------------------------- # function(set_precision precision_name) list(FIND precisions ${precision_name} PRECISION_FOUND) if(${PRECISION_FOUND} EQUAL -1) - message(FATAL_ERROR "Invalid precision: ${precision_name}\nValid options are: ${precisions}") + message( + FATAL_ERROR + "Invalid precision: ${precision_name}\nValid options are: ${precisions}" + ) endif() if(${precision_name} STREQUAL "single") @@ -13,19 +18,64 @@ endfunction() # ---------------------------- Problem generator --------------------------- # function(set_problem_generator pgen_name) - file(GLOB_RECURSE PGENS "${CMAKE_CURRENT_SOURCE_DIR}/setups/**/pgen.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/setups/pgen.hpp") + if(pgen_name STREQUAL ".") + message(FATAL_ERROR "Problem generator not specified") + endif() + + file(GLOB_RECURSE PGENS "${CMAKE_CURRENT_SOURCE_DIR}/pgens/**/pgen.hpp") + foreach(PGEN ${PGENS}) get_filename_component(PGEN_NAME ${PGEN} DIRECTORY) - string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/setups/" "" PGEN_NAME ${PGEN_NAME}) - string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/setups" "" PGEN_NAME ${PGEN_NAME}) + string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/pgens/" "" PGEN_NAME + ${PGEN_NAME}) list(APPEND PGEN_NAMES ${PGEN_NAME}) endforeach() + list(FIND PGEN_NAMES ${pgen_name} PGEN_FOUND) - if(NOT ${pgen_name} STREQUAL "." AND ${PGEN_FOUND} EQUAL -1) - message(FATAL_ERROR "Invalid problem generator: ${pgen_name}\nValid options are: ${PGEN_NAMES}") + + file(GLOB_RECURSE EXTRA_PGENS + "${CMAKE_CURRENT_SOURCE_DIR}/extern/entity-pgens/**/pgen.hpp") + foreach(EXTRA_PGEN ${EXTRA_PGENS}) + get_filename_component(EXTRA_PGEN_NAME ${EXTRA_PGEN} DIRECTORY) + string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/extern/entity-pgens/" "" + EXTRA_PGEN_NAME ${EXTRA_PGEN_NAME}) + list(APPEND PGEN_NAMES "pgens/${EXTRA_PGEN_NAME}") + endforeach() + + if(${PGEN_FOUND} EQUAL -1) + if(${pgen_name} MATCHES "^pgens/") + get_filename_component(pgen_name ${pgen_name} NAME) + set(pgen_path + "${CMAKE_CURRENT_SOURCE_DIR}/extern/entity-pgens/${pgen_name}") + set(pgen_name "pgens/${pgen_name}") + else() + set(pgen_path ${pgen_name}) + get_filename_component(pgen_path ${pgen_path} ABSOLUTE) + string(REGEX REPLACE ".*/" "" pgen_name ${pgen_name}) + list(APPEND PGEN_NAMES ${pgen_name}) + endif() + else() + set(pgen_path ${CMAKE_CURRENT_SOURCE_DIR}/pgens/${pgen_name}) + endif() + + file(GLOB_RECURSE PGEN_FILES "${pgen_path}/pgen.hpp") + if(NOT PGEN_FILES) + message(FATAL_ERROR "pgen.hpp file not found in ${pgen_path}") endif() - set(PGEN ${pgen_name} PARENT_SCOPE) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/setups/${pgen_name}) - set(PGEN_FOUND TRUE PARENT_SCOPE) - set(problem_generators ${PGEN_NAMES} PARENT_SCOPE) + + add_library(ntt_pgen INTERFACE) + target_link_libraries(ntt_pgen INTERFACE ntt_global ntt_framework + ntt_archetypes ntt_kernels) + + target_include_directories(ntt_pgen INTERFACE ${pgen_path}) + + set(PGEN + ${pgen_name} + PARENT_SCOPE) + set(PGEN_FOUND + TRUE + PARENT_SCOPE) + set(problem_generators + ${PGEN_NAMES} + PARENT_SCOPE) endfunction() diff --git a/cmake/defaults.cmake b/cmake/defaults.cmake index f70120e0d..2bfa9a61c 100644 --- a/cmake/defaults.cmake +++ b/cmake/defaults.cmake @@ -1,62 +1,88 @@ +# cmake-lint: disable=C0103 + # ----------------------------- Defaults ---------------------------------- # if(DEFINED ENV{Entity_ENABLE_DEBUG}) - set(default_debug $ENV{Entity_ENABLE_DEBUG} CACHE INTERNAL "Default flag for debug mode") + set(default_debug + $ENV{Entity_ENABLE_DEBUG} + CACHE INTERNAL "Default flag for debug mode") else() - set(default_debug OFF CACHE INTERNAL "Default flag for debug mode") + set(default_debug + OFF + CACHE INTERNAL "Default flag for debug mode") endif() set_property(CACHE default_debug PROPERTY TYPE BOOL) -set(default_engine "pic" CACHE INTERNAL "Default engine") -set(default_precision "single" CACHE INTERNAL "Default precision") -set(default_pgen "." CACHE INTERNAL "Default problem generator") -set(default_sr_metric "minkowski" CACHE INTERNAL "Default SR metric") -set(default_gr_metric "kerr_schild" CACHE INTERNAL "Default GR metric") +set(default_engine + "pic" + CACHE INTERNAL "Default engine") +set(default_precision + "single" + CACHE INTERNAL "Default precision") +set(default_pgen + "." + CACHE INTERNAL "Default problem generator") +set(default_sr_metric + "minkowski" + CACHE INTERNAL "Default SR metric") +set(default_gr_metric + "kerr_schild" + CACHE INTERNAL "Default GR metric") if(DEFINED ENV{Entity_ENABLE_OUTPUT}) - set(default_output $ENV{Entity_ENABLE_OUTPUT} CACHE INTERNAL "Default flag for output") + set(default_output + $ENV{Entity_ENABLE_OUTPUT} + CACHE INTERNAL "Default flag for output") else() - set(default_output OFF CACHE INTERNAL "Default flag for output") + set(default_output + ON + CACHE INTERNAL "Default flag for output") endif() set_property(CACHE default_output PROPERTY TYPE BOOL) if(DEFINED ENV{Entity_ENABLE_GUI}) - set(default_gui $ENV{Entity_ENABLE_GUI} CACHE INTERNAL "Default flag for GUI") + set(default_gui + $ENV{Entity_ENABLE_GUI} + CACHE INTERNAL "Default flag for GUI") else() - set(default_gui OFF CACHE INTERNAL "Default flag for GUI") + set(default_gui + OFF + CACHE INTERNAL "Default flag for GUI") endif() set_property(CACHE default_gui PROPERTY TYPE BOOL) -if(DEFINED ENV{Kokkos_ENABLE_CUDA}) - set(default_KOKKOS_ENABLE_CUDA $ENV{Kokkos_ENABLE_CUDA} CACHE INTERNAL "Default flag for CUDA") -else() - set(default_KOKKOS_ENABLE_CUDA OFF CACHE INTERNAL "Default flag for CUDA") -endif() - -set_property(CACHE default_KOKKOS_ENABLE_CUDA PROPERTY TYPE BOOL) - -if(DEFINED ENV{Kokkos_ENABLE_HIP}) - set(default_KOKKOS_ENABLE_HIP $ENV{Kokkos_ENABLE_HIP} CACHE INTERNAL "Default flag for HIP") +if(DEFINED ENV{Entity_ENABLE_MPI}) + set(default_mpi + $ENV{Entity_ENABLE_MPI} + CACHE INTERNAL "Default flag for MPI") else() - set(default_KOKKOS_ENABLE_HIP OFF CACHE INTERNAL "Default flag for HIP") + set(default_mpi + OFF + CACHE INTERNAL "Default flag for MPI") endif() -set_property(CACHE default_KOKKOS_ENABLE_HIP PROPERTY TYPE BOOL) - -if(DEFINED ENV{Kokkos_ENABLE_OPENMP}) - set(default_KOKKOS_ENABLE_OPENMP $ENV{Kokkos_ENABLE_OPENMP} CACHE INTERNAL "Default flag for OpenMP") +if(DEFINED ENV{Entity_MPI_DEVICE_COPY}) + set(default_mpi_device_copy + $ENV{Entity_MPI_DEVICE_COPY} + CACHE INTERNAL "Default flag for copying from device to host for MPI") else() - set(default_KOKKOS_ENABLE_OPENMP OFF CACHE INTERNAL "Default flag for OpenMP") + set(default_mpi_device_copy + OFF + CACHE INTERNAL "Default flag for copying from device to host for MPI") endif() -set_property(CACHE default_KOKKOS_ENABLE_OPENMP PROPERTY TYPE BOOL) +set_property(CACHE default_mpi PROPERTY TYPE BOOL) -if(DEFINED ENV{Entity_ENABLE_MPI}) - set(default_mpi $ENV{Entity_ENABLE_MPI} CACHE INTERNAL "Default flag for MPI") +if(DEFINED ENV{Entity_ENABLE_GPU_AWARE_MPI}) + set(default_gpu_aware_mpi + $ENV{Entity_ENABLE_GPU_AWARE_MPI} + CACHE INTERNAL "Default flag for GPU-aware MPI") else() - set(default_mpi OFF CACHE INTERNAL "Default flag for MPI") + set(default_gpu_aware_mpi + ON + CACHE INTERNAL "Default flag for GPU-aware MPI") endif() -set_property(CACHE default_mpi PROPERTY TYPE BOOL) +set_property(CACHE default_gpu_aware_mpi PROPERTY TYPE BOOL) diff --git a/cmake/dependencies.cmake b/cmake/dependencies.cmake index b143befdf..1780bf97e 100644 --- a/cmake/dependencies.cmake +++ b/cmake/dependencies.cmake @@ -1,97 +1,170 @@ -set(Kokkos_REPOSITORY https://github.com/kokkos/kokkos.git CACHE STRING "Kokkos repository") -set(plog_REPOSITORY https://github.com/SergiusTheBest/plog.git CACHE STRING "plog repository") -set(toml11_REPOSITORY https://github.com/ToruNiina/toml11 CACHE STRING "toml11 repository") +# cmake-lint: disable=C0103,C0111,R0915,R0912 + +set(Kokkos_REPOSITORY + https://github.com/kokkos/kokkos.git + CACHE STRING "Kokkos repository") +set(plog_REPOSITORY + https://github.com/SergiusTheBest/plog.git + CACHE STRING "plog repository") +set(adios2_REPOSITORY + https://github.com/ornladios/ADIOS2.git + CACHE STRING "ADIOS2 repository") -# set (adios2_REPOSITORY https://github.com/ornladios/ADIOS2.git CACHE STRING "ADIOS2 repository") function(check_internet_connection) if(OFFLINE STREQUAL "ON") - set(FETCHCONTENT_FULLY_DISCONNECTED ON CACHE BOOL "Connection status") + set(FETCHCONTENT_FULLY_DISCONNECTED + ON + CACHE BOOL "Connection status") message(STATUS "${Blue}Offline mode.${ColorReset}") else() execute_process( COMMAND ping 8.8.8.8 -c 2 RESULT_VARIABLE NO_CONNECTION - OUTPUT_QUIET - ) + OUTPUT_QUIET) if(NO_CONNECTION GREATER 0) - set(FETCHCONTENT_FULLY_DISCONNECTED ON CACHE BOOL "Connection status") - message(STATUS "${Red}No internet connection. Fetching disabled.${ColorReset}") + set(FETCHCONTENT_FULLY_DISCONNECTED + ON + CACHE BOOL "Connection status") + message( + STATUS "${Red}No internet connection. Fetching disabled.${ColorReset}") else() - set(FETCHCONTENT_FULLY_DISCONNECTED OFF CACHE BOOL "Connection status") + set(FETCHCONTENT_FULLY_DISCONNECTED + OFF + CACHE BOOL "Connection status") message(STATUS "${Green}Internet connection established.${ColorReset}") endif() endif() endfunction() -function(find_or_fetch_dependency package_name header_only) +function(find_or_fetch_dependency package_name header_only mode) if(NOT header_only) - find_package(${package_name} QUIET) + find_package(${package_name} ${mode}) endif() if(NOT ${package_name}_FOUND) - if(DEFINED ${package_name}_REPOSITORY AND NOT FETCHCONTENT_FULLY_DISCONNECTED) + if(${package_name} STREQUAL "Kokkos") + include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/kokkosConfig.cmake) + elseif(${package_name} STREQUAL "adios2") + include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/adios2Config.cmake) + endif() + if(DEFINED ${package_name}_REPOSITORY AND NOT + FETCHCONTENT_FULLY_DISCONNECTED) # fetching package - message(STATUS "${Blue}${package_name} not found. Fetching from ${${package_name}_REPOSITORY}${ColorReset}") + message(STATUS "${Blue}${package_name} not found. " + "Fetching from ${${package_name}_REPOSITORY}${ColorReset}") include(FetchContent) if(${package_name} STREQUAL "Kokkos") FetchContent_Declare( ${package_name} GIT_REPOSITORY ${${package_name}_REPOSITORY} - GIT_TAG 4.3.00 - ) + GIT_TAG 4.6.01) else() - FetchContent_Declare( - ${package_name} - GIT_REPOSITORY ${${package_name}_REPOSITORY} - ) + FetchContent_Declare(${package_name} + GIT_REPOSITORY ${${package_name}_REPOSITORY}) endif() FetchContent_MakeAvailable(${package_name}) set(lower_pckg_name ${package_name}) string(TOLOWER ${lower_pckg_name} lower_pckg_name) - set(${package_name}_SRC ${CMAKE_CURRENT_BINARY_DIR}/_deps/${lower_pckg_name}-src CACHE PATH "Path to ${package_name} src") - set(${package_name}_FETCHED TRUE CACHE BOOL "Whether ${package_name} was fetched") + set(${package_name}_SRC + ${CMAKE_CURRENT_BINARY_DIR}/_deps/${lower_pckg_name}-src + CACHE PATH "Path to ${package_name} src") + set(${package_name}_BUILD_DIR + ${CMAKE_CURRENT_BINARY_DIR}/_deps/${lower_pckg_name}-build + CACHE PATH "Path to ${package_name} build") + set(${package_name}_FETCHED + TRUE + CACHE BOOL "Whether ${package_name} was fetched") message(STATUS "${Green}${package_name} fetched.${ColorReset}") else() # get as submodule - message(STATUS "${Yellow}${package_name} not found. Using as submodule.${ColorReset}") + message( + STATUS + "${Yellow}${package_name} not found. Using as submodule.${ColorReset}" + ) - set(${package_name}_FETCHED FALSE CACHE BOOL "Whether ${package_name} was fetched") + set(${package_name}_FETCHED + FALSE + CACHE BOOL "Whether ${package_name} was fetched") if(NOT FETCHCONTENT_FULLY_DISCONNECTED) - message(STATUS "${GREEN}Updating ${package_name} submodule.${ColorReset}") + message( + STATUS "${GREEN}Updating ${package_name} submodule.${ColorReset}") execute_process( - COMMAND git submodule update --init --remote ${CMAKE_CURRENT_SOURCE_DIR}/extern/${package_name} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - ) + COMMAND git submodule update --init --remote + ${CMAKE_CURRENT_SOURCE_DIR}/extern/${package_name} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) endif() - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extern/${package_name} extern/${package_name}) - set(${package_name}_SRC ${CMAKE_CURRENT_SOURCE_DIR}/extern/${package_name} CACHE PATH "Path to ${package_name} src") - set(${package_name}_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/build/extern/${package_name} CACHE PATH "Path to ${package_name} build") + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extern/${package_name} + extern/${package_name}) + set(${package_name}_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/extern/${package_name} + CACHE PATH "Path to ${package_name} src") + set(${package_name}_BUILD_DIR + ${CMAKE_CURRENT_SOURCE_DIR}/build/extern/${package_name} + CACHE PATH "Path to ${package_name} build") endif() else() message(STATUS "${Green}${package_name} found.${ColorReset}") - set(${package_name}_FETCHED FALSE CACHE BOOL "Whether ${package_name} was fetched") - set(${package_name}_VERSION ${${package_name}_VERSION} CACHE INTERNAL "${package_name} version") + set(${package_name}_FETCHED + FALSE + CACHE BOOL "Whether ${package_name} was fetched") + set(${package_name}_VERSION + ${${package_name}_VERSION} + CACHE INTERNAL "${package_name} version") endif() if(${package_name} STREQUAL "adios2") if(NOT DEFINED adios2_VERSION OR adios2_VERSION STREQUAL "") - get_directory_property(adios2_VERSION DIRECTORY ${adios2_BUILD_DIR} DEFINITION ADIOS2_VERSION) - set(adios2_VERSION ${adios2_VERSION} CACHE INTERNAL "ADIOS2 version") + get_directory_property(adios2_VERSION DIRECTORY ${adios2_BUILD_DIR} + DEFINITION ADIOS2_VERSION) + set(adios2_VERSION + ${adios2_VERSION} + CACHE INTERNAL "ADIOS2 version") endif() endif() if(${package_name} STREQUAL "Kokkos") if(NOT DEFINED Kokkos_VERSION OR Kokkos_VERSION STREQUAL "") - get_directory_property(Kokkos_VERSION DIRECTORY ${Kokkos_SRC} DEFINITION Kokkos_VERSION) - set(Kokkos_VERSION ${Kokkos_VERSION} CACHE INTERNAL "Kokkos version") + get_directory_property(Kokkos_VERSION DIRECTORY ${Kokkos_SRC} DEFINITION + Kokkos_VERSION) + set(Kokkos_VERSION + ${Kokkos_VERSION} + CACHE INTERNAL "Kokkos version") + endif() + if(NOT DEFINED Kokkos_ARCH + OR Kokkos_ARCH STREQUAL "" + OR NOT DEFINED Kokkos_DEVICES + OR Kokkos_DEVICES STREQUAL "") + if(${Kokkos_FOUND}) + include(${Kokkos_DIR}/KokkosConfigCommon.cmake) + elseif(NOT ${Kokkos_BUILD_DIR} STREQUAL "") + include(${Kokkos_BUILD_DIR}/KokkosConfigCommon.cmake) + else() + message( + STATUS "${Red}Kokkos_DIR and Kokkos_BUILD_DIR not set.${ColorReset}") + endif() endif() + set(Kokkos_ARCH + ${Kokkos_ARCH} + PARENT_SCOPE) + set(Kokkos_DEVICES + ${Kokkos_DEVICES} + PARENT_SCOPE) endif() + set(${package_name}_FOUND + ${${package_name}_FOUND} + PARENT_SCOPE) + set(${package_name}_FETCHED + ${${package_name}_FETCHED} + PARENT_SCOPE) + set(${package_name}_BUILD_DIR + ${${package_name}_BUILD_DIR} + PARENT_SCOPE) endfunction() check_internet_connection() diff --git a/cmake/kokkosConfig.cmake b/cmake/kokkosConfig.cmake index 8928253ae..f1a1cf207 100644 --- a/cmake/kokkosConfig.cmake +++ b/cmake/kokkosConfig.cmake @@ -1,57 +1,40 @@ +# cmake-lint: disable=C0103 + # ----------------------------- Kokkos settings ---------------------------- # if(${DEBUG} STREQUAL "OFF") - set(Kokkos_ENABLE_AGGRESSIVE_VECTORIZATION ON CACHE BOOL "Kokkos aggressive vectorization") - set(Kokkos_ENABLE_COMPILER_WARNINGS OFF CACHE BOOL "Kokkos compiler warnings") - set(Kokkos_ENABLE_DEBUG OFF CACHE BOOL "Kokkos debug") - set(Kokkos_ENABLE_DEBUG_BOUNDS_CHECK OFF CACHE BOOL "Kokkos debug bounds check") + set(Kokkos_ENABLE_AGGRESSIVE_VECTORIZATION + ON + CACHE BOOL "Kokkos aggressive vectorization") + set(Kokkos_ENABLE_COMPILER_WARNINGS + OFF + CACHE BOOL "Kokkos compiler warnings") + set(Kokkos_ENABLE_DEBUG + OFF + CACHE BOOL "Kokkos debug") + set(Kokkos_ENABLE_DEBUG_BOUNDS_CHECK + OFF + CACHE BOOL "Kokkos debug bounds check") else() - set(Kokkos_ENABLE_AGGRESSIVE_VECTORIZATION OFF CACHE BOOL "Kokkos aggressive vectorization") - set(Kokkos_ENABLE_COMPILER_WARNINGS ON CACHE BOOL "Kokkos compiler warnings") - set(Kokkos_ENABLE_DEBUG ON CACHE BOOL "Kokkos debug") - set(Kokkos_ENABLE_DEBUG_BOUNDS_CHECK ON CACHE BOOL "Kokkos debug bounds check") + set(Kokkos_ENABLE_AGGRESSIVE_VECTORIZATION + OFF + CACHE BOOL "Kokkos aggressive vectorization") + set(Kokkos_ENABLE_COMPILER_WARNINGS + ON + CACHE BOOL "Kokkos compiler warnings") + set(Kokkos_ENABLE_DEBUG + ON + CACHE BOOL "Kokkos debug") + set(Kokkos_ENABLE_DEBUG_BOUNDS_CHECK + ON + CACHE BOOL "Kokkos debug bounds check") endif() -set(Kokkos_ENABLE_HIP ${default_KOKKOS_ENABLE_HIP} CACHE BOOL "Enable HIP") -set(Kokkos_ENABLE_CUDA ${default_KOKKOS_ENABLE_CUDA} CACHE BOOL "Enable CUDA") -set(Kokkos_ENABLE_OPENMP ${default_KOKKOS_ENABLE_OPENMP} CACHE BOOL "Enable OpenMP") - -# set memory space -if(${Kokkos_ENABLE_CUDA}) - add_compile_definitions(CUDA_ENABLED) - set(ACC_MEM_SPACE Kokkos::CudaSpace) -elseif(${Kokkos_ENABLE_HIP}) - add_compile_definitions(HIP_ENABLED) - set(ACC_MEM_SPACE Kokkos::HIPSpace) -else() - set(ACC_MEM_SPACE Kokkos::HostSpace) -endif() - -set(HOST_MEM_SPACE Kokkos::HostSpace) - -# set execution space -if(${Kokkos_ENABLE_CUDA}) - set(ACC_EXE_SPACE Kokkos::Cuda) -elseif(${Kokkos_ENABLE_HIP}) - set(ACC_EXE_SPACE Kokkos::HIP) -elseif(${Kokkos_ENABLE_OPENMP}) - set(ACC_EXE_SPACE Kokkos::OpenMP) -else() - set(ACC_EXE_SPACE Kokkos::Serial) -endif() - -if(${Kokkos_ENABLE_OPENMP}) - set(HOST_EXE_SPACE Kokkos::OpenMP) -else() - set(HOST_EXE_SPACE Kokkos::Serial) -endif() - -add_compile_options("-D AccelExeSpace=${ACC_EXE_SPACE}") -add_compile_options("-D AccelMemSpace=${ACC_MEM_SPACE}") -add_compile_options("-D HostExeSpace=${HOST_EXE_SPACE}") -add_compile_options("-D HostMemSpace=${HOST_MEM_SPACE}") - if(${BUILD_TESTING} STREQUAL "OFF") - set(Kokkos_ENABLE_TESTS OFF CACHE BOOL "Kokkos tests") + set(Kokkos_ENABLE_TESTS + OFF + CACHE BOOL "Kokkos tests") else() - set(Kokkos_ENABLE_TESTS ON CACHE BOOL "Kokkos tests") + set(Kokkos_ENABLE_TESTS + ON + CACHE BOOL "Kokkos tests") endif() diff --git a/cmake/report.cmake b/cmake/report.cmake index fe914baa8..7bd623943 100644 --- a/cmake/report.cmake +++ b/cmake/report.cmake @@ -1,264 +1,93 @@ -function(PadTo Text Padding Target Result) - set(rt ${Text}) - string(FIND ${rt} "${Magenta}" mg_fnd) - - if(mg_fnd GREATER -1) - string(REGEX REPLACE "${Esc}\\[35m" "" rt ${rt}) - endif() - - string(LENGTH "${rt}" TextLength) - math(EXPR PaddingNeeded "${Target} - ${TextLength}") - set(rt ${Text}) - - if(PaddingNeeded GREATER 0) - foreach(i RANGE 0 ${PaddingNeeded}) - set(rt "${rt}${Padding}") - endforeach() - else() - set(${rt} "${rt}") - endif() - - set(${Result} "${rt}" PARENT_SCOPE) -endfunction() - -function(PrintChoices Label Flag Choices Value Default Color OutputString Multiline Padding) - list(LENGTH "${Choices}" nchoices) - set(rstring "") - set(counter 0) - - foreach(ch ${Choices}) - if(${counter} EQUAL 0) - set(rstring_i "- ${Label}") - - if(NOT "${Flag}" STREQUAL "") - set(rstring_i "${rstring_i} [${Magenta}${Flag}${ColorReset}]") - endif() - - set(rstring_i "${rstring_i}:") - PadTo("${rstring_i}" " " ${Padding} rstring_i) - else() - set(rstring_i "") - - if(NOT ${counter} EQUAL ${nchoices}) - if(${Multiline} EQUAL 1) - set(rstring_i "${rstring_i}\n") - PadTo("${rstring_i}" " " ${Padding} rstring_i) - else() - set(rstring_i "${rstring_i}/") - endif() - endif() - endif() - - if(${ch} STREQUAL ${Value}) - if(${ch} STREQUAL "ON") - set(col ${Green}) - elseif(${ch} STREQUAL "OFF") - set(col ${Red}) - else() - set(col ${Color}) - endif() - else() - set(col ${Dim}) - endif() - - if(${ch} STREQUAL ${Default}) - set(col ${Underline}${col}) - endif() - - set(rstring_i "${rstring_i}${col}${ch}${ColorReset}") - math(EXPR counter "${counter} + 1") - set(rstring "${rstring}${rstring_i}") - set(rstring_i "") - endforeach() - - set(${OutputString} "${rstring}" PARENT_SCOPE) -endfunction() - -set(ON_OFF_VALUES "ON" "OFF") - if(${PGEN_FOUND}) - PrintChoices("Problem generator" + printchoices( + "Problem generator" "pgen" "${problem_generators}" ${PGEN} - ${default_pgen} + "" "${Blue}" PGEN_REPORT - 1 - 36 - ) + 0) +elseif(${TESTS}) + set(TEST_NAMES "") + foreach(test_dir IN LISTS TEST_DIRECTORIES) + get_property( + LOCAL_TEST_NAMES + DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/${test_dir}/tests + PROPERTY TESTS) + list(APPEND TEST_NAMES ${LOCAL_TEST_NAMES}) + endforeach() + printchoices( + "Test cases" + "" + "${TEST_NAMES}" + "" + "${ColorReset}" + "" + TESTS_REPORT + 0) endif() -PrintChoices("Precision" +printchoices( + "Precision" "precision" "${precisions}" ${precision} ${default_precision} "${Blue}" PRECISION_REPORT - 1 - 36 -) -PrintChoices("Output" + 46) +printchoices( + "Output" "output" "${ON_OFF_VALUES}" ${output} ${default_output} "${Green}" OUTPUT_REPORT - 0 - 36 -) -PrintChoices("GUI" - "gui" - "${ON_OFF_VALUES}" - ${gui} - ${default_gui} - "${Green}" - GUI_REPORT - 0 - 36 -) -PrintChoices("MPI" + 46) +printchoices( + "MPI" "mpi" "${ON_OFF_VALUES}" ${mpi} OFF "${Green}" MPI_REPORT - 0 - 42 -) -PrintChoices("Debug mode" + 46) +if(${mpi} AND ${DEVICE_ENABLED}) + printchoices( + "GPU-aware MPI" + "gpu_aware_mpi" + "${ON_OFF_VALUES}" + ${gpu_aware_mpi} + OFF + "${Green}" + GPU_AWARE_MPI_REPORT + 46) +endif() +printchoices( + "Debug mode" "DEBUG" "${ON_OFF_VALUES}" ${DEBUG} OFF "${Green}" DEBUG_REPORT - 0 - 42 -) - -PrintChoices("CUDA" - "Kokkos_ENABLE_CUDA" - "${ON_OFF_VALUES}" - ${Kokkos_ENABLE_CUDA} - "OFF" - "${Green}" - CUDA_REPORT - 0 - 42 -) -PrintChoices("HIP" - "Kokkos_ENABLE_HIP" - "${ON_OFF_VALUES}" - ${Kokkos_ENABLE_HIP} - "OFF" - "${Green}" - HIP_REPORT - 0 - 42 -) -PrintChoices("OpenMP" - "Kokkos_ENABLE_OPENMP" - "${ON_OFF_VALUES}" - ${Kokkos_ENABLE_OPENMP} - "OFF" - "${Green}" - OPENMP_REPORT - 0 - 42 -) - -PrintChoices("C++ compiler" - "CMAKE_CXX_COMPILER" - "${CMAKE_CXX_COMPILER} v${CMAKE_CXX_COMPILER_VERSION}" - "${CMAKE_CXX_COMPILER} v${CMAKE_CXX_COMPILER_VERSION}" - "N/A" - "${ColorReset}" - CXX_COMPILER_REPORT - 0 - 42 -) - -PrintChoices("C compiler" - "CMAKE_C_COMPILER" - "${CMAKE_C_COMPILER} v${CMAKE_C_COMPILER_VERSION}" - "${CMAKE_C_COMPILER} v${CMAKE_C_COMPILER_VERSION}" - "N/A" - "${ColorReset}" - C_COMPILER_REPORT - 0 - 42 -) - -get_cmake_property(_variableNames VARIABLES) -foreach (_variableName ${_variableNames}) - string(REGEX MATCH "Kokkos_ARCH_*" _isMatched ${_variableName}) - if(_isMatched) - get_property(isSet CACHE ${_variableName} PROPERTY VALUE) - if(isSet STREQUAL "ON") - string(REGEX REPLACE "Kokkos_ARCH_" "" ARCH ${_variableName}) - break() - endif() - endif() -endforeach() -PrintChoices("Architecture" - "Kokkos_ARCH_*" - "${ARCH}" - "${ARCH}" - "N/A" - "${ColorReset}" - ARCH_REPORT - 0 - 42 -) - -if(${Kokkos_ENABLE_CUDA}) - if("${CMAKE_CUDA_COMPILER}" STREQUAL "") - execute_process(COMMAND which nvcc OUTPUT_VARIABLE CUDACOMP) - else() - set(CUDACOMP ${CMAKE_CUDA_COMPILER}) - endif() - - string(STRIP ${CUDACOMP} CUDACOMP) - - message(STATUS "CUDA compiler: ${CUDACOMP}") - execute_process(COMMAND bash -c "${CUDACOMP} --version | grep release | sed -e 's/.*release //' -e 's/,.*//'" - OUTPUT_VARIABLE CUDACOMP_VERSION - OUTPUT_STRIP_TRAILING_WHITESPACE) - - PrintChoices("CUDA compiler" - "CMAKE_CUDA_COMPILER" - "${CUDACOMP}" - "${CUDACOMP}" - "N/A" - "${ColorReset}" - CUDA_COMPILER_REPORT - 0 - 42 - ) -endif() - -if (${Kokkos_ENABLE_HIP}) - execute_process(COMMAND bash -c "hipcc --version | grep HIP | cut -d ':' -f 2 | tr -d ' '" - OUTPUT_VARIABLE ROCM_VERSION - OUTPUT_STRIP_TRAILING_WHITESPACE) -endif() - -set(DOT_SYMBOL "${ColorReset}.") -set(DOTTED_LINE_SYMBOL "${ColorReset}. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ") - -set(DASHED_LINE_SYMBOL "${ColorReset}....................................................................... ") + 46) if(NOT ${PROJECT_VERSION_TWEAK} EQUAL 0) - set(VERSION_SYMBOL "v${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}-rc${PROJECT_VERSION_TWEAK}") + set(VERSION_SYMBOL "v${PROJECT_VERSION_MAJOR}." "${PROJECT_VERSION_MINOR}.") + string(APPEND VERSION_SYMBOL + "${PROJECT_VERSION_PATCH}-rc${PROJECT_VERSION_TWEAK}") else() - set(VERSION_SYMBOL "v${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH} ") + set(VERSION_SYMBOL "v${PROJECT_VERSION_MAJOR}.") + string(APPEND VERSION_SYMBOL + "${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH} ") endif() -message("${Blue} __ __ +set(REPORT_TEXT + "${Blue} __ __ /\\ \\__ __/\\ \\__ __ ___\\ \\ _\\/\\_\\ \\ _\\ __ __ / __ \\ / __ \\ \\ \\/\\/\\ \\ \\ \\/ /\\ \\/\\ \\ @@ -267,52 +96,145 @@ message("${Blue} __ __ \\/____/\\/_/\\/_/\\/__/ \\/_/\\/__/ \\/___/ \\/_/ /\\___/ Entity ${VERSION_SYMBOL}\t\t \\/__/") -message("${DASHED_LINE_SYMBOL} -Main configurations") +string(APPEND REPORT_TEXT ${ColorReset} "\n") + +string(APPEND REPORT_TEXT ${DASHED_LINE_SYMBOL} "\n" "Configurations" "\n") if(${PGEN_FOUND}) - message(" ${PGEN_REPORT}") + string(APPEND REPORT_TEXT " " ${PGEN_REPORT} "\n") +elseif(${TESTS}) + string(APPEND REPORT_TEXT " " ${TESTS_REPORT} "\n") endif() -message(" ${PRECISION_REPORT}") -message(" ${OUTPUT_REPORT}") -message("${DASHED_LINE_SYMBOL} -Compile configurations") - -message(" ${ARCH_REPORT}") -message(" ${CUDA_REPORT}") -message(" ${HIP_REPORT}") -message(" ${OPENMP_REPORT}") - -message(" ${C_COMPILER_REPORT}") - -message(" ${CXX_COMPILER_REPORT}") - -if(NOT "${CUDA_COMPILER_REPORT}" STREQUAL "") - message(" ${CUDA_COMPILER_REPORT}") +string( + APPEND + REPORT_TEXT + " " + ${PRECISION_REPORT} + "\n" + " " + ${OUTPUT_REPORT} + "\n") + +string(REPLACE ";" "+" Kokkos_ARCH "${Kokkos_ARCH}") +string(REPLACE ";" "+" Kokkos_DEVICES "${Kokkos_DEVICES}") + +string( + APPEND + REPORT_TEXT + " - ARCH [${Magenta}Kokkos_ARCH_***${ColorReset}]: " + "${Kokkos_ARCH}" + "\n" + " - DEVICES [${Magenta}Kokkos_ENABLE_***${ColorReset}]: " + "${Kokkos_DEVICES}" + "\n" + " " + ${MPI_REPORT} + "\n") + +if(${mpi} AND ${DEVICE_ENABLED}) + string(APPEND REPORT_TEXT " " ${GPU_AWARE_MPI_REPORT} "\n") endif() -message(" ${MPI_REPORT}") - -message(" ${DEBUG_REPORT}") - -message("${DASHED_LINE_SYMBOL}\nDependencies") +string( + APPEND + REPORT_TEXT + " " + ${DEBUG_REPORT} + "\n" + ${DASHED_LINE_SYMBOL} + "\n" + "Compilers & dependencies" + "\n") + +string( + APPEND + REPORT_TEXT + " - C compiler [${Magenta}CMAKE_C_COMPILER${ColorReset}]: v" + ${CMAKE_C_COMPILER_VERSION} + "\n" + " ${Dim}" + ${CMAKE_C_COMPILER} + "${ColorReset}\n" + " - C++ compiler [${Magenta}CMAKE_CXX_COMPILER${ColorReset}]: v" + ${CMAKE_CXX_COMPILER_VERSION} + "\n" + " ${Dim}" + ${CMAKE_CXX_COMPILER} + "${ColorReset}\n") + +if(${Kokkos_DEVICES} MATCHES "CUDA") + if("${CMAKE_CUDA_COMPILER}" STREQUAL "") + execute_process(COMMAND which nvcc OUTPUT_VARIABLE CUDACOMP) + else() + set(CUDACOMP ${CMAKE_CUDA_COMPILER}) + endif() + string(STRIP ${CUDACOMP} CUDACOMP) + set(cmd "${CUDACOMP} --version |") + string(APPEND cmd " grep release | sed -e 's/.*release //' -e 's/,.*//'") + execute_process( + COMMAND bash -c ${cmd} + OUTPUT_VARIABLE CUDACOMP_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE) + string( + APPEND + REPORT_TEXT + " - CUDA compiler: v" + ${CUDACOMP_VERSION} + "\n" + " ${Dim}" + ${CUDACOMP} + "${ColorReset}\n") +elseif(${Kokkos_DEVICES} MATCHES "HIP") + set(cmd "hipcc --version | grep HIP | cut -d ':' -f 2 | tr -d ' '") + execute_process( + COMMAND bash -c ${cmd} + OUTPUT_VARIABLE ROCM_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE) + string(APPEND REPORT_TEXT " - ROCm: v" ${ROCM_VERSION} "\n") +endif() -if (NOT "${CUDACOMP_VERSION}" STREQUAL "") - message(" - CUDA:\tv${CUDACOMP_VERSION}") -elseif(NOT "${ROCM_VERSION}" STREQUAL "") - message(" - ROCm:\tv${ROCM_VERSION}") +string(APPEND REPORT_TEXT " - Kokkos: v" ${Kokkos_VERSION} "\n") +if(${Kokkos_FOUND}) + string(APPEND REPORT_TEXT " " ${Dim}${Kokkos_DIR}${ColorReset} "\n") +else() + string(APPEND REPORT_TEXT " " ${Dim}${Kokkos_BUILD_DIR}${ColorReset} "\n") endif() -message(" - Kokkos:\tv${Kokkos_VERSION}") + if(${output}) - message(" - ADIOS2:\tv${adios2_VERSION}") + string(APPEND REPORT_TEXT " - ADIOS2: v" ${adios2_VERSION} "\n") + if(${adios2_FOUND}) + string(APPEND REPORT_TEXT " " "${Dim}${adios2_DIR}${ColorReset}" "\n") + else() + string(APPEND REPORT_TEXT " " "${Dim}${adios2_BUILD_DIR}${ColorReset}" + "\n") + endif() endif() -if(${HDF5_FOUND}) - message(" - HDF5:\tv${HDF5_VERSION}") + +string( + APPEND + REPORT_TEXT + ${DASHED_LINE_SYMBOL} + "\n" + "Notes" + "\n" + " ${Dim}: Set flags with `cmake ... -D " + "${Magenta}${ColorReset}${Dim}=`, " + "the ${Underline}default${ColorReset}${Dim} value" + "\n" + " : will be used unless the variable is explicitly set.${ColorReset}") + +if(${TESTS}) + string( + APPEND + REPORT_TEXT + "\n" + " ${Dim}: Run the tests with the following command:" + "\n" + " : ctest --test-dir ${CMAKE_CURRENT_BINARY_DIR}${ColorReset}" + "\n") endif() -message("${DASHED_LINE_SYMBOL} -Notes - ${Dim}: Set flags with `cmake ... -D ${Magenta}${ColorReset}${Dim}=`, the ${Underline}default${ColorReset}${Dim} value - : will be used unless the variable is explicitly set.${ColorReset} -") +string(APPEND REPORT_TEXT "\n") + +message(${REPORT_TEXT}) diff --git a/cmake/styling.cmake b/cmake/styling.cmake index fb9cfcc87..5f1e4a7ad 100644 --- a/cmake/styling.cmake +++ b/cmake/styling.cmake @@ -1,3 +1,5 @@ +# cmake-lint: disable=C0103,C0301,C0111,E1120,R0913,R0915 + if(NOT WIN32) string(ASCII 27 Esc) set(ColorReset "${Esc}[m") @@ -23,20 +25,147 @@ if(NOT WIN32) set(StrikeEnd "${Esc}[0m") endif() -# message("This is normal") -# message("${Red}This is Red${ColorReset}") -# message("${Green}This is Green${ColorReset}") -# message("${Yellow}This is Yellow${ColorReset}") -# message("${Blue}This is Blue${ColorReset}") -# message("${Magenta}This is Magenta${ColorReset}") -# message("${Cyan}This is Cyan${ColorReset}") -# message("${White}This is White${ColorReset}") -# message("${BoldRed}This is BoldRed${ColorReset}") -# message("${BoldGreen}This is BoldGreen${ColorReset}") -# message("${BoldYellow}This is BoldYellow${ColorReset}") -# message("${BoldBlue}This is BoldBlue${ColorReset}") -# message("${BoldMagenta}This is BoldMagenta${ColorReset}") -# message("${BoldCyan}This is BoldCyan${ColorReset}") -# message("${BoldWhite}This is BoldWhite\n\n${ColorReset}") - -# message() \ No newline at end of file +set(DOTTED_LINE_SYMBOL "${ColorReset}. . . . . . . . . . . . . . . .") +string(APPEND DOTTED_LINE_SYMBOL " . . . . . . . . . . . . . . . . . . . . ") + +set(DASHED_LINE_SYMBOL "${ColorReset}.................................") +string(APPEND DASHED_LINE_SYMBOL "...................................... ") + +set(ON_OFF_VALUES "ON" "OFF") + +function(PureLength Text Result) + set(rt ${Text}) + string(FIND ${rt} "${Magenta}" mg_fnd) + + if(mg_fnd GREATER -1) + string(REGEX REPLACE "${Esc}\\[35m" "" rt ${rt}) + endif() + + string(LENGTH "${rt}" TextLength) + set(${Result} + "${TextLength}" + PARENT_SCOPE) +endfunction() + +function(PadTo Text Padding Target Result) + purelength("${Text}" TextLength) + math(EXPR PaddingNeeded "${Target} - ${TextLength}") + set(rt ${Text}) + + if(PaddingNeeded GREATER 0) + foreach(i RANGE 0 ${PaddingNeeded}) + set(rt "${rt}${Padding}") + endforeach() + else() + set(rt "${rt}") + endif() + + set(${Result} + "${rt}" + PARENT_SCOPE) +endfunction() + +function( + PrintChoices + Label + Flag + Choices + Value + Default + Color + OutputString + Padding) + set(rstring "- ${Label}") + + if(NOT "${Flag}" STREQUAL "") + string(APPEND rstring " [${Magenta}${Flag}${ColorReset}]") + endif() + + string(APPEND rstring ":") + + if(${Padding} EQUAL 0) + list(LENGTH "${Choices}" nchoices) + math(EXPR lastchoice "${nchoices} - 1") + + set(longest 0) + foreach(ch IN LISTS Choices) + string(LENGTH ${ch} clen) + if(clen GREATER longest) + set(longest ${clen}) + endif() + endforeach() + + if(longest GREATER 20) + set(ncols 3) + else() + set(ncols 4) + endif() + math(EXPR lastcol "${ncols} - 1") + + set(counter 0) + foreach(ch IN LISTS Choices) + if(NOT ${Value} STREQUAL "") + if(${ch} STREQUAL ${Value}) + set(col ${Color}) + else() + set(col ${Dim}) + endif() + else() + set(col ${Dim}) + endif() + + if(NOT ${Default} STREQUAL "") + if(${ch} STREQUAL ${Default}) + set(col ${Underline}${col}) + endif() + endif() + + string(LENGTH "${ch}" clen) + math(EXPR PaddingNeeded "${longest} - ${clen} + 4") + + if(counter EQUAL ${lastcol} AND NOT ${counter} EQUAL ${lastchoice}) + string(APPEND rstring "${col}~ ${ch}${ColorReset}") + else() + if(counter EQUAL 0) + string(APPEND rstring "\n ") + endif() + string(APPEND rstring "${col}~ ${ch}${ColorReset}") + foreach(i RANGE 0 ${PaddingNeeded}) + string(APPEND rstring " ") + endforeach() + endif() + + math(EXPR counter "(${counter} + 1) % ${ncols}") + endforeach() + else() + padto("${rstring}" " " ${Padding} rstring) + + set(new_choices ${Choices}) + foreach(ch IN LISTS new_choices) + string(REPLACE ${ch} "${Dim}${ch}${ColorReset}" new_choices + "${new_choices}") + endforeach() + set(Choices ${new_choices}) + if(${Value} STREQUAL "ON") + set(col ${Green}) + elseif(${Value} STREQUAL "OFF") + set(col ${Red}) + else() + set(col ${Color}) + endif() + if(NOT "${Value}" STREQUAL "") + string(REPLACE ${Value} "${col}${Value}${ColorReset}" Choices + "${Choices}") + endif() + if(NOT "${Default}" STREQUAL "") + string(REPLACE ${Default} "${Underline}${Default}${ColorReset}" Choices + "${Choices}") + endif() + string(REPLACE ";" "/" Choices "${Choices}") + string(APPEND rstring "${Choices}") + endif() + + set(${OutputString} + "${rstring}" + PARENT_SCOPE) +endfunction() diff --git a/cmake/tests.cmake b/cmake/tests.cmake index b53626723..189cc2cc4 100644 --- a/cmake/tests.cmake +++ b/cmake/tests.cmake @@ -8,24 +8,30 @@ add_subdirectory(${SRC_DIR}/metrics ${CMAKE_CURRENT_BINARY_DIR}/metrics) add_subdirectory(${SRC_DIR}/kernels ${CMAKE_CURRENT_BINARY_DIR}/kernels) add_subdirectory(${SRC_DIR}/archetypes ${CMAKE_CURRENT_BINARY_DIR}/archetypes) add_subdirectory(${SRC_DIR}/framework ${CMAKE_CURRENT_BINARY_DIR}/framework) -if (${output}) - add_subdirectory(${SRC_DIR}/output ${CMAKE_CURRENT_BINARY_DIR}/output) +add_subdirectory(${SRC_DIR}/output ${CMAKE_CURRENT_BINARY_DIR}/output) +if(${output}) + add_subdirectory(${SRC_DIR}/checkpoint ${CMAKE_CURRENT_BINARY_DIR}/checkpoint) endif() -if (${mpi}) - # tests with mpi - if (${output}) - add_subdirectory(${SRC_DIR}/output/tests ${CMAKE_CURRENT_BINARY_DIR}/output/tests) - add_subdirectory(${SRC_DIR}/framework/tests ${CMAKE_CURRENT_BINARY_DIR}/framework/tests) - endif() -else() - # tests without mpi - add_subdirectory(${SRC_DIR}/global/tests ${CMAKE_CURRENT_BINARY_DIR}/global/tests) - add_subdirectory(${SRC_DIR}/metrics/tests ${CMAKE_CURRENT_BINARY_DIR}/metrics/tests) - add_subdirectory(${SRC_DIR}/kernels/tests ${CMAKE_CURRENT_BINARY_DIR}/kernels/tests) - add_subdirectory(${SRC_DIR}/archetypes/tests ${CMAKE_CURRENT_BINARY_DIR}/archetypes/tests) - add_subdirectory(${SRC_DIR}/framework/tests ${CMAKE_CURRENT_BINARY_DIR}/framework/tests) - if (${output}) - add_subdirectory(${SRC_DIR}/output/tests ${CMAKE_CURRENT_BINARY_DIR}/output/tests) - endif() +set(TEST_DIRECTORIES "") + +if(NOT ${mpi}) + list(APPEND TEST_DIRECTORIES global) + list(APPEND TEST_DIRECTORIES metrics) + list(APPEND TEST_DIRECTORIES kernels) + list(APPEND TEST_DIRECTORIES archetypes) + list(APPEND TEST_DIRECTORIES framework) +elseif(${mpi} AND ${output}) + list(APPEND TEST_DIRECTORIES framework) endif() + +list(APPEND TEST_DIRECTORIES output) + +if(${output}) + list(APPEND TEST_DIRECTORIES checkpoint) +endif() + +foreach(test_dir IN LISTS TEST_DIRECTORIES) + add_subdirectory(${SRC_DIR}/${test_dir}/tests + ${CMAKE_CURRENT_BINARY_DIR}/${test_dir}/tests) +endforeach() diff --git a/dev/nix/adios2.nix b/dev/nix/adios2.nix new file mode 100644 index 000000000..0418b71cd --- /dev/null +++ b/dev/nix/adios2.nix @@ -0,0 +1,67 @@ +{ + pkgs ? import { }, + hdf5, + mpi, +}: + +let + name = "adios2"; + version = "2.10.2"; + cmakeFlags = { + CMAKE_CXX_STANDARD = "17"; + CMAKE_CXX_EXTENSIONS = "OFF"; + CMAKE_POSITION_INDEPENDENT_CODE = "TRUE"; + BUILD_SHARED_LIBS = "ON"; + ADIOS2_USE_HDF5 = if hdf5 then "ON" else "OFF"; + ADIOS2_USE_Python = "OFF"; + ADIOS2_USE_Fortran = "OFF"; + ADIOS2_USE_ZeroMQ = "OFF"; + BUILD_TESTING = "OFF"; + ADIOS2_BUILD_EXAMPLES = "OFF"; + ADIOS2_USE_MPI = if mpi then "ON" else "OFF"; + ADIOS2_HAVE_HDF5_VOL = if mpi then "ON" else "OFF"; + CMAKE_BUILD_TYPE = "Release"; + }; + stdenv = pkgs.gcc13Stdenv; +in +stdenv.mkDerivation { + pname = "${name}${if hdf5 then "-hdf5" else ""}${if mpi then "-mpi" else ""}"; + version = "${version}"; + src = pkgs.fetchgit { + url = "https://github.com/ornladios/ADIOS2/"; + rev = "v${version}"; + sha256 = "sha256-NVyw7xoPutXeUS87jjVv1YxJnwNGZAT4QfkBLzvQbwg="; + }; + + nativeBuildInputs = with pkgs; [ + cmake + perl + ]; + + propagatedBuildInputs = [ + pkgs.gcc13 + ] ++ (if hdf5 then (if mpi then [ pkgs.hdf5-mpi ] else [ pkgs.hdf5-cpp ]) else [ ]); + # ++ (if mpi then [ pkgs.openmpi ] else [ ]); + + configurePhase = '' + cmake -B build $src ${ + pkgs.lib.attrsets.foldlAttrs ( + acc: key: value: + acc + " -D ${key}=${value}" + ) "" cmakeFlags + } + ''; + + buildPhase = '' + cmake --build build -j + ''; + + installPhase = '' + sed -i '/if(CMAKE_INSTALL_COMPONENT/,/^[[:space:]]&endif()$/d' build/cmake/install/post/cmake_install.cmake + cmake --install build --prefix $out + chmod +x build/cmake/install/post/generate-adios2-config.sh + sh build/cmake/install/post/generate-adios2-config.sh $out + ''; + + enableParallelBuilding = true; +} diff --git a/dev/nix/kokkos.nix b/dev/nix/kokkos.nix new file mode 100644 index 000000000..2f6ee6b99 --- /dev/null +++ b/dev/nix/kokkos.nix @@ -0,0 +1,106 @@ +{ + pkgs ? import { }, + stdenv, + arch, + gpu, +}: + +let + name = "kokkos"; + pversion = "4.6.01"; + compilerPkgs = { + "HIP" = with pkgs.rocmPackages; [ + llvm.rocm-merged-llvm + rocm-core + clr + rocthrust + rocprim + rocminfo + rocm-smi + ]; + "CUDA" = with pkgs.cudaPackages; [ + llvmPackages_18.clang-tools + cudatoolkit + cuda_cudart + pkgs.gcc13 + ]; + "NONE" = [ + pkgs.llvmPackages_18.clang-tools + pkgs.gcc13 + ]; + }; + getArch = + _: + if gpu != "NONE" && arch == "NATIVE" then + throw "Please specify an architecture when the GPU support is enabled. Available architectures: https://kokkos.org/kokkos-core-wiki/get-started/configuration-guide.html#gpu-architectures" + else + arch; + cmakeExtraFlags = { + "HIP" = [ + "-D Kokkos_ENABLE_HIP=ON" + "-D Kokkos_ARCH_${getArch { }}=ON" + "-D AMDGPU_TARGETS=${builtins.replaceStrings [ "amd_" ] [ "" ] (pkgs.lib.toLower (getArch { }))}" + "-D CMAKE_CXX_COMPILER=hipcc" + ]; + "CUDA" = [ + "-D Kokkos_ENABLE_CUDA=ON" + "-D Kokkos_ARCH_${getArch { }}=ON" + "-D CMAKE_CXX_COMPILER=$WRAPPER_PATH" + ]; + "NONE" = [ ]; + }; +in +pkgs.stdenv.mkDerivation rec { + pname = "${name}"; + version = "${pversion}"; + src = pkgs.fetchgit { + url = "https://github.com/kokkos/kokkos/"; + rev = "${pversion}"; + sha256 = "sha256-+yszUbdHqhIkJZiGLZ9Ln4DYUosuJWKhO8FkbrY0/tY="; + }; + + nativeBuildInputs = with pkgs; [ + cmake + ]; + + propagatedBuildInputs = compilerPkgs.${gpu}; + + patchPhase = + if gpu == "CUDA" then + '' + export WRAPPER_PATH="$(mktemp -d)/nvcc_wrapper" + cp ${src}/bin/nvcc_wrapper $WRAPPER_PATH + substituteInPlace $WRAPPER_PATH --replace-fail "#!/usr/bin/env bash" "#!${stdenv.shell}" + chmod +x "$WRAPPER_PATH" + '' + else + ""; + + configurePhase = '' + cmake -B build -D CMAKE_BUILD_TYPE=Release \ + -D CMAKE_CXX_STANDARD=17 \ + -D CMAKE_CXX_EXTENSIONS=OFF \ + -D CMAKE_POSITION_INDEPENDENT_CODE=TRUE \ + ${pkgs.lib.concatStringsSep " " cmakeExtraFlags.${gpu}} \ + -D CMAKE_INSTALL_PREFIX=$out + ''; + + buildPhase = '' + cmake --build build -j + ''; + + installPhase = '' + cmake --install build + ''; + + # cmakeFlags = [ + # "-D CMAKE_CXX_STANDARD=17" + # "-D CMAKE_CXX_EXTENSIONS=OFF" + # "-D CMAKE_POSITION_INDEPENDENT_CODE=TRUE" + # "-D Kokkos_ARCH_${getArch { }}=ON" + # (if gpu != "none" then "-D Kokkos_ENABLE_${gpu}=ON" else "") + # "-D CMAKE_BUILD_TYPE=Release" + # ] ++ (cmakeExtraFlags.${gpu} src); + + # enableParallelBuilding = true; +} diff --git a/dev/nix/shell.nix b/dev/nix/shell.nix new file mode 100644 index 000000000..33ae57095 --- /dev/null +++ b/dev/nix/shell.nix @@ -0,0 +1,88 @@ +{ + pkgs ? import { + config.allowUnfree = true; + config.cudaSupport = gpu == "CUDA"; + }, + gpu ? "NONE", + arch ? "NATIVE", + hdf5 ? true, + mpi ? false, +}: + +let + gpuUpper = pkgs.lib.toUpper gpu; + archUpper = pkgs.lib.toUpper arch; + name = "entity-dev"; + adios2Pkg = (pkgs.callPackage ./adios2.nix { inherit pkgs mpi hdf5; }); + kokkosPkg = ( + pkgs.callPackage ./kokkos.nix { + inherit pkgs; + stdenv = pkgs.stdenv; + arch = archUpper; + gpu = gpuUpper; + } + ); + envVars = { + compiler = { + NONE = { + CXX = "g++"; + CC = "gcc"; + }; + HIP = { + CXX = "hipcc"; + CC = "hipcc"; + }; + CUDA = { }; + }; + }; +in +pkgs.mkShell { + name = "${name}-env"; + nativeBuildInputs = with pkgs; [ + zlib + cmake + + adios2Pkg + kokkosPkg + + python312 + python312Packages.jupyter + + cmake-format + cmake-lint + neocmakelsp + black + pyright + taplo + vscode-langservers-extracted + ]; + + LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath ([ + pkgs.stdenv.cc.cc + pkgs.zlib + ]); + + shellHook = + '' + BLUE='\033[0;34m' + NC='\033[0m' + + echo "following environment variables are set:" + '' + + pkgs.lib.concatStringsSep "" ( + pkgs.lib.mapAttrsToList ( + category: vars: + pkgs.lib.concatStringsSep "" ( + pkgs.lib.mapAttrsToList (name: value: '' + export ${name}=${value} + echo -e " ''\${BLUE}${name}''\${NC}=${value}" + '') vars.${gpuUpper} + ) + ) envVars + ) + + '' + echo "" + echo -e "${name} nix-shell activated" + ''; + +} diff --git a/dev/runners/Dockerfile.runner.cpu b/dev/runners/Dockerfile.runner.cpu new file mode 100644 index 000000000..3c2cf4926 --- /dev/null +++ b/dev/runners/Dockerfile.runner.cpu @@ -0,0 +1,73 @@ +FROM ubuntu:22.04 + +ARG DEBIAN_FRONTEND=noninteractive + +# upgrade +RUN apt-get update && apt-get upgrade -y + +# cmake & build tools +RUN apt-get remove -y --purge cmake && \ + apt-get install -y sudo wget curl build-essential openmpi-bin openmpi-common libopenmpi-dev && \ + wget "https://github.com/Kitware/CMake/releases/download/v3.29.6/cmake-3.29.6-linux-x86_64.tar.gz" -P /opt && \ + tar xvf /opt/cmake-3.29.6-linux-x86_64.tar.gz -C /opt && \ + rm /opt/cmake-3.29.6-linux-x86_64.tar.gz +ENV PATH=/opt/cmake-3.29.6-linux-x86_64/bin:$PATH + +# adios2 +RUN apt-get update && apt-get install -y git libhdf5-openmpi-dev && \ + git clone https://github.com/ornladios/ADIOS2.git /opt/adios2-src && \ + cd /opt/adios2-src && \ + cmake -B build \ + -D CMAKE_CXX_STANDARD=17 \ + -D CMAKE_CXX_EXTENSIONS=OFF \ + -D CMAKE_POSITION_INDEPENDENT_CODE=TRUE \ + -D BUILD_SHARED_LIBS=ON \ + -D ADIOS2_USE_HDF5=ON \ + -D ADIOS2_USE_Python=OFF \ + -D ADIOS2_USE_Fortran=OFF \ + -D ADIOS2_USE_ZeroMQ=OFF \ + -D BUILD_TESTING=OFF \ + -D ADIOS2_BUILD_EXAMPLES=OFF \ + -D ADIOS2_USE_MPI=ON \ + -D ADIOS2_HAVE_HDF5_VOL=OFF \ + -D CMAKE_INSTALL_PREFIX=/opt/adios2 && \ + cmake --build build -j && \ + cmake --install build && \ + rm -rf /opt/adios2-src + +ENV HDF5_ROOT=/usr +ENV ADIOS2_DIR=/opt/adios2 +ENV PATH=/opt/adios2/bin:$PATH + +# cleanup +RUN apt-get clean && \ + apt-get autoclean && \ + apt-get autoremove -y && \ + rm -rf /var/lib/cache/* && \ + rm -rf /var/lib/log/* && \ + rm -rf /var/lib/apt/lists/* + +ARG USER=runner +RUN useradd -ms /usr/bin/zsh $USER && \ + usermod -aG sudo $USER && \ + echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers + +USER $USER +ARG HOME=/home/$USER +WORKDIR $HOME + +# gh runner +ARG RUNNER_VERSION=2.317.0 +RUN mkdir actions-runner +WORKDIR $HOME/actions-runner + +RUN curl -o actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz \ + -L https://github.com/actions/runner/releases/download/v${RUNNER_VERSION}/actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz && \ + tar xzf ./actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz && \ + sudo ./bin/installdependencies.sh + +ADD start.sh start.sh +RUN sudo chown $USER:$USER start.sh && \ + sudo chmod +x start.sh + +ENTRYPOINT ["./start.sh"] diff --git a/dev/runners/Dockerfile.runner.cuda b/dev/runners/Dockerfile.runner.cuda index 4ff132990..6e0b0755c 100644 --- a/dev/runners/Dockerfile.runner.cuda +++ b/dev/runners/Dockerfile.runner.cuda @@ -59,14 +59,17 @@ ARG HOME=/home/$USER WORKDIR $HOME # gh runner +ARG RUNNER_VERSION=2.317.0 RUN mkdir actions-runner WORKDIR $HOME/actions-runner -RUN --mount=type=secret,id=ghtoken \ - curl -o actions-runner-linux-x64-2.317.0.tar.gz \ - -L https://github.com/actions/runner/releases/download/v2.317.0/actions-runner-linux-x64-2.317.0.tar.gz && \ - tar xzf ./actions-runner-linux-x64-2.317.0.tar.gz && \ - sudo ./bin/installdependencies.sh && \ - ./config.sh --url https://github.com/entity-toolkit/entity --token "$(sudo cat /run/secrets/ghtoken)" --labels nvidia-gpu +RUN curl -o actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz \ + -L https://github.com/actions/runner/releases/download/v${RUNNER_VERSION}/actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz && \ + tar xzf ./actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz && \ + sudo ./bin/installdependencies.sh -ENTRYPOINT ["./run.sh"] +ADD start.sh start.sh +RUN sudo chown $USER:$USER start.sh && \ + sudo chmod +x start.sh + +ENTRYPOINT ["./start.sh"] diff --git a/dev/runners/README.md b/dev/runners/README.md index 08d0cd176..0aac76e2d 100644 --- a/dev/runners/README.md +++ b/dev/runners/README.md @@ -19,3 +19,10 @@ docker run -e TOKEN= -e LABEL=nvidia-gpu --runtime=nvidia --gpus=all -dt docker build -t ghrunner:amd -f Dockerfile.runner.rocm . docker run -e TOKEN= -e LABEL=amd-gpu --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video -dt ghrunner:amd ``` + +### CPU + +```sh +docker build -t ghrunner:cpu -f Dockerfile.runner.cpu . +docker run -e TOKEN= -e LABEL=cpu -dt ghrunner:cpu +``` diff --git a/extern/Kokkos b/extern/Kokkos index eb11070f6..1b1383c60 160000 --- a/extern/Kokkos +++ b/extern/Kokkos @@ -1 +1 @@ -Subproject commit eb11070f67565b2e660659f5207f0363bdf3b882 +Subproject commit 1b1383c6001f3bfe9fe309ca923c2d786600cc79 diff --git a/extern/adios2 b/extern/adios2 index b8761e2af..a19dad6ce 160000 --- a/extern/adios2 +++ b/extern/adios2 @@ -1 +1 @@ -Subproject commit b8761e2afab2cd05b89d09b2ee4da1cd7a834225 +Subproject commit a19dad6cecb00319825f20fd9f455ebbab903d34 diff --git a/extern/entity-pgens b/extern/entity-pgens new file mode 160000 index 000000000..386eefc80 --- /dev/null +++ b/extern/entity-pgens @@ -0,0 +1 @@ +Subproject commit 386eefc80e2f63d0e29168869f881dd0b288952d diff --git a/extern/plog b/extern/plog index 85a871b13..e21baecd4 160000 --- a/extern/plog +++ b/extern/plog @@ -1 +1 @@ -Subproject commit 85a871b13be0bd1a9e0110744fa60cc9bd1e8380 +Subproject commit e21baecd4753f14da64ede979c5a19302618b752 diff --git a/extern/toml11 b/extern/toml11 deleted file mode 160000 index 12c0f379f..000000000 --- a/extern/toml11 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 12c0f379f2e865b4ce984758d5ae004f9de07d69 diff --git a/input.example.toml b/input.example.toml index 88589495c..6444a16f0 100644 --- a/input.example.toml +++ b/input.example.toml @@ -1,73 +1,72 @@ [simulation] - # Name of the simulation: + # Name of the simulation # @required # @type: string - # @example: "MySim" - # @note: The name is used for the output files. + # @note: The name is used for the output files name = "" # Simulation engine to use # @required # @type: string - # @valid: "SRPIC", "GRPIC" + # @enum: "SRPIC", "GRPIC" engine = "" - # Max runtime in physical (code) units: + # Max runtime in physical (code) units # @required - # @type: float: > 0 + # @type: float [> 0] # @example: 1e5 runtime = "" [simulation.domain] # Number of domains - # @type int - # @default: 1 (no MPI) - # @default: MPI_SIZE (MPI) + # @type: int + # @default: 1 [no MPI]; MPI_SIZE [MPI] number = "" - # Decomposition of the domain (e.g., for MPI) in each of the directions - # @type array of int of size 1, 2 or 3 - # @example: [2, 2, 2] (for a total of 8 domains) + # Decomposition of the domain (for MPI) in each of the directions + # @type: array [size 1 :->: 3] # @default: [-1, -1, -1] # @note: -1 means the code will determine the decomposition in the specific direction automatically - # @note: automatic detection is either done by inference from # of MPI tasks, or by balancing the grid size on each domain + # @note: Automatic detection is either done by inference from # of MPI tasks, or by balancing the grid size on each domain + # @example: [2, 2, 2] (total of 8 domains) decomposition = "" [grid] - # Spatial resolution of the grid: + # Spatial resolution of the grid # @required - # @type: array of uint of size 1, 2 or 3 - # @example: [1024, 1024, 1024] + # @type: array [size 1 :->: 3] # @note: Dimensionality is inferred from the size of this array + # @example: [1024, 1024, 1024] resolution = "" - # Physical extent of the grid: + # Physical extent of the grid # @required - # @type: 1/2/3-size array of float tuples, each of size 2 + # @type: array> [size 1 :->: 3] + # @note: For spherical geometry, only specify `[[rmin, rmax]]`, other values are set automatically + # @note: For cartesian geometry, cell aspect ratio has to be 1: `dx=dy=dz` # @example: [[0.0, 1.0], [-1.0, 1.0]] - # @note: For spherical geometry, only specify [[rmin, rmax]], other values are set automatically - # @note: For cartesian geometry, cell aspect ratio has to be 1, i.e., dx=dy=dz extent = "" # @inferred: # - dim # @brief: Dimensionality of the grid - # @type: short (1, 2, 3) + # @type: short + # @enum: 1, 2, 3 # @from: `grid.resolution` [grid.metric] # Metric on the grid # @required # @type: string - # @valid: "Minkowski", "Spherical", "QSpherical", "Kerr_Schild", "QKerr_Schild", "Kerr_Schild_0" + # @enum: "Minkowski", "Spherical", "QSpherical", "Kerr_Schild", "QKerr_Schild", "Kerr_Schild_0" metric = "" - # r0 paramter for the QSpherical metric, x1 = log(r-r0): - # @type: float: -inf < ... < rmin - # @default: 0.0 (e.g., x1 = log(r)) - # @note: negative values produce almost uniform grid in r + # `r0` paramter for the QSpherical metric `x1 = log(r-r0)` + # @type: float [-inf -> rmin] + # @default: 0.0 + # @note: Negative values produce almost uniform grid in r qsph_r0 = "" - # h paramter for the QSpherical metric, th = x2 + 2*h x2 (pi-2*x2)*(pi-x2)/pi^2: - # @type: float: -1 < ... < 1 - # @default: 0.0 (e.g., x2 = th) + # `h` paramter for the QSpherical metric `th = x2 + 2*h x2 (pi-2*x2)*(pi-x2)/pi^2` + # @type: float [-1 :->: 1] + # @default: 0.0 qsph_h = "" - # Spin parameter for the Kerr Schild metric: - # @type: float: 0 < ... < 1 + # Spin parameter for the Kerr Schild metric + # @type: float [0 :-> 1] # @default: 0.0 ks_a = "" @@ -75,7 +74,7 @@ # - coord # @brief: Coordinate system on the grid # @type: string - # @valid: "cartesian", "spherical", "qspherical" + # @enum: "cartesian", "spherical", "qspherical" # @from: `grid.metric.metric` # - ks_rh # @brief: Size of the horizon for GR Kerr Schild @@ -87,50 +86,56 @@ # @from: `grid.metric` [grid.boundaries] - # Boundary conditions for fields: + # Boundary conditions for fields # @required - # @type: 1/2/3-size array of string tuples, each of size 1 or 2 - # @valid: "PERIODIC", "ABSORB", "ATMOSPHERE", "CUSTOM", "HORIZON" - # @example: [["CUSTOM", "ABSORB"]] (for 2D spherical [[rmin, rmax]]) - # @note: When periodic in any of the directions, you should only set one value [..., ["PERIODIC"], ...] - # @note: In spherical, bondaries in theta/phi are set automatically (only specify bc @ [rmin, rmax]) [["ATMOSPHERE", "ABSORB"]] - # @note: In GR, the horizon boundary is set automatically (only specify bc @ rmax): [["ABSORB"]] + # @type: array> [size 1 :->: 3] + # @enum: "PERIODIC", "MATCH", "FIXED", "ATMOSPHERE", "CUSTOM", "HORIZON", "CONDUCTOR" + # @note: When periodic in any of the directions, you should only set one value: [..., ["PERIODIC"], ...] + # @note: In spherical, bondaries in theta/phi are set automatically (only specify bc @ `[rmin, rmax]`): [["ATMOSPHERE", "MATCH"]] + # @note: In GR, the horizon boundary is set automatically (only specify bc @ rmax): [["MATCH"]] + # @example: [["CUSTOM", "MATCH"]] (for 2D spherical `[[rmin, rmax]]`) fields = "" - # Boundary conditions for fields: + # Boundary conditions for fields # @required - # @type: 1/2/3-size array of string tuples, each of size 1 or 2 - # @valid: "PERIODIC", "ABSORB", "ATMOSPHERE", "CUSTOM", "REFLECT", "HORIZON" - # @example: [["PERIODIC"], ["PERIODIC"]] + # @type: array> [size 1 :->: 3] + # @enum: "PERIODIC", "ABSORB", "ATMOSPHERE", "CUSTOM", "REFLECT", "HORIZON" # @note: When periodic in any of the directions, you should only set one value [..., ["PERIODIC"], ...] - # @note: In spherical, bondaries in theta/phi are set automatically (only specify bc @ [rmin, rmax]) [["ATMOSPHERE", "ABSORB"]] - # @note: In GR, the horizon boundary is set automatically (only specify bc @ rmax): [["ABSORB"]] + # @note: In spherical, bondaries in theta/phi are set automatically (only specify bc @ `[rmin, rmax]`) [["ATMOSPHERE", "ABSORB"]] + # @note: In GR, the horizon boundary is set automatically (only specify bc @ `rmax`): [["ABSORB"]] + # @example: [["PERIODIC"], ["PERIODIC"]] particles = "" - + + [grid.boundaries.match] + # Size of the matching layer in each direction for fields in physical (code) units + # @type: float | array> + # @default: 1% of the domain size (in shortest dimension) + # @note: In spherical, this is the size of the layer in `r` from the outer wall + # @example: `ds = 1.5` (will set the same for all directions) + # @example: `ds = [[1.5], [2.0, 1.0], [1.1]]` (will duplicate 1.5 for +/- `x1` and 1.1 for +/- `x3`) + # @example: `ds = [[], [1.5], []]` (will only set for x2) + ds = "" + [grid.boundaries.absorb] - # Size of the absorption layer in physical (code) units: + # Size of the absorption layer for particles in physical (code) units # @type: float # @default: 1% of the domain size (in shortest dimension) - # @note: In spherical, this is the size of the layer in r from the outer wall + # @note: In spherical, this is the size of the layer in `r` from the outer wall # @note: In cartesian, this is the same for all dimensions where applicable ds = "" - # Absorption coefficient for fields: - # @type: float: -inf < ... < inf, != 0 - # @default: 1.0 - coeff = "" [grid.boundaries.atmosphere] - # @required: if ATMOSPHERE is one of the boundaries - # Temperature of the atmosphere in units of m0 c^2 + # Temperature of the atmosphere in units of `m0 c^2` # @type: float + # @note: [required] if `ATMOSPHERE` is one of the boundaries temperature = "" - # Peak number density of the atmosphere at base in units of n0 + # Peak number density of the atmosphere at base in units of `n0` # @type: float density = "" # Pressure scale-height in physical units # @type: float height = "" # Species indices of particles that populate the atmosphere - # @type: array of ints of size 2 + # @type: array [size 2] species = "" # Distance from the edge to which the gravity is imposed in physical units # @type: float @@ -139,19 +144,20 @@ ds = "" # @inferred: - # - g [= temperature / height] + # - g # @brief: Acceleration due to imposed gravity # @type: float # @from: `grid.boundaries.atmosphere.temperature`, `grid.boundaries.atmosphere.height` + # @value: `temperature / height` [scales] - # Fiducial larmor radius: + # Fiducial larmor radius # @required - # @type: float: > 0.0 + # @type: float [> 0.0] larmor0 = "" - # Fiducial plasma skin depth: + # Fiducial plasma skin depth # @required - # @type: float: > 0.0 + # @type: float [> 0.0] skindepth0 = "" # @inferred: @@ -163,234 +169,252 @@ # @brief: fiducial elementary volume # @type: float # @from: `grid` - # - n0 [= ppc0 / V0] + # - n0 # @brief: Fiducial number density # @type: float # @from: `particles.ppc0`, `grid` - # - q0 [= 1 / (n0 * skindepth0^2)] + # @value: `ppc0 / V0` + # - q0 # @brief: Fiducial elementary charge # @type: float # @from: `scales.skindepth0`, `scales.n0` - # - sigma0 [= (skindepth0 / larmor0)^2] + # @value: `1 / (n0 * skindepth0^2)` + # - sigma0 # @brief: Fiducial magnetization parameter # @type: float # @from: `scales.larmor0`, `scales.skindepth0` - # - B0 [= 1 / larmor0] + # @value: `(skindepth0 / larmor0)^2` + # - B0 # @brief: Fiducial magnetic field # @type: float # @from: `scales.larmor0` - # - omegaB0 [= 1 / larmor0] + # @value: `1 / larmor0` + # - omegaB0 # @brief: Fiducial cyclotron frequency # @type: float # @from: `scales.larmor0` + # @value: `1 / larmor0` [algorithms] - # Number of current smoothing passes: - # @type: unsigned short: >= 0 + # Number of current smoothing passes + # @type: ushort [>= 0] # @default: 0 current_filters = "" [algorithms.toggles] - # Toggle for the field solver: - # @type bool + # Toggle for the field solver + # @type: bool # @default: true fieldsolver = "" - # Toggle for the current deposition: - # @type bool + # Toggle for the current deposition + # @type: bool # @default: true deposit = "" [algorithms.timestep] - # Courant-Friedrichs-Lewy number: - # @type: float: 0.0 < ... < 1.0 + # Courant-Friedrichs-Lewy number + # @type: float [0.0 -> 1.0] # @default: 0.95 - # @note: CFL number determines the timestep duration. + # @note: CFL number determines the timestep duration CFL = "" - # Correction factor for the speed of light used in field solver: - # @type: float: ~1 + # Correction factor for the speed of light used in field solver + # @type: float # @default: 1.0 correction = "" - + # @inferred: - # - dt [= CFL * dx0] + # - dt # @brief: timestep duration # @type: float + # @from: `algorithms.timestep.CFL`, `scales.dx0` + # @value: `CFL * dx0` [algorithms.gr] - # Stepsize for numerical differentiation in GR pusher: - # @type: float: > 0 + # Stepsize for numerical differentiation in GR pusher + # @type: float [> 0.0] # @default: 1e-6 pusher_eps = "" - # Number of iterations for the Newton-Raphson method in GR pusher: - # @type: unsigned short: > 0 + # Number of iterations for the Newton-Raphson method in GR pusher + # @type: ushort [> 0] # @default: 10 pusher_niter = "" [algorithms.gca] - # Maximum value for E/B allowed for GCA particles: - # @type: float: 0.0 < ... < 1.0 + # Maximum value for E/B allowed for GCA particles + # @type: float [0.0 -> 1.0] # @default: 0.9 e_ovr_b_max = "" - # Maximum Larmor radius allowed for GCA particles (in physical units): - # @type: float: > 0 + # Maximum Larmor radius allowed for GCA particles (in physical units) + # @type: float # @default: 0.0 # @note: When `larmor_max` == 0, the limit is disabled larmor_max = "" [algorithms.synchrotron] - # Radiation reaction limit gamma-factor for synchrotron: - # @required [if one of the species has `cooling = "synchrotron"`] - # @type: float: > 0 + # Radiation reaction limit gamma-factor for synchrotron + # @type: float [> 0.0] + # @default: 1.0 + # @note: [required] if one of the species has `cooling = "synchrotron"` gamma_rad = "" [particles] - # Fiducial number of particles per cell: + # Fiducial number of particles per cell # @required - # @type: float: > 0 + # @type: float [> 0.0] ppc0 = "" - # Toggle for using particle weights: + # Toggle for using particle weights # @type: bool # @default: false use_weights = "" - # Timesteps between particle re-sorting: - # @type: unsigned int: >= 0 + # Timesteps between particle re-sorting (removing dead particles) + # @type: uint # @default: 100 - # @note: When MPI is enable, particles are sorted every step. - # @note: When `sort_interval` == 0, the sorting is disabled. - sort_interval = "" + # @note: Set to 0 to disable re-sorting + clear_interval = "" # @inferred: # - nspec # @brief: Number of particle species - # @type: unsigned int - # @from: `particles.species` - # - species - # @brief: An object containing information about all the species - # @type: vector of ParticleSpecies + # @type: uint # @from: `particles.species` [[particles.species]] - # Label of the species: + # Label of the species # @type: string - # @default: "s*" (where "*" is the species index starting at 1) + # @default: "s" # @example: "e-" + # @note: `` is the index of the species in the list starting from 1 label = "" - # Mass of the species (in units of fiducial mass): + # Mass of the species (in units of fiducial mass) # @required - # @type: float + # @type: float [>= 0.0] mass = "" - # Charge of the species (in units of fiducial charge): + # Charge of the species (in units of fiducial charge) # @required # @type: float charge = "" - # Maximum number of particles per task: + # Maximum number of particles per task # @required - # @type: unsigned int: > 0 + # @type: uint [> 0] maxnpart = "" - # Pusher algorithm for the species: + # Pusher algorithm for the species # @type: string - # @default: "Boris" for massive and "Photon" for massless - # @valid: "Boris", "Vay", "Boris,GCA", "Vay,GCA", "Photon", "None" + # @default: "Boris" [massive]; "Photon" [massless] + # @enum: "Boris", "Vay", "Boris,GCA", "Vay,GCA", "Photon", "None" pusher = "" - # Number of additional (payload) variables for each particle of the given species: - # @type: unsigned short: >= 0 + # Number of additional (payload) variables for each particle of the given species + # @type: ushort # @default: 0 n_payloads = "" - # Radiation reaction to use for the species: + # Radiation reaction to use for the species # @type: string # @default: "None" - # @valid: "None", "Synchrotron" + # @enum: "None", "Synchrotron" cooling = "" -# Parameters for specific problem generators and setups: +# Parameters for specific problem generators and setups [setup] [output] - # Output format: + # Output format # @type: string - # @valid: "disabled", "hdf5", "BPFile" # @default: "hdf5" + # @enum: "disabled", "hdf5", "BPFile" format = "" - # Number of timesteps between all outputs (overriden by specific output interval below): - # @type: unsigned int: > 0 + # Number of timesteps between all outputs + # @type: uint [> 0] # @default: 1 + # @note: Value is overriden by output intervals for specific outputs interval = "" - # Physical (code) time interval between all outputs (overriden by specific output intervals below): - # @type: float: > 0 - # @default: -1.0 (disabled) + # Physical (code) time interval between all outputs + # @type: float + # @default: -1.0 # @note: When `interval_time` < 0, the output is controlled by `interval`, otherwise by `interval_time` + # @note: Value is overriden by output intervals for specific outputs interval_time = "" + # Whether to output each timestep into separate files + # @type: bool + # @default: true + # @deprecated: starting v1.3.0 + separate_files = "" [output.fields] - # Toggle for the field output: + # Toggle for the field output # @type: bool # @default: true enable = "" - # Field quantities to output: - # @type: array of strings - # @valid: fields: "E", "B", "J", "divE" - # @valid: moments: "Rho", "Charge", "N", "Nppc", "T0i", "Tij" - # @valid: for GR: "D", "H", "divD", "A" + # Field quantities to output + # @type: array # @default: [] - # @note: For T, you can use unspecified indices, e.g., Tij, T0i, or specific ones, e.g., Ttt, T00, T02, T23 - # @note: For T, in cartesian can also use "x" "y" "z" instead of "1" "2" "3" - # @note: By default, we accumulate moments from all massive species, one can specify only specific species: e.g., Ttt_1_2, Rho_1, Rho_3_4 + # @enum: "E", "B", "J", "divE", "Rho", "Charge", "N", "Nppc", "T0i", "Tij", "Vi", "D", "H", "divD", "A" + # @note: For `T`, you can use unspecified indices: `Tij`, `T0i`, or specific ones: `Ttt`, `T00`, `T02`, `T23` + # @note: For `T`, in cartesian can also use "x" "y" "z" instead of "1" "2" "3" + # @note: By default, we accumulate moments from all massive species, one can specify only specific species: `Ttt_1_2`, `Rho_1`, `Rho_3_4` quantities = "" - # Custom (user-defined) field quantities: - # @type: array of strings + # Custom (user-defined) field quantities + # @type: array # @default: [] custom = "" - # @NOT_IMPLEMENTED: Stride for the output of fields: - # @type: unsigned short: > 1 - # @default: 1 - stride = "" - # Smoothing window for the output of moments (e.g., "Rho", "Charge", "T", etc.): - # @type: unsigned short: >= 0 + # Smoothing window for the output of moments ("Rho", "Charge", "T", ...) + # @type: ushort # @default: 0 mom_smooth = "" - # Number of timesteps between field outputs (overrides `output.interval`): - # @type: unsigned int: > 0 - # @default: 0 (use `output.interval`) + # Number of timesteps between field outputs + # @type: uint + # @default: 0 + # @note: When `!= 0`, overrides `output.interval` + # @note: When `== 0`, `output.interval` is used interval = "" - # Physical (code) time interval between field outputs (overrides `output.interval_time`): - # @type: float: > 0 - # @default: -1.0 (use `output.interval_time`) - # @note: When `interval_time` < 0, the output is controlled by `interval`, otherwise by `interval_time` + # Physical (code) time interval between field outputs + # @type: float + # @default: -1.0 + # @note: When `< 0`, the output is controlled by `interval` + # @note: When specified, overrides `output.interval_time` interval_time = "" + # Downsample factor for the output of fields + # @type: uint | array [>= 1] + # @default: [1, 1, 1] + # @note: The output is downsampled by the given factors in each direction + # @note: If a scalar is given, it is applied to all directions + downsampling = "" [output.particles] - # Toggle for the particles output: + # Toggle for the particles output # @type: bool # @default: true enable = "" - # Particle species indices to output: - # @type: array of ints - # @default: [] = all species + # Particle species indices to output + # @type: array + # @default: [] + # @note: If empty, all species are output species = "" - # Stride for the output of particles: - # @type: unsigned int: > 1 + # Stride for the output of particles + # @type: uint [> 1] # @default: 100 stride = "" - # Number of timesteps between particle outputs (overrides `output.interval`): - # @type: unsigned int: > 0 - # @default: 0 (use `output.interval`) + # Number of timesteps between particle outputs + # @type: uint + # @default: 0 + # @note: When `!= 0`, overrides `output.interval` + # @note: When `== 0`, `output.interval` is used interval = "" - # Physical (code) time interval between field outputs (overrides `output.interval_time`): - # @type: float: > 0 - # @default: -1.0 (use `output.interval_time`) - # @note: When `interval_time` < 0, the output is controlled by `interval`, otherwise by `interval_time` + # Physical (code) time interval between particle outputs + # @type: float + # @default: -1.0 + # @note: When `< 0`, the output is controlled by `interval` + # @note: When specified, overrides `output.interval_time` interval_time = "" [output.spectra] - # Toggle for the spectra output: + # Toggle for the spectra output # @type: bool # @default: true enable = "" - # Minimum energy for the spectra output: + # Minimum energy for the spectra output # @type: float # @default: 1e-3 e_min = "" - # Maximum energy for the spectra output: + # Maximum energy for the spectra output # @type: float # @default: 1e3 e_max = "" @@ -398,36 +422,112 @@ # @type: bool # @default: true log_bins = "" - # Number of bins for the spectra output: - # @type: unsigned int: > 0 + # Number of bins for the spectra output + # @type: uint [> 0] # @default: 200 n_bins = "" - # Number of timesteps between spectra outputs (overrides `output.interval`): - # @type: unsigned int: > 0 - # @default: 0 (use `output.interval`) + # Number of timesteps between spectra outputs + # @type: uint + # @default: 0 + # @note: When `!= 0`, overrides `output.interval` + # @note: When `== 0`, `output.interval` is used interval = "" - # Physical (code) time interval between spectra outputs (overrides `output.interval_time`): - # @type: float: > 0 - # @default: -1.0 (use `output.interval_time`) - # @note: When `interval_time` < 0, the output is controlled by `interval`, otherwise by `interval_time` + # Physical (code) time interval between spectra outputs + # @type: float + # @default: -1.0 + # @note: When `< 0`, the output is controlled by `interval` + # @note: When specified, overrides `output.interval_time` interval_time = "" [output.debug] - # Output fields "as is" without conversions: + # Output fields "as is" without conversions # @type: bool # @default: false as_is = "" - # Output fields with values in ghost cells: + # Output fields with values in ghost cells # @type: bool # @default: false ghosts = "" + [output.stats] + # Toggle for the stats output + # @type: bool + # @default: true + enable = "" + # Number of timesteps between stat outputs + # @type: uint [> 0] + # @default: 100 + # @note: Overriden if `output.stats.interval_time != -1` + interval = "" + # Physical (code) time interval between stat outputs + # @type: float + # @default: -1.0 + # @note: When `< 0`, the output is controlled by `interval` + interval_time = "" + # Field quantities to output + # @type: array + # @default: ["B^2", "E^2", "ExB", "Rho", "T00"] + # @enum: "B^2", "E^2", "ExB", "N", "Charge", "Rho", "T00", "T0i", "Tij" + # @note: Same notation as for `output.fields.quantities` + quantities = "" + # Custom (user-defined) stats + # @type: array + # @default: [] + custom = "" + +[checkpoint] + # Number of timesteps between checkpoints + # @type: uint [> 0] + # @default: 1000 + interval = "" + # Physical (code) time interval between checkpoints + # @type: float [> 0] + # @default: -1.0 + # @note: When `< 0`, the output is controlled by `interval` + interval_time = "" + # Number of checkpoints to keep + # @type: int + # @default: 2 + # @note: 0 = disable checkpointing + # @note: -1 = keep all checkpoints + keep = "" + # Write a checkpoint once after a fixed walltime + # @type: string + # @default: "00:00:00" + # @note: The format is "HH:MM:SS" + # @note: Empty string or "00:00:00" disables this functionality + # @note: Writing checkpoint at walltime does not stop the simulation + walltime = "" + # Parent directory to write checkpoints to + # @type: string + # @default: `.ckpt` + # @note: The directory is created if it does not exist + write_path = "" + # Parent directory to use when resuming from a checkpoint + # @type: string + # @default: inherit `write_path` + read_path = "" + + # @inferred: + # - is_resuming + # @brief: Whether the simulation is resuming from a checkpoint + # @type: bool + # @from: command-line flag + # - start_step + # @brief: Timestep of the checkpoint used to resume + # @type: uint + # @from: automatically determined during restart + # - start_time + # @brief: Time of the checkpoint used to resume + # @type: float + # @from: automatically determined during restart + [diagnostics] - # Number of timesteps between diagnostic logs: - # @type: int: > 0 + # Number of timesteps between diagnostic logs + # @type: int [> 0] # @default: 1 interval = "" - # Blocking timers between successive algorithms: + # Blocking timers between successive algorithms # @type: bool # @default: false blocking_timers = "" @@ -435,3 +535,9 @@ # @type: bool # @default: true colored_stdout = "" + # Specify the log level + # @type: string + # @default: "VERBOSE" + # @enum: "VERBOSE", "WARNING", "ERROR" + # @note: "VERBOSE" prints all messages, "WARNING" prints only warnings and errors, "ERROR" prints only errors + log_level = "" diff --git a/setups/srpic/magnetosphere/magnetosphere.toml b/legacy/_monopole/monopole.toml similarity index 51% rename from setups/srpic/magnetosphere/magnetosphere.toml rename to legacy/_monopole/monopole.toml index 34e04b02d..cf735fce8 100644 --- a/setups/srpic/magnetosphere/magnetosphere.toml +++ b/legacy/_monopole/monopole.toml @@ -1,31 +1,31 @@ [simulation] - name = "magnetosphere" - engine = "srpic" + name = "monopole" + engine = "srpic" runtime = 60.0 [grid] resolution = [2048, 1024] - extent = [[1.0, 50.0]] + extent = [[1.0, 50.0]] [grid.metric] metric = "qspherical" [grid.boundaries] - fields = [["ATMOSPHERE", "ABSORB"]] + fields = [["ATMOSPHERE", "ABSORB"]] particles = [["ATMOSPHERE", "ABSORB"]] - + [grid.boundaries.absorb] ds = 1.0 [grid.boundaries.atmosphere] temperature = 0.1 - density = 10.0 - height = 0.02 - species = [1, 2] - ds = 2.0 - + density = 10.0 + height = 0.02 + species = [1, 2] + ds = 2.0 + [scales] - larmor0 = 2e-5 + larmor0 = 2e-5 skindepth0 = 0.01 [algorithms] @@ -36,37 +36,38 @@ [algorithms.gca] e_ovr_b_max = 0.9 - larmor_max = 1.0 + larmor_max = 1.0 [particles] - ppc0 = 5.0 - use_weights = true - sort_interval = 100 + ppc0 = 5.0 + use_weights = true + clear_interval = 100 [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e8 - pusher = "Boris,GCA" + label = "e-" + mass = 1.0 + charge = -1.0 + maxnpart = 1e8 + pusher = "Boris,GCA" [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e8 - pusher = "Boris,GCA" + label = "e+" + mass = 1.0 + charge = 1.0 + maxnpart = 1e8 + pusher = "Boris,GCA" [setup] - Bsurf = 1.0 + Bsurf = 1.0 period = 60.0 [output] format = "hdf5" - + [output.fields] interval_time = 0.1 - quantities = ["N_1", "N_2", "E", "B", "T00"] + quantities = ["N_1", "N_2", "E", "B", "T00"] + mom_smooth = 2 [output.particles] enable = false @@ -75,5 +76,5 @@ enable = false [diagnostics] - interval = 50 + interval = 50 colored_stdout = true diff --git a/setups/srpic/monopole/pgen.hpp b/legacy/_monopole/pgen.hpp similarity index 97% rename from setups/srpic/monopole/pgen.hpp rename to legacy/_monopole/pgen.hpp index 389a6c6f7..ed8877b71 100644 --- a/setups/srpic/monopole/pgen.hpp +++ b/legacy/_monopole/pgen.hpp @@ -86,7 +86,7 @@ namespace user { inline PGen() {} - auto FieldDriver(real_t time) const -> DriveFields { + auto AtmFields(real_t time) const -> DriveFields { return DriveFields { time, Bsurf, Rstar, Omega }; } }; diff --git a/legacy/benchmark.cpp b/legacy/benchmark.cpp new file mode 100644 index 000000000..54fc17cf9 --- /dev/null +++ b/legacy/benchmark.cpp @@ -0,0 +1,273 @@ +#include "enums.h" +#include "global.h" + +#include "utils/error.h" + +#include "metrics/metric_base.h" +#include "metrics/minkowski.h" + +#include "framework/containers/species.h" +#include "framework/domain/domain.h" +#include "framework/domain/metadomain.h" + +#include + +#include "framework/domain/communications.cpp" +#include "mpi.h" +#include "mpi-ext.h" + +#define TIMER_START(label) \ + Kokkos::fence(); \ + auto start_##label = std::chrono::high_resolution_clock::now(); + +#define TIMER_STOP(label) \ + Kokkos::fence(); \ + auto stop_##label = std::chrono::high_resolution_clock::now(); \ + auto duration_##label = std::chrono::duration_cast( \ + stop_##label - start_##label) \ + .count(); \ + std::cout << "Timer [" #label "]: " << duration_##label << " microseconds" \ + << std::endl; + +/* + Test to check the performance of the new particle allocation scheme + - Create a metadomain object main() + - Set npart + initialize tags InitializeParticleArrays() + - 'Push' the particles by randomly updating the tags PushParticles() + - Communicate particles to neighbors and time the communication + - Compute the time taken for best of N iterations for the communication + */ +using namespace ntt; + +// Set npart and set the particle tags to alive +template +void InitializeParticleArrays(Domain& domain, const int npart) { + raise::ErrorIf(npart > domain.species[0].maxnpart(), + "Npart cannot be greater than maxnpart", + HERE); + const auto nspecies = domain.species.size(); + for (int i_spec = 0; i_spec < nspecies; i_spec++) { + domain.species[i_spec].set_npart(npart); + domain.species[i_spec].SyncHostDevice(); + auto& this_tag = domain.species[i_spec].tag; + Kokkos::parallel_for( + "Initialize particles", + npart, + Lambda(const std::size_t i) { this_tag(i) = ParticleTag::alive; }); + } + return; +} + +// Randomly reassign tags to particles for a fraction of particles +template +void PushParticles(Domain& domain, + const double send_frac, + const int seed_ind, + const int seed_tag) { + raise::ErrorIf(send_frac > 1.0, "send_frac cannot be greater than 1.0", HERE); + const auto nspecies = domain.species.size(); + for (int i_spec = 0; i_spec < nspecies; i_spec++) { + domain.species[i_spec].set_unsorted(); + const auto nparticles = domain.species[i_spec].npart(); + const auto nparticles_to_send = static_cast(send_frac * nparticles); + // Generate random indices to send + // Kokkos::Random_XorShift64_Pool<> random_pool(seed_ind); + Kokkos::View indices_to_send("indices_to_send", nparticles_to_send); + Kokkos::fill_random(indices_to_send, domain.random_pool, 0, nparticles); + // Generate random tags to send + // Kokkos::Random_XorShift64_Pool<> random_pool_tag(seed_tag); + Kokkos::View tags_to_send("tags_to_send", nparticles_to_send); + Kokkos::fill_random(tags_to_send, + domain.random_pool, + 0, + domain.species[i_spec].ntags()); + auto& this_tag = domain.species[i_spec].tag; + Kokkos::parallel_for( + "Push particles", + nparticles_to_send, + Lambda(const std::size_t i) { + auto prtl_to_send = indices_to_send(i); + auto tag_to_send = tags_to_send(i); + this_tag(prtl_to_send) = tag_to_send; + }); + domain.species[i_spec].npart_per_tag(); + domain.species[i_spec].SyncHostDevice(); + } + return; +} + +auto main(int argc, char* argv[]) -> int { + GlobalInitialize(argc, argv); + { + /* + MPI checks + */ + printf("Compile time check:\n"); +#if defined(MPIX_CUDA_AWARE_SUPPORT) && MPIX_CUDA_AWARE_SUPPORT + printf("This MPI library has CUDA-aware support.\n", MPIX_CUDA_AWARE_SUPPORT); +#elif defined(MPIX_CUDA_AWARE_SUPPORT) && !MPIX_CUDA_AWARE_SUPPORT + printf("This MPI library does not have CUDA-aware support.\n"); +#else + printf("This MPI library cannot determine if there is CUDA-aware support.\n"); +#endif /* MPIX_CUDA_AWARE_SUPPORT */ +printf("Run time check:\n"); +#if defined(MPIX_CUDA_AWARE_SUPPORT) + if (1 == MPIX_Query_cuda_support()) { + printf("This MPI library has CUDA-aware support.\n"); + } else { + printf("This MPI library does not have CUDA-aware support.\n"); + } +#else /* !defined(MPIX_CUDA_AWARE_SUPPORT) */ + printf("This MPI library cannot determine if there is CUDA-aware support.\n"); +#endif /* MPIX_CUDA_AWARE_SUPPORT */ + + /* + Test to send and receive Kokkos arrays + */ + int sender_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &sender_rank); + + int neighbor_rank = 0; + if (sender_rank == 0) { + neighbor_rank = 1; + } + else if (sender_rank == 1) { + neighbor_rank = 0; + } + else { + raise::Error("This test is only for 2 ranks", HERE); + } + Kokkos::View send_array("send_array", 10); + Kokkos::View recv_array("recv_array", 10); + if (sender_rank == 0) { + Kokkos::deep_copy(send_array, 10); + } + else { + Kokkos::deep_copy(send_array, 20); + } + + auto send_array_host = Kokkos::create_mirror_view(send_array); + Kokkos::deep_copy(send_array_host, send_array); + auto host_recv_array = Kokkos::create_mirror_view(recv_array); + + MPI_Sendrecv(send_array.data(), send_array.extent(0), MPI_INT, neighbor_rank, 0, + recv_array.data(), recv_array.extent(0), MPI_INT, neighbor_rank, 0, + MPI_COMM_WORLD, MPI_STATUS_IGNORE); + + // Print the received array + Kokkos::deep_copy(host_recv_array, recv_array); + for (int i = 0; i < 10; ++i) { + printf("Rank %d: Received %d\n", sender_rank, host_recv_array(i)); + } + + + std::cout << "Constructing the domain" << std::endl; + // Create a Metadomain object + const unsigned int ndomains = 2; + const std::vector global_decomposition = { + {-1, -1, -1} + }; + const std::vector global_ncells = { 32, 32, 32 }; + const boundaries_t global_extent = { + {0.0, 3.0}, + {0.0, 3.0}, + {0.0, 3.0} + }; + const boundaries_t global_flds_bc = { + {FldsBC::PERIODIC, FldsBC::PERIODIC}, + {FldsBC::PERIODIC, FldsBC::PERIODIC}, + {FldsBC::PERIODIC, FldsBC::PERIODIC} + }; + const boundaries_t global_prtl_bc = { + {PrtlBC::PERIODIC, PrtlBC::PERIODIC}, + {PrtlBC::PERIODIC, PrtlBC::PERIODIC}, + {PrtlBC::PERIODIC, PrtlBC::PERIODIC} + }; + const std::map metric_params = {}; + const int maxnpart = argc > 1 ? std::stoi(argv[1]) : 1000; + const double npart_to_send_frac = 0.01; + const int npart = static_cast(maxnpart * (1 - 2 * npart_to_send_frac)); + auto species = ntt::ParticleSpecies(1u, + "test_e", + 1.0f, + 1.0f, + maxnpart, + ntt::PrtlPusher::BORIS, + false, + ntt::Cooling::NONE); + auto metadomain = Metadomain>( + ndomains, + global_decomposition, + global_ncells, + global_extent, + global_flds_bc, + global_prtl_bc, + metric_params, + { species }); + + const auto local_subdomain_idx = metadomain.l_subdomain_indices()[0]; + auto local_domain = metadomain.subdomain_ptr(local_subdomain_idx); + auto timers = timer::Timers { { "Communication" }, nullptr, false }; + InitializeParticleArrays(*local_domain, npart); + // Timers for both the communication routines + auto total_time_elapsed_old = 0; + auto total_time_elapsed_new = 0; + + int seed_ind = 0; + int seed_tag = 1; + Kokkos::fence(); + + for (int i = 0; i < 10; ++i) { + { + // Push + seed_ind += 2; + seed_tag += 3; + PushParticles(*local_domain, npart_to_send_frac, seed_ind, seed_tag); + // Sort new + Kokkos::fence(); + auto start_new = std::chrono::high_resolution_clock::now(); + metadomain.CommunicateParticlesBuffer(*local_domain, &timers); + auto stop_new = std::chrono::high_resolution_clock::now(); + auto duration_new = std::chrono::duration_cast( + stop_new - start_new) + .count(); + total_time_elapsed_new += duration_new; + Kokkos::fence(); + } + { + // Push + seed_ind += 2; + seed_tag += 3; + PushParticles(*local_domain, npart_to_send_frac, seed_ind, seed_tag); + // Sort old + Kokkos::fence(); + auto start_old = std::chrono::high_resolution_clock::now(); + metadomain.CommunicateParticles(*local_domain, &timers); + auto stop_old = std::chrono::high_resolution_clock::now(); + auto duration_old = std::chrono::duration_cast( + stop_old - start_old) + .count(); + total_time_elapsed_old += duration_old; + Kokkos::fence(); + } + } + printf("Total time elapsed for old: %f us : %f us/prtl\n", + total_time_elapsed_old / 10.0, + total_time_elapsed_old / 10.0 * 1000 / npart); + printf("Total time elapsed for new: %f us : %f us/prtl\n", + total_time_elapsed_new / 10.0, + total_time_elapsed_new / 10.0 * 1000 / npart); + } + GlobalFinalize(); + return 0; +} + +/* + Buggy behavior: + Consider a single domain with a single mpi rank + Particle tag arrays is set to [0, 0, 1, 1, 2, 3, ...] for a single domain + CommunicateParticles() discounts all the dead particles and reassigns the + other tags to alive + CommunicateParticlesBuffer() only keeps the ParticleTag::Alive particles + and discounts the rest +*/ diff --git a/legacy/src/framework/utils/particle_injectors.hpp b/legacy/src/framework/utils/particle_injectors.hpp index 21e3dad72..c275f170a 100644 --- a/legacy/src/framework/utils/particle_injectors.hpp +++ b/legacy/src/framework/utils/particle_injectors.hpp @@ -165,7 +165,7 @@ namespace ntt { * @brief Volumetrically uniform particle injector parallelized over particles. * @tparam D dimension. * @tparam S simulation engine. - * @tparam EnDist energy distribution [default = ColdDist]. + * @tparam EnDist energy distribution [default = Cold]. * * @param params simulation parameters. * @param mblock meshblock. @@ -174,7 +174,7 @@ namespace ntt { * @param region region to inject particles as a list of coordinates [optional]. * @param time current time [optional]. */ - template class EnDist = ColdDist> + template class EnDist = Cold> inline void InjectUniform(const SimulationParams& params, Meshblock& mblock, const std::vector& species, @@ -613,8 +613,8 @@ namespace ntt { * @brief Particle injector parallelized by cells in a volume. * @tparam D dimension. * @tparam S simulation engine. - * @tparam EnDist energy distribution [default = ColdDist]. - * @tparam SpDist spatial distribution [default = UniformDist]. + * @tparam EnDist energy distribution [default = Cold]. + * @tparam SpDist spatial distribution [default = Uniform]. * @tparam InjCrit injection criterion [default = NoCriterion]. * * @param params simulation parameters. @@ -626,8 +626,8 @@ namespace ntt { */ template class EnDist = ColdDist, - template class SpDist = UniformDist, + template class EnDist = Cold, + template class SpDist = Uniform, template class InjCrit = NoCriterion> inline void InjectInVolume(const SimulationParams& params, Meshblock& mblock, @@ -928,7 +928,7 @@ namespace ntt { * @brief ... up to certain number density. * @tparam D dimension. * @tparam S simulation engine. - * @tparam EnDist energy distribution [default = ColdDist]. + * @tparam EnDist energy distribution [default = Cold]. * @tparam InjCrit injection criterion [default = NoCriterion]. * * @param params simulation parameters. @@ -940,7 +940,7 @@ namespace ntt { */ template class EnDist = ColdDist, + template class EnDist = Cold, template class InjCrit = NoCriterion> inline void InjectNonUniform(const SimulationParams& params, Meshblock& mblock, diff --git a/legacy/src/pic/particles/particle_pusher.hpp b/legacy/src/pic/particles/particle_pusher.hpp index 4c0ec639a..7991a95a4 100644 --- a/legacy/src/pic/particles/particle_pusher.hpp +++ b/legacy/src/pic/particles/particle_pusher.hpp @@ -1,14 +1,13 @@ #ifndef PIC_PARTICLE_PUSHER_H #define PIC_PARTICLE_PUSHER_H -#include "wrapper.h" - -#include "pic.h" +#include "utils/qmath.h" #include "io/output.h" #include "meshblock/meshblock.h" #include "meshblock/particles.h" -#include "utils/qmath.h" +#include "pic.h" +#include "wrapper.h" #include METRIC_HEADER #include @@ -73,35 +72,34 @@ namespace ntt { real_t time, real_t coeff, real_t dt, - ProblemGenerator& pgen) : - EB { mblock.em }, - i1 { particles.i1 }, - i2 { particles.i2 }, - i3 { particles.i3 }, - i1_prev { particles.i1_prev }, - i2_prev { particles.i2_prev }, - i3_prev { particles.i3_prev }, - dx1 { particles.dx1 }, - dx2 { particles.dx2 }, - dx3 { particles.dx3 }, - dx1_prev { particles.dx1_prev }, - dx2_prev { particles.dx2_prev }, - dx3_prev { particles.dx3_prev }, - ux1 { particles.ux1 }, - ux2 { particles.ux2 }, - ux3 { particles.ux3 }, - phi { particles.phi }, - tag { particles.tag }, - metric { mblock.metric }, - time { time }, - coeff { coeff }, - dt { dt }, - ni1 { (int)mblock.Ni1() }, - ni2 { (int)mblock.Ni2() }, - ni3 { (int)mblock.Ni3() } + ProblemGenerator& pgen) + : EB { mblock.em } + , i1 { particles.i1 } + , i2 { particles.i2 } + , i3 { particles.i3 } + , i1_prev { particles.i1_prev } + , i2_prev { particles.i2_prev } + , i3_prev { particles.i3_prev } + , dx1 { particles.dx1 } + , dx2 { particles.dx2 } + , dx3 { particles.dx3 } + , dx1_prev { particles.dx1_prev } + , dx2_prev { particles.dx2_prev } + , dx3_prev { particles.dx3_prev } + , ux1 { particles.ux1 } + , ux2 { particles.ux2 } + , ux3 { particles.ux3 } + , phi { particles.phi } + , tag { particles.tag } + , metric { mblock.metric } + , time { time } + , coeff { coeff } + , dt { dt } + , ni1 { (int)mblock.Ni1() } + , ni2 { (int)mblock.Ni2() } + , ni3 { (int)mblock.Ni3() } #ifdef EXTERNAL_FORCE - , - pgen { pgen } + , pgen { pgen } #endif { (void)pgen; @@ -237,7 +235,7 @@ namespace ntt { const auto coeff = charge_ovr_mass * HALF * dt * params.B0(); Kokkos::parallel_for( "ParticlesPush", - Kokkos::RangePolicy(0, particles.npart()), + Kokkos::RangePolicy(0, particles.npart()), Pusher_kernel(mblock, particles, time, coeff, dt, pgen)); } @@ -638,9 +636,9 @@ namespace ntt { template template Inline void Pusher_kernel::get3VelCntrv(T, - index_t& p, + index_t& p, vec_t& xp, - vec_t& v) const { + vec_t& v) const { metric.v3_Cart2Cntrv(xp, { ux1(p), ux2(p), ux3(p) }, v); auto inv_energy { ONE / getEnergy(T {}, p) }; v[0] *= inv_energy; @@ -666,7 +664,8 @@ namespace ntt { } template <> - Inline void Pusher_kernel::getPrtlPos(index_t& p, coord_t& xp) const { + Inline void Pusher_kernel::getPrtlPos(index_t& p, + coord_t& xp) const { xp[0] = static_cast(i1(p)) + static_cast(dx1(p)); xp[1] = static_cast(i2(p)) + static_cast(dx2(p)); xp[2] = phi(p); @@ -1066,7 +1065,7 @@ namespace ntt { #else template Inline void Pusher_kernel::initForce(coord_t& xp, - vec_t& force_Cart) const { + vec_t& force_Cart) const { coord_t xp_Ph { ZERO }; coord_t xp_Code { ZERO }; for (short d { 0 }; d < static_cast(PrtlCoordD); ++d) { diff --git a/legacy/tests/kernels-gr.cpp b/legacy/tests/kernels-gr.cpp index 84a0c952b..6962f7c9f 100644 --- a/legacy/tests/kernels-gr.cpp +++ b/legacy/tests/kernels-gr.cpp @@ -1,16 +1,16 @@ -#include "wrapper.h" - #include #include #include -#include METRIC_HEADER +#include "wrapper.h" -#include "particle_macros.h" +#include METRIC_HEADER #include "kernels/particle_pusher_gr.hpp" +#include "particle_macros.h" + template void put_value(ntt::array_t& arr, T value, int i) { auto arr_h = Kokkos::create_mirror_view(arr); @@ -154,9 +154,10 @@ auto main(int argc, char* argv[]) -> int { static_cast(1.0e-5), 10, boundaries); - Kokkos::parallel_for("ParticlesPush", - Kokkos::RangePolicy(0, 1), - kernel); + Kokkos::parallel_for( + "ParticlesPush", + Kokkos::RangePolicy(0, 1), + kernel); auto [ra, tha] = get_physical_coord(0, i1, i2, dx1, dx2, metric); const real_t pha = get_value(phi, 0); @@ -207,4 +208,4 @@ auto main(int argc, char* argv[]) -> int { ntt::GlobalFinalize(); return 0; -} \ No newline at end of file +} diff --git a/legacy/tests/kernels-sr.cpp b/legacy/tests/kernels-sr.cpp index 59ce0646b..3f64122cd 100644 --- a/legacy/tests/kernels-sr.cpp +++ b/legacy/tests/kernels-sr.cpp @@ -1,17 +1,17 @@ -#include "wrapper.h" - #include #include #include +#include "wrapper.h" + #include METRIC_HEADER #include PGEN_HEADER -#include "particle_macros.h" - #include "kernels/particle_pusher_sr.hpp" +#include "particle_macros.h" + template void put_value(ntt::array_t& arr, T value, int i) { auto arr_h = Kokkos::create_mirror_view(arr); @@ -181,9 +181,10 @@ auto main(int argc, char* argv[]) -> int { ZERO, ZERO, ZERO); - Kokkos::parallel_for("ParticlesPush", - Kokkos::RangePolicy(0, 1), - kernel); + Kokkos::parallel_for( + "ParticlesPush", + Kokkos::RangePolicy(0, 1), + kernel); auto [xa, ya] = get_cartesian_coord(0, i1, i2, dx1, dx2, phi, metric); if (!ntt::AlmostEqual(xa, @@ -221,4 +222,4 @@ auto main(int argc, char* argv[]) -> int { ntt::GlobalFinalize(); return 0; -} \ No newline at end of file +} diff --git a/minimal/CMakeLists.txt b/minimal/CMakeLists.txt new file mode 100644 index 000000000..b21dd0fec --- /dev/null +++ b/minimal/CMakeLists.txt @@ -0,0 +1,180 @@ +# cmake-lint: disable=C0103,C0111,E1120,R0913,R0915 +cmake_minimum_required(VERSION 3.16) +cmake_policy(SET CMP0110 NEW) + +set(PROJECT_NAME minimal-test) + +project(${PROJECT_NAME} LANGUAGES CXX C) +set(CMAKE_CXX_EXTENSIONS OFF) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if($DEBUG) + set(CMAKE_BUILD_TYPE + Release + CACHE STRING "CMake build type") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG") +else() + set(CMAKE_BUILD_TYPE + Debug + CACHE STRING "CMake build type") + set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS} -DDEBUG -Wall -Wextra -Wno-unknown-pragmas") +endif() + +set(BUILD_TESTING + OFF + CACHE BOOL "Build tests") + +set(MODES + "KOKKOS;ADIOS2_NOMPI" + CACHE STRING "Build modes") + +function(find_kokkos) + find_package(Kokkos QUIET) + if(NOT Kokkos_FOUND) + include(FetchContent) + FetchContent_Declare( + Kokkos + GIT_REPOSITORY https://github.com/kokkos/kokkos.git + GIT_TAG 4.6.01) + FetchContent_MakeAvailable(Kokkos) + endif() + if(NOT DEFINED Kokkos_ARCH + OR Kokkos_ARCH STREQUAL "" + OR NOT DEFINED Kokkos_DEVICES + OR Kokkos_DEVICES STREQUAL "") + if(${Kokkos_FOUND}) + include(${Kokkos_DIR}/KokkosConfigCommon.cmake) + elseif(NOT ${Kokkos_BUILD_DIR} STREQUAL "") + include(${Kokkos_BUILD_DIR}/KokkosConfigCommon.cmake) + else() + message( + STATUS "${Red}Kokkos_DIR and Kokkos_BUILD_DIR not set.${ColorReset}") + endif() + endif() +endfunction() + +function(find_adios2) + find_package(adios2 QUIET) + if(NOT adios2_FOUND) + include(FetchContent) + FetchContent_Declare( + adios2 + GIT_REPOSITORY https://github.com/ornladios/ADIOS2.git + GIT_TAG 2.10.2) + FetchContent_MakeAvailable(adios2) + endif() +endfunction() + +if("KOKKOS" IN_LIST MODES) + set(libs "") + set(exec kokkos.xc) + set(src ${CMAKE_CURRENT_SOURCE_DIR}/kokkos.cpp) + + find_kokkos() + list(APPEND libs Kokkos::kokkos) + + add_executable(${exec} ${src}) + + target_link_libraries(${exec} ${libs}) +endif() + +if("ADIOS2_NOMPI" IN_LIST MODES) + set(libs stdc++fs) + set(exec adios2-nompi.xc) + set(src ${CMAKE_CURRENT_SOURCE_DIR}/adios2.cpp) + + find_kokkos() + find_adios2() + list(APPEND libs Kokkos::kokkos adios2::cxx11) + + add_executable(${exec} ${src}) + + target_link_libraries(${exec} ${libs}) +endif() + +if("ADIOS2_MPI" IN_LIST MODES) + set(libs stdc++fs) + set(exec adios2-mpi.xc) + set(src ${CMAKE_CURRENT_SOURCE_DIR}/adios2.cpp) + + find_package(MPI REQUIRED) + find_kokkos() + find_adios2() + list(APPEND libs MPI::MPI_CXX Kokkos::kokkos adios2::cxx11_mpi) + + add_executable(${exec} ${src}) + + target_include_directories(${exec} PUBLIC ${MPI_CXX_INCLUDE_PATH}) + target_compile_options(${exec} PUBLIC "-D MPI_ENABLED") + target_link_libraries(${exec} ${libs}) +endif() + +if("MPI" IN_LIST MODES) + set(libs "") + set(exec mpi-simple.xc) + set(src ${CMAKE_CURRENT_SOURCE_DIR}/mpi-simple.cpp) + + find_package(MPI REQUIRED) + find_kokkos() + list(APPEND libs MPI::MPI_CXX Kokkos::kokkos) + + add_executable(${exec} ${src}) + + target_include_directories(${exec} PUBLIC ${MPI_CXX_INCLUDE_PATH}) + target_link_libraries(${exec} ${libs}) + + set(GPU_AWARE_MPI + ON + CACHE BOOL "Enable GPU-aware MPI support") + + if(("${Kokkos_DEVICES}" MATCHES "CUDA") + OR ("${Kokkos_DEVICES}" MATCHES "HIP") + OR ("${Kokkos_DEVICES}" MATCHES "SYCL")) + set(DEVICE_ENABLED ON) + target_compile_options(${exec} PRIVATE -DDEVICE_ENABLED) + else() + set(DEVICE_ENABLED OFF) + endif() + + if(${GPU_AWARE_MPI}) + target_compile_options(${exec} PRIVATE -DGPU_AWARE_MPI) + endif() +endif() + +if("MPI_SIMPLE" IN_LIST MODES) + set(libs "") + set(exec mpi-simple.xc) + set(src ${CMAKE_CURRENT_SOURCE_DIR}/mpi-simple.cpp) + + find_package(MPI REQUIRED) + find_kokkos() + list(APPEND libs MPI::MPI_CXX Kokkos::kokkos) + + add_executable(${exec} ${src}) + + target_include_directories(${exec} PUBLIC ${MPI_CXX_INCLUDE_PATH}) + target_link_libraries(${exec} ${libs}) + + set(GPU_AWARE_MPI + ON + CACHE BOOL "Enable GPU-aware MPI support") + + if(("${Kokkos_DEVICES}" MATCHES "CUDA") + OR ("${Kokkos_DEVICES}" MATCHES "HIP") + OR ("${Kokkos_DEVICES}" MATCHES "SYCL")) + set(DEVICE_ENABLED ON) + target_compile_options(${exec} PRIVATE -DDEVICE_ENABLED) + else() + set(DEVICE_ENABLED OFF) + endif() + + if(${GPU_AWARE_MPI}) + target_compile_options(${exec} PRIVATE -DGPU_AWARE_MPI) + endif() +endif() + +message(STATUS "Build modes: ${MODES}") diff --git a/minimal/README.md b/minimal/README.md new file mode 100644 index 000000000..b7e5691a2 --- /dev/null +++ b/minimal/README.md @@ -0,0 +1,21 @@ +# Minimal third-party tests + +These minimal tests are designed to test the third-party libraries outside of the `Entity` scope. These tests will show whether there is an issue with the way third-party are installed (or the cluster is set up). + +To compile: + +```sh +cmake -B build -D MODES="MPI;MPI_SIMPLE;ADIOS2_NOMPI;ADIOS2_MPI" +cmake --build build -j +``` + +This will produce executables, one for each test, in the `build` directory. + +The `MODES` flag determines the tests it will generate and can be a subset of the following (separated with a `;`): + +- `MPI` test of pure MPI + Kokkos (can also add `-D GPU_AWARE_MPI=OFF` to disable the GPU-aware MPI explicitly); +- `MPI_SIMPLE` a simpler test of pure MPI + Kokkos; +- `ADIOS2_NOMPI` test of ADIOS2 library without MPI; +- `ADIOS2_MPI` same but with MPI. + +All tests also use `Kokkos`. To build `ADIOS2` or `Kokkos` in-tree, you may pass the regular `-D Kokkos_***` and `-D ADIOS2_***` flags to cmake`. diff --git a/minimal/adios2.cpp b/minimal/adios2.cpp new file mode 100644 index 000000000..cd4ca3d6f --- /dev/null +++ b/minimal/adios2.cpp @@ -0,0 +1,438 @@ +#include +#include +#include + +#if defined(MPI_ENABLED) + #include + #define MPI_ROOT_RANK 0 +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +auto pad(const std::string&, std::size_t, char, bool = false) -> std::string; + +template +void CallOnce(Func, Args&&...); + +template +auto define_constdim_array(adios2::IO&, + const std::vector&, + const std::vector&, + const std::vector&) -> std::string; + +template +auto define_unknowndim_array(adios2::IO&) -> std::string; + +template +void put_constdim_array(adios2::IO&, adios2::Engine&, const A&, const std::string&); + +template +void put_unknowndim_array(adios2::IO&, + adios2::Engine&, + const Kokkos::View&, + std::size_t, + const std::string&); + +auto main(int argc, char** argv) -> int { + try { + Kokkos::initialize(argc, argv); +#if defined(MPI_ENABLED) + MPI_Init(&argc, &argv); + adios2::ADIOS adios { MPI_COMM_WORLD }; +#else + adios2::ADIOS adios; +#endif + + std::string engine = "hdf5"; + if (argc > 1) { + engine = std::string(argv[1]); + if (engine != "hdf5" && engine != "bp") { + throw std::invalid_argument("Engine must be either 'hdf5' or 'bp'"); + } + } + const std::string format = (engine == "hdf5") ? "h5" : "bp"; + + auto io = adios.DeclareIO("Test::Output"); + io.SetEngine(engine); + + io.DefineAttribute("Attr::Int", 42); + io.DefineAttribute("Attr::Float", 42.0f); + io.DefineAttribute("Attr::Double", 42.0); + io.DefineAttribute("Attr::String", engine); + + io.DefineVariable("Var::Int"); + io.DefineVariable("Var::Size_t"); + + int rank = 0, size = 1; +#if defined(MPI_ENABLED) + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); +#endif + + // global sizes + const std::size_t Sx_1d = (size - 1) * 1000 + 230; + const std::size_t Sx_2d = 100, Sy_2d = (size - 1) * 100 + 23; + const std::size_t Sx_3d = 10, Sy_3d = 10, Sz_3d = (size - 1) * 10 + 3; + + // local sizes + const std::size_t sx_1d = (rank != size - 1) ? 1000 : 230; + const std::size_t sx_2d = 100, sy_2d = (rank != size - 1) ? 100 : 23; + const std::size_t sx_3d = 10, sy_3d = 10, sz_3d = (rank != size - 1) ? 10 : 3; + + // displacements + const std::size_t ox_1d = rank * 1000; + const std::size_t ox_2d = 0, oy_2d = rank * 100; + const std::size_t ox_3d = 0, oy_3d = 0, oz_3d = rank * 10; + + CallOnce( + [](auto&& size) { + std::cout << "Running ADIOS2 test" << std::endl; +#if defined(MPI_ENABLED) + std::cout << "- Number of MPI ranks: " << size << std::endl; +#else + (void)size; + std::cout << "- No MPI" << std::endl; +#endif + }, + size); + + std::vector vars; + + { + vars.push_back( + define_constdim_array(io, { Sx_1d }, { ox_1d }, { sx_1d })); + vars.push_back(define_constdim_array(io, + { Sx_2d, Sy_2d }, + { ox_2d, oy_2d }, + { sx_2d, sy_2d })); + vars.push_back(define_constdim_array(io, + { Sx_3d, Sy_3d, Sz_3d }, + { ox_3d, oy_3d, oz_3d }, + { sx_3d, sy_3d, sz_3d })); + vars.push_back( + define_constdim_array(io, { Sx_1d }, { ox_1d }, { sx_1d })); + vars.push_back(define_constdim_array(io, + { Sx_2d, Sy_2d }, + { ox_2d, oy_2d }, + { sx_2d, sy_2d })); + vars.push_back(define_constdim_array(io, + { Sx_3d, Sy_3d, Sz_3d }, + { ox_3d, oy_3d, oz_3d }, + { sx_3d, sy_3d, sz_3d })); + } + + { + vars.push_back(define_unknowndim_array(io)); + vars.push_back(define_unknowndim_array(io)); + vars.push_back(define_unknowndim_array(io)); + } + + Kokkos::View constdim_1d_f { "constdim_1d_f", sx_1d }; + Kokkos::View constdim_2d_f { "constdim_2d_f", sx_2d, sy_2d }; + Kokkos::View constdim_3d_f { "constdim_3d_f", sx_3d, sy_3d, sz_3d }; + + Kokkos::View constdim_1d_d { "constdim_1d_d", sx_1d }; + Kokkos::View constdim_2d_d { "constdim_2d_d", sx_2d, sy_2d }; + Kokkos::View constdim_3d_d { "constdim_3d_d", sx_3d, sy_3d, sz_3d }; + + { + // fill 1d + Kokkos::parallel_for( + "fill_constdim_1d_f", + Kokkos::RangePolicy<>(0, sx_1d), + KOKKOS_LAMBDA(std::size_t i) { + constdim_1d_f(i) = static_cast(ox_1d + i); + constdim_1d_d(i) = static_cast(ox_1d + i); + }); + + // fill 2d + Kokkos::parallel_for( + "fill_constdim_2d_f", + Kokkos::MDRangePolicy>({ 0, 0 }, { sx_2d, sy_2d }), + KOKKOS_LAMBDA(std::size_t i, std::size_t j) { + constdim_2d_f(i, j) = static_cast(ox_2d + i + (oy_2d + j) * Sx_2d); + constdim_2d_d(i, j) = static_cast(ox_2d + i + (oy_2d + j) * Sx_2d); + }); + + // fill 3d + Kokkos::parallel_for( + "fill_constdim_3d_f", + Kokkos::MDRangePolicy>({ 0, 0, 0 }, { sx_3d, sy_3d, sz_3d }), + KOKKOS_LAMBDA(std::size_t i, std::size_t j, std::size_t k) { + constdim_3d_f(i, j, k) = static_cast( + ox_3d + i + (oy_3d + j + (oz_3d + k) * Sy_3d) * Sx_3d); + constdim_3d_d(i, j, k) = static_cast( + ox_3d + i + (oy_3d + j + (oz_3d + k) * Sy_3d) * Sx_3d); + }); + } + + { + // test multiple file mode + const std::string path = "steps"; + CallOnce( + [](auto&& path) { + const std::filesystem::path parent_path { path }; + if (std::filesystem::exists(parent_path)) { + std::filesystem::remove_all(parent_path); + } + std::filesystem::create_directory(path); + }, + path); + for (auto step { 0u }; step < 5u; ++step) { + const std::string filename = path + "/step_" + + pad(std::to_string(step * 20u), 6, '0') + + "." + format; + auto writer = io.Open(filename, adios2::Mode::Write); + writer.BeginStep(); + + { + // constant dim arrays + put_constdim_array(io, + writer, + constdim_1d_f, + vars[0]); + put_constdim_array(io, + writer, + constdim_2d_f, + vars[1]); + put_constdim_array(io, + writer, + constdim_3d_f, + vars[2]); + put_constdim_array(io, + writer, + constdim_1d_d, + vars[3]); + put_constdim_array(io, + writer, + constdim_2d_d, + vars[4]); + put_constdim_array(io, + writer, + constdim_3d_d, + vars[5]); + } + + { + // unknown dim arrays + const std::size_t nelems = static_cast( + (std::sin((step + 1 + rank) * 0.25) + 2.0) * 1000.0); + + Kokkos::View unknowndim_f { "unknowndim_f", nelems }; + Kokkos::View unknowndim_d { "unknowndim_d", nelems }; + Kokkos::View unknowndim_i { "unknowndim_i", nelems }; + + // fill unknown dim arrays + Kokkos::parallel_for( + "fill_unknowndim", + Kokkos::RangePolicy<>(0, nelems), + KOKKOS_LAMBDA(std::size_t i) { + unknowndim_f(i) = static_cast(i + step * 1000); + unknowndim_d(i) = static_cast(i + step * 1000); + unknowndim_i(i) = static_cast(i + step * 1000); + }); + + put_unknowndim_array(io, writer, unknowndim_f, nelems, vars[6]); + put_unknowndim_array(io, writer, unknowndim_d, nelems, vars[7]); + put_unknowndim_array(io, writer, unknowndim_i, nelems, vars[8]); + } + + writer.EndStep(); + writer.Close(); + } + } + { + // test single file mode + const std::string filename = "allsteps." + format; + adios2::Mode mode = adios2::Mode::Write; + for (auto step { 0u }; step < 5u; ++step) { + auto writer = io.Open(filename, mode); + writer.BeginStep(); + + { + // constant dim arrays + put_constdim_array(io, + writer, + constdim_1d_f, + vars[0]); + put_constdim_array(io, + writer, + constdim_2d_f, + vars[1]); + put_constdim_array(io, + writer, + constdim_3d_f, + vars[2]); + put_constdim_array(io, + writer, + constdim_1d_d, + vars[3]); + put_constdim_array(io, + writer, + constdim_2d_d, + vars[4]); + put_constdim_array(io, + writer, + constdim_3d_d, + vars[5]); + } + + { + // unknown dim arrays + const std::size_t nelems = static_cast( + (std::sin((step + 1 + rank) * 0.25) + 2.0) * 1000.0); + + Kokkos::View unknowndim_f { "unknowndim_f", nelems }; + Kokkos::View unknowndim_d { "unknowndim_d", nelems }; + Kokkos::View unknowndim_i { "unknowndim_i", nelems }; + + // fill unknown dim arrays + Kokkos::parallel_for( + "fill_unknowndim", + Kokkos::RangePolicy<>(0, nelems), + KOKKOS_LAMBDA(std::size_t i) { + unknowndim_f(i) = static_cast(i + step * 1000); + unknowndim_d(i) = static_cast(i + step * 1000); + unknowndim_i(i) = static_cast(i + step * 1000); + }); + + put_unknowndim_array(io, writer, unknowndim_f, nelems, vars[6]); + put_unknowndim_array(io, writer, unknowndim_d, nelems, vars[7]); + put_unknowndim_array(io, writer, unknowndim_i, nelems, vars[8]); + } + + writer.EndStep(); + writer.Close(); + mode = adios2::Mode::Append; + } + } + } catch (const std::exception& e) { +#if defined(MPI_ENABLED) + if (MPI_COMM_WORLD != MPI_COMM_NULL) { + MPI_Finalize(); + } +#endif + if (Kokkos::is_initialized()) { + Kokkos::finalize(); + } + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } + +#if defined(MPI_ENABLED) + MPI_Finalize(); +#endif + Kokkos::finalize(); + return 0; +} + +auto pad(const std::string& str, std::size_t n, char c, bool right) -> std::string { + if (n <= str.size()) { + return str; + } + if (right) { + return str + std::string(n - str.size(), c); + } + return std::string(n - str.size(), c) + str; +} + +#if !defined(MPI_ENABLED) + +template +void CallOnce(Func func, Args&&... args) { + func(std::forward(args)...); +} + +#else + +template +void CallOnce(Func func, Args&&... args) { + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + if (rank == MPI_ROOT_RANK) { + func(std::forward(args)...); + } +} +#endif + +template +auto define_constdim_array(adios2::IO& io, + const std::vector& glob_shape, + const std::vector& loc_corner, + const std::vector& loc_shape) -> std::string { + const std::string arrname = "ConstantDimArr" + + std::to_string(glob_shape.size()) + + "D::" + std::string(typeid(T).name()); + io.DefineVariable(arrname, glob_shape, loc_corner, loc_shape, adios2::ConstantDims); + return arrname; +} + +template +auto define_unknowndim_array(adios2::IO& io) -> std::string { + const std::string arrname = "UnknownDimArr::" + std::string(typeid(T).name()); + io.DefineVariable(arrname, + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + return arrname; +} + +template +void put_constdim_array(adios2::IO& io, + adios2::Engine& writer, + const A& array, + const std::string& varname) { + auto var = io.InquireVariable(varname); + if (!var) { + throw std::runtime_error("Variable not found: " + varname); + } + auto array_h = Kokkos::create_mirror_view(array); + Kokkos::deep_copy(array_h, array); + writer.Put(var, array_h); +} + +template +void put_unknowndim_array(adios2::IO& io, + adios2::Engine& writer, + const Kokkos::View& array, + std::size_t nelems, + const std::string& varname) { + auto var = io.InquireVariable(varname); + if (!var) { + throw std::runtime_error("Variable not found: " + varname); + } + std::size_t glob_nelems = nelems; + std::size_t offset_nelems = 0u; +#if defined(MPI_ENABLED) + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + std::vector all_nelems(size); + MPI_Allgather(&nelems, + 1, + MPI_UNSIGNED_LONG, + all_nelems.data(), + 1, + MPI_UNSIGNED_LONG, + MPI_COMM_WORLD); + glob_nelems = 0u; + for (int r = 0; r < size; ++r) { + if (r < rank) { + offset_nelems += all_nelems[r]; + } + glob_nelems += all_nelems[r]; + } +#endif + var.SetShape({ glob_nelems }); + var.SetSelection(adios2::Box({ offset_nelems }, { nelems })); + auto array_h = Kokkos::create_mirror_view(array); + Kokkos::deep_copy(array_h, array); + writer.Put(var, array_h); +} diff --git a/minimal/kokkos.cpp b/minimal/kokkos.cpp new file mode 100644 index 000000000..2be2996a8 --- /dev/null +++ b/minimal/kokkos.cpp @@ -0,0 +1,58 @@ +#include + +#include +#include + +auto main(int argc, char** argv) -> int { + try { + Kokkos::initialize(argc, argv); + Kokkos::DefaultExecutionSpace {}.print_configuration(std::cout); + + std::cout << "1D views" << std::endl; + for (const auto& sz : { 100u, 10000u, 1000000u }) { + Kokkos::View view { "test_view", sz }; + Kokkos::parallel_for( + "fill_1d", + Kokkos::RangePolicy<>(0, sz), + KOKKOS_LAMBDA(std::size_t i) { view(i) = static_cast(i); }); + Kokkos::fence(); + std::cout << "- allocated " << view.size() << std::endl; + } + + std::cout << "2D views" << std::endl; + for (const auto& sz : { 10u, 100u, 1000u }) { + Kokkos::View view { "test_view", sz, 2 * sz }; + Kokkos::parallel_for( + "fill_2d", + Kokkos::MDRangePolicy>({ 0, 0 }, { sz, 2 * sz }), + KOKKOS_LAMBDA(std::size_t i, std::size_t j) { + view(i, j) = static_cast(i * 2 * sz + j); + }); + Kokkos::fence(); + std::cout << "- allocated " << view.size() << std::endl; + } + + std::cout << "3D views" << std::endl; + for (const auto& sz : { 10u, 100u }) { + Kokkos::View view { "test_view", sz, 2 * sz, 3 * sz }; + Kokkos::parallel_for( + "fill_3d", + Kokkos::MDRangePolicy>({ 0, 0, 0 }, { sz, 2 * sz, 3 * sz }), + KOKKOS_LAMBDA(std::size_t i, std::size_t j, std::size_t k) { + view(i, j, k) = static_cast(i * 2 * sz * 3 * sz + j * 3 * sz + k); + }); + Kokkos::fence(); + std::cout << "- allocated " << view.size() << std::endl; + } + + } catch (const std::exception& e) { + if (Kokkos::is_initialized()) { + Kokkos::finalize(); + } + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } + + Kokkos::finalize(); + return 0; +} diff --git a/minimal/mpi-simple.cpp b/minimal/mpi-simple.cpp new file mode 100644 index 000000000..4663d73be --- /dev/null +++ b/minimal/mpi-simple.cpp @@ -0,0 +1,84 @@ +#include +#include + +#include +#include +#include + +auto main(int argc, char** argv) -> int { + try { + Kokkos::initialize(argc, argv); + MPI_Init(&argc, &argv); + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + + const auto nelems = 500u; + const auto nsend = 10u; + const auto nrecv = 10u; + + if (rank == 0) { + std::cout << "Running the simple MPI communication test" << std::endl; + std::cout << "- Number of MPI ranks: " << size << std::endl; + std::cout << "- Number elements to send/recv (2D): " << nelems << "x" + << nsend << std::endl; +#if defined(GPU_AWARE_MPI) && defined(DEVICE_ENABLED) + std::cout << "- GPU-aware MPI is enabled" << std::endl; +#else + std::cout << "- GPU-aware MPI is disabled" << std::endl; +#endif + } + + Kokkos::View view("view", nelems, nelems); + Kokkos::View send("send", nsend, nelems); + Kokkos::View recv("recv", nrecv, nelems); + Kokkos::deep_copy( + send, + Kokkos::subview(view, std::make_pair(0u, nsend), Kokkos::ALL)); + +#if defined(GPU_AWARE_MPI) || !defined(DEVICE_ENABLED) + MPI_Sendrecv(send.data(), + nsend * nelems, + MPI_FLOAT, + (rank + 1) % size, + 0, + recv.data(), + nrecv * nelems, + MPI_FLOAT, + (rank - 1 + size) % size, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); +#else + auto send_h = Kokkos::create_mirror_view(send); + auto recv_h = Kokkos::create_mirror_view(recv); + Kokkos::deep_copy(send_h, send); + MPI_Sendrecv(send_h.data(), + nsend * nelems, + MPI_FLOAT, + (rank + 1) % size, + 0, + recv_h.data(), + nrecv * nelems, + MPI_FLOAT, + (rank - 1 + size) % size, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + Kokkos::deep_copy(recv, recv_h); +#endif + } catch (const std::exception& e) { + if (MPI_COMM_WORLD != MPI_COMM_NULL) { + MPI_Finalize(); + } + if (Kokkos::is_initialized()) { + Kokkos::finalize(); + } + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } + + MPI_Finalize(); + Kokkos::finalize(); + return 0; +} diff --git a/minimal/mpi.cpp b/minimal/mpi.cpp new file mode 100644 index 000000000..821b5fcae --- /dev/null +++ b/minimal/mpi.cpp @@ -0,0 +1,306 @@ +#include +#include + +#include +#include +#include +#include +#include + +#define MPI_ROOT_RANK 0 +#define N_GHOSTS 2 + +template +void CallOnce(Func, Args&&...); + +template +using R = std::conditional_t< + D == 1, + T*, + std::conditional_t>>; + +template +void send_recv(int send_to, + int recv_from, + bool sendxmin, + const Kokkos::View[N]>& view, + std::size_t smallsize) { + const auto mpi_type = std::is_same_v ? MPI_FLOAT : MPI_DOUBLE; + std::size_t nsend = 0; + Kokkos::View[N]> send_buffer; + if (send_to == MPI_PROC_NULL) { + nsend = 0; + } else { + std::pair range = { 0, N_GHOSTS }; + if (not sendxmin) { + range = { view.extent(0) - N_GHOSTS, view.extent(0) }; + } + if constexpr (D == 1) { + nsend = N_GHOSTS * N; + send_buffer = Kokkos::View[N]> { + "comm_1d_send_buffer", N_GHOSTS + }; + Kokkos::deep_copy(send_buffer, Kokkos::subview(view, range, Kokkos::ALL)); + } else if constexpr (D == 2) { + nsend = N_GHOSTS * smallsize * N; + send_buffer = Kokkos::View[N]> { + "comm_2d_send_buffer", N_GHOSTS, smallsize + }; + Kokkos::deep_copy(send_buffer, + Kokkos::subview(view, range, Kokkos::ALL, Kokkos::ALL)); + } else if constexpr (D == 3) { + nsend = N_GHOSTS * smallsize * smallsize * N; + send_buffer = Kokkos::View[N]> { + "comm_3d_send_buffer", N_GHOSTS, smallsize, smallsize + }; + Kokkos::deep_copy( + send_buffer, + Kokkos::subview(view, range, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL)); + } + } + + std::size_t nrecv = 0; + Kokkos::View[N]> recv_buffer; + if (recv_from == MPI_PROC_NULL) { + nrecv = 0; + } else { + if constexpr (D == 1) { + nrecv = N_GHOSTS * N; + recv_buffer = Kokkos::View[N]> { + "comm_1d_recv_buffer", N_GHOSTS + }; + } else if constexpr (D == 2) { + nrecv = N_GHOSTS * smallsize * N; + recv_buffer = Kokkos::View[N]> { + "comm_2d_recv_buffer", N_GHOSTS, smallsize + }; + } else if constexpr (D == 3) { + nrecv = N_GHOSTS * smallsize * smallsize * N; + recv_buffer = Kokkos::View[N]> { + "comm_3d_recv_buffer", N_GHOSTS, smallsize, smallsize + }; + } + } + + if (nrecv == 0 and nsend == 0) { + throw std::invalid_argument( + "Both nsend and nrecv are zero, no communication to perform."); + } else if (nrecv > 0 and nsend > 0) { +#if defined(GPU_AWARE_MPI) || !defined(DEVICE_ENABLED) + MPI_Sendrecv(send_buffer.data(), + nsend, + mpi_type, + send_to, + 0, + recv_buffer.data(), + nrecv, + mpi_type, + recv_from, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); +#else + auto send_buffer_h = Kokkos::create_mirror_view(send_buffer); + auto recv_buffer_h = Kokkos::create_mirror_view(recv_buffer); + Kokkos::deep_copy(send_buffer_h, send_buffer); + MPI_Sendrecv(send_buffer_h.data(), + nsend, + mpi_type, + send_to, + 0, + recv_buffer_h.data(), + nrecv, + mpi_type, + recv_from, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + Kokkos::deep_copy(recv_buffer, recv_buffer_h); +#endif + } else if (nrecv > 0) { +#if defined(GPU_AWARE_MPI) || !defined(DEVICE_ENABLED) + MPI_Recv(recv_buffer.data(), + nrecv, + mpi_type, + recv_from, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); +#else + auto recv_buffer_h = Kokkos::create_mirror_view(recv_buffer); + MPI_Recv(recv_buffer_h.data(), + nrecv, + mpi_type, + recv_from, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + Kokkos::deep_copy(recv_buffer, recv_buffer_h); +#endif + } else if (nsend > 0) { +#if defined(GPU_AWARE_MPI) || !defined(DEVICE_ENABLED) + MPI_Send(send_buffer.data(), nsend, mpi_type, send_to, 0, MPI_COMM_WORLD); +#else + auto send_buffer_h = Kokkos::create_mirror_view(send_buffer); + Kokkos::deep_copy(send_buffer_h, send_buffer); + MPI_Send(send_buffer_h.data(), nsend, mpi_type, send_to, 0, MPI_COMM_WORLD); +#endif + } + + if (nrecv > 0) { + std::pair range = { view.extent(0) - N_GHOSTS, + view.extent(0) }; + if (not sendxmin) { + range = { 0, N_GHOSTS }; + } + if constexpr (D == 1) { + Kokkos::deep_copy(Kokkos::subview(view, range, Kokkos::ALL), recv_buffer); + } else if constexpr (D == 2) { + Kokkos::deep_copy(Kokkos::subview(view, range, Kokkos::ALL, Kokkos::ALL), + recv_buffer); + } else if constexpr (D == 3) { + Kokkos::deep_copy( + Kokkos::subview(view, range, Kokkos::ALL, Kokkos::ALL, Kokkos::ALL), + recv_buffer); + } + } +} + +template +void comm(int rank, int size, std::size_t bigsize, std::size_t smallsize) { + static_assert(D <= 3 and D != 0, "Only dimensions 1, 2, and 3 are supported."); + static_assert(N == 3 or N == 6, "Only 3 or 6 last indices are supported."); + static_assert(std::is_same_v || std::is_same_v, + "Only float and double types are supported."); + + // smallsize must be the same for all ranks + if (bigsize < 2 * N_GHOSTS) { + throw std::invalid_argument( + "bigsize must be at least 2 * N_GHOSTS for communication to work."); + } + + Kokkos::View[N]> view; + + // define and fill the view + if constexpr (D == 1) { + view = Kokkos::View[N]> { + "comm_1d_view", bigsize + }; + Kokkos::parallel_for( + "fill_comm_1d_view", + Kokkos::MDRangePolicy>({ 0, 0 }, + { view.extent(0), view.extent(1) }), + KOKKOS_LAMBDA(std::size_t i, std::size_t c) { + view(i, c) = static_cast(i * c + rank); + }); + } else if constexpr (D == 2) { + view = Kokkos::View[N]> { + "comm_2d_view", bigsize, smallsize + }; + Kokkos::parallel_for( + "fill_comm_2d_view", + Kokkos::MDRangePolicy>( + { 0, 0, 0 }, + { view.extent(0), view.extent(1), view.extent(2) }), + KOKKOS_LAMBDA(std::size_t i, std::size_t j, std::size_t c) { + view(i, j, c) = static_cast(i * j * c + rank); + }); + } else if constexpr (D == 3) { + view = Kokkos::View[N]> { + "comm_3d_view", bigsize, smallsize, smallsize + }; + Kokkos::parallel_for( + "fill_comm_3d_view", + Kokkos::MDRangePolicy>( + { 0, 0, 0, 0 }, + { view.extent(0), view.extent(1), view.extent(2), view.extent(3) }), + KOKKOS_LAMBDA(std::size_t i, std::size_t j, std::size_t k, std::size_t c) { + view(i, j, k, c) = static_cast(i * j * k * c + rank); + }); + } + + // communicate + const int r_neighbor = (rank != size - 1) ? rank + 1 : MPI_PROC_NULL; + const int l_neighbor = (rank != 0) ? rank - 1 : MPI_PROC_NULL; + + send_recv(r_neighbor, l_neighbor, false, view, smallsize); + send_recv(l_neighbor, r_neighbor, true, view, smallsize); + + MPI_Barrier(MPI_COMM_WORLD); + CallOnce([]() { + std::cout << "Finished " << D << "D "; + if constexpr (std::is_same_v) { + std::cout << "float"; + } else { + std::cout << "double"; + } + std::cout << " communication test" << std::endl; + }); +} + +auto main(int argc, char** argv) -> int { + try { + Kokkos::initialize(argc, argv); + MPI_Init(&argc, &argv); + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + + const std::size_t bigsize = (std::sin((rank + 1) * 0.25) + 2) * 1e3; + const std::size_t smallsize = 123; + + CallOnce( + [](auto&& size, auto&& bigsize, auto&& smallsize) { + std::cout << "Running the MPI communication test" << std::endl; + std::cout << "- Number of MPI ranks: " << size << std::endl; + std::cout << "- Big size: " << bigsize << std::endl; + std::cout << "- Small size: " << smallsize << std::endl; +#if defined(GPU_AWARE_MPI) && defined(DEVICE_ENABLED) + std::cout << "- GPU-aware MPI is enabled" << std::endl; +#else + std::cout << "- GPU-aware MPI is disabled" << std::endl; +#endif + }, + size, + bigsize, + smallsize); + + comm(rank, size, bigsize, smallsize); + comm(rank, size, bigsize, smallsize); + comm(rank, size, bigsize, smallsize); + + comm(rank, size, bigsize, smallsize); + comm(rank, size, bigsize, smallsize); + comm(rank, size, bigsize, smallsize); + + comm(rank, size, bigsize, smallsize); + comm(rank, size, bigsize, smallsize); + comm(rank, size, bigsize, smallsize); + + comm(rank, size, bigsize, smallsize); + comm(rank, size, bigsize, smallsize); + comm(rank, size, bigsize, smallsize); + } catch (const std::exception& e) { + if (MPI_COMM_WORLD != MPI_COMM_NULL) { + MPI_Finalize(); + } + if (Kokkos::is_initialized()) { + Kokkos::finalize(); + } + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } + + MPI_Finalize(); + Kokkos::finalize(); + return 0; +} + +template +void CallOnce(Func func, Args&&... args) { + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + if (rank == MPI_ROOT_RANK) { + func(std::forward(args)...); + } +} diff --git a/pgens/CMakeLists.txt b/pgens/CMakeLists.txt new file mode 100644 index 000000000..e3d047a98 --- /dev/null +++ b/pgens/CMakeLists.txt @@ -0,0 +1,14 @@ +# ------------------------------ +# @defines: ntt_pgen [INTERFACE] +# +# @includes: +# +# * ../src/ +# ------------------------------ + +add_library(ntt_pgen INTERFACE) +target_link_libraries(ntt_pgen INTERFACE ntt_global ntt_framework + ntt_archetypes ntt_kernels) + +target_include_directories(ntt_pgen + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/${PGEN}) diff --git a/pgens/accretion/accretion.toml b/pgens/accretion/accretion.toml new file mode 100644 index 000000000..1ec641430 --- /dev/null +++ b/pgens/accretion/accretion.toml @@ -0,0 +1,85 @@ +[simulation] + name = "wald" + engine = "grpic" + runtime = 500.0 + +[grid] + resolution = [256, 256] + extent = [[1.0, 6.0]] + + [grid.metric] + metric = "qkerr_schild" + qsph_r0 = 0.0 + qsph_h = 0.0 + ks_a = 0.95 + + [grid.boundaries] + fields = [["MATCH"]] + particles = [["ABSORB"]] + + [grid.boundaries.absorb] + ds = 1.0 + +[scales] + larmor0 = 0.025 + skindepth0 = 0.5 + +[algorithms] + current_filters = 4 + + [algorithms.gr] + pusher_niter = 10 + pusher_eps = 1e-2 + + [algorithms.timestep] + CFL = 0.5 + correction = 1.0 + + [algorithms.toggles] + deposit = true + fieldsolver = true + +[particles] + ppc0 = 4.0 + use_weights = true + clear_interval = 100 + + [[particles.species]] + label = "e-" + mass = 1.0 + charge = -1.0 + maxnpart = 2e8 + pusher = "Boris" + + [[particles.species]] + label = "e+" + mass = 1.0 + charge = 1.0 + maxnpart = 2e8 + pusher = "Boris" + +[setup] + multiplicity = 1.0 + sigma_max = 1000.0 + temperature = 0.01 + xi_min = [1.5, 0.0] + xi_max = [4.0, 3.14159265] + m_eps = 1.0 + +[output] + format = "hdf5" + + [output.fields] + interval_time = 1.0 + quantities = ["D", "B", "N_1", "N_2", "A"] + + [output.particles] + enable = false + + [output.spectra] + enable = false + +[diagnostics] + interval = 2 + colored_stdout = true + blocking_timers = true diff --git a/pgens/accretion/pgen.hpp b/pgens/accretion/pgen.hpp new file mode 100644 index 000000000..54a607352 --- /dev/null +++ b/pgens/accretion/pgen.hpp @@ -0,0 +1,283 @@ +#ifndef PROBLEM_GENERATOR_H +#define PROBLEM_GENERATOR_H + +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "arch/traits.h" +#include "utils/numeric.h" + +#include "archetypes/energy_dist.h" +#include "archetypes/particle_injector.h" +#include "archetypes/problem_generator.h" +#include "archetypes/spatial_dist.h" +#include "framework/domain/metadomain.h" + +#include "kernels/particle_moments.hpp" + +namespace user { + using namespace ntt; + + template + struct InitFields { + InitFields(M metric_, real_t m_eps) : metric { metric_ }, m_eps { m_eps } {} + + Inline auto A_3(const coord_t& x_Cd) const -> real_t { + return HALF * (metric.template h_<3, 3>(x_Cd) + + TWO * metric.spin() * metric.template h_<1, 3>(x_Cd) * + metric.beta1(x_Cd)); + } + + Inline auto A_1(const coord_t& x_Cd) const -> real_t { + return HALF * (metric.template h_<1, 3>(x_Cd) + + TWO * metric.spin() * metric.template h_<1, 1>(x_Cd) * + metric.beta1(x_Cd)); + } + + Inline auto A_0(const coord_t& x_Cd) const -> real_t { + real_t g_00 { -metric.alpha(x_Cd) * metric.alpha(x_Cd) + + metric.template h_<1, 1>(x_Cd) * metric.beta1(x_Cd) * + metric.beta1(x_Cd) }; + return HALF * (metric.template h_<1, 3>(x_Cd) * metric.beta1(x_Cd) + + TWO * metric.spin() * g_00); + } + + Inline auto bx1(const coord_t& x_Ph) const + -> real_t { // at ( i , j + HALF ) + coord_t xi { ZERO }, x0m { ZERO }, x0p { ZERO }; + metric.template convert(x_Ph, xi); + + x0m[0] = xi[0]; + x0m[1] = xi[1] - HALF * m_eps; + x0p[0] = xi[0]; + x0p[1] = xi[1] + HALF * m_eps; + + real_t inv_sqrt_detH_ijP { ONE / metric.sqrt_det_h({ xi[0], xi[1] }) }; + + if (cmp::AlmostZero(x_Ph[1])) { + return ONE; + } else { + return (A_3(x0p) - A_3(x0m)) * inv_sqrt_detH_ijP / m_eps; + } + } + + Inline auto bx2(const coord_t& x_Ph) const + -> real_t { // at ( i + HALF , j ) + coord_t xi { ZERO }, x0m { ZERO }, x0p { ZERO }; + metric.template convert(x_Ph, xi); + + x0m[0] = xi[0] - HALF * m_eps; + x0m[1] = xi[1]; + x0p[0] = xi[0] + HALF * m_eps; + x0p[1] = xi[1]; + + real_t inv_sqrt_detH_ijP { ONE / metric.sqrt_det_h({ xi[0], xi[1] }) }; + if (cmp::AlmostZero(x_Ph[1])) { + return ZERO; + } else { + return -(A_3(x0p) - A_3(x0m)) * inv_sqrt_detH_ijP / m_eps; + } + } + + Inline auto bx3(const coord_t& x_Ph) const -> real_t { + return ZERO; + } + + Inline auto dx1(const coord_t& x_Ph) const -> real_t { + return ZERO; + } + + Inline auto dx2(const coord_t& x_Ph) const -> real_t { + return ZERO; + } + + Inline auto dx3(const coord_t& x_Ph) const -> real_t { + return ZERO; + } + + private: + const M metric; + const real_t m_eps; + }; + + template + struct PointDistribution : public arch::SpatialDistribution { + PointDistribution(const std::vector& xi_min, + const std::vector& xi_max, + const real_t sigma_thr, + const real_t dens_thr, + const SimulationParams& params, + Domain* domain_ptr) + : arch::SpatialDistribution { domain_ptr->mesh.metric } + , metric { domain_ptr->mesh.metric } + , EM { domain_ptr->fields.em } + , density { domain_ptr->fields.buff } + , sigma_thr { sigma_thr } + , inv_n0 { ONE / params.template get("scales.n0") } + , dens_thr { dens_thr } { + std::copy(xi_min.begin(), xi_min.end(), x_min); + std::copy(xi_max.begin(), xi_max.end(), x_max); + + std::vector specs {}; + for (auto& sp : domain_ptr->species) { + if (sp.mass() > 0) { + specs.push_back(sp.index()); + } + } + + Kokkos::deep_copy(density, ZERO); + auto scatter_buff = Kokkos::Experimental::create_scatter_view(density); + // some parameters + auto& mesh = domain_ptr->mesh; + const auto use_weights = params.template get( + "particles.use_weights"); + const auto ni2 = mesh.n_active(in::x2); + + for (const auto& sp : specs) { + auto& prtl_spec = domain_ptr->species[sp - 1]; + // clang-format off + Kokkos::parallel_for( + "ComputeMoments", + prtl_spec.rangeActiveParticles(), + kernel::ParticleMoments_kernel({}, scatter_buff, 0u, + prtl_spec.i1, prtl_spec.i2, prtl_spec.i3, + prtl_spec.dx1, prtl_spec.dx2, prtl_spec.dx3, + prtl_spec.ux1, prtl_spec.ux2, prtl_spec.ux3, + prtl_spec.phi, prtl_spec.weight, prtl_spec.tag, + prtl_spec.mass(), prtl_spec.charge(), + use_weights, + metric, mesh.flds_bc(), + ni2, inv_n0, ZERO)); + // clang-format on + } + Kokkos::Experimental::contribute(density, scatter_buff); + } + + Inline auto sigma_crit(const coord_t& x_Ph) const -> bool { + coord_t xi { ZERO }; + if constexpr (M::Dim == Dim::_2D) { + metric.template convert(x_Ph, xi); + const auto i1 = static_cast(xi[0]) + static_cast(N_GHOSTS); + const auto i2 = static_cast(xi[1]) + static_cast(N_GHOSTS); + const vec_t B_cntrv { EM(i1, i2, em::bx1), + EM(i1, i2, em::bx2), + EM(i1, i2, em::bx3) }; + const vec_t D_cntrv { EM(i1, i2, em::dx1), + EM(i1, i2, em::dx2), + EM(i1, i2, em::dx3) }; + vec_t B_cov { ZERO }; + metric.template transform(xi, B_cntrv, B_cov); + const auto bsqr = + DOT(B_cntrv[0], B_cntrv[1], B_cntrv[2], B_cov[0], B_cov[1], B_cov[2]); + const auto dens = density(i1, i2, 0); + return (bsqr > sigma_thr * dens) || (dens < dens_thr); + } + return false; + } + + Inline auto operator()(const coord_t& x_Ph) const -> real_t { + auto fill = true; + for (auto d = 0u; d < M::Dim; ++d) { + fill &= x_Ph[d] > x_min[d] and x_Ph[d] < x_max[d] and sigma_crit(x_Ph); + } + return fill ? ONE : ZERO; + } + + private: + tuple_t x_min; + tuple_t x_max; + const real_t sigma_thr; + const real_t dens_thr; + const real_t inv_n0; + Domain* domain_ptr; + ndfield_t density; + ndfield_t EM; + const M metric; + }; + + template + struct PGen : public arch::ProblemGenerator { + // compatibility traits for the problem generator + static constexpr auto engines { traits::compatible_with::value }; + static constexpr auto metrics { + traits::compatible_with::value + }; + static constexpr auto dimensions { traits::compatible_with::value }; + + // for easy access to variables in the child class + using arch::ProblemGenerator::D; + using arch::ProblemGenerator::C; + using arch::ProblemGenerator::params; + + const std::vector xi_min; + const std::vector xi_max; + const real_t sigma0, sigma_max, multiplicity, nGJ, temperature, m_eps; + + InitFields init_flds; + const Metadomain* metadomain; + + inline PGen(SimulationParams& p, const Metadomain& m) + : arch::ProblemGenerator(p) + , xi_min { p.template get>("setup.xi_min") } + , xi_max { p.template get>("setup.xi_max") } + , sigma_max { p.template get("setup.sigma_max") } + , sigma0 { p.template get("scales.sigma0") } + , multiplicity { p.template get("setup.multiplicity") } + , nGJ { p.template get("scales.B0") * + SQR(p.template get("scales.skindepth0")) } + , temperature { p.template get("setup.temperature") } + , m_eps { p.template get("setup.m_eps") } + , init_flds { m.mesh().metric, m_eps } + , metadomain { &m } {} + + inline void InitPrtls(Domain& local_domain) { + const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, + local_domain.random_pool, + temperature); + const auto spatial_dist = PointDistribution(xi_min, + xi_max, + sigma_max / sigma0, + multiplicity * nGJ, + params, + &local_domain); + + const auto injector = + arch::NonUniformInjector( + energy_dist, + spatial_dist, + { 1, 2 }); + arch::InjectNonUniform(params, + local_domain, + injector, + 1.0, + true); + } + + void CustomPostStep(std::size_t, long double time, Domain& local_domain) { + const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, + local_domain.random_pool, + temperature); + const auto spatial_dist = PointDistribution(xi_min, + xi_max, + sigma_max / sigma0, + multiplicity * nGJ, + params, + &local_domain); + + const auto injector = + arch::NonUniformInjector( + energy_dist, + spatial_dist, + { 1, 2 }); + arch::InjectNonUniform(params, + local_domain, + injector, + 1.0, + true); + } + }; + +} // namespace user + +#endif diff --git a/setups/srpic/magnetosphere/magnetosphere.py b/pgens/magnetosphere/magnetosphere.py similarity index 100% rename from setups/srpic/magnetosphere/magnetosphere.py rename to pgens/magnetosphere/magnetosphere.py diff --git a/setups/srpic/monopole/monopole.toml b/pgens/magnetosphere/magnetosphere.toml similarity index 55% rename from setups/srpic/monopole/monopole.toml rename to pgens/magnetosphere/magnetosphere.toml index 169837489..1a4af8a09 100644 --- a/setups/srpic/monopole/monopole.toml +++ b/pgens/magnetosphere/magnetosphere.toml @@ -1,31 +1,31 @@ [simulation] - name = "monopole" - engine = "srpic" + name = "magnetosphere" + engine = "srpic" runtime = 60.0 [grid] resolution = [2048, 1024] - extent = [[1.0, 50.0]] + extent = [[1.0, 50.0]] [grid.metric] metric = "qspherical" [grid.boundaries] - fields = [["ATMOSPHERE", "ABSORB"]] + fields = [["ATMOSPHERE", "MATCH"]] particles = [["ATMOSPHERE", "ABSORB"]] - + [grid.boundaries.absorb] ds = 1.0 [grid.boundaries.atmosphere] temperature = 0.1 - density = 10.0 - height = 0.02 - species = [1, 2] - ds = 2.0 - + density = 10.0 + height = 0.02 + species = [1, 2] + ds = 2.0 + [scales] - larmor0 = 2e-5 + larmor0 = 2e-5 skindepth0 = 0.01 [algorithms] @@ -36,38 +36,37 @@ [algorithms.gca] e_ovr_b_max = 0.9 - larmor_max = 1.0 + larmor_max = 1.0 [particles] - ppc0 = 5.0 - use_weights = true - sort_interval = 100 + ppc0 = 10.0 + use_weights = true + clear_interval = 100 [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 + label = "e-" + mass = 1.0 + charge = -1.0 maxnpart = 1e8 - pusher = "Boris,GCA" + pusher = "Boris,GCA" [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 + label = "e+" + mass = 1.0 + charge = 1.0 maxnpart = 1e8 - pusher = "Boris,GCA" + pusher = "Boris,GCA" [setup] - Bsurf = 1.0 + Bsurf = 1.0 period = 60.0 [output] format = "hdf5" - + [output.fields] interval_time = 0.1 - quantities = ["N_1", "N_2", "E", "B", "T00"] - mom_smooth = 2 + quantities = ["N_1", "N_2", "E", "B", "T00"] [output.particles] enable = false @@ -77,4 +76,3 @@ [diagnostics] interval = 50 - colored_stdout = true diff --git a/setups/srpic/magnetosphere/pgen.hpp b/pgens/magnetosphere/pgen.hpp similarity index 94% rename from setups/srpic/magnetosphere/pgen.hpp rename to pgens/magnetosphere/pgen.hpp index 681c4d6d1..64fe13cfe 100644 --- a/setups/srpic/magnetosphere/pgen.hpp +++ b/pgens/magnetosphere/pgen.hpp @@ -86,9 +86,13 @@ namespace user { inline PGen() {} - auto FieldDriver(real_t time) const -> DriveFields { + auto AtmFields(real_t time) const -> DriveFields { return DriveFields { time, Bsurf, Rstar, Omega }; } + + auto MatchFields(real_t) const -> InitFields { + return InitFields { Bsurf, Rstar }; + } }; } // namespace user diff --git a/setups/pgen.hpp b/pgens/pgen.hpp similarity index 100% rename from setups/pgen.hpp rename to pgens/pgen.hpp diff --git a/pgens/reconnection/pgen.hpp b/pgens/reconnection/pgen.hpp new file mode 100644 index 000000000..8e8804b2d --- /dev/null +++ b/pgens/reconnection/pgen.hpp @@ -0,0 +1,341 @@ +#ifndef PROBLEM_GENERATOR_H +#define PROBLEM_GENERATOR_H + +#include "enums.h" +#include "global.h" + +#include "arch/directions.h" +#include "arch/kokkos_aliases.h" +#include "arch/traits.h" +#include "utils/numeric.h" + +#include "archetypes/energy_dist.h" +#include "archetypes/particle_injector.h" +#include "archetypes/problem_generator.h" +#include "archetypes/spatial_dist.h" +#include "framework/domain/metadomain.h" + +#include "kernels/particle_moments.hpp" + +namespace user { + using namespace ntt; + + template + struct CurrentLayer : public arch::SpatialDistribution { + CurrentLayer(const M& metric, real_t cs_width, real_t center_x, real_t cs_y) + : arch::SpatialDistribution { metric } + , cs_width { cs_width } + , center_x { center_x } + , cs_y { cs_y } {} + + Inline auto operator()(const coord_t& x_Ph) const -> real_t { + return ONE / SQR(math::cosh((x_Ph[1] - cs_y) / cs_width)) * + (ONE - math::exp(-SQR((x_Ph[0] - center_x) / cs_width))); + } + + private: + const real_t cs_width, center_x, cs_y; + }; + + // field initializer + template + struct InitFields { + InitFields(real_t bg_B, real_t bg_Bguide, real_t cs_width, real_t cs_y) + : bg_B { bg_B } + , bg_Bguide { bg_Bguide } + , cs_width { cs_width } + , cs_y { cs_y } {} + + Inline auto bx1(const coord_t& x_Ph) const -> real_t { + return bg_B * (math::tanh((x_Ph[1] - cs_y) / cs_width)); + } + + Inline auto bx3(const coord_t&) const -> real_t { + return bg_Bguide; + } + + private: + const real_t bg_B, bg_Bguide, cs_width, cs_y; + }; + + template + struct BoundaryFieldsInX1 { + BoundaryFieldsInX1(real_t bg_B, + real_t bg_Bguide, + real_t beta_rec, + real_t cs_width, + real_t cs_x, + real_t cs_y) + : bg_B { bg_B } + , bg_Bguide { bg_Bguide } + , beta_rec { beta_rec } + , cs_width { cs_width } + , cs_x { cs_x } + , cs_y { cs_y } {} + + Inline auto bx1(const coord_t& x_Ph) const -> real_t { + return bg_B * (math::tanh((x_Ph[1] - cs_y) / cs_width)); + } + + Inline auto bx2(const coord_t& x_Ph) const -> real_t { + return beta_rec * bg_B * (math::tanh((x_Ph[0] - cs_x) / cs_width)); + } + + Inline auto bx3(const coord_t&) const -> real_t { + return bg_Bguide; + } + + Inline auto ex1(const coord_t& x_Ph) const -> real_t { + return beta_rec * bg_Bguide * math::tanh((x_Ph[1] - cs_y) / cs_width); + } + + Inline auto ex2(const coord_t&) const -> real_t { + return ZERO; + } + + Inline auto ex3(const coord_t&) const -> real_t { + return -beta_rec * bg_B; + } + + private: + const real_t bg_B, bg_Bguide, beta_rec, cs_width, cs_x, cs_y; + }; + + template + struct BoundaryFieldsInX2 { + BoundaryFieldsInX2(real_t bg_B, real_t bg_Bguide, real_t cs_width, real_t cs_y) + : bg_B { bg_B } + , bg_Bguide { bg_Bguide } + , cs_width { cs_width } + , cs_y { cs_y } {} + + Inline auto bx1(const coord_t& x_Ph) const -> real_t { + return bg_B * (math::tanh((x_Ph[1] - cs_y) / cs_width)); + } + + Inline auto bx2(const coord_t&) const -> real_t { + return ZERO; + } + + Inline auto bx3(const coord_t&) const -> real_t { + return bg_Bguide; + } + + Inline auto ex1(const coord_t&) const -> real_t { + return ZERO; + } + + Inline auto ex2(const coord_t&) const -> real_t { + return ZERO; + } + + Inline auto ex3(const coord_t&) const -> real_t { + return ZERO; + } + + private: + const real_t bg_B, bg_Bguide, cs_width, cs_y; + }; + + // constant particle density for particle boundaries + template + struct ConstDens { + Inline auto operator()(const coord_t& x_Ph) const -> real_t { + return ONE; + } + }; + template + using spatial_dist_t = arch::Replenish>; + + template + struct PGen : public arch::ProblemGenerator { + // compatibility traits for the problem generator + static constexpr auto engines { traits::compatible_with::value }; + static constexpr auto metrics { traits::compatible_with::value }; + static constexpr auto dimensions { + traits::compatible_with::value + }; + + // for easy access to variables in the child class + using arch::ProblemGenerator::D; + using arch::ProblemGenerator::C; + using arch::ProblemGenerator::params; + + const real_t bg_B, bg_Bguide, bg_temperature, inj_ypad; + const real_t cs_width, cs_overdensity, cs_x, cs_y; + const real_t ymin, ymax; + const simtime_t t_open; + bool bc_opened { false }; + + Metadomain& metadomain; + + InitFields init_flds; + + inline PGen(const SimulationParams& p, Metadomain& m) + : arch::ProblemGenerator(p) + , bg_B { p.template get("setup.bg_B", 1.0) } + , bg_Bguide { p.template get("setup.bg_Bguide", 0.0) } + , bg_temperature { p.template get("setup.bg_temperature", 0.001) } + , inj_ypad { p.template get("setup.inj_ypad", (real_t)0.05) } + , cs_width { p.template get("setup.cs_width") } + , cs_overdensity { p.template get("setup.cs_overdensity") } + , cs_x { INV_2 * + (m.mesh().extent(in::x1).second + m.mesh().extent(in::x1).first) } + , cs_y { INV_2 * + (m.mesh().extent(in::x2).second + m.mesh().extent(in::x2).first) } + , ymin { m.mesh().extent(in::x2).first } + , ymax { m.mesh().extent(in::x2).second } + , t_open { p.template get( + "setup.t_open", + 1.5 * HALF * + (m.mesh().extent(in::x1).second - m.mesh().extent(in::x1).first)) } + , metadomain { m } + , init_flds { bg_B, bg_Bguide, cs_width, cs_y } {} + + inline PGen() {} + + auto MatchFieldsInX1(simtime_t) const -> BoundaryFieldsInX1 { + return BoundaryFieldsInX1 { bg_B, bg_Bguide, (real_t)0.1, + cs_width, cs_x, cs_y }; + } + + auto MatchFieldsInX2(simtime_t) const -> BoundaryFieldsInX2 { + return BoundaryFieldsInX2 { bg_B, bg_Bguide, cs_width, cs_y }; + } + + inline void InitPrtls(Domain& local_domain) { + // background + const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, + local_domain.random_pool, + bg_temperature); + const auto injector = arch::UniformInjector( + energy_dist, + { 1, 2 }); + arch::InjectUniform>( + params, + local_domain, + injector, + ONE); + + const auto sigma = params.template get("scales.sigma0"); + const auto c_omp = params.template get("scales.skindepth0"); + const auto cs_drift_beta = math::sqrt(sigma) * c_omp / + (cs_width * cs_overdensity); + const auto cs_drift_gamma = ONE / math::sqrt(ONE - SQR(cs_drift_beta)); + const auto cs_drift_u = cs_drift_beta * cs_drift_gamma; + const auto cs_temperature = HALF * sigma / cs_overdensity; + + // current layer + auto edist_cs = arch::Maxwellian(local_domain.mesh.metric, + local_domain.random_pool, + cs_temperature, + cs_drift_u, + in::x3, + false); + const auto sdist_cs = CurrentLayer(local_domain.mesh.metric, + cs_width, + cs_x, + cs_y); + const auto inj_cs = arch::NonUniformInjector( + edist_cs, + sdist_cs, + { 1, 2 }); + arch::InjectNonUniform(params, + local_domain, + inj_cs, + cs_overdensity); + } + + void CustomPostStep(timestep_t, simtime_t time, Domain& domain) { + // open boundaries if not yet opened at time = t_open + if ((t_open > 0.0) and (not bc_opened) and (time > t_open)) { + bc_opened = true; + metadomain.setFldsBC(bc_in::Mx1, FldsBC::MATCH); + metadomain.setPrtlBC(bc_in::Mx1, PrtlBC::ABSORB); + metadomain.setFldsBC(bc_in::Px1, FldsBC::MATCH); + metadomain.setPrtlBC(bc_in::Px1, PrtlBC::ABSORB); + } + + const auto energy_dist = arch::Maxwellian(domain.mesh.metric, + domain.random_pool, + bg_temperature); + + const auto dx = domain.mesh.metric.template sqrt_h_<1, 1>({}); + + boundaries_t inj_box_up, inj_box_down; + boundaries_t probe_box_up, probe_box_down; + inj_box_up.push_back(Range::All); + inj_box_down.push_back(Range::All); + probe_box_up.push_back(Range::All); + probe_box_down.push_back(Range::All); + inj_box_up.push_back({ ymax - inj_ypad - 10 * dx, ymax - inj_ypad }); + inj_box_down.push_back({ ymin + inj_ypad, ymin + inj_ypad + 10 * dx }); + probe_box_up.push_back({ ymax - inj_ypad - 10 * dx, ymax - inj_ypad }); + probe_box_down.push_back({ ymin + inj_ypad, ymin + inj_ypad + 10 * dx }); + + if constexpr (M::Dim == Dim::_3D) { + inj_box_up.push_back(Range::All); + inj_box_down.push_back(Range::All); + } + + { + // compute density of species #1 and #2 + const auto use_weights = params.template get( + "particles.use_weights"); + const auto ni2 = domain.mesh.n_active(in::x2); + const auto inv_n0 = ONE / params.template get("scales.n0"); + + auto scatter_buff = Kokkos::Experimental::create_scatter_view( + domain.fields.buff); + Kokkos::deep_copy(domain.fields.buff, ZERO); + for (const auto sp : std::vector { 1, 2 }) { + const auto& prtl_spec = domain.species[sp - 1]; + // clang-format off + Kokkos::parallel_for( + "ComputeMoments", + prtl_spec.rangeActiveParticles(), + kernel::ParticleMoments_kernel({}, scatter_buff, 0u, + prtl_spec.i1, prtl_spec.i2, prtl_spec.i3, + prtl_spec.dx1, prtl_spec.dx2, prtl_spec.dx3, + prtl_spec.ux1, prtl_spec.ux2, prtl_spec.ux3, + prtl_spec.phi, prtl_spec.weight, prtl_spec.tag, + prtl_spec.mass(), prtl_spec.charge(), + use_weights, + domain.mesh.metric, domain.mesh.flds_bc(), + ni2, inv_n0, 0u)); + // clang-format on + } + Kokkos::Experimental::contribute(domain.fields.buff, scatter_buff); + } + + const auto injector_up = arch::KeepConstantInjector( + energy_dist, + { 1, 2 }, + 0u, + probe_box_up); + const auto injector_down = arch::KeepConstantInjector( + energy_dist, + { 1, 2 }, + 0u, + probe_box_down); + + arch::InjectUniform( + params, + domain, + injector_up, + ONE, + params.template get("particles.use_weights"), + inj_box_up); + arch::InjectUniform( + params, + domain, + injector_down, + ONE, + params.template get("particles.use_weights"), + inj_box_down); + } + }; + +} // namespace user + +#endif diff --git a/pgens/reconnection/reconnection.toml b/pgens/reconnection/reconnection.toml new file mode 100644 index 000000000..76acbf089 --- /dev/null +++ b/pgens/reconnection/reconnection.toml @@ -0,0 +1,71 @@ +[simulation] + name = "reconnection" + engine = "srpic" + runtime = 10.0 + + [simulation.domain] + decomposition = [-1, 2] + +[grid] + resolution = [512, 512] + extent = [[-1.0, 1.0], [-1.0, 1.0]] + + [grid.metric] + metric = "minkowski" + + [grid.boundaries] + fields = [["PERIODIC"], ["MATCH", "MATCH"]] + particles = [["PERIODIC"], ["ABSORB", "ABSORB"]] + + [grid.boundaries.match] + ds = [[0.04], [0.1]] + +[scales] + larmor0 = 2e-4 + skindepth0 = 2e-3 + +[algorithms] + current_filters = 8 + + [algorithms.timestep] + CFL = 0.5 + +[particles] + ppc0 = 8.0 + + [[particles.species]] + label = "e-" + mass = 1.0 + charge = -1.0 + maxnpart = 1e7 + + [[particles.species]] + label = "e+" + mass = 1.0 + charge = 1.0 + maxnpart = 1e7 + +[setup] + bg_B = 1.0 + bg_Bguide = 0.0 + bg_temperature = 1e-4 + inj_ypad = 0.25 + cs_width = 0.05 + cs_overdensity = 3.0 + +[output] + format = "hdf5" + interval_time = 0.1 + + [output.fields] + quantities = ["N_1", "N_2", "E", "B", "J"] + + [output.particles] + enable = false + + [output.spectra] + enable = false + +[diagnostics] + colored_stdout = true + interval = 10 diff --git a/pgens/shock/pgen.hpp b/pgens/shock/pgen.hpp new file mode 100644 index 000000000..6bd6f21a9 --- /dev/null +++ b/pgens/shock/pgen.hpp @@ -0,0 +1,343 @@ +#ifndef PROBLEM_GENERATOR_H +#define PROBLEM_GENERATOR_H + +#include "enums.h" +#include "global.h" + +#include "arch/traits.h" +#include "utils/error.h" +#include "utils/numeric.h" + +#include "archetypes/energy_dist.h" +#include "archetypes/field_setter.h" +#include "archetypes/particle_injector.h" +#include "archetypes/problem_generator.h" +#include "framework/domain/metadomain.h" + +#include +#include + +namespace user { + using namespace ntt; + + template + struct InitFields { + /* + Sets up magnetic and electric field components for the simulation. + Must satisfy E = -v x B for Lorentz Force to be zero. + + @param bmag: magnetic field scaling + @param btheta: magnetic field polar angle + @param bphi: magnetic field azimuthal angle + @param drift_ux: drift velocity in the x direction + */ + InitFields(real_t bmag, real_t btheta, real_t bphi, real_t drift_ux) + : Bmag { bmag } + , Btheta { btheta * static_cast(convert::deg2rad) } + , Bphi { bphi * static_cast(convert::deg2rad) } + , Vx { drift_ux } {} + + // magnetic field components + Inline auto bx1(const coord_t&) const -> real_t { + return Bmag * math::cos(Btheta); + } + + Inline auto bx2(const coord_t&) const -> real_t { + return Bmag * math::sin(Btheta) * math::sin(Bphi); + } + + Inline auto bx3(const coord_t&) const -> real_t { + return Bmag * math::sin(Btheta) * math::cos(Bphi); + } + + // electric field components + Inline auto ex1(const coord_t&) const -> real_t { + return ZERO; + } + + Inline auto ex2(const coord_t&) const -> real_t { + return -Vx * Bmag * math::sin(Btheta) * math::cos(Bphi); + } + + Inline auto ex3(const coord_t&) const -> real_t { + return Vx * Bmag * math::sin(Btheta) * math::sin(Bphi); + } + + private: + const real_t Btheta, Bphi, Vx, Bmag; + }; + + + template + struct PGen : public arch::ProblemGenerator { + // compatibility traits for the problem generator + static constexpr auto engines { traits::compatible_with::value }; + static constexpr auto metrics { traits::compatible_with::value }; + static constexpr auto dimensions { + traits::compatible_with::value + }; + + // for easy access to variables in the child class + using arch::ProblemGenerator::D; + using arch::ProblemGenerator::C; + using arch::ProblemGenerator::params; + + // domain properties + const real_t global_xmin, global_xmax; + // gas properties + const real_t drift_ux, temperature, temperature_ratio, filling_fraction; + // injector properties + const real_t injector_velocity, injection_start, dt; + const int injection_frequency; + // magnetic field properties + real_t Btheta, Bphi, Bmag; + InitFields init_flds; + + inline PGen(const SimulationParams& p, const Metadomain& global_domain) + : arch::ProblemGenerator { p } + , global_xmin { global_domain.mesh().extent(in::x1).first } + , global_xmax { global_domain.mesh().extent(in::x1).second } + , drift_ux { p.template get("setup.drift_ux") } + , temperature { p.template get("setup.temperature") } + , temperature_ratio { p.template get("setup.temperature_ratio") } + , Bmag { p.template get("setup.Bmag", ZERO) } + , Btheta { p.template get("setup.Btheta", ZERO) } + , Bphi { p.template get("setup.Bphi", ZERO) } + , init_flds { Bmag, Btheta, Bphi, drift_ux } + , filling_fraction { p.template get("setup.filling_fraction", 1.0) } + , injector_velocity { p.template get("setup.injector_velocity", 1.0) } + , injection_start { p.template get("setup.injection_start", 0.0) } + , injection_frequency { p.template get("setup.injection_frequency", 100) } + , dt { p.template get("algorithms.timestep.dt") } {} + + inline PGen() {} + + auto MatchFields(real_t time) const -> InitFields { + return init_flds; + } + + auto FixFieldsConst(const bc_in&, const em& comp) const + -> std::pair { + if (comp == em::ex1) { + return { init_flds.ex1({ ZERO }), true }; + } else if (comp == em::ex2) { + return { ZERO, true }; + } else if (comp == em::ex3) { + return { ZERO, true }; + } else if (comp == em::bx1) { + return { init_flds.bx1({ ZERO }), true }; + } else if (comp == em::bx2) { + return { init_flds.bx2({ ZERO }), true }; + } else if (comp == em::bx3) { + return { init_flds.bx3({ ZERO }), true }; + } else { + raise::Error("Invalid component", HERE); + return { ZERO, false }; + } + } + + inline void InitPrtls(Domain& local_domain) { + + /* + * Plasma setup as partially filled box + * + * Plasma setup: + * + * global_xmin global_xmax + * | | + * V V + * |:::::::::::|..........................| + * ^ + * | + * filling_fraction + */ + + // minimum and maximum position of particles + real_t xg_min = global_xmin; + real_t xg_max = global_xmin + filling_fraction * (global_xmax - global_xmin); + + // define box to inject into + boundaries_t box; + // loop over all dimensions + for (auto d { 0u }; d < (unsigned int)M::Dim; ++d) { + // compute the range for the x-direction + if (d == static_cast(in::x1)) { + box.push_back({ xg_min, xg_max }); + } else { + // inject into full range in other directions + box.push_back(Range::All); + } + } + + // species #1 -> e^- + // species #2 -> protons + + // energy distribution of the particles + const auto energy_dist = arch::TwoTemperatureMaxwellian( + local_domain.mesh.metric, + local_domain.random_pool, + { temperature_ratio * temperature * local_domain.species[1].mass() , + temperature }, + { 1, 2 }, + -drift_ux, + in::x1); + + // we want to set up a uniform density distribution + const auto injector = arch::UniformInjector( + energy_dist, + { 1, 2 }); + + // inject uniformly within the defined box + arch::InjectUniform>( + params, + local_domain, + injector, + 1.0, // target density + false, // no weights + box); + } + + void CustomPostStep(timestep_t step, simtime_t time, Domain& domain) { + + /* + * Replenish plasma in a moving injector + * + * Injector setup: + * + * global_xmin purge/replenish global_xmax + * | x_init | | + * V v V V + * |:::::::::::;::::::::::|\\\\\\\\|......| + * xmin xmax + * ^ + * | + * moving injector + */ + + // check if the injector should be active + if (step % injection_frequency != 0) { + return; + } + + // initial position of injector + const auto x_init = global_xmin + + filling_fraction * (global_xmax - global_xmin); + + // compute the position of the injector after the current timestep + auto xmax = x_init + injector_velocity * + (std::max(time - injection_start, ZERO) + dt); + if (xmax >= global_xmax) { + xmax = global_xmax; + } + + // compute the beginning of the injected region + auto xmin = xmax - injection_frequency * dt; + if (xmin <= global_xmin) { + xmin = global_xmin; + } + + // define indice range to reset fields + boundaries_t incl_ghosts; + for (auto d = 0; d < M::Dim; ++d) { + incl_ghosts.push_back({ false, false }); + } + + // define box to reset fields + boundaries_t purge_box; + // loop over all dimension + for (auto d = 0u; d < M::Dim; ++d) { + if (d == 0) { + purge_box.push_back({ xmin, global_xmax }); + } else { + purge_box.push_back(Range::All); + } + } + + const auto extent = domain.mesh.ExtentToRange(purge_box, incl_ghosts); + tuple_t x_min { 0 }, x_max { 0 }; + for (auto d = 0; d < M::Dim; ++d) { + x_min[d] = extent[d].first; + x_max[d] = extent[d].second; + } + + Kokkos::parallel_for("ResetFields", + CreateRangePolicy(x_min, x_max), + arch::SetEMFields_kernel { + domain.fields.em, + init_flds, + domain.mesh.metric }); + + /* + tag particles inside the injection zone as dead + */ + const auto& mesh = domain.mesh; + + // loop over particle species + for (auto s { 0u }; s < 2; ++s) { + // get particle properties + auto& species = domain.species[s]; + auto i1 = species.i1; + auto dx1 = species.dx1; + auto tag = species.tag; + + Kokkos::parallel_for( + "RemoveParticles", + species.rangeActiveParticles(), + Lambda(index_t p) { + // check if the particle is already dead + if (tag(p) == ParticleTag::dead) { + return; + } + const auto x_Cd = static_cast(i1(p)) + + static_cast(dx1(p)); + const auto x_Ph = mesh.metric.template convert<1, Crd::Cd, Crd::XYZ>( + x_Cd); + + if (x_Ph > xmin) { + tag(p) = ParticleTag::dead; + } + }); + } + + /* + Inject slab of fresh plasma + */ + + // define box to inject into + boundaries_t inj_box; + // loop over all dimension + for (auto d = 0u; d < M::Dim; ++d) { + if (d == 0) { + inj_box.push_back({ xmin, xmax }); + } else { + inj_box.push_back(Range::All); + } + } + + // same maxwell distribution as above + const auto energy_dist = arch::TwoTemperatureMaxwellian( + domain.mesh.metric, + domain.random_pool, + { temperature_ratio * temperature * domain.species[1].mass(), + temperature }, + { 1, 2 }, + -drift_ux, + in::x1); + + // we want to set up a uniform density distribution + const auto injector = arch::UniformInjector( + energy_dist, + { 1, 2 }); + + // inject uniformly within the defined box + arch::InjectUniform>( + params, + domain, + injector, + 1.0, // target density + false, // no weights + inj_box); + } + }; +} // namespace user +#endif diff --git a/setups/srpic/shock/shock.py b/pgens/shock/shock.py similarity index 95% rename from setups/srpic/shock/shock.py rename to pgens/shock/shock.py index 64224c728..dc1565572 100644 --- a/setups/srpic/shock/shock.py +++ b/pgens/shock/shock.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt import matplotlib as mpl -data = nt2r.Data("shock-03.h5") +data = nt2r.Data("shock.h5") def frame(ti, f): @@ -55,7 +55,7 @@ def frame(ti, f): axs = [fig.add_subplot(gs[i]) for i in range(len(quantities))] for ax, q in zip(axs, quantities): - q["compute"](f).coarsen(x=2, y=2).mean().plot( + q["compute"](f.isel(t=ti)).plot( ax=ax, cmap=q["cmap"], norm=q["norm"], diff --git a/pgens/shock/shock.toml b/pgens/shock/shock.toml new file mode 100644 index 000000000..90678488a --- /dev/null +++ b/pgens/shock/shock.toml @@ -0,0 +1,70 @@ +[simulation] + name = "shock_perp" + engine = "srpic" + runtime = 50.0 + + [simulation.domain] + decomposition = [1,-1] + +[grid] + resolution = [4096, 128] + extent = [[0.0, 4.096], [-0.064, 0.064]] + + [grid.metric] + metric = "minkowski" + + [grid.boundaries] + fields = [["CONDUCTOR", "MATCH"], ["PERIODIC"]] + particles = [["REFLECT", "ABSORB"], ["PERIODIC"]] + + +[scales] + larmor0 = 0.057735 + skindepth0 = 0.01 + +[algorithms] + current_filters = 8 + + [algorithms.timestep] + CFL = 0.5 + +[particles] + ppc0 = 8.0 + + [[particles.species]] + label = "e-" + mass = 1.0 + charge = -1.0 + maxnpart = 8e7 + + [[particles.species]] + label = "p+" + mass = 100.0 + charge = 1.0 + maxnpart = 8e7 + +[setup] + drift_ux = 0.15 # speed towards the wall [c] + temperature = 0.001683 # temperature of maxwell distribution [kB T / (m_i c^2)] + temperature_ratio = 1.0 # temperature ratio of electrons to protons + Bmag = 1.0 # magnetic field strength as fraction of magnetisation + Btheta = 63.0 # magnetic field angle in the plane + Bphi = 0.0 # magnetic field angle out of plane + filling_fraction = 0.1 # fraction of the shock piston filled with plasma + injector_velocity = 0.2 # speed of injector [c] + injection_start = 0.0 # start time of moving injector + injection_frequency = 100 # inject particles every 100 timesteps + +[output] + interval_time = 0.1 + format = "hdf5" + + [output.fields] + quantities = ["N_1", "N_2", "B", "E"] + + [output.particles] + enable = true + stride = 10 + + [output.spectra] + enable = false diff --git a/pgens/streaming/pgen.hpp b/pgens/streaming/pgen.hpp new file mode 100644 index 000000000..ee14712de --- /dev/null +++ b/pgens/streaming/pgen.hpp @@ -0,0 +1,112 @@ +#ifndef PROBLEM_GENERATOR_H +#define PROBLEM_GENERATOR_H + +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "arch/traits.h" +#include "utils/error.h" +#include "utils/numeric.h" + +#include "archetypes/energy_dist.h" +#include "archetypes/particle_injector.h" +#include "archetypes/problem_generator.h" +#include "framework/domain/domain.h" +#include "framework/domain/metadomain.h" + +namespace user { + using namespace ntt; + + template + struct PGen : public arch::ProblemGenerator { + + // compatibility traits for the problem generator + static constexpr auto engines = traits::compatible_with::value; + static constexpr auto metrics = traits::compatible_with::value; + static constexpr auto dimensions = + traits::compatible_with::value; + + // for easy access to variables in the child class + using arch::ProblemGenerator::D; + using arch::ProblemGenerator::C; + using arch::ProblemGenerator::params; + + using prmvec_t = std::vector; + + prmvec_t drifts_in_x, drifts_in_y, drifts_in_z; + prmvec_t densities, temperatures; + + inline PGen(const SimulationParams& p, const Metadomain& global_domain) + : arch::ProblemGenerator { p } + , drifts_in_x { p.template get("setup.drifts_in_x", prmvec_t {}) } + , drifts_in_y { p.template get("setup.drifts_in_y", prmvec_t {}) } + , drifts_in_z { p.template get("setup.drifts_in_z", prmvec_t {}) } + , densities { p.template get("setup.densities", prmvec_t {}) } + , temperatures { p.template get("setup.temperatures", prmvec_t {}) } { + const auto nspec = p.template get("particles.nspec"); + raise::ErrorIf(nspec % 2 != 0, + "Number of species must be even for this setup", + HERE); + for (auto n = 0u; n < nspec; n += 2) { + raise::ErrorIf( + global_domain.species_params()[n].charge() != + -global_domain.species_params()[n + 1].charge(), + "Charges of i-th and i+1-th species must be opposite for this setup", + HERE); + } + for (auto* specs : + { &drifts_in_x, &drifts_in_y, &drifts_in_z, &temperatures }) { + if (specs->empty()) { + for (auto n = 0u; n < nspec; ++n) { + specs->push_back(ZERO); + } + } + raise::ErrorIf(specs->size() != nspec, + "Drift vector and/or temperature vector length does " + "not match number of species", + HERE); + } + if (densities.empty()) { + for (auto n = 0u; n < nspec; n += 2) { + densities.push_back(TWO / static_cast(nspec)); + } + } + raise::ErrorIf(densities.size() != nspec / 2, + "Density vector length must be half of the number of " + "species (per each pair of species)", + HERE); + } + + inline void InitPrtls(Domain& domain) { + const auto nspec = domain.species.size(); + for (auto n = 0u; n < nspec; n += 2) { + const auto drift_1 = prmvec_t { drifts_in_x[n], + drifts_in_y[n], + drifts_in_z[n] }; + const auto drift_2 = prmvec_t { drifts_in_x[n + 1], + drifts_in_y[n + 1], + drifts_in_z[n + 1] }; + const auto injector = arch::experimental:: + UniformInjector( + arch::experimental::Maxwellian(domain.mesh.metric, + domain.random_pool, + temperatures[n], + drift_1), + arch::experimental::Maxwellian(domain.mesh.metric, + domain.random_pool, + temperatures[n + 1], + drift_2), + { n + 1, n + 2 }); + arch::experimental::InjectUniform( + params, + domain, + injector, + densities[n / 2]); + } + } + }; + +} // namespace user + +#endif diff --git a/pgens/streaming/twostream.toml b/pgens/streaming/twostream.toml new file mode 100644 index 000000000..1b2334777 --- /dev/null +++ b/pgens/streaming/twostream.toml @@ -0,0 +1,83 @@ +[simulation] + name = "twostream" + engine = "srpic" + runtime = 1000.0 + +[grid] + resolution = [12288] + extent = [[0.0, 100.0]] + + [grid.metric] + metric = "minkowski" + + [grid.boundaries] + fields = [["PERIODIC"]] + particles = [["PERIODIC"]] + +[scales] + larmor0 = 100.0 + skindepth0 = 10.0 + +[algorithms] + current_filters = 4 + + [algorithms.timestep] + CFL = 0.5 + +[particles] + ppc0 = 16.0 + + [[particles.species]] + label = "e-Px" + mass = 1.0 + charge = -1.0 + maxnpart = 1e7 + + [[particles.species]] + label = "e+bg1" + mass = 1.0 + charge = 1.0 + maxnpart = 1e7 + pusher = "None" + + [[particles.species]] + label = "e-Mx" + mass = 1.0 + charge = -1.0 + maxnpart = 1e7 + + [[particles.species]] + label = "e+bg2" + mass = 1.0 + charge = 1.0 + maxnpart = 1e7 + pusher = "None" + +[setup] + # Drift 4-velocities for each species in all 3 directions + # @type: array of floats (length = nspec) + # @default: [ 0.0, ... ] + drifts_in_x = [0.1, 0.0, -0.1, 0.0] + drifts_in_y = [0.0, 0.0, 0.0, 0.0] + drifts_in_z = [0.0, 0.0, 0.0, 0.0] + # Pair-wise species densities in units of n0 + # @type: array of floats (length = nspec/2) + # @default: [ 2 / nspec, ... ] + densities = [0.5, 0.5] + # Species temperatures in units of m0 (c^2) + # @type: array of floats (length = nspec) + # @default: [ 0.0, ... ] + temperatures = [1e-4, 1e-4, 1e-4, 1e-4] + +[output] + interval_time = 2.0 + + [output.fields] + quantities = ["N_1", "N_3", "E", "B", "J", "T0i_1", "T0i_3"] + + [output.particles] + species = [1, 3] + stride = 10 + + [output.spectra] + enable = false diff --git a/pgens/streaming/weibel.toml b/pgens/streaming/weibel.toml new file mode 100644 index 000000000..0d1f15bca --- /dev/null +++ b/pgens/streaming/weibel.toml @@ -0,0 +1,89 @@ +[simulation] + name = "weibel" + engine = "srpic" + runtime = 100.0 + +[grid] + resolution = [1024, 1024] + extent = [[-10.0, 10.0], [-10.0, 10.0]] + + [grid.metric] + metric = "minkowski" + + [grid.boundaries] + fields = [["PERIODIC"], ["PERIODIC"]] + particles = [["PERIODIC"], ["PERIODIC"]] + +[scales] + larmor0 = 1.0 + skindepth0 = 1.0 + +[algorithms] + current_filters = 4 + + [algorithms.timestep] + CFL = 0.5 + +[particles] + ppc0 = 16.0 + + [[particles.species]] + label = "e-_p" + mass = 1.0 + charge = -1.0 + maxnpart = 1e7 + + [[particles.species]] + label = "e+_p" + mass = 1.0 + charge = 1.0 + maxnpart = 1e7 + + [[particles.species]] + label = "e-_b" + mass = 1.0 + charge = -1.0 + maxnpart = 1e7 + + [[particles.species]] + label = "e+_b" + mass = 1.0 + charge = 1.0 + maxnpart = 1e7 + +[setup] + # Drift 4-velocities for each species in all 3 directions + # @type: array of floats (length = nspec) + # @default: [ 0.0, ... ] + drifts_in_x = [0.0, 0.0, 0.0, 0.0] + drifts_in_y = [0.0, 0.0, 0.0, 0.0] + drifts_in_z = [0.3, 0.3, -0.3, -0.3] + # Pair-wise species densities in units of n0 + # @type: array of floats (length = nspec/2) + # @default: [ 2 / nspec, ... ] + densities = [0.5, 0.5] + # Species temperatures in units of m0 (c^2) + # @type: array of floats (length = nspec) + # @default: [ 0.0, ... ] + temperatures = [1e-4, 1e-4, 1e-4, 1e-4] + +[output] + interval_time = 0.25 + + [output.fields] + quantities = [ + "N_1_2", + "N_3_4", + "E", + "B", + "T0i_1", + "T0i_2", + "T0i_3", + "T0i_4", + ] + + [output.particles] + enable = false + + [output.spectra] + enable = false diff --git a/pgens/turbulence/pgen.hpp b/pgens/turbulence/pgen.hpp new file mode 100644 index 000000000..4c4a2c78e --- /dev/null +++ b/pgens/turbulence/pgen.hpp @@ -0,0 +1,442 @@ +#ifndef PROBLEM_GENERATOR_H +#define PROBLEM_GENERATOR_H + +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/error.h" +#include "utils/numeric.h" + +#include "archetypes/energy_dist.h" +#include "archetypes/particle_injector.h" +#include "archetypes/problem_generator.h" +#include "framework/domain/domain.h" +#include "framework/domain/metadomain.h" + +#if defined(MPI_ENABLED) + #include +#endif // MPI_ENABLED + +namespace user { + using namespace ntt; + + // initializing guide field and curl(B) = J_ext at the initial time step + template + struct InitFields { + InitFields(array_t& k, + array_t& a_real, + array_t& a_imag, + array_t& a_real_inv, + array_t& a_imag_inv) + : k { k } + , a_real { a_real } + , a_imag { a_imag } + , a_real_inv { a_real_inv } + , a_imag_inv { a_imag_inv } + , n_modes { a_real.size() } {}; + + Inline auto bx1(const coord_t& x_Ph) const -> real_t { + auto bx1_0 = ZERO; + if constexpr(D==Dim::_2D){ + for (auto i = 0; i < n_modes; i++) { + auto k_dot_r = k(0, i) * x_Ph[0] + k(1, i) * x_Ph[1]; + bx1_0 -= TWO * k(1, i) * + (a_real(i) * math::sin(k_dot_r) + a_imag(i) * math::cos(k_dot_r)); + bx1_0 -= TWO * k(1, i) * + (a_real_inv(i) * math::sin(k_dot_r) + + a_imag_inv(i) * math::cos(k_dot_r)); + } + return bx1_0; + } + if constexpr (D==Dim::_3D){ + for (auto i = 0; i < n_modes; i++) { + auto k_dot_r = k(0, i) * x_Ph[0] + k(1, i) * x_Ph[1] + k(2, i) * x_Ph[2]; + bx1_0 -= TWO * k(1, i) * + (a_real(i) * math::sin(k_dot_r) + a_imag(i) * math::cos(k_dot_r)); + } + return bx1_0; + } + } + + Inline auto bx2(const coord_t& x_Ph) const -> real_t { + auto bx2_0 = ZERO; + if constexpr (D==Dim::_2D){ + for (auto i = 0; i < n_modes; i++) { + auto k_dot_r = k(0, i) * x_Ph[0] + k(1, i) * x_Ph[1]; + bx2_0 += TWO * k(0, i) * + (a_real(i) * math::sin(k_dot_r) + a_imag(i) * math::cos(k_dot_r)); + bx2_0 += TWO * k(0, i) * + (a_real_inv(i) * math::sin(k_dot_r) + + a_imag_inv(i) * math::cos(k_dot_r)); + } + return bx2_0; + } + if constexpr (D==Dim::_3D){ + for (auto i = 0; i < n_modes; i++) { + auto k_dot_r = k(0, i) * x_Ph[0] + k(1, i) * x_Ph[1] + k(2, i) * x_Ph[2]; + bx2_0 += TWO * k(0, i) * + (a_real(i) * math::sin(k_dot_r) + a_imag(i) * math::cos(k_dot_r)); + } + return bx2_0; + } + } + + Inline auto bx3(const coord_t&) const -> real_t { + return ONE; + } + + array_t k; + array_t a_real; + array_t a_imag; + array_t a_real_inv; + array_t a_imag_inv; + std::size_t n_modes; + }; + + inline auto init_pool(int seed) -> unsigned int { + if (seed < 0) { + unsigned int new_seed = static_cast(rand()); +#if defined(MPI_ENABLED) + MPI_Bcast(&new_seed, 1, MPI_UNSIGNED, MPI_ROOT_RANK, MPI_COMM_WORLD); +#endif // MPI_ENABLED + return new_seed; + } else { + return static_cast(seed); + } + } + + template + inline auto init_wavenumbers() -> std::vector> { + if constexpr (D == Dim::_2D) { + return { + { 1, 0 }, + { 0, 1 }, + { 1, 1 }, + { -1, 1 } + }; + } else if constexpr (D == Dim::_3D) { + return { + { 1, 0, 1 }, + { 0, 1, 1 }, + { -1, 0, 1 }, + { 0, -1, 1 }, + { 1, 0,-1 }, + { 0, 1,-1 }, + { -1, 0,-1 }, + { 0, -1,-1 } + }; + } else { + raise::Error("Invalid dimension", HERE); + return {}; + } + } + + // external current definition + template + struct ExternalCurrent { + ExternalCurrent(real_t dB, + real_t om0, + real_t g0, + std::vector>& wavenumbers, + unsigned int seed, + real_t Lx, + real_t Ly, + real_t Lz) + : wavenumbers { wavenumbers } + , n_modes { wavenumbers.size() } + , dB { dB } + , Lx { Lx } + , Ly { Ly } + , Lz { Lz } + , seed { seed } + , omega_0 { om0 } + , gamma_0 { g0 } + , k { "wavevector", D, n_modes } + , a_real { "a_real", n_modes } + , a_imag { "a_imag", n_modes } + , a_real_inv { "a_real_inv", n_modes } + , a_imag_inv { "a_imag_inv", n_modes } + , A0 { "A0", n_modes } { + // initializing random generator + srand(seed); + // initializing wavevectors + auto k_host = Kokkos::create_mirror_view(k); + if constexpr (D == Dim::_2D) { + for (auto i = 0u; i < n_modes; i++) { + k_host(0, i) = constant::TWO_PI * wavenumbers[i][0] / Lx; + k_host(1, i) = constant::TWO_PI * wavenumbers[i][1] / Ly; + } + } + if constexpr (D == Dim::_3D) { + for (auto i = 0u; i < n_modes; i++) { + k_host(0, i) = constant::TWO_PI * wavenumbers[i][0] / Lx; + k_host(1, i) = constant::TWO_PI * wavenumbers[i][1] / Ly; + k_host(2, i) = constant::TWO_PI * wavenumbers[i][2] / Lz; + } + } + // initializing initial complex amplitudes + auto a_real_host = Kokkos::create_mirror_view(a_real); + auto a_imag_host = Kokkos::create_mirror_view(a_imag); + auto a_real_inv_host = Kokkos::create_mirror_view(a_real_inv); + auto a_imag_inv_host = Kokkos::create_mirror_view(a_imag_inv); + auto A0_host = Kokkos::create_mirror_view(A0); + + real_t prefac { ZERO }; + if constexpr (D == Dim::_2D) { + prefac = HALF; // HALF = 1/sqrt(twice modes due to reality condition * twice the frequencies due to sign change) + } else if constexpr (D == Dim::_3D) { + prefac = ONE; + } + for (auto i = 0u; i < n_modes; i++) { + auto k_perp = math::sqrt( + k_host(0, i) * k_host(0, i) + k_host(1, i) * k_host(1, i)); + real_t phase = static_cast (rand()) / static_cast (RAND_MAX) * constant::TWO_PI; + A0_host(i) = dB / math::sqrt((real_t)n_modes) / k_perp * prefac; + a_real_host(i) = A0_host(i) * math::cos(phase); + a_imag_host(i) = A0_host(i) * math::sin(phase); + phase = static_cast (rand()) / static_cast (RAND_MAX) * constant::TWO_PI; + a_imag_inv_host(i) = A0_host(i) * math::cos(phase); + a_real_inv_host(i) = A0_host(i) * math::sin(phase); + } + + Kokkos::deep_copy(a_real, a_real_host); + Kokkos::deep_copy(a_imag, a_imag_host); + Kokkos::deep_copy(a_real_inv, a_real_inv_host); + Kokkos::deep_copy(a_imag_inv, a_imag_inv_host); + Kokkos::deep_copy(A0, A0_host); + Kokkos::deep_copy(k, k_host); + }; + + Inline auto jx1(const coord_t& x_Ph) const -> real_t { + if constexpr (D == Dim::_2D) { + return ZERO; + } + if constexpr (D == Dim::_3D) { + real_t jx1_ant = ZERO; + for (auto i = 0u; i < n_modes; i++) { + auto k_dot_r = k(0, i) * x_Ph[0] + k(1, i) * x_Ph[1] + k(2, i) * x_Ph[2]; + jx1_ant -= TWO * k(0, i) * k(2, i) * + (a_real(i) * math::cos(k_dot_r) - + a_imag(i) * math::sin(k_dot_r)); + } + return jx1_ant; + } + } + + Inline auto jx2(const coord_t& x_Ph) const -> real_t { + if constexpr (D == Dim::_2D) { + return ZERO; + } else if constexpr (D == Dim::_3D) { + real_t jx2_ant = ZERO; + for (auto i = 0u; i < n_modes; i++) { + auto k_dot_r = k(0, i) * x_Ph[0] + k(1, i) * x_Ph[1] + k(2, i) * x_Ph[2]; + jx2_ant -= TWO * k(1, i) * k(2, i) * + (a_real(i) * math::cos(k_dot_r) - + a_imag(i) * math::sin(k_dot_r)); + } + return jx2_ant; + } + } + + Inline auto jx3(const coord_t& x_Ph) const -> real_t { + if constexpr (D == Dim::_2D) { + real_t jx3_ant = ZERO; + for (auto i = 0u; i < n_modes; i++) { + auto k_perp_sq = k(0, i) * k(0, i) + k(1, i) * k(1, i); + auto k_dot_r = k(0, i) * x_Ph[0] + k(1, i) * x_Ph[1]; + jx3_ant += TWO * k_perp_sq * + (a_real(i) * math::cos(k_dot_r) - + a_imag(i) * math::sin(k_dot_r)); + jx3_ant += TWO * k_perp_sq * + (a_real_inv(i) * math::cos(k_dot_r) - + a_imag_inv(i) * math::sin(k_dot_r)); + } + return jx3_ant; + } else if constexpr (D == Dim::_3D) { + real_t jx3_ant = ZERO; + for (auto i = 0u; i < n_modes; i++) { + auto k_perp_sq = k(0, i) * k(0, i) + k(1, i) * k(1, i); + auto k_dot_r = k(0, i) * x_Ph[0] + k(1, i) * x_Ph[1] + k(2, i) * x_Ph[2]; + jx3_ant += TWO * k_perp_sq * + (a_real(i) * math::cos(k_dot_r) - + a_imag(i) * math::sin(k_dot_r)); + } + return jx3_ant; + } + } + + private: + const std::vector> wavenumbers; + const std::size_t n_modes; + const real_t dB, Lx, Ly, Lz; + const int seed; + + public: + const real_t omega_0, gamma_0; + array_t k; + array_t a_real; + array_t a_imag; + array_t a_real_inv; + array_t a_imag_inv; + array_t A0; + }; + + template + struct PGen : public arch::ProblemGenerator { + + // compatibility traits for the problem generator + static constexpr auto engines = traits::compatible_with::value; + static constexpr auto metrics = traits::compatible_with::value; + static constexpr auto dimensions = traits::compatible_with::value; + + // for easy access to variables in the child class + using arch::ProblemGenerator::D; + using arch::ProblemGenerator::C; + using arch::ProblemGenerator::params; + + const real_t temperature, dB, omega_0, gamma_0; + const real_t Lx, Ly, Lz, escape_dist; + const int random_seed; + std::vector> wavenumbers; + random_number_pool_t random_pool; + + // debugging, will delete later + real_t total_sum = ZERO; + real_t total_sum_inv = ZERO; + real_t number_of_timesteps = ZERO; + + ExternalCurrent ext_current; + InitFields init_flds; + + inline PGen(const SimulationParams& p, const Metadomain& global_domain) + : arch::ProblemGenerator { p } + , temperature { p.template get("setup.temperature") } + , dB { p.template get("setup.dB", ONE) } + , omega_0 { p.template get("setup.omega_0") } + , gamma_0 { p.template get("setup.gamma_0") } + , wavenumbers { init_wavenumbers() } + , random_seed { p.template get("setup.seed", -1) } + , random_pool { init_pool(random_seed) } + , Lx { global_domain.mesh().extent(in::x1).second - + global_domain.mesh().extent(in::x1).first } + , Ly { global_domain.mesh().extent(in::x2).second - + global_domain.mesh().extent(in::x2).first } + , Lz { global_domain.mesh().extent(in::x3).second - + global_domain.mesh().extent(in::x3).first } + , escape_dist { p.template get("setup.escape_dist", HALF * Lx) } + , ext_current { dB, omega_0, gamma_0, wavenumbers, init_pool(random_seed), Lx, Ly, Lz } + , init_flds { ext_current.k, + ext_current.a_real, + ext_current.a_imag, + ext_current.a_real_inv, + ext_current.a_imag_inv } {}; + + inline void InitPrtls(Domain& local_domain) { + const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, + local_domain.random_pool, + temperature); + const auto spatial_dist = arch::UniformInjector( + energy_dist, + { 1, 2 }); + arch::InjectUniform>( + params, + local_domain, + spatial_dist, + ONE); + }; + + void CustomPostStep(timestep_t, simtime_t, Domain& domain) { + #if defined(MPI_ENABLED) + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + #endif + // update amplitudes of antenna + const auto dt = params.template get("algorithms.timestep.dt"); + const auto& ext_curr = ext_current; + Kokkos::parallel_for( + "Antenna amplitudes", + wavenumbers.size(), + ClassLambda(index_t i) { + auto generator = random_pool.get_state(); + const auto u_imag = Random(generator) - HALF; + const auto u_real = Random(generator) - HALF; + const auto u_real_inv = Random(generator) - HALF; + const auto u_imag_inv = Random(generator) - HALF; + random_pool.free_state(generator); + + auto a_real_prev = ext_curr.a_real(i); + auto a_imag_prev = ext_curr.a_imag(i); + auto a_real_inv_prev = ext_curr.a_real_inv(i); + auto a_imag_inv_prev = ext_curr.a_imag_inv(i); + ext_curr.a_real(i) = (a_real_prev * math::cos(ext_curr.omega_0 * dt) + + a_imag_prev * math::sin(ext_curr.omega_0 * dt)) * + math::exp(-ext_curr.gamma_0 * dt) + + ext_curr.A0(i) * + math::sqrt(TWELVE * ext_curr.gamma_0 / dt) * + u_real * dt; + + ext_curr.a_imag(i) = (a_imag_prev * math::cos(ext_curr.omega_0 * dt) - + a_real_prev * math::sin(ext_curr.omega_0 * dt)) * + math::exp(-ext_curr.gamma_0 * dt) + + ext_curr.A0(i) * + math::sqrt(TWELVE * ext_curr.gamma_0 / dt) * + u_imag * dt; + + ext_curr.a_real_inv( + i) = (a_real_inv_prev * math::cos(-ext_curr.omega_0 * dt) + + a_imag_inv_prev * math::sin(-ext_curr.omega_0 * dt)) * + math::exp(-ext_curr.gamma_0 * dt) + + ext_curr.A0(i) * math::sqrt(TWELVE * ext_curr.gamma_0 / dt) * + u_real_inv * dt; + + ext_curr.a_imag_inv( + i) = (a_imag_inv_prev * math::cos(-ext_curr.omega_0 * dt) - + a_real_inv_prev * math::sin(-ext_curr.omega_0 * dt)) * + math::exp(-ext_curr.gamma_0 * dt) + + ext_curr.A0(i) * math::sqrt(TWELVE * ext_curr.gamma_0 / dt) * + u_imag_inv * dt; + }); + + // particle escape (resample velocities) + const auto energy_dist = arch::Maxwellian(domain.mesh.metric, + domain.random_pool, + temperature); + for (const auto& sp : { 0, 1 }) { + if (domain.species[sp].npld() > 1) { + const auto& ux1 = domain.species[sp].ux1; + const auto& ux2 = domain.species[sp].ux2; + const auto& ux3 = domain.species[sp].ux3; + const auto& pld = domain.species[sp].pld; + const auto& tag = domain.species[sp].tag; + const auto L = escape_dist; + Kokkos::parallel_for( + "UpdatePld", + domain.species[sp].npart(), + Lambda(index_t p) { + if (tag(p) == ParticleTag::dead) { + return; + } + const auto gamma = math::sqrt( + ONE + ux1(p) * ux1(p) + ux2(p) * ux2(p) + ux3(p) * ux3(p)); + pld(p, 0) += ux1(p) * dt / gamma; + pld(p, 1) += ux2(p) * dt / gamma; + + if ((math::abs(pld(p, 0)) > L) or (math::abs(pld(p, 1)) > L)) { + coord_t x_Ph { ZERO }; + vec_t u_Mxw { ZERO }; + energy_dist(x_Ph, u_Mxw); + ux1(p) = u_Mxw[0]; + ux2(p) = u_Mxw[1]; + ux3(p) = u_Mxw[2]; + pld(p, 0) = ZERO; + pld(p, 1) = ZERO; + } + }); + } + } + } + }; +} // namespace user + +#endif diff --git a/pgens/turbulence/turbulence.toml b/pgens/turbulence/turbulence.toml new file mode 100644 index 000000000..79cc641ef --- /dev/null +++ b/pgens/turbulence/turbulence.toml @@ -0,0 +1,65 @@ +[simulation] + name = "turbulence" + engine = "srpic" + runtime = 1200.0 + +[grid] + resolution = [1024, 1024] + extent = [[-128.0, 128.0], [-128.0, 128.0]] + + [grid.metric] + metric = "minkowski" + + [grid.boundaries] + fields = [["PERIODIC"], ["PERIODIC"]] + particles = [["PERIODIC"], ["PERIODIC"]] + +[scales] + larmor0 = 1.0 + skindepth0 = 1.0 + +[algorithms] + current_filters = 4 + + [algorithms.timestep] + CFL = 0.5 + +[particles] + ppc0 = 32.0 + + [[particles.species]] + label = "e-_p" + mass = 1.0 + charge = -1.0 + maxnpart = 2e7 + + [[particles.species]] + label = "e+_p" + mass = 1.0 + charge = 1.0 + maxnpart = 2e7 + +[setup] + temperature = 1e0 + dB = 1.0 + omega_0 = 0.0156 + gamma_0 = 0.0078 + + +[output] + format = "hdf5" + interval_time = 12.0 + + [output.fields] + quantities = ["N_1_2", "J", "B", "E"] + + [output.particles] + enable = false + + [output.spectra] + enable = false + [output.stats] + enable = false + +[diagnostics] + colored_stdout = true diff --git a/pgens/wald/pgen.hpp b/pgens/wald/pgen.hpp new file mode 100644 index 000000000..71ee905e3 --- /dev/null +++ b/pgens/wald/pgen.hpp @@ -0,0 +1,258 @@ +#ifndef PROBLEM_GENERATOR_H +#define PROBLEM_GENERATOR_H + +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "arch/traits.h" +#include "utils/comparators.h" +#include "utils/error.h" +#include "utils/formatting.h" +#include "utils/log.h" +#include "utils/numeric.h" + +#include "archetypes/energy_dist.h" +#include "archetypes/particle_injector.h" +#include "archetypes/problem_generator.h" +#include "framework/domain/domain.h" +#include "framework/domain/metadomain.h" + +#include +#include + +enum InitFieldGeometry { + Wald, + Vertical, +}; + +namespace user { + using namespace ntt; + + template + struct InitFields { + InitFields(M metric_, const std::string& init_field_geometry) + : metric { metric_ } { + if (init_field_geometry == "wald") { + field_geometry = InitFieldGeometry::Wald; + } else if (init_field_geometry == "vertical") { + field_geometry = InitFieldGeometry::Vertical; + } else { + raise::Error(fmt::format("Unrecognized field geometry: %s", + init_field_geometry.c_str()), + HERE); + } + } + + Inline auto A_3(const coord_t& x_Cd) const -> real_t { + return HALF * (metric.template h_<3, 3>(x_Cd) + + TWO * metric.spin() * metric.template h_<1, 3>(x_Cd) * + metric.beta1(x_Cd)); + } + + Inline auto A_1(const coord_t& x_Cd) const -> real_t { + return HALF * (metric.template h_<1, 3>(x_Cd) + + TWO * metric.spin() * metric.template h_<1, 1>(x_Cd) * + metric.beta1(x_Cd)); + } + + Inline auto A_0(const coord_t& x_Cd) const -> real_t { + real_t g_00 { -metric.alpha(x_Cd) * metric.alpha(x_Cd) + + metric.template h_<1, 1>(x_Cd) * metric.beta1(x_Cd) * + metric.beta1(x_Cd) }; + return HALF * (metric.template h_<1, 3>(x_Cd) * metric.beta1(x_Cd) + + TWO * metric.spin() * g_00); + } + + Inline auto bx1(const coord_t& x_Ph) const -> real_t { // at ( i , j + HALF ) + coord_t xi { ZERO }, x0m { ZERO }, x0p { ZERO }; + metric.template convert(x_Ph, xi); + + x0m[0] = xi[0]; + x0m[1] = xi[1] - HALF; + x0p[0] = xi[0]; + x0p[1] = xi[1] + HALF; + + real_t inv_sqrt_detH_ijP { ONE / metric.sqrt_det_h({ xi[0], xi[1] }) }; + + if (cmp::AlmostZero(x_Ph[1])) { + return ONE; + } else { + return (A_3(x0p) - A_3(x0m)) * inv_sqrt_detH_ijP; + } + } + + Inline auto bx2(const coord_t& x_Ph) const -> real_t { // at ( i + HALF , j ) + coord_t xi { ZERO }, x0m { ZERO }, x0p { ZERO }; + metric.template convert(x_Ph, xi); + + x0m[0] = xi[0] - HALF; + x0m[1] = xi[1]; + x0p[0] = xi[0] + HALF; + x0p[1] = xi[1]; + + real_t inv_sqrt_detH_ijP { ONE / metric.sqrt_det_h({ xi[0], xi[1] }) }; + if (cmp::AlmostZero(x_Ph[1])) { + return ZERO; + } else { + return -(A_3(x0p) - A_3(x0m)) * inv_sqrt_detH_ijP; + } + } + + Inline auto bx3( + const coord_t& x_Ph) const -> real_t { // at ( i + HALF , j + HALF ) + if (field_geometry == InitFieldGeometry::Wald) { + coord_t xi { ZERO }, x0m { ZERO }, x0p { ZERO }; + metric.template convert(x_Ph, xi); + + x0m[0] = xi[0]; + x0m[1] = xi[1] - HALF; + x0p[0] = xi[0]; + x0p[1] = xi[1] + HALF; + + real_t inv_sqrt_detH_iPjP { ONE / metric.sqrt_det_h({ xi[0], xi[1] }) }; + return -(A_1(x0p) - A_1(x0m)) * inv_sqrt_detH_iPjP; + } else if (field_geometry == InitFieldGeometry::Vertical) { + return ZERO; + } else { + raise::KernelError(HERE, "Unrecognized field geometry"); + return ZERO; + } + } + + Inline auto dx1(const coord_t& x_Ph) const -> real_t { // at ( i + HALF , j ) + if (field_geometry == InitFieldGeometry::Wald) { + coord_t xi { ZERO }, x0m { ZERO }, x0p { ZERO }; + metric.template convert(x_Ph, xi); + + real_t alpha_iPj { metric.alpha({ xi[0], xi[1] }) }; + real_t inv_sqrt_detH_ij { ONE / metric.sqrt_det_h({ xi[0] - HALF, xi[1] }) }; + real_t sqrt_detH_ij { metric.sqrt_det_h({ xi[0] - HALF, xi[1] }) }; + real_t beta_ij { metric.beta1({ xi[0] - HALF, xi[1] }) }; + real_t alpha_ij { metric.alpha({ xi[0] - HALF, xi[1] }) }; + + // D1 at ( i + HALF , j ) + x0m[0] = xi[0] - HALF; + x0m[1] = xi[1]; + x0p[0] = xi[0] + HALF; + x0p[1] = xi[1]; + real_t E1d { (A_0(x0p) - A_0(x0m)) }; + real_t D1d { E1d / alpha_iPj }; + + // D3 at ( i , j ) + x0m[0] = xi[0] - HALF - HALF; + x0m[1] = xi[1]; + x0p[0] = xi[0] - HALF + HALF; + x0p[1] = xi[1]; + real_t D3d { (A_3(x0p) - A_3(x0m)) * beta_ij / alpha_ij }; + + real_t D1u { metric.template h<1, 1>({ xi[0], xi[1] }) * D1d + + metric.template h<1, 3>({ xi[0], xi[1] }) * D3d }; + + return D1u; + } else if (field_geometry == InitFieldGeometry::Vertical) { + return ZERO; + } else { + raise::KernelError(HERE, "Unrecognized field geometry"); + return ZERO; + } + } + + Inline auto dx2(const coord_t& x_Ph) const -> real_t { // at ( i , j + HALF ) + if (field_geometry == InitFieldGeometry::Wald) { + coord_t xi { ZERO }, x0m { ZERO }, x0p { ZERO }; + metric.template convert(x_Ph, xi); + x0m[0] = xi[0]; + x0m[1] = xi[1] - HALF; + x0p[0] = xi[0]; + x0p[1] = xi[1] + HALF; + real_t inv_sqrt_detH_ijP { ONE / metric.sqrt_det_h({ xi[0], xi[1] }) }; + real_t sqrt_detH_ijP { metric.sqrt_det_h({ xi[0], xi[1] }) }; + real_t alpha_ijP { metric.alpha({ xi[0], xi[1] }) }; + real_t beta_ijP { metric.beta1({ xi[0], xi[1] }) }; + + real_t E2d { (A_0(x0p) - A_0(x0m)) }; + real_t D2d { E2d / alpha_ijP - + (A_1(x0p) - A_1(x0m)) * beta_ijP / alpha_ijP }; + real_t D2u { metric.template h<2, 2>({ xi[0], xi[1] }) * D2d }; + + return D2u; + } else if (field_geometry == InitFieldGeometry::Vertical) { + return ZERO; + } else { + raise::KernelError(HERE, "Unrecognized field geometry"); + return ZERO; + } + } + + Inline auto dx3(const coord_t& x_Ph) const -> real_t { // at ( i , j ) + if (field_geometry == InitFieldGeometry::Wald) { + coord_t xi { ZERO }, x0m { ZERO }, x0p { ZERO }; + metric.template convert(x_Ph, xi); + real_t inv_sqrt_detH_ij { ONE / metric.sqrt_det_h({ xi[0], xi[1] }) }; + real_t sqrt_detH_ij { metric.sqrt_det_h({ xi[0], xi[1] }) }; + real_t beta_ij { metric.beta1({ xi[0], xi[1] }) }; + real_t alpha_ij { metric.alpha({ xi[0], xi[1] }) }; + real_t alpha_iPj { metric.alpha({ xi[0] + HALF, xi[1] }) }; + + // D3 at ( i , j ) + x0m[0] = xi[0] - HALF; + x0m[1] = xi[1]; + x0p[0] = xi[0] + HALF; + x0p[1] = xi[1]; + real_t D3d { (A_3(x0p) - A_3(x0m)) * beta_ij / alpha_ij }; + + // D1 at ( i + HALF , j ) + x0m[0] = xi[0] + HALF - HALF; + x0m[1] = xi[1]; + x0p[0] = xi[0] + HALF + HALF; + x0p[1] = xi[1]; + real_t E1d { (A_0(x0p) - A_0(x0m)) }; + real_t D1d { E1d / alpha_iPj }; + + if (cmp::AlmostZero(x_Ph[1])) { + return metric.template h<1, 3>({ xi[0], xi[1] }) * D1d; + } else { + return metric.template h<3, 3>({ xi[0], xi[1] }) * D3d + + metric.template h<1, 3>({ xi[0], xi[1] }) * D1d; + } + } else if (field_geometry == InitFieldGeometry::Vertical) { + return ZERO; + } else { + raise::KernelError(HERE, "Unrecognized field geometry"); + return ZERO; + } + } + + private: + const M metric; + InitFieldGeometry field_geometry; + }; + + template + struct PGen : public arch::ProblemGenerator { + // compatibility traits for the problem generator + static constexpr auto engines { traits::compatible_with::value }; + static constexpr auto metrics { + traits::compatible_with::value + }; + static constexpr auto dimensions { traits::compatible_with::value }; + + // for easy access to variables in the child class + using arch::ProblemGenerator::D; + using arch::ProblemGenerator::C; + using arch::ProblemGenerator::params; + + InitFields init_flds; + const Metadomain& global_domain; + + inline PGen(const SimulationParams& p, const Metadomain& m) + : arch::ProblemGenerator { p } + , global_domain { m } + , init_flds { m.mesh().metric, + p.template get("setup.init_field", "wald") } {} + }; + +} // namespace user + +#endif diff --git a/pgens/wald/wald.toml b/pgens/wald/wald.toml new file mode 100644 index 000000000..2b05fbac9 --- /dev/null +++ b/pgens/wald/wald.toml @@ -0,0 +1,59 @@ +[simulation] + name = "vacuum" + engine = "grpic" + runtime = 100.0 + +[grid] + resolution = [512, 512] + extent = [[1.0, 10.0]] + + [grid.metric] + metric = "qkerr_schild" + qsph_r0 = 0.0 + qsph_h = 0.0 + ks_a = 0.95 + + [grid.boundaries] + fields = [["MATCH"]] + particles = [["ABSORB"]] + + [grid.boundaries.absorb] + ds = 1.0 + +[scales] + larmor0 = 0.0025 + skindepth0 = 0.05 + +[algorithms] + current_filters = 0 + + [algorithms.timestep] + CFL = 0.5 + + [algorithms.toggles] + deposit = false + fieldsolver = true + +[particles] + ppc0 = 2.0 + +[setup] + init_field = "wald" # or "vertical" + +[output] + format = "hdf5" + + [output.fields] + interval_time = 1.0 + quantities = ["D", "H", "B", "A"] + + [output.particles] + enable = false + + [output.spectra] + enable = false + +[diagnostics] + interval = 2 + colored_stdout = true + blocking_timers = true diff --git a/setups/CMakeLists.txt b/setups/CMakeLists.txt deleted file mode 100644 index b1753d7b8..000000000 --- a/setups/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -# ------------------------------ -# @defines: ntt_pgen [INTERFACE] -# @includes: -# - ../ -# @depends: -# - ntt_pgen [required] -# @uses: -# - kokkos [required] -# - plog [required] -# - mpi [optional] -# ------------------------------ - -add_library(ntt_pgen INTERFACE) -target_link_libraries(ntt_pgen INTERFACE - ntt_global - ntt_framework - ntt_archetypes - ntt_kernels -) - -target_include_directories(ntt_pgen - INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/${PGEN} -) \ No newline at end of file diff --git a/setups/grpic/pgen_grpic_example.hpp b/setups/grpic/pgen_grpic_example.hpp deleted file mode 100644 index f553ae849..000000000 --- a/setups/grpic/pgen_grpic_example.hpp +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/traits.h" - -#include "archetypes/problem_generator.h" - -namespace user { - using namespace ntt; - - template - struct PGen : public ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines { traits::compatible_with::value }; - static constexpr auto metrics { - traits::compatible_with::value - }; - static constexpr auto dimensions { traits::compatible_with::value }; - - // for easy access to variables in the child class - using ProblemGenerator::D; - using ProblemGenerator::C; - using ProblemGenerator::params; - using ProblemGenerator::domain; - - inline PGen(SimulationParams& p, const Metadomain&) : - ProblemGenerator(p) {} - - inline PGen() {} - }; - -} // namespace user - -#endif diff --git a/setups/srpic/em_vacuum/em_vacuum.py b/setups/srpic/em_vacuum/em_vacuum.py deleted file mode 100644 index 13a62ea3b..000000000 --- a/setups/srpic/em_vacuum/em_vacuum.py +++ /dev/null @@ -1,14 +0,0 @@ -import nt2.read as nt2r -import matplotlib.pyplot as plt - -data = nt2r.Data("em_vacuum.h5") - - -def plot(ti): - fig = plt.figure(figsize=(10, 5), dpi=150) - ax = fig.add_subplot(121) - data.Bz.isel(t=ti).plot(ax=ax, cmap="BrBG") - ax = fig.add_subplot(122) - data.Ey.isel(t=ti).plot(ax=ax, cmap="RdBu_r") - for ax in fig.axes[::2]: - ax.set_aspect("equal") diff --git a/setups/srpic/em_vacuum/em_vacuum.toml b/setups/srpic/em_vacuum/em_vacuum.toml deleted file mode 100644 index 156c8d308..000000000 --- a/setups/srpic/em_vacuum/em_vacuum.toml +++ /dev/null @@ -1,43 +0,0 @@ -[simulation] - name = "em_vacuum" - engine = "srpic" - runtime = 2.0 - -[grid] - resolution = [256, 512] - extent = [[-1.0, 1.0], [-2.0, 2.0]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["PERIODIC"], ["PERIODIC"]] - particles = [["PERIODIC"], ["PERIODIC"]] - -[scales] - larmor0 = 0.1 - skindepth0 = 0.01 - -[algorithms] - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 1.0 - -[setup] - amplitude = 1.0 - kx1 = 1 - kx2 = 1 - kx3 = 0 - -[output] - format = "hdf5" - interval_time = 0.1 - - [output.fields] - quantities = ["E", "B"] - -[diagnostics] - colored_stdout = true diff --git a/setups/srpic/em_vacuum/pgen.hpp b/setups/srpic/em_vacuum/pgen.hpp deleted file mode 100644 index 52368cbd8..000000000 --- a/setups/srpic/em_vacuum/pgen.hpp +++ /dev/null @@ -1,108 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" -#include "utils/comparators.h" -#include "utils/numeric.h" - -#include "archetypes/problem_generator.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct InitFields { - InitFields(real_t a, real_t sx1, real_t sx2, real_t sx3, int k1, int k2, int k3) - : amplitude { a } - , kx1 { (sx1 > ZERO) ? (real_t)(constant::TWO_PI) * (real_t)k1 / sx1 : ZERO } - , kx2 { (sx2 > ZERO) ? (real_t)(constant::TWO_PI) * (real_t)k2 / sx2 : ZERO } - , kx3 { (sx3 > ZERO) ? (real_t)(constant::TWO_PI) * (real_t)k3 / sx3 : ZERO } - , kmag13 { math::sqrt(SQR(kx1) + SQR(kx3)) } - , kmag { math::sqrt(SQR(kx1) + SQR(kx2) + SQR(kx3)) } { - raise::ErrorIf(cmp::AlmostZero_host(kx1) and cmp::AlmostZero_host(kx3), - "kx1 and kx3 cannot be zero", - HERE); - } - - // B is in k x y - // E is in -k x B - - Inline auto arg(const coord_t& x_Ph) const -> real_t { - if constexpr (D == Dim::_1D) { - return kx1 * x_Ph[0]; - } else if constexpr (D == Dim::_2D) { - return kx1 * x_Ph[0] + kx2 * x_Ph[1]; - } else { - return kx1 * x_Ph[0] + kx2 * x_Ph[1] + kx3 * x_Ph[2]; - } - } - - Inline auto ex1(const coord_t& x_Ph) const -> real_t { - return -amplitude * kx1 * kx2 / (kmag13 * kmag) * math::sin(arg(x_Ph)); - } - - Inline auto ex2(const coord_t& x_Ph) const -> real_t { - return amplitude * (SQR(kx1) + SQR(kx3)) / (kmag13 * kmag) * - math::sin(arg(x_Ph)); - } - - Inline auto ex3(const coord_t& x_Ph) const -> real_t { - return -amplitude * kx3 * kx2 / (kmag13 * kmag) * math::sin(arg(x_Ph)); - } - - Inline auto bx1(const coord_t& x_Ph) const -> real_t { - return -amplitude * (kx3 / kmag13) * math::sin(arg(x_Ph)); - } - - // skipping bx2 - - Inline auto bx3(const coord_t& x_Ph) const -> real_t { - return amplitude * (kx1 / kmag13) * math::sin(arg(x_Ph)); - } - - private: - const real_t amplitude; - const real_t kx1, kx2, kx3, kmag13, kmag; - }; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines = traits::compatible_with::value; - static constexpr auto metrics = traits::compatible_with::value; - static constexpr auto dimensions = - traits::compatible_with::value; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const real_t amplitude; - const int kx1, kx2, kx3; - const real_t sx1, sx2, sx3; - InitFields init_flds; - - inline PGen(const SimulationParams& p, const Metadomain& global_domain) - : arch::ProblemGenerator { p } - , amplitude { params.template get("setup.amplitude", 1.0) } - , kx1 { params.template get("setup.kx1", 1) } - , kx2 { params.template get("setup.kx2", 0) } - , kx3 { params.template get("setup.kx3", 0) } - , sx1 { global_domain.mesh().extent(in::x1).second - - global_domain.mesh().extent(in::x1).first } - , sx2 { global_domain.mesh().extent(in::x2).second - - global_domain.mesh().extent(in::x2).first } - , sx3 { global_domain.mesh().extent(in::x3).second - - global_domain.mesh().extent(in::x3).first } - , init_flds { amplitude, sx1, sx2, sx3, kx1, kx2, kx3 } {} - }; - -} // namespace user - -#endif diff --git a/setups/srpic/example/pgen.hpp b/setups/srpic/example/pgen.hpp deleted file mode 100644 index 3739243cd..000000000 --- a/setups/srpic/example/pgen.hpp +++ /dev/null @@ -1,105 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" -#include "utils/numeric.h" - -#include "archetypes/problem_generator.h" -#include "framework/domain/metadomain.h" - -#include - -namespace user { - using namespace ntt; - - template - struct InitFields { - - InitFields(real_t a, real_t sx2, int kx2) - : amplitude { a } - , sx2 { sx2 } - , kx2 { kx2 } {} - - // only set ex2 and bx3 - - Inline auto ex2(const coord_t& x_Ph) const -> real_t { - return amplitude * math::sin(constant::TWO_PI * (x_Ph[1] / sx2) * - static_cast(kx2)); - } - - Inline auto bx3(const coord_t& x_Ph) const -> real_t { - return -amplitude * math::cos(constant::TWO_PI * (x_Ph[1] / sx2) * - static_cast(kx2)); - } - - private: - const real_t amplitude; - const real_t sx2; - const int kx2; - }; - - template - struct ExtForce { - const std::vector species { 1, 2 }; - - ExtForce() = default; - - Inline auto fx1(const unsigned short& sp, - const real_t& time, - const coord_t& x_Ph) const -> real_t { - (void)sp; - (void)time; - (void)x_Ph; - return ZERO; - } - - Inline auto fx2(const unsigned short& sp, - const real_t& time, - const coord_t& x_Ph) const -> real_t { - (void)sp; - (void)time; - (void)x_Ph; - return ZERO; - } - - Inline auto fx3(const unsigned short& sp, - const real_t& time, - const coord_t& x_Ph) const -> real_t { - (void)sp; - (void)time; - (void)x_Ph; - return ZERO; - } - }; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines = traits::compatible_with::value; - static constexpr auto metrics = traits::compatible_with::value; - static constexpr auto dimensions = traits::compatible_with::value; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - InitFields init_flds; - ExtForce ext_force; - - inline PGen(const SimulationParams& p, const Metadomain& global_domain) - : arch::ProblemGenerator { p } - , init_flds { params.template get("setup.amplitude", 1.0), - global_domain.mesh().extent(in::x2).second - - global_domain.mesh().extent(in::x2).first, - params.template get("setup.kx2", 2) } - , ext_force {} {} - }; - -} // namespace user - -#endif diff --git a/setups/srpic/langmuir/langmuir.py b/setups/srpic/langmuir/langmuir.py deleted file mode 100644 index a880bc00b..000000000 --- a/setups/srpic/langmuir/langmuir.py +++ /dev/null @@ -1,17 +0,0 @@ -import nt2.read as nt2r -import matplotlib.pyplot as plt - -data = nt2r.Data("langmuir.h5") - - -def plot(ti, d): - # for 2D - fig = plt.figure(figsize=(10, 5), dpi=150) - ax = fig.add_subplot(211) - d.Rho.isel(t=ti).plot(ax=ax, cmap="inferno", vmin=0, vmax=4) - ax = fig.add_subplot(212) - d.Ex.isel(t=ti).plot(ax=ax, cmap="RdBu_r", vmin=-1, vmax=1) - for ax in fig.get_axes()[::2]: - ax.set_aspect("equal") - fig.get_axes()[0].set(xlabel="", xticks=[]) - fig.get_axes()[2].set(title=None) diff --git a/setups/srpic/langmuir/langmuir.toml b/setups/srpic/langmuir/langmuir.toml deleted file mode 100644 index 2f3520fc5..000000000 --- a/setups/srpic/langmuir/langmuir.toml +++ /dev/null @@ -1,55 +0,0 @@ -[simulation] - name = "langmuir" - engine = "srpic" - runtime = 1.0 - -[grid] - resolution = [2048, 512] - extent = [[0.0, 1.0], [0.0, 0.25]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["PERIODIC"], ["PERIODIC"]] - particles = [["PERIODIC"], ["PERIODIC"]] - -[scales] - larmor0 = 0.1 - skindepth0 = 0.01 - -[algorithms] - current_filters = 4 - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 14.0 - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e7 - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e7 - -[setup] - vmax = 0.1 - nx1 = 4 - nx2 = 2 - -[output] - format = "hdf5" - interval_time = 0.0025 - - [output.fields] - quantities = ["Rho", "E"] - -[diagnostics] - colored_stdout = true diff --git a/setups/srpic/langmuir/pgen.hpp b/setups/srpic/langmuir/pgen.hpp deleted file mode 100644 index 2a23b17f7..000000000 --- a/setups/srpic/langmuir/pgen.hpp +++ /dev/null @@ -1,124 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" -#include "utils/numeric.h" - -#include "archetypes/energy_dist.h" -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "framework/domain/domain.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct SinEDist : public arch::EnergyDistribution { - SinEDist(const M& metric, - real_t v_max, - const std::vector& n, - const std::vector& s) - : arch::EnergyDistribution { metric } - , v_max { v_max } - , kx1 { s.size() > 0 ? static_cast(constant::TWO_PI) * - static_cast(n[0]) / s[0] - : ZERO } - , kx2 { s.size() > 1 ? static_cast(constant::TWO_PI) * - static_cast(n[1]) / s[1] - : ZERO } - , kx3 { s.size() > 2 ? static_cast(constant::TWO_PI) * - static_cast(n[2]) / s[2] - : ZERO } {} - - Inline void operator()(const coord_t& x_Ph, - vec_t& v, - unsigned short sp) const override { - if (sp == 1) { - const auto k = math::sqrt(SQR(kx1) + SQR(kx2) + SQR(kx3)); - if constexpr (M::Dim == Dim::_1D) { - v[0] = v_max * math::sin(kx1 * x_Ph[0]); - } else if constexpr (M::Dim == Dim::_2D) { - v[0] = v_max * kx1 / k * math::sin(kx1 * x_Ph[0] + kx2 * x_Ph[1]); - v[1] = v_max * kx2 / k * math::sin(kx1 * x_Ph[0] + kx2 * x_Ph[1]); - } else { - v[0] = v_max * kx1 / k * - math::sin(kx1 * x_Ph[0] + kx2 * x_Ph[1] + kx3 * x_Ph[2]); - v[1] = v_max * kx2 / k * - math::sin(kx1 * x_Ph[0] + kx2 * x_Ph[1] + kx3 * x_Ph[2]); - v[2] = v_max * kx3 / k * - math::sin(kx1 * x_Ph[0] + kx2 * x_Ph[1] + kx3 * x_Ph[2]); - } - } else { - v[0] = ZERO; - v[1] = ZERO; - v[2] = ZERO; - } - } - - private: - const real_t v_max, kx1, kx2, kx3; - }; - - template - struct PGen : public arch::ProblemGenerator { - - // compatibility traits for the problem generator - static constexpr auto engines = traits::compatible_with::value; - static constexpr auto metrics = traits::compatible_with::value; - static constexpr auto dimensions = - traits::compatible_with::value; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const real_t sx1, sx2, sx3; - const real_t vmax; - const int nx1, nx2, nx3; - - std::vector svec; - std::vector nvec; - - inline PGen(const SimulationParams& p, const Metadomain& global_domain) - : arch::ProblemGenerator { p } - , sx1 { global_domain.mesh().extent(in::x1).second - - global_domain.mesh().extent(in::x1).first } - , sx2 { global_domain.mesh().extent(in::x2).second - - global_domain.mesh().extent(in::x2).first } - , sx3 { global_domain.mesh().extent(in::x3).second - - global_domain.mesh().extent(in::x3).first } - , vmax { p.get("setup.vmax", 0.01) } - , nx1 { p.get("setup.nx1", 10) } - , nx2 { p.get("setup.nx2", 10) } - , nx3 { p.get("setup.nx3", 10) } { - const auto sxs = std::vector { sx1, sx2, sx3 }; - const auto nxs = std::vector { nx1, nx2, nx3 }; - for (auto d = 0u; d < M::Dim; ++d) { - svec.push_back(sxs[d]); - nvec.push_back(nxs[d]); - } - } - - inline void InitPrtls(Domain& local_domain) { - const auto energy_dist = SinEDist(local_domain.mesh.metric, - vmax, - nvec, - svec); - const auto injector = arch::UniformInjector(energy_dist, - { 1, 2 }); - arch::InjectUniform>(params, - local_domain, - injector, - 1.0); - } - }; - -} // namespace user - -#endif diff --git a/setups/srpic/magnetar/magnetar.py b/setups/srpic/magnetar/magnetar.py deleted file mode 100644 index 0bbf790e5..000000000 --- a/setups/srpic/magnetar/magnetar.py +++ /dev/null @@ -1,9 +0,0 @@ -import nt2.read as nt2r -import matplotlib as mpl - -data = nt2r.Data("magnetar.h5") - -def plot (ti, data): - (data.Bph*(data.r*np.sin(data.th))).isel(t=ti).polar.pcolor( - norm=mpl.colors.Normalize(vmin=-0.075, vmax=0.075), - cmap="PuOr") \ No newline at end of file diff --git a/setups/srpic/magnetar/magnetar.toml b/setups/srpic/magnetar/magnetar.toml deleted file mode 100644 index 2a2260af5..000000000 --- a/setups/srpic/magnetar/magnetar.toml +++ /dev/null @@ -1,108 +0,0 @@ -[simulation] - name = "magnetar" - engine = "srpic" - runtime = 50.0 - -[grid] - resolution = [2048,1024] - extent = [[1.0, 400.0]] - - [grid.metric] - metric = "qspherical" - - [grid.boundaries] - fields = [["ATMOSPHERE", "ABSORB"]] - particles = [["ATMOSPHERE", "ABSORB"]] - - [grid.boundaries.absorb] - ds = 1.0 - - [grid.boundaries.atmosphere] - temperature = 0.1 - density = 40.0 - height = 0.02 - species = [1, 2] - ds = 0.5 - -[scales] - larmor0 = 1e-5 - skindepth0 = 0.01 - -[algorithms] - current_filters = 4 - - [algorithms.timestep] - CFL = 0.5 - - [algorithms.gca] - e_ovr_b_max = 0.9 - larmor_max = 100.0 - -[particles] - ppc0 = 4.0 - use_weights = true - sort_interval = 100 - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 5e7 - pusher = "Boris,GCA" - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 5e7 - pusher = "Boris,GCA" - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 5e7 - pusher = "Boris,GCA" - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 5e7 - pusher = "Boris,GCA" - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 5e7 - pusher = "Boris,GCA" - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 5e7 - pusher = "Boris,GCA" - -[setup] - Bsurf = 1.0 - omega = 0.0125 - pp_thres = 10.0 - gamma_pairs = 1.75 - -[output] - format = "hdf5" - - [output.fields] - interval_time = 0.5 - quantities = ["N_1", "N_2", "N_3", "N_4", "N_5", "N_6", "B", "E", "J"] - - [output.particles] - enable = false - - [output.spectra] - enable = false - -[diagnostics] - interval = 1 diff --git a/setups/srpic/magnetar/pgen.hpp b/setups/srpic/magnetar/pgen.hpp deleted file mode 100644 index cacbb7c9a..000000000 --- a/setups/srpic/magnetar/pgen.hpp +++ /dev/null @@ -1,280 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" -#include "utils/numeric.h" - -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct InitFields { - InitFields(real_t bsurf, real_t rstar) : Bsurf { bsurf }, Rstar { rstar } {} - - Inline auto bx1(const coord_t& x_Ph) const -> real_t { - return Bsurf * math::cos(x_Ph[1]) / CUBE(x_Ph[0] / Rstar); - } - - Inline auto bx2(const coord_t& x_Ph) const -> real_t { - return Bsurf * HALF * math::sin(x_Ph[1]) / CUBE(x_Ph[0] / Rstar); - } - - private: - const real_t Bsurf, Rstar; - }; - - template - struct DriveFields : public InitFields { - DriveFields(real_t time, real_t bsurf, real_t rstar, real_t omega) - : InitFields { bsurf, rstar } - , time { time } - , Omega { omega } {} - - using InitFields::bx1; - using InitFields::bx2; - - Inline auto bx3(const coord_t&) const -> real_t { - return ZERO; - } - - Inline auto ex1(const coord_t& x_Ph) const -> real_t { - auto sigma = (x_Ph[1] - HALF * constant::PI) / - (static_cast(0.2) * constant::PI); - return Omega * bx2(x_Ph) * x_Ph[0] * math::sin(x_Ph[1]) * sigma * - math::exp((ONE - SQR(SQR(sigma))) * INV_4); - } - - Inline auto ex2(const coord_t& x_Ph) const -> real_t { - auto sigma = (x_Ph[1] - HALF * constant::PI) / - (static_cast(0.2) * constant::PI); - return -Omega * bx1(x_Ph) * x_Ph[0] * math::sin(x_Ph[1]) * sigma * - math::exp((ONE - SQR(SQR(sigma))) * INV_4); - } - - Inline auto ex3(const coord_t&) const -> real_t { - return ZERO; - } - - private: - const real_t time, Omega; - }; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines { traits::compatible_with::value }; - static constexpr auto metrics { - traits::compatible_with::value - }; - static constexpr auto dimensions { traits::compatible_with::value }; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const Metadomain& global_domain; - - const real_t Bsurf, Rstar, Omega, gamma_pairs, pp_thres; - InitFields init_flds; - - inline PGen(const SimulationParams& p, const Metadomain& m) - : arch::ProblemGenerator(p) - , global_domain { m } - , Bsurf { p.template get("setup.Bsurf", ONE) } - , Rstar { m.mesh().extent(in::x1).first } - , Omega { p.template get("setup.omega") } - , pp_thres { p.template get("setup.pp_thres") } - , gamma_pairs { p.template get("setup.gamma_pairs") } - , init_flds { Bsurf, Rstar } { - } - - inline PGen() {} - - auto FieldDriver(real_t time) const -> DriveFields { - const real_t omega_t = - Omega * - ((ONE - math::tanh((static_cast(5.0) - time) * HALF)) * - (ONE + (-ONE + math::tanh((static_cast(45.0) - time) * HALF)) * - HALF)) * - HALF; - return DriveFields { time, Bsurf, Rstar, omega_t }; - } - - void CustomPostStep(std::size_t , long double, Domain& domain) { - - // Ad-hoc PP kernel - { - - auto& species2_e = domain.species[2]; - auto& species2_p = domain.species[3]; - auto& species3_e = domain.species[4]; - auto& species3_p = domain.species[5]; - auto metric = domain.mesh.metric; - auto pp_thres_ = this->pp_thres; - auto gamma_pairs_ = this->gamma_pairs; - - for (std::size_t s { 0 }; s < 6; ++s) { - if (s == 1) { - continue; - } - - array_t elec_ind("elec_ind"); - array_t pos_ind("pos_ind"); - - auto offset_e = species3_e.npart(); - auto offset_p = species3_p.npart(); - - auto ux1_e = species3_e.ux1; - auto ux2_e = species3_e.ux2; - auto ux3_e = species3_e.ux3; - auto i1_e = species3_e.i1; - auto i2_e = species3_e.i2; - auto dx1_e = species3_e.dx1; - auto dx2_e = species3_e.dx2; - auto phi_e = species3_e.phi; - auto weight_e = species3_e.weight; - auto tag_e = species3_e.tag; - - auto ux1_p = species3_p.ux1; - auto ux2_p = species3_p.ux2; - auto ux3_p = species3_p.ux3; - auto i1_p = species3_p.i1; - auto i2_p = species3_p.i2; - auto dx1_p = species3_p.dx1; - auto dx2_p = species3_p.dx2; - auto phi_p = species3_p.phi; - auto weight_p = species3_p.weight; - auto tag_p = species3_p.tag; - - if (s == 0) { - - offset_e = species2_e.npart(); - offset_p = species2_p.npart(); - - ux1_e = species2_e.ux1; - ux2_e = species2_e.ux2; - ux3_e = species2_e.ux3; - i1_e = species2_e.i1; - i2_e = species2_e.i2; - dx1_e = species2_e.dx1; - dx2_e = species2_e.dx2; - phi_e = species2_e.phi; - weight_e = species2_e.weight; - tag_e = species2_e.tag; - - ux1_p = species2_p.ux1; - ux2_p = species2_p.ux2; - ux3_p = species2_p.ux3; - i1_p = species2_p.i1; - i2_p = species2_p.i2; - dx1_p = species2_p.dx1; - dx2_p = species2_p.dx2; - phi_p = species2_p.phi; - weight_p = species2_p.weight; - tag_p = species2_p.tag; - - } - - auto& species = domain.species[s]; - auto ux1 = species.ux1; - auto ux2 = species.ux2; - auto ux3 = species.ux3; - auto i1 = species.i1; - auto i2 = species.i2; - auto dx1 = species.dx1; - auto dx2 = species.dx2; - auto phi = species.phi; - auto weight = species.weight; - auto tag = species.tag; - - Kokkos::parallel_for( - "InjectPairs", species.rangeActiveParticles(), Lambda(index_t p) { - if (tag(p) == ParticleTag::dead) { - return; - } - - auto px = ux1(p); - auto py = ux2(p); - auto pz = ux3(p); - auto gamma = math::sqrt(ONE + SQR(px) + SQR(py) + SQR(pz)); - - const coord_t xCd{ - static_cast(i1(p)) + dx1(p), - static_cast(i2(p)) + dx2(p)}; - - coord_t xPh { ZERO }; - metric.template convert(xCd, xPh); - - if ((gamma > pp_thres_) && (math::sin(xPh[1]) > 0.1)) { - - auto new_gamma = gamma - 2.0 * gamma_pairs_; - auto new_fac = math::sqrt(SQR(new_gamma) - 1.0) / math::sqrt(SQR(gamma) - 1.0); - auto pair_fac = math::sqrt(SQR(gamma_pairs_) - 1.0) / math::sqrt(SQR(gamma) - 1.0); - - auto elec_p = Kokkos::atomic_fetch_add(&elec_ind(), 1); - auto pos_p = Kokkos::atomic_fetch_add(&pos_ind(), 1); - - i1_e(elec_p + offset_e) = i1(p); - dx1_e(elec_p + offset_e) = dx1(p); - i2_e(elec_p + offset_e) = i2(p); - dx2_e(elec_p + offset_e) = dx2(p); - phi_e(elec_p + offset_e) = phi(p); - ux1_e(elec_p + offset_e) = px * pair_fac; - ux2_e(elec_p + offset_e) = py * pair_fac; - ux3_e(elec_p + offset_e) = pz * pair_fac; - weight_e(elec_p + offset_e) = weight(p); - tag_e(elec_p + offset_e) = ParticleTag::alive; - - i1_p(pos_p + offset_p) = i1(p); - dx1_p(pos_p + offset_p) = dx1(p); - i2_p(pos_p + offset_p) = i2(p); - dx2_p(pos_p + offset_p) = dx2(p); - phi_p(pos_p + offset_p) = phi(p); - ux1_p(pos_p + offset_p) = px * pair_fac; - ux2_p(pos_p + offset_p) = py * pair_fac; - ux3_p(pos_p + offset_p) = pz * pair_fac; - weight_p(pos_p + offset_p) = weight(p); - tag_p(pos_p + offset_p) = ParticleTag::alive; - - ux1(p) *= new_fac; - ux2(p) *= new_fac; - ux3(p) *= new_fac; - } - - }); - - auto elec_ind_h = Kokkos::create_mirror(elec_ind); - Kokkos::deep_copy(elec_ind_h, elec_ind); - if (s == 0) { - species2_e.set_npart(offset_e + elec_ind_h()); - } else { - species3_e.set_npart(offset_e + elec_ind_h()); - } - - auto pos_ind_h = Kokkos::create_mirror(pos_ind); - Kokkos::deep_copy(pos_ind_h, pos_ind); - if (s == 0) { - species2_p.set_npart(offset_p + pos_ind_h()); - } else { - species3_p.set_npart(offset_p + pos_ind_h()); - } - - - } - } // Ad-hoc PP kernel - } - - }; - -} // namespace user - -#endif diff --git a/setups/srpic/pgen_srpic_example.hpp b/setups/srpic/pgen_srpic_example.hpp deleted file mode 100644 index de3bae408..000000000 --- a/setups/srpic/pgen_srpic_example.hpp +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/traits.h" - -#include "archetypes/problem_generator.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines { traits::compatible_with::value }; - static constexpr auto metrics { - traits::compatible_with::value - }; - static constexpr auto dimensions { - traits::compatible_with::value - }; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - inline PGen(const SimulationParams& p, const Metadomain&) - : arch::ProblemGenerator(p) {} - - inline PGen() {} - }; - -} // namespace user - -#endif diff --git a/setups/srpic/shock/pgen.hpp b/setups/srpic/shock/pgen.hpp deleted file mode 100644 index f07b99878..000000000 --- a/setups/srpic/shock/pgen.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/traits.h" - -#include "archetypes/energy_dist.h" -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines { traits::compatible_with::value }; - static constexpr auto metrics { traits::compatible_with::value }; - static constexpr auto dimensions { - traits::compatible_with::value - }; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const real_t drift_ux, temperature; - - inline PGen(const SimulationParams& p, const Metadomain& m) - : arch::ProblemGenerator(p) - , drift_ux { p.template get("setup.drift_ux") } - , temperature { p.template get("setup.temperature") } {} - - inline PGen() {} - - inline void InitPrtls(Domain& local_domain) { - const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - temperature, - -drift_ux, - in::x1); - const auto injector = arch::UniformInjector( - energy_dist, - { 1, 2 }); - arch::InjectUniform>( - params, - local_domain, - injector, - 1.0); - } - }; - -} // namespace user - -#endif diff --git a/setups/srpic/shock/shock.toml b/setups/srpic/shock/shock.toml deleted file mode 100644 index f48edb2d6..000000000 --- a/setups/srpic/shock/shock.toml +++ /dev/null @@ -1,51 +0,0 @@ -[simulation] - name = "shock" - engine = "srpic" - runtime = 50.0 - -[grid] - resolution = [2048, 128] - extent = [[0.0, 10.0], [-0.3125, 0.3125]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["CONDUCTOR", "ABSORB"], ["PERIODIC"]] - particles = [["REFLECT", "ABSORB"], ["PERIODIC"]] - -[scales] - larmor0 = 1e-2 - skindepth0 = 1e-2 - -[algorithms] - current_filters = 8 - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 16.0 - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e8 - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e8 - -[setup] - drift_ux = 0.1 - temperature = 1e-3 - -[output] - interval_time = 0.1 - format = "hdf5" - - [output.fields] - quantities = ["N_1", "N_2", "E", "B", "T0i_1", "T0i_2", "J"] diff --git a/setups/srpic/turbulence/pgen.hpp b/setups/srpic/turbulence/pgen.hpp deleted file mode 100644 index bd8b0ce41..000000000 --- a/setups/srpic/turbulence/pgen.hpp +++ /dev/null @@ -1,351 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" -#include "utils/numeric.h" - -#include "archetypes/energy_dist.h" -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "framework/domain/metadomain.h" - -#include -#include - -enum { - REAL = 0, - IMAG = 1 -}; - -namespace user { - using namespace ntt; - - template - struct ExtForce { - ExtForce(array_t amplitudes, real_t SX1, real_t SX2, real_t SX3) - : amps { amplitudes } - , sx1 { SX1 } - , sx2 { SX2 } - , sx3 { SX3 } {} - - const std::vector species { 1, 2 }; - - ExtForce() = default; - - Inline auto fx1(const unsigned short&, - const real_t&, - const coord_t& x_Ph) const -> real_t { - real_t k01 = ONE * constant::TWO_PI / sx1; - real_t k02 = ZERO * constant::TWO_PI / sx2; - real_t k03 = ZERO * constant::TWO_PI / sx3; - real_t k04 = ONE; - real_t k11 = ZERO * constant::TWO_PI / sx1; - real_t k12 = ONE * constant::TWO_PI / sx2; - real_t k13 = ZERO * constant::TWO_PI / sx3; - real_t k14 = ONE; - real_t k21 = ZERO * constant::TWO_PI / sx1; - real_t k22 = ZERO * constant::TWO_PI / sx2; - real_t k23 = ONE * constant::TWO_PI / sx3; - real_t k24 = ONE; - - // return 0.1 * cos(2.0 * constant::TWO_PI * x_Ph[1]); - - return (k14 * amps(0, REAL) * - math::cos(k11 * x_Ph[0] + k12 * x_Ph[1] + k13 * x_Ph[2]) + - k14 * amps(0, IMAG) * - math::sin(k11 * x_Ph[0] + k12 * x_Ph[1] + k13 * x_Ph[2])) + - (k24 * amps(1, REAL) * - math::cos(k21 * x_Ph[0] + k22 * x_Ph[1] + k23 * x_Ph[2]) + - k24 * amps(1, IMAG) * - math::sin(k21 * x_Ph[0] + k22 * x_Ph[1] + k23 * x_Ph[2])); - } - - Inline auto fx2(const unsigned short&, - const real_t&, - const coord_t& x_Ph) const -> real_t { - real_t k01 = ONE * constant::TWO_PI / sx1; - real_t k02 = ZERO * constant::TWO_PI / sx2; - real_t k03 = ZERO * constant::TWO_PI / sx3; - real_t k04 = ONE; - real_t k11 = ZERO * constant::TWO_PI / sx1; - real_t k12 = ONE * constant::TWO_PI / sx2; - real_t k13 = ZERO * constant::TWO_PI / sx3; - real_t k14 = ONE; - real_t k21 = ZERO * constant::TWO_PI / sx1; - real_t k22 = ZERO * constant::TWO_PI / sx2; - real_t k23 = ONE * constant::TWO_PI / sx3; - real_t k24 = ONE; - return (k04 * amps(2, REAL) * - math::cos(k01 * x_Ph[0] + k02 * x_Ph[1] + k03 * x_Ph[2]) + - k04 * amps(2, IMAG) * - math::sin(k01 * x_Ph[0] + k02 * x_Ph[1] + k03 * x_Ph[2])) + - (k24 * amps(3, REAL) * - math::cos(k21 * x_Ph[0] + k22 * x_Ph[1] + k23 * x_Ph[2]) + - k24 * amps(3, IMAG) * - math::sin(k21 * x_Ph[0] + k22 * x_Ph[1] + k23 * x_Ph[2])); - // return ZERO; - } - - Inline auto fx3(const unsigned short&, - const real_t&, - const coord_t& x_Ph) const -> real_t { - real_t k01 = ONE * constant::TWO_PI / sx1; - real_t k02 = ZERO * constant::TWO_PI / sx2; - real_t k03 = ZERO * constant::TWO_PI / sx3; - real_t k04 = ONE; - real_t k11 = ZERO * constant::TWO_PI / sx1; - real_t k12 = ONE * constant::TWO_PI / sx2; - real_t k13 = ZERO * constant::TWO_PI / sx3; - real_t k14 = ONE; - real_t k21 = ZERO * constant::TWO_PI / sx1; - real_t k22 = ZERO * constant::TWO_PI / sx2; - real_t k23 = ONE * constant::TWO_PI / sx3; - real_t k24 = ONE; - return (k04 * amps(4, REAL) * - math::cos(k01 * x_Ph[0] + k02 * x_Ph[1] + k03 * x_Ph[2]) + - k04 * amps(4, IMAG) * - math::sin(k01 * x_Ph[0] + k02 * x_Ph[1] + k03 * x_Ph[2])) + - (k14 * amps(5, REAL) * - math::cos(k11 * x_Ph[0] + k12 * x_Ph[1] + k13 * x_Ph[2]) + - k14 * amps(5, IMAG) * - math::sin(k11 * x_Ph[0] + k12 * x_Ph[1] + k13 * x_Ph[2])); - // return ZERO; - } - - private: - array_t amps; - const real_t sx1, sx2, sx3; - }; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines = traits::compatible_with::value; - static constexpr auto metrics = traits::compatible_with::value; - static constexpr auto dimensions = traits::compatible_with::value; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const real_t SX1, SX2, SX3; - const real_t temperature, machno; - const unsigned int nmodes; - const real_t amp0, phi0; - array_t amplitudes; - ExtForce ext_force; - const real_t dt; - - inline PGen(const SimulationParams& params, const Metadomain& global_domain) - : arch::ProblemGenerator { params } - , SX1 { global_domain.mesh().extent(in::x1).second - - global_domain.mesh().extent(in::x1).first } - , SX2 { global_domain.mesh().extent(in::x2).second - - global_domain.mesh().extent(in::x2).first } - , SX3 { global_domain.mesh().extent(in::x3).second - - global_domain.mesh().extent(in::x3).first } - // , SX1 { 2.0 } - // , SX2 { 2.0 } - // , SX3 { 2.0 } - , temperature { params.template get("problem.temperature", 0.1) } - , machno { params.template get("problem.machno", 0.1) } - , nmodes { params.template get("setup.nmodes", 6) } - , amp0 { machno * temperature / static_cast(nmodes) } - , phi0 { INV_4 } // !TODO: randomize - , amplitudes { "DrivingModes", nmodes } - , ext_force { amplitudes, SX1, SX2, SX3 } - , dt { params.template get("algorithms.timestep.dt") } { - Init(); - } - - void Init() { - // initializing amplitudes - auto amplitudes_ = amplitudes; - const auto amp0_ = amp0; - const auto phi0_ = phi0; - Kokkos::parallel_for( - "RandomAmplitudes", - amplitudes.extent(0), - Lambda(index_t i) { - amplitudes_(i, REAL) = amp0_ * math::cos(phi0_); - amplitudes_(i, IMAG) = amp0_ * math::sin(phi0_); - }); - } - - inline void InitPrtls(Domain& local_domain) { - { - const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - temperature); - const auto injector = arch::UniformInjector( - energy_dist, - { 1, 2 }); - const real_t ndens = 1.0; - arch::InjectUniform(params, - local_domain, - injector, - ndens); - } - - { - // const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, - // local_domain.random_pool, - // temperature*10); - // // const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, - // // local_domain.random_pool, - // // temperature * 2, - // // 10.0, - // // 1); - // const auto injector = arch::UniformInjector( - // energy_dist, - // { 1, 2 }); - // const real_t ndens = 0.01; - // arch::InjectUniform(params, - // local_domain, - // injector, - // ndens); - } - } - - void CustomPostStep(std::size_t time, long double, Domain& domain) { - auto omega0 = 0.6 * math::sqrt(temperature * machno * constant::TWO_PI / SX1); - auto gamma0 = 0.5 * math::sqrt(temperature * machno * constant::TWO_PI / SX2); - auto sigma0 = amp0 * math::sqrt(static_cast(nmodes) * gamma0); - auto pool = domain.random_pool; - Kokkos::parallel_for( - "RandomAmplitudes", - amplitudes.extent(0), - ClassLambda(index_t i) { - auto rand_gen = pool.get_state(); - const auto unr = Random(rand_gen) - HALF; - const auto uni = Random(rand_gen) - HALF; - pool.free_state(rand_gen); - const auto ampr_prev = amplitudes(i, REAL); - const auto ampi_prev = amplitudes(i, IMAG); - amplitudes(i, REAL) = (ampr_prev * math::cos(omega0 * dt) + - ampi_prev * math::sin(omega0 * dt)) * - math::exp(-gamma0 * dt) + - unr * sigma0; - amplitudes(i, IMAG) = (-ampr_prev * math::sin(omega0 * dt) + - ampi_prev * math::cos(omega0 * dt)) * - math::exp(-gamma0 * dt) + - uni * sigma0; - }); - - // auto fext_en_total = ZERO; - // for (auto& species : domain.species) { - // auto pld = species.pld[0]; - // auto weight = species.weight; - // Kokkos::parallel_reduce( - // "ExtForceEnrg", - // species.rangeActiveParticles(), - // ClassLambda(index_t p, real_t & fext_en) { - // fext_en += pld(p) * weight(p); - // }, - // fext_en_total); - // } - - // auto pkin_en_total = ZERO; - // for (auto& species : domain.species) { - // auto ux1 = species.ux1; - // auto ux2 = species.ux2; - // auto ux3 = species.ux3; - // auto weight = species.weight; - // Kokkos::parallel_reduce( - // "KinEnrg", - // species.rangeActiveParticles(), - // ClassLambda(index_t p, real_t & pkin_en) { - // pkin_en += (math::sqrt(ONE + SQR(ux1(p)) + SQR(ux2(p)) + SQR(ux3(p))) - - // ONE) * - // weight(p); - // }, - // pkin_en_total); - // } - // // Weight the macroparticle integral by sim parameters - // pkin_en_total /= params.template get("scales.n0"); - - // std::ofstream myfile; - // if (time == 0) { - // myfile.open("fextenrg.txt"); - // } else { - // myfile.open("fextenrg.txt", std::ios_base::app); - // } - // myfile << fext_en_total << std::endl; - // myfile.close(); - - // if (time == 0) { - // myfile.open("kenrg.txt"); - // } else { - // myfile.open("kenrg.txt", std::ios_base::app); - // } - // myfile << pkin_en_total << std::endl; - // myfile.close(); - - // if constexpr (D == Dim::_3D) { - - // auto metric = domain.mesh.metric; - - // auto benrg_total = ZERO; - // auto EB = domain.fields.em; - // Kokkos::parallel_reduce( - // "BEnrg", - // domain.mesh.rangeActiveCells(), - // Lambda(index_t i1, index_t i2, index_t i3, real_t & benrg) { - // coord_t x_Cd { ZERO }; - // vec_t b_Cntrv { EB(i1, i2, i3, em::bx1), - // EB(i1, i2, i3, em::bx2), - // EB(i1, i2, i3, em::bx3) }; - // vec_t b_XYZ; - // metric.template transform(x_Cd, - // b_Cntrv, - // b_XYZ); - // benrg += (SQR(b_XYZ[0]) + SQR(b_XYZ[1]) + SQR(b_XYZ[2])); - // }, - // benrg_total); - // benrg_total *= params.template get("scales.sigma0") * HALF; - - // if (time == 0) { - // myfile.open("bsqenrg.txt"); - // } else { - // myfile.open("bsqenrg.txt", std::ios_base::app); - // } - // myfile << benrg_total << std::endl; - // myfile.close(); - // auto eenrg_total = ZERO; - // Kokkos::parallel_reduce( - // "BEnrg", - // domain.mesh.rangeActiveCells(), - // Lambda(index_t i1, index_t i2, index_t i3, real_t & eenrg) { - // coord_t x_Cd { ZERO }; - // vec_t e_Cntrv { EB(i1, i2, i3, em::ex1), - // EB(i1, i2, i3, em::ex2), - // EB(i1, i2, i3, em::ex3) }; - // vec_t e_XYZ; - // metric.template transform(x_Cd, - // e_Cntrv, - // e_XYZ); - // eenrg += (SQR(e_XYZ[0]) + SQR(e_XYZ[1]) + SQR(e_XYZ[2])); - // }, - // eenrg_total); - // eenrg_total *= params.template get("scales.sigma0") * HALF; - - - // if (time == 0) { - // myfile.open("esqenrg.txt"); - // } else { - // myfile.open("esqenrg.txt", std::ios_base::app); - // } - // myfile << eenrg_total << std::endl; - // myfile.close(); - // } - } - }; - -} // namespace user - -#endif \ No newline at end of file diff --git a/setups/srpic/turbulence/turbulence.toml b/setups/srpic/turbulence/turbulence.toml deleted file mode 100644 index a28afde15..000000000 --- a/setups/srpic/turbulence/turbulence.toml +++ /dev/null @@ -1,49 +0,0 @@ -[simulation] - name = "turbulence" - engine = "srpic" - runtime = 20.0 - -[grid] - resolution = [184, 184, 184] - extent = [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["PERIODIC"], ["PERIODIC"], ["PERIODIC"]] - particles = [["PERIODIC"], ["PERIODIC"], ["PERIODIC"]] - -[scales] - larmor0 = 0.02 - skindepth0 = 0.02 - -[algorithms] - current_filters = 4 - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 32.0 - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e8 - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e8 - -[setup] - -[output] - format = "hdf5" - interval_time = 0.1 - - [output.fields] - quantities = ["N_1", "N_2", "E", "B", "J", "T00_1", "T00_2"] diff --git a/setups/srpic/weibel/pgen.hpp b/setups/srpic/weibel/pgen.hpp deleted file mode 100644 index 21acc8032..000000000 --- a/setups/srpic/weibel/pgen.hpp +++ /dev/null @@ -1,75 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" - -#include "archetypes/energy_dist.h" -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "framework/domain/domain.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct PGen : public arch::ProblemGenerator { - - // compatibility traits for the problem generator - static constexpr auto engines = traits::compatible_with::value; - static constexpr auto metrics = traits::compatible_with::value; - static constexpr auto dimensions = - traits::compatible_with::value; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const real_t temp_1, temp_2; - const real_t drift_u_1, drift_u_2; - - inline PGen(const SimulationParams& p, const Metadomain& global_domain) - : arch::ProblemGenerator { p } - , temp_1 { p.template get("setup.temp_1") } - , temp_2 { p.template get("setup.temp_2") } - , drift_u_1 { p.template get("setup.drift_u_1") } - , drift_u_2 { p.template get("setup.drift_u_2") } {} - - inline void InitPrtls(Domain& local_domain) { - const auto energy_dist_1 = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - temp_1, - -drift_u_1, - in::x3); - const auto energy_dist_2 = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - temp_2, - drift_u_2, - in::x3); - const auto injector_1 = arch::UniformInjector( - energy_dist_1, - { 1, 2 }); - const auto injector_2 = arch::UniformInjector( - energy_dist_2, - { 3, 4 }); - arch::InjectUniform>( - params, - local_domain, - injector_1, - HALF); - arch::InjectUniform>( - params, - local_domain, - injector_2, - HALF); - } - }; - -} // namespace user - -#endif diff --git a/setups/srpic/weibel/weibel.toml b/setups/srpic/weibel/weibel.toml deleted file mode 100644 index c8e2506f6..000000000 --- a/setups/srpic/weibel/weibel.toml +++ /dev/null @@ -1,74 +0,0 @@ -[simulation] - name = "weibel" - engine = "srpic" - runtime = 100.0 - -[grid] - resolution = [512, 512] - extent = [[-10.0, 10.0], [-10.0, 10.0]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["PERIODIC"], ["PERIODIC"]] - particles = [["PERIODIC"], ["PERIODIC"]] - -[scales] - larmor0 = 1.0 - skindepth0 = 1.0 - -[algorithms] - current_filters = 4 - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 16.0 - - [[particles.species]] - label = "e-_p" - mass = 1.0 - charge = -1.0 - maxnpart = 1e7 - - [[particles.species]] - label = "e+_p" - mass = 1.0 - charge = 1.0 - maxnpart = 1e7 - - [[particles.species]] - label = "e-_b" - mass = 1.0 - charge = -1.0 - maxnpart = 1e7 - - [[particles.species]] - label = "e+_b" - mass = 1.0 - charge = 1.0 - maxnpart = 1e7 - -[setup] - drift_u_1 = 0.2 - drift_u_2 = 0.2 - temp_1 = 1e-4 - temp_2 = 1e-4 - -[output] - format = "hdf5" - interval_time = 0.25 - - [output.fields] - quantities = ["N_1_2", "N_3_4", "B", "E", "T0i_1", "T0i_3"] - - [output.particles] - enable = false - - [output.spectra] - enable = false - -[diagnostics] - colored_stdout = true diff --git a/setups/tests/blob/blob.toml b/setups/tests/blob/blob.toml deleted file mode 100644 index fffa5fff1..000000000 --- a/setups/tests/blob/blob.toml +++ /dev/null @@ -1,66 +0,0 @@ -[simulation] - name = "blob-1x1x2" - engine = "srpic" - runtime = 5.0 - - [simulation.domain] - decomposition = [1, 1, 2] - -[grid] - resolution = [128, 192, 64] - # extent = [[1.0, 10.0]] - extent = [[-2.0, 2.0], [-3.0, 3.0], [-1.0, 1.0]] - - [grid.metric] - # metric = "qspherical" - metric = "minkowski" - - [grid.boundaries] - # fields = [["ATMOSPHERE", "ABSORB"]] - # particles = [["ATMOSPHERE", "ABSORB"]] - fields = [["PERIODIC"], ["PERIODIC"], ["PERIODIC"]] - particles = [["PERIODIC"], ["PERIODIC"], ["PERIODIC"]] - - # [grid.boundaries.absorb] - # ds = 1.0 - -[scales] - larmor0 = 2e-5 - skindepth0 = 0.01 - -[algorithms] - current_filters = 4 - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 20.0 - # use_weights = true - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e7 - pusher = "Boris" - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e7 - pusher = "Boris" - -[setup] - xi_min = [0.55, 1.85, -0.25] - xi_max = [0.65, 2.3, -0.1] - v1 = [0.25, -0.55, 0.0] - v2 = [-0.75, -0.15, 0.0] - -[output] - format = "hdf5" - interval_time = 0.02 - - [output.fields] - quantities = ["Nppc_1", "Nppc_2", "E", "B", "J"] diff --git a/setups/tests/blob/pgen.hpp b/setups/tests/blob/pgen.hpp deleted file mode 100644 index d07240bfd..000000000 --- a/setups/tests/blob/pgen.hpp +++ /dev/null @@ -1,121 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" - -#include "archetypes/energy_dist.h" -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "archetypes/spatial_dist.h" -#include "framework/domain/metadomain.h" - -#include - -namespace user { - using namespace ntt; - - template - struct Beam : public arch::EnergyDistribution { - Beam(const M& metric, - const std::vector& v1_vec, - const std::vector& v2_vec) - : arch::EnergyDistribution { metric } { - std::copy(v1_vec.begin(), v1_vec.end(), v1); - std::copy(v2_vec.begin(), v2_vec.end(), v2); - } - - Inline void operator()(const coord_t&, - vec_t& v_Ph, - unsigned short sp) const override { - if (sp == 1) { - v_Ph[0] = v1[0]; - v_Ph[1] = v1[1]; - v_Ph[2] = v1[2]; - } else { - v_Ph[0] = v2[0]; - v_Ph[1] = v2[1]; - v_Ph[2] = v2[2]; - } - } - - private: - vec_t v1; - vec_t v2; - }; - - template - struct PointDistribution : public arch::SpatialDistribution { - PointDistribution(const M& metric, - const std::vector& xi_min, - const std::vector& xi_max) - : arch::SpatialDistribution { metric } { - std::copy(xi_min.begin(), xi_min.end(), x_min); - std::copy(xi_max.begin(), xi_max.end(), x_max); - } - - Inline auto operator()(const coord_t& x_Ph) const -> real_t override { - auto fill = true; - for (auto d = 0u; d < M::Dim; ++d) { - fill &= x_Ph[d] > x_min[d] and x_Ph[d] < x_max[d]; - } - return fill ? ONE : ZERO; - } - - private: - tuple_t x_min; - tuple_t x_max; - }; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines { traits::compatible_with::value }; - static constexpr auto metrics { - traits::compatible_with::value - }; - static constexpr auto dimensions { - traits::compatible_with::value - }; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const std::vector xi_min; - const std::vector xi_max; - const std::vector v1; - const std::vector v2; - - inline PGen(const SimulationParams& p, const Metadomain& m) - : arch::ProblemGenerator(p) - , xi_min { p.template get>("setup.xi_min") } - , xi_max { p.template get>("setup.xi_max") } - , v1 { p.template get>("setup.v1") } - , v2 { p.template get>("setup.v2") } {} - - inline void InitPrtls(Domain& domain) { - const auto energy_dist = Beam(domain.mesh.metric, v1, v2); - const auto spatial_dist = PointDistribution(domain.mesh.metric, - xi_min, - xi_max); - const auto injector = arch::NonUniformInjector( - energy_dist, - spatial_dist, - { 1, 2 }); - - arch::InjectNonUniform>( - params, - domain, - injector, - 1.0); - } - }; - -} // namespace user - -#endif diff --git a/setups/tests/customout/customout.toml b/setups/tests/customout/customout.toml deleted file mode 100644 index 497b96dc2..000000000 --- a/setups/tests/customout/customout.toml +++ /dev/null @@ -1,50 +0,0 @@ -[simulation] - name = "customout" - engine = "srpic" - runtime = 10.0 - -[grid] - resolution = [256, 256] - extent = [[-1.0, 1.0], [-1.0, 1.0]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["PERIODIC"], ["PERIODIC"]] - particles = [["PERIODIC"], ["PERIODIC"]] - -[scales] - larmor0 = 0.01 - skindepth0 = 0.01 - -[algorithms] - current_filters = 4 - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 20.0 - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e7 - pusher = "Boris" - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e7 - pusher = "Boris" - -[output] - format = "hdf5" - interval_time = 0.02 - - [output.fields] - quantities = ["E", "B", "J"] - custom = ["mybuff", "EdotB+1"] diff --git a/setups/tests/customout/pgen.hpp b/setups/tests/customout/pgen.hpp deleted file mode 100644 index 22c8f6564..000000000 --- a/setups/tests/customout/pgen.hpp +++ /dev/null @@ -1,86 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" - -#include "archetypes/problem_generator.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines { traits::compatible_with::value }; - static constexpr auto metrics { traits::compatible_with::value }; - static constexpr auto dimensions { traits::compatible_with::value }; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - array_t cbuff; - - inline PGen(const SimulationParams& p, const Metadomain&) - : arch::ProblemGenerator(p) {} - - inline PGen() {} - - void CustomPostStep(std::size_t step, long double, Domain& domain) { - if (step == 0) { - // allocate the array at time = 0 - cbuff = array_t("cbuff", - domain.mesh.n_all(in::x1), - domain.mesh.n_all(in::x2)); - } - // fill with zeros - Kokkos::deep_copy(cbuff, ZERO); - // populate the array atomically (here it's not strictly necessary) - auto cbuff_sc = Kokkos::Experimental::create_scatter_view(cbuff); - Kokkos::parallel_for( - "FillCbuff", - domain.mesh.rangeActiveCells(), - Lambda(index_t i1, index_t i2) { - auto cbuff_acc = cbuff_sc.access(); - cbuff_acc(i1, i2) += static_cast(i1 + i2); - }); - Kokkos::Experimental::contribute(cbuff, cbuff_sc); - } - - void CustomFieldOutput(const std::string& name, - ndfield_t buffer, - std::size_t index, - const Domain& domain) { - printf("CustomFieldOutput: %s\n", name.c_str()); - // examples for 2D - if (name == "mybuff") { - // copy the custom buffer to the buffer output - Kokkos::deep_copy(Kokkos::subview(buffer, Kokkos::ALL, Kokkos::ALL, index), - cbuff); - } else if (name == "EdotB+1") { - // calculate the custom buffer from EM fields - const auto& EM = domain.fields.em; - Kokkos::parallel_for( - "EdotB+1", - domain.mesh.rangeActiveCells(), - Lambda(index_t i1, index_t i2) { - buffer(i1, i2, index) = EM(i1, i2, em::ex1) * EM(i1, i2, em::bx1) + - EM(i1, i2, em::ex2) * EM(i1, i2, em::bx2) + - EM(i1, i2, em::ex3) * EM(i1, i2, em::bx3) + - ONE; - }); - } else { - raise::Error("Custom output not provided", HERE); - } - } - }; - -} // namespace user - -#endif diff --git a/setups/tests/deposit/deposit.toml b/setups/tests/deposit/deposit.toml deleted file mode 100644 index 04c23ce7d..000000000 --- a/setups/tests/deposit/deposit.toml +++ /dev/null @@ -1,53 +0,0 @@ -[simulation] - name = "deposit-test" - engine = "srpic" - runtime = 1.0 - -[grid] - resolution = [256, 256] - extent = [[0.0, 1.0], [0.0, 1.0]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["PERIODIC"], ["PERIODIC"]] - particles = [["PERIODIC"], ["PERIODIC"]] - -[scales] - larmor0 = 0.1 - skindepth0 = 0.1 - -[algorithms] - current_filters = 4 - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 10.0 - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e2 - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e2 - -[setup] - -[output] - format = "hdf5" - interval_time = 0.01 - - [output.quantities] - quantities = ["N_1", "N_2", "E", "B", "J"] - -[diagnostics] - colored_stdout = true - blocking_timers = true diff --git a/setups/tests/deposit/pgen.hpp b/setups/tests/deposit/pgen.hpp deleted file mode 100644 index fd9a41c2e..000000000 --- a/setups/tests/deposit/pgen.hpp +++ /dev/null @@ -1,133 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" -#include "utils/comparators.h" -#include "utils/formatting.h" -#include "utils/log.h" -#include "utils/numeric.h" - -#include "archetypes/energy_dist.h" -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "framework/domain/domain.h" -#include "framework/domain/metadomain.h" - -#include - -namespace user { - using namespace ntt; - - template - struct PGen : public arch::ProblemGenerator { - - // compatibility traits for the problem generator - static constexpr auto engines = traits::compatible_with::value; - static constexpr auto metrics = traits::compatible_with::value; - static constexpr auto dimensions = - traits::compatible_with::value; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const Metadomain& global_domain; - - inline PGen(const SimulationParams& p, const Metadomain& global_domain) - : arch::ProblemGenerator { p } - , global_domain { global_domain } {} - - inline void InitPrtls(Domain& local_domain) { - const auto empty = std::vector {}; - const auto x1s = params.template get>("setup.x1s", empty); - const auto y1s = params.template get>("setup.y1s", empty); - const auto z1s = params.template get>("setup.z1s", empty); - const auto ux1s = params.template get>("setup.ux1s", - empty); - const auto uy1s = params.template get>("setup.uy1s", - empty); - const auto uz1s = params.template get>("setup.uz1s", - empty); - - const auto x2s = params.template get>("setup.x2s", empty); - const auto y2s = params.template get>("setup.y2s", empty); - const auto z2s = params.template get>("setup.z2s", empty); - const auto ux2s = params.template get>("setup.ux2s", - empty); - const auto uy2s = params.template get>("setup.uy2s", - empty); - const auto uz2s = params.template get>("setup.uz2s", - empty); - // std::vector x, y, z, ux_1, uy_1, uz_1, ux_2, uy_2, uz_2; - // x.push_back(0.85); - // x.push_back(0.123); - // if constexpr (D == Dim::_2D || D == Dim::_3D) { - // y.push_back(0.32); - // y.push_back(0.321); - // } - // if constexpr (D == Dim::_3D) { - // z.push_back(0.231); - // z.push_back(0.687); - // } - // ux_1.push_back(1.0); - // uy_1.push_back(-1.0); - // uz_1.push_back(0.0); - // ux_1.push_back(1.0); - // uy_1.push_back(-2.0); - // uz_1.push_back(1.0); - // - // ux_2.push_back(1.0); - // uy_2.push_back(1.0); - // uz_2.push_back(0.0); - // ux_2.push_back(-2.0); - // uy_2.push_back(3.0); - // uz_2.push_back(-1.0); - // - const std::map> data_1 { - { "x1", x1s}, - { "x2", y1s}, - { "x3", z1s}, - {"ux1", ux1s}, - {"ux2", uy1s}, - {"ux3", uz1s} - }; - const std::map> data_2 { - { "x1", x2s}, - { "x2", y2s}, - { "x3", z2s}, - {"ux1", ux2s}, - {"ux2", uy2s}, - {"ux3", uz2s} - }; - - arch::InjectGlobally(global_domain, local_domain, (arch::spidx_t)1, data_1); - arch::InjectGlobally(global_domain, local_domain, (arch::spidx_t)2, data_2); - } - - // void CustomPostStep(std::size_t, long double time, Domain& domain) { - // if (time >= 0.1) { - // for (auto& species : domain.species) { - // auto ux1 = species.ux1; - // auto ux2 = species.ux2; - // auto ux3 = species.ux3; - // Kokkos::parallel_for( - // "Stop", - // species.rangeActiveParticles(), - // Lambda(index_t p) { - // ux1(p) = ZERO; - // ux2(p) = ZERO; - // ux3(p) = ZERO; - // }); - // } - // } - // } - }; - -} // namespace user - -#endif diff --git a/setups/tests/injector/injector.toml b/setups/tests/injector/injector.toml deleted file mode 100644 index 10fdaa251..000000000 --- a/setups/tests/injector/injector.toml +++ /dev/null @@ -1,62 +0,0 @@ -[simulation] - name = "injector-test" - engine = "srpic" - runtime = 2.0 - -[grid] - resolution = [512, 512] - extent = [[-1.0, 1.0], [-1.0, 1.0]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["ABSORB", "ABSORB"], ["ABSORB", "ABSORB"]] - particles = [["ABSORB", "ABSORB"], ["ABSORB", "ABSORB"]] - - [grid.boundaries.absorb] - ds = 0.15 - -[scales] - larmor0 = 0.1 - skindepth0 = 0.1 - -[algorithms] - current_filters = 4 - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 1.0 - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e6 - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e6 - -[setup] - period = 0.1 - vmax = 1.0 - x1c = 0.25 - x2c = -0.32 - dr = 1e-2 - rate = 0.1 - -[output] - format = "hdf5" - interval_time = 0.01 - - [output.fields] - quantities = ["N_1", "N_2", "E"] - -[diagnostics] - interval = 10 - colored_stdout = true diff --git a/setups/tests/injector/pgen.hpp b/setups/tests/injector/pgen.hpp deleted file mode 100644 index 17d7f9398..000000000 --- a/setups/tests/injector/pgen.hpp +++ /dev/null @@ -1,103 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" -#include "utils/numeric.h" - -#include "archetypes/energy_dist.h" -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "archetypes/spatial_dist.h" -#include "framework/domain/domain.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct Firehose : public arch::EnergyDistribution { - Firehose(const M& metric, real_t time, real_t period, real_t vmax) - : arch::EnergyDistribution { metric } - , phase { (real_t)(constant::TWO_PI)*time / period } - , vmax { vmax } {} - - Inline void operator()(const coord_t&, - vec_t& v_Ph, - unsigned short) const override { - v_Ph[0] = vmax * math::cos(phase); - v_Ph[1] = vmax * math::sin(phase); - } - - private: - const real_t phase, vmax; - }; - - template - struct PointDistribution : public arch::SpatialDistribution { - PointDistribution(const M& metric, real_t x1c, real_t x2c, real_t dr) - : arch::SpatialDistribution { metric } - , x1c { x1c } - , x2c { x2c } - , dr { dr } {} - - Inline auto operator()(const coord_t& x_Ph) const -> real_t override { - return math::exp(-(SQR(x_Ph[0] - x1c) + SQR(x_Ph[1] - x2c)) / SQR(dr)); - } - - private: - const real_t x1c, x2c, dr; - }; - - template - struct PGen : public arch::ProblemGenerator { - - // compatibility traits for the problem generator - static constexpr auto engines = traits::compatible_with::value; - static constexpr auto metrics = traits::compatible_with::value; - static constexpr auto dimensions = traits::compatible_with::value; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const real_t period, vmax, x1c, x2c, dr, rate; - - inline PGen(const SimulationParams& p, const Metadomain&) - : arch::ProblemGenerator { p } - , period { params.template get("setup.period", 1.0) } - , vmax { params.template get("setup.vmax", 1.0) } - , x1c { params.template get("setup.x1c", 0.0) } - , x2c { params.template get("setup.x2c", 0.0) } - , dr { params.template get("setup.dr", 0.1) } - , rate { params.template get("setup.rate", 1.0) } {} - - void CustomPostStep(std::size_t, long double time, Domain& domain) { - const auto energy_dist = Firehose(domain.mesh.metric, - (real_t)time, - period, - vmax); - const auto spatial_dist = PointDistribution(domain.mesh.metric, - x1c, - x2c, - dr); - const auto injector = arch::NonUniformInjector( - energy_dist, - spatial_dist, - { 1, 2 }); - - arch::InjectNonUniform>( - params, - domain, - injector, - rate); - } - }; - -} // namespace user - -#endif diff --git a/setups/wip/rec-gravity/pgen.hpp b/setups/wip/rec-gravity/pgen.hpp deleted file mode 100644 index a4f927113..000000000 --- a/setups/wip/rec-gravity/pgen.hpp +++ /dev/null @@ -1,211 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/directions.h" -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" -#include "utils/numeric.h" - -#include "archetypes/energy_dist.h" -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "archetypes/spatial_dist.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct Gravity { - const std::vector species { 1, 2 }; - - Gravity(real_t f, real_t tscale, real_t ymid) - : force { f } - , tscale { tscale } - , ymid { ymid } {} - - Inline auto fx1(const unsigned short&, const real_t&, const coord_t&) const - -> real_t { - return ZERO; - } - - Inline auto fx2(const unsigned short&, - const real_t& t, - const coord_t& x_Ph) const -> real_t { - const auto sign = (x_Ph[1] < ymid) ? ONE : -ONE; - if (t > tscale) { - return sign * force; - } else { - return sign * force * (ONE - math::cos(constant::PI * t / tscale)) / TWO; - } - } - - Inline auto fx3(const unsigned short&, const real_t&, const coord_t&) const - -> real_t { - return ZERO; - } - - private: - const real_t force, tscale, ymid; - }; - - template - struct CurrentLayer : public arch::SpatialDistribution { - CurrentLayer(const M& metric, real_t width, real_t yi) - : arch::SpatialDistribution { metric } - , width { width } - , yi { yi } {} - - Inline auto operator()(const coord_t& x_Ph) const -> real_t override { - return ONE / SQR(math::cosh((x_Ph[1] - yi) / width)); - } - - private: - const real_t yi, width; - }; - - template - struct InitFields { - InitFields(real_t Bmag, real_t width, real_t angle, real_t y1, real_t y2) - : Bmag { Bmag } - , width { width } - , angle { angle } - , y1 { y1 } - , y2 { y2 } {} - - Inline auto bx1(const coord_t& x_Ph) const -> real_t { - return Bmag * math::cos(angle) * - (math::tanh((x_Ph[1] - y1) / width) - - math::tanh((x_Ph[1] - y2) / width) - 1); - } - - Inline auto bx3(const coord_t& x_Ph) const -> real_t { - return Bmag * math::sin(angle) * - (math::tanh((x_Ph[1] - y1) / width) - - math::tanh((x_Ph[1] - y2) / width) - 1); - } - - private: - const real_t Bmag, width, angle, y1, y2; - }; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines { traits::compatible_with::value }; - static constexpr auto metrics { traits::compatible_with::value }; - static constexpr auto dimensions { - traits::compatible_with::value - }; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const real_t Bmag, width, angle, overdensity, y1, y2, bg_temp; - InitFields init_flds; - - Gravity ext_force; - - inline PGen(const SimulationParams& p, const Metadomain& m) - : arch::ProblemGenerator(p) - , Bmag { p.template get("setup.Bmag", 1.0) } - , width { p.template get("setup.width") } - , angle { p.template get("setup.angle") } - , overdensity { p.template get("setup.overdensity") } - , y1 { m.mesh().extent(in::x2).first + - INV_4 * - (m.mesh().extent(in::x2).second - m.mesh().extent(in::x2).first) } - , y2 { m.mesh().extent(in::x2).first + - 3 * INV_4 * - (m.mesh().extent(in::x2).second - m.mesh().extent(in::x2).first) } - , init_flds { Bmag, width, angle, y1, y2 } - , bg_temp { p.template get("setup.bg_temp") } - , ext_force { - p.template get("setup.fmag", 0.1), - (m.mesh().extent(in::x1).second - m.mesh().extent(in::x1).first), - INV_2 * (m.mesh().extent(in::x2).second + m.mesh().extent(in::x2).first) - } {} - - inline PGen() {} - - inline void InitPrtls(Domain& local_domain) { - // background - const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - bg_temp); - const auto injector = arch::UniformInjector( - energy_dist, - { 1, 2 }); - arch::InjectUniform(params, - local_domain, - injector, - HALF); - // record npart - const auto npart1 = local_domain.species[0].npart(); - const auto npart2 = local_domain.species[1].npart(); - - const auto sigma = params.template get("scales.sigma0"); - const auto c_omp = params.template get("scales.skindepth0"); - const auto cs_drift_beta = math::sqrt(sigma) * c_omp / (width * overdensity); - const auto cs_drift_gamma = ONE / math::sqrt(ONE - SQR(cs_drift_beta)); - const auto cs_drift_u = cs_drift_beta * cs_drift_gamma; - const auto cs_temp = HALF * sigma / overdensity; - // current layer #1 - auto edist_cs_1 = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - cs_temp, - cs_drift_u, - in::x3, - false); - const auto sdist_cs_1 = CurrentLayer(local_domain.mesh.metric, width, y1); - const auto inj_cs_1 = arch::NonUniformInjector( - edist_cs_1, - sdist_cs_1, - { 1, 2 }); - arch::InjectNonUniform(params, - local_domain, - inj_cs_1, - overdensity); - // current layer #2 - const auto edist_cs_2 = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - cs_temp, - -cs_drift_u, - in::x3, - false); - const auto sdist_cs_2 = CurrentLayer(local_domain.mesh.metric, width, y2); - const auto inj_cs_2 = arch::NonUniformInjector( - edist_cs_2, - sdist_cs_2, - { 1, 2 }); - arch::InjectNonUniform(params, - local_domain, - inj_cs_2, - overdensity); - auto ux1_1 = local_domain.species[0].ux1; - auto ux3_1 = local_domain.species[0].ux3; - auto ux1_2 = local_domain.species[1].ux1; - auto ux3_2 = local_domain.species[1].ux3; - Kokkos::parallel_for( - "TurnParticles", - CreateRangePolicy({ npart1 }, { local_domain.species[0].npart() }), - ClassLambda(index_t p) { - auto ux1_ = ux1_1(p), ux3_ = ux3_1(p); - ux1_1(p) = math::cos(angle) * ux1_ - math::sin(angle) * ux3_; - ux3_1(p) = math::sin(angle) * ux1_ + math::cos(angle) * ux3_; - - ux1_ = ux1_2(p), ux3_ = ux3_2(p); - ux1_2(p) = math::cos(angle) * ux1_ - math::sin(angle) * ux3_; - ux3_2(p) = math::sin(angle) * ux1_ + math::cos(angle) * ux3_; - }); - } // namespace user - }; - -} // namespace user - -#endif diff --git a/setups/wip/rec-gravity/rec-gravity.toml b/setups/wip/rec-gravity/rec-gravity.toml deleted file mode 100644 index f8d5b94ee..000000000 --- a/setups/wip/rec-gravity/rec-gravity.toml +++ /dev/null @@ -1,54 +0,0 @@ -[simulation] - name = "rec-gravity" - engine = "srpic" - runtime = 20.0 - -[grid] - resolution = [2000, 4000] - extent = [[-0.5, 0.5], [-1.0, 1.0]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["PERIODIC"], ["PERIODIC"]] - particles = [["PERIODIC"], ["PERIODIC"]] - -[scales] - larmor0 = 3.1e-4 - skindepth0 = 1e-3 - -[algorithms] - current_filters = 8 - - [algorithms.timestep] - CFL = 0.45 - -[particles] - ppc0 = 8.0 - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e8 - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e8 - -[setup] - Bmag = 1.0 - width = 0.04 - bg_temp = 1e-4 - overdensity = 3.0 - angle = 0.0 - -[output] - format = "hdf5" - interval_time = 0.36 - - [output.fields] - quantities = ["N_1", "N_2", "E", "B", "J", "T00_1", "T00_2"] diff --git a/setups/wip/reconnection/pgen.hpp b/setups/wip/reconnection/pgen.hpp deleted file mode 100644 index e97bc518a..000000000 --- a/setups/wip/reconnection/pgen.hpp +++ /dev/null @@ -1,143 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/directions.h" -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" -#include "utils/numeric.h" - -#include "archetypes/energy_dist.h" -#include "archetypes/particle_injector.h" -#include "archetypes/problem_generator.h" -#include "archetypes/spatial_dist.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct CurrentLayer : public arch::SpatialDistribution { - CurrentLayer(const M& metric, real_t width, real_t yi) - : arch::SpatialDistribution { metric } - , width { width } - , yi { yi } {} - - Inline auto operator()(const coord_t& x_Ph) const -> real_t override { - return ONE / SQR(math::cosh((x_Ph[1] - yi) / width)); - } - - private: - const real_t yi, width; - }; - - template - struct InitFields { - InitFields(real_t Bmag, real_t width, real_t y1, real_t y2) - : Bmag { Bmag } - , width { width } - , y1 { y1 } - , y2 { y2 } {} - - Inline auto bx1(const coord_t& x_Ph) const -> real_t { - return Bmag * (math::tanh((x_Ph[1] - y1) / width) - - math::tanh((x_Ph[1] - y2) / width) - 1); - } - - private: - const real_t Bmag, width, y1, y2; - }; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines { traits::compatible_with::value }; - static constexpr auto metrics { traits::compatible_with::value }; - static constexpr auto dimensions { - traits::compatible_with::value - }; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - const real_t Bmag, width, overdensity, y1, y2, bg_temp; - InitFields init_flds; - - inline PGen(const SimulationParams& p, const Metadomain& m) - : arch::ProblemGenerator(p) - , Bmag { p.template get("setup.Bmag", 1.0) } - , width { p.template get("setup.width") } - , overdensity { p.template get("setup.overdensity") } - , y1 { m.mesh().extent(in::x2).first + - INV_4 * - (m.mesh().extent(in::x2).second - m.mesh().extent(in::x2).first) } - , y2 { m.mesh().extent(in::x2).first + - 3 * INV_4 * - (m.mesh().extent(in::x2).second - m.mesh().extent(in::x2).first) } - , init_flds { Bmag, width, y1, y2 } - , bg_temp { p.template get("setup.bg_temp") } {} - - inline PGen() {} - - inline void InitPrtls(Domain& local_domain) { - // background - const auto energy_dist = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - bg_temp); - const auto injector = arch::UniformInjector( - energy_dist, - { 1, 2 }); - arch::InjectUniform>( - params, - local_domain, - injector, - HALF); - - const auto sigma = params.template get("scales.sigma0"); - const auto c_omp = params.template get("scales.skindepth0"); - const auto cs_drift_beta = math::sqrt(sigma) * c_omp / (width * overdensity); - const auto cs_drift_gamma = ONE / math::sqrt(ONE - SQR(cs_drift_beta)); - const auto cs_drift_u = cs_drift_beta * cs_drift_gamma; - const auto cs_temp = HALF * sigma / overdensity; - // current layer #1 - auto edist_cs_1 = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - cs_temp, - cs_drift_u, - in::x3, - false); - const auto sdist_cs_1 = CurrentLayer(local_domain.mesh.metric, width, y1); - const auto inj_cs_1 = arch::NonUniformInjector( - edist_cs_1, - sdist_cs_1, - { 1, 2 }); - arch::InjectNonUniform(params, - local_domain, - inj_cs_1, - overdensity); - // current layer #2 - const auto edist_cs_2 = arch::Maxwellian(local_domain.mesh.metric, - local_domain.random_pool, - cs_temp, - -cs_drift_u, - in::x3, - false); - const auto sdist_cs_2 = CurrentLayer(local_domain.mesh.metric, width, y2); - const auto inj_cs_2 = arch::NonUniformInjector( - edist_cs_2, - sdist_cs_2, - { 1, 2 }); - arch::InjectNonUniform(params, - local_domain, - inj_cs_2, - overdensity); - } - }; - -} // namespace user - -#endif diff --git a/setups/wip/reconnection/reconnection.toml b/setups/wip/reconnection/reconnection.toml deleted file mode 100644 index fa7b049f4..000000000 --- a/setups/wip/reconnection/reconnection.toml +++ /dev/null @@ -1,53 +0,0 @@ -[simulation] - name = "reconnection" - engine = "srpic" - runtime = 10.0 - -[grid] - resolution = [1024, 2048] - extent = [[-1.0, 1.0], [-2.0, 2.0]] - - [grid.metric] - metric = "minkowski" - - [grid.boundaries] - fields = [["PERIODIC"], ["PERIODIC"]] - particles = [["PERIODIC"], ["PERIODIC"]] - -[scales] - larmor0 = 2e-4 - skindepth0 = 2e-3 - -[algorithms] - current_filters = 4 - - [algorithms.timestep] - CFL = 0.5 - -[particles] - ppc0 = 8.0 - - [[particles.species]] - label = "e-" - mass = 1.0 - charge = -1.0 - maxnpart = 1e7 - - [[particles.species]] - label = "e+" - mass = 1.0 - charge = 1.0 - maxnpart = 1e7 - -[setup] - Bmag = 1.0 - width = 0.01 - bg_temp = 1e-4 - overdensity = 3.0 - -[output] - format = "hdf5" - interval_time = 0.1 - - [output.fields] - quantities = ["N_1", "N_2", "E", "B", "J"] diff --git a/setups/wip/spider/pgen.hpp b/setups/wip/spider/pgen.hpp deleted file mode 100644 index 27d9504b7..000000000 --- a/setups/wip/spider/pgen.hpp +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef PROBLEM_GENERATOR_H -#define PROBLEM_GENERATOR_H - -#include "enums.h" -#include "global.h" - -#include "arch/kokkos_aliases.h" -#include "arch/traits.h" - -#include "archetypes/problem_generator.h" -#include "framework/domain/metadomain.h" - -namespace user { - using namespace ntt; - - template - struct PGen : public arch::ProblemGenerator { - // compatibility traits for the problem generator - static constexpr auto engines { traits::compatible_with::value }; - static constexpr auto metrics { traits::compatible_with::value }; - static constexpr auto dimensions { - traits::compatible_with::value - }; - - // for easy access to variables in the child class - using arch::ProblemGenerator::D; - using arch::ProblemGenerator::C; - using arch::ProblemGenerator::params; - - inline PGen(const SimulationParams& p, const Metadomain& m) - : arch::ProblemGenerator(p) {} - - inline PGen() {} - }; - -} // namespace user - -#endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5d7f0abb4..a54a119b5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,27 +1,28 @@ +# cmake-lint: disable=C0103 # ------------------------------ # @defines: entity [STATIC/SHARED] +# # @sources: -# - entity.cpp +# +# * entity.cpp +# # @depends: -# - ntt_global [required] -# - ntt_framework [required] -# - ntt_metrics [required] -# - ntt_engine [required] +# +# * ntt_global [required] +# * ntt_framework [required] +# * ntt_metrics [required] +# * ntt_engine [required] +# * ntt_pgen [required] +# # @uses: -# - kokkos [required] -# - plog [required] -# - toml11 [required] -# - ADIOS2 [optional] -# - mpi [optional] +# +# * kokkos [required] +# * plog [required] +# * ADIOS2 [optional] +# * mpi [optional] # ------------------------------ - -set(ENTITY ${PROJECT_NAME}.xc) set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(SOURCES - ${SRC_DIR}/entity.cpp -) -add_executable(${ENTITY} entity.cpp) # dependencies add_subdirectory(${SRC_DIR}/global ${CMAKE_CURRENT_BINARY_DIR}/global) @@ -30,11 +31,16 @@ add_subdirectory(${SRC_DIR}/kernels ${CMAKE_CURRENT_BINARY_DIR}/kernels) add_subdirectory(${SRC_DIR}/archetypes ${CMAKE_CURRENT_BINARY_DIR}/archetypes) add_subdirectory(${SRC_DIR}/framework ${CMAKE_CURRENT_BINARY_DIR}/framework) add_subdirectory(${SRC_DIR}/engines ${CMAKE_CURRENT_BINARY_DIR}/engines) -if (${output} STREQUAL "ON") add_subdirectory(${SRC_DIR}/output ${CMAKE_CURRENT_BINARY_DIR}/output) + +if(${output}) + add_subdirectory(${SRC_DIR}/checkpoint ${CMAKE_CURRENT_BINARY_DIR}/checkpoint) endif() -add_subdirectory(${SRC_DIR}/../setups ${CMAKE_CURRENT_BINARY_DIR}/setups) +set(ENTITY ${PROJECT_NAME}.xc) +set(SOURCES ${SRC_DIR}/entity.cpp) + +add_executable(${ENTITY} ${SOURCES}) set(libs ntt_global ntt_framework ntt_metrics ntt_engines ntt_pgen) add_dependencies(${ENTITY} ${libs}) target_link_libraries(${ENTITY} PUBLIC ${libs}) diff --git a/src/archetypes/CMakeLists.txt b/src/archetypes/CMakeLists.txt index 7883ba6a5..93f8baaaa 100644 --- a/src/archetypes/CMakeLists.txt +++ b/src/archetypes/CMakeLists.txt @@ -1,13 +1,20 @@ +# cmake-lint: disable=C0103 # ------------------------------ # @defines: ntt_archetypes [INTERFACE] +# # @includes: -# - ../ +# +# * ../ +# # @depends: -# - ntt_global [required] -# - ntt_kernels [required] +# +# * ntt_global [required] +# * ntt_kernels [required] +# # @uses: -# - kokkos [required] -# - mpi [optional] +# +# * kokkos [required] +# * mpi [optional] # ------------------------------ add_library(ntt_archetypes INTERFACE) @@ -17,5 +24,4 @@ add_dependencies(ntt_archetypes ${libs}) target_link_libraries(ntt_archetypes INTERFACE ${libs}) target_include_directories(ntt_archetypes - INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../ -) \ No newline at end of file + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../) diff --git a/src/archetypes/energy_dist.h b/src/archetypes/energy_dist.h index c5b9d8a87..bcb88fbc3 100644 --- a/src/archetypes/energy_dist.h +++ b/src/archetypes/energy_dist.h @@ -3,7 +3,8 @@ * @brief Defines an archetype for energy distributions * @implements * - arch::EnergyDistribution<> - * - arch::ColdDist<> : arch::EnergyDistribution<> + * - arch::Cold<> : arch::EnergyDistribution<> + * - arch::Powerlaw<> : arch::EnergyDistribution<> * - arch::Maxwellian<> : arch::EnergyDistribution<> * @namespaces: * - arch:: @@ -39,33 +40,169 @@ namespace arch { EnergyDistribution(const M& metric) : metric { metric } {} - // Takes the physical coordinate of the particle and returns - // the velocity in tetrad basis - // last argument -- is the species index (1, ..., nspec) - Inline virtual void operator()(const coord_t&, - vec_t& v, - unsigned short = 0) const { + protected: + const M metric; + }; + + template + struct Cold : public EnergyDistribution { + Cold(const M& metric) : EnergyDistribution { metric } {} + + Inline void operator()(const coord_t&, + vec_t& v, + spidx_t = 0) const { + v[0] = ZERO; v[1] = ZERO; v[2] = ZERO; } - - protected: - const M metric; }; template - struct ColdDist : public EnergyDistribution { - ColdDist(const M& metric) : EnergyDistribution { metric } {} + struct Powerlaw : public EnergyDistribution { + using EnergyDistribution::metric; + + Powerlaw(const M& metric, + random_number_pool_t& pool, + real_t g_min, + real_t g_max, + real_t pl_ind) + : EnergyDistribution { metric } + , g_min { g_min } + , g_max { g_max } + , pl_ind { pl_ind } + , pool { pool } {} Inline void operator()(const coord_t&, vec_t& v, - unsigned short = 0) const override { + spidx_t = 0) const { + auto rand_gen = pool.get_state(); + auto rand_X1 = Random(rand_gen); + auto rand_gam = ONE; + + // Power-law distribution from uniform (see https://mathworld.wolfram.com/RandomNumber.html) + if (pl_ind != -ONE) { + rand_gam += math::pow( + math::pow(g_min, ONE + pl_ind) + + (-math::pow(g_min, ONE + pl_ind) + math::pow(g_max, ONE + pl_ind)) * + rand_X1, + ONE / (ONE + pl_ind)); + } else { + rand_gam += math::pow(g_min, ONE - rand_X1) * math::pow(g_max, rand_X1); + } + auto rand_u = math::sqrt(SQR(rand_gam) - ONE); + auto rand_X2 = Random(rand_gen); + auto rand_X3 = Random(rand_gen); + v[0] = rand_u * (TWO * rand_X2 - ONE); + v[2] = TWO * rand_u * math::sqrt(rand_X2 * (ONE - rand_X2)); + v[1] = v[2] * math::cos(constant::TWO_PI * rand_X3); + v[2] = v[2] * math::sin(constant::TWO_PI * rand_X3); + + pool.free_state(rand_gen); + } + + private: + const real_t g_min, g_max, pl_ind; + random_number_pool_t pool; + }; + + Inline void JuttnerSinge(vec_t& v, + const real_t& temp, + const random_number_pool_t& pool) { + auto rand_gen = pool.get_state(); + real_t randX1, randX2; + if (temp < static_cast(0.5)) { + // Juttner-Synge distribution using the Box-Muller method - non-relativistic + randX1 = Random(rand_gen); + while (cmp::AlmostZero(randX1)) { + randX1 = Random(rand_gen); + } + randX1 = math::sqrt(-TWO * math::log(randX1)); + randX2 = constant::TWO_PI * Random(rand_gen); + v[0] = randX1 * math::cos(randX2) * math::sqrt(temp); + + randX1 = Random(rand_gen); + while (cmp::AlmostZero(randX1)) { + randX1 = Random(rand_gen); + } + randX1 = math::sqrt(-TWO * math::log(randX1)); + randX2 = constant::TWO_PI * Random(rand_gen); + v[1] = randX1 * math::cos(randX2) * math::sqrt(temp); + + randX1 = Random(rand_gen); + while (cmp::AlmostZero(randX1)) { + randX1 = Random(rand_gen); + } + randX1 = math::sqrt(-TWO * math::log(randX1)); + randX2 = constant::TWO_PI * Random(rand_gen); + v[2] = randX1 * math::cos(randX2) * math::sqrt(temp); + } else { + // Juttner-Synge distribution using the Sobol method - relativistic + auto randu = ONE; + auto randeta = Random(rand_gen); + while (SQR(randeta) <= SQR(randu) + ONE) { + randX1 = Random(rand_gen) * Random(rand_gen) * + Random(rand_gen); + while (cmp::AlmostZero(randX1)) { + randX1 = Random(rand_gen) * Random(rand_gen) * + Random(rand_gen); + } + randu = -temp * math::log(randX1); + randX2 = Random(rand_gen); + while (cmp::AlmostZero(randX2)) { + randX2 = Random(rand_gen); + } + randeta = -temp * math::log(randX1 * randX2); + } + randX1 = Random(rand_gen); + randX2 = Random(rand_gen); + v[0] = randu * (TWO * randX1 - ONE); + v[2] = TWO * randu * math::sqrt(randX1 * (ONE - randX1)); + v[1] = v[2] * math::cos(constant::TWO_PI * randX2); + v[2] = v[2] * math::sin(constant::TWO_PI * randX2); + } + pool.free_state(rand_gen); + } + + template + Inline void SampleFromMaxwellian( + vec_t& v, + const random_number_pool_t& pool, + const real_t& temperature, + const real_t& boost_velocity = static_cast(0), + const in& boost_direction = in::x1, + bool flip_velocity = false) { + if (cmp::AlmostZero(temperature)) { v[0] = ZERO; v[1] = ZERO; v[2] = ZERO; + } else { + JuttnerSinge(v, temperature, pool); } - }; + if constexpr (CanBoost) { + // Boost a symmetric distribution to a relativistic speed using flipping + // method https://arxiv.org/pdf/1504.03910.pdf + // @note: boost only when using cartesian coordinates + if (not cmp::AlmostZero(boost_velocity)) { + const auto boost_dir = static_cast(boost_direction); + const auto boost_beta { boost_velocity / + math::sqrt(ONE + SQR(boost_velocity)) }; + const auto gamma { U2GAMMA(v[0], v[1], v[2]) }; + auto rand_gen = pool.get_state(); + if (-boost_beta * v[boost_dir] > gamma * Random(rand_gen)) { + v[boost_dir] = -v[boost_dir]; + } + pool.free_state(rand_gen); + v[boost_dir] = math::sqrt(ONE + SQR(boost_velocity)) * + (v[boost_dir] + boost_beta * gamma); + if (flip_velocity) { + v[0] = -v[0]; + v[1] = -v[1]; + v[2] = -v[2]; + } + } + } + } template struct Maxwellian : public EnergyDistribution { @@ -87,123 +224,212 @@ namespace arch { "Maxwellian: Temperature must be non-negative", HERE); raise::ErrorIf( - (not cmp::AlmostZero(boost_vel, ZERO)) && (M::CoordType != Coord::Cart), + (not cmp::AlmostZero_host(boost_vel, ZERO)) && (M::CoordType != Coord::Cart), "Maxwellian: Boosting is only supported in Cartesian coordinates", HERE); } - // Juttner-Synge distribution - Inline void JS(vec_t& v, const real_t& temp) const { - auto rand_gen = pool.get_state(); - real_t randX1, randX2; - if (temp < static_cast(0.5)) { - // Juttner-Synge distribution using the Box-Muller method - non-relativistic - randX1 = Random(rand_gen); - while (cmp::AlmostZero(randX1)) { - randX1 = Random(rand_gen); - } - randX1 = math::sqrt(-TWO * math::log(randX1)); - randX2 = constant::TWO_PI * Random(rand_gen); - v[0] = randX1 * math::cos(randX2) * math::sqrt(temp); + Inline void operator()(const coord_t&, + vec_t& v, + spidx_t sp = 0) const { + SampleFromMaxwellian(v, + pool, + temperature, + boost_velocity, + boost_direction, + not zero_current and + sp % 2 == 0); + } - randX1 = Random(rand_gen); - while (cmp::AlmostZero(randX1)) { - randX1 = Random(rand_gen); - } - randX1 = math::sqrt(-TWO * math::log(randX1)); - randX2 = constant::TWO_PI * Random(rand_gen); - v[1] = randX1 * math::cos(randX2) * math::sqrt(temp); + private: + random_number_pool_t pool; - randX1 = Random(rand_gen); - while (cmp::AlmostZero(randX1)) { - randX1 = Random(rand_gen); - } - randX1 = math::sqrt(-TWO * math::log(randX1)); - randX2 = constant::TWO_PI * Random(rand_gen); - v[2] = randX1 * math::cos(randX2) * math::sqrt(temp); - } else { - // Juttner-Synge distribution using the Sobol method - relativistic - auto randu = ONE; - auto randeta = Random(rand_gen); - while (SQR(randeta) <= SQR(randu) + ONE) { - randX1 = Random(rand_gen) * Random(rand_gen) * - Random(rand_gen); - while (cmp::AlmostZero(randX1)) { - randX1 = Random(rand_gen) * Random(rand_gen) * - Random(rand_gen); - } - randu = -temp * math::log(randX1); - randX2 = Random(rand_gen); - while (cmp::AlmostZero(randX2)) { - randX2 = Random(rand_gen); - } - randeta = -temp * math::log(randX1 * randX2); - } - randX1 = Random(rand_gen); - randX2 = Random(rand_gen); - v[0] = randu * (TWO * randX1 - ONE); - v[2] = TWO * randu * math::sqrt(randX1 * (ONE - randX1)); - v[1] = v[2] * math::cos(constant::TWO_PI * randX2); - v[2] = v[2] * math::sin(constant::TWO_PI * randX2); - } - pool.free_state(rand_gen); + const real_t temperature; + const real_t boost_velocity; + const in boost_direction; + const bool zero_current; + }; + + template + struct TwoTemperatureMaxwellian : public EnergyDistribution { + using EnergyDistribution::metric; + + TwoTemperatureMaxwellian(const M& metric, + random_number_pool_t& pool, + const std::pair& temperatures, + const std::pair& species, + real_t boost_vel = ZERO, + in boost_direction = in::x1, + bool zero_current = true) + : EnergyDistribution { metric } + , pool { pool } + , temperature_1 { temperatures.first } + , temperature_2 { temperatures.second } + , sp_1 { species.first } + , sp_2 { species.second } + , boost_velocity { boost_vel } + , boost_direction { boost_direction } + , zero_current { zero_current } { + raise::ErrorIf( + (temperature_1 < ZERO) or (temperature_2 < ZERO), + "TwoTemperatureMaxwellian: Temperature must be non-negative", + HERE); + raise::ErrorIf((not cmp::AlmostZero(boost_vel, ZERO)) && + (M::CoordType != Coord::Cart), + "TwoTemperatureMaxwellian: Boosting is only supported in " + "Cartesian coordinates", + HERE); } - // Boost a symmetric distribution to a relativistic speed using flipping - // method https://arxiv.org/pdf/1504.03910.pdf - Inline void boost(vec_t& v) const { - const auto boost_dir = static_cast(boost_direction); - const auto boost_beta { boost_velocity / - math::sqrt(ONE + SQR(boost_velocity)) }; - const auto gamma { U2GAMMA(v[0], v[1], v[2]) }; - auto rand_gen = pool.get_state(); - if (-boost_beta * v[boost_dir] > gamma * Random(rand_gen)) { - v[boost_dir] = -v[boost_dir]; - } - pool.free_state(rand_gen); - v[boost_dir] = math::sqrt(ONE + SQR(boost_velocity)) * - (v[boost_dir] + boost_beta * gamma); + Inline void operator()(const coord_t&, + vec_t& v, + spidx_t sp = 0) const { + SampleFromMaxwellian( + v, + pool, + (sp == sp_1) ? temperature_1 : temperature_2, + boost_velocity, + boost_direction, + not zero_current and sp == sp_1); } - Inline void operator()(const coord_t& x_Code, - vec_t& v, - unsigned short s = 0) const override { - if (cmp::AlmostZero(temperature)) { - v[0] = ZERO; - v[1] = ZERO; - v[2] = ZERO; - } else { - JS(v, temperature); - } - if constexpr (S == SimEngine::GRPIC) { - // convert from the tetrad basis to covariant - vec_t v_Hat; - v_Hat[0] = v[0]; - v_Hat[1] = v[1]; - v_Hat[2] = v[2]; - metric.template transform(x_Code, v_Hat, v); + private: + random_number_pool_t pool; + + const real_t temperature_1, temperature_2; + const spidx_t sp_1, sp_2; + const real_t boost_velocity; + const in boost_direction; + const bool zero_current; + }; + + namespace experimental { + + template + struct Maxwellian : public EnergyDistribution { + using EnergyDistribution::metric; + + Maxwellian(const M& metric, + random_number_pool_t& pool, + real_t temperature, + const std::vector& drift_four_vel = { ZERO, ZERO, ZERO }) + : EnergyDistribution { metric } + , pool { pool } + , temperature { temperature } { + raise::ErrorIf(drift_four_vel.size() != 3, + "Maxwellian: Drift velocity must be a 3D vector", + HERE); + raise::ErrorIf(temperature < ZERO, + "Maxwellian: Temperature must be non-negative", + HERE); + if constexpr (M::CoordType == Coord::Cart) { + drift_4vel = NORM(drift_four_vel[0], drift_four_vel[1], drift_four_vel[2]); + if (cmp::AlmostZero_host(drift_4vel)) { + drift_dir = 0; + } else { + drift_3vel = drift_4vel / math::sqrt(ONE + SQR(drift_4vel)); + drift_dir_x1 = drift_four_vel[0] / drift_4vel; + drift_dir_x2 = drift_four_vel[1] / drift_4vel; + drift_dir_x3 = drift_four_vel[2] / drift_4vel; + + // assume drift is in an arbitrary direction + drift_dir = 4; + // check whether drift is in one of principal directions + for (auto d { 0u }; d < 3u; ++d) { + const auto dprev = (d + 2) % 3; + const auto dnext = (d + 1) % 3; + if (cmp::AlmostZero_host(drift_four_vel[dprev]) and + cmp::AlmostZero_host(drift_four_vel[dnext])) { + drift_dir = SIGN(drift_four_vel[d]) * (d + 1); + break; + } + } + } + raise::ErrorIf(drift_dir > 3 and drift_dir != 4, + "Maxwellian: Incorrect drift direction", + HERE); + raise::ErrorIf( + drift_dir != 0 and (M::CoordType != Coord::Cart), + "Maxwellian: Boosting is only supported in Cartesian coordinates", + HERE); + } } - if constexpr (M::CoordType == Coord::Cart) { - // boost only when using cartesian coordinates - if (not cmp::AlmostZero(boost_velocity)) { - boost(v); - if (not zero_current and s % 2 == 0) { - v[0] = -v[0]; - v[1] = -v[1]; - v[2] = -v[2]; + + Inline void operator()(const coord_t& x_Code, + vec_t& v, + spidx_t = 0) const { + if (cmp::AlmostZero(temperature)) { + v[0] = ZERO; + v[1] = ZERO; + v[2] = ZERO; + } else { + JuttnerSinge(v, temperature, pool); + } + // @note: boost only when using cartesian coordinates + if constexpr (M::CoordType == Coord::Cart) { + if (drift_dir != 0) { + // Boost an isotropic Maxwellian with a drift velocity using + // flipping method https://arxiv.org/pdf/1504.03910.pdf + // 1. apply drift in X1 direction + const auto gamma { U2GAMMA(v[0], v[1], v[2]) }; + auto rand_gen = pool.get_state(); + if (-drift_3vel * v[0] > gamma * Random(rand_gen)) { + v[0] = -v[0]; + } + pool.free_state(rand_gen); + v[0] = math::sqrt(ONE + SQR(drift_4vel)) * (v[0] + drift_3vel * gamma); + // 2. rotate to desired orientation + if (drift_dir == -1) { + v[0] = -v[0]; + } else if (drift_dir == 2 || drift_dir == -2) { + const auto tmp = v[1]; + v[1] = drift_dir > 0 ? v[0] : -v[0]; + v[0] = tmp; + } else if (drift_dir == 3 || drift_dir == -3) { + const auto tmp = v[2]; + v[2] = drift_dir > 0 ? v[0] : -v[0]; + v[0] = tmp; + } else if (drift_dir == 4) { + vec_t v_old; + v_old[0] = v[0]; + v_old[1] = v[1]; + v_old[2] = v[2]; + + v[0] = v_old[0] * drift_dir_x1 - v_old[1] * drift_dir_x2 - + v_old[2] * drift_dir_x3; + v[1] = (v_old[0] * drift_dir_x2 * (drift_dir_x1 + ONE) + + v_old[1] * + (SQR(drift_dir_x1) + drift_dir_x1 + SQR(drift_dir_x3)) - + v_old[2] * drift_dir_x2 * drift_dir_x3) / + (drift_dir_x1 + ONE); + v[2] = (v_old[0] * drift_dir_x3 * (drift_dir_x1 + ONE) - + v_old[1] * drift_dir_x2 * drift_dir_x3 - + v_old[2] * (-drift_dir_x1 + SQR(drift_dir_x3) - ONE)) / + (drift_dir_x1 + ONE); + } } } } - } - private: - random_number_pool_t pool; + private: + random_number_pool_t pool; - const real_t temperature; - const real_t boost_velocity; - const in boost_direction; - const bool zero_current; - }; + const real_t temperature; + + real_t drift_3vel { ZERO }, drift_4vel { ZERO }; + // components of the unit vector in the direction of the drift + real_t drift_dir_x1 { ZERO }, drift_dir_x2 { ZERO }, drift_dir_x3 { ZERO }; + + // values of boost_dir: + // 4 -> arbitrary direction + // 0 -> no drift + // +/- 1 -> +/- x1 + // +/- 2 -> +/- x2 + // +/- 3 -> +/- x3 + short drift_dir { 0 }; + }; + + } // namespace experimental } // namespace arch diff --git a/src/archetypes/field_setter.h b/src/archetypes/field_setter.h index 281c28df6..5c5c4dbe4 100644 --- a/src/archetypes/field_setter.h +++ b/src/archetypes/field_setter.h @@ -170,70 +170,30 @@ namespace arch { const real_t x2_H { metric.template convert<2, Crd::Cd, Crd::Ph>( i2_ + HALF) }; { // dx1 - vec_t d_PU { finit.dx1({ x1_H, x2_0 }), - finit.dx2({ x1_H, x2_0 }), - finit.dx3({ x1_H, x2_0 }) }; - vec_t d_U { ZERO }; - metric.template transform({ i1_ + HALF, i2_ }, - d_PU, - d_U); - EM(i1, i2, em::dx1) = d_U[0]; + EM(i1, i2, em::dx1) = finit.dx1({ x1_H, x2_0 }); } { // dx2 - vec_t d_PU { finit.dx1({ x1_0, x2_H }), - finit.dx2({ x1_0, x2_H }), - finit.dx3({ x1_0, x2_H }) }; - vec_t d_U { ZERO }; - metric.template transform({ i1_, i2_ + HALF }, - d_PU, - d_U); - EM(i1, i2, em::dx2) = d_U[1]; + EM(i1, i2, em::dx2) = finit.dx2({ x1_0, x2_H }); } { // dx3 - vec_t d_PU { finit.dx1({ x1_0, x2_0 }), - finit.dx2({ x1_0, x2_0 }), - finit.dx3({ x1_0, x2_0 }) }; - vec_t d_U { ZERO }; - metric.template transform({ i1_, i2_ }, d_PU, d_U); - EM(i1, i2, em::dx3) = d_U[2]; + EM(i1, i2, em::dx3) = finit.dx3({ x1_0, x2_0 }); } } if constexpr (defines_bx1 && defines_bx2 && defines_bx3) { const real_t x1_0 { metric.template convert<1, Crd::Cd, Crd::Ph>(i1_) }; const real_t x1_H { metric.template convert<1, Crd::Cd, Crd::Ph>( i1_ + HALF) }; - const real_t x2_0 { metric.template convert<1, Crd::Cd, Crd::Ph>(i2_) }; - const real_t x2_H { metric.template convert<1, Crd::Cd, Crd::Ph>( + const real_t x2_0 { metric.template convert<2, Crd::Cd, Crd::Ph>(i2_) }; + const real_t x2_H { metric.template convert<2, Crd::Cd, Crd::Ph>( i2_ + HALF) }; { // bx1 - vec_t b_PU { finit.dx1({ x1_0, x2_H }), - finit.dx2({ x1_0, x2_H }), - finit.dx3({ x1_0, x2_H }) }; - vec_t b_U { ZERO }; - metric.template transform({ i1_, i2_ + HALF }, - b_PU, - b_U); - EM(i1, i2, em::bx1) = b_U[0]; + EM(i1, i2, em::bx1) = finit.bx1({ x1_0, x2_H }); } { // bx2 - vec_t b_PU { finit.dx1({ x1_H, x2_0 }), - finit.dx2({ x1_H, x2_0 }), - finit.dx3({ x1_H, x2_0 }) }; - vec_t b_U { ZERO }; - metric.template transform({ i1_ + HALF, i2_ }, - b_PU, - b_U); - EM(i1, i2, em::bx2) = b_U[1]; + EM(i1, i2, em::bx2) = finit.bx2({ x1_H, x2_0 }); } { // bx3 - vec_t b_PU { finit.dx1({ x1_H, x2_H }), - finit.dx2({ x1_H, x2_H }), - finit.dx3({ x1_H, x2_H }) }; - vec_t b_U { ZERO }; - metric.template transform({ i1_ + HALF, i2_ + HALF }, - b_PU, - b_U); - EM(i1, i2, em::bx3) = b_U[2]; + EM(i1, i2, em::bx3) = finit.bx3({ x1_H, x2_H }); } } } else { diff --git a/src/archetypes/particle_injector.h b/src/archetypes/particle_injector.h index 4e75003bb..6313031d1 100644 --- a/src/archetypes/particle_injector.h +++ b/src/archetypes/particle_injector.h @@ -28,19 +28,97 @@ #include "framework/domain/metadomain.h" #include "kernels/injectors.hpp" +#include "kernels/particle_moments.hpp" +#include "kernels/utils.hpp" #include +#if defined(MPI_ENABLED) + #include +#endif + #include +#include #include #include namespace arch { using namespace ntt; - using spidx_t = unsigned short; + + template + struct BaseInjector { + virtual auto DeduceRegion(const Domain& domain, + const boundaries_t& box) const + -> std::tuple, array_t> { + if (not domain.mesh.Intersects(box)) { + return { false, array_t {}, array_t {} }; + } + coord_t xCorner_min_Ph { ZERO }; + coord_t xCorner_max_Ph { ZERO }; + coord_t xCorner_min_Cd { ZERO }; + coord_t xCorner_max_Cd { ZERO }; + + for (auto d { 0u }; d < M::Dim; ++d) { + const auto local_xi_min = domain.mesh.extent(static_cast(d)).first; + const auto local_xi_max = domain.mesh.extent(static_cast(d)).second; + const auto extent_min = std::min(std::max(local_xi_min, box[d].first), + local_xi_max); + const auto extent_max = std::max(std::min(local_xi_max, box[d].second), + local_xi_min); + xCorner_min_Ph[d] = extent_min; + xCorner_max_Ph[d] = extent_max; + } + domain.mesh.metric.template convert(xCorner_min_Ph, + xCorner_min_Cd); + domain.mesh.metric.template convert(xCorner_max_Ph, + xCorner_max_Cd); + + array_t xi_min { "xi_min", M::Dim }, xi_max { "xi_max", M::Dim }; + + auto xi_min_h = Kokkos::create_mirror_view(xi_min); + auto xi_max_h = Kokkos::create_mirror_view(xi_max); + for (auto d { 0u }; d < M::Dim; ++d) { + xi_min_h(d) = xCorner_min_Cd[d]; + xi_max_h(d) = xCorner_max_Cd[d]; + } + Kokkos::deep_copy(xi_min, xi_min_h); + Kokkos::deep_copy(xi_max, xi_max_h); + + return { true, xi_min, xi_max }; + } + + virtual auto ComputeNumInject(const SimulationParams& params, + const Domain& domain, + real_t number_density, + const boundaries_t& box) const + -> std::tuple, array_t> { + const auto result = DeduceRegion(domain, box); + if (not std::get<0>(result)) { + return { false, (npart_t)0, array_t {}, array_t {} }; + } + const auto xi_min = std::get<1>(result); + const auto xi_max = std::get<2>(result); + auto xi_min_h = Kokkos::create_mirror_view(xi_min); + auto xi_max_h = Kokkos::create_mirror_view(xi_max); + Kokkos::deep_copy(xi_min_h, xi_min); + Kokkos::deep_copy(xi_max_h, xi_max); + + long double num_cells { 1.0 }; + for (auto d { 0u }; d < M::Dim; ++d) { + num_cells *= static_cast(xi_max_h(d)) - + static_cast(xi_min_h(d)); + } + + const auto ppc0 = params.template get("particles.ppc0"); + const auto nparticles = static_cast( + (long double)(ppc0 * number_density * 0.5) * num_cells); + + return { true, nparticles, xi_min, xi_max }; + } + }; template class ED> - struct UniformInjector { + struct UniformInjector : BaseInjector { using energy_dist_t = ED; static_assert(M::is_metric, "M must be a metric class"); static_assert(energy_dist_t::is_energy_dist, @@ -60,6 +138,122 @@ namespace arch { ~UniformInjector() = default; }; + template class ED> + struct KeepConstantInjector : UniformInjector { + using energy_dist_t = ED; + using UniformInjector::D; + using UniformInjector::C; + + const idx_t density_buff_idx; + boundaries_t probe_box; + + KeepConstantInjector(const energy_dist_t& energy_dist, + const std::pair& species, + idx_t density_buff_idx, + boundaries_t box = {}) + : UniformInjector { energy_dist, species } + , density_buff_idx { density_buff_idx } { + for (auto d { 0u }; d < M::Dim; ++d) { + if (d < box.size()) { + probe_box.push_back({ box[d].first, box[d].second }); + } else { + probe_box.push_back(Range::All); + } + } + } + + ~KeepConstantInjector() = default; + + auto ComputeAvgDensity(const SimulationParams& params, + const Domain& domain) const -> real_t { + const auto result = this->DeduceRegion(domain, probe_box); + const auto should_probe = std::get<0>(result); + if (not should_probe) { + return ZERO; + } + const auto xi_min_arr = std::get<1>(result); + const auto xi_max_arr = std::get<2>(result); + + tuple_t i_min { 0 }; + tuple_t i_max { 0 }; + + auto xi_min_h = Kokkos::create_mirror_view(xi_min_arr); + auto xi_max_h = Kokkos::create_mirror_view(xi_max_arr); + Kokkos::deep_copy(xi_min_h, xi_min_arr); + Kokkos::deep_copy(xi_max_h, xi_max_arr); + + ncells_t num_cells = 1u; + for (auto d { 0u }; d < M::Dim; ++d) { + i_min[d] = std::floor(xi_min_h(d)) + N_GHOSTS; + i_max[d] = std::ceil(xi_max_h(d)) + N_GHOSTS; + num_cells *= (i_max[d] - i_min[d]); + } + + real_t dens { ZERO }; + if (should_probe) { + Kokkos::parallel_reduce( + "AvgDensity", + CreateRangePolicy(i_min, i_max), + kernel::ComputeSum_kernel(domain.fields.buff, density_buff_idx), + dens); + } +#if defined(MPI_ENABLED) + real_t tot_dens { ZERO }; + ncells_t tot_num_cells { 0 }; + MPI_Allreduce(&dens, &tot_dens, 1, mpi::get_type(), MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&num_cells, + &tot_num_cells, + 1, + mpi::get_type(), + MPI_SUM, + MPI_COMM_WORLD); + dens = tot_dens; + num_cells = tot_num_cells; +#endif + if (num_cells > 0) { + return dens / (real_t)(num_cells); + } else { + return ZERO; + } + } + + auto ComputeNumInject(const SimulationParams& params, + const Domain& domain, + real_t number_density, + const boundaries_t& box) const + -> std::tuple, array_t> override { + const auto computed_avg_density = ComputeAvgDensity(params, domain); + + const auto result = this->DeduceRegion(domain, box); + if (not std::get<0>(result)) { + return { false, (npart_t)0, array_t {}, array_t {} }; + } + + const auto xi_min = std::get<1>(result); + const auto xi_max = std::get<2>(result); + auto xi_min_h = Kokkos::create_mirror_view(xi_min); + auto xi_max_h = Kokkos::create_mirror_view(xi_max); + Kokkos::deep_copy(xi_min_h, xi_min); + Kokkos::deep_copy(xi_max_h, xi_max); + + long double num_cells { 1.0 }; + for (auto d { 0u }; d < M::Dim; ++d) { + num_cells *= static_cast(xi_max_h(d)) - + static_cast(xi_min_h(d)); + } + + const auto ppc0 = params.template get("particles.ppc0"); + npart_t nparticles { 0u }; + if (number_density > computed_avg_density) { + nparticles = static_cast( + (long double)(ppc0 * (number_density - computed_avg_density) * 0.5) * + num_cells); + } + + return { nparticles != 0u, nparticles, xi_min, xi_max }; + } + }; + template @@ -107,7 +301,7 @@ namespace arch { if constexpr ((O == in::x1) or (O == in::x2 and (M::Dim == Dim::_2D or M::Dim == Dim::_3D)) or (O == in::x3 and M::Dim == Dim::_3D)) { - const auto xi = x_Ph[static_cast(O)]; + const auto xi = x_Ph[static_cast(O)]; if constexpr (P) { // + direction if (xi < xsurf - ds or xi >= xsurf) { @@ -142,7 +336,7 @@ namespace arch { }; using energy_dist_t = Maxwellian; - using spatial_dist_t = ReplenishDist; + using spatial_dist_t = Replenish; static_assert(M::is_metric, "M must be a metric class"); static constexpr bool is_nonuniform_injector { true }; static constexpr Dimension D { M::Dim }; @@ -170,22 +364,86 @@ namespace arch { ~AtmosphereInjector() = default; }; + template + struct MovingInjector { + struct TargetDensityProfile { + const real_t nmax, xinj, xdrift; + + TargetDensityProfile(real_t xinj, real_t xdrift, real_t nmax) + : xinj { xinj } + , xdrift { xdrift } + , nmax { nmax } {} + + Inline auto operator()(const coord_t& x_Ph) const -> real_t { + if constexpr ((O == in::x1) or + (O == in::x2 and (M::Dim == Dim::_2D or M::Dim == Dim::_3D)) or + (O == in::x3 and M::Dim == Dim::_3D)) { + const auto xi = x_Ph[static_cast(O)]; + // + direction + if (xi < xdrift or xi >= xinj) { + return ZERO; + } else { + if constexpr (M::CoordType == Coord::Cart) { + return nmax; + } else { + raise::KernelError( + HERE, + "Moving injector in +x cannot be applied for non-cartesian"); + return ZERO; + } + } + } else { + raise::KernelError(HERE, "Wrong direction"); + return ZERO; + } + } + }; + + using energy_dist_t = Maxwellian; + using spatial_dist_t = Replenish; + static_assert(M::is_metric, "M must be a metric class"); + static constexpr bool is_nonuniform_injector { true }; + static constexpr Dimension D { M::Dim }; + static constexpr Coord C { M::CoordType }; + + const energy_dist_t energy_dist; + const TargetDensityProfile target_density; + const spatial_dist_t spatial_dist; + const std::pair species; + + MovingInjector(const M& metric, + const ndfield_t& density, + const energy_dist_t& energy_dist, + real_t xinj, + real_t xdrift, + real_t nmax, + const std::pair& species) + : energy_dist { energy_dist } + , target_density { xinj, xdrift, nmax } + , spatial_dist { metric, density, 0, target_density, nmax } + , species { species } {} + + ~MovingInjector() = default; + }; + /** * @brief Injects uniform number density of particles everywhere in the domain * @param domain Domain object * @param injector Uniform injector object * @param number_density Total number density (in units of n0) * @param use_weights Use weights + * @param box Region to inject the particles in global coords * @tparam S Simulation engine type * @tparam M Metric type * @tparam I Injector type */ template - inline void InjectUniform(const SimulationParams& params, - Domain& domain, - const I& injector, - real_t number_density, - bool use_weights = false) { + inline void InjectUniform(const SimulationParams& params, + Domain& domain, + const I& injector, + real_t number_density, + bool use_weights = false, + const boundaries_t& box = {}) { static_assert(M::is_metric, "M must be a metric class"); static_assert(I::is_uniform_injector, "I must be a uniform injector class"); raise::ErrorIf((M::CoordType != Coord::Cart) && (not use_weights), @@ -205,17 +463,24 @@ namespace arch { } { - auto ppc0 = params.template get("particles.ppc0"); - array_t ni { "ni", M::Dim }; - auto ni_h = Kokkos::create_mirror_view(ni); - std::size_t ncells = 1; - for (auto d = 0; d < M::Dim; ++d) { - ni_h(d) = domain.mesh.n_active()[d]; - ncells *= domain.mesh.n_active()[d]; + boundaries_t nonempty_box; + for (auto d { 0u }; d < M::Dim; ++d) { + if (d < box.size()) { + nonempty_box.push_back({ box[d].first, box[d].second }); + } else { + nonempty_box.push_back(Range::All); + } + } + const auto result = injector.ComputeNumInject(params, + domain, + number_density, + nonempty_box); + if (not std::get<0>(result)) { + return; } - Kokkos::deep_copy(ni, ni_h); - const auto nparticles = static_cast( - (long double)(ppc0 * number_density * 0.5) * (long double)(ncells)); + const auto nparticles = std::get<1>(result); + const auto xi_min = std::get<2>(result); + const auto xi_max = std::get<3>(result); Kokkos::parallel_for( "InjectUniform", @@ -228,7 +493,8 @@ namespace arch { domain.species[injector.species.first - 1].npart(), domain.species[injector.species.second - 1].npart(), domain.mesh.metric, - ni, + xi_min, + xi_max, injector.energy_dist, ONE / params.template get("scales.V0"), domain.random_pool)); @@ -239,6 +505,124 @@ namespace arch { } } + namespace experimental { + + template + class ED1, + template + class ED2> + struct UniformInjector : BaseInjector { + using energy_dist_1_t = ED1; + using energy_dist_2_t = ED2; + static_assert(M::is_metric, "M must be a metric class"); + static_assert(energy_dist_1_t::is_energy_dist, + "ED1 must be an energy distribution class"); + static_assert(energy_dist_2_t::is_energy_dist, + "ED2 must be an energy distribution class"); + static constexpr bool is_uniform_injector { true }; + static constexpr Dimension D { M::Dim }; + static constexpr Coord C { M::CoordType }; + + const energy_dist_1_t energy_dist_1; + const energy_dist_2_t energy_dist_2; + const std::pair species; + + UniformInjector(const energy_dist_1_t& energy_dist_1, + const energy_dist_2_t& energy_dist_2, + const std::pair& species) + : energy_dist_1 { energy_dist_1 } + , energy_dist_2 { energy_dist_2 } + , species { species } {} + + ~UniformInjector() = default; + }; + + /** + * @brief Injects uniform number density of particles everywhere in the domain + * @param domain Domain object + * @param injector Uniform injector object + * @param number_density Total number density (in units of n0) + * @param use_weights Use weights + * @param box Region to inject the particles in global coords + * @tparam S Simulation engine type + * @tparam M Metric type + * @tparam I Injector type + */ + template + inline void InjectUniform(const SimulationParams& params, + Domain& domain, + const I& injector, + real_t number_density, + bool use_weights = false, + const boundaries_t& box = {}) { + static_assert(M::is_metric, "M must be a metric class"); + static_assert(I::is_uniform_injector, "I must be a uniform injector class"); + raise::ErrorIf((M::CoordType != Coord::Cart) && (not use_weights), + "Weights must be used for non-Cartesian coordinates", + HERE); + raise::ErrorIf((M::CoordType == Coord::Cart) && use_weights, + "Weights should not be used for Cartesian coordinates", + HERE); + raise::ErrorIf( + params.template get("particles.use_weights") != use_weights, + "Weights must be enabled from the input file to use them in " + "the injector", + HERE); + if (domain.species[injector.species.first - 1].charge() + + domain.species[injector.species.second - 1].charge() != + 0.0f) { + raise::Warning("Total charge of the injected species is non-zero", HERE); + } + + { + boundaries_t nonempty_box; + for (auto d { 0u }; d < M::Dim; ++d) { + if (d < box.size()) { + nonempty_box.push_back({ box[d].first, box[d].second }); + } else { + nonempty_box.push_back(Range::All); + } + } + const auto result = injector.ComputeNumInject(params, + domain, + number_density, + nonempty_box); + if (not std::get<0>(result)) { + return; + } + const auto nparticles = std::get<1>(result); + const auto xi_min = std::get<2>(result); + const auto xi_max = std::get<3>(result); + + Kokkos::parallel_for( + "InjectUniform", + nparticles, + kernel::experimental:: + UniformInjector_kernel( + injector.species.first, + injector.species.second, + domain.species[injector.species.first - 1], + domain.species[injector.species.second - 1], + domain.species[injector.species.first - 1].npart(), + domain.species[injector.species.second - 1].npart(), + domain.mesh.metric, + xi_min, + xi_max, + injector.energy_dist_1, + injector.energy_dist_2, + ONE / params.template get("scales.V0"), + domain.random_pool)); + domain.species[injector.species.first - 1].set_npart( + domain.species[injector.species.first - 1].npart() + nparticles); + domain.species[injector.species.second - 1].set_npart( + domain.species[injector.species.second - 1].npart() + nparticles); + } + } + + } // namespace experimental + /** * @brief Injects particles from a globally-defined map * @note very inefficient, should only be used for debug purposes @@ -279,12 +663,12 @@ namespace arch { * @param box Region to inject the particles in */ template - inline void InjectNonUniform(const SimulationParams& params, - Domain& domain, - const I& injector, - real_t number_density, - bool use_weights = false, - boundaries_t box = {}) { + inline void InjectNonUniform(const SimulationParams& params, + Domain& domain, + const I& injector, + real_t number_density, + bool use_weights = false, + const boundaries_t& box = {}) { static_assert(M::is_metric, "M must be a metric class"); static_assert(I::is_nonuniform_injector, "I must be a nonuniform injector class"); @@ -320,7 +704,7 @@ namespace arch { incl_ghosts.push_back({ false, false }); } const auto extent = domain.mesh.ExtentToRange(box, incl_ghosts); - tuple_t x_min { 0 }, x_max { 0 }; + tuple_t x_min { 0 }, x_max { 0 }; for (auto d = 0; d < M::Dim; ++d) { x_min[d] = extent[d].first; x_max[d] = extent[d].second; diff --git a/src/archetypes/spatial_dist.h b/src/archetypes/spatial_dist.h index 6c19d44d0..55c84ddf2 100644 --- a/src/archetypes/spatial_dist.h +++ b/src/archetypes/spatial_dist.h @@ -3,8 +3,8 @@ * @brief Spatial distribution class passed to injectors * @implements * - arch::SpatialDistribution<> - * - arch::UniformDist<> : arch::SpatialDistribution<> - * - arch::ReplenishDist<> : arch::SpatialDistribution<> + * - arch::Uniform<> : arch::SpatialDistribution<> + * - arch::Replenish<> : arch::SpatialDistribution<> * @namespace * - arch:: * @note @@ -32,57 +32,53 @@ namespace arch { SpatialDistribution(const M& metric) : metric { metric } {} - Inline virtual auto operator()(const coord_t&) const -> real_t { - return ONE; - } - protected: const M metric; }; template - struct UniformDist : public SpatialDistribution { - UniformDist(const M& metric) : SpatialDistribution { metric } {} + struct Uniform : public SpatialDistribution { + Uniform(const M& metric) : SpatialDistribution { metric } {} - Inline auto operator()(const coord_t&) const -> real_t override { + Inline auto operator()(const coord_t&) const -> real_t { return ONE; } }; template - struct ReplenishDist : public SpatialDistribution { + struct Replenish : public SpatialDistribution { using SpatialDistribution::metric; const ndfield_t density; - const unsigned short idx; + const idx_t idx; const T target_density; const real_t target_max_density; - ReplenishDist(const M& metric, - const ndfield_t& density, - unsigned short idx, - const T& target_density, - real_t target_max_density) + Replenish(const M& metric, + const ndfield_t& density, + idx_t idx, + const T& target_density, + real_t target_max_density) : SpatialDistribution { metric } , density { density } , idx { idx } , target_density { target_density } , target_max_density { target_max_density } {} - Inline auto operator()(const coord_t& x_Ph) const -> real_t override { + Inline auto operator()(const coord_t& x_Ph) const -> real_t { coord_t x_Cd { ZERO }; metric.template convert(x_Ph, x_Cd); real_t dens { ZERO }; if constexpr (M::Dim == Dim::_1D) { - dens = density(static_cast(x_Cd[0]) + N_GHOSTS, idx); + dens = density(static_cast(x_Cd[0]) + N_GHOSTS, idx); } else if constexpr (M::Dim == Dim::_2D) { - dens = density(static_cast(x_Cd[0]) + N_GHOSTS, - static_cast(x_Cd[1]) + N_GHOSTS, + dens = density(static_cast(x_Cd[0]) + N_GHOSTS, + static_cast(x_Cd[1]) + N_GHOSTS, idx); } else if constexpr (M::Dim == Dim::_3D) { - dens = density(static_cast(x_Cd[0]) + N_GHOSTS, - static_cast(x_Cd[1]) + N_GHOSTS, - static_cast(x_Cd[2]) + N_GHOSTS, + dens = density(static_cast(x_Cd[0]) + N_GHOSTS, + static_cast(x_Cd[1]) + N_GHOSTS, + static_cast(x_Cd[2]) + N_GHOSTS, idx); } else { raise::KernelError(HERE, "Invalid dimension"); diff --git a/src/archetypes/tests/CMakeLists.txt b/src/archetypes/tests/CMakeLists.txt index ceee1edc9..9419847c5 100644 --- a/src/archetypes/tests/CMakeLists.txt +++ b/src/archetypes/tests/CMakeLists.txt @@ -1,9 +1,12 @@ +# cmake-lint: disable=C0103,C0111 # ------------------------------ # @brief: Generates tests for the `ntt_archetypes` module +# # @uses: -# - kokkos [required] -# - plog [required] -# - mpi [optional] +# +# * kokkos [required] +# * plog [required] +# * mpi [optional] # ------------------------------ set(SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../) @@ -22,4 +25,5 @@ endfunction() gen_test(energy_dist) gen_test(spatial_dist) -gen_test(field_setter) \ No newline at end of file +gen_test(field_setter) +gen_test(powerlaw) diff --git a/src/archetypes/tests/energy_dist.cpp b/src/archetypes/tests/energy_dist.cpp index bad1d0eb9..0d3fc8023 100644 --- a/src/archetypes/tests/energy_dist.cpp +++ b/src/archetypes/tests/energy_dist.cpp @@ -27,7 +27,7 @@ struct Caller { Inline void operator()(index_t) const { vec_t vp { ZERO }; coord_t xp { ZERO }; - for (unsigned short d = 0; d < D; ++d) { + for (dim_t d { 0u }; d < D; ++d) { xp[d] = 5.0; } dist(xp, vp); @@ -54,13 +54,13 @@ void testEnergyDist(const std::vector& res, if constexpr (M::Dim == Dim::_2D) { extent = { ext[0], - {ZERO, constant::PI} + { ZERO, constant::PI } }; } else if constexpr (M::Dim == Dim::_3D) { extent = { ext[0], - {ZERO, constant::PI}, - {ZERO, constant::TWO_PI} + { ZERO, constant::PI }, + { ZERO, constant::TWO_PI } }; } } diff --git a/src/archetypes/tests/powerlaw.cpp b/src/archetypes/tests/powerlaw.cpp new file mode 100644 index 000000000..3cb76763f --- /dev/null +++ b/src/archetypes/tests/powerlaw.cpp @@ -0,0 +1,173 @@ +#include "enums.h" +#include "global.h" + +#include "utils/error.h" + +#include "metrics/kerr_schild.h" +#include "metrics/kerr_schild_0.h" +#include "metrics/minkowski.h" +#include "metrics/qkerr_schild.h" +#include "metrics/qspherical.h" +#include "metrics/spherical.h" + +#include "archetypes/energy_dist.h" + +#include + +#include + +using namespace ntt; +using namespace metric; +using namespace arch; + +template +struct Caller { + static constexpr auto D = M::Dim; + + Caller(const M& metric, const EnrgDist& dist) + : metric { metric } + , dist { dist } {} + + Inline void operator()(index_t) const { + vec_t vp { ZERO }; + coord_t xp { ZERO }; + for (dim_t d { 0u }; d < D; ++d) { + xp[d] = 2.0; + } + dist(xp, vp); + if (not Kokkos::isfinite(vp[0]) or not Kokkos::isfinite(vp[1]) or + not Kokkos::isfinite(vp[2])) { + raise::KernelError(HERE, "Non-finite velocity generated"); + } + if constexpr (S == SimEngine::SRPIC) { + const auto gamma = math::sqrt(ONE + SQR(vp[0]) + SQR(vp[1]) + SQR(vp[2])); + if (gamma < 10 or gamma > 1000) { + raise::KernelError(HERE, "Gamma out of bounds"); + } + } else { + vec_t vd { ZERO }; + vec_t vu { ZERO }; + metric.template transform(xp, vp, vd); + metric.template transform(xp, vp, vu); + const auto gamma = math::sqrt( + ONE + vu[0] * vd[0] + vu[1] * vd[1] + vu[2] * vd[2]); + if (gamma < 10 or gamma > 1000) { + raise::KernelError(HERE, "Gamma out of bounds"); + } + } + } + +private: + M metric; + EnrgDist dist; +}; + +template +void testEnergyDist(const std::vector& res, + const boundaries_t& ext, + const std::map& params = {}) { + raise::ErrorIf(res.size() != M::Dim, "res.size() != M::Dim", HERE); + + boundaries_t extent; + if constexpr (M::CoordType == Coord::Cart) { + extent = ext; + } else { + if constexpr (M::Dim == Dim::_2D) { + extent = { + ext[0], + { ZERO, constant::PI } + }; + } else if constexpr (M::Dim == Dim::_3D) { + extent = { + ext[0], + { ZERO, constant::PI }, + { ZERO, constant::TWO_PI } + }; + } + } + raise::ErrorIf(extent.size() != M::Dim, "extent.size() != M::Dim", HERE); + + M metric { res, extent, params }; + + random_number_pool_t pool { constant::RandomSeed }; + Powerlaw plaw { metric, + pool, + static_cast(10), + static_cast(1000), + static_cast(-2.5) }; + Kokkos::parallel_for("Powerlaw", 100, Caller, S, M>(metric, plaw)); +} + +auto main(int argc, char* argv[]) -> int { + Kokkos::initialize(argc, argv); + + try { + using namespace ntt; + testEnergyDist>( + { + 10 + }, + { { 0.0, 55.0 } }); + + testEnergyDist>( + { + 10, + 10 + }, + { { 0.0, 55.0 }, { 0.0, 55.0 } }); + + testEnergyDist>( + { + 10, + 10, + 10 + }, + { { 0.0, 55.0 }, { 0.0, 55.0 }, { 0.0, 55.0 } }); + + testEnergyDist>( + { + 10, + 10 + }, + { { 1.0, 100.0 } }); + + testEnergyDist>( + { + 10, + 10 + }, + { { 1.0, 100.0 } }, + { { "r0", 0.0 }, { "h", 0.25 } }); + + testEnergyDist>( + { + 10, + 10 + }, + { { 1.0, 100.0 } }, + { { "a", 0.9 } }); + + testEnergyDist>( + { + 10, + 10 + }, + { { 1.0, 100.0 } }, + { { "r0", 0.0 }, { "h", 0.25 }, { "a", 0.9 } }); + + testEnergyDist>( + { + 10, + 10 + }, + { { 1.0, 100.0 } }, + { { "a", 0.9 } }); + + } catch (std::exception& e) { + std::cerr << e.what() << std::endl; + Kokkos::finalize(); + return 1; + } + Kokkos::finalize(); + return 0; +} diff --git a/src/archetypes/tests/spatial_dist.cpp b/src/archetypes/tests/spatial_dist.cpp index 5ab64a156..51f1703a9 100644 --- a/src/archetypes/tests/spatial_dist.cpp +++ b/src/archetypes/tests/spatial_dist.cpp @@ -76,11 +76,11 @@ struct RadialDist : public SpatialDistribution { RadialDist(const M& metric) : SpatialDistribution { metric } {} - Inline auto operator()(const coord_t& x_Code) const -> real_t override { + Inline auto operator()(const coord_t& x_Code) const -> real_t { coord_t x_Sph { ZERO }; metric.template convert(x_Code, x_Sph); auto r { ZERO }; - for (unsigned short d = 0; d < M::Dim; ++d) { + for (dim_t d { 0u }; d < M::Dim; ++d) { r += SQR(x_Sph[d]); } return math::sqrt(r); @@ -91,14 +91,14 @@ auto main(int argc, char* argv[]) -> int { Kokkos::initialize(argc, argv); try { Minkowski m1 { - { 10, 10}, - {{ -10.0, 55.0 }, { -10.0, 55.0 }} + { 10, 10 }, + { { -10.0, 55.0 }, { -10.0, 55.0 } } }; RadialDist> r1 { m1 }; Minkowski m2 { - { 10, 10, 30}, - {{ -1.0, 1.0 }, { -1.0, 1.0 }, { -3.0, 3.0 }} + { 10, 10, 30 }, + { { -1.0, 1.0 }, { -1.0, 1.0 }, { -3.0, 3.0 } } }; RadialDist> r2 { m2 }; @@ -123,4 +123,4 @@ auto main(int argc, char* argv[]) -> int { } Kokkos::finalize(); return 0; -} \ No newline at end of file +} diff --git a/src/checkpoint/CMakeLists.txt b/src/checkpoint/CMakeLists.txt new file mode 100644 index 000000000..096aad690 --- /dev/null +++ b/src/checkpoint/CMakeLists.txt @@ -0,0 +1,37 @@ +# cmake-lint: disable=C0103 +# ------------------------------ +# @defines: ntt_checkpoint [STATIC/SHARED] +# +# @sources: +# +# * writer.cpp +# * reader.cpp +# +# @includes: +# +# * ../ +# +# @depends: +# +# * ntt_global [required] +# +# @uses: +# +# * kokkos [required] +# * ADIOS2 [required] +# * mpi [optional] +# ------------------------------ + +set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(SOURCES ${SRC_DIR}/writer.cpp ${SRC_DIR}/reader.cpp) +add_library(ntt_checkpoint ${SOURCES}) + +set(libs ntt_global) +add_dependencies(ntt_checkpoint ${libs}) +target_link_libraries(ntt_checkpoint PUBLIC ${libs}) +target_link_libraries(ntt_checkpoint PRIVATE stdc++fs) + +target_include_directories( + ntt_checkpoint + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../ + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../) diff --git a/src/checkpoint/reader.cpp b/src/checkpoint/reader.cpp new file mode 100644 index 000000000..d973b0ddd --- /dev/null +++ b/src/checkpoint/reader.cpp @@ -0,0 +1,161 @@ +#include "checkpoint/reader.h" + +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/error.h" +#include "utils/formatting.h" +#include "utils/log.h" + +#include +#include + +#if defined(MPI_ENABLED) + #include +#endif + +#include +#include +#include + +namespace checkpoint { + + template + void ReadFields(adios2::IO& io, + adios2::Engine& reader, + const std::string& field, + const adios2::Box& range, + ndfield_t& array) { + logger::Checkpoint(fmt::format("Reading field: %s", field.c_str()), HERE); + auto field_var = io.InquireVariable(field); + if (field_var) { + field_var.SetSelection(range); + + auto array_h = Kokkos::create_mirror_view(array); + reader.Get(field_var, array_h.data(), adios2::Mode::Sync); + Kokkos::deep_copy(array, array_h); + } else { + raise::Error(fmt::format("Field variable: %s not found", field.c_str()), + HERE); + } + } + + auto ReadParticleCount(adios2::IO& io, + adios2::Engine& reader, + spidx_t s, + std::size_t local_dom, + std::size_t ndomains) -> std::pair { + logger::Checkpoint(fmt::format("Reading particle count for: %d", s + 1), HERE); + auto npart_var = io.InquireVariable(fmt::format("s%d_npart", s + 1)); + if (npart_var) { + raise::ErrorIf(npart_var.Shape()[0] != ndomains, + "npart_var.Shape()[0] != ndomains", + HERE); + raise::ErrorIf(npart_var.Shape().size() != 1, + "npart_var.Shape().size() != 1", + HERE); + npart_var.SetSelection(adios2::Box({ local_dom }, { 1 })); + npart_t npart; + reader.Get(npart_var, &npart, adios2::Mode::Sync); + const auto loc_npart = npart; +#if !defined(MPI_ENABLED) + npart_t offset_npart = 0; +#else + std::vector glob_nparts(ndomains); + MPI_Allgather(&loc_npart, + 1, + mpi::get_type(), + glob_nparts.data(), + 1, + mpi::get_type(), + MPI_COMM_WORLD); + npart_t offset_npart = 0; + for (auto d { 0u }; d < local_dom; ++d) { + offset_npart += glob_nparts[d]; + } +#endif + return { loc_npart, offset_npart }; + } else { + raise::Error("npart_var is not found", HERE); + return { 0, 0 }; + } + } + + template + void ReadParticleData(adios2::IO& io, + adios2::Engine& reader, + const std::string& quantity, + spidx_t s, + array_t& array, + npart_t count, + npart_t offset) { + logger::Checkpoint( + fmt::format("Reading quantity: s%d_%s", s + 1, quantity.c_str()), + HERE); + auto var = io.InquireVariable( + fmt::format("s%d_%s", s + 1, quantity.c_str())); + if (var) { + var.SetSelection(adios2::Box({ offset }, { count })); + const auto slice = range_tuple_t(0, count); + auto array_h = Kokkos::create_mirror_view(array); + reader.Get(var, Kokkos::subview(array_h, slice).data(), adios2::Mode::Sync); + Kokkos::deep_copy(Kokkos::subview(array, slice), + Kokkos::subview(array_h, slice)); + } else { + raise::Error( + fmt::format("Variable: s%d_%s not found", s + 1, quantity.c_str()), + HERE); + } + } + + void ReadParticlePayloads(adios2::IO& io, + adios2::Engine& reader, + spidx_t s, + array_t& array, + std::size_t nplds, + npart_t count, + npart_t offset) { + logger::Checkpoint(fmt::format("Reading quantity: s%d_plds", s + 1), HERE); + auto var = io.InquireVariable(fmt::format("s%d_plds", s + 1)); + if (var) { + var.SetSelection(adios2::Box({ offset, 0 }, { count, nplds })); + const auto slice = range_tuple_t(0, count); + auto array_h = Kokkos::create_mirror_view(array); + reader.Get(var, + Kokkos::subview(array_h, slice, range_tuple_t(0, nplds)).data(), + adios2::Mode::Sync); + Kokkos::deep_copy(array, array_h); + } else { + raise::Error(fmt::format("Variable: s%d_plds not found", s + 1), HERE); + } + } + +#define CHECKPOINT_FIELDS(D, N) \ + template void ReadFields(adios2::IO&, \ + adios2::Engine&, \ + const std::string&, \ + const adios2::Box&, \ + ndfield_t&); + CHECKPOINT_FIELDS(Dim::_1D, 3) + CHECKPOINT_FIELDS(Dim::_2D, 3) + CHECKPOINT_FIELDS(Dim::_3D, 3) + CHECKPOINT_FIELDS(Dim::_1D, 6) + CHECKPOINT_FIELDS(Dim::_2D, 6) + CHECKPOINT_FIELDS(Dim::_3D, 6) +#undef CHECKPOINT_FIELDS + +#define CHECKPOINT_PARTICLE_DATA(T) \ + template void ReadParticleData(adios2::IO&, \ + adios2::Engine&, \ + const std::string&, \ + spidx_t, \ + array_t&, \ + npart_t, \ + npart_t); + CHECKPOINT_PARTICLE_DATA(int) + CHECKPOINT_PARTICLE_DATA(float) + CHECKPOINT_PARTICLE_DATA(double) + CHECKPOINT_PARTICLE_DATA(short) +#undef CHECKPOINT_PARTICLE_DATA + +} // namespace checkpoint diff --git a/src/checkpoint/reader.h b/src/checkpoint/reader.h new file mode 100644 index 000000000..7939ba82b --- /dev/null +++ b/src/checkpoint/reader.h @@ -0,0 +1,58 @@ +/** + * @file checkpoint/reader.h + * @brief Function for reading field & particle data from checkpoint files + * @implements + * - checkpoint::ReadFields -> void + * - checkpoint::ReadParticleData -> void + * - checkpoint::ReadParticleCount -> std::pair + * @cpp: + * - reader.cpp + * @namespaces: + * - checkpoint:: + */ + +#ifndef CHECKPOINT_READER_H +#define CHECKPOINT_READER_H + +#include "arch/kokkos_aliases.h" + +#include + +#include +#include + +namespace checkpoint { + + template + void ReadFields(adios2::IO&, + adios2::Engine&, + const std::string&, + const adios2::Box&, + ndfield_t&); + + auto ReadParticleCount(adios2::IO&, + adios2::Engine&, + spidx_t, + std::size_t, + std::size_t) -> std::pair; + + template + void ReadParticleData(adios2::IO&, + adios2::Engine&, + const std::string&, + spidx_t, + array_t&, + npart_t, + npart_t); + + void ReadParticlePayloads(adios2::IO&, + adios2::Engine&, + spidx_t, + array_t&, + std::size_t, + npart_t, + npart_t); + +} // namespace checkpoint + +#endif // CHECKPOINT_READER_H diff --git a/src/checkpoint/tests/CMakeLists.txt b/src/checkpoint/tests/CMakeLists.txt new file mode 100644 index 000000000..cbfd63aa9 --- /dev/null +++ b/src/checkpoint/tests/CMakeLists.txt @@ -0,0 +1,30 @@ +# cmake-lint: disable=C0103,C0111 +# ------------------------------ +# @brief: Generates tests for the `ntt_checkpoint` module +# +# @uses: +# +# * kokkos [required] +# * adios2 [required] +# * mpi [optional] +# ------------------------------ + +set(SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../) + +function(gen_test title) + set(exec test-output-${title}.xc) + set(src ${title}.cpp) + add_executable(${exec} ${src}) + + set(libs ntt_checkpoint ntt_global) + add_dependencies(${exec} ${libs}) + target_link_libraries(${exec} PRIVATE ${libs} stdc++fs) + + add_test(NAME "CHECKPOINT::${title}" COMMAND "${exec}") +endfunction() + +if(NOT ${mpi}) + gen_test(checkpoint-nompi) +else() + gen_test(checkpoint-mpi) +endif() diff --git a/src/checkpoint/tests/checkpoint-mpi.cpp b/src/checkpoint/tests/checkpoint-mpi.cpp new file mode 100644 index 000000000..2372d81bc --- /dev/null +++ b/src/checkpoint/tests/checkpoint-mpi.cpp @@ -0,0 +1,273 @@ +#include "enums.h" +#include "global.h" + +#include "utils/comparators.h" + +#include "checkpoint/reader.h" +#include "checkpoint/writer.h" + +#include +#include +#include +#include + +#include +#include +#include + +using namespace ntt; +using namespace checkpoint; + +void cleanup() { + namespace fs = std::filesystem; + fs::path temp_path { "chck" }; + fs::remove_all(temp_path); +} + +auto main(int argc, char* argv[]) -> int { + Kokkos::initialize(argc, argv); + MPI_Init(&argc, &argv); + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + + try { + // assuming 4 ranks + // |------|------| + // | 2 | 3 | + // |------|------| + // | | | + // | 0 | 1 | + // |------|------| + const std::size_t g_nx1 = 20; + const std::size_t g_nx2 = 15; + const std::size_t g_nx1_gh = g_nx1 + 4 * N_GHOSTS; + const std::size_t g_nx2_gh = g_nx2 + 4 * N_GHOSTS; + + const std::size_t l_nx1 = 10; + const std::size_t l_nx2 = (rank < 2) ? 10 : 5; + + const std::size_t l_nx1_gh = l_nx1 + 2 * N_GHOSTS; + const std::size_t l_nx2_gh = l_nx2 + 2 * N_GHOSTS; + + const std::size_t l_corner_x1 = (rank % 2 == 0) ? 0 : l_nx1_gh; + const std::size_t l_corner_x2 = (rank < 2) ? 0 : l_nx2_gh; + + const std::size_t i1min = N_GHOSTS; + const std::size_t i2min = N_GHOSTS; + const std::size_t i1max = l_nx1 + N_GHOSTS; + const std::size_t i2max = l_nx2 + N_GHOSTS; + + const std::size_t npart1 = (rank % 2 + rank) * 23 + 100; + const std::size_t npart2 = (rank % 2 + rank) * 37 + 100; + + std::size_t npart1_offset = 0; + std::size_t npart2_offset = 0; + + std::size_t npart1_globtot = 0; + std::size_t npart2_globtot = 0; + + for (auto r = 0; r < rank - 1; ++r) { + npart1_offset += (r % 2 + r) * 23 + 100; + npart2_offset += (r % 2 + r) * 37 + 100; + } + + for (auto r = 0; r < size; ++r) { + npart1_globtot += (r % 2 + r) * 23 + 100; + npart2_globtot += (r % 2 + r) * 37 + 100; + } + + // init data + ndfield_t field1 { "fld1", l_nx1_gh, l_nx2_gh }; + ndfield_t field2 { "fld2", l_nx1_gh, l_nx2_gh }; + + array_t i1 { "i_1", npart1 }; + array_t u1 { "u_1", npart1 }; + array_t i2 { "i_2", npart2 }; + array_t u2 { "u_2", npart2 }; + array_t plds1 { "plds_1", npart1, 3 }; + + { + // fill data + Kokkos::parallel_for( + "fillFlds", + CreateRangePolicy({ i1min, i2min }, { i1max, i2max }), + Lambda(index_t i1, index_t i2) { + field1(i1, i2, 0) = static_cast(i1 + i2); + field1(i1, i2, 1) = static_cast(i1 * i2); + field1(i1, i2, 2) = static_cast(i1 / i2); + field1(i1, i2, 3) = static_cast(i1 - i2); + field1(i1, i2, 4) = static_cast(i2 / i1); + field1(i1, i2, 5) = static_cast(i1); + field2(i1, i2, 0) = static_cast(-(i1 + i2)); + field2(i1, i2, 1) = static_cast(-(i1 * i2)); + field2(i1, i2, 2) = static_cast(-(i1 / i2)); + field2(i1, i2, 3) = static_cast(-(i1 - i2)); + field2(i1, i2, 4) = static_cast(-(i2 / i1)); + field2(i1, i2, 5) = static_cast(-i1); + }); + Kokkos::parallel_for( + "fillPrtl1", + npart1, + Lambda(index_t p) { + u1(p) = static_cast(p); + i1(p) = static_cast(p); + plds1(p, 0) = static_cast(p); + plds1(p, 1) = static_cast(p * p); + plds1(p, 2) = static_cast(p * p * p); + }); + Kokkos::parallel_for( + "fillPrtl2", + npart2, + Lambda(index_t p) { + u2(p) = -static_cast(p); + i2(p) = -static_cast(p); + }); + } + + adios2::ADIOS adios; + const path_t checkpoint_path { "chck" }; + + { + // write checkpoint + Writer writer; + writer.init(&adios, checkpoint_path, 0, 0.0, 1); + + writer.defineFieldVariables(SimEngine::GRPIC, + { g_nx1_gh, g_nx2_gh }, + { l_corner_x1, l_corner_x2 }, + { l_nx1_gh, l_nx2_gh }); + + writer.defineParticleVariables(Coord::Sph, Dim::_2D, 2, { 3, 0 }); + + writer.beginSaving(0, 0.0); + + writer.saveField("em", field1); + writer.saveField("em0", field2); + + writer.savePerDomainVariable("s1_npart", 1, 0, npart1); + writer.savePerDomainVariable("s2_npart", 1, 0, npart2); + + writer.saveParticleQuantity("s1_i1", + npart1_globtot, + npart1_offset, + npart1, + i1); + writer.saveParticleQuantity("s1_ux1", + npart1_globtot, + npart1_offset, + npart1, + u1); + writer.saveParticleQuantity("s2_i1", + npart2_globtot, + npart2_offset, + npart2, + i2); + writer.saveParticleQuantity("s2_ux1", + npart2_globtot, + npart2_offset, + npart2, + u2); + + writer.saveParticlePayloads("s1_plds", + 3, + npart1_globtot, + npart1_offset, + npart1, + plds1); + + writer.endSaving(); + } + + { + // read checkpoint + ndfield_t field1_read { "fld1_read", l_nx1_gh, l_nx2_gh }; + ndfield_t field2_read { "fld2_read", l_nx1_gh, l_nx2_gh }; + + array_t i1_read { "i_1", npart1 }; + array_t u1_read { "u_1", npart1 }; + array_t i2_read { "i_2", npart2 }; + array_t u2_read { "u_2", npart2 }; + array_t plds1_read { "plds_1", npart1, 3 }; + + adios2::IO io = adios.DeclareIO("checkpointRead"); + adios2::Engine reader = io.Open(checkpoint_path / "step-00000000.bp", + adios2::Mode::Read); + reader.BeginStep(); + + auto fieldRange = adios2::Box({ l_corner_x1, l_corner_x2, 0 }, + { l_nx1_gh, l_nx2_gh, 6 }); + ReadFields(io, reader, "em", fieldRange, field1_read); + ReadFields(io, reader, "em0", fieldRange, field2_read); + + auto [nprtl1, noff1] = ReadParticleCount(io, reader, 0, rank, size); + auto [nprtl2, noff2] = ReadParticleCount(io, reader, 1, rank, size); + + ReadParticleData(io, reader, "ux1", 0, u1_read, nprtl1, noff1); + ReadParticleData(io, reader, "ux1", 1, u2_read, nprtl2, noff2); + ReadParticleData(io, reader, "i1", 0, i1_read, nprtl1, noff1); + ReadParticleData(io, reader, "i1", 1, i2_read, nprtl2, noff2); + ReadParticlePayloads(io, reader, 0, plds1_read, 3, nprtl1, noff1); + + reader.EndStep(); + reader.Close(); + + // check the validity + Kokkos::parallel_for( + "checkFields", + CreateRangePolicy({ 0, 0 }, { l_nx1_gh, l_nx2_gh }), + Lambda(index_t i1, index_t i2) { + for (int i = 0; i < 6; ++i) { + if (not cmp::AlmostEqual(field1(i1, i2, i), field1_read(i1, i2, i))) { + raise::KernelError(HERE, "Field1 read failed"); + } + if (not cmp::AlmostEqual(field2(i1, i2, i), field2_read(i1, i2, i))) { + raise::KernelError(HERE, "Field2 read failed"); + } + } + }); + + raise::ErrorIf(npart1 != nprtl1, "Particle count 1 mismatch", HERE); + raise::ErrorIf(npart2 != nprtl2, "Particle count 2 mismatch", HERE); + raise::ErrorIf(noff1 != npart1_offset, "Particle offset 1 mismatch", HERE); + raise::ErrorIf(noff2 != npart2_offset, "Particle offset 2 mismatch", HERE); + + Kokkos::parallel_for( + "checkPrtl1", + nprtl1, + Lambda(index_t p) { + if (not cmp::AlmostEqual(u1(p), u1_read(p))) { + raise::KernelError(HERE, "u1 read failed"); + } + if (i1(p) != i1_read(p)) { + raise::KernelError(HERE, "i1 read failed"); + } + for (auto l = 0; l < 3; ++l) { + if (not cmp::AlmostEqual(plds1(p, l), plds1_read(p, l))) { + raise::KernelError(HERE, "plds1 read failed"); + } + } + }); + Kokkos::parallel_for( + "checkPrtl2", + nprtl2, + Lambda(index_t p) { + if (not cmp::AlmostEqual(u2(p), u2_read(p))) { + raise::KernelError(HERE, "u2 read failed"); + } + if (i2(p) != i2_read(p)) { + raise::KernelError(HERE, "i2 read failed"); + } + }); + } + + } catch (std::exception& e) { + std::cerr << e.what() << std::endl; + cleanup(); + Kokkos::finalize(); + return 1; + } + cleanup(); + Kokkos::finalize(); + return 0; +} diff --git a/src/checkpoint/tests/checkpoint-nompi.cpp b/src/checkpoint/tests/checkpoint-nompi.cpp new file mode 100644 index 000000000..132a3679a --- /dev/null +++ b/src/checkpoint/tests/checkpoint-nompi.cpp @@ -0,0 +1,208 @@ +#include "enums.h" +#include "global.h" + +#include "utils/comparators.h" + +#include "checkpoint/reader.h" +#include "checkpoint/writer.h" + +#include +#include +#include + +#include +#include + +using namespace ntt; +using namespace checkpoint; + +void cleanup() { + namespace fs = std::filesystem; + fs::path temp_path { "chck" }; + fs::remove_all(temp_path); +} + +auto main(int argc, char* argv[]) -> int { + Kokkos::initialize(argc, argv); + + try { + constexpr auto nx1 = 10; + constexpr auto nx1_gh = nx1 + 2 * N_GHOSTS; + constexpr auto nx2 = 13; + constexpr auto nx2_gh = nx2 + 2 * N_GHOSTS; + constexpr auto nx3 = 9; + constexpr auto nx3_gh = nx3 + 2 * N_GHOSTS; + constexpr auto i1min = N_GHOSTS; + constexpr auto i2min = N_GHOSTS; + constexpr auto i3min = N_GHOSTS; + constexpr auto i1max = nx1 + N_GHOSTS; + constexpr auto i2max = nx2 + N_GHOSTS; + constexpr auto i3max = nx3 + N_GHOSTS; + constexpr auto npart1 = 100; + constexpr auto npart2 = 100; + + // init data + ndfield_t field1 { "fld1", nx1_gh, nx2_gh, nx3_gh }; + ndfield_t field2 { "fld2", nx1_gh, nx2_gh, nx3_gh }; + + array_t i1 { "i_1", npart1 }; + array_t u1 { "u_1", npart1 }; + array_t i2 { "i_2", npart2 }; + array_t u2 { "u_2", npart2 }; + + { + // fill data + Kokkos::parallel_for( + "fillFlds", + CreateRangePolicy({ i1min, i2min, i3min }, + { i1max, i2max, i3max }), + Lambda(index_t i1, index_t i2, index_t i3) { + const auto i1_ = static_cast(i1); + const auto i2_ = static_cast(i2); + const auto i3_ = static_cast(i3); + field1(i1, i2, i3, 0) = i1_ + i2_ + i3_; + field1(i1, i2, i3, 1) = i1_ * i2_ / i3_; + field1(i1, i2, i3, 2) = i1_ / i2_ * i3_; + field1(i1, i2, i3, 3) = i1_ + i2_ - i3_; + field1(i1, i2, i3, 4) = i1_ * i2_ + i3_; + field1(i1, i2, i3, 5) = i1_ / i2_ - i3_; + field2(i1, i2, i3, 0) = -(i1_ + i2_ + i3_); + field2(i1, i2, i3, 1) = -(i1_ * i2_ / i3_); + field2(i1, i2, i3, 2) = -(i1_ / i2_ * i3_); + field2(i1, i2, i3, 3) = -(i1_ + i2_ - i3_); + field2(i1, i2, i3, 4) = -(i1_ * i2_ + i3_); + field2(i1, i2, i3, 5) = -(i1_ / i2_ - i3_); + }); + Kokkos::parallel_for( + "fillPrtl1", + npart1, + Lambda(index_t p) { + u1(p) = static_cast(p); + i1(p) = static_cast(p); + }); + Kokkos::parallel_for( + "fillPrtl2", + npart2, + Lambda(index_t p) { + u2(p) = -static_cast(p); + i2(p) = -static_cast(p); + }); + } + + adios2::ADIOS adios; + const path_t checkpoint_path { "chck" }; + + { + // write checkpoint + Writer writer {}; + writer.init(&adios, checkpoint_path, 0, 0.0, 1); + + writer.defineFieldVariables(SimEngine::GRPIC, + { nx1_gh, nx2_gh, nx3_gh }, + { 0, 0, 0 }, + { nx1_gh, nx2_gh, nx3_gh }); + writer.defineParticleVariables(Coord::Sph, Dim::_3D, 2, { 0, 2 }); + + writer.beginSaving(0, 0.0); + + writer.saveField("em", field1); + writer.saveField("em0", field2); + + writer.savePerDomainVariable("s1_npart", 1, 0, npart1); + writer.savePerDomainVariable("s2_npart", 1, 0, npart2); + + writer.saveParticleQuantity("s1_i1", npart1, 0, npart1, i1); + writer.saveParticleQuantity("s1_ux1", npart1, 0, npart1, u1); + writer.saveParticleQuantity("s2_i1", npart2, 0, npart2, i2); + writer.saveParticleQuantity("s2_ux1", npart2, 0, npart2, u2); + + writer.endSaving(); + } + + { + // read checkpoint + ndfield_t field1_read { "fld1_read", nx1_gh, nx2_gh, nx3_gh }; + ndfield_t field2_read { "fld2_read", nx1_gh, nx2_gh, nx3_gh }; + + array_t i1_read { "i_1", npart1 }; + array_t u1_read { "u_1", npart1 }; + array_t i2_read { "i_2", npart2 }; + array_t u2_read { "u_2", npart2 }; + + adios2::IO io = adios.DeclareIO("checkpointRead"); + adios2::Engine reader = io.Open(checkpoint_path / "step-00000000.bp", + adios2::Mode::Read); + reader.BeginStep(); + + auto fieldRange = adios2::Box({ 0, 0, 0, 0 }, + { nx1_gh, nx2_gh, nx3_gh, 6 }); + ReadFields(io, reader, "em", fieldRange, field1_read); + ReadFields(io, reader, "em0", fieldRange, field2_read); + + auto [nprtl1, noff1] = ReadParticleCount(io, reader, 0, 0, 1); + auto [nprtl2, noff2] = ReadParticleCount(io, reader, 1, 0, 1); + + ReadParticleData(io, reader, "ux1", 0, u1_read, nprtl1, noff1); + ReadParticleData(io, reader, "ux1", 1, u2_read, nprtl2, noff2); + ReadParticleData(io, reader, "i1", 0, i1_read, nprtl1, noff1); + ReadParticleData(io, reader, "i1", 1, i2_read, nprtl2, noff2); + + reader.EndStep(); + reader.Close(); + + // check the validity + Kokkos::parallel_for( + "checkFields", + CreateRangePolicy({ 0, 0, 0 }, { nx1_gh, nx2_gh, nx3_gh }), + Lambda(index_t i1, index_t i2, index_t i3) { + for (int i = 0; i < 6; ++i) { + if (not cmp::AlmostEqual(field1(i1, i2, i3, i), + field1_read(i1, i2, i3, i))) { + raise::KernelError(HERE, "Field1 read failed"); + } + if (not cmp::AlmostEqual(field2(i1, i2, i3, i), + field2_read(i1, i2, i3, i))) { + raise::KernelError(HERE, "Field2 read failed"); + } + } + }); + + raise::ErrorIf(npart1 != nprtl1, "Particle count 1 mismatch", HERE); + raise::ErrorIf(npart2 != nprtl2, "Particle count 2 mismatch", HERE); + raise::ErrorIf(noff1 != 0, "Particle offset 1 mismatch", HERE); + raise::ErrorIf(noff2 != 0, "Particle offset 2 mismatch", HERE); + + Kokkos::parallel_for( + "checkPrtl1", + npart1, + Lambda(index_t p) { + if (not cmp::AlmostEqual(u1(p), u1_read(p))) { + raise::KernelError(HERE, "u1 read failed"); + } + if (i1(p) != i1_read(p)) { + raise::KernelError(HERE, "i1 read failed"); + } + }); + Kokkos::parallel_for( + "checkPrtl2", + npart2, + Lambda(index_t p) { + if (not cmp::AlmostEqual(u2(p), u2_read(p))) { + raise::KernelError(HERE, "u2 read failed"); + } + if (i2(p) != i2_read(p)) { + raise::KernelError(HERE, "i2 read failed"); + } + }); + } + + } catch (std::exception& e) { + std::cerr << e.what() << std::endl; + cleanup(); + Kokkos::finalize(); + return 1; + } + cleanup(); + Kokkos::finalize(); + return 0; +} diff --git a/src/checkpoint/writer.cpp b/src/checkpoint/writer.cpp new file mode 100644 index 000000000..b766ddfbd --- /dev/null +++ b/src/checkpoint/writer.cpp @@ -0,0 +1,302 @@ +#include "checkpoint/writer.h" + +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/error.h" +#include "utils/formatting.h" +#include "utils/log.h" + +#include "framework/parameters.h" + +#include +#include + +#include +#include +#include +#include + +namespace checkpoint { + + void Writer::init(adios2::ADIOS* ptr_adios, + const path_t& checkpoint_root, + timestep_t interval, + simtime_t interval_time, + int keep, + const std::string& walltime) { + m_keep = keep; + m_checkpoint_root = checkpoint_root; + m_enabled = keep != 0; + if (not m_enabled) { + return; + } + m_tracker.init("checkpoint", interval, interval_time, walltime); + p_adios = ptr_adios; + raise::ErrorIf(p_adios == nullptr, "ADIOS pointer is null", HERE); + + m_io = p_adios->DeclareIO("Entity::Checkpoint"); + m_io.SetEngine("BPFile"); + + m_io.DefineVariable("Step"); + m_io.DefineVariable("Time"); + m_io.DefineAttribute("NGhosts", ntt::N_GHOSTS); + + CallOnce( + [](auto&& checkpoint_root) { + if (!std::filesystem::exists(checkpoint_root)) { + std::filesystem::create_directory(checkpoint_root); + } + }, + m_checkpoint_root); + } + + void Writer::defineFieldVariables(const ntt::SimEngine& S, + const std::vector& glob_shape, + const std::vector& loc_corner, + const std::vector& loc_shape) { + auto gs6 = std::vector(glob_shape.begin(), glob_shape.end()); + auto lc6 = std::vector(loc_corner.begin(), loc_corner.end()); + auto ls6 = std::vector(loc_shape.begin(), loc_shape.end()); + gs6.push_back(6); + lc6.push_back(0); + ls6.push_back(6); + + m_io.DefineVariable("em", gs6, lc6, ls6); + if (S == ntt::SimEngine::GRPIC) { + m_io.DefineVariable("em0", gs6, lc6, ls6); + auto gs3 = std::vector(glob_shape.begin(), glob_shape.end()); + auto lc3 = std::vector(loc_corner.begin(), loc_corner.end()); + auto ls3 = std::vector(loc_shape.begin(), loc_shape.end()); + gs3.push_back(3); + lc3.push_back(0); + ls3.push_back(3); + m_io.DefineVariable("cur0", gs3, lc3, ls3); + } + } + + void Writer::defineParticleVariables(const ntt::Coord& C, + Dimension dim, + std::size_t nspec, + const std::vector& nplds) { + raise::ErrorIf(nplds.size() != nspec, + "Number of payloads does not match the number of species", + HERE); + for (auto s { 0u }; s < nspec; ++s) { + m_io.DefineVariable(fmt::format("s%d_npart", s + 1), + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + + for (auto d { 0u }; d < dim; ++d) { + m_io.DefineVariable(fmt::format("s%d_i%d", s + 1, d + 1), + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + m_io.DefineVariable(fmt::format("s%d_dx%d", s + 1, d + 1), + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + m_io.DefineVariable(fmt::format("s%d_i%d_prev", s + 1, d + 1), + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + m_io.DefineVariable(fmt::format("s%d_dx%d_prev", s + 1, d + 1), + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + } + + if (dim == Dim::_2D and C != ntt::Coord::Cart) { + m_io.DefineVariable(fmt::format("s%d_phi", s + 1), + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + } + + for (auto d { 0u }; d < 3; ++d) { + m_io.DefineVariable(fmt::format("s%d_ux%d", s + 1, d + 1), + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + } + + m_io.DefineVariable(fmt::format("s%d_tag", s + 1), + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + m_io.DefineVariable(fmt::format("s%d_weight", s + 1), + { adios2::UnknownDim }, + { adios2::UnknownDim }, + { adios2::UnknownDim }); + if (nplds[s] > 0) { + m_io.DefineVariable(fmt::format("s%d_plds", s + 1), + { adios2::UnknownDim, nplds[s] }, + { adios2::UnknownDim, 0 }, + { adios2::UnknownDim, nplds[s] }); + } + } + } + + auto Writer::shouldSave(timestep_t step, simtime_t time) -> bool { + return m_enabled and m_tracker.shouldWrite(step, time); + } + + void Writer::beginSaving(timestep_t step, simtime_t time) { + raise::ErrorIf(!m_enabled, "Checkpoint is not enabled", HERE); + raise::ErrorIf(p_adios == nullptr, "ADIOS pointer is null", HERE); + if (m_writing_mode) { + raise::Fatal("Already writing", HERE); + } + m_writing_mode = true; + try { + const auto filename = m_checkpoint_root / fmt::format("step-%08lu.bp", step); + const auto metafilename = m_checkpoint_root / + fmt::format("meta-%08lu.toml", step); + m_writer = m_io.Open(filename, adios2::Mode::Write); + m_written.push_back({ filename, metafilename }); + logger::Checkpoint(fmt::format("Writing checkpoint to %s and %s", + filename.c_str(), + metafilename.c_str()), + HERE); + } catch (std::exception& e) { + raise::Fatal(e.what(), HERE); + } + + m_writer.BeginStep(); + m_writer.Put(m_io.InquireVariable("Step"), &step); + m_writer.Put(m_io.InquireVariable("Time"), &time); + } + + void Writer::endSaving() { + raise::ErrorIf(p_adios == nullptr, "ADIOS pointer is null", HERE); + if (!m_writing_mode) { + raise::Fatal("Not writing", HERE); + } + m_writing_mode = false; + m_writer.EndStep(); + m_writer.Close(); + + // optionally remove the oldest checkpoint + CallOnce([&]() { + if (m_keep > 0 and m_written.size() > (std::size_t)m_keep) { + const auto oldest = m_written.front(); + if (std::filesystem::exists(oldest.first) and + std::filesystem::exists(oldest.second)) { + std::filesystem::remove_all(oldest.first); + std::filesystem::remove(oldest.second); + m_written.erase(m_written.begin()); + } else { + raise::Warning("Checkpoint file does not exist for some reason", HERE); + } + } + }); + } + + template + void Writer::savePerDomainVariable(const std::string& varname, + std::size_t total, + std::size_t offset, + T data) { + auto var = m_io.InquireVariable(varname); + var.SetShape({ total }); + var.SetSelection(adios2::Box({ offset }, { 1 })); + m_writer.Put(var, &data); + } + + void Writer::saveAttrs(const ntt::SimulationParams& params, simtime_t time) { + CallOnce([&]() { + std::ofstream metadata; + if (m_written.empty()) { + raise::Fatal("No checkpoint file to save metadata", HERE); + } + metadata.open(m_written.back().second.c_str()); + metadata << "[metadata]\n" + << " time = " << time << "\n\n" + << params.data() << std::endl; + metadata.close(); + }); + } + + template + void Writer::saveField(const std::string& fieldname, + const ndfield_t& field) { + auto field_h = Kokkos::create_mirror_view(field); + Kokkos::deep_copy(field_h, field); + m_writer.Put(m_io.InquireVariable(fieldname), + field_h.data(), + adios2::Mode::Sync); + } + + template + void Writer::saveParticleQuantity(const std::string& quantity, + npart_t glob_total, + npart_t loc_offset, + npart_t loc_size, + const array_t& data) { + const auto slice = range_tuple_t(0, loc_size); + auto var = m_io.InquireVariable(quantity); + + var.SetShape({ glob_total }); + var.SetSelection(adios2::Box({ loc_offset }, { loc_size })); + + auto data_h = Kokkos::create_mirror_view(data); + Kokkos::deep_copy(data_h, data); + auto data_sub = Kokkos::subview(data_h, slice); + m_writer.Put(var, data_sub.data(), adios2::Mode::Sync); + } + + void Writer::saveParticlePayloads(const std::string& quantity, + std::size_t nplds, + npart_t glob_total, + npart_t loc_offset, + npart_t loc_size, + const array_t& data) { + const auto slice = range_tuple_t(0, loc_size); + auto var = m_io.InquireVariable(quantity); + + var.SetShape({ glob_total, nplds }); + var.SetSelection( + adios2::Box({ loc_offset, 0 }, { loc_size, nplds })); + + auto data_h = Kokkos::create_mirror_view(data); + Kokkos::deep_copy(data_h, data); + auto data_sub = Kokkos::subview(data_h, slice, range_tuple_t(0, nplds)); + m_writer.Put(var, data_sub.data(), adios2::Mode::Sync); + } + +#define CHECKPOINT_PERDOMAIN_VARIABLE(T) \ + template void Writer::savePerDomainVariable(const std::string&, \ + std::size_t, \ + std::size_t, \ + T); + CHECKPOINT_PERDOMAIN_VARIABLE(int) + CHECKPOINT_PERDOMAIN_VARIABLE(float) + CHECKPOINT_PERDOMAIN_VARIABLE(double) + CHECKPOINT_PERDOMAIN_VARIABLE(npart_t) +#undef CHECKPOINT_PERDOMAIN_VARIABLE + +#define CHECKPOINT_FIELD(D, N) \ + template void Writer::saveField(const std::string&, \ + const ndfield_t&); + CHECKPOINT_FIELD(Dim::_1D, 3) + CHECKPOINT_FIELD(Dim::_1D, 6) + CHECKPOINT_FIELD(Dim::_2D, 3) + CHECKPOINT_FIELD(Dim::_2D, 6) + CHECKPOINT_FIELD(Dim::_3D, 3) + CHECKPOINT_FIELD(Dim::_3D, 6) +#undef CHECKPOINT_FIELD + +#define CHECKPOINT_PARTICLE_QUANTITY(T) \ + template void Writer::saveParticleQuantity(const std::string&, \ + npart_t, \ + npart_t, \ + npart_t, \ + const array_t&); + CHECKPOINT_PARTICLE_QUANTITY(int) + CHECKPOINT_PARTICLE_QUANTITY(float) + CHECKPOINT_PARTICLE_QUANTITY(double) + CHECKPOINT_PARTICLE_QUANTITY(short) +#undef CHECKPOINT_PARTICLE_QUANTITY + +} // namespace checkpoint diff --git a/src/checkpoint/writer.h b/src/checkpoint/writer.h new file mode 100644 index 000000000..6f8bc8cb5 --- /dev/null +++ b/src/checkpoint/writer.h @@ -0,0 +1,103 @@ +/** + * @file checkpoint/writer.h + * @brief Class that dumps checkpoints + * @implements + * - checkpoint::Writer + * @cpp: + * - writer.cpp + * @namespaces: + * - checkpoint:: + */ + +#ifndef CHECKPOINT_WRITER_H +#define CHECKPOINT_WRITER_H + +#include "enums.h" +#include "global.h" + +#include "utils/tools.h" + +#include "framework/parameters.h" + +#include + +#include +#include +#include + +namespace checkpoint { + + class Writer { + adios2::ADIOS* p_adios { nullptr }; + + adios2::IO m_io; + adios2::Engine m_writer; + + tools::Tracker m_tracker {}; + + bool m_writing_mode { false }; + + std::vector> m_written; + + int m_keep; + bool m_enabled; + path_t m_checkpoint_root; + + public: + Writer() {} + + ~Writer() = default; + + void init(adios2::ADIOS*, + const path_t&, + timestep_t, + simtime_t, + int, + const std::string& = ""); + + auto shouldSave(timestep_t, simtime_t) -> bool; + + void beginSaving(timestep_t, simtime_t); + void endSaving(); + + void saveAttrs(const ntt::SimulationParams&, simtime_t); + + template + void savePerDomainVariable(const std::string&, std::size_t, std::size_t, T); + + template + void saveField(const std::string&, const ndfield_t&); + + template + void saveParticleQuantity(const std::string&, + npart_t, + npart_t, + npart_t, + const array_t&); + + void saveParticlePayloads(const std::string&, + std::size_t, + npart_t, + npart_t, + npart_t, + const array_t&); + + void defineFieldVariables(const ntt::SimEngine&, + const std::vector&, + const std::vector&, + const std::vector&); + + void defineParticleVariables(const ntt::Coord&, + Dimension, + std::size_t, + const std::vector&); + + [[nodiscard]] + auto enabled() const -> bool { + return m_enabled; + } + }; + +} // namespace checkpoint + +#endif // CHECKPOINT_WRITER_H diff --git a/src/engines/CMakeLists.txt b/src/engines/CMakeLists.txt index 2cc61a265..4cef18630 100644 --- a/src/engines/CMakeLists.txt +++ b/src/engines/CMakeLists.txt @@ -1,46 +1,51 @@ +# cmake-lint: disable=C0103 # ------------------------------ # @defines: ntt_engines [STATIC/SHARED] +# # @sources: -# - engine_printer.cpp -# - engine_init.cpp -# - engine_run.cpp -# - engine_step_report.cpp +# +# * engine_printer.cpp +# * engine_init.cpp +# * engine_run.cpp +# # @includes: -# - ../ +# +# * ../ +# # @depends: -# - ntt_global [required] -# - ntt_framework [required] -# - ntt_metrics [required] -# - ntt_kernels [required] -# - ntt_archetypes [required] -# - ntt_pgen [required] -# - ntt_output [optional] +# +# * ntt_global [required] +# * ntt_framework [required] +# * ntt_metrics [required] +# * ntt_kernels [required] +# * ntt_archetypes [required] +# * ntt_pgen [required] +# * ntt_output [optional] +# # @uses: -# - kokkos [required] -# - plog [required] -# - adios2 [optional] -# - hdf5 [optional] -# - mpi [optional] +# +# * kokkos [required] +# * plog [required] +# * adios2 [optional] +# * hdf5 [optional] +# * mpi [optional] # ------------------------------ set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(SOURCES - ${SRC_DIR}/engine_printer.cpp - ${SRC_DIR}/engine_init.cpp - ${SRC_DIR}/engine_run.cpp - ${SRC_DIR}/engine_step_report.cpp -) +set(SOURCES ${SRC_DIR}/engine_printer.cpp ${SRC_DIR}/engine_init.cpp + ${SRC_DIR}/engine_run.cpp) add_library(ntt_engines ${SOURCES}) -set(libs ntt_global ntt_framework ntt_metrics ntt_archetypes ntt_kernels ntt_pgen) +set(libs ntt_global ntt_framework ntt_metrics ntt_archetypes ntt_kernels + ntt_pgen) if(${output}) - list(APPEND libs ntt_output hdf5::hdf5) + list(APPEND libs ntt_output) endif() add_dependencies(ntt_engines ${libs}) target_link_libraries(ntt_engines PUBLIC ${libs}) target_compile_definitions(ntt_engines PRIVATE PGEN=\"${PGEN}\") -target_include_directories(ntt_engines +target_include_directories( + ntt_engines PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../ - INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../ -) + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../) diff --git a/src/engines/engine.hpp b/src/engines/engine.hpp index a57b52d19..17103f1de 100644 --- a/src/engines/engine.hpp +++ b/src/engines/engine.hpp @@ -23,8 +23,8 @@ #include "arch/traits.h" #include "utils/error.h" -#include "utils/progressbar.h" #include "utils/timer.h" +#include "utils/toml.h" #include "framework/containers/species.h" #include "framework/domain/metadomain.h" @@ -34,6 +34,15 @@ #include +#if defined(OUTPUT_ENABLED) + #include + #include +#endif // OUTPUT_ENABLED + +#if defined(MPI_ENABLED) + #include +#endif // MPI_ENABLED + #include #include @@ -45,15 +54,26 @@ namespace ntt { static_assert(user::PGen::is_pgen, "unrecognized problem generator"); protected: - SimulationParams& m_params; - Metadomain m_metadomain; - user::PGen m_pgen; +#if defined(OUTPUT_ENABLED) + #if defined(MPI_ENABLED) + adios2::ADIOS m_adios { MPI_COMM_WORLD }; + #else + adios2::ADIOS m_adios; + #endif +#endif + + SimulationParams m_params; + Metadomain m_metadomain; + user::PGen m_pgen; - const long double runtime; - const real_t dt; - const std::size_t max_steps; - long double time { 0.0 }; - std::size_t step { 0 }; + const bool is_resuming; + const simtime_t runtime; + const real_t dt; + const timestep_t max_steps; + const timestep_t start_step; + const simtime_t start_time; + simtime_t time; + timestep_t step; public: static constexpr bool pgen_is_ok { @@ -65,60 +85,37 @@ namespace ntt { static constexpr Dimension D { M::Dim }; static constexpr bool is_engine { true }; -#if defined(OUTPUT_ENABLED) - Engine(SimulationParams& params) - : m_params { params } - , m_metadomain { params.get("simulation.domain.number"), - params.get>( - "simulation.domain.decomposition"), - params.get>("grid.resolution"), - params.get>("grid.extent"), - params.get>( - "grid.boundaries.fields"), - params.get>( - "grid.boundaries.particles"), - params.get>( - "grid.metric.params"), - params.get>( - "particles.species"), - params.template get("output.format") } - , m_pgen { m_params, m_metadomain } - , runtime { params.get("simulation.runtime") } - , dt { params.get("algorithms.timestep.dt") } - , max_steps { static_cast(runtime / dt) } - -#else // not OUTPUT_ENABLED - Engine(SimulationParams& params) + Engine(const SimulationParams& params) : m_params { params } - , m_metadomain { params.get("simulation.domain.number"), - params.get>( + , m_metadomain { m_params.get("simulation.domain.number"), + m_params.get>( "simulation.domain.decomposition"), - params.get>("grid.resolution"), - params.get>("grid.extent"), - params.get>( + m_params.get>("grid.resolution"), + m_params.get>("grid.extent"), + m_params.get>( "grid.boundaries.fields"), - params.get>( + m_params.get>( "grid.boundaries.particles"), - params.get>( + m_params.get>( "grid.metric.params"), - params.get>( + m_params.get>( "particles.species") } , m_pgen { m_params, m_metadomain } - , runtime { params.get("simulation.runtime") } - , dt { params.get("algorithms.timestep.dt") } - , max_steps { static_cast(runtime / dt) } -#endif - { - + , is_resuming { m_params.get("checkpoint.is_resuming") } + , runtime { m_params.get("simulation.runtime") } + , dt { m_params.get("algorithms.timestep.dt") } + , max_steps { static_cast(runtime / dt) } + , start_step { m_params.get("checkpoint.start_step") } + , start_time { m_params.get("checkpoint.start_time") } + , time { start_time } + , step { start_step } { raise::ErrorIf(not pgen_is_ok, "Problem generator is not compatible with the picked engine/metric/dimension", HERE); - print_report(); } ~Engine() = default; void init(); void print_report() const; - void print_step_report(timer::Timers&, pbar::DurationHistory&, bool, bool) const; virtual void step_forward(timer::Timers&, Domain&) = 0; diff --git a/src/engines/engine_init.cpp b/src/engines/engine_init.cpp index abb8754d9..7a963b615 100644 --- a/src/engines/engine_init.cpp +++ b/src/engines/engine_init.cpp @@ -11,40 +11,64 @@ #include "metrics/spherical.h" #include "archetypes/field_setter.h" + #include "engines/engine.hpp" #include +#include + namespace ntt { template void Engine::init() { if constexpr (pgen_is_ok) { + m_metadomain.InitStatsWriter(m_params, is_resuming); #if defined(OUTPUT_ENABLED) - m_metadomain.InitWriter(m_params); + m_metadomain.InitWriter(&m_adios, m_params); + m_metadomain.InitCheckpointWriter(&m_adios, m_params); #endif logger::Checkpoint("Initializing Engine", HERE); - if constexpr ( - traits::has_member>::value) { - logger::Checkpoint("Initializing fields from problem generator", HERE); - m_metadomain.runOnLocalDomains([&](auto& loc_dom) { - Kokkos::parallel_for( - "InitFields", - loc_dom.mesh.rangeActiveCells(), - arch::SetEMFields_kernel { - loc_dom.fields.em, - m_pgen.init_flds, - loc_dom.mesh.metric }); - }); - } - if constexpr ( - traits::has_member>::value) { - logger::Checkpoint("Initializing particles from problem generator", HERE); - m_metadomain.runOnLocalDomains([&](auto& loc_dom) { - m_pgen.InitPrtls(loc_dom); - }); + if (not is_resuming) { + // start a new simulation with initial conditions + logger::Checkpoint("Loading initial conditions", HERE); + if constexpr ( + traits::has_member>::value) { + logger::Checkpoint("Initializing fields from problem generator", HERE); + m_metadomain.runOnLocalDomains([&](auto& loc_dom) { + Kokkos::parallel_for( + "InitFields", + loc_dom.mesh.rangeActiveCells(), + arch::SetEMFields_kernel { + loc_dom.fields.em, + m_pgen.init_flds, + loc_dom.mesh.metric }); + }); + } + if constexpr ( + traits::has_member>::value) { + logger::Checkpoint("Initializing particles from problem generator", HERE); + m_metadomain.runOnLocalDomains([&](auto& loc_dom) { + m_pgen.InitPrtls(loc_dom); + }); + } + } else { +#if defined(OUTPUT_ENABLED) + // read simulation data from the checkpoint + raise::ErrorIf( + m_params.template get("checkpoint.start_step") == 0, + "Resuming simulation from a checkpoint requires a valid start_step", + HERE); + logger::Checkpoint("Resuming simulation from a checkpoint", HERE); + m_metadomain.ContinueFromCheckpoint(&m_adios, m_params); +#else + raise::Error( + "Resuming simulation from a checkpoint requires -D output=ON", + HERE); +#endif } } + print_report(); } template class Engine>; @@ -55,4 +79,5 @@ namespace ntt { template class Engine>; template class Engine>; template class Engine>; -} // namespace ntt \ No newline at end of file + +} // namespace ntt diff --git a/src/engines/engine_printer.cpp b/src/engines/engine_printer.cpp index 90dec3326..20f5e81ba 100644 --- a/src/engines/engine_printer.cpp +++ b/src/engines/engine_printer.cpp @@ -17,10 +17,11 @@ #if defined(CUDA_ENABLED) #include +#elif defined(HIP_ENABLED) + #include #endif #if defined(OUTPUT_ENABLED) - #include #include #endif @@ -38,7 +39,7 @@ namespace ntt { color::BRIGHT_BLACK, fmt::repeat("═", 58).c_str(), color::RESET); - for (std::size_t i { 0 }; i < lines.size(); ++i) { + for (auto i { 0u }; i < lines.size(); ++i) { report += fmt::format("%sβ•‘%s %s%s%s%s%sβ•‘%s\n", color::BRIGHT_BLACK, color::RESET, @@ -105,13 +106,13 @@ namespace ntt { color::RESET); } - auto bytes_to_human_readable(std::size_t bytes) - -> std::pair { + auto bytes_to_human_readable( + std::size_t bytes) -> std::pair { const std::vector units { "B", "KB", "MB", "GB", "TB" }; - std::size_t unit_idx = 0; - auto size = static_cast(bytes); - while ((size >= 1024) && (unit_idx < units.size() - 1)) { - size /= 1024; + idx_t unit_idx = 0; + auto size = static_cast(bytes); + while ((size >= 1024.0) and (unit_idx < units.size() - 1)) { + size /= 1024.0; ++unit_idx; } return { size, units[unit_idx] }; @@ -178,8 +179,16 @@ namespace ntt { const auto minor { cuda_v % 1000 / 10 }; const auto patch { cuda_v % 10 }; const auto cuda_version = fmt::format("%d.%d.%d", major, minor, patch); -#else // not CUDA_ENABLED - const std::string cuda_version = "OFF"; +#elif defined(HIP_ENABLED) + int hip_v; + auto status = hipDriverGetVersion(&hip_v); + raise::ErrorIf(status != hipSuccess, + "hipDriverGetVersion failed with error code %d", + HERE); + const auto major { hip_v / 10000000 }; + const auto minor { (hip_v % 10000000) / 100000 }; + const auto patch { hip_v % 100000 }; + const auto hip_version = fmt::format("%d.%d.%d", major, minor, patch); #endif const auto kokkos_version = fmt::format("%d.%d.%d", @@ -188,18 +197,11 @@ namespace ntt { KOKKOS_VERSION % 100); #if defined(OUTPUT_ENABLED) - unsigned h5_major, h5_minor, h5_release; - H5get_libversion(&h5_major, &h5_minor, &h5_release); - const std::string hdf5_version = fmt::format("%d.%d.%d", - h5_major, - h5_minor, - h5_release); const std::string adios2_version = fmt::format("%d.%d.%d", ADIOS2_VERSION / 10000, ADIOS2_VERSION / 100 % 100, ADIOS2_VERSION % 100); #else // not OUTPUT_ENABLED - const std::string hdf5_version = "OFF"; const std::string adios2_version = "OFF"; #endif @@ -212,17 +214,89 @@ namespace ntt { report += "\n\n"; add_header(report, { entity_version }, { color::BRIGHT_GREEN }); report += "\n"; + + /* + * Backend + */ add_category(report, 4, "Backend"); add_param(report, 4, "Build hash", "%s", hash.c_str()); add_param(report, 4, "CXX", "%s [%s]", ccx.c_str(), cpp_standard.c_str()); +#if defined(CUDA_ENABLED) add_param(report, 4, "CUDA", "%s", cuda_version.c_str()); +#elif defined(HIP_VERSION) + add_param(report, 4, "HIP", "%s", hip_version.c_str()); +#endif add_param(report, 4, "MPI", "%s", mpi_version.c_str()); - add_param(report, 4, "HDF5", "%s", hdf5_version.c_str()); +#if defined(MPI_ENABLED) && defined(DEVICE_ENABLED) + #if defined(GPU_AWARE_MPI) + const std::string gpu_aware_mpi = "ON"; + #else + const std::string gpu_aware_mpi = "OFF"; + #endif + add_param(report, 4, "GPU-aware MPI", "%s", gpu_aware_mpi.c_str()); +#endif add_param(report, 4, "Kokkos", "%s", kokkos_version.c_str()); add_param(report, 4, "ADIOS2", "%s", adios2_version.c_str()); add_param(report, 4, "Precision", "%s", precision); add_param(report, 4, "Debug", "%s", dbg.c_str()); report += "\n"; + + /* + * Compilation flags + */ + add_category(report, 4, "Compilation flags"); +#if defined(SINGLE_PRECISION) + add_param(report, 4, "SINGLE_PRECISION", "%s", "ON"); +#else + add_param(report, 4, "SINGLE_PRECISION", "%s", "OFF"); +#endif + +#if defined(OUTPUT_ENABLED) + add_param(report, 4, "OUTPUT_ENABLED", "%s", "ON"); +#else + add_param(report, 4, "OUTPUT_ENABLED", "%s", "OFF"); +#endif + +#if defined(DEBUG) + add_param(report, 4, "DEBUG", "%s", "ON"); +#else + add_param(report, 4, "DEBUG", "%s", "OFF"); +#endif + +#if defined(CUDA_ENABLED) + add_param(report, 4, "CUDA_ENABLED", "%s", "ON"); +#else + add_param(report, 4, "CUDA_ENABLED", "%s", "OFF"); +#endif + +#if defined(HIP_ENABLED) + add_param(report, 4, "HIP_ENABLED", "%s", "ON"); +#else + add_param(report, 4, "HIP_ENABLED", "%s", "OFF"); +#endif + +#if defined(DEVICE_ENABLED) + add_param(report, 4, "DEVICE_ENABLED", "%s", "ON"); +#else + add_param(report, 4, "DEVICE_ENABLED", "%s", "OFF"); +#endif + +#if defined(MPI_ENABLED) + add_param(report, 4, "MPI_ENABLED", "%s", "ON"); +#else + add_param(report, 4, "MPI_ENABLED", "%s", "OFF"); +#endif + +#if defined(GPU_AWARE_MPI) + add_param(report, 4, "GPU_AWARE_MPI", "%s", "ON"); +#else + add_param(report, 4, "GPU_AWARE_MPI", "%s", "OFF"); +#endif + report += "\n"; + + /* + * Simulation configs + */ add_category(report, 4, "Configuration"); add_param(report, 4, @@ -233,15 +307,14 @@ namespace ntt { add_param(report, 4, "Engine", "%s", SimEngine(S).to_string()); add_param(report, 4, "Metric", "%s", Metric(M::MetricType).to_string()); add_param(report, 4, "Timestep [dt]", "%.3e", dt); - add_param(report, 4, "Runtime", "%.3Le [%d steps]", runtime, max_steps); + add_param(report, 4, "Runtime", "%.3e [%d steps]", runtime, max_steps); report += "\n"; add_category(report, 4, "Global domain"); - add_param( - report, - 4, - "Resolution", - "%s", - params.template stringize("grid.resolution").c_str()); + add_param(report, + 4, + "Resolution", + "%s", + params.template stringize("grid.resolution").c_str()); add_param(report, 4, "Extent", @@ -344,7 +417,7 @@ namespace ntt { for (unsigned int idx { 0 }; idx < m_metadomain.ndomains(); ++idx) { auto is_local = false; - for (const auto& lidx : m_metadomain.local_subdomain_indices()) { + for (const auto& lidx : m_metadomain.l_subdomain_indices()) { is_local |= (idx == lidx); } if (is_local) { @@ -392,7 +465,7 @@ namespace ntt { add_subcategory(report, 6, "Memory footprint"); auto flds_footprint = domain.fields.memory_footprint(); auto [flds_size, flds_unit] = bytes_to_human_readable(flds_footprint); - add_param(report, 8, "Fields", "%.2Lf %s", flds_size, flds_unit.c_str()); + add_param(report, 8, "Fields", "%.2f %s", flds_size, flds_unit.c_str()); if (domain.species.size() > 0) { add_subcategory(report, 8, "Particles"); } @@ -401,7 +474,7 @@ namespace ntt { species.index(), species.label().c_str()); auto [size, unit] = bytes_to_human_readable(species.memory_footprint()); - add_param(report, 10, str.c_str(), "%.2Lf %s", size, unit.c_str()); + add_param(report, 10, str.c_str(), "%.2f %s", size, unit.c_str()); } report.pop_back(); if (idx == m_metadomain.ndomains() - 1) { @@ -415,13 +488,13 @@ namespace ntt { } } - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; + template void Engine>::print_report() const; + template void Engine>::print_report() const; + template void Engine>::print_report() const; + template void Engine>::print_report() const; + template void Engine>::print_report() const; + template void Engine>::print_report() const; + template void Engine>::print_report() const; + template void Engine>::print_report() const; } // namespace ntt diff --git a/src/engines/engine_run.cpp b/src/engines/engine_run.cpp index 4485e0e40..65c87e939 100644 --- a/src/engines/engine_run.cpp +++ b/src/engines/engine_run.cpp @@ -1,6 +1,7 @@ #include "enums.h" #include "arch/traits.h" +#include "utils/diag.h" #include "metrics/kerr_schild.h" #include "metrics/kerr_schild_0.h" @@ -25,19 +26,20 @@ namespace ntt { "CurrentFiltering", "CurrentDeposit", "ParticlePusher", "FieldBoundaries", "ParticleBoundaries", "Communications", - "Injector", "Sorting", - "Custom", "Output" }, + "Injector", "Custom", + "PrtlClear", "Output", + "Checkpoint" }, []() { Kokkos::fence(); }, m_params.get("diagnostics.blocking_timers") }; - const auto diag_interval = m_params.get( + const auto diag_interval = m_params.get( "diagnostics.interval"); - auto time_history = pbar::DurationHistory { 1000 }; - const auto sort_interval = m_params.template get( - "particles.sort_interval"); + auto time_history = pbar::DurationHistory { 1000 }; + const auto clear_interval = m_params.template get( + "particles.clear_interval"); // main algorithm loop while (step < max_steps) { @@ -54,50 +56,100 @@ namespace ntt { }); timers.stop("Custom"); } - auto print_sorting = (sort_interval > 0 and step % sort_interval == 0); + auto print_prtl_clear = (clear_interval > 0 and + step % clear_interval == 0 and step > 0); - // advance time & timestep - ++step; + // advance time & step time += dt; + ++step; - auto print_output = false; + auto print_output = false; + auto print_checkpoint = false; #if defined(OUTPUT_ENABLED) timers.start("Output"); if constexpr ( traits::has_method::value) { auto lambda_custom_field_output = [&](const std::string& name, ndfield_t& buff, - std::size_t idx, + index_t idx, + timestep_t step, + simtime_t time, const Domain& dom) { - m_pgen.CustomFieldOutput(name, buff, idx, dom); + m_pgen.CustomFieldOutput(name, buff, idx, step, time, dom); + }; + print_output &= m_metadomain.Write(m_params, + step, + step - 1, + time, + time - dt, + lambda_custom_field_output); + } else { + print_output &= m_metadomain.Write(m_params, step, step - 1, time, time - dt); + } + if constexpr ( + traits::has_method::value) { + auto lambda_custom_stat = [&](const std::string& name, + timestep_t step, + simtime_t time, + const Domain& dom) -> real_t { + return m_pgen.CustomStat(name, step, time, dom); }; - print_output = m_metadomain.Write(m_params, - step, - time, - lambda_custom_field_output); + print_output &= m_metadomain.WriteStats(m_params, + step, + step - 1, + time, + time - dt, + lambda_custom_stat); } else { - print_output = m_metadomain.Write(m_params, step, time); + print_output &= m_metadomain.WriteStats(m_params, + step, + step - 1, + time, + time - dt); } timers.stop("Output"); + + timers.start("Checkpoint"); + print_checkpoint = m_metadomain.WriteCheckpoint(m_params, + step, + step - 1, + time, + time - dt); + timers.stop("Checkpoint"); #endif // advance time_history time_history.tick(); - // print final timestep report + // print timestep report if (diag_interval > 0 and step % diag_interval == 0) { - print_step_report(timers, time_history, print_output, print_sorting); + diag::printDiagnostics( + step - 1, + max_steps, + time - dt, + dt, + timers, + time_history, + m_metadomain.l_ncells(), + m_metadomain.species_labels(), + m_metadomain.l_npart_perspec(), + m_metadomain.l_maxnpart_perspec(), + print_prtl_clear, + print_output, + print_checkpoint, + m_params.get("diagnostics.colored_stdout")); } timers.resetAll(); } } } - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; + template void Engine>::run(); + template void Engine>::run(); + template void Engine>::run(); + template void Engine>::run(); + template void Engine>::run(); + template void Engine>::run(); + template void Engine>::run(); + template void Engine>::run(); + } // namespace ntt diff --git a/src/engines/engine_step_report.cpp b/src/engines/engine_step_report.cpp deleted file mode 100644 index f2a35bb82..000000000 --- a/src/engines/engine_step_report.cpp +++ /dev/null @@ -1,293 +0,0 @@ -#include "enums.h" -#include "global.h" - -#include "arch/mpi_aliases.h" -#include "utils/colors.h" -#include "utils/formatting.h" -#include "utils/progressbar.h" -#include "utils/timer.h" - -#include "metrics/kerr_schild.h" -#include "metrics/kerr_schild_0.h" -#include "metrics/minkowski.h" -#include "metrics/qkerr_schild.h" -#include "metrics/qspherical.h" -#include "metrics/spherical.h" - -#include "engines/engine.hpp" - -#include -#include - -namespace ntt { - namespace {} // namespace - - template - void print_particles(const Metadomain&, - unsigned short, - DiagFlags, - std::ostream& = std::cout); - - template - void Engine::print_step_report(timer::Timers& timers, - pbar::DurationHistory& time_history, - bool print_output, - bool print_sorting) const { - DiagFlags diag_flags = Diag::Default; - TimerFlags timer_flags = Timer::Default; - if (not m_params.get("diagnostics.colored_stdout")) { - diag_flags ^= Diag::Colorful; - timer_flags ^= Timer::Colorful; - } - if (m_params.get("particles.nspec") == 0) { - diag_flags ^= Diag::Species; - } - if (print_output) { - timer_flags |= Timer::PrintOutput; - } - if (print_sorting) { - timer_flags |= Timer::PrintSorting; - } - CallOnce( - [diag_flags](auto& time, auto& step, auto& max_steps, auto& dt) { - const auto c_bgreen = color::get_color("bgreen", - diag_flags & Diag::Colorful); - const auto c_bblack = color::get_color("bblack", - diag_flags & Diag::Colorful); - const auto c_reset = color::get_color("reset", diag_flags & Diag::Colorful); - std::cout << fmt::format("Step:%s %-8d%s %s[of %d]%s\n", - c_bgreen.c_str(), - step, - c_reset.c_str(), - c_bblack.c_str(), - max_steps, - c_reset.c_str()); - std::cout << fmt::format("Time:%s %-8.4f%s %s[Ξ”t = %.4f]%s\n", - c_bgreen.c_str(), - (double)time, - c_reset.c_str(), - c_bblack.c_str(), - (double)dt, - c_reset.c_str()) - << std::endl; - }, - time, - step, - max_steps, - dt); - if (diag_flags & Diag::Timers) { - timers.printAll(timer_flags, std::cout); - } - CallOnce([]() { - std::cout << std::endl; - }); - if (diag_flags & Diag::Species) { - CallOnce([diag_flags]() { - std::cout << color::get_color("bblack", diag_flags & Diag::Colorful); -#if defined(MPI_ENABLED) - std::cout << "Particle count:" << std::setw(22) << std::right << "[TOT]" - << std::setw(20) << std::right << "[MIN (%)]" << std::setw(20) - << std::right << "[MAX (%)]"; -#else - std::cout << "Particle count:" << std::setw(25) << std::right - << "[TOT (%)]"; -#endif - std::cout << color::get_color("reset", diag_flags & Diag::Colorful) - << std::endl; - }); - for (std::size_t sp { 0 }; sp < m_metadomain.species_params().size(); ++sp) { - print_particles(m_metadomain, sp, diag_flags, std::cout); - } - CallOnce([]() { - std::cout << std::endl; - }); - } - if (diag_flags & Diag::Progress) { - pbar::ProgressBar(time_history, step, max_steps, diag_flags, std::cout); - } - CallOnce([]() { - std::cout << std::setw(80) << std::setfill('.') << "" << std::endl - << std::endl; - }); - } - - template - void print_particles(const Metadomain& md, - unsigned short sp, - DiagFlags flags, - std::ostream& os) { - - static_assert(M::is_metric, "template arg for Engine class has to be a metric"); - std::size_t npart { 0 }; - std::size_t maxnpart { 0 }; - std::string species_label; - int species_index; - // sum npart & maxnpart over all subdomains on the current rank - md.runOnLocalDomainsConst( - [&npart, &maxnpart, &species_label, &species_index, sp](auto& dom) { - npart += dom.species[sp].npart(); - maxnpart += dom.species[sp].maxnpart(); - species_label = dom.species[sp].label(); - species_index = dom.species[sp].index(); - }); -#if defined(MPI_ENABLED) - int rank, size; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &size); - std::vector mpi_npart(size, 0); - std::vector mpi_maxnpart(size, 0); - MPI_Gather(&npart, - 1, - mpi::get_type(), - mpi_npart.data(), - 1, - mpi::get_type(), - MPI_ROOT_RANK, - MPI_COMM_WORLD); - MPI_Gather(&maxnpart, - 1, - mpi::get_type(), - mpi_maxnpart.data(), - 1, - mpi::get_type(), - MPI_ROOT_RANK, - MPI_COMM_WORLD); - if (rank != MPI_ROOT_RANK) { - return; - } - auto tot_npart = std::accumulate(mpi_npart.begin(), mpi_npart.end(), 0); - std::size_t npart_max = *std::max_element(mpi_npart.begin(), mpi_npart.end()); - std::size_t npart_min = *std::min_element(mpi_npart.begin(), mpi_npart.end()); - std::vector mpi_load(size, 0.0); - for (auto r { 0 }; r < size; ++r) { - mpi_load[r] = 100.0 * (double)(mpi_npart[r]) / (double)(mpi_maxnpart[r]); - } - double load_max = *std::max_element(mpi_load.begin(), mpi_load.end()); - double load_min = *std::min_element(mpi_load.begin(), mpi_load.end()); - auto npart_min_str = npart_min > 9999 - ? fmt::format("%.2Le", (long double)npart_min) - : std::to_string(npart_min); - auto tot_npart_str = tot_npart > 9999 - ? fmt::format("%.2Le", (long double)tot_npart) - : std::to_string(tot_npart); - auto npart_max_str = npart_max > 9999 - ? fmt::format("%.2Le", (long double)npart_max) - : std::to_string(npart_max); - os << " species " << fmt::format("%2d", species_index) << " (" - << species_label << ")"; - - const auto c_bblack = color::get_color("bblack", flags & Diag::Colorful); - const auto c_red = color::get_color("red", flags & Diag::Colorful); - const auto c_yellow = color::get_color("yellow", flags & Diag::Colorful); - const auto c_green = color::get_color("green", flags & Diag::Colorful); - const auto c_reset = color::get_color("reset", flags & Diag::Colorful); - auto c_loadmin = (load_min > 80) ? c_red - : ((load_min > 50) ? c_yellow : c_green); - auto c_loadmax = (load_max > 80) ? c_red - : ((load_max > 50) ? c_yellow : c_green); - const auto raw1 = fmt::format("%s (%4.1f%%)", npart_min_str.c_str(), load_min); - const auto raw2 = fmt::format("%s (%4.1f%%)", npart_max_str.c_str(), load_max); - os << c_bblack - << fmt::pad(tot_npart_str, 20, '.', false).substr(0, 20 - tot_npart_str.size()) - << c_reset << tot_npart_str; - os << fmt::pad(raw1, 20, ' ', false).substr(0, 20 - raw1.size()) - << fmt::format("%s (%s%4.1f%%%s)", - npart_min_str.c_str(), - c_loadmin.c_str(), - load_min, - c_reset.c_str()); - os << fmt::pad(raw2, 20, ' ', false).substr(0, 20 - raw2.size()) - << fmt::format("%s (%s%4.1f%%%s)", - npart_max_str.c_str(), - c_loadmax.c_str(), - load_max, - c_reset.c_str()); -#else // not MPI_ENABLED - auto load = 100.0 * (double)(npart) / (double)(maxnpart); - auto npart_str = npart > 9999 ? fmt::format("%.2Le", (long double)npart) - : std::to_string(npart); - const auto c_bblack = color::get_color("bblack", flags & Diag::Colorful); - const auto c_red = color::get_color("red", flags & Diag::Colorful); - const auto c_yellow = color::get_color("yellow", flags & Diag::Colorful); - const auto c_green = color::get_color("green", flags & Diag::Colorful); - const auto c_reset = color::get_color("reset", flags & Diag::Colorful); - const auto c_load = (load > 80) - ? c_red.c_str() - : ((load > 50) ? c_yellow.c_str() : c_green.c_str()); - os << " species " << species_index << " (" << species_label << ")"; - const auto raw = fmt::format("%s (%4.1f%%)", npart_str.c_str(), load); - os << c_bblack << fmt::pad(raw, 24, '.').substr(0, 24 - raw.size()) << c_reset; - os << fmt::format("%s (%s%4.1f%%%s)", - npart_str.c_str(), - c_load, - load, - c_reset.c_str()); -#endif - os << std::endl; - } - - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; - template class Engine>; -} // namespace ntt - -// template -// auto Simulation::PrintDiagnostics(const std::size_t& step, -// const real_t& time, -// const timer::Timers& timers, -// std::vector& tstep_durations, -// const DiagFlags diag_flags, -// std::ostream& os) -> void { -// if (tstep_durations.size() > m_params.diagMaxnForPbar()) { -// tstep_durations.erase(tstep_durations.begin()); -// } -// tstep_durations.push_back(timers.get("Total")); -// if (step % m_params.diagInterval() == 0) { -// auto& mblock = this->meshblock; -// const auto title { -// fmt::format("Time = %f : step = %d : Ξ”t = %f", time, step, mblock.timestep()) -// }; -// PrintOnce( -// [](std::ostream& os, std::string title) { -// os << title << std::endl; -// }, -// os, -// title); -// if (diag_flags & DiagFlags_Timers) { -// timers.printAll("", timer::TimerFlags_Default, os); -// } -// if (diag_flags & DiagFlags_Species) { -// auto header = fmt::format("%s %27s", "[SPECIES]", "[TOT]"); -// #if defined(MPI_ENABLED) -// header += fmt::format("%17s %s", "[MIN (%) :", "MAX (%)]"); -// #endif -// PrintOnce( -// [](std::ostream& os, std::string header) { -// os << header << std::endl; -// }, -// os, -// header); -// for (const auto& species : meshblock.particles) { -// species.PrintParticleCounts(os); -// } -// } -// if (diag_flags & DiagFlags_Progress) { -// PrintOnce( -// [](std::ostream& os) { -// os << std::setw(65) << std::setfill('-') << "" << std::endl; -// }, -// os); -// ProgressBar(tstep_durations, time, m_params.totalRuntime(), os); -// } -// PrintOnce( -// [](std::ostream& os) { -// os << std::setw(65) << std::setfill('=') << "" << std::endl; -// }, -// os); -// } -// } diff --git a/src/engines/grpic.hpp b/src/engines/grpic.hpp index 148c1c5c5..35e544a99 100644 --- a/src/engines/grpic.hpp +++ b/src/engines/grpic.hpp @@ -4,7 +4,7 @@ * @implements * - ntt::GRPICEngine<> : ntt::Engine<> * @cpp: - * - srpic.cpp + * - grpic.cpp * @namespaces: * - ntt:: */ @@ -13,31 +13,1133 @@ #define ENGINES_GRPIC_GRPIC_H #include "enums.h" +#include "global.h" +#include "arch/kokkos_aliases.h" +#include "utils/log.h" +#include "utils/numeric.h" #include "utils/timer.h" +#include "utils/toml.h" #include "framework/domain/domain.h" +#include "framework/parameters.h" #include "engines/engine.hpp" +#include "kernels/ampere_gr.hpp" +#include "kernels/aux_fields_gr.hpp" +#include "kernels/currents_deposit.hpp" +#include "kernels/digital_filter.hpp" +#include "kernels/faraday_gr.hpp" +#include "kernels/fields_bcs.hpp" +#include "kernels/particle_pusher_gr.hpp" +#include "pgen.hpp" + +#include +#include + +#include +#include namespace ntt { + enum class gr_getE { + D0_B, + D_B0 + }; + enum class gr_getH { + D_B0, + D0_B0 + }; + enum class gr_faraday { + aux, + main + }; + enum class gr_ampere { + init, + aux, + main + }; + enum class gr_bc { + main, + aux, + curr + }; + template class GRPICEngine : public Engine { - using Engine::m_params; - using Engine::m_metadomain; + using base_t = Engine; + using pgen_t = user::PGen; + using domain_t = Domain; + // constexprs + using base_t::pgen_is_ok; + // contents + using base_t::m_metadomain; + using base_t::m_params; + using base_t::m_pgen; + // methods + using base_t::init; + // variables + using base_t::dt; + using base_t::max_steps; + using base_t::runtime; + using base_t::step; + using base_t::time; public: - static constexpr auto S { SimEngine::SRPIC }; + static constexpr auto S { SimEngine::GRPIC }; - GRPICEngine(SimulationParams& params) - : Engine { params } {} + GRPICEngine(SimulationParams& params) : base_t { params } {} ~GRPICEngine() = default; - void step_forward(timer::Timers&, Domain&) override {} - }; + void step_forward(timer::Timers& timers, domain_t& dom) override { + const auto fieldsolver_enabled = m_params.template get( + "algorithms.toggles.fieldsolver"); + const auto deposit_enabled = m_params.template get( + "algorithms.toggles.deposit"); + const auto clear_interval = m_params.template get( + "particles.clear_interval"); + + if (step == 0) { + if (fieldsolver_enabled) { + // communicate fields and apply BCs on the first timestep + /** + * Initially: em0::B -- + * em0::D -- + * em::B at -1/2 + * em::D at -1/2 + * + * cur0::J -- + * cur::J -- + * + * aux::E -- + * aux::H -- + * + * x_prtl at -1/2 + * u_prtl at -1/2 + */ + + /** + * em0::D, em::D, em0::B, em::B <- boundary conditions + */ + m_metadomain.CommunicateFields(dom, + Comm::B | Comm::B0 | Comm::D | Comm::D0); + FieldBoundaries(dom, BC::B | BC::D, gr_bc::main); + + /** + * em0::B <- em::B + * em0::D <- em::D + * + * Now: em0::B & em0::D at -1/2 + */ + CopyFields(dom); + + /** + * aux::E <- alpha * em::D + beta x em0::B + * aux::H <- alpha * em::B0 - beta x em::D + * + * Now: aux::E & aux::H at -1/2 + */ + ComputeAuxE(dom, gr_getE::D_B0); + ComputeAuxH(dom, gr_getH::D_B0); + + /** + * aux::E, aux::H <- boundary conditions + */ + m_metadomain.CommunicateFields(dom, Comm::H | Comm::E); + FieldBoundaries(dom, BC::H | BC::E, gr_bc::aux); + + /** + * em0::B <- (em0::B) <- -curl aux::E + * + * Now: em0::B at 0 + */ + Faraday(dom, gr_faraday::aux, HALF); + + /** + * em0::B, em::B <- boundary conditions + */ + m_metadomain.CommunicateFields(dom, Comm::B | Comm::B0); + FieldBoundaries(dom, BC::B, gr_bc::main); + + /** + * em::D <- (em0::D) <- curl aux::H + * + * Now: em::D at 0 + */ + Ampere(dom, gr_ampere::init, HALF); + + /** + * em0::D, em::D <- boundary conditions + */ + m_metadomain.CommunicateFields(dom, Comm::D | Comm::D0); + FieldBoundaries(dom, BC::D, gr_bc::main); + + /** + * aux::E <- alpha * em::D + beta x em0::B + * aux::H <- alpha * em0::B - beta x em::D + * + * Now: aux::E & aux::H at 0 + */ + ComputeAuxE(dom, gr_getE::D_B0); + ComputeAuxH(dom, gr_getH::D_B0); + + /** + * aux::E, aux::H <- boundary conditions + */ + m_metadomain.CommunicateFields(dom, Comm::H | Comm::E); + FieldBoundaries(dom, BC::H | BC::E, gr_bc::aux); + + // !ADD: GR -- particles? + + /** + * em0::B <- (em::B) <- -curl aux::E + * + * Now: em0::B at 1/2 + */ + Faraday(dom, gr_faraday::main, ONE); + /** + * em0::B, em::B <- boundary conditions + */ + m_metadomain.CommunicateFields(dom, Comm::B | Comm::B0); + FieldBoundaries(dom, BC::B, gr_bc::main); + + /** + * em0::D <- (em0::D) <- curl aux::H + * + * Now: em0::D at 1/2 + */ + Ampere(dom, gr_ampere::aux, ONE); + /** + * em0::D, em::D <- boundary conditions + */ + m_metadomain.CommunicateFields(dom, Comm::D | Comm::D0); + FieldBoundaries(dom, BC::D, gr_bc::main); + + /** + * aux::H <- alpha * em0::B - beta x em0::D + * + * Now: aux::H at 1/2 + */ + ComputeAuxH(dom, gr_getH::D0_B0); + /** + * aux::H <- boundary conditions + */ + m_metadomain.CommunicateFields(dom, Comm::H); + FieldBoundaries(dom, BC::H, gr_bc::aux); + + /** + * em0::D <- (em::D) <- curl aux::H + * + * Now: em0::D at 1 + * em::D at 0 + */ + Ampere(dom, gr_ampere::main, ONE); + /** + * em0::D, em::D <- boundary conditions + */ + m_metadomain.CommunicateFields(dom, Comm::D | Comm::D0); + FieldBoundaries(dom, BC::D, gr_bc::main); + + /** + * em::D <-> em0::D + * em::B <-> em0::B + * em::J <-> em0::J + */ + SwapFields(dom); + /** + * Finally: em0::B at -1/2 + * em0::D at 0 + * em::B at 1/2 + * em::D at 1 + * + * cur0::J -- + * cur::J -- + * + * aux::E -- + * aux::H -- + * + * x_prtl at 1 + * u_prtl at 1/2 + */ + } else { + /** + * em0::B <- em::B + * em0::D <- em::D + * + * Now: em0::B & em0::D at -1/2 + */ + CopyFields(dom); + } + } + + /** + * Initially: em0::B at n-3/2 + * em0::D at n-1 + * em::B at n-1/2 + * em::D at n + * + * cur0::J -- + * cur::J at n-1/2 + * + * aux::E -- + * aux::H -- + * + * x_prtl at n + * u_prtl at n-1/2 + */ + + if (fieldsolver_enabled) { + timers.start("FieldSolver"); + /** + * em0::D <- (em0::D + em::D) / 2 + * em0::B <- (em0::B + em::B) / 2 + * + * Now: em0::D at n-1/2 + * em0::B at n-1 + */ + TimeAverageDB(dom); + /** + * aux::E <- alpha * em0::D + beta x em::B + * + * Now: aux::E at n-1/2 + */ + ComputeAuxE(dom, gr_getE::D0_B); + timers.stop("FieldSolver"); + + timers.start("Communications"); + m_metadomain.CommunicateFields(dom, Comm::E); + timers.stop("Communications"); + timers.start("FieldBoundaries"); + /** + * aux::E <- boundary conditions + */ + FieldBoundaries(dom, BC::E, gr_bc::aux); + timers.stop("FieldBoundaries"); + + timers.start("FieldSolver"); + /** + * em0::B <- (em0::B) <- -curl aux::E + * + * Now: em0::B at n + */ + Faraday(dom, gr_faraday::aux, ONE); + timers.stop("FieldSolver"); + + timers.start("Communications"); + m_metadomain.CommunicateFields(dom, Comm::B | Comm::B0); + timers.stop("Communications"); + /** + * em0::B, em::B <- boundary conditions + */ + timers.start("FieldBoundaries"); + FieldBoundaries(dom, BC::B, gr_bc::main); + timers.stop("FieldBoundaries"); + + timers.start("FieldSolver"); + /** + * aux::H <- alpha * em0::B - beta x em::D + * + * Now: aux::H at n + */ + ComputeAuxH(dom, gr_getH::D_B0); + timers.stop("FieldSolver"); + + timers.start("Communications"); + m_metadomain.CommunicateFields(dom, Comm::H); + timers.stop("Communications"); + timers.start("FieldBoundaries"); + /** + * aux::H <- boundary conditions + */ + FieldBoundaries(dom, BC::H, gr_bc::aux); + timers.stop("FieldBoundaries"); + } + + { + /** + * x_prtl, u_prtl <- em::D, em0::B + * + * Now: x_prtl at n + 1, u_prtl at n + 1/2 + */ + timers.start("ParticlePusher"); + ParticlePush(dom); + timers.stop("ParticlePusher"); + + /** + * cur0::J <- current deposition + * + * Now: cur0::J at n+1/2 + */ + if (deposit_enabled) { + timers.start("CurrentDeposit"); + Kokkos::deep_copy(dom.fields.cur0, ZERO); + CurrentsDeposit(dom); + timers.stop("CurrentDeposit"); + + timers.start("Communications"); + m_metadomain.SynchronizeFields(dom, Comm::J); + m_metadomain.CommunicateFields(dom, Comm::J); + timers.stop("Communications"); + + timers.start("FieldBoundaries"); + FieldBoundaries(dom, BC::J, gr_bc::curr); + timers.stop("FieldBoundaries"); + + timers.start("CurrentFiltering"); + CurrentsFilter(dom); + timers.stop("CurrentFiltering"); + } + + timers.start("Communications"); + m_metadomain.CommunicateParticles(dom); + timers.stop("Communications"); + } + + if (fieldsolver_enabled) { + timers.start("FieldSolver"); + if (deposit_enabled) { + /** + * cur::J <- (cur0::J + cur::J) / 2 + * + * Now: cur::J at n + */ + TimeAverageJ(dom); + } + /** + * aux::Π• <- alpha * em::D + beta x em0::B + * + * Now: aux::Π• at n + */ + ComputeAuxE(dom, gr_getE::D_B0); + timers.stop("FieldSolver"); + timers.start("Communications"); + m_metadomain.CommunicateFields(dom, Comm::E); + timers.stop("Communications"); + timers.start("FieldBoundaries"); + /** + * aux::Π• <- boundary conditions + */ + FieldBoundaries(dom, BC::E, gr_bc::aux); + timers.stop("FieldBoundaries"); + + timers.start("FieldSolver"); + /** + * em0::B <- (em::B) <- -curl aux::E + * + * Now: em0::B at n+1/2 + * em::B at n-1/2 + */ + Faraday(dom, gr_faraday::main, ONE); + timers.stop("FieldSolver"); + + /** + * em0::B, em::B <- boundary conditions + */ + timers.start("Communications"); + m_metadomain.CommunicateFields(dom, Comm::B | Comm::B0); + timers.stop("Communications"); + timers.start("FieldBoundaries"); + FieldBoundaries(dom, BC::B, gr_bc::main); + timers.stop("FieldBoundaries"); + + timers.start("FieldSolver"); + /** + * em0::D <- (em0::D) <- curl aux::H + * + * Now: em0::D at n+1/2 + */ + Ampere(dom, gr_ampere::aux, ONE); + timers.stop("FieldSolver"); + + if (deposit_enabled) { + timers.start("FieldSolver"); + /** + * em0::D <- (em0::D) <- cur::J + * + * Now: em0::D at n+1/2 + */ + AmpereCurrents(dom, gr_ampere::aux); + timers.stop("FieldSolver"); + } + + /** + * em0::D, em::D <- boundary conditions + */ + timers.start("Communications"); + m_metadomain.CommunicateFields(dom, Comm::D | Comm::D0); + timers.stop("Communications"); + timers.start("FieldBoundaries"); + FieldBoundaries(dom, BC::D, gr_bc::main); + timers.stop("FieldBoundaries"); + + timers.start("FieldSolver"); + /** + * aux::H <- alpha * em0::B - beta x em0::D + * + * Now: aux::H at n+1/2 + */ + ComputeAuxH(dom, gr_getH::D0_B0); + timers.stop("FieldSolver"); + + timers.start("Communications"); + m_metadomain.CommunicateFields(dom, Comm::H); + timers.stop("Communications"); + timers.start("FieldBoundaries"); + /** + * aux::H <- boundary conditions + */ + FieldBoundaries(dom, BC::B, gr_bc::aux); + timers.stop("FieldBoundaries"); + + timers.start("FieldSolver"); + /** + * em0::D <- (em::D) <- curl aux::H + * + * Now: em0::D at n+1 + * em::D at n + */ + Ampere(dom, gr_ampere::main, ONE); + timers.stop("FieldSolver"); + + if (deposit_enabled) { + timers.start("FieldSolver"); + /** + * em0::D <- (em0::D) <- cur0::J + * + * Now: em0::D at n+1 + */ + AmpereCurrents(dom, gr_ampere::main); + timers.stop("FieldSolver"); + } + timers.start("FieldSolver"); + /** + * em::D <-> em0::D + * em::B <-> em0::B + * cur::J <-> cur0::J + */ + SwapFields(dom); + timers.stop("FieldSolver"); + + /** + * em0::D, em::D <- boundary conditions + */ + timers.start("Communications"); + m_metadomain.CommunicateFields(dom, Comm::D | Comm::D0); + timers.stop("Communications"); + timers.start("FieldBoundaries"); + FieldBoundaries(dom, BC::D, gr_bc::main); + timers.stop("FieldBoundaries"); + } + + if (clear_interval > 0 and step % clear_interval == 0 and step > 0) { + timers.start("PrtlClear"); + m_metadomain.RemoveDeadParticles(dom); + timers.stop("PrtlClear"); + } + + /** + * Finally: em0::B at n-1/2 + * em0::D at n + * em::B at n+1/2 + * em::D at n+1 + * + * cur0::J (at n) + * cur::J at n+1/2 + * + * aux::E (at n+1/2) + * aux::H (at n) + * + * x_prtl at n+1 + * u_prtl at n+1/2 + */ + } + + /* algorithm substeps --------------------------------------------------- */ + void FieldBoundaries(domain_t& domain, BCTags tags, const gr_bc& g) { + if (g == gr_bc::main) { + for (auto& direction : dir::Directions::orth) { + if (m_metadomain.mesh().flds_bc_in(direction) == FldsBC::MATCH) { + MatchFieldsIn(direction, domain, tags, g); + } else if (domain.mesh.flds_bc_in(direction) == FldsBC::AXIS) { + AxisFieldsIn(direction, domain, tags); + } else if (m_metadomain.mesh().flds_bc_in(direction) == FldsBC::CUSTOM) { + CustomFieldsIn(direction, domain, tags, g); + } else if (domain.mesh.flds_bc_in(direction) == FldsBC::HORIZON) { + HorizonFieldsIn(direction, domain, tags, g); + } + } // loop over directions + } else if (g == gr_bc::aux) { + for (auto& direction : dir::Directions::orth) { + if (domain.mesh.flds_bc_in(direction) == FldsBC::HORIZON) { + HorizonFieldsIn(direction, domain, tags, g); + } + } + } else if (g == gr_bc::curr) { + for (auto& direction : dir::Directions::orth) { + if (domain.mesh.prtl_bc_in(direction) == PrtlBC::ABSORB) { + MatchFieldsIn(direction, domain, tags, g); + } + } + } + } + + void MatchFieldsIn(dir::direction_t direction, + domain_t& domain, + BCTags tags, + const gr_bc& g) { + /** + * match boundaries + */ + const auto ds_array = m_params.template get>( + "grid.boundaries.match.ds"); + const auto dim = direction.get_dim(); + real_t xg_min, xg_max, xg_edge; + auto sign = direction.get_sign(); + + raise::ErrorIf(((dim != in::x1) or (sign < 0)) and (g == gr_bc::curr), + "Absorption of currents only possible in +x1 (+r)", + HERE); + + real_t ds; + if (sign > 0) { // + direction + ds = ds_array[(short)dim].second; + xg_max = m_metadomain.mesh().extent(dim).second; + xg_min = xg_max - ds; + xg_edge = xg_max; + } else { // - direction + ds = ds_array[(short)dim].first; + xg_min = m_metadomain.mesh().extent(dim).first; + xg_max = xg_min + ds; + xg_edge = xg_min; + } + boundaries_t box; + boundaries_t incl_ghosts; + for (unsigned short d { 0 }; d < M::Dim; ++d) { + if (d == static_cast(dim)) { + box.push_back({ xg_min, xg_max }); + incl_ghosts.push_back({ false, true }); + } else { + box.push_back(Range::All); + incl_ghosts.push_back({ true, true }); + } + } + if (not domain.mesh.Intersects(box)) { + return; + } + const auto intersect_range = domain.mesh.ExtentToRange(box, incl_ghosts); + tuple_t range_min { 0 }; + tuple_t range_max { 0 }; + + for (unsigned short d { 0 }; d < M::Dim; ++d) { + range_min[d] = intersect_range[d].first; + range_max[d] = intersect_range[d].second; + } + if (dim == in::x1) { + if (g != gr_bc::curr) { + Kokkos::parallel_for( + "MatchBoundaries", + CreateRangePolicy(range_min, range_max), + kernel::bc::MatchBoundaries_kernel( + domain.fields.em, + m_pgen.init_flds, + domain.mesh.metric, + xg_edge, + ds, + tags, + domain.mesh.flds_bc())); + Kokkos::parallel_for( + "MatchBoundaries", + CreateRangePolicy(range_min, range_max), + kernel::bc::MatchBoundaries_kernel( + domain.fields.em0, + m_pgen.init_flds, + domain.mesh.metric, + xg_edge, + ds, + tags, + domain.mesh.flds_bc())); + } else { + Kokkos::parallel_for( + "AbsorbCurrents", + CreateRangePolicy(range_min, range_max), + kernel::bc::gr::AbsorbCurrents_kernel(domain.fields.cur0, + domain.mesh.metric, + xg_edge, + ds)); + } + } else { + raise::Error("Invalid dimension", HERE); + } + } + + void HorizonFieldsIn(dir::direction_t direction, + domain_t& domain, + BCTags tags, + const gr_bc& g) { + /** + * open boundaries + */ + raise::ErrorIf(M::CoordType == Coord::Cart, + "Invalid coordinate type for horizon BCs", + HERE); + raise::ErrorIf(direction.get_dim() != in::x1, + "Invalid horizon direction, should be x1", + HERE); + const auto i1_min = domain.mesh.i_min(in::x1); + auto range = CreateRangePolicy({ domain.mesh.i_min(in::x2) }, + { domain.mesh.i_max(in::x2) + 1 }); + const auto nfilter = m_params.template get( + "algorithms.current_filters"); + if (g == gr_bc::main) { + Kokkos::parallel_for( + "OpenBCFields", + range, + kernel::bc::gr::HorizonBoundaries_kernel(domain.fields.em, + i1_min, + tags, + nfilter)); + Kokkos::parallel_for( + "OpenBCFields", + range, + kernel::bc::gr::HorizonBoundaries_kernel(domain.fields.em0, + i1_min, + tags, + nfilter)); + } + } + + void AxisFieldsIn(dir::direction_t direction, + domain_t& domain, + BCTags tags) { + /** + * axis boundaries + */ + raise::ErrorIf(M::CoordType == Coord::Cart, + "Invalid coordinate type for axis BCs", + HERE); + raise::ErrorIf(direction.get_dim() != in::x2, + "Invalid axis direction, should be x2", + HERE); + const auto i2_min = domain.mesh.i_min(in::x2); + const auto i2_max = domain.mesh.i_max(in::x2); + if (direction.get_sign() < 0) { + Kokkos::parallel_for( + "AxisBCFields", + domain.mesh.n_all(in::x1), + kernel::bc::AxisBoundaries_kernel(domain.fields.em, + i2_min, + tags)); + Kokkos::parallel_for( + "AxisBCFields", + domain.mesh.n_all(in::x1), + kernel::bc::AxisBoundaries_kernel(domain.fields.em0, + i2_min, + tags)); + } else { + Kokkos::parallel_for( + "AxisBCFields", + domain.mesh.n_all(in::x1), + kernel::bc::AxisBoundaries_kernel(domain.fields.em, + i2_max, + tags)); + Kokkos::parallel_for( + "AxisBCFields", + domain.mesh.n_all(in::x1), + kernel::bc::AxisBoundaries_kernel(domain.fields.em0, + i2_max, + tags)); + } + } + + void CustomFieldsIn(dir::direction_t direction, + domain_t& domain, + BCTags tags, + const gr_bc& g) { + (void)direction; + (void)domain; + (void)tags; + (void)g; + raise::Error("Custom boundaries not implemented", HERE); + // if constexpr ( + // traits::has_member::value) { + // const auto [box, custom_fields] = m_pgen.CustomFields(time); + // if (domain.mesh.Intersects(box)) { + // } + // + // } else { + // raise::Error("Custom boundaries not implemented", HERE); + // } + } + + /** + * @brief Swaps em and em0 fields, cur and cur0 currents. + */ + void SwapFields(domain_t& domain) { + std::swap(domain.fields.em, domain.fields.em0); + std::swap(domain.fields.cur, domain.fields.cur0); + } + + /** + * @brief Copies em fields into em0 + */ + void CopyFields(domain_t& domain) { + Kokkos::deep_copy(domain.fields.em0, domain.fields.em); + } + + void ComputeAuxE(domain_t& domain, const gr_getE& g) { + auto range = range_with_axis_BCs(domain); + if (g == gr_getE::D0_B) { + Kokkos::parallel_for( + "ComputeAuxE", + range, + kernel::gr::ComputeAuxE_kernel(domain.fields.em0, // D + domain.fields.em, // B + domain.fields.aux, // E + domain.mesh.metric)); + } else if (g == gr_getE::D_B0) { + Kokkos::parallel_for("ComputeAuxE", + range, + kernel::gr::ComputeAuxE_kernel(domain.fields.em, + domain.fields.em0, + domain.fields.aux, + domain.mesh.metric)); + } else { + raise::Error("Wrong option for `g`", HERE); + } + } + + void ComputeAuxH(domain_t& domain, const gr_getH& g) { + auto range = range_with_axis_BCs(domain); + if (g == gr_getH::D_B0) { + Kokkos::parallel_for( + "ComputeAuxH", + range, + kernel::gr::ComputeAuxH_kernel(domain.fields.em, // D + domain.fields.em0, // B + domain.fields.aux, // H + domain.mesh.metric)); + } else if (g == gr_getH::D0_B0) { + Kokkos::parallel_for("ComputeAuxH", + range, + kernel::gr::ComputeAuxH_kernel(domain.fields.em0, + domain.fields.em0, + domain.fields.aux, + domain.mesh.metric)); + } else { + raise::Error("Wrong option for `g`", HERE); + } + } + + auto range_with_axis_BCs(const domain_t& domain) -> range_t { + auto range = domain.mesh.rangeActiveCells(); + /** + * @brief taking one extra cell in the x1 and x2 directions if AXIS BCs + */ + if constexpr (M::Dim == Dim::_2D) { + if (domain.mesh.flds_bc_in({ 0, +1 }) == FldsBC::AXIS) { + range = CreateRangePolicy( + { domain.mesh.i_min(in::x1) - 1, domain.mesh.i_min(in::x2) }, + { domain.mesh.i_max(in::x1), domain.mesh.i_max(in::x2) + 1 }); + } else { + range = CreateRangePolicy( + { domain.mesh.i_min(in::x1) - 1, domain.mesh.i_min(in::x2) }, + { domain.mesh.i_max(in::x1), domain.mesh.i_max(in::x2) }); + } + } else if constexpr (M::Dim == Dim::_3D) { + raise::Error("Invalid dimension", HERE); + } + return range; + } + + void Faraday(domain_t& domain, const gr_faraday& g, real_t fraction = ONE) { + logger::Checkpoint("Launching Faraday kernel", HERE); + const auto dT = fraction * + m_params.template get( + "algorithms.timestep.correction") * + dt; + if (g == gr_faraday::aux) { + Kokkos::parallel_for( + "Faraday", + domain.mesh.rangeActiveCells(), + kernel::gr::Faraday_kernel(domain.fields.em0, // Bin + domain.fields.em0, // Bout + domain.fields.aux, // E + domain.mesh.metric, + dT, + domain.mesh.n_active(in::x2), + domain.mesh.flds_bc())); + } else if (g == gr_faraday::main) { + Kokkos::parallel_for( + "Faraday", + domain.mesh.rangeActiveCells(), + kernel::gr::Faraday_kernel(domain.fields.em, + domain.fields.em0, + domain.fields.aux, + domain.mesh.metric, + dT, + domain.mesh.n_active(in::x2), + domain.mesh.flds_bc())); + + } else { + raise::Error("Wrong option for `g`", HERE); + } + } + + void Ampere(domain_t& domain, const gr_ampere& g, real_t fraction = ONE) { + logger::Checkpoint("Launching Ampere kernel", HERE); + const auto dT = fraction * + m_params.template get( + "algorithms.timestep.correction") * + dt; + auto range = CreateRangePolicy( + { domain.mesh.i_min(in::x1), domain.mesh.i_min(in::x2) }, + { domain.mesh.i_max(in::x1), domain.mesh.i_max(in::x2) + 1 }); + const auto ni2 = domain.mesh.n_active(in::x2); + + if (g == gr_ampere::aux) { + // First push, updates D0 with J. + Kokkos::parallel_for("Ampere-1", + range, + kernel::gr::Ampere_kernel(domain.fields.em0, // Din + domain.fields.em0, // Dout + domain.fields.aux, + domain.mesh.metric, + dT, + ni2, + domain.mesh.flds_bc())); + } else if (g == gr_ampere::main) { + // Second push, updates D with J0 but assigns it to D0. + Kokkos::parallel_for("Ampere-2", + range, + kernel::gr::Ampere_kernel(domain.fields.em, + domain.fields.em0, + domain.fields.aux, + domain.mesh.metric, + dT, + ni2, + domain.mesh.flds_bc())); + } else if (g == gr_ampere::init) { + // Second push, updates D with J0 and assigns it to D. + Kokkos::parallel_for("Ampere-3", + range, + kernel::gr::Ampere_kernel(domain.fields.em, + domain.fields.em, + domain.fields.aux, + domain.mesh.metric, + dT, + ni2, + domain.mesh.flds_bc())); + } else { + raise::Error("Wrong option for `g`", HERE); + } + } + + void AmpereCurrents(domain_t& domain, const gr_ampere& g) { + logger::Checkpoint("Launching Ampere kernel for adding currents", HERE); + const auto q0 = m_params.template get("scales.q0"); + const auto B0 = m_params.template get("scales.B0"); + const auto coeff = -dt * q0 / B0; + auto range = CreateRangePolicy( + { domain.mesh.i_min(in::x1), domain.mesh.i_min(in::x2) }, + { domain.mesh.i_max(in::x1), domain.mesh.i_max(in::x2) + 1 }); + const auto ni2 = domain.mesh.n_active(in::x2); + + if (g == gr_ampere::aux) { + // Updates D0 with J: D0(n-1/2) -> (J(n)) -> D0(n+1/2) + Kokkos::parallel_for( + "AmpereCurrentsAux", + range, + kernel::gr::CurrentsAmpere_kernel(domain.fields.em0, + domain.fields.cur, + domain.mesh.metric, + coeff, + ni2, + domain.mesh.flds_bc())); + } else if (g == gr_ampere::main) { + // Updates D0 with J0: D0(n) -> (J0(n+1/2)) -> D0(n+1) + Kokkos::parallel_for( + "AmpereCurrentsMain", + range, + kernel::gr::CurrentsAmpere_kernel(domain.fields.em0, + domain.fields.cur0, + domain.mesh.metric, + coeff, + ni2, + domain.mesh.flds_bc())); + } else { + raise::Error("Wrong option for `g`", HERE); + } + } + + void TimeAverageDB(domain_t& domain) { + Kokkos::parallel_for("TimeAverageDB", + domain.mesh.rangeActiveCells(), + kernel::gr::TimeAverageDB_kernel(domain.fields.em, + domain.fields.em0, + domain.mesh.metric)); + } + + void TimeAverageJ(domain_t& domain) { + Kokkos::parallel_for("TimeAverageJ", + domain.mesh.rangeActiveCells(), + kernel::gr::TimeAverageJ_kernel(domain.fields.cur, + domain.fields.cur0, + domain.mesh.metric)); + } + + void CurrentsDeposit(domain_t& domain) { + auto scatter_cur0 = Kokkos::Experimental::create_scatter_view( + domain.fields.cur0); + for (auto& species : domain.species) { + logger::Checkpoint( + fmt::format("Launching currents deposit kernel for %d [%s] : %lu %f", + species.index(), + species.label().c_str(), + species.npart(), + (double)species.charge()), + HERE); + if (species.npart() == 0 || cmp::AlmostZero(species.charge())) { + continue; + } + Kokkos::parallel_for("CurrentsDeposit", + species.rangeActiveParticles(), + kernel::DepositCurrents_kernel( + scatter_cur0, + species.i1, + species.i2, + species.i3, + species.i1_prev, + species.i2_prev, + species.i3_prev, + species.dx1, + species.dx2, + species.dx3, + species.dx1_prev, + species.dx2_prev, + species.dx3_prev, + species.ux1, + species.ux2, + species.ux3, + species.phi, + species.weight, + species.tag, + domain.mesh.metric, + (real_t)(species.charge()), + dt)); + } + Kokkos::Experimental::contribute(domain.fields.cur0, scatter_cur0); + } + + void CurrentsFilter(domain_t& domain) { + logger::Checkpoint("Launching currents filtering kernels", HERE); + auto range = CreateRangePolicy( + { domain.mesh.i_min(in::x1), domain.mesh.i_min(in::x2) }, + { domain.mesh.i_max(in::x1), domain.mesh.i_max(in::x2) + 1 }); + const auto nfilter = m_params.template get( + "algorithms.current_filters"); + tuple_t size; + size[0] = domain.mesh.n_active(in::x1); + size[1] = domain.mesh.n_active(in::x2); + + // !TODO: this needs to be done more efficiently + for (unsigned short i = 0; i < nfilter; ++i) { + Kokkos::deep_copy(domain.fields.buff, domain.fields.cur0); + Kokkos::parallel_for("CurrentsFilter", + range, + kernel::DigitalFilter_kernel( + domain.fields.cur0, + domain.fields.buff, + size, + domain.mesh.flds_bc())); + m_metadomain.CommunicateFields(domain, Comm::J); // J0 + } + } + + void ParticlePush(domain_t& domain) { + for (auto& species : domain.species) { + species.set_unsorted(); + logger::Checkpoint( + fmt::format("Launching particle pusher kernel for %d [%s] : %lu", + species.index(), + species.label().c_str(), + species.npart()), + HERE); + if (species.npart() == 0) { + continue; + } + const auto q_ovr_m = species.mass() > ZERO + ? species.charge() / species.mass() + : ZERO; + // coeff = q / m (dt / 2) omegaB0 + const auto coeff = q_ovr_m * HALF * dt * + m_params.template get( + "algorithms.timestep.correction") * + m_params.template get("scales.omegaB0"); + const auto eps = m_params.template get( + "algorithms.gr.pusher_eps"); + const auto niter = m_params.template get( + "algorithms.gr.pusher_niter"); + // clang-format off + if (species.pusher() == PrtlPusher::PHOTON) { + auto range_policy = Kokkos::RangePolicy( + 0, + species.npart()); + + Kokkos::parallel_for( + "ParticlePusher", + range_policy, + kernel::gr::Pusher_kernel( + domain.fields.em, + domain.fields.em0, + species.i1, species.i2, species.i3, + species.i1_prev, species.i2_prev, species.i3_prev, + species.dx1, species.dx2, species.dx3, + species.dx1_prev, species.dx2_prev, species.dx3_prev, + species.ux1, species.ux2, species.ux3, + species.phi, species.tag, + domain.mesh.metric, + coeff, dt, + domain.mesh.n_active(in::x1), + domain.mesh.n_active(in::x2), + domain.mesh.n_active(in::x3), + eps, niter, + domain.mesh.prtl_bc() + )); + } else if (species.pusher() == PrtlPusher::BORIS) { + auto range_policy = Kokkos::RangePolicy( + 0, + species.npart()); + Kokkos::parallel_for( + "ParticlePusher", + range_policy, + kernel::gr::Pusher_kernel( + domain.fields.em, + domain.fields.em0, + species.i1, species.i2, species.i3, + species.i1_prev, species.i2_prev, species.i3_prev, + species.dx1, species.dx2, species.dx3, + species.dx1_prev, species.dx2_prev, species.dx3_prev, + species.ux1, species.ux2, species.ux3, + species.phi, species.tag, + domain.mesh.metric, + coeff, dt, + domain.mesh.n_active(in::x1), + domain.mesh.n_active(in::x2), + domain.mesh.n_active(in::x3), + eps, niter, + domain.mesh.prtl_bc() + )); + } else if (species.pusher() == PrtlPusher::NONE) { + // do nothing + } else { + raise::Error("not implemented", HERE); + } + // clang-format on + } + } + }; } // namespace ntt #endif // ENGINES_GRPIC_GRPIC_H diff --git a/src/engines/srpic.hpp b/src/engines/srpic.hpp index bddc557c9..6b6a52039 100644 --- a/src/engines/srpic.hpp +++ b/src/engines/srpic.hpp @@ -21,6 +21,7 @@ #include "utils/log.h" #include "utils/numeric.h" #include "utils/timer.h" +#include "utils/toml.h" #include "archetypes/particle_injector.h" #include "framework/domain/domain.h" @@ -41,7 +42,6 @@ #include #include -#include #include namespace ntt { @@ -70,7 +70,7 @@ namespace ntt { public: static constexpr auto S { SimEngine::SRPIC }; - SRPICEngine(SimulationParams& params) : base_t { params } {} + SRPICEngine(const SimulationParams& params) : base_t { params } {} ~SRPICEngine() = default; @@ -79,8 +79,8 @@ namespace ntt { "algorithms.toggles.fieldsolver"); const auto deposit_enabled = m_params.template get( "algorithms.toggles.deposit"); - const auto sort_interval = m_params.template get( - "particles.sort_interval"); + const auto clear_interval = m_params.template get( + "particles.clear_interval"); if (step == 0) { // communicate fields and apply BCs on the first timestep @@ -101,6 +101,7 @@ namespace ntt { timers.start("FieldBoundaries"); FieldBoundaries(dom, BC::B); timers.stop("FieldBoundaries"); + Kokkos::fence(); } { @@ -125,9 +126,7 @@ namespace ntt { } timers.start("Communications"); - if ((sort_interval > 0) and (step % sort_interval == 0)) { - m_metadomain.CommunicateParticles(dom, &timers); - } + m_metadomain.CommunicateParticles(dom); timers.stop("Communications"); } @@ -168,6 +167,12 @@ namespace ntt { ParticleInjector(dom); timers.stop("Injector"); } + + if (clear_interval > 0 and step % clear_interval == 0 and step > 0) { + timers.start("PrtlClear"); + m_metadomain.RemoveDeadParticles(dom); + timers.stop("PrtlClear"); + } } /* algorithm substeps --------------------------------------------------- */ @@ -274,6 +279,9 @@ namespace ntt { } } for (auto& species : domain.species) { + if ((species.pusher() == PrtlPusher::NONE) or (species.npart() == 0)) { + continue; + } species.set_unsorted(); logger::Checkpoint( fmt::format("Launching particle pusher kernel for %d [%s] : %lu", @@ -281,9 +289,6 @@ namespace ntt { species.label().c_str(), species.npart()), HERE); - if (species.npart() == 0) { - continue; - } const auto q_ovr_m = species.mass() > ZERO ? species.charge() / species.mass() : ZERO; @@ -476,6 +481,10 @@ namespace ntt { auto scatter_cur = Kokkos::Experimental::create_scatter_view( domain.fields.cur); for (auto& species : domain.species) { + if ((species.pusher() == PrtlPusher::NONE) or (species.npart() == 0) or + cmp::AlmostZero_host(species.charge())) { + continue; + } logger::Checkpoint( fmt::format("Launching currents deposit kernel for %d [%s] : %lu %f", species.index(), @@ -483,56 +492,61 @@ namespace ntt { species.npart(), (double)species.charge()), HERE); - if (species.npart() == 0 || cmp::AlmostZero(species.charge())) { - continue; - } + // clang-format off Kokkos::parallel_for("CurrentsDeposit", species.rangeActiveParticles(), kernel::DepositCurrents_kernel( scatter_cur, - species.i1, - species.i2, - species.i3, - species.i1_prev, - species.i2_prev, - species.i3_prev, - species.dx1, - species.dx2, - species.dx3, - species.dx1_prev, - species.dx2_prev, - species.dx3_prev, - species.ux1, - species.ux2, - species.ux3, - species.phi, - species.weight, - species.tag, + species.i1, species.i2, species.i3, + species.i1_prev, species.i2_prev, species.i3_prev, + species.dx1, species.dx2, species.dx3, + species.dx1_prev, species.dx2_prev, species.dx3_prev, + species.ux1, species.ux2, species.ux3, + species.phi, species.weight, species.tag, domain.mesh.metric, - (real_t)(species.charge()), - dt)); + (real_t)(species.charge()), dt)); + // clang-format on } Kokkos::Experimental::contribute(domain.fields.cur, scatter_cur); } void CurrentsAmpere(domain_t& domain) { logger::Checkpoint("Launching Ampere kernel for adding currents", HERE); - const auto q0 = m_params.template get("scales.q0"); - const auto n0 = m_params.template get("scales.n0"); - const auto B0 = m_params.template get("scales.B0"); - const auto coeff = -dt * q0 * n0 / B0; + const auto q0 = m_params.template get("scales.q0"); + const auto n0 = m_params.template get("scales.n0"); + const auto B0 = m_params.template get("scales.B0"); if constexpr (M::CoordType == Coord::Cart) { // minkowski case - const auto V0 = m_params.template get("scales.V0"); - - Kokkos::parallel_for( - "Ampere", - domain.mesh.rangeActiveCells(), - kernel::mink::CurrentsAmpere_kernel(domain.fields.em, - domain.fields.cur, - coeff / V0, - ONE / n0)); + const auto V0 = m_params.template get("scales.V0"); + const auto ppc0 = m_params.template get("particles.ppc0"); + const auto coeff = -dt * q0 / (B0 * V0); + if constexpr ( + traits::has_member::value) { + const std::vector xmin { domain.mesh.extent(in::x1).first, + domain.mesh.extent(in::x2).first, + domain.mesh.extent(in::x3).first }; + const auto ext_current = m_pgen.ext_current; + const auto dx = domain.mesh.metric.template sqrt_h_<1, 1>({}); + // clang-format off + Kokkos::parallel_for( + "Ampere", + domain.mesh.rangeActiveCells(), + kernel::mink::CurrentsAmpere_kernel( + domain.fields.em, domain.fields.cur, + coeff, ppc0, ext_current, xmin, dx)); + // clang-format on + } else { + Kokkos::parallel_for( + "Ampere", + domain.mesh.rangeActiveCells(), + kernel::mink::CurrentsAmpere_kernel(domain.fields.em, + domain.fields.cur, + coeff, + ppc0)); + } } else { + // non-minkowski + const auto coeff = -dt * q0 * n0 / B0; auto range = range_with_axis_BCs(domain); const auto ni2 = domain.mesh.n_active(in::x2); Kokkos::parallel_for( @@ -553,7 +567,7 @@ namespace ntt { auto range = range_with_axis_BCs(domain); const auto nfilter = m_params.template get( "algorithms.current_filters"); - tuple_t size; + tuple_t size; if constexpr (M::Dim == Dim::_1D || M::Dim == Dim::_2D || M::Dim == Dim::_3D) { size[0] = domain.mesh.n_active(in::x1); } @@ -564,7 +578,7 @@ namespace ntt { size[2] = domain.mesh.n_active(in::x3); } // !TODO: this needs to be done more efficiently - for (unsigned short i = 0; i < nfilter; ++i) { + for (auto i { 0u }; i < nfilter; ++i) { Kokkos::deep_copy(domain.fields.buff, domain.fields.cur); Kokkos::parallel_for("CurrentsFilter", range, @@ -579,17 +593,21 @@ namespace ntt { void FieldBoundaries(domain_t& domain, BCTags tags) { for (auto& direction : dir::Directions::orth) { - if (m_metadomain.mesh().flds_bc_in(direction) == FldsBC::ABSORB) { - AbsorbFieldsIn(direction, domain, tags); + if (m_metadomain.mesh().flds_bc_in(direction) == FldsBC::MATCH) { + MatchFieldsIn(direction, domain, tags); } else if (m_metadomain.mesh().flds_bc_in(direction) == FldsBC::AXIS) { if (domain.mesh.flds_bc_in(direction) == FldsBC::AXIS) { AxisFieldsIn(direction, domain, tags); } } else if (m_metadomain.mesh().flds_bc_in(direction) == FldsBC::ATMOSPHERE) { AtmosphereFieldsIn(direction, domain, tags); + } else if (m_metadomain.mesh().flds_bc_in(direction) == FldsBC::FIXED) { + if (domain.mesh.flds_bc_in(direction) == FldsBC::FIXED) { + FixedFieldsIn(direction, domain, tags); + } } else if (m_metadomain.mesh().flds_bc_in(direction) == FldsBC::CONDUCTOR) { if (domain.mesh.flds_bc_in(direction) == FldsBC::CONDUCTOR) { - ConductorFieldsIn(direction, domain, tags); + PerfectConductorFieldsIn(direction, domain, tags); } } else if (m_metadomain.mesh().flds_bc_in(direction) == FldsBC::CUSTOM) { if (domain.mesh.flds_bc_in(direction) == FldsBC::CUSTOM) { @@ -601,30 +619,33 @@ namespace ntt { } // loop over directions } - void AbsorbFieldsIn(dir::direction_t direction, - domain_t& domain, - BCTags tags) { + void MatchFieldsIn(dir::direction_t direction, + domain_t& domain, + BCTags tags) { /** - * absorbing boundaries + * matching boundaries */ - const auto ds = m_params.template get( - "grid.boundaries.absorb.ds"); + const auto ds_array = m_params.template get>( + "grid.boundaries.match.ds"); const auto dim = direction.get_dim(); real_t xg_min, xg_max, xg_edge; auto sign = direction.get_sign(); + real_t ds; if (sign > 0) { // + direction + ds = ds_array[(short)dim].second; xg_max = m_metadomain.mesh().extent(dim).second; xg_min = xg_max - ds; xg_edge = xg_max; } else { // - direction + ds = ds_array[(short)dim].first; xg_min = m_metadomain.mesh().extent(dim).first; xg_max = xg_min + ds; xg_edge = xg_min; } boundaries_t box; boundaries_t incl_ghosts; - for (unsigned short d { 0 }; d < M::Dim; ++d) { - if (d == static_cast(dim)) { + for (dim_t d { 0 }; d < M::Dim; ++d) { + if (d == static_cast(dim)) { box.push_back({ xg_min, xg_max }); if (sign > 0) { incl_ghosts.push_back({ false, true }); @@ -640,48 +661,100 @@ namespace ntt { return; } const auto intersect_range = domain.mesh.ExtentToRange(box, incl_ghosts); - tuple_t range_min { 0 }; - tuple_t range_max { 0 }; + tuple_t range_min { 0 }; + tuple_t range_max { 0 }; - for (unsigned short d { 0 }; d < M::Dim; ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { range_min[d] = intersect_range[d].first; range_max[d] = intersect_range[d].second; } + if (dim == in::x1) { - Kokkos::parallel_for( - "AbsorbFields", - CreateRangePolicy(range_min, range_max), - kernel::AbsorbBoundaries_kernel(domain.fields.em, - domain.mesh.metric, - xg_edge, - ds, - tags)); + if constexpr ( + traits::has_member::value) { + auto match_fields = m_pgen.MatchFields(time); + call_match_fields(domain.fields.em, + domain.mesh.flds_bc(), + match_fields, + domain.mesh.metric, + xg_edge, + ds, + tags, + range_min, + range_max); + } else if constexpr ( + traits::has_member::value) { + auto match_fields = m_pgen.MatchFieldsInX1(time); + call_match_fields(domain.fields.em, + domain.mesh.flds_bc(), + match_fields, + domain.mesh.metric, + xg_edge, + ds, + tags, + range_min, + range_max); + } } else if (dim == in::x2) { if constexpr (M::Dim == Dim::_2D or M::Dim == Dim::_3D) { - Kokkos::parallel_for( - "AbsorbFields", - CreateRangePolicy(range_min, range_max), - kernel::AbsorbBoundaries_kernel(domain.fields.em, - domain.mesh.metric, - xg_edge, - ds, - tags)); + if constexpr ( + traits::has_member::value) { + auto match_fields = m_pgen.MatchFields(time); + call_match_fields(domain.fields.em, + domain.mesh.flds_bc(), + match_fields, + domain.mesh.metric, + xg_edge, + ds, + tags, + range_min, + range_max); + } else if constexpr ( + traits::has_member::value) { + auto match_fields = m_pgen.MatchFieldsInX2(time); + call_match_fields(domain.fields.em, + domain.mesh.flds_bc(), + match_fields, + domain.mesh.metric, + xg_edge, + ds, + tags, + range_min, + range_max); + } } else { raise::Error("Invalid dimension", HERE); } } else if (dim == in::x3) { if constexpr (M::Dim == Dim::_3D) { - Kokkos::parallel_for( - "AbsorbFields", - CreateRangePolicy(range_min, range_max), - kernel::AbsorbBoundaries_kernel(domain.fields.em, - domain.mesh.metric, - xg_edge, - ds, - tags)); - } else { - raise::Error("Invalid dimension", HERE); + if constexpr ( + traits::has_member::value) { + auto match_fields = m_pgen.MatchFields(time); + call_match_fields(domain.fields.em, + domain.mesh.flds_bc(), + match_fields, + domain.mesh.metric, + xg_edge, + ds, + tags, + range_min, + range_max); + } else if constexpr ( + traits::has_member::value) { + auto match_fields = m_pgen.MatchFieldsInX3(time); + call_match_fields(domain.fields.em, + domain.mesh.flds_bc(), + match_fields, + domain.mesh.metric, + xg_edge, + ds, + tags, + range_min, + range_max); + } } + } else { + raise::Error("Invalid dimension", HERE); } } @@ -691,24 +764,261 @@ namespace ntt { /** * axis boundaries */ - raise::ErrorIf(M::CoordType == Coord::Cart, - "Invalid coordinate type for axis BCs", + if constexpr (M::CoordType != Coord::Cart) { + raise::ErrorIf(direction.get_dim() != in::x2, + "Invalid axis direction, should be x2", + HERE); + const auto i2_min = domain.mesh.i_min(in::x2); + const auto i2_max = domain.mesh.i_max(in::x2); + if (direction.get_sign() < 0) { + Kokkos::parallel_for( + "AxisBCFields", + domain.mesh.n_all(in::x1), + kernel::bc::AxisBoundaries_kernel(domain.fields.em, + i2_min, + tags)); + } else { + Kokkos::parallel_for( + "AxisBCFields", + domain.mesh.n_all(in::x1), + kernel::bc::AxisBoundaries_kernel(domain.fields.em, + i2_max, + tags)); + } + } else { + (void)direction; + (void)domain; + (void)tags; + raise::Error("Invalid coordinate type for axis BCs", HERE); + } + } + + void FixedFieldsIn(dir::direction_t direction, + domain_t& domain, + BCTags tags) { + /** + * fixed field boundaries + */ + const auto sign = direction.get_sign(); + const auto dim = direction.get_dim(); + raise::ErrorIf(dim != in::x1 and M::CoordType != Coord::Cart, + "Fixed BCs only implemented for x1 in " + "non-cartesian coordinates", HERE); - raise::ErrorIf(direction.get_dim() != in::x2, - "Invalid axis direction, should be x2", + em normal_b_comp, tang_e_comp1, tang_e_comp2; + if (dim == in::x1) { + normal_b_comp = em::bx1; + tang_e_comp1 = em::ex2; + tang_e_comp2 = em::ex3; + } else if (dim == in::x2) { + normal_b_comp = em::bx2; + tang_e_comp1 = em::ex1; + tang_e_comp2 = em::ex3; + } else if (dim == in::x3) { + normal_b_comp = em::bx3; + tang_e_comp1 = em::ex1; + tang_e_comp2 = em::ex2; + } else { + raise::Error("Invalid dimension", HERE); + } + std::vector xi_min, xi_max; + const std::vector all_dirs { in::x1, in::x2, in::x3 }; + for (dim_t d { 0u }; d < M::Dim; ++d) { + const auto dd = all_dirs[d]; + if (dim == dd) { + if (sign > 0) { // + direction + xi_min.push_back(domain.mesh.n_all(dd) - N_GHOSTS); + xi_max.push_back(domain.mesh.n_all(dd)); + } else { // - direction + xi_min.push_back(0); + xi_max.push_back(N_GHOSTS); + } + } else { + xi_min.push_back(0); + xi_max.push_back(domain.mesh.n_all(dd)); + } + } + raise::ErrorIf(xi_min.size() != xi_max.size() or + xi_min.size() != static_cast(M::Dim), + "Invalid range size", HERE); - const auto i2_min = domain.mesh.i_min(in::x2); - const auto i2_max = domain.mesh.i_max(in::x2); - if (direction.get_sign() < 0) { - Kokkos::parallel_for( - "AxisBCFields", - domain.mesh.n_all(in::x1), - kernel::AxisBoundaries_kernel(domain.fields.em, i2_min, tags)); + std::vector comps; + if (tags & BC::E) { + comps.push_back(tang_e_comp1); + comps.push_back(tang_e_comp2); + } + if (tags & BC::B) { + comps.push_back(normal_b_comp); + } + if constexpr (traits::has_member::value) { + raise::Error("Non-const fixed fields not implemented", HERE); + } else if constexpr ( + traits::has_member::value) { + for (const auto& comp : comps) { + auto value = ZERO; + bool shouldset = false; + if constexpr ( + traits::has_member::value) { + // if fix field function present, read from it + const auto newset = m_pgen.FixFieldsConst( + (bc_in)(sign * ((short)dim + 1)), + (em)comp); + value = newset.first; + shouldset = newset.second; + } + if (shouldset) { + if constexpr (M::Dim == Dim::_1D) { + Kokkos::deep_copy( + Kokkos::subview(domain.fields.em, + std::make_pair(xi_min[0], xi_max[0]), + comp), + value); + } else if constexpr (M::Dim == Dim::_2D) { + Kokkos::deep_copy( + Kokkos::subview(domain.fields.em, + std::make_pair(xi_min[0], xi_max[0]), + std::make_pair(xi_min[1], xi_max[1]), + comp), + value); + } else if constexpr (M::Dim == Dim::_3D) { + Kokkos::deep_copy( + Kokkos::subview(domain.fields.em, + std::make_pair(xi_min[0], xi_max[0]), + std::make_pair(xi_min[1], xi_max[1]), + std::make_pair(xi_min[2], xi_max[2]), + comp), + value); + } else { + raise::Error("Invalid dimension", HERE); + } + } + } } else { - Kokkos::parallel_for( - "AxisBCFields", - domain.mesh.n_all(in::x1), - kernel::AxisBoundaries_kernel(domain.fields.em, i2_max, tags)); + (void)direction; + (void)domain; + (void)tags; + raise::Error("Fixed fields not present (both const and non-const)", HERE); + } + } + + void PerfectConductorFieldsIn(dir::direction_t direction, + domain_t& domain, + BCTags tags) { + /** + * perfect conductor field boundaries + */ + if constexpr (M::CoordType != Coord::Cart) { + (void)direction; + (void)domain; + (void)tags; + raise::Error( + "Perfect conductor BCs only applicable to cartesian coordinates", + HERE); + } else { + const auto sign = direction.get_sign(); + const auto dim = direction.get_dim(); + + std::vector xi_min, xi_max; + + const std::vector all_dirs { in::x1, in::x2, in::x3 }; + + for (auto d { 0u }; d < M::Dim; ++d) { + const auto dd = all_dirs[d]; + if (dim == dd) { + xi_min.push_back(0); + xi_max.push_back((sign < 0) ? (N_GHOSTS + 1) : N_GHOSTS); + } else { + xi_min.push_back(0); + xi_max.push_back(domain.mesh.n_all(dd)); + } + } + raise::ErrorIf(xi_min.size() != xi_max.size() or + xi_min.size() != static_cast(M::Dim), + "Invalid range size", + HERE); + + range_t range; + if constexpr (M::Dim == Dim::_1D) { + range = CreateRangePolicy({ xi_min[0] }, { xi_max[0] }); + } else if constexpr (M::Dim == Dim::_2D) { + range = CreateRangePolicy({ xi_min[0], xi_min[1] }, + { xi_max[0], xi_max[1] }); + } else if constexpr (M::Dim == Dim::_3D) { + range = CreateRangePolicy({ xi_min[0], xi_min[1], xi_min[2] }, + { xi_max[0], xi_max[1], xi_max[2] }); + } else { + raise::Error("Invalid dimension", HERE); + } + std::size_t i_edge; + if (sign > 0) { + i_edge = domain.mesh.i_max(dim); + } else { + i_edge = domain.mesh.i_min(dim); + } + + if (dim == in::x1) { + if (sign > 0) { + Kokkos::parallel_for( + "ConductorFields", + range, + kernel::bc::ConductorBoundaries_kernel( + domain.fields.em, + i_edge, + tags)); + } else { + Kokkos::parallel_for( + "ConductorFields", + range, + kernel::bc::ConductorBoundaries_kernel( + domain.fields.em, + i_edge, + tags)); + } + } else if (dim == in::x2) { + if constexpr (M::Dim == Dim::_2D or M::Dim == Dim::_3D) { + if (sign > 0) { + Kokkos::parallel_for( + "ConductorFields", + range, + kernel::bc::ConductorBoundaries_kernel( + domain.fields.em, + i_edge, + tags)); + } else { + Kokkos::parallel_for( + "ConductorFields", + range, + kernel::bc::ConductorBoundaries_kernel( + domain.fields.em, + i_edge, + tags)); + } + } else { + raise::Error("Invalid dimension", HERE); + } + } else { + if constexpr (M::Dim == Dim::_3D) { + if (sign > 0) { + Kokkos::parallel_for( + "ConductorFields", + range, + kernel::bc::ConductorBoundaries_kernel( + domain.fields.em, + i_edge, + tags)); + } else { + Kokkos::parallel_for( + "ConductorFields", + range, + kernel::bc::ConductorBoundaries_kernel( + domain.fields.em, + i_edge, + tags)); + } + } else { + raise::Error("Invalid dimension", HERE); + } + } } } @@ -716,14 +1026,14 @@ namespace ntt { domain_t& domain, BCTags tags) { /** - * atmosphere boundaries + * atmosphere field boundaries */ - if constexpr (traits::has_member::value) { + if constexpr (traits::has_member::value) { const auto [sign, dim, xg_min, xg_max] = get_atm_extent(direction); - const auto dd = static_cast(dim); + const auto dd = static_cast(dim); boundaries_t box; boundaries_t incl_ghosts; - for (unsigned short d { 0 }; d < M::Dim; ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { if (d == dd) { box.push_back({ xg_min, xg_max }); if (sign > 0) { @@ -743,11 +1053,11 @@ namespace ntt { tuple_t range_min { 0 }; tuple_t range_max { 0 }; - for (unsigned short d { 0 }; d < M::Dim; ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { range_min[d] = intersect_range[d].first; range_max[d] = intersect_range[d].second; } - auto field_driver = m_pgen.FieldDriver(time); + auto atm_fields = m_pgen.AtmFields(time); std::size_t il_edge; if (sign > 0) { il_edge = range_min[dd] - N_GHOSTS; @@ -760,9 +1070,9 @@ namespace ntt { Kokkos::parallel_for( "AtmosphereBCFields", range, - kernel::AtmosphereBoundaries_kernel( + kernel::bc::EnforcedBoundaries_kernel( domain.fields.em, - field_driver, + atm_fields, domain.mesh.metric, il_edge, tags)); @@ -770,9 +1080,9 @@ namespace ntt { Kokkos::parallel_for( "AtmosphereBCFields", range, - kernel::AtmosphereBoundaries_kernel( + kernel::bc::EnforcedBoundaries_kernel( domain.fields.em, - field_driver, + atm_fields, domain.mesh.metric, il_edge, tags)); @@ -783,9 +1093,9 @@ namespace ntt { Kokkos::parallel_for( "AtmosphereBCFields", range, - kernel::AtmosphereBoundaries_kernel( + kernel::bc::EnforcedBoundaries_kernel( domain.fields.em, - field_driver, + atm_fields, domain.mesh.metric, il_edge, tags)); @@ -793,9 +1103,9 @@ namespace ntt { Kokkos::parallel_for( "AtmosphereBCFields", range, - kernel::AtmosphereBoundaries_kernel( + kernel::bc::EnforcedBoundaries_kernel( domain.fields.em, - field_driver, + atm_fields, domain.mesh.metric, il_edge, tags)); @@ -809,9 +1119,9 @@ namespace ntt { Kokkos::parallel_for( "AtmosphereBCFields", range, - kernel::AtmosphereBoundaries_kernel( + kernel::bc::EnforcedBoundaries_kernel( domain.fields.em, - field_driver, + atm_fields, domain.mesh.metric, il_edge, tags)); @@ -819,9 +1129,9 @@ namespace ntt { Kokkos::parallel_for( "AtmosphereBCFields", range, - kernel::AtmosphereBoundaries_kernel( + kernel::bc::EnforcedBoundaries_kernel( domain.fields.em, - field_driver, + atm_fields, domain.mesh.metric, il_edge, tags)); @@ -833,80 +1143,10 @@ namespace ntt { raise::Error("Invalid dimension", HERE); } } else { - raise::Error("Field driver not implemented in PGEN for atmosphere BCs", - HERE); - } - } - - void ConductorFieldsIn(dir::direction_t direction, - domain_t& domain, - BCTags tags) { - const auto sign = direction.get_sign(); - const auto dim = direction.get_dim(); - raise::ErrorIf( - dim != in::x1 and M::CoordType != Coord::Cart, - "Conductor BCs only implemented for x1 in non-cartesian coordinates", - HERE); - em normal_b_comp, tang_e_comp1, tang_e_comp2; - if (dim == in::x1) { - normal_b_comp = em::bx1; - tang_e_comp1 = em::ex2; - tang_e_comp2 = em::ex3; - } else if (dim == in::x2) { - normal_b_comp = em::bx2; - tang_e_comp1 = em::ex1; - tang_e_comp2 = em::ex3; - } else if (dim == in::x3) { - normal_b_comp = em::bx3; - tang_e_comp1 = em::ex1; - tang_e_comp2 = em::ex2; - } else { - raise::Error("Invalid dimension", HERE); - } - std::vector xi_min, xi_max; - const std::vector all_dirs { in::x1, in::x2, in::x3 }; - for (unsigned short d { 0 }; d < static_cast(M::Dim); ++d) { - const auto dd = all_dirs[d]; - if (dim == dd) { - if (sign > 0) { // + direction - xi_min.push_back(domain.mesh.n_all(dd) - N_GHOSTS); - xi_max.push_back(domain.mesh.n_all(dd)); - } else { // - direction - xi_min.push_back(0); - xi_max.push_back(N_GHOSTS); - } - } else { - xi_min.push_back(0); - xi_max.push_back(domain.mesh.n_all(dd)); - } - } - raise::ErrorIf(xi_min.size() != xi_max.size() or - xi_min.size() != static_cast(M::Dim), - "Invalid range size", - HERE); - for (const unsigned short comp : - { normal_b_comp, tang_e_comp1, tang_e_comp2 }) { - if constexpr (M::Dim == Dim::_1D) { - Kokkos::deep_copy(Kokkos::subview(domain.fields.em, - std::make_pair(xi_min[0], xi_max[0]), - comp), - ZERO); - } else if constexpr (M::Dim == Dim::_2D) { - Kokkos::deep_copy(Kokkos::subview(domain.fields.em, - std::make_pair(xi_min[0], xi_max[0]), - std::make_pair(xi_min[1], xi_max[1]), - comp), - ZERO); - } else if constexpr (M::Dim == Dim::_3D) { - Kokkos::deep_copy(Kokkos::subview(domain.fields.em, - std::make_pair(xi_min[0], xi_max[0]), - std::make_pair(xi_min[1], xi_max[1]), - std::make_pair(xi_min[2], xi_max[2]), - comp), - ZERO); - } else { - raise::Error("Invalid dimension", HERE); - } + (void)direction; + (void)domain; + (void)tags; + raise::Error("Atm fields not implemented in PGEN for atmosphere BCs", HERE); } } @@ -940,9 +1180,8 @@ namespace ntt { "grid.boundaries.atmosphere.temperature"); const auto height = m_params.template get( "grid.boundaries.atmosphere.height"); - const auto species = - m_params.template get>( - "grid.boundaries.atmosphere.species"); + const auto species = m_params.template get>( + "grid.boundaries.atmosphere.species"); const auto nmax = m_params.template get( "grid.boundaries.atmosphere.density"); @@ -975,7 +1214,7 @@ namespace ntt { } } else { for (const auto& sp : - std::vector({ species.first, species.second })) { + std::vector { species.first, species.second }) { auto& prtl_spec = domain.species[sp - 1]; if (prtl_spec.npart() == 0) { continue; @@ -1153,8 +1392,8 @@ namespace ntt { "possible only in -x1 (@ rmin)", HERE); } - real_t xg_min { ZERO }, xg_max { ZERO }; - std::size_t ig_min, ig_max; + real_t xg_min { ZERO }, xg_max { ZERO }; + ncells_t ig_min, ig_max; if (sign > 0) { // + direction ig_min = m_metadomain.mesh().n_active(dim) - buffer_ncells; ig_max = m_metadomain.mesh().n_active(dim); @@ -1217,6 +1456,28 @@ namespace ntt { } return range; } + + template + void call_match_fields(ndfield_t& fields, + const boundaries_t& boundaries, + const T& match_fields, + const M& metric, + real_t xg_edge, + real_t ds, + BCTags tags, + tuple_t& range_min, + tuple_t& range_max) { + Kokkos::parallel_for( + "MatchFields", + CreateRangePolicy(range_min, range_max), + kernel::bc::MatchBoundaries_kernel(fields, + match_fields, + metric, + xg_edge, + ds, + tags, + boundaries)); + } }; } // namespace ntt diff --git a/src/entity.cpp b/src/entity.cpp index 272635d68..79b2f1335 100644 --- a/src/entity.cpp +++ b/src/entity.cpp @@ -114,4 +114,4 @@ auto main(int argc, char* argv[]) -> int { } return 0; -} +} \ No newline at end of file diff --git a/src/framework/CMakeLists.txt b/src/framework/CMakeLists.txt index c7c3ba8a1..b74d11bec 100644 --- a/src/framework/CMakeLists.txt +++ b/src/framework/CMakeLists.txt @@ -1,52 +1,64 @@ +# cmake-lint: disable=C0103 # ------------------------------ # @defines: ntt_framework [STATIC/SHARED] +# # @sources: -# - parameters.cpp -# - simulation.cpp -# - domain/grid.cpp -# - domain/metadomain.cpp -# - domain/communications.cpp -# - containers/particles.cpp -# - containers/fields.cpp -# - domain/output.cpp +# +# * parameters.cpp +# * simulation.cpp +# * domain/grid.cpp +# * domain/metadomain.cpp +# * domain/communications.cpp +# * domain/checkpoint.cpp +# * containers/particles.cpp +# * containers/fields.cpp +# * domain/stats.cpp +# * domain/output.cpp +# # @includes: -# - ../ +# +# * ../ +# # @depends: -# - ntt_global [required] -# - ntt_metrics [required] -# - ntt_kernels [required] -# - ntt_output [optional] +# +# * ntt_global [required] +# * ntt_metrics [required] +# * ntt_kernels [required] +# * ntt_output [required] +# # @uses: -# - kokkos [required] -# - plog [required] -# - toml11 [required] -# - ADIOS2 [optional] -# - mpi [optional] +# +# * kokkos [required] +# * plog [required] +# * ADIOS2 [optional] +# * mpi [optional] # ------------------------------ set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(SOURCES - ${SRC_DIR}/parameters.cpp - ${SRC_DIR}/simulation.cpp - ${SRC_DIR}/domain/grid.cpp - ${SRC_DIR}/domain/metadomain.cpp - ${SRC_DIR}/domain/communications.cpp - ${SRC_DIR}/containers/particles.cpp - ${SRC_DIR}/containers/fields.cpp -) -if (${output}) +set(SOURCES + ${SRC_DIR}/parameters.cpp + ${SRC_DIR}/simulation.cpp + ${SRC_DIR}/domain/grid.cpp + ${SRC_DIR}/domain/metadomain.cpp + ${SRC_DIR}/domain/communications.cpp + ${SRC_DIR}/domain/stats.cpp + ${SRC_DIR}/containers/particles.cpp + ${SRC_DIR}/containers/fields.cpp) +if(${output}) list(APPEND SOURCES ${SRC_DIR}/domain/output.cpp) + list(APPEND SOURCES ${SRC_DIR}/domain/checkpoint.cpp) endif() add_library(ntt_framework ${SOURCES}) -set(libs ntt_global ntt_metrics ntt_kernels) +set(libs ntt_global ntt_metrics ntt_kernels ntt_output) if(${output}) - list(APPEND libs ntt_output) + list(APPEND libs ntt_checkpoint) endif() add_dependencies(ntt_framework ${libs}) target_link_libraries(ntt_framework PUBLIC ${libs}) +target_link_libraries(ntt_framework PRIVATE stdc++fs) -target_include_directories(ntt_framework +target_include_directories( + ntt_framework PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../ - INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../ -) \ No newline at end of file + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../) diff --git a/src/framework/containers/fields.cpp b/src/framework/containers/fields.cpp index a62886b06..7202ff282 100644 --- a/src/framework/containers/fields.cpp +++ b/src/framework/containers/fields.cpp @@ -8,8 +8,8 @@ namespace ntt { template - Fields::Fields(const std::vector& res) { - std::size_t nx1, nx2, nx3; + Fields::Fields(const std::vector& res) { + ncells_t nx1, nx2, nx3; nx1 = res[0] + 2 * N_GHOSTS; if constexpr ((D == Dim::_3D) || (D == Dim::_2D)) { nx2 = res[1] + 2 * N_GHOSTS; @@ -52,4 +52,4 @@ namespace ntt { template struct Fields; template struct Fields; -} // namespace ntt \ No newline at end of file +} // namespace ntt diff --git a/src/framework/containers/fields.h b/src/framework/containers/fields.h index d0bd7d020..ee9d656d6 100644 --- a/src/framework/containers/fields.h +++ b/src/framework/containers/fields.h @@ -109,7 +109,7 @@ namespace ntt { */ Fields() {} - Fields(const std::vector& res); + Fields(const std::vector& res); Fields(Fields&& other) noexcept : em { std::move(other.em) } diff --git a/src/framework/containers/particles.cpp b/src/framework/containers/particles.cpp index f0c64c4ee..d2db9c491 100644 --- a/src/framework/containers/particles.cpp +++ b/src/framework/containers/particles.cpp @@ -4,184 +4,226 @@ #include "global.h" #include "arch/kokkos_aliases.h" -#include "utils/sorting.h" #include "framework/containers/species.h" #include #include +#include #include #include namespace ntt { template - Particles::Particles(unsigned short index, + Particles::Particles(spidx_t index, const std::string& label, float m, float ch, - std::size_t maxnpart, + npart_t maxnpart, const PrtlPusher& pusher, bool use_gca, const Cooling& cooling, unsigned short npld) : ParticleSpecies(index, label, m, ch, maxnpart, pusher, use_gca, cooling, npld) { - i1 = array_t { label + "_i1", maxnpart }; - i1_h = Kokkos::create_mirror_view(i1); - dx1 = array_t { label + "_dx1", maxnpart }; - dx1_h = Kokkos::create_mirror_view(dx1); - - i1_prev = array_t { label + "_i1_prev", maxnpart }; - dx1_prev = array_t { label + "_dx1_prev", maxnpart }; - - ux1 = array_t { label + "_ux1", maxnpart }; - ux1_h = Kokkos::create_mirror_view(ux1); - ux2 = array_t { label + "_ux2", maxnpart }; - ux2_h = Kokkos::create_mirror_view(ux2); - ux3 = array_t { label + "_ux3", maxnpart }; - ux3_h = Kokkos::create_mirror_view(ux3); - - weight = array_t { label + "_w", maxnpart }; - weight_h = Kokkos::create_mirror_view(weight); - - tag = array_t { label + "_tag", maxnpart }; - tag_h = Kokkos::create_mirror_view(tag); - - for (unsigned short n { 0 }; n < npld; ++n) { - pld.push_back(array_t("pld", maxnpart)); - pld_h.push_back(Kokkos::create_mirror_view(pld[n])); - } - if constexpr ((D == Dim::_2D) || (D == Dim::_3D)) { - i2 = array_t { label + "_i2", maxnpart }; - i2_h = Kokkos::create_mirror_view(i2); - dx2 = array_t { label + "_dx2", maxnpart }; - dx2_h = Kokkos::create_mirror_view(dx2); + if constexpr (D == Dim::_1D or D == Dim::_2D or D == Dim::_3D) { + i1 = array_t { label + "_i1", maxnpart }; + dx1 = array_t { label + "_dx1", maxnpart }; + i1_prev = array_t { label + "_i1_prev", maxnpart }; + dx1_prev = array_t { label + "_dx1_prev", maxnpart }; + } + if constexpr (D == Dim::_2D or D == Dim::_3D) { + i2 = array_t { label + "_i2", maxnpart }; + dx2 = array_t { label + "_dx2", maxnpart }; i2_prev = array_t { label + "_i2_prev", maxnpart }; dx2_prev = array_t { label + "_dx2_prev", maxnpart }; } - if ((D == Dim::_2D) && (C != Coord::Cart)) { - phi = array_t { label + "_phi", maxnpart }; - phi_h = Kokkos::create_mirror_view(phi); - } if constexpr (D == Dim::_3D) { - i3 = array_t { label + "_i3", maxnpart }; - i3_h = Kokkos::create_mirror_view(i3); - dx3 = array_t { label + "_dx3", maxnpart }; - dx3_h = Kokkos::create_mirror_view(dx3); - + i3 = array_t { label + "_i3", maxnpart }; + dx3 = array_t { label + "_dx3", maxnpart }; i3_prev = array_t { label + "_i3_prev", maxnpart }; dx3_prev = array_t { label + "_dx3_prev", maxnpart }; } + + ux1 = array_t { label + "_ux1", maxnpart }; + ux2 = array_t { label + "_ux2", maxnpart }; + ux3 = array_t { label + "_ux3", maxnpart }; + + weight = array_t { label + "_w", maxnpart }; + + tag = array_t { label + "_tag", maxnpart }; + + if (npld > 0) { + pld = array_t { label + "_pld", maxnpart, npld }; + } + + if ((D == Dim::_2D) && (C != Coord::Cart)) { + phi = array_t { label + "_phi", maxnpart }; + } } template - auto Particles::npart_per_tag() const -> std::vector { - auto this_tag = tag; - array_t npart_tag("npart_tags", ntags()); - - auto npart_tag_scatter = Kokkos::Experimental::create_scatter_view(npart_tag); + auto Particles::NpartsPerTagAndOffsets() const + -> std::pair, array_t> { + auto this_tag = tag; + const auto num_tags = ntags(); + array_t npptag { "nparts_per_tag", ntags() }; + + // count # of particles per each tag + auto npptag_scat = Kokkos::Experimental::create_scatter_view(npptag); Kokkos::parallel_for( "NpartPerTag", - npart(), + rangeActiveParticles(), Lambda(index_t p) { - auto npart_tag_scatter_access = npart_tag_scatter.access(); - npart_tag_scatter_access((int)(this_tag(p))) += 1; + auto npptag_acc = npptag_scat.access(); + if (this_tag(p) < 0 || this_tag(p) >= static_cast(num_tags)) { + raise::KernelError(HERE, "Invalid tag value"); + } + npptag_acc(this_tag(p)) += 1; }); - Kokkos::Experimental::contribute(npart_tag, npart_tag_scatter); + Kokkos::Experimental::contribute(npptag, npptag_scat); + + // copy the count to a vector on the host + auto npptag_h = Kokkos::create_mirror_view(npptag); + Kokkos::deep_copy(npptag_h, npptag); + std::vector npptag_vec(num_tags); + for (auto t { 0u }; t < num_tags; ++t) { + npptag_vec[t] = npptag_h(t); + } - auto npart_tag_host = Kokkos::create_mirror_view(npart_tag); - Kokkos::deep_copy(npart_tag_host, npart_tag); + // count the offsets on the host and copy to device + array_t tag_offsets("tag_offsets", num_tags - 3); + auto tag_offsets_h = Kokkos::create_mirror_view(tag_offsets); - std::vector npart_tag_vec; - for (std::size_t t { 0 }; t < ntags(); ++t) { - npart_tag_vec.push_back(npart_tag_host(t)); + tag_offsets_h(0) = npptag_vec[2]; // offset for tag = 3 + for (auto t { 1u }; t < num_tags - 3; ++t) { + tag_offsets_h(t) = npptag_vec[t + 2] + tag_offsets_h(t - 1); } - return npart_tag_vec; + Kokkos::deep_copy(tag_offsets, tag_offsets_h); + + return { npptag_vec, tag_offsets }; + } + + template + void RemoveDeadInArray(array_t& arr, const array_t& indices_alive) { + npart_t n_alive = indices_alive.extent(0); + auto buffer = Kokkos::View("buffer", n_alive); + Kokkos::parallel_for( + "PopulateBufferAlive", + n_alive, + Lambda(index_t p) { buffer(p) = arr(indices_alive(p)); }); + + Kokkos::deep_copy( + Kokkos::subview(arr, std::make_pair(static_cast(0), n_alive)), + buffer); + } + + template + void RemoveDeadInArray(array_t& arr, const array_t& indices_alive) { + npart_t n_alive = indices_alive.extent(0); + auto buffer = array_t { "buffer", n_alive, arr.extent(1) }; + Kokkos::parallel_for( + "PopulateBufferAlive", + CreateRangePolicy({ 0, 0 }, { n_alive, arr.extent(1) }), + Lambda(index_t p, index_t l) { buffer(p, l) = arr(indices_alive(p), l); }); + + Kokkos::deep_copy( + Kokkos::subview(arr, + std::make_pair(static_cast(0), n_alive), + Kokkos::ALL), + buffer); } template - auto Particles::SortByTags() -> std::vector { - if (npart() == 0 || is_sorted()) { - return npart_per_tag(); - } - using KeyType = array_t; - using BinOp = sort::BinTag; - BinOp bin_op(ntags()); - auto slice = range_tuple_t(0, npart()); - Kokkos::BinSort Sorter(Kokkos::subview(tag, slice), bin_op, false); - Sorter.create_permute_vector(); - - Sorter.sort(Kokkos::subview(i1, slice)); - Sorter.sort(Kokkos::subview(dx1, slice)); - Sorter.sort(Kokkos::subview(i1_prev, slice)); - Sorter.sort(Kokkos::subview(dx1_prev, slice)); - Sorter.sort(Kokkos::subview(ux1, slice)); - Sorter.sort(Kokkos::subview(ux2, slice)); - Sorter.sort(Kokkos::subview(ux3, slice)); - - Sorter.sort(Kokkos::subview(tag, slice)); - Sorter.sort(Kokkos::subview(weight, slice)); - - for (unsigned short n { 0 }; n < npld(); ++n) { - Sorter.sort(Kokkos::subview(pld[n], slice)); - } + void Particles::RemoveDead() { + npart_t n_alive = 0, n_dead = 0; + auto& this_tag = tag; + + Kokkos::parallel_reduce( + "CountDeadAlive", + rangeActiveParticles(), + Lambda(index_t p, npart_t & nalive, npart_t & ndead) { + nalive += (this_tag(p) == ParticleTag::alive); + ndead += (this_tag(p) == ParticleTag::dead); + if (this_tag(p) != ParticleTag::alive and this_tag(p) != ParticleTag::dead) { + raise::KernelError(HERE, "wrong particle tag"); + } + }, + n_alive, + n_dead); + + array_t indices_alive { "indices_alive", n_alive }; + array_t alive_counter { "counter_alive", 1 }; - if constexpr ((D == Dim::_2D) || (D == Dim::_3D)) { - Sorter.sort(Kokkos::subview(i2, slice)); - Sorter.sort(Kokkos::subview(dx2, slice)); + Kokkos::parallel_for( + "AliveIndices", + rangeActiveParticles(), + Lambda(index_t p) { + if (this_tag(p) == ParticleTag::alive) { + const auto idx = Kokkos::atomic_fetch_add(&alive_counter(0), 1); + indices_alive(idx) = p; + } + }); - Sorter.sort(Kokkos::subview(i2_prev, slice)); - Sorter.sort(Kokkos::subview(dx2_prev, slice)); + { + auto alive_counter_h = Kokkos::create_mirror_view(alive_counter); + Kokkos::deep_copy(alive_counter_h, alive_counter); + raise::ErrorIf(alive_counter_h(0) != n_alive, + "error in finding alive particle indices", + HERE); } - if constexpr (D == Dim::_3D) { - Sorter.sort(Kokkos::subview(i3, slice)); - Sorter.sort(Kokkos::subview(dx3, slice)); - Sorter.sort(Kokkos::subview(i3_prev, slice)); - Sorter.sort(Kokkos::subview(dx3_prev, slice)); + if constexpr (D == Dim::_1D or D == Dim::_2D or D == Dim::_3D) { + RemoveDeadInArray(i1, indices_alive); + RemoveDeadInArray(i1_prev, indices_alive); + RemoveDeadInArray(dx1, indices_alive); + RemoveDeadInArray(dx1_prev, indices_alive); } - if ((D == Dim::_2D) && (C != Coord::Cart)) { - Sorter.sort(Kokkos::subview(phi, slice)); + if constexpr (D == Dim::_2D or D == Dim::_3D) { + RemoveDeadInArray(i2, indices_alive); + RemoveDeadInArray(i2_prev, indices_alive); + RemoveDeadInArray(dx2, indices_alive); + RemoveDeadInArray(dx2_prev, indices_alive); } - const auto np_per_tag = npart_per_tag(); - set_npart(np_per_tag[(short)(ParticleTag::alive)]); + if constexpr (D == Dim::_3D) { + RemoveDeadInArray(i3, indices_alive); + RemoveDeadInArray(i3_prev, indices_alive); + RemoveDeadInArray(dx3, indices_alive); + RemoveDeadInArray(dx3_prev, indices_alive); + } - m_is_sorted = true; - return np_per_tag; - } + RemoveDeadInArray(ux1, indices_alive); + RemoveDeadInArray(ux2, indices_alive); + RemoveDeadInArray(ux3, indices_alive); + RemoveDeadInArray(weight, indices_alive); - template - void Particles::SyncHostDevice() { - Kokkos::deep_copy(i1_h, i1); - Kokkos::deep_copy(dx1_h, dx1); - Kokkos::deep_copy(ux1_h, ux1); - Kokkos::deep_copy(ux2_h, ux2); - Kokkos::deep_copy(ux3_h, ux3); - - Kokkos::deep_copy(tag_h, tag); - Kokkos::deep_copy(weight_h, weight); - - for (auto n { 0 }; n < npld(); ++n) { - Kokkos::deep_copy(pld_h[n], pld[n]); + if constexpr (D == Dim::_2D && C != Coord::Cart) { + RemoveDeadInArray(phi, indices_alive); } - if constexpr ((D == Dim::_2D) || (D == Dim::_3D)) { - Kokkos::deep_copy(i2_h, i2); - Kokkos::deep_copy(dx2_h, dx2); - } - if constexpr (D == Dim::_3D) { - Kokkos::deep_copy(i3_h, i3); - Kokkos::deep_copy(dx3_h, dx3); + if (npld() > 0) { + RemoveDeadInArray(pld, indices_alive); } - if ((D == Dim::_2D) && (C != Coord::Cart)) { - Kokkos::deep_copy(phi_h, phi); - } + Kokkos::Experimental::fill( + "TagAliveParticles", + Kokkos::DefaultExecutionSpace(), + Kokkos::subview(this_tag, std::make_pair(static_cast(0), n_alive)), + ParticleTag::alive); + + Kokkos::Experimental::fill( + "TagDeadParticles", + Kokkos::DefaultExecutionSpace(), + Kokkos::subview(this_tag, std::make_pair(n_alive, n_alive + n_dead)), + ParticleTag::dead); + + set_npart(n_alive); + m_is_sorted = true; } template struct Particles; diff --git a/src/framework/containers/particles.h b/src/framework/containers/particles.h index b4831b64a..8ff74be33 100644 --- a/src/framework/containers/particles.h +++ b/src/framework/containers/particles.h @@ -37,8 +37,8 @@ namespace ntt { struct Particles : public ParticleSpecies { private: // Number of currently active (used) particles - std::size_t m_npart { 0 }; - bool m_is_sorted { false }; + npart_t m_npart { 0 }; + bool m_is_sorted { false }; #if !defined(MPI_ENABLED) const std::size_t m_ntags { 2 }; @@ -48,31 +48,22 @@ namespace ntt { public: // Cell indices of the current particle - array_t i1, i2, i3; + array_t i1, i2, i3; // Displacement of a particle within the cell - array_t dx1, dx2, dx3; + array_t dx1, dx2, dx3; // Three spatial components of the covariant 4-velocity (physical units) - array_t ux1, ux2, ux3; + array_t ux1, ux2, ux3; // Particle weights. - array_t weight; + array_t weight; // Previous timestep coordinates - array_t i1_prev, i2_prev, i3_prev; - array_t dx1_prev, dx2_prev, dx3_prev; + array_t i1_prev, i2_prev, i3_prev; + array_t dx1_prev, dx2_prev, dx3_prev; // Array to tag the particles - array_t tag; - // Array to store the particle load - std::vector> pld; + array_t tag; + // Array to store the particle payloads + array_t pld; // phi coordinate (for axisymmetry) - array_t phi; - - // host mirrors - array_mirror_t i1_h, i2_h, i3_h; - array_mirror_t dx1_h, dx2_h, dx3_h; - array_mirror_t ux1_h, ux2_h, ux3_h; - array_mirror_t weight_h; - array_mirror_t phi_h; - array_mirror_t tag_h; - std::vector> pld_h; + array_t phi; // for empty allocation Particles() {} @@ -89,11 +80,11 @@ namespace ntt { * @param cooling The cooling mechanism assigned for the species * @param npld The number of payloads for the species */ - Particles(unsigned short index, + Particles(spidx_t index, const std::string& label, float m, float ch, - std::size_t maxnpart, + npart_t maxnpart, const PrtlPusher& pusher, bool use_gca, const Cooling& cooling, @@ -125,7 +116,7 @@ namespace ntt { * @returns A 1D Kokkos range policy of size of `npart` */ inline auto rangeActiveParticles() const -> range_t { - return CreateRangePolicy({ 0 }, { npart() }); + return CreateParticleRangePolicy(0u, npart()); } /** @@ -133,7 +124,7 @@ namespace ntt { * @returns A 1D Kokkos range policy of size of `npart` */ inline auto rangeAllParticles() const -> range_t { - return CreateRangePolicy({ 0 }, { maxnpart() }); + return CreateParticleRangePolicy(0u, maxnpart()); } /* getters -------------------------------------------------------------- */ @@ -141,7 +132,7 @@ namespace ntt { * @brief Get the number of active particles */ [[nodiscard]] - auto npart() const -> std::size_t { + auto npart() const -> npart_t { return m_npart; } @@ -178,26 +169,33 @@ namespace ntt { footprint += sizeof(prtldx_t) * dx2_prev.extent(0); footprint += sizeof(prtldx_t) * dx3_prev.extent(0); footprint += sizeof(short) * tag.extent(0); - for (auto& p : pld) { - footprint += sizeof(real_t) * p.extent(0); - } - footprint += sizeof(real_t) * phi.extent(0); + footprint += sizeof(real_t) * pld.extent(0) * pld.extent(1); + footprint += sizeof(real_t) * phi.extent(0); return footprint; } /** * @brief Count the number of particles with a specific tag. - * @return The vector of counts for each tag. + * @return The vector of counts for each tag + offsets + * @note For instance, given the counts: 0 -> n0, 1 -> n1, 2 -> n2, 3 -> n3, + * ... it returns: + * ... [n0, n1, n2, n3, ...] of size ntags + * ... [n2, n2 + n3, n2 + n3 + n4, ...] of size ntags - 3 + * ... so in buffer array: + * ... tag=2 particles are offset by 0 + * ... tag=3 particles are offset by n2 + * ... tag=4 particles are offset by n2 + n3 + * ... etc. */ - [[nodiscard]] - auto npart_per_tag() const -> std::vector; + auto NpartsPerTagAndOffsets() const + -> std::pair, array_t>; /* setters -------------------------------------------------------------- */ /** * @brief Set the number of particles - * @param npart The number of particles as a std::size_t + * @param npart The number of particles as a npart_t */ - void set_npart(std::size_t n) { + void set_npart(npart_t n) { raise::ErrorIf( n > maxnpart(), fmt::format( @@ -213,15 +211,16 @@ namespace ntt { } /** - * @brief Sort particles by their tags. - * @return The vector of counts per each tag. + * @brief Move dead particles to the end of arrays */ - auto SortByTags() -> std::vector; + void RemoveDead(); /** * @brief Copy particle data from device to host. */ void SyncHostDevice(); + + // void PrintTags(); }; } // namespace ntt diff --git a/src/framework/containers/species.h b/src/framework/containers/species.h index 1f52733aa..ada0282e2 100644 --- a/src/framework/containers/species.h +++ b/src/framework/containers/species.h @@ -20,15 +20,15 @@ namespace ntt { class ParticleSpecies { protected: // Species index - const unsigned short m_index; + const spidx_t m_index; // Species label - const std::string m_label; + const std::string m_label; // Species mass in units of m0 - const float m_mass; + const float m_mass; // Species charge in units of q0 - const float m_charge; + const float m_charge; // Max number of allocated particles for the species - std::size_t m_maxnpart; + npart_t m_maxnpart; // Pusher assigned for the species const PrtlPusher m_pusher; @@ -44,7 +44,7 @@ namespace ntt { public: ParticleSpecies() - : m_index { 0 } + : m_index { 0u } , m_label { "" } , m_mass { 0.0 } , m_charge { 0.0 } @@ -64,11 +64,11 @@ namespace ntt { * @param maxnpart The maximum number of allocated particles for the species. * @param pusher The pusher assigned for the species. */ - ParticleSpecies(unsigned short index, + ParticleSpecies(spidx_t index, const std::string& label, float m, float ch, - std::size_t maxnpart, + npart_t maxnpart, const PrtlPusher& pusher, bool use_gca, const Cooling& cooling, @@ -91,7 +91,7 @@ namespace ntt { ~ParticleSpecies() = default; [[nodiscard]] - auto index() const -> unsigned short { + auto index() const -> spidx_t { return m_index; } @@ -111,7 +111,7 @@ namespace ntt { } [[nodiscard]] - auto maxnpart() const -> std::size_t { + auto maxnpart() const -> npart_t { return m_maxnpart; } diff --git a/src/framework/domain/checkpoint.cpp b/src/framework/domain/checkpoint.cpp new file mode 100644 index 000000000..e0e34f993 --- /dev/null +++ b/src/framework/domain/checkpoint.cpp @@ -0,0 +1,505 @@ +#include "enums.h" +#include "global.h" + +#include "utils/error.h" +#include "utils/formatting.h" +#include "utils/log.h" + +#include "metrics/kerr_schild.h" +#include "metrics/kerr_schild_0.h" +#include "metrics/minkowski.h" +#include "metrics/qkerr_schild.h" +#include "metrics/qspherical.h" +#include "metrics/spherical.h" + +#include "checkpoint/reader.h" +#include "checkpoint/writer.h" +#include "framework/domain/metadomain.h" +#include "framework/parameters.h" + +namespace ntt { + + template + void Metadomain::InitCheckpointWriter(adios2::ADIOS* ptr_adios, + const SimulationParams& params) { + raise::ErrorIf(ptr_adios == nullptr, "adios == nullptr", HERE); + raise::ErrorIf( + l_subdomain_indices().size() != 1, + "Checkpoint writing for now is only supported for one subdomain per rank", + HERE); + auto local_domain = subdomain_ptr(l_subdomain_indices()[0]); + raise::ErrorIf(local_domain->is_placeholder(), + "local_domain is a placeholder", + HERE); + + std::vector glob_shape_with_ghosts, off_ncells_with_ghosts; + for (auto d { 0u }; d < M::Dim; ++d) { + off_ncells_with_ghosts.push_back( + local_domain->offset_ncells()[d] + + 2 * N_GHOSTS * local_domain->offset_ndomains()[d]); + glob_shape_with_ghosts.push_back( + mesh().n_active()[d] + 2 * N_GHOSTS * ndomains_per_dim()[d]); + } + auto loc_shape_with_ghosts = local_domain->mesh.n_all(); + + std::vector nplds; + for (auto s { 0u }; s < local_domain->species.size(); ++s) { + nplds.push_back(local_domain->species[s].npld()); + } + + const path_t checkpoint_root = params.template get( + "checkpoint.write_path"); + + g_checkpoint_writer.init( + ptr_adios, + checkpoint_root, + params.template get("checkpoint.interval"), + params.template get("checkpoint.interval_time"), + params.template get("checkpoint.keep"), + params.template get("checkpoint.walltime")); + if (g_checkpoint_writer.enabled()) { + g_checkpoint_writer.defineFieldVariables(S, + glob_shape_with_ghosts, + off_ncells_with_ghosts, + loc_shape_with_ghosts); + g_checkpoint_writer.defineParticleVariables(M::CoordType, + M::Dim, + local_domain->species.size(), + nplds); + } + } + + template + auto Metadomain::WriteCheckpoint(const SimulationParams& params, + timestep_t current_step, + timestep_t finished_step, + simtime_t current_time, + simtime_t finished_time) -> bool { + raise::ErrorIf( + l_subdomain_indices().size() != 1, + "Checkpointing for now is only supported for one subdomain per rank", + HERE); + if (not g_checkpoint_writer.shouldSave(finished_step, finished_time) or + finished_step <= 1) { + return false; + } + auto local_domain = subdomain_ptr(l_subdomain_indices()[0]); + raise::ErrorIf(local_domain->is_placeholder(), + "local_domain is a placeholder", + HERE); + logger::Checkpoint("Writing checkpoint", HERE); + g_checkpoint_writer.beginSaving(current_step, current_time); + { + g_checkpoint_writer.saveAttrs(params, current_time); + g_checkpoint_writer.saveField("em", local_domain->fields.em); + if constexpr (S == SimEngine::GRPIC) { + g_checkpoint_writer.saveField("em0", local_domain->fields.em0); + g_checkpoint_writer.saveField("cur0", local_domain->fields.cur0); + } + std::size_t dom_offset = 0, dom_tot = 1; +#if defined(MPI_ENABLED) + dom_offset = g_mpi_rank; + dom_tot = g_mpi_size; +#endif // MPI_ENABLED + + for (auto s { 0u }; s < local_domain->species.size(); ++s) { + auto npart = local_domain->species[s].npart(); + npart_t offset = 0; + auto glob_tot = npart; +#if defined(MPI_ENABLED) + auto glob_npart = std::vector(g_ndomains); + MPI_Allgather(&npart, + 1, + mpi::get_type(), + glob_npart.data(), + 1, + mpi::get_type(), + MPI_COMM_WORLD); + glob_tot = 0; + for (auto r = 0; r < g_mpi_size; ++r) { + if (r < g_mpi_rank) { + offset += glob_npart[r]; + } + glob_tot += glob_npart[r]; + } +#endif // MPI_ENABLED + g_checkpoint_writer.savePerDomainVariable( + fmt::format("s%d_npart", s + 1), + dom_tot, + dom_offset, + npart); + if constexpr (M::Dim == Dim::_1D or M::Dim == Dim::_2D or + M::Dim == Dim::_3D) { + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_i1", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].i1); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_dx1", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].dx1); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_i1_prev", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].i1_prev); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_dx1_prev", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].dx1_prev); + } + if constexpr (M::Dim == Dim::_2D or M::Dim == Dim::_3D) { + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_i2", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].i2); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_dx2", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].dx2); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_i2_prev", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].i2_prev); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_dx2_prev", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].dx2_prev); + } + if constexpr (M::Dim == Dim::_3D) { + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_i3", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].i3); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_dx3", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].dx3); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_i3_prev", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].i3_prev); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_dx3_prev", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].dx3_prev); + } + if constexpr (M::Dim == Dim::_2D and M::CoordType != Coord::Cart) { + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_phi", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].phi); + } + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_ux1", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].ux1); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_ux2", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].ux2); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_ux3", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].ux3); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_tag", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].tag); + g_checkpoint_writer.saveParticleQuantity( + fmt::format("s%d_weight", s + 1), + glob_tot, + offset, + npart, + local_domain->species[s].weight); + + auto nplds = local_domain->species[s].npld(); + if (nplds > 0) { + g_checkpoint_writer.saveParticlePayloads(fmt::format("s%d_plds", s + 1), + nplds, + glob_tot, + offset, + npart, + local_domain->species[s].pld); + } + } + } + g_checkpoint_writer.endSaving(); + logger::Checkpoint("Checkpoint written", HERE); + return true; + } + + template + void Metadomain::ContinueFromCheckpoint(adios2::ADIOS* ptr_adios, + const SimulationParams& params) { + raise::ErrorIf(ptr_adios == nullptr, "adios == nullptr", HERE); + const path_t checkpoint_root = params.template get( + "checkpoint.read_path"); + const auto fname = checkpoint_root / + fmt::format("step-%08lu.bp", + params.template get( + "checkpoint.start_step")); + logger::Checkpoint(fmt::format("Reading checkpoint from %s", fname.c_str()), + HERE); + + adios2::IO io = ptr_adios->DeclareIO("Entity::CheckpointRead"); + io.SetEngine("BPFile"); +#if !defined(MPI_ENABLED) + adios2::Engine reader = io.Open(fname, adios2::Mode::Read); +#else + adios2::Engine reader = io.Open(fname, adios2::Mode::Read, MPI_COMM_SELF); +#endif + + reader.BeginStep(); + for (auto& ldidx : l_subdomain_indices()) { + auto& domain = g_subdomains[ldidx]; + adios2::Box range; + for (auto d { 0u }; d < M::Dim; ++d) { + range.first.push_back(domain.offset_ncells()[d] + + 2 * N_GHOSTS * domain.offset_ndomains()[d]); + range.second.push_back(domain.mesh.n_all()[d]); + } + range.first.push_back(0); + range.second.push_back(6); + checkpoint::ReadFields(io, reader, "em", range, domain.fields.em); + if constexpr (S == ntt::SimEngine::GRPIC) { + checkpoint::ReadFields(io, + reader, + "em0", + range, + domain.fields.em0); + adios2::Box range3; + for (auto d { 0u }; d < M::Dim; ++d) { + range3.first.push_back(domain.offset_ncells()[d] + + 2 * N_GHOSTS * domain.offset_ndomains()[d]); + range3.second.push_back(domain.mesh.n_all()[d]); + } + range3.first.push_back(0); + range3.second.push_back(3); + checkpoint::ReadFields(io, + reader, + "cur0", + range3, + domain.fields.cur0); + } + for (auto s { 0u }; s < domain.species.size(); ++s) { + const auto [loc_npart, offset_npart] = + checkpoint::ReadParticleCount(io, reader, s, ldidx, ndomains()); + raise::ErrorIf(loc_npart > domain.species[s].maxnpart(), + "loc_npart > domain.species[s].maxnpart()", + HERE); + if (loc_npart == 0) { + continue; + } + if constexpr (M::Dim == Dim::_1D or M::Dim == Dim::_2D or + M::Dim == Dim::_3D) { + checkpoint::ReadParticleData(io, + reader, + "i1", + s, + domain.species[s].i1, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "dx1", + s, + domain.species[s].dx1, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "i1_prev", + s, + domain.species[s].i1_prev, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "dx1_prev", + s, + domain.species[s].dx1_prev, + loc_npart, + offset_npart); + } + if constexpr (M::Dim == Dim::_2D or M::Dim == Dim::_3D) { + checkpoint::ReadParticleData(io, + reader, + "i2", + s, + domain.species[s].i2, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "dx2", + s, + domain.species[s].dx2, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "i2_prev", + s, + domain.species[s].i2_prev, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "dx2_prev", + s, + domain.species[s].dx2_prev, + loc_npart, + offset_npart); + } + if constexpr (M::Dim == Dim::_3D) { + checkpoint::ReadParticleData(io, + reader, + "i3", + s, + domain.species[s].i3, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "dx3", + s, + domain.species[s].dx3, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "i3_prev", + s, + domain.species[s].i3_prev, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "dx3_prev", + s, + domain.species[s].dx3_prev, + loc_npart, + offset_npart); + } + if constexpr (M::Dim == Dim::_2D and M::CoordType != Coord::Cart) { + checkpoint::ReadParticleData(io, + reader, + "phi", + s, + domain.species[s].phi, + loc_npart, + offset_npart); + } + checkpoint::ReadParticleData(io, + reader, + "ux1", + s, + domain.species[s].ux1, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "ux2", + s, + domain.species[s].ux2, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "ux3", + s, + domain.species[s].ux3, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "tag", + s, + domain.species[s].tag, + loc_npart, + offset_npart); + checkpoint::ReadParticleData(io, + reader, + "weight", + s, + domain.species[s].weight, + loc_npart, + offset_npart); + + const auto nplds = domain.species[s].npld(); + if (nplds > 0) { + checkpoint::ReadParticlePayloads(io, + reader, + s, + domain.species[s].pld, + nplds, + loc_npart, + offset_npart); + } + domain.species[s].set_npart(loc_npart); + } // species loop + + } // local subdomain loop + + reader.EndStep(); + reader.Close(); + logger::Checkpoint( + fmt::format("Checkpoint reading done from %s", fname.c_str()), + HERE); + } + +#define METADOMAIN_CHECKPOINTS(S, M) \ + template void Metadomain::InitCheckpointWriter(adios2::ADIOS*, \ + const SimulationParams&); \ + template auto Metadomain::WriteCheckpoint(const SimulationParams&, \ + timestep_t, \ + timestep_t, \ + simtime_t, \ + simtime_t) -> bool; \ + template void Metadomain::ContinueFromCheckpoint(adios2::ADIOS*, \ + const SimulationParams&); + METADOMAIN_CHECKPOINTS(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_CHECKPOINTS(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_CHECKPOINTS(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_CHECKPOINTS(SimEngine::SRPIC, metric::Spherical) + METADOMAIN_CHECKPOINTS(SimEngine::SRPIC, metric::QSpherical) + METADOMAIN_CHECKPOINTS(SimEngine::GRPIC, metric::KerrSchild) + METADOMAIN_CHECKPOINTS(SimEngine::GRPIC, metric::QKerrSchild) + METADOMAIN_CHECKPOINTS(SimEngine::GRPIC, metric::KerrSchild0) +#undef METADOMAIN_CHECKPOINTS + +} // namespace ntt diff --git a/src/framework/domain/comm_mpi.hpp b/src/framework/domain/comm_mpi.hpp index 63dd8271a..e0d0cb4b2 100644 --- a/src/framework/domain/comm_mpi.hpp +++ b/src/framework/domain/comm_mpi.hpp @@ -14,20 +14,233 @@ #include "enums.h" #include "global.h" +#include "arch/directions.h" #include "arch/kokkos_aliases.h" #include "arch/mpi_aliases.h" +#include "arch/mpi_tags.h" #include "utils/error.h" #include "framework/containers/particles.h" +#include "kernels/comm.hpp" + #include #include +#include #include namespace comm { using namespace ntt; + namespace flds { + template + void send_recv(ndarray_t& send_arr, + ndarray_t& recv_arr, + int send_rank, + int recv_rank, + ncells_t nsend, + ncells_t nrecv) { +#if !defined(DEVICE_ENABLED) || defined(GPU_AWARE_MPI) + MPI_Sendrecv(send_arr.data(), + nsend, + mpi::get_type(), + send_rank, + 0, + recv_arr.data(), + nrecv, + mpi::get_type(), + recv_rank, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); +#else + auto send_arr_h = Kokkos::create_mirror_view(send_arr); + auto recv_arr_h = Kokkos::create_mirror_view(recv_arr); + Kokkos::deep_copy(send_arr_h, send_arr); + MPI_Sendrecv(send_arr_h.data(), + nsend, + mpi::get_type(), + send_rank, + 0, + recv_arr_h.data(), + nrecv, + mpi::get_type(), + recv_rank, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + Kokkos::deep_copy(recv_arr, recv_arr_h); +#endif + } + + template + void send(ndarray_t& send_arr, int send_rank, ncells_t nsend) { +#if !defined(DEVICE_ENABLED) || defined(GPU_AWARE_MPI) + MPI_Send(send_arr.data(), nsend, mpi::get_type(), send_rank, 0, MPI_COMM_WORLD); +#else + auto send_arr_h = Kokkos::create_mirror_view(send_arr); + Kokkos::deep_copy(send_arr_h, send_arr); + MPI_Send(send_arr_h.data(), + nsend, + mpi::get_type(), + send_rank, + 0, + MPI_COMM_WORLD); +#endif + } + + template + void recv(ndarray_t& recv_arr, int recv_rank, ncells_t nrecv) { +#if !defined(DEVICE_ENABLED) || defined(GPU_AWARE_MPI) + MPI_Recv(recv_arr.data(), + nrecv, + mpi::get_type(), + recv_rank, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); +#else + auto recv_arr_h = Kokkos::create_mirror_view(recv_arr); + MPI_Recv(recv_arr_h.data(), + nrecv, + mpi::get_type(), + recv_rank, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + Kokkos::deep_copy(recv_arr, recv_arr_h); +#endif + } + + template + void communicate(ndarray_t& send_arr, + ndarray_t& recv_arr, + int send_rank, + int recv_rank, + ncells_t nsend, + ncells_t nrecv) { + if (send_rank >= 0 and recv_rank >= 0 and nsend > 0 and nrecv > 0) { + send_recv(send_arr, recv_arr, send_rank, recv_rank, nsend, nrecv); + } else if (send_rank >= 0 and nsend > 0) { + send(send_arr, send_rank, nsend); + } else if (recv_rank >= 0 and nrecv > 0) { + recv(recv_arr, recv_rank, nrecv); + } + } + + } // namespace flds + + namespace prtls { + template + void send_recv(array_t& send_arr, + array_t& recv_arr, + int send_rank, + int recv_rank, + npart_t nsend, + npart_t nrecv, + npart_t offset) { +#if !defined(DEVICE_ENABLED) || defined(GPU_AWARE_MPI) + MPI_Sendrecv(send_arr.data(), + nsend, + mpi::get_type(), + send_rank, + 0, + recv_arr.data() + offset, + nrecv, + mpi::get_type(), + recv_rank, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); +#else + const auto slice = std::make_pair(offset, offset + nrecv); + + auto send_arr_h = Kokkos::create_mirror_view(send_arr); + auto recv_arr_h = Kokkos::create_mirror_view( + Kokkos::subview(recv_arr, slice)); + Kokkos::deep_copy(send_arr_h, send_arr); + MPI_Sendrecv(send_arr_h.data(), + nsend, + mpi::get_type(), + send_rank, + 0, + recv_arr_h.data(), + nrecv, + mpi::get_type(), + recv_rank, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + Kokkos::deep_copy(Kokkos::subview(recv_arr, slice), recv_arr_h); +#endif + } + + template + void send(array_t& send_arr, int send_rank, npart_t nsend) { +#if !defined(DEVICE_ENABLED) || defined(GPU_AWARE_MPI) + MPI_Send(send_arr.data(), nsend, mpi::get_type(), send_rank, 0, MPI_COMM_WORLD); +#else + auto send_arr_h = Kokkos::create_mirror_view(send_arr); + Kokkos::deep_copy(send_arr_h, send_arr); + MPI_Send(send_arr_h.data(), nsend, mpi::get_type(), send_rank, 0, MPI_COMM_WORLD); +#endif + } + + template + void recv(array_t& recv_arr, int recv_rank, npart_t nrecv, npart_t offset) { +#if !defined(DEVICE_ENABLED) || defined(GPU_AWARE_MPI) + MPI_Recv(recv_arr.data() + offset, + nrecv, + mpi::get_type(), + recv_rank, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); +#else + const auto slice = std::make_pair(offset, offset + nrecv); + + auto recv_arr_h = Kokkos::create_mirror_view( + Kokkos::subview(recv_arr, slice)); + MPI_Recv(recv_arr_h.data(), + nrecv, + mpi::get_type(), + recv_rank, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + Kokkos::deep_copy(Kokkos::subview(recv_arr, slice), recv_arr_h); +#endif + } + + template + void communicate(array_t& send_arr, + array_t& recv_arr, + int send_rank, + int recv_rank, + npart_t nsend, + npart_t nrecv, + npart_t offset) { + if (send_rank >= 0 && recv_rank >= 0) { + raise::ErrorIf( + nrecv + offset > recv_arr.extent(0), + "recv_arr is not large enough to hold the received particles", + HERE); + send_recv(send_arr, recv_arr, send_rank, recv_rank, nsend, nrecv, offset); + } else if (send_rank >= 0) { + send(send_arr, send_rank, nsend); + } else if (recv_rank >= 0) { + raise::ErrorIf( + nrecv + offset > recv_arr.extent(0), + "recv_arr is not large enough to hold the received particles", + HERE); + recv(recv_arr, recv_rank, nrecv, offset); + } else { + raise::Error("CommunicateParticles called with negative ranks", HERE); + } + } + } // namespace prtls + template inline void CommunicateField(unsigned int idx, ndfield_t& fld, @@ -52,10 +265,11 @@ namespace comm { (recv_rank == rank && recv_idx != idx), "Multiple-domain single-rank communication not yet implemented", HERE); - if ((send_idx == idx) and (recv_idx == idx)) { // trivial copy if sending to self and receiving from self + if (not additive) { + // simply filling the ghost cells if constexpr (D == Dim::_1D) { Kokkos::deep_copy(Kokkos::subview(fld, recv_slice[0], comps), @@ -65,6 +279,7 @@ namespace comm { Kokkos::subview(fld, recv_slice[0], recv_slice[1], comps), Kokkos::subview(fld, send_slice[0], send_slice[1], comps)); } else if constexpr (D == Dim::_3D) { + Kokkos::deep_copy( Kokkos::subview(fld, recv_slice[0], recv_slice[1], recv_slice[2], comps), Kokkos::subview(fld, send_slice[0], send_slice[1], send_slice[2], comps)); @@ -76,7 +291,7 @@ namespace comm { (long int)(send_slice[0].first); Kokkos::parallel_for( "CommunicateField-extract", - Kokkos::MDRangePolicy, AccelExeSpace>( + Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { recv_slice[0].first, comps.first }, { recv_slice[0].second, comps.second }), Lambda(index_t i1, index_t ci) { @@ -89,7 +304,7 @@ namespace comm { (long int)(send_slice[1].first); Kokkos::parallel_for( "CommunicateField-extract", - Kokkos::MDRangePolicy, AccelExeSpace>( + Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { recv_slice[0].first, recv_slice[1].first, comps.first }, { recv_slice[0].second, recv_slice[1].second, comps.second }), Lambda(index_t i1, index_t i2, index_t ci) { @@ -104,7 +319,7 @@ namespace comm { (long int)(send_slice[2].first); Kokkos::parallel_for( "CommunicateField-extract", - Kokkos::MDRangePolicy, AccelExeSpace>( + Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { recv_slice[0].first, recv_slice[1].first, recv_slice[2].first, @@ -122,9 +337,9 @@ namespace comm { } } } else { - std::size_t nsend { comps.second - comps.first }, + ncells_t nsend { comps.second - comps.first }, nrecv { comps.second - comps.first }; - ndarray_t(D) + 1> send_fld, recv_fld; + ndarray_t(D) + 1> send_fld, recv_fld; for (short d { 0 }; d < (short)D; ++d) { if (send_rank >= 0) { @@ -177,39 +392,16 @@ namespace comm { comps.second - comps.first); } } - if (send_rank >= 0 && recv_rank >= 0) { - MPI_Sendrecv(send_fld.data(), - nsend, - mpi::get_type(), - send_rank, - 0, - recv_fld.data(), - nrecv, - mpi::get_type(), - recv_rank, - 0, - MPI_COMM_WORLD, - MPI_STATUS_IGNORE); - } else if (send_rank >= 0) { - MPI_Send(send_fld.data(), - nsend, - mpi::get_type(), - send_rank, - 0, - MPI_COMM_WORLD); - } else if (recv_rank >= 0) { - MPI_Recv(recv_fld.data(), - nrecv, - mpi::get_type(), - recv_rank, - 0, - MPI_COMM_WORLD, - MPI_STATUS_IGNORE); - } else { - raise::Error("CommunicateField called with negative ranks", HERE); - } + + flds::communicate(D) + 1>(send_fld, + recv_fld, + send_rank, + recv_rank, + nsend, + nrecv); + if (recv_rank >= 0) { - // !TODO: perhaps directly recv to the fld? + if (not additive) { if constexpr (D == Dim::_1D) { Kokkos::deep_copy(Kokkos::subview(fld, recv_slice[0], comps), recv_fld); @@ -228,7 +420,7 @@ namespace comm { const auto offset_c = comps.first; Kokkos::parallel_for( "CommunicateField-extract", - Kokkos::MDRangePolicy, AccelExeSpace>( + Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { recv_slice[0].first, comps.first }, { recv_slice[0].second, comps.second }), Lambda(index_t i1, index_t ci) { @@ -240,7 +432,7 @@ namespace comm { const auto offset_c = comps.first; Kokkos::parallel_for( "CommunicateField-extract", - Kokkos::MDRangePolicy, AccelExeSpace>( + Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { recv_slice[0].first, recv_slice[1].first, comps.first }, { recv_slice[0].second, recv_slice[1].second, comps.second }), Lambda(index_t i1, index_t i2, index_t ci) { @@ -255,7 +447,7 @@ namespace comm { const auto offset_c = comps.first; Kokkos::parallel_for( "CommunicateField-extract", - Kokkos::MDRangePolicy, AccelExeSpace>( + Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { recv_slice[0].first, recv_slice[1].first, recv_slice[2].first, @@ -276,69 +468,29 @@ namespace comm { } } - template - void CommunicateParticleQuantity(array_t& arr, - int send_rank, - int recv_rank, - const range_tuple_t& send_slice, - const range_tuple_t& recv_slice) { - const std::size_t send_count = send_slice.second - send_slice.first; - const std::size_t recv_count = recv_slice.second - recv_slice.first; - if ((send_rank >= 0) and (recv_rank >= 0) and (send_count > 0) and - (recv_count > 0)) { - MPI_Sendrecv(arr.data() + send_slice.first, - send_count, - mpi::get_type(), - send_rank, - 0, - arr.data() + recv_slice.first, - recv_count, - mpi::get_type(), - recv_rank, - 0, - MPI_COMM_WORLD, - MPI_STATUS_IGNORE); - } else if ((send_rank >= 0) and (send_count > 0)) { - MPI_Send(arr.data() + send_slice.first, - send_count, - mpi::get_type(), - send_rank, - 0, - MPI_COMM_WORLD); - } else if ((recv_rank >= 0) and (recv_count > 0)) { - MPI_Recv(arr.data() + recv_slice.first, - recv_count, - mpi::get_type(), - recv_rank, - 0, - MPI_COMM_WORLD, - MPI_STATUS_IGNORE); - } - } - - void ParticleSendRecvCount(int send_rank, - int recv_rank, - const std::size_t& send_count, - std::size_t& recv_count) { + void ParticleSendRecvCount(int send_rank, + int recv_rank, + npart_t send_count, + npart_t& recv_count) { if ((send_rank >= 0) && (recv_rank >= 0)) { MPI_Sendrecv(&send_count, 1, - mpi::get_type(), + mpi::get_type(), send_rank, 0, &recv_count, 1, - mpi::get_type(), + mpi::get_type(), recv_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); } else if (send_rank >= 0) { - MPI_Send(&send_count, 1, mpi::get_type(), send_rank, 0, MPI_COMM_WORLD); + MPI_Send(&send_count, 1, mpi::get_type(), send_rank, 0, MPI_COMM_WORLD); } else if (recv_rank >= 0) { MPI_Recv(&recv_count, 1, - mpi::get_type(), + mpi::get_type(), recv_rank, 0, MPI_COMM_WORLD, @@ -349,96 +501,142 @@ namespace comm { } template - auto CommunicateParticles(Particles& species, - int send_rank, - int recv_rank, - const range_tuple_t& send_slice, - std::size_t& index_last) -> std::size_t { - if ((send_rank < 0) && (recv_rank < 0)) { - raise::Error("No send or recv in CommunicateParticles", HERE); - } - std::size_t recv_count { 0 }; - ParticleSendRecvCount(send_rank, - recv_rank, - send_slice.second - send_slice.first, - recv_count); - - raise::FatalIf((index_last + recv_count) >= species.maxnpart(), - "Too many particles to receive (cannot fit into maxptl)", - HERE); - const auto recv_slice = range_tuple_t({ index_last, index_last + recv_count }); - - CommunicateParticleQuantity(species.i1, send_rank, recv_rank, send_slice, recv_slice); - CommunicateParticleQuantity(species.dx1, send_rank, recv_rank, send_slice, recv_slice); - CommunicateParticleQuantity(species.i1_prev, - send_rank, - recv_rank, - send_slice, - recv_slice); - CommunicateParticleQuantity(species.dx1_prev, - send_rank, - recv_rank, - send_slice, - recv_slice); - if constexpr (D == Dim::_2D || D == Dim::_3D) { - CommunicateParticleQuantity(species.i2, send_rank, recv_rank, send_slice, recv_slice); - CommunicateParticleQuantity(species.dx2, - send_rank, - recv_rank, - send_slice, - recv_slice); - CommunicateParticleQuantity(species.i2_prev, - send_rank, - recv_rank, - send_slice, - recv_slice); - CommunicateParticleQuantity(species.dx2_prev, - send_rank, - recv_rank, - send_slice, - recv_slice); - } - if constexpr (D == Dim::_3D) { - CommunicateParticleQuantity(species.i3, send_rank, recv_rank, send_slice, recv_slice); - CommunicateParticleQuantity(species.dx3, - send_rank, - recv_rank, - send_slice, - recv_slice); - CommunicateParticleQuantity(species.i3_prev, - send_rank, - recv_rank, - send_slice, - recv_slice); - CommunicateParticleQuantity(species.dx3_prev, - send_rank, - recv_rank, - send_slice, - recv_slice); - } - CommunicateParticleQuantity(species.ux1, send_rank, recv_rank, send_slice, recv_slice); - CommunicateParticleQuantity(species.ux2, send_rank, recv_rank, send_slice, recv_slice); - CommunicateParticleQuantity(species.ux3, send_rank, recv_rank, send_slice, recv_slice); - CommunicateParticleQuantity(species.weight, - send_rank, - recv_rank, - send_slice, - recv_slice); - if constexpr (D == Dim::_2D and C != Coord::Cart) { - CommunicateParticleQuantity(species.phi, - send_rank, - recv_rank, - send_slice, - recv_slice); + void CommunicateParticles(Particles& species, + const array_t& outgoing_indices, + const array_t& tag_offsets, + const std::vector& npptag_vec, + const std::vector& npptag_recv_vec, + const std::vector& send_ranks, + const std::vector& recv_ranks, + const dir::dirs_t& dirs_to_comm) { + // number of arrays of each type to send/recv + const unsigned short NREALS = 4 + static_cast( + D == Dim::_2D and C != Coord::Cart); + const unsigned short NINTS = 2 * static_cast(D); + const unsigned short NPRTLDX = 2 * static_cast(D); + const unsigned short NPLDS = species.npld(); + + // buffers to store recv data + const auto npart_dead = npptag_vec[ParticleTag::dead]; + const auto npart_recv = std::accumulate(npptag_recv_vec.begin(), + npptag_recv_vec.end(), + static_cast(0)); + array_t recv_buff_int { "recv_buff_int", npart_recv * NINTS }; + array_t recv_buff_real { "recv_buff_real", npart_recv * NREALS }; + array_t recv_buff_prtldx { "recv_buff_prtldx", npart_recv * NPRTLDX }; + array_t recv_buff_pld; + + if (NPLDS > 0) { + recv_buff_pld = array_t { "recv_buff_pld", npart_recv * NPLDS }; } - for (auto p { 0 }; p < species.npld(); ++p) { - CommunicateParticleQuantity(species.pld[p], - send_rank, - recv_rank, - send_slice, - recv_slice); + + auto iteration = 0; + auto current_received = 0; + + for (const auto& direction : dirs_to_comm) { + const auto send_rank = send_ranks[iteration]; + const auto recv_rank = recv_ranks[iteration]; + const auto tag_send = mpi::PrtlSendTag::dir2tag(direction); + const auto tag_recv = mpi::PrtlSendTag::dir2tag(-direction); + const auto npart_send_in = npptag_vec[tag_send]; + const auto npart_recv_in = npptag_recv_vec[tag_recv - 2]; + if (send_rank < 0 and recv_rank < 0) { + continue; + } + array_t send_buff_int { "send_buff_int", npart_send_in * NINTS }; + array_t send_buff_real { "send_buff_real", npart_send_in * NREALS }; + array_t send_buff_prtldx { "send_buff_prtldx", + npart_send_in * NPRTLDX }; + array_t send_buff_pld; + if (NPLDS > 0) { + send_buff_pld = array_t { "send_buff_pld", npart_send_in * NPLDS }; + } + + auto tag_offsets_h = Kokkos::create_mirror_view(tag_offsets); + Kokkos::deep_copy(tag_offsets_h, tag_offsets); + + npart_t idx_offset = npart_dead; + if (tag_send > 2) { + idx_offset += tag_offsets_h(tag_send - 3); + } + // clang-format off + Kokkos::parallel_for( + "PopulatePrtlSendBuffer", + npart_send_in, + kernel::comm::PopulatePrtlSendBuffer_kernel( + send_buff_int, send_buff_real, send_buff_prtldx, send_buff_pld, + NINTS, NREALS, NPRTLDX, NPLDS, idx_offset, + species.i1, species.i1_prev, species.dx1, species.dx1_prev, + species.i2, species.i2_prev, species.dx2, species.dx2_prev, + species.i3, species.i3_prev, species.dx3, species.dx3_prev, + species.ux1, species.ux2, species.ux3, + species.weight, species.phi, species.pld, species.tag, + outgoing_indices) + ); + // clang-format on + + const auto recv_offset_int = current_received * NINTS; + const auto recv_offset_real = current_received * NREALS; + const auto recv_offset_prtldx = current_received * NPRTLDX; + const auto recv_offset_pld = current_received * NPLDS; + + prtls::communicate(send_buff_int, + recv_buff_int, + send_rank, + recv_rank, + npart_send_in * NINTS, + npart_recv_in * NINTS, + recv_offset_int); + prtls::communicate(send_buff_real, + recv_buff_real, + send_rank, + recv_rank, + npart_send_in * NREALS, + npart_recv_in * NREALS, + recv_offset_real); + prtls::communicate(send_buff_prtldx, + recv_buff_prtldx, + send_rank, + recv_rank, + npart_send_in * NPRTLDX, + npart_recv_in * NPRTLDX, + recv_offset_prtldx); + if (NPLDS > 0) { + prtls::communicate(send_buff_pld, + recv_buff_pld, + send_rank, + recv_rank, + npart_send_in * NPLDS, + npart_recv_in * NPLDS, + recv_offset_pld); + } + current_received += npart_recv_in; + iteration++; + + } // end direction loop + + // clang-format off + Kokkos::parallel_for( + "PopulateFromRecvBuffer", + npart_recv, + kernel::comm::ExtractReceivedPrtls_kernel( + recv_buff_int, recv_buff_real, recv_buff_prtldx, recv_buff_pld, + NINTS, NREALS, NPRTLDX, NPLDS, + species.npart(), + species.i1, species.i1_prev, species.dx1, species.dx1_prev, + species.i2, species.i2_prev, species.dx2, species.dx2_prev, + species.i3, species.i3_prev, species.dx3, species.dx3_prev, + species.ux1, species.ux2, species.ux3, + species.weight, species.phi, species.pld, species.tag, + outgoing_indices) + ); + // clang-format on + + const auto npart = species.npart(); + const auto npart_holes = outgoing_indices.extent(0); + if (npart_recv > npart_holes) { + species.set_npart(npart + npart_recv - npart_holes); } - return recv_count; } } // namespace comm diff --git a/src/framework/domain/comm_nompi.hpp b/src/framework/domain/comm_nompi.hpp index 197d336fa..b477ac176 100644 --- a/src/framework/domain/comm_nompi.hpp +++ b/src/framework/domain/comm_nompi.hpp @@ -70,7 +70,7 @@ namespace comm { (long int)(send_slice[0].first); Kokkos::parallel_for( "CommunicateField-extract", - Kokkos::MDRangePolicy, AccelExeSpace>( + Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { recv_slice[0].first, comps.first }, { recv_slice[0].second, comps.second }), Lambda(index_t i1, index_t ci) { @@ -83,7 +83,7 @@ namespace comm { (long int)(send_slice[1].first); Kokkos::parallel_for( "CommunicateField-extract", - Kokkos::MDRangePolicy, AccelExeSpace>( + Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { recv_slice[0].first, recv_slice[1].first, comps.first }, { recv_slice[0].second, recv_slice[1].second, comps.second }), Lambda(index_t i1, index_t i2, index_t ci) { @@ -98,7 +98,7 @@ namespace comm { (long int)(send_slice[2].first); Kokkos::parallel_for( "CommunicateField-extract", - Kokkos::MDRangePolicy, AccelExeSpace>( + Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { recv_slice[0].first, recv_slice[1].first, recv_slice[2].first, diff --git a/src/framework/domain/communications.cpp b/src/framework/domain/communications.cpp index 60524eedd..bf6eb3dd1 100644 --- a/src/framework/domain/communications.cpp +++ b/src/framework/domain/communications.cpp @@ -20,10 +20,13 @@ #include "arch/mpi_tags.h" #include "framework/domain/comm_mpi.hpp" + #include "kernels/comm.hpp" #else #include "framework/domain/comm_nompi.hpp" #endif +#include + #include #include @@ -33,10 +36,10 @@ namespace ntt { using comm_params_t = std::pair>; template - auto GetSendRecvRanks(Metadomain* metadomain, - Domain& domain, - dir::direction_t direction) - -> std::pair { + auto GetSendRecvRanks( + Metadomain* metadomain, + Domain& domain, + dir::direction_t direction) -> std::pair { Domain* send_to_nghbr_ptr = nullptr; Domain* recv_from_nghbr_ptr = nullptr; // set pointers to the correct send/recv domains @@ -86,8 +89,8 @@ namespace ntt { } else { // no communication necessary return { - {0, -1}, - {0, -1} + { 0, -1 }, + { 0, -1 } }; } #if defined(MPI_ENABLED) @@ -110,17 +113,17 @@ namespace ntt { (void)send_rank; (void)recv_rank; return { - {send_ind, send_rank}, - {recv_ind, recv_rank} + { send_ind, send_rank }, + { recv_ind, recv_rank } }; } template - auto GetSendRecvParams(Metadomain* metadomain, - Domain& domain, - dir::direction_t direction, - bool synchronize) - -> std::pair { + auto GetSendRecvParams( + Metadomain* metadomain, + Domain& domain, + dir::direction_t direction, + bool synchronize) -> std::pair { const auto [send_indrank, recv_indrank] = GetSendRecvRanks(metadomain, domain, direction); const auto [send_ind, send_rank] = send_indrank; @@ -129,15 +132,15 @@ namespace ntt { const auto is_receiving = (recv_rank >= 0); if (not(is_sending or is_receiving)) { return { - {{ 0, -1 }, {}}, - {{ 0, -1 }, {}} + { { 0, -1 }, {} }, + { { 0, -1 }, {} } }; } auto send_slice = std::vector {}; auto recv_slice = std::vector {}; const in components[] = { in::x1, in::x2, in::x3 }; // find the field components and indices to be sent/received - for (std::size_t d { 0 }; d < direction.size(); ++d) { + for (auto d { 0u }; d < direction.size(); ++d) { const auto c = components[d]; const auto dir = direction[d]; if (not synchronize) { @@ -196,20 +199,29 @@ namespace ntt { } return { - {{ send_ind, send_rank }, send_slice}, - {{ recv_ind, recv_rank }, recv_slice}, + { { send_ind, send_rank }, send_slice }, + { { recv_ind, recv_rank }, recv_slice }, }; } template void Metadomain::CommunicateFields(Domain& domain, CommTags tags) { - const auto comm_fields = (tags & Comm::E) || (tags & Comm::B) || - (tags & Comm::J) || (tags & Comm::D) || - (tags & Comm::D0) || (tags & Comm::B0); - const bool comm_em = (tags & Comm::E) || (tags & Comm::B) || (tags & Comm::D); - const bool comm_em0 = (tags & Comm::B0) || (tags & Comm::D0); + // const auto comm_fields = (tags & Comm::E) or (tags & Comm::B) or + // (tags & Comm::J) or (tags & Comm::D) or + // (tags & Comm::D0) or (tags & Comm::B0) or + // (tags & Comm::H); + const auto comm_em = ((S == SimEngine::SRPIC) and + ((tags & Comm::E) or (tags & Comm::B))) or + ((S == SimEngine::GRPIC) and + ((tags & Comm::D) or (tags & Comm::B))); + const bool comm_em0 = (S == SimEngine::GRPIC) and + ((tags & Comm::B0) or (tags & Comm::D0)); const bool comm_j = (tags & Comm::J); - raise::ErrorIf(not comm_fields, "CommunicateFields called with no task", HERE); + const bool comm_aux = (S == SimEngine::GRPIC) and + ((tags & Comm::E) or (tags & Comm::H)); + raise::ErrorIf(not(comm_em or comm_em0 or comm_j or comm_aux), + "CommunicateFields called with no task", + HERE); std::string comms = ""; if (tags & Comm::E) { @@ -224,6 +236,9 @@ namespace ntt { if (tags & Comm::D) { comms += "D "; } + if (tags & Comm::H) { + comms += "H "; + } if (tags & Comm::D0) { comms += "D0 "; } @@ -240,16 +255,17 @@ namespace ntt { auto comp_range_fld = range_tuple_t {}; auto comp_range_cur = range_tuple_t {}; if constexpr (S == SimEngine::GRPIC) { - if (((tags & Comm::D) && (tags & Comm::B)) || - ((tags & Comm::D0) && (tags & Comm::B0))) { + if (((tags & Comm::D) and (tags & Comm::B)) or + ((tags & Comm::D0) and (tags & Comm::B0)) or + ((tags & Comm::E) and (tags & Comm::H))) { comp_range_fld = range_tuple_t(em::dx1, em::bx3 + 1); - } else if ((tags & Comm::D) || (tags & Comm::D0)) { + } else if ((tags & Comm::D) or (tags & Comm::D0) or (tags & Comm::E)) { comp_range_fld = range_tuple_t(em::dx1, em::dx3 + 1); - } else if ((tags & Comm::B) || (tags & Comm::B0)) { + } else if ((tags & Comm::B) or (tags & Comm::B0) or (tags & Comm::H)) { comp_range_fld = range_tuple_t(em::bx1, em::bx3 + 1); } } else if constexpr (S == SimEngine::SRPIC) { - if ((tags & Comm::E) && (tags & Comm::B)) { + if ((tags & Comm::E) and (tags & Comm::B)) { comp_range_fld = range_tuple_t(em::ex1, em::bx3 + 1); } else if (tags & Comm::E) { comp_range_fld = range_tuple_t(em::ex1, em::ex3 + 1); @@ -287,6 +303,19 @@ namespace ntt { false); } if constexpr (S == SimEngine::GRPIC) { + if (comm_aux) { + comm::CommunicateField(domain.index(), + domain.fields.aux, + domain.fields.aux, + send_ind, + recv_ind, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_range_fld, + false); + } if (comm_em0) { comm::CommunicateField(domain.index(), domain.fields.em0, @@ -299,20 +328,59 @@ namespace ntt { recv_slice, comp_range_fld, false); + // @HACK_GR_1.2.0 -- this has to be done carefully + // comm::CommunicateField(domain.index(), + // domain.fields.aux, + // domain.fields.aux, + // send_ind, + // recv_ind, + // send_rank, + // recv_rank, + // send_slice, + // recv_slice, + // comp_range_fld, + // false); + } + if (comm_j) { + comm::CommunicateField(domain.index(), + domain.fields.cur0, + domain.fields.cur0, + send_ind, + recv_ind, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_range_cur, + false); + } + } else { + if (comm_em) { + comm::CommunicateField(domain.index(), + domain.fields.em, + domain.fields.em, + send_ind, + recv_ind, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_range_fld, + false); + } + if (comm_j) { + comm::CommunicateField(domain.index(), + domain.fields.cur, + domain.fields.cur, + send_ind, + recv_ind, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_range_cur, + false); } - } - if (comm_j) { - comm::CommunicateField(domain.index(), - domain.fields.cur, - domain.fields.cur, - send_ind, - recv_ind, - send_rank, - recv_rank, - send_slice, - recv_slice, - comp_range_cur, - false); } } } @@ -432,17 +500,31 @@ namespace ntt { continue; } if (comm_j) { - comm::CommunicateField(domain.index(), - domain.fields.cur, - domain.fields.buff, - send_ind, - recv_ind, - send_rank, - recv_rank, - send_slice, - recv_slice, - comp_range_cur, - synchronize); + if constexpr (S == SimEngine::GRPIC) { + comm::CommunicateField(domain.index(), + domain.fields.cur0, + domain.fields.buff, + send_ind, + recv_ind, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_range_cur, + synchronize); + } else { + comm::CommunicateField(domain.index(), + domain.fields.cur, + domain.fields.buff, + send_ind, + recv_ind, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_range_cur, + synchronize); + } } if (comm_bckp) { comm::CommunicateField(domain.index(), @@ -472,10 +554,17 @@ namespace ntt { } } if (comm_j) { - AddBufferedFields(domain.fields.cur, - domain.fields.buff, - domain.mesh.rangeActiveCells(), - comp_range_cur); + if constexpr (S == SimEngine::GRPIC) { + AddBufferedFields(domain.fields.cur0, + domain.fields.buff, + domain.mesh.rangeActiveCells(), + comp_range_cur); + } else { + AddBufferedFields(domain.fields.cur, + domain.fields.buff, + domain.mesh.rangeActiveCells(), + comp_range_cur); + } } if (comm_bckp) { AddBufferedFields(domain.fields.bckp, @@ -492,167 +581,169 @@ namespace ntt { } template - void Metadomain::CommunicateParticles(Domain& domain, - timer::Timers* timers) { - raise::ErrorIf(timers == nullptr, - "Timers not passed when Comm::Prtl called", - HERE); + void Metadomain::CommunicateParticles(Domain& domain) { +#if defined(MPI_ENABLED) logger::Checkpoint("Communicating particles\n", HERE); for (auto& species : domain.species) { - // at this point particles should already by tagged in the pusher - timers->start("Sorting"); - const auto npart_per_tag = species.SortByTags(); - timers->stop("Sorting"); -#if defined(MPI_ENABLED) - timers->start("Communications"); - // only necessary when MPI is enabled - /** - * index_last - * | - * alive new dead tag1 tag2 v dead - * [ 11111111 000000000 222222222 3333333 .... nnnnnnn 00000000 ... ] - * ^ ^ - * | | - * tag_offset[tag1] -----+ +----- tag_offset[tag1] + npart_per_tag[tag1] - * "send_pmin" "send_pmax" (after last element) - */ - auto tag_offset { npart_per_tag }; - for (std::size_t i { 1 }; i < tag_offset.size(); ++i) { - tag_offset[i] += tag_offset[i - 1]; - } - for (std::size_t i { 0 }; i < tag_offset.size(); ++i) { - tag_offset[i] -= npart_per_tag[i]; - } - auto index_last = tag_offset[tag_offset.size() - 1] + - npart_per_tag[npart_per_tag.size() - 1]; - for (auto& direction : dir::Directions::all) { + const auto ntags = species.ntags(); + + // at this point particles should already be tagged in the pusher + auto [npptag_vec, tag_offsets] = species.NpartsPerTagAndOffsets(); + const auto npart_dead = npptag_vec[ParticleTag::dead]; + const auto npart_alive = npptag_vec[ParticleTag::alive]; + + const auto npart = species.npart(); + + // # of particles to receive per each tag (direction) + std::vector npptag_recv_vec(ntags - 2, 0); + // coordinate shifts per each direction + array_t shifts_in_x1 { "shifts_in_x1", ntags - 2 }; + array_t shifts_in_x2 { "shifts_in_x2", ntags - 2 }; + array_t shifts_in_x3 { "shifts_in_x3", ntags - 2 }; + auto shifts_in_x1_h = Kokkos::create_mirror_view(shifts_in_x1); + auto shifts_in_x2_h = Kokkos::create_mirror_view(shifts_in_x2); + auto shifts_in_x3_h = Kokkos::create_mirror_view(shifts_in_x3); + + // all directions requiring communication + dir::dirs_t dirs_to_comm; + + // ranks & indices of meshblock to send/recv from + std::vector send_ranks, send_inds; + std::vector recv_ranks, recv_inds; + + // total # of reaceived particles from all directions + npart_t npart_recv = 0u; + + for (const auto& direction : dir::Directions::all) { + // tags corresponding to the direction (both send & recv) + const auto tag_recv = mpi::PrtlSendTag::dir2tag(-direction); + const auto tag_send = mpi::PrtlSendTag::dir2tag(direction); + + // get indices & ranks of send/recv meshblocks const auto [send_params, recv_params] = GetSendRecvParams(this, domain, direction, true); const auto [send_indrank, send_slice] = send_params; const auto [recv_indrank, recv_slice] = recv_params; const auto [send_ind, send_rank] = send_indrank; const auto [recv_ind, recv_rank] = recv_indrank; - if (send_rank < 0 and recv_rank < 0) { + + // skip if no communication is necessary + const auto is_sending = (send_rank >= 0); + const auto is_receiving = (recv_rank >= 0); + if (not is_sending and not is_receiving) { continue; } - const auto send_dir_tag = mpi::PrtlSendTag::dir2tag(direction); - const auto nsend = npart_per_tag[send_dir_tag]; - const auto send_pmin = tag_offset[send_dir_tag]; - const auto send_pmax = tag_offset[send_dir_tag] + nsend; - const auto recv_count = comm::CommunicateParticles( - species, - send_rank, - recv_rank, - { send_pmin, send_pmax }, - index_last); - if (recv_count > 0) { - if constexpr (D == Dim::_1D) { - int shift_in_x1 { 0 }; - if ((-direction)[0] == -1) { - shift_in_x1 = -subdomain(recv_ind).mesh.n_active(in::x1); - } else if ((-direction)[0] == 1) { - shift_in_x1 = domain.mesh.n_active(in::x1); - } - auto& this_tag = species.tag; - auto& this_i1 = species.i1; - auto& this_i1_prev = species.i1_prev; - Kokkos::parallel_for( - "CommunicateParticles", - recv_count, - Lambda(index_t p) { - this_tag(index_last + p) = ParticleTag::alive; - this_i1(index_last + p) += shift_in_x1; - this_i1_prev(index_last + p) += shift_in_x1; - }); - } else if constexpr (D == Dim::_2D) { - int shift_in_x1 { 0 }, shift_in_x2 { 0 }; - if ((-direction)[0] == -1) { - shift_in_x1 = -subdomain(recv_ind).mesh.n_active(in::x1); - } else if ((-direction)[0] == 1) { - shift_in_x1 = domain.mesh.n_active()[0]; - } - if ((-direction)[1] == -1) { - shift_in_x2 = -subdomain(recv_ind).mesh.n_active(in::x2); - } else if ((-direction)[1] == 1) { - shift_in_x2 = domain.mesh.n_active(in::x2); - } - auto& this_tag = species.tag; - auto& this_i1 = species.i1; - auto& this_i2 = species.i2; - auto& this_i1_prev = species.i1_prev; - auto& this_i2_prev = species.i2_prev; - Kokkos::parallel_for( - "CommunicateParticles", - recv_count, - Lambda(index_t p) { - this_tag(index_last + p) = ParticleTag::alive; - this_i1(index_last + p) += shift_in_x1; - this_i2(index_last + p) += shift_in_x2; - this_i1_prev(index_last + p) += shift_in_x1; - this_i2_prev(index_last + p) += shift_in_x2; - }); - } else if constexpr (D == Dim::_3D) { - int shift_in_x1 { 0 }, shift_in_x2 { 0 }, shift_in_x3 { 0 }; - if ((-direction)[0] == -1) { - shift_in_x1 = -subdomain(recv_ind).mesh.n_active(in::x1); - } else if ((-direction)[0] == 1) { - shift_in_x1 = domain.mesh.n_active(in::x1); + dirs_to_comm.push_back(direction); + send_ranks.push_back(send_rank); + recv_ranks.push_back(recv_rank); + send_inds.push_back(send_ind); + recv_inds.push_back(recv_ind); + + // record the # of particles to-be-sent + const auto nsend = npptag_vec[tag_send]; + + // request the # of particles to-be-received ... + // ... and send the # of particles to-be-sent + npart_t nrecv = 0; + comm::ParticleSendRecvCount(send_rank, recv_rank, nsend, nrecv); + npart_recv += nrecv; + npptag_recv_vec[tag_recv - 2] = nrecv; + + raise::ErrorIf((npart + npart_recv) >= species.maxnpart(), + "Too many particles to receive (cannot fit into maxptl)", + HERE); + + // if sending, record displacements to apply before + // ... tag_send - 2: because we only shift tags > 2 (i.e. no dead/alive) + if (is_sending) { + if constexpr (D == Dim::_1D || D == Dim::_2D || D == Dim::_3D) { + if (direction[0] == -1) { + // sending backwards in x1 (add sx1 of target meshblock) + shifts_in_x1_h(tag_send - 2) = subdomain(send_ind).mesh.n_active( + in::x1); + } else if (direction[0] == 1) { + // sending forward in x1 (subtract sx1 of source meshblock) + shifts_in_x1_h(tag_send - 2) = -domain.mesh.n_active(in::x1); } - if ((-direction)[1] == -1) { - shift_in_x2 = -subdomain(recv_ind).mesh.n_active(in::x2); - } else if ((-direction)[1] == 1) { - shift_in_x2 = domain.mesh.n_active(in::x2); + } + if constexpr (D == Dim::_2D || D == Dim::_3D) { + if (direction[1] == -1) { + shifts_in_x2_h(tag_send - 2) = subdomain(send_ind).mesh.n_active( + in::x2); + } else if (direction[1] == 1) { + shifts_in_x2_h(tag_send - 2) = -domain.mesh.n_active(in::x2); } - if ((-direction)[2] == -1) { - shift_in_x3 = -subdomain(recv_ind).mesh.n_active(in::x3); - } else if ((-direction)[2] == 1) { - shift_in_x3 = domain.mesh.n_active(in::x3); + } + if constexpr (D == Dim::_3D) { + if (direction[2] == -1) { + shifts_in_x3_h(tag_send - 2) = subdomain(send_ind).mesh.n_active( + in::x3); + } else if (direction[2] == 1) { + shifts_in_x3_h(tag_send - 2) = -domain.mesh.n_active(in::x3); } - auto& this_tag = species.tag; - auto& this_i1 = species.i1; - auto& this_i2 = species.i2; - auto& this_i3 = species.i3; - auto& this_i1_prev = species.i1_prev; - auto& this_i2_prev = species.i2_prev; - auto& this_i3_prev = species.i3_prev; - Kokkos::parallel_for( - "CommunicateParticles", - recv_count, - Lambda(index_t p) { - this_tag(index_last + p) = ParticleTag::alive; - this_i1(index_last + p) += shift_in_x1; - this_i2(index_last + p) += shift_in_x2; - this_i3(index_last + p) += shift_in_x3; - this_i1_prev(index_last + p) += shift_in_x1; - this_i2_prev(index_last + p) += shift_in_x2; - this_i3_prev(index_last + p) += shift_in_x3; - }); } - index_last += recv_count; - species.set_npart(index_last); } + } // end directions loop - Kokkos::deep_copy( - Kokkos::subview(species.tag, std::make_pair(send_pmin, send_pmax)), - ParticleTag::dead); - } - timers->stop("Communications"); - // !TODO: maybe there is a way to not sort twice - timers->start("Sorting"); + Kokkos::deep_copy(shifts_in_x1, shifts_in_x1_h); + Kokkos::deep_copy(shifts_in_x2, shifts_in_x2_h); + Kokkos::deep_copy(shifts_in_x3, shifts_in_x3_h); + + array_t outgoing_indices { "outgoing_indices", npart - npart_alive }; + // clang-format off + Kokkos::parallel_for( + "PrepareOutgoingPrtls", + species.rangeActiveParticles(), + kernel::comm::PrepareOutgoingPrtls_kernel( + shifts_in_x1, shifts_in_x2, shifts_in_x3, + outgoing_indices, + npart, npart_alive, npart_dead, ntags, + species.i1, species.i1_prev, + species.i2, species.i2_prev, + species.i3, species.i3_prev, + species.tag, tag_offsets) + ); + // clang-format on + + comm::CommunicateParticles(species, + outgoing_indices, + tag_offsets, + npptag_vec, + npptag_recv_vec, + send_ranks, + recv_ranks, + dirs_to_comm); species.set_unsorted(); - species.SortByTags(); - timers->stop("Sorting"); + } // end species loop +#else + (void)domain; #endif + } + + template + void Metadomain::RemoveDeadParticles(Domain& domain) { + for (auto& species : domain.species) { + species.RemoveDead(); } } - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; +#define METADOMAIN_COMM(S, M) \ + template void Metadomain::CommunicateFields(Domain&, CommTags); \ + template void Metadomain::SynchronizeFields(Domain&, \ + CommTags, \ + const range_tuple_t&); \ + template void Metadomain::CommunicateParticles(Domain&); \ + template void Metadomain::RemoveDeadParticles(Domain&); + + METADOMAIN_COMM(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_COMM(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_COMM(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_COMM(SimEngine::SRPIC, metric::Spherical) + METADOMAIN_COMM(SimEngine::SRPIC, metric::QSpherical) + METADOMAIN_COMM(SimEngine::GRPIC, metric::KerrSchild) + METADOMAIN_COMM(SimEngine::GRPIC, metric::QKerrSchild) + METADOMAIN_COMM(SimEngine::GRPIC, metric::KerrSchild0) + +#undef METADOMAIN_COMM } // namespace ntt diff --git a/src/framework/domain/domain.h b/src/framework/domain/domain.h index 397907fef..b6cbb985a 100644 --- a/src/framework/domain/domain.h +++ b/src/framework/domain/domain.h @@ -65,7 +65,7 @@ namespace ntt { Mesh mesh; Fields fields; std::vector> species; - random_number_pool_t random_pool { constant::RandomSeed }; + random_number_pool_t random_pool; /** * @brief constructor for "empty" allocation of non-local domain placeholders @@ -73,28 +73,30 @@ namespace ntt { Domain(bool, unsigned int index, const std::vector& offset_ndomains, - const std::vector& offset_ncells, - const std::vector& ncells, + const std::vector& offset_ncells, + const std::vector& ncells, const boundaries_t& extent, const std::map& metric_params, const std::vector&) : mesh { ncells, extent, metric_params } , fields {} , species {} + , random_pool { constant::RandomSeed } , m_index { index } , m_offset_ndomains { offset_ndomains } , m_offset_ncells { offset_ncells } {} Domain(unsigned int index, const std::vector& offset_ndomains, - const std::vector& offset_ncells, - const std::vector& ncells, + const std::vector& offset_ncells, + const std::vector& ncells, const boundaries_t& extent, const std::map& metric_params, const std::vector& species_params) : mesh { ncells, extent, metric_params } , fields { ncells } , species { species_params.begin(), species_params.end() } + , random_pool { constant::RandomSeed + static_cast(index) } , m_index { index } , m_offset_ndomains { offset_ndomains } , m_offset_ncells { offset_ncells } {} @@ -122,7 +124,7 @@ namespace ntt { } [[nodiscard]] - auto offset_ncells() const -> std::vector { + auto offset_ncells() const -> std::vector { return m_offset_ncells; } @@ -144,8 +146,7 @@ namespace ntt { } /* setters -------------------------------------------------------------- */ - auto set_neighbor_idx(const dir::direction_t& dir, unsigned int idx) - -> void { + auto set_neighbor_idx(const dir::direction_t& dir, unsigned int idx) -> void { m_neighbor_idx[dir] = idx; } @@ -155,7 +156,7 @@ namespace ntt { // offset of the domain in # of domains std::vector m_offset_ndomains; // offset of the domain in cells (# of cells in each dimension) - std::vector m_offset_ncells; + std::vector m_offset_ncells; // neighboring domain indices dir::map_t m_neighbor_idx; // MPI rank of the domain (used only when MPI enabled) @@ -163,8 +164,8 @@ namespace ntt { }; template - inline auto operator<<(std::ostream& os, const Domain& domain) - -> std::ostream& { + inline auto operator<<(std::ostream& os, + const Domain& domain) -> std::ostream& { os << "Domain #" << domain.index(); #if defined(MPI_ENABLED) os << " [MPI rank: " << domain.mpi_rank() << "]"; @@ -183,23 +184,16 @@ namespace ntt { } os << "\n"; os << std::setw(19) << std::left << " physical extent: "; - for (auto dim = 0; dim < M::Dim; ++dim) { + for (auto dim { 0u }; dim < M::Dim; ++dim) { os << std::setw(15) << std::left << fmt::format("{%.2f; %.2f}", - domain.mesh.extent(dim).first, - domain.mesh.extent(dim).second); + domain.mesh.extent(static_cast(dim)).first, + domain.mesh.extent(static_cast(dim)).second); } os << "\n neighbors:\n"; for (auto& direction : dir::Directions::all) { - auto neighbor = domain.neighbor_in(direction); - os << " " << direction; - if (neighbor != nullptr) { - os << " -> #" << neighbor->index() << "\n"; - } else { - os << " -> " - << "N/A" - << "\n"; - } + auto neighbor_idx = domain.neighbor_idx_in(direction); + os << " " << direction << " -> #" << neighbor_idx << "\n"; } os << " field boundaries:\n"; for (auto& direction : dir::Directions::orth) { diff --git a/src/framework/domain/grid.cpp b/src/framework/domain/grid.cpp index 9302386e1..f087df6f5 100644 --- a/src/framework/domain/grid.cpp +++ b/src/framework/domain/grid.cpp @@ -50,8 +50,8 @@ namespace ntt { template auto Grid::rangeCells(const box_region_t& region) const -> range_t { - tuple_t imin, imax; - for (unsigned short i = 0; i < (unsigned short)D; i++) { + tuple_t imin, imax; + for (auto i { 0u }; i < D; i++) { switch (region[i]) { case CellLayer::allLayer: imin[i] = 0; @@ -85,12 +85,11 @@ namespace ntt { return CreateRangePolicy(imin, imax); } - // !TODO: too ugly, implement a better solution (combine with device) template - auto Grid::rangeCellsOnHost(const box_region_t& region) const - -> range_h_t { - tuple_t imin, imax; - for (unsigned short i = 0; i < (unsigned short)D; i++) { + auto Grid::rangeCellsOnHost( + const box_region_t& region) const -> range_h_t { + tuple_t imin, imax; + for (auto i { 0u }; i < D; i++) { switch (region[i]) { case CellLayer::allLayer: imin[i] = 0; @@ -164,10 +163,10 @@ namespace ntt { } template - auto Grid::rangeCells(const tuple_t, D>& ranges) const - -> range_t { - tuple_t imin, imax; - for (unsigned short i = 0; i < (unsigned short)D; i++) { + auto Grid::rangeCells( + const tuple_t, D>& ranges) const -> range_t { + tuple_t imin, imax; + for (auto i { 0u }; i < D; i++) { raise::ErrorIf((ranges[i][0] < -(int)N_GHOSTS) || (ranges[i][1] > (int)N_GHOSTS), "Invalid cell layer picked", diff --git a/src/framework/domain/grid.h b/src/framework/domain/grid.h index 97a939117..87b21b1f5 100644 --- a/src/framework/domain/grid.h +++ b/src/framework/domain/grid.h @@ -73,7 +73,7 @@ namespace ntt { template struct Grid { - Grid(const std::vector& res) : m_resolution { res } { + Grid(const std::vector& res) : m_resolution { res } { raise::ErrorIf(m_resolution.size() != D, "invalid dimension", HERE); } @@ -81,7 +81,7 @@ namespace ntt { /* getters -------------------------------------------------------------- */ [[nodiscard]] - auto i_min(in i) const -> std::size_t { + auto i_min(in i) const -> ncells_t { switch (i) { case in::x1: return (m_resolution.size() > 0) ? N_GHOSTS : 0; @@ -96,7 +96,7 @@ namespace ntt { } [[nodiscard]] - auto i_max(in i) const -> std::size_t { + auto i_max(in i) const -> ncells_t { switch (i) { case in::x1: return (m_resolution.size() > 0) ? (m_resolution[0] + N_GHOSTS) : 1; @@ -111,7 +111,7 @@ namespace ntt { } [[nodiscard]] - auto n_active(in i) const -> std::size_t { + auto n_active(in i) const -> ncells_t { switch (i) { case in::x1: return (m_resolution.size() > 0) ? m_resolution[0] : 1; @@ -126,12 +126,21 @@ namespace ntt { } [[nodiscard]] - auto n_active() const -> std::vector { + auto n_active() const -> std::vector { return m_resolution; } [[nodiscard]] - auto n_all(in i) const -> std::size_t { + auto num_active() const -> ncells_t { + ncells_t total_active = 1u; + for (const auto& res : m_resolution) { + total_active *= res; + } + return total_active; + } + + [[nodiscard]] + auto n_all(in i) const -> ncells_t { switch (i) { case in::x1: return (m_resolution.size() > 0) ? (m_resolution[0] + 2 * N_GHOSTS) : 1; @@ -146,14 +155,23 @@ namespace ntt { } [[nodiscard]] - auto n_all() const -> std::vector { - std::vector nall; - for (std::size_t i = 0; i < D; ++i) { + auto n_all() const -> std::vector { + std::vector nall; + for (auto i = 0u; i < D; ++i) { nall.push_back(m_resolution[i] + 2 * N_GHOSTS); } return nall; } + [[nodiscard]] + auto num_all() const -> ncells_t { + ncells_t total_all = 1u; + for (const auto& res : n_all()) { + total_all *= res; + } + return total_all; + } + /* Ranges in the device execution space --------------------------------- */ /** * @brief Loop over all active cells (disregard ghost cells) @@ -204,7 +222,7 @@ namespace ntt { auto rangeCellsOnHost(const box_region_t&) const -> range_h_t; protected: - std::vector m_resolution; + std::vector m_resolution; }; } // namespace ntt diff --git a/src/framework/domain/mesh.h b/src/framework/domain/mesh.h index b0bd1a567..afb095f44 100644 --- a/src/framework/domain/mesh.h +++ b/src/framework/domain/mesh.h @@ -38,14 +38,14 @@ namespace ntt { M metric; - Mesh(const std::vector& res, + Mesh(const std::vector& res, const boundaries_t& ext, const std::map& metric_params) : Grid { res } , metric { res, ext, metric_params } , m_extent { ext } {} - Mesh(const std::vector& res, + Mesh(const std::vector& res, const boundaries_t& ext, const std::map& metric_params, const boundaries_t& flds_bc, @@ -74,7 +74,7 @@ namespace ntt { * @note pass Range::All to select the entire dimension */ [[nodiscard]] - auto Intersection(boundaries_t box) -> boundaries_t { + auto Intersection(boundaries_t box) const -> boundaries_t { raise::ErrorIf(box.size() != M::Dim, "Invalid box dimension", HERE); boundaries_t intersection; auto d = 0; @@ -109,7 +109,7 @@ namespace ntt { * @note pass Range::All to select the entire dimension */ [[nodiscard]] - auto Intersects(boundaries_t box) -> bool { + auto Intersects(boundaries_t box) const -> bool { raise::ErrorIf(box.size() != M::Dim, "Invalid box dimension", HERE); const auto intersection = Intersection(box); for (const auto& i : intersection) { @@ -131,15 +131,16 @@ namespace ntt { * @note indices are already shifted by N_GHOSTS (i.e. they start at N_GHOSTS not 0) */ [[nodiscard]] - auto ExtentToRange(boundaries_t box, boundaries_t incl_ghosts) - -> boundaries_t { + auto ExtentToRange( + boundaries_t box, + boundaries_t incl_ghosts) const -> boundaries_t { raise::ErrorIf(box.size() != M::Dim, "Invalid box dimension", HERE); raise::ErrorIf(incl_ghosts.size() != M::Dim, "Invalid incl_ghosts dimension", HERE); - boundaries_t range; + boundaries_t range; if (not Intersects(box)) { - for (std::size_t i { 0 }; i < box.size(); ++i) { + for (auto i { 0u }; i < box.size(); ++i) { range.push_back({ 0, 0 }); } return range; @@ -184,9 +185,9 @@ namespace ntt { raise::Error("invalid dimension", HERE); throw; } - range.push_back({ static_cast(xi_min_Cd) + + range.push_back({ static_cast(xi_min_Cd) + (incl_ghosts[d].first ? 0 : N_GHOSTS), - static_cast(xi_max_Cd) + + static_cast(xi_max_Cd) + (incl_ghosts[d].second ? 2 * N_GHOSTS : N_GHOSTS) }); } ++d; @@ -222,18 +223,18 @@ namespace ntt { auto flds_bc() const -> boundaries_t { if constexpr (D == Dim::_1D) { return { - {flds_bc_in({ -1 }), flds_bc_in({ -1 })} + { flds_bc_in({ -1 }), flds_bc_in({ -1 }) } }; } else if constexpr (D == Dim::_2D) { return { - {flds_bc_in({ -1, 0 }), flds_bc_in({ 1, 0 })}, - {flds_bc_in({ 0, -1 }), flds_bc_in({ 0, 1 })} + { flds_bc_in({ -1, 0 }), flds_bc_in({ 1, 0 }) }, + { flds_bc_in({ 0, -1 }), flds_bc_in({ 0, 1 }) } }; } else if constexpr (D == Dim::_3D) { return { - {flds_bc_in({ -1, 0, 0 }), flds_bc_in({ 1, 0, 0 })}, - {flds_bc_in({ 0, -1, 0 }), flds_bc_in({ 0, 1, 0 })}, - {flds_bc_in({ 0, 0, -1 }), flds_bc_in({ 0, 0, 1 })} + { flds_bc_in({ -1, 0, 0 }), flds_bc_in({ 1, 0, 0 }) }, + { flds_bc_in({ 0, -1, 0 }), flds_bc_in({ 0, 1, 0 }) }, + { flds_bc_in({ 0, 0, -1 }), flds_bc_in({ 0, 0, 1 }) } }; } else { raise::Error("invalid dimension", HERE); @@ -245,18 +246,18 @@ namespace ntt { auto prtl_bc() const -> boundaries_t { if constexpr (D == Dim::_1D) { return { - {prtl_bc_in({ -1 }), prtl_bc_in({ -1 })} + { prtl_bc_in({ -1 }), prtl_bc_in({ -1 }) } }; } else if constexpr (D == Dim::_2D) { return { - {prtl_bc_in({ -1, 0 }), prtl_bc_in({ 1, 0 })}, - {prtl_bc_in({ 0, -1 }), prtl_bc_in({ 0, 1 })} + { prtl_bc_in({ -1, 0 }), prtl_bc_in({ 1, 0 }) }, + { prtl_bc_in({ 0, -1 }), prtl_bc_in({ 0, 1 }) } }; } else if constexpr (D == Dim::_3D) { return { - {prtl_bc_in({ -1, 0, 0 }), prtl_bc_in({ 1, 0, 0 })}, - {prtl_bc_in({ 0, -1, 0 }), prtl_bc_in({ 0, 1, 0 })}, - {prtl_bc_in({ 0, 0, -1 }), prtl_bc_in({ 0, 0, 1 })} + { prtl_bc_in({ -1, 0, 0 }), prtl_bc_in({ 1, 0, 0 }) }, + { prtl_bc_in({ 0, -1, 0 }), prtl_bc_in({ 0, 1, 0 }) }, + { prtl_bc_in({ 0, 0, -1 }), prtl_bc_in({ 0, 0, 1 }) } }; } else { raise::Error("invalid dimension", HERE); diff --git a/src/framework/domain/metadomain.cpp b/src/framework/domain/metadomain.cpp index 26a4f3168..32275448e 100644 --- a/src/framework/domain/metadomain.cpp +++ b/src/framework/domain/metadomain.cpp @@ -23,7 +23,6 @@ #include #include -#include #include #include @@ -32,29 +31,23 @@ namespace ntt { template Metadomain::Metadomain(unsigned int global_ndomains, const std::vector& global_decomposition, - const std::vector& global_ncells, - const boundaries_t& global_extent, - const boundaries_t& global_flds_bc, - const boundaries_t& global_prtl_bc, + const std::vector& global_ncells, + const boundaries_t& global_extent, + const boundaries_t& global_flds_bc, + const boundaries_t& global_prtl_bc, const std::map& metric_params, - const std::vector& species_params -#if defined(OUTPUT_ENABLED) - , - const std::string& output_engine -#endif - ) + const std::vector& species_params) : g_ndomains { global_ndomains } , g_decomposition { global_decomposition } , g_mesh { global_ncells, global_extent, metric_params, global_flds_bc, global_prtl_bc } , g_metric_params { metric_params } - , g_species_params { species_params } -#if defined(OUTPUT_ENABLED) - , g_writer { output_engine } -#endif - { + , g_species_params { species_params } { #if defined(MPI_ENABLED) MPI_Comm_size(MPI_COMM_WORLD, &g_mpi_size); MPI_Comm_rank(MPI_COMM_WORLD, &g_mpi_rank); + raise::ErrorIf(global_ndomains != (unsigned int)g_mpi_size, + "Exactly 1 domain per MPI rank is allowed", + HERE); #endif initialValidityCheck(); @@ -110,13 +103,13 @@ namespace ntt { raise::ErrorIf(d_ncells.size() != (std::size_t)D, "Invalid number of dimensions received", HERE); - auto d_offset_ncells = std::vector> {}; + auto d_offset_ncells = std::vector> {}; auto d_offset_ndoms = std::vector> {}; for (auto& d : d_ncells) { g_ndomains_per_dim.push_back(d.size()); - auto offset_ncell = std::vector { 0 }; + auto offset_ncell = std::vector { 0 }; auto offset_ndom = std::vector { 0 }; - for (std::size_t i { 1 }; i < d.size(); ++i) { + for (auto i { 1u }; i < d.size(); ++i) { auto di = d[i - 1]; offset_ncell.push_back(offset_ncell.back() + di); offset_ndom.push_back(offset_ndom.back() + 1); @@ -127,8 +120,8 @@ namespace ntt { /* compute tensor products of the domain decompositions --------------- */ // works similar to np.meshgrid() - const auto domain_ncells = tools::TensorProduct(d_ncells); - const auto domain_offset_ncells = tools::TensorProduct( + const auto domain_ncells = tools::TensorProduct(d_ncells); + const auto domain_offset_ncells = tools::TensorProduct( d_offset_ncells); const auto domain_offset_ndoms = tools::TensorProduct( d_offset_ndoms); @@ -146,7 +139,7 @@ namespace ntt { boundaries_t l_extent; coord_t low_corner_Code { ZERO }, up_corner_Code { ZERO }; coord_t low_corner_Phys { ZERO }, up_corner_Phys { ZERO }; - for (unsigned short d { 0 }; d < (unsigned short)D; ++d) { + for (auto d { 0u }; d < D; d++) { low_corner_Code[d] = (real_t)l_offset_ncells[d]; up_corner_Code[d] = (real_t)(l_offset_ncells[d] + l_ncells[d]); } @@ -240,7 +233,6 @@ namespace ntt { template void Metadomain::redefineBoundaries() { - // !TODO: not setting CommBC for now for (unsigned int idx { 0 }; idx < g_ndomains; ++idx) { // offset of the subdomain[idx] auto& current_domain = g_subdomains[idx]; @@ -351,7 +343,7 @@ namespace ntt { } // check that local subdomains are contained in g_local_subdomain_indices auto contained_in_local = false; - for (const auto& gidx : g_local_subdomain_indices) { + for (const auto& gidx : l_subdomain_indices()) { contained_in_local |= (idx == gidx); } #if defined(MPI_ENABLED) @@ -375,6 +367,8 @@ namespace ntt { template void Metadomain::metricCompatibilityCheck() const { + const auto epsilon = std::numeric_limits::epsilon() * + static_cast(100.0); const auto dx_min = g_mesh.metric.dxMin(); auto dx_min_from_domains = std::numeric_limits::infinity(); for (unsigned int idx { 0 }; idx < g_ndomains; ++idx) { @@ -383,7 +377,7 @@ namespace ntt { dx_min_from_domains = std::min(dx_min_from_domains, current_dx_min); } raise::ErrorIf( - not cmp::AlmostEqual(dx_min, dx_min_from_domains), + not cmp::AlmostEqual_host(dx_min / dx_min_from_domains, ONE, epsilon), "dx_min is not the same across all domains: " + std::to_string(dx_min) + " " + std::to_string(dx_min_from_domains), HERE); @@ -398,13 +392,149 @@ namespace ntt { mpi::get_type(), MPI_COMM_WORLD); for (const auto& dx : dx_mins) { - raise::ErrorIf(!cmp::AlmostEqual(dx, dx_min), + raise::ErrorIf(not cmp::AlmostEqual_host(dx / dx_min, ONE, epsilon), "dx_min is not the same across all MPI ranks", HERE); } #endif } + template + void Metadomain::setFldsBC(const bc_in& dir, const FldsBC& new_bcs) { + if (dir == bc_in::Mx1) { + if constexpr (M::Dim == Dim::_1D) { + g_mesh.set_flds_bc({ -1 }, new_bcs); + } else if constexpr (M::Dim == Dim::_2D) { + g_mesh.set_flds_bc({ -1, 0 }, new_bcs); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_flds_bc({ -1, 0, 0 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Px1) { + if constexpr (M::Dim == Dim::_1D) { + g_mesh.set_flds_bc({ +1 }, new_bcs); + } else if constexpr (M::Dim == Dim::_2D) { + g_mesh.set_flds_bc({ +1, 0 }, new_bcs); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_flds_bc({ +1, 0, 0 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Mx2) { + if constexpr (M::Dim == Dim::_1D) { + raise::Error("Cannot set -x2 BCs for 1D", HERE); + } else if constexpr (M::Dim == Dim::_2D) { + g_mesh.set_flds_bc({ -1, 0 }, new_bcs); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_flds_bc({ -1, 0, 0 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Px2) { + if constexpr (M::Dim == Dim::_1D) { + raise::Error("Cannot set +x2 BCs for 1D", HERE); + } else if constexpr (M::Dim == Dim::_2D) { + g_mesh.set_flds_bc({ +1, 0 }, new_bcs); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_flds_bc({ +1, 0, 0 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Mx3) { + if constexpr (M::Dim == Dim::_1D) { + raise::Error("Cannot set -x3 BCs for 1D", HERE); + } else if constexpr (M::Dim == Dim::_2D) { + raise::Error("Cannot set -x3 BCs for 2D", HERE); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_flds_bc({ 0, 0, -1 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Px3) { + if constexpr (M::Dim == Dim::_1D) { + raise::Error("Cannot set +x3 BCs for 1D", HERE); + } else if constexpr (M::Dim == Dim::_2D) { + raise::Error("Cannot set +x3 BCs for 2D", HERE); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_flds_bc({ 0, 0, +1 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else { + raise::Error("Invalid direction", HERE); + } + redefineBoundaries(); + } + + template + void Metadomain::setPrtlBC(const bc_in& dir, const PrtlBC& new_bcs) { + if (dir == bc_in::Mx1) { + if constexpr (M::Dim == Dim::_1D) { + g_mesh.set_prtl_bc({ -1 }, new_bcs); + } else if constexpr (M::Dim == Dim::_2D) { + g_mesh.set_prtl_bc({ -1, 0 }, new_bcs); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_prtl_bc({ -1, 0, 0 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Px1) { + if constexpr (M::Dim == Dim::_1D) { + g_mesh.set_prtl_bc({ +1 }, new_bcs); + } else if constexpr (M::Dim == Dim::_2D) { + g_mesh.set_prtl_bc({ +1, 0 }, new_bcs); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_prtl_bc({ +1, 0, 0 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Mx2) { + if constexpr (M::Dim == Dim::_1D) { + raise::Error("Cannot set -x2 BCs for 1D", HERE); + } else if constexpr (M::Dim == Dim::_2D) { + g_mesh.set_prtl_bc({ -1, 0 }, new_bcs); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_prtl_bc({ -1, 0, 0 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Px2) { + if constexpr (M::Dim == Dim::_1D) { + raise::Error("Cannot set +x2 BCs for 1D", HERE); + } else if constexpr (M::Dim == Dim::_2D) { + g_mesh.set_prtl_bc({ +1, 0 }, new_bcs); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_prtl_bc({ +1, 0, 0 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Mx3) { + if constexpr (M::Dim == Dim::_1D) { + raise::Error("Cannot set -x3 BCs for 1D", HERE); + } else if constexpr (M::Dim == Dim::_2D) { + raise::Error("Cannot set -x3 BCs for 2D", HERE); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_prtl_bc({ 0, 0, -1 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else if (dir == bc_in::Px3) { + if constexpr (M::Dim == Dim::_1D) { + raise::Error("Cannot set +x3 BCs for 1D", HERE); + } else if constexpr (M::Dim == Dim::_2D) { + raise::Error("Cannot set +x3 BCs for 2D", HERE); + } else if constexpr (M::Dim == Dim::_3D) { + g_mesh.set_prtl_bc({ 0, 0, +1 }, new_bcs); + } else { + raise::Error("Invalid dimension", HERE); + } + } else { + raise::Error("Invalid direction", HERE); + } + redefineBoundaries(); + } + template struct Metadomain>; template struct Metadomain>; template struct Metadomain>; diff --git a/src/framework/domain/metadomain.h b/src/framework/domain/metadomain.h index fb81fcfca..7ddacffb3 100644 --- a/src/framework/domain/metadomain.h +++ b/src/framework/domain/metadomain.h @@ -19,20 +19,24 @@ #include "global.h" #include "arch/kokkos_aliases.h" -#include "utils/timer.h" #include "framework/containers/species.h" #include "framework/domain/domain.h" #include "framework/domain/mesh.h" #include "framework/parameters.h" +#include "output/stats.h" #if defined(MPI_ENABLED) #include #endif // MPI_ENABLED -#if defined OUTPUT_ENABLED +#if defined(OUTPUT_ENABLED) + #include "checkpoint/writer.h" #include "output/writer.h" -#endif + + #include + #include +#endif // OUTPUT_ENABLED #include #include @@ -70,21 +74,22 @@ namespace ntt { template void runOnLocalDomains(Func func, Args&&... args) { - for (auto& ldidx : g_local_subdomain_indices) { + for (auto& ldidx : l_subdomain_indices()) { func(g_subdomains[ldidx], std::forward(args)...); } } template void runOnLocalDomainsConst(Func func, Args&&... args) const { - for (auto& ldidx : g_local_subdomain_indices) { + for (auto& ldidx : l_subdomain_indices()) { func(g_subdomains[ldidx], std::forward(args)...); } } void CommunicateFields(Domain&, CommTags); void SynchronizeFields(Domain&, CommTags, const range_tuple_t& = { 0, 0 }); - void CommunicateParticles(Domain&, timer::Timers*); + void CommunicateParticles(Domain&); + void RemoveDeadParticles(Domain&); /** * @param global_ndomains total number of domains @@ -95,39 +100,57 @@ namespace ntt { * @param global_prtl_bc boundary conditions for particles * @param metric_params parameters for the metric * @param species_params parameters for the particle species - * @param output_params parameters for the output */ Metadomain(unsigned int, const std::vector&, - const std::vector&, + const std::vector&, const boundaries_t&, const boundaries_t&, const boundaries_t&, const std::map&, - const std::vector& -#if defined(OUTPUT_ENABLED) - , - const std::string& -#endif - ); + const std::vector&); + + Metadomain(const Metadomain&) = delete; + Metadomain& operator=(const Metadomain&) = delete; + + ~Metadomain() = default; #if defined(OUTPUT_ENABLED) - void InitWriter(const SimulationParams&); + void InitWriter(adios2::ADIOS*, const SimulationParams&); auto Write(const SimulationParams&, - std::size_t, - long double, + timestep_t, + timestep_t, + simtime_t, + simtime_t, std::function&, - std::size_t, - const Domain&)> = {}) -> bool; + index_t, + timestep_t, + simtime_t, + const Domain&)> = nullptr) -> bool; + void InitCheckpointWriter(adios2::ADIOS*, const SimulationParams&); + auto WriteCheckpoint(const SimulationParams&, + timestep_t, + timestep_t, + simtime_t, + simtime_t) -> bool; + + void ContinueFromCheckpoint(adios2::ADIOS*, const SimulationParams&); #endif - Metadomain(const Metadomain&) = delete; - Metadomain& operator=(const Metadomain&) = delete; - - ~Metadomain() = default; + void InitStatsWriter(const SimulationParams&, bool); + auto WriteStats( + const SimulationParams&, + timestep_t, + timestep_t, + simtime_t, + simtime_t, + std::function&)> = + nullptr) -> bool; /* setters -------------------------------------------------------------- */ + void setFldsBC(const bc_in&, const FldsBC&); + void setPrtlBC(const bc_in&, const PrtlBC&); /* getters -------------------------------------------------------------- */ [[nodiscard]] @@ -163,10 +186,60 @@ namespace ntt { } [[nodiscard]] - auto local_subdomain_indices() const -> std::vector { + auto l_subdomain_indices() const -> std::vector { return g_local_subdomain_indices; } + [[nodiscard]] + auto l_npart_perspec() const -> std::vector { + std::vector npart(g_species_params.size(), 0); + for (const auto& ldidx : l_subdomain_indices()) { + for (std::size_t i = 0; i < g_species_params.size(); ++i) { + npart[i] += g_subdomains[ldidx].species[i].npart(); + } + } + return npart; + } + + [[nodiscard]] + auto l_maxnpart_perspec() const -> std::vector { + std::vector maxnpart(g_species_params.size(), 0); + for (const auto& ldidx : l_subdomain_indices()) { + for (std::size_t i = 0; i < g_species_params.size(); ++i) { + maxnpart[i] += g_subdomains[ldidx].species[i].maxnpart(); + } + } + return maxnpart; + } + + [[nodiscard]] + auto l_npart() const -> std::size_t { + const auto npart = l_npart_perspec(); + return std::accumulate(npart.begin(), npart.end(), 0); + } + + [[nodiscard]] + auto l_ncells() const -> std::size_t { + std::size_t ncells_local = 0; + for (const auto& ldidx : l_subdomain_indices()) { + std::size_t ncells = 1; + for (const auto& n : g_subdomains[ldidx].mesh.n_all()) { + ncells *= n; + } + ncells_local += ncells; + } + return ncells_local; + } + + [[nodiscard]] + auto species_labels() const -> std::vector { + std::vector labels; + for (const auto& sp : g_species_params) { + labels.push_back(sp.label()); + } + return labels; + } + private: // domain information unsigned int g_ndomains; @@ -183,8 +256,14 @@ namespace ntt { const std::map g_metric_params; const std::vector g_species_params; + stats::Writer g_stats_writer; + #if defined(OUTPUT_ENABLED) - out::Writer g_writer; + out::Writer g_writer; + checkpoint::Writer g_checkpoint_writer; + #if defined(MPI_ENABLED) + void CommunicateVectorPotential(unsigned short); + #endif #endif #if defined(MPI_ENABLED) diff --git a/src/framework/domain/output.cpp b/src/framework/domain/output.cpp index be154ce16..960b2c713 100644 --- a/src/framework/domain/output.cpp +++ b/src/framework/domain/output.cpp @@ -18,6 +18,7 @@ #include "framework/domain/metadomain.h" #include "framework/parameters.h" +#include "kernels/divergences.hpp" #include "kernels/fields_to_phys.hpp" #include "kernels/particle_moments.hpp" #include "kernels/prtls_to_phys.hpp" @@ -37,16 +38,16 @@ namespace ntt { template - void Metadomain::InitWriter(const SimulationParams& params) { + void Metadomain::InitWriter(adios2::ADIOS* ptr_adios, + const SimulationParams& params) { raise::ErrorIf( - local_subdomain_indices().size() != 1, + l_subdomain_indices().size() != 1, "Output for now is only supported for one subdomain per rank", HERE); - auto local_domain = subdomain_ptr(local_subdomain_indices()[0]); + auto local_domain = subdomain_ptr(l_subdomain_indices()[0]); raise::ErrorIf(local_domain->is_placeholder(), "local_domain is a placeholder", HERE); - const auto incl_ghosts = params.template get("output.debug.ghosts"); auto glob_shape_with_ghosts = mesh().n_active(); @@ -61,9 +62,16 @@ namespace ntt { } } + g_writer.init(ptr_adios, + params.template get("output.format"), + params.template get("simulation.name"), + params.template get("output.separate_files")); g_writer.defineMeshLayout(glob_shape_with_ghosts, off_ncells_with_ghosts, loc_shape_with_ghosts, + { local_domain->index(), ndomains() }, + params.template get>( + "output.fields.downsampling"), incl_ghosts, M::CoordType); const auto fields_to_write = params.template get>( @@ -76,21 +84,27 @@ namespace ntt { custom_fields_to_write.begin(), custom_fields_to_write.end(), std::back_inserter(all_fields_to_write)); - const auto species_to_write = params.template get>( + const auto species_to_write = params.template get>( "output.particles.species"); g_writer.defineFieldOutputs(S, all_fields_to_write); - g_writer.defineParticleOutputs(M::PrtlDim, species_to_write); + + Dimension dim = M::PrtlDim; + if constexpr (M::CoordType != Coord::Cart) { + dim = Dim::_3D; + } + g_writer.defineParticleOutputs(dim, species_to_write); + // spectra write all particle species - std::vector spectra_species {}; + std::vector spectra_species {}; for (const auto& sp : species_params()) { spectra_species.push_back(sp.index()); } g_writer.defineSpectraOutputs(spectra_species); for (const auto& type : { "fields", "particles", "spectra" }) { g_writer.addTracker(type, - params.template get( + params.template get( "output." + std::string(type) + ".interval"), - params.template get( + params.template get( "output." + std::string(type) + ".interval_time")); } g_writer.writeAttrs(params); @@ -100,11 +114,11 @@ namespace ntt { void ComputeMoments(const SimulationParams& params, const Mesh& mesh, const std::vector>& prtl_species, - const std::vector& species, + const std::vector& species, const std::vector& components, ndfield_t& buffer, - unsigned short buff_idx) { - std::vector specs = species; + idx_t buff_idx) { + std::vector specs = species; if (specs.size() == 0) { // if no species specified, take all massive species for (auto& sp : prtl_species) { @@ -113,6 +127,11 @@ namespace ntt { } } } + for (const auto& sp : specs) { + raise::ErrorIf((sp > prtl_species.size()) or (sp == 0), + "Invalid species index " + std::to_string(sp), + HERE); + } auto scatter_buff = Kokkos::Experimental::create_scatter_view(buffer); // some parameters @@ -147,7 +166,7 @@ namespace ntt { ndfield_t& fld_to, const range_tuple_t& from, const range_tuple_t& to) { - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { raise::ErrorIf(fld_from.extent(d) != fld_to.extent(d), "Fields have different sizes " + std::to_string(fld_from.extent(d)) + @@ -167,77 +186,249 @@ namespace ntt { } } + template + void ComputeVectorPotential(ndfield_t& buffer, + ndfield_t& EM, + unsigned short buff_idx, + const Mesh mesh) { + if constexpr (M::Dim == Dim::_2D) { + const auto metric = mesh.metric; + Kokkos::parallel_for( + "ComputeVectorPotential", + mesh.rangeActiveCells(), + Lambda(index_t i1, index_t i2) { + const real_t i1_ { COORD(i1) }; + const ncells_t k_min = 0; + const ncells_t k_max = (i2 - (N_GHOSTS)); + real_t A3 = ZERO; + for (auto k { k_min }; k <= k_max; ++k) { + real_t k_ = static_cast(k); + real_t sqrt_detH_ij1 { metric.sqrt_det_h({ i1_, k_ - HALF }) }; + real_t sqrt_detH_ij2 { metric.sqrt_det_h({ i1_, k_ + HALF }) }; + auto k1 { k + N_GHOSTS }; + A3 += HALF * (sqrt_detH_ij1 * EM(i1, k + N_GHOSTS - 1, em::bx1) + + sqrt_detH_ij2 * EM(i1, k + N_GHOSTS, em::bx1)); + } + buffer(i1, i2, buff_idx) = A3; + }); + + // @TODO: Implementation with team policies works on AMD, but not on NVIDIA GPUs + // + // using TeamPolicy = Kokkos::TeamPolicy; + // const auto nx1 = mesh.n_active(in::x1); + // const auto nx2 = mesh.n_active(in::x2); + // + // TeamPolicy policy(nx1, Kokkos::AUTO); + // + // Kokkos::parallel_for( + // "ComputeVectorPotential", + // policy, + // Lambda(const TeamPolicy::member_type& team_member) { + // index_t i1 = team_member.league_rank(); + // Kokkos::parallel_scan( + // Kokkos::TeamThreadRange(team_member, nx2), + // [=](index_t i2, real_t& update, const bool final_pass) { + // const auto i1_ { static_cast(i1) }; + // const auto i2_ { static_cast(i2) }; + // const real_t sqrt_detH_ijM { metric.sqrt_det_h({ i1_, i2_ - HALF }) }; + // const real_t sqrt_detH_ijP { metric.sqrt_det_h({ i1_, i2_ + HALF }) }; + // const auto input_val = + // HALF * + // (sqrt_detH_ijM * EM(i1 + N_GHOSTS, i2 + N_GHOSTS - 1, em::bx1) + + // sqrt_detH_ijP * EM(i1 + N_GHOSTS, i2 + N_GHOSTS, em::bx1)); + // if (final_pass) { + // buffer(i1 + N_GHOSTS, i2 + N_GHOSTS, buff_idx) = update; + // } + // update += input_val; + // }); + // }); + } else { + raise::KernelError( + HERE, + "ComputeVectorPotential: 2D implementation called for D != 2"); + } + } + +#if defined(MPI_ENABLED) && defined(OUTPUT_ENABLED) + template + void ExtractVectorPotential(ndfield_t& buffer, + array_t& aphi_r, + unsigned short buff_idx, + const Mesh mesh) { + Kokkos::parallel_for( + "AddVectorPotential", + mesh.rangeActiveCells(), + Lambda(index_t i1, index_t i2) { + buffer(i1, i2, buff_idx) += aphi_r(i1 - N_GHOSTS); + }); + } + + template + void Metadomain::CommunicateVectorPotential(unsigned short buff_idx) { + if constexpr (M::Dim == Dim::_2D) { + auto local_domain = subdomain_ptr(l_subdomain_indices()[0]); + const auto nx1 = local_domain->mesh.n_active(in::x1); + const auto nx2 = local_domain->mesh.n_active(in::x2); + + auto& buffer = local_domain->fields.bckp; + + const auto nranks_x1 = ndomains_per_dim()[0]; + const auto nranks_x2 = ndomains_per_dim()[1]; + + for (auto nr2 { 1u }; nr2 < nranks_x2; ++nr2) { + const auto rank_send_pre = (nr2 - 1u) * nranks_x1; + const auto rank_recv_pre = nr2 * nranks_x1; + for (auto nr1 { 0u }; nr1 < nranks_x1; ++nr1) { + const auto rank_send = rank_send_pre + nr1; + const auto rank_recv = rank_recv_pre + nr1; + if (local_domain->mpi_rank() == rank_send) { + array_t aphi_r { "Aphi_r", nx1 }; + Kokkos::deep_copy( + aphi_r, + Kokkos::subview(buffer, + std::make_pair(N_GHOSTS, N_GHOSTS + nx1), + N_GHOSTS + nx2 - 1, + buff_idx)); + MPI_Send(aphi_r.data(), + nx1, + mpi::get_type(), + rank_recv, + 0, + MPI_COMM_WORLD); + } else if (local_domain->mpi_rank() == rank_recv) { + array_t aphi_r { "Aphi_r", nx1 }; + MPI_Recv(aphi_r.data(), + nx1, + mpi::get_type(), + rank_send, + 0, + MPI_COMM_WORLD, + MPI_STATUS_IGNORE); + ExtractVectorPotential(buffer, aphi_r, buff_idx, local_domain->mesh); + } + } + } + } else { + raise::Error("CommunicateVectorPotential: comm vector potential only " + "possible for 2D", + HERE); + } + } +#endif + template auto Metadomain::Write( - const SimulationParams& params, - std::size_t step, - long double time, - std::function< - void(const std::string&, ndfield_t&, std::size_t, const Domain&)> - CustomFieldOutput) -> bool { + const SimulationParams& params, + timestep_t current_step, + timestep_t finished_step, + simtime_t current_time, + simtime_t finished_time, + std::function&, + index_t, + timestep_t, + simtime_t, + const Domain&)> CustomFieldOutput) -> bool { raise::ErrorIf( - local_subdomain_indices().size() != 1, + l_subdomain_indices().size() != 1, "Output for now is only supported for one subdomain per rank", HERE); const auto write_fields = params.template get( "output.fields.enable") and - g_writer.shouldWrite("fields", step, time); + g_writer.shouldWrite("fields", + finished_step, + finished_time); const auto write_particles = params.template get( "output.particles.enable") and - g_writer.shouldWrite("particles", step, time); + g_writer.shouldWrite("particles", + finished_step, + finished_time); const auto write_spectra = params.template get( "output.spectra.enable") and - g_writer.shouldWrite("spectra", step, time); - if (not(write_fields or write_particles or write_spectra)) { + g_writer.shouldWrite("spectra", + finished_step, + finished_time); + const auto extension = params.template get("output.format"); + if (not(write_fields or write_particles or write_spectra) and + extension != "disabled") { return false; } - auto local_domain = subdomain_ptr(local_subdomain_indices()[0]); + auto local_domain = subdomain_ptr(l_subdomain_indices()[0]); raise::ErrorIf(local_domain->is_placeholder(), "local_domain is a placeholder", HERE); logger::Checkpoint("Writing output", HERE); - g_writer.beginWriting(params.template get("simulation.name"), - step, - time); - if (write_fields) { + g_writer.beginWriting(WriteMode::Fields, current_step, current_time); const auto incl_ghosts = params.template get("output.debug.ghosts"); + const auto dwn = params.template get>( + "output.fields.downsampling"); + + auto off_ncells_with_ghosts = local_domain->offset_ncells(); + auto loc_shape_with_ghosts = local_domain->mesh.n_active(); + { // compute positions/sizes of meshblocks in cells in all dimensions + const auto off_ndomains = local_domain->offset_ndomains(); + if (incl_ghosts) { + for (auto d { 0 }; d <= M::Dim; ++d) { + off_ncells_with_ghosts[d] += 2 * N_GHOSTS * off_ndomains[d]; + loc_shape_with_ghosts[d] += 2 * N_GHOSTS; + } + } + } + for (auto dim { 0u }; dim < M::Dim; ++dim) { + const auto l_size = local_domain->mesh.n_active()[dim]; + const auto l_offset = local_domain->offset_ncells()[dim]; + const auto g_size = mesh().n_active()[dim]; + + const auto dwn_in_dim = dwn[dim]; + + const double n = l_size; + const double d = dwn_in_dim; + const double l = l_offset; + const double f = math::ceil(l / d) * d - l; + + const auto first_cell = static_cast(f); + const auto l_size_dwn = static_cast(math::ceil((n - f) / d)); + + const auto is_last = l_offset + l_size == g_size; + + const auto add_ghost = (incl_ghosts ? 2 * N_GHOSTS : 0); + const auto add_last = (is_last ? 1 : 0); + + array_t xc { "Xc", l_size_dwn + add_ghost }; + array_t xe { "Xe", l_size_dwn + add_ghost + add_last }; + + const auto offset = (incl_ghosts ? N_GHOSTS : 0); + const auto ncells = l_size_dwn; + + const auto& metric = local_domain->mesh.metric; - for (unsigned short dim = 0; dim < M::Dim; ++dim) { - const auto is_last = local_domain->offset_ncells()[dim] + - local_domain->mesh.n_active()[dim] == - mesh().n_active()[dim]; - array_t xc { "Xc", - local_domain->mesh.n_active()[dim] + - (incl_ghosts ? 2 * N_GHOSTS : 0) }; - array_t xe { "Xe", - local_domain->mesh.n_active()[dim] + - (incl_ghosts ? 2 * N_GHOSTS : 0) + - (is_last ? 1 : 0) }; - const auto offset = (incl_ghosts ? N_GHOSTS : 0); - const auto ncells = local_domain->mesh.n_active()[dim]; - const auto& metric = local_domain->mesh.metric; Kokkos::parallel_for( "GenerateMesh", ncells, - Lambda(index_t i) { + Lambda(index_t i_dwn) { + const auto i = first_cell + i_dwn * dwn_in_dim; const auto i_ = static_cast(i); coord_t x_Cd { ZERO }, x_Ph { ZERO }; x_Cd[dim] = i_ + HALF; + // TODO : change to convert by component metric.template convert(x_Cd, x_Ph); - xc(offset + i) = x_Ph[dim]; - x_Cd[dim] = i_; + xc(offset + i_dwn) = x_Ph[dim]; + x_Cd[dim] = i_; metric.template convert(x_Cd, x_Ph); - xe(offset + i) = x_Ph[dim]; - if (is_last && i == ncells - 1) { + xe(offset + i_dwn) = x_Ph[dim]; + if (is_last && i_dwn == ncells - 1) { x_Cd[dim] = i_ + ONE; metric.template convert(x_Cd, x_Ph); - xe(offset + i + 1) = x_Ph[dim]; + xe(offset + i_dwn + 1) = x_Ph[dim]; } }); - g_writer.writeMesh(dim, xc, xe); + g_writer.writeMesh( + dim, + xc, + xe, + { off_ncells_with_ghosts[dim], loc_shape_with_ghosts[dim] }); } - const auto output_asis = params.template get("output.debug.as_is"); // !TODO: this can probably be optimized to dump things at once for (auto& fld : g_writer.fieldWriters()) { @@ -250,7 +441,7 @@ namespace ntt { if (fld.is_moment()) { // output a particle distribution moment (single component) // this includes T, Rho, Charge, N, Nppc - const auto c = static_cast(addresses.back()); + const auto c = static_cast(addresses.back()); if (fld.id() == FldsID::T) { raise::ErrorIf(fld.comp.size() != 1, "Wrong # of components requested for T output", @@ -294,19 +485,58 @@ namespace ntt { {}, local_domain->fields.bckp, c); + } else if (fld.id() == FldsID::V) { + if constexpr (S != SimEngine::GRPIC) { + ComputeMoments(params, + local_domain->mesh, + local_domain->species, + fld.species, + fld.comp[0], + local_domain->fields.bckp, + c); + } else { + raise::Error("Bulk velocity not supported for GRPIC", HERE); + } } else { raise::Error("Wrong moment requested for output", HERE); } + } else if (fld.is_divergence()) { + // @TODO: is this correct for GR too? not em0? + const auto c = static_cast(addresses.back()); + Kokkos::parallel_for( + "ComputeDivergence", + local_domain->mesh.rangeActiveCells(), + kernel::ComputeDivergence_kernel(local_domain->mesh.metric, + local_domain->fields.em, + local_domain->fields.bckp, + c)); } else if (fld.is_custom()) { if (CustomFieldOutput) { CustomFieldOutput(fld.name().substr(1), local_domain->fields.bckp, addresses.back(), + finished_step, + finished_time, *local_domain); } else { raise::Error("Custom output requested but no function provided", HERE); } + } else if (fld.is_vpotential()) { + if constexpr (S == SimEngine::GRPIC && M::Dim == Dim::_2D) { + const auto c = static_cast(addresses.back()); + ComputeVectorPotential(local_domain->fields.bckp, + local_domain->fields.em, + c, + local_domain->mesh); +#if defined(MPI_ENABLED) + CommunicateVectorPotential(c); +#endif + } else { + raise::Error( + "Vector potential can only be computed for GRPIC in 2D", + HERE); + } } else { raise::Error("Wrong # of components requested for " "non-moment/non-custom output", @@ -322,17 +552,36 @@ namespace ntt { } if (fld.is_moment()) { for (auto i = 0; i < 3; ++i) { - const auto c = static_cast(addresses[i]); - raise::ErrorIf(fld.comp[i].size() != 2, - "Wrong # of components requested for moment", - HERE); - ComputeMoments(params, - local_domain->mesh, - local_domain->species, - fld.species, - fld.comp[i], - local_domain->fields.bckp, - c); + const auto c = static_cast(addresses[i]); + if (fld.id() == FldsID::T) { + raise::ErrorIf(fld.comp[i].size() != 2, + "Wrong # of components requested for moment", + HERE); + ComputeMoments(params, + local_domain->mesh, + local_domain->species, + fld.species, + fld.comp[i], + local_domain->fields.bckp, + c); + } else if (fld.id() == FldsID::V) { + raise::ErrorIf(fld.comp[i].size() != 1, + "Wrong # of components requested for 3vel", + HERE); + if constexpr (S == SimEngine::SRPIC) { + ComputeMoments(params, + local_domain->mesh, + local_domain->species, + fld.species, + fld.comp[i], + local_domain->fields.bckp, + c); + } else { + raise::Error("Bulk velocity not supported for GRPIC", HERE); + } + } else { + raise::Error("Wrong moment requested for output", HERE); + } } raise::ErrorIf(addresses[1] - addresses[0] != addresses[2] - addresses[1], @@ -341,6 +590,28 @@ namespace ntt { SynchronizeFields(*local_domain, Comm::Bckp, { addresses[0], addresses[2] + 1 }); + if constexpr (S == SimEngine::SRPIC) { + if (fld.id() == FldsID::V) { + // normalize 3vel * rho (combuted above) by rho + ComputeMoments(params, + local_domain->mesh, + local_domain->species, + fld.species, + {}, + local_domain->fields.bckp, + 0u); + SynchronizeFields(*local_domain, Comm::Bckp, { 0, 1 }); + Kokkos::parallel_for("NormalizeVectorByRho", + local_domain->mesh.rangeActiveCells(), + kernel::NormalizeVectorByRho_kernel( + local_domain->fields.bckp, + local_domain->fields.bckp, + 0, + addresses[0], + addresses[1], + addresses[2])); + } + } } else { // copy fields to bckp (:, 0, 1, 2) // if as-is specified ==> copy directly to 3, 4, 5 @@ -389,8 +660,8 @@ namespace ntt { if (not output_asis) { // copy fields from bckp(:, 0, 1, 2) -> bckp(:, 3, 4, 5) // converting to proper basis and properly interpolating - list_t comp_from = { 0, 1, 2 }; - list_t comp_to = { 3, 4, 5 }; + list_t comp_from = { 0, 1, 2 }; + list_t comp_to = { 3, 4, 5 }; DeepCopyFields(local_domain->fields.bckp, local_domain->fields.bckp, { 0, 3 }, @@ -413,7 +684,7 @@ namespace ntt { for (auto i = 0; i < 6; ++i) { names.push_back(fld.name(i)); addresses.push_back(i); - const auto c = static_cast(addresses.back()); + const auto c = static_cast(addresses.back()); raise::ErrorIf(fld.comp[i].size() != 2, "Wrong # of components requested for moment", HERE); @@ -433,24 +704,24 @@ namespace ntt { } g_writer.writeField(names, local_domain->fields.bckp, addresses); } + g_writer.endWriting(WriteMode::Fields); } // end shouldWrite("fields", step, time) if (write_particles) { - const auto prtl_stride = params.template get( + g_writer.beginWriting(WriteMode::Particles, current_step, current_time); + const auto prtl_stride = params.template get( "output.particles.stride"); for (const auto& prtl : g_writer.speciesWriters()) { auto& species = local_domain->species[prtl.species() - 1]; if (not species.is_sorted()) { - species.SortByTags(); + species.RemoveDead(); } - const std::size_t nout = species.npart() / prtl_stride; - array_t buff_x1, buff_x2, buff_x3; - array_t buff_ux1, buff_ux2, buff_ux3; - array_t buff_wei; - buff_wei = array_t { "w", nout }; - buff_ux1 = array_t { "u1", nout }; - buff_ux2 = array_t { "u2", nout }; - buff_ux3 = array_t { "u3", nout }; + const npart_t nout = species.npart() / prtl_stride; + array_t buff_x1, buff_x2, buff_x3; + array_t buff_ux1 { "u1", nout }; + array_t buff_ux2 { "ux2", nout }; + array_t buff_ux3 { "ux3", nout }; + array_t buff_wei { "w", nout }; if constexpr (M::Dim == Dim::_1D or M::Dim == Dim::_2D or M::Dim == Dim::_3D) { buff_x1 = array_t { "x1", nout }; @@ -478,16 +749,16 @@ namespace ntt { local_domain->mesh.metric)); // clang-format on } - std::size_t offset = 0; - std::size_t glob_tot = nout; + npart_t offset = 0; + npart_t glob_tot = nout; #if defined(MPI_ENABLED) - auto glob_nout = std::vector(g_ndomains); + auto glob_nout = std::vector(g_ndomains); MPI_Allgather(&nout, 1, - mpi::get_type(), + mpi::get_type(), glob_nout.data(), 1, - mpi::get_type(), + mpi::get_type(), MPI_COMM_WORLD); glob_tot = 0; for (auto r = 0; r < g_mpi_size; ++r) { @@ -513,9 +784,11 @@ namespace ntt { g_writer.writeParticleQuantity(buff_x3, glob_tot, offset, prtl.name("X", 3)); } } + g_writer.endWriting(WriteMode::Particles); } // end shouldWrite("particles", step, time) if (write_spectra) { + g_writer.beginWriting(WriteMode::Spectra, current_step, current_time); const auto log_bins = params.template get( "output.spectra.log_bins"); const auto n_bins = params.template get( @@ -579,19 +852,48 @@ namespace ntt { g_writer.writeSpectrum(dn, spec.name()); } g_writer.writeSpectrumBins(energy, "sEbn"); + g_writer.endWriting(WriteMode::Spectra); } // end shouldWrite("spectra", step, time) - g_writer.endWriting(); return true; } - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; - template struct Metadomain>; +#define METADOMAIN_OUTPUT(S, M) \ + template void Metadomain::InitWriter(adios2::ADIOS*, \ + const SimulationParams&); \ + template auto Metadomain::Write( \ + const SimulationParams&, \ + timestep_t, \ + timestep_t, \ + simtime_t, \ + simtime_t, \ + std::function&, \ + index_t, \ + timestep_t, \ + simtime_t, \ + const Domain&)>) -> bool; + + METADOMAIN_OUTPUT(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_OUTPUT(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_OUTPUT(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_OUTPUT(SimEngine::SRPIC, metric::Spherical) + METADOMAIN_OUTPUT(SimEngine::SRPIC, metric::QSpherical) + METADOMAIN_OUTPUT(SimEngine::GRPIC, metric::KerrSchild) + METADOMAIN_OUTPUT(SimEngine::GRPIC, metric::QKerrSchild) + METADOMAIN_OUTPUT(SimEngine::GRPIC, metric::KerrSchild0) + +#undef METADOMAIN_OUTPUT + +#if defined(MPI_ENABLED) + #define COMMVECTORPOTENTIAL(S, M) \ + template void Metadomain::CommunicateVectorPotential(unsigned short); + + COMMVECTORPOTENTIAL(SimEngine::GRPIC, metric::KerrSchild) + COMMVECTORPOTENTIAL(SimEngine::GRPIC, metric::QKerrSchild) + COMMVECTORPOTENTIAL(SimEngine::GRPIC, metric::KerrSchild0) + + #undef COMMVECTORPOTENTIAL +#endif } // namespace ntt diff --git a/src/framework/domain/stats.cpp b/src/framework/domain/stats.cpp new file mode 100644 index 000000000..6c5bb0ffa --- /dev/null +++ b/src/framework/domain/stats.cpp @@ -0,0 +1,296 @@ +#include "enums.h" +#include "global.h" + +#include "utils/comparators.h" +#include "utils/error.h" +#include "utils/log.h" +#include "utils/numeric.h" + +#include "metrics/kerr_schild.h" +#include "metrics/kerr_schild_0.h" +#include "metrics/minkowski.h" +#include "metrics/qkerr_schild.h" +#include "metrics/qspherical.h" +#include "metrics/spherical.h" + +#include "framework/containers/particles.h" +#include "framework/domain/domain.h" +#include "framework/domain/metadomain.h" +#include "framework/parameters.h" + +#include "kernels/reduced_stats.hpp" + +#include +#include +#include + +#include + +namespace ntt { + + template + void Metadomain::InitStatsWriter(const SimulationParams& params, + bool is_resuming) { + raise::ErrorIf( + l_subdomain_indices().size() != 1, + "StatsWriter for now is only supported for one subdomain per rank", + HERE); + auto local_domain = subdomain_ptr(l_subdomain_indices()[0]); + raise::ErrorIf(local_domain->is_placeholder(), + "local_domain is a placeholder", + HERE); + const auto filename = params.template get("simulation.name") + + "_stats.csv"; + const auto enable_stats = params.template get("output.stats.enable"); + if (enable_stats and (not is_resuming)) { + CallOnce( + [](auto& filename) { + if (std::filesystem::exists(filename)) { + std::filesystem::remove(filename); + } + }, + filename); + } + const auto stats_to_write = params.template get>( + "output.stats.quantities"); + const auto custom_stats_to_write = params.template get>( + "output.stats.custom"); + g_stats_writer.init( + params.template get("output.stats.interval"), + params.template get("output.stats.interval_time")); + g_stats_writer.defineStatsFilename(filename); + g_stats_writer.defineStatsOutputs(stats_to_write, false); + g_stats_writer.defineStatsOutputs(custom_stats_to_write, true); + + if (not std::filesystem::exists(filename)) { + g_stats_writer.writeHeader(); + } + } + + template + auto ComputeMoments(const SimulationParams& params, + const Mesh& mesh, + const std::vector>& prtl_species, + const std::vector& species, + const std::vector& components) -> real_t { + std::vector specs = species; + if (specs.size() == 0) { + // if no species specified, take all massive species + for (auto& sp : prtl_species) { + if (sp.mass() > 0) { + specs.push_back(sp.index()); + } + } + } + for (const auto& sp : specs) { + raise::ErrorIf((sp > prtl_species.size()) or (sp == 0), + "Invalid species index " + std::to_string(sp), + HERE); + } + // some parameters + const auto use_weights = params.template get("particles.use_weights"); + + real_t buffer = static_cast(0); + for (const auto& sp : specs) { + auto& prtl_spec = prtl_species[sp - 1]; + if (P == StatsID::Charge and cmp::AlmostZero_host(prtl_spec.charge())) { + continue; + } + if (P == StatsID::Rho and cmp::AlmostZero_host(prtl_spec.mass())) { + continue; + } + Kokkos::parallel_reduce( + "ComputeMoments", + prtl_spec.rangeActiveParticles(), + // clang-format off + kernel::ReducedParticleMoments_kernel(components, + prtl_spec.i1, prtl_spec.i2, prtl_spec.i3, + prtl_spec.dx1, prtl_spec.dx2, prtl_spec.dx3, + prtl_spec.ux1, prtl_spec.ux2, prtl_spec.ux3, + prtl_spec.phi, prtl_spec.weight, prtl_spec.tag, + prtl_spec.mass(), prtl_spec.charge(), + use_weights, mesh.metric), + // clang-format on + buffer); + } + return buffer; + } + + template + auto ReduceFields(Domain* domain, + const M& global_metric, + const std::vector& components) -> real_t { + auto buffer { ZERO }; + if constexpr (F == StatsID::JdotE) { + if (components.size() == 0) { + Kokkos::parallel_reduce( + "ReduceFields", + domain->mesh.rangeActiveCells(), + kernel::ReducedFields_kernel(domain->fields.em, + domain->fields.cur, + domain->mesh.metric), + buffer); + } else { + raise::Error("Components not supported for JdotE", HERE); + } + } else if constexpr ( + (S == SimEngine::SRPIC) and + (F == StatsID::B2 or F == StatsID::E2 or F == StatsID::ExB)) { + raise::ErrorIf(components.size() != 1, + "Components must be of size 1 for B2, E2 or ExB stats", + HERE); + const auto comp = components[0]; + if (comp == 1) { + Kokkos::parallel_reduce( + "ReduceFields", + domain->mesh.rangeActiveCells(), + kernel::ReducedFields_kernel(domain->fields.em, + domain->fields.cur, + domain->mesh.metric), + buffer); + } else if (comp == 2) { + Kokkos::parallel_reduce( + "ReduceFields", + domain->mesh.rangeActiveCells(), + kernel::ReducedFields_kernel(domain->fields.em, + domain->fields.cur, + domain->mesh.metric), + buffer); + } else if (comp == 3) { + Kokkos::parallel_reduce( + "ReduceFields", + domain->mesh.rangeActiveCells(), + kernel::ReducedFields_kernel(domain->fields.em, + domain->fields.cur, + domain->mesh.metric), + buffer); + } else { + raise::Error( + "Invalid component for B2, E2 or ExB stats: " + std::to_string(comp), + HERE); + } + } else { + raise::Error("ReduceFields not implemented for this stats ID + SimEngine " + "combination", + HERE); + } + + return buffer / global_metric.totVolume(); + } + + template + auto Metadomain::WriteStats( + const SimulationParams& params, + timestep_t current_step, + timestep_t finished_step, + simtime_t current_time, + simtime_t finished_time, + std::function&)> + CustomStat) -> bool { + if (not(params.template get("output.stats.enable") and + g_stats_writer.shouldWrite(finished_step, finished_time))) { + return false; + } + auto local_domain = subdomain_ptr(l_subdomain_indices()[0]); + logger::Checkpoint("Writing stats", HERE); + g_stats_writer.write(current_step); + g_stats_writer.write(current_time); + for (const auto& stat : g_stats_writer.statsWriters()) { + if (stat.id() == StatsID::Custom) { + if (CustomStat != nullptr) { + g_stats_writer.write( + CustomStat(stat.name(), finished_step, finished_time, *local_domain)); + } else { + raise::Error("Custom output requested but no function provided", HERE); + } + } else if (stat.id() == StatsID::N) { + g_stats_writer.write(ComputeMoments(params, + local_domain->mesh, + local_domain->species, + stat.species, + {})); + } else if (stat.id() == StatsID::Npart) { + g_stats_writer.write( + ComputeMoments(params, + local_domain->mesh, + local_domain->species, + stat.species, + {})); + } else if (stat.id() == StatsID::Rho) { + g_stats_writer.write( + ComputeMoments(params, + local_domain->mesh, + local_domain->species, + stat.species, + {})); + } else if (stat.id() == StatsID::Charge) { + g_stats_writer.write( + ComputeMoments(params, + local_domain->mesh, + local_domain->species, + stat.species, + {})); + } else if (stat.id() == StatsID::T) { + for (const auto& comp : stat.comp) { + g_stats_writer.write( + ComputeMoments(params, + local_domain->mesh, + local_domain->species, + stat.species, + comp)); + } + } else if (stat.id() == StatsID::JdotE) { + g_stats_writer.write( + ReduceFields(local_domain, g_mesh.metric, {})); + } else if (S == SimEngine::SRPIC) { + if (stat.id() == StatsID::E2) { + for (const auto& comp : stat.comp) { + g_stats_writer.write( + ReduceFields(local_domain, g_mesh.metric, comp)); + } + } else if (stat.id() == StatsID::B2) { + for (const auto& comp : stat.comp) { + g_stats_writer.write( + ReduceFields(local_domain, g_mesh.metric, comp)); + } + } else if (stat.id() == StatsID::ExB) { + for (const auto& comp : stat.comp) { + g_stats_writer.write( + ReduceFields(local_domain, g_mesh.metric, comp)); + } + } else { + raise::Error("Unrecognized stats ID " + stat.name(), HERE); + } + } else { + raise::Error("StatsID not implemented for particular SimEngine: " + + std::to_string(static_cast(S)), + HERE); + } + } + g_stats_writer.endWriting(); + return true; + } + +#define METADOMAIN_STATS(S, M) \ + template void Metadomain::InitStatsWriter(const SimulationParams&, bool); \ + template auto Metadomain::WriteStats( \ + const SimulationParams&, \ + timestep_t, \ + timestep_t, \ + simtime_t, \ + simtime_t, \ + std::function< \ + real_t(const std::string&, timestep_t, simtime_t, const Domain&)>) -> bool; + + METADOMAIN_STATS(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_STATS(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_STATS(SimEngine::SRPIC, metric::Minkowski) + METADOMAIN_STATS(SimEngine::SRPIC, metric::Spherical) + METADOMAIN_STATS(SimEngine::SRPIC, metric::QSpherical) + METADOMAIN_STATS(SimEngine::GRPIC, metric::KerrSchild) + METADOMAIN_STATS(SimEngine::GRPIC, metric::QKerrSchild) + METADOMAIN_STATS(SimEngine::GRPIC, metric::KerrSchild0) + +#undef METADOMAIN_STATS + +} // namespace ntt diff --git a/src/framework/parameters.cpp b/src/framework/parameters.cpp index 20c7e83d8..079ad615e 100644 --- a/src/framework/parameters.cpp +++ b/src/framework/parameters.cpp @@ -8,6 +8,7 @@ #include "utils/formatting.h" #include "utils/log.h" #include "utils/numeric.h" +#include "utils/toml.h" #include "metrics/kerr_schild.h" #include "metrics/kerr_schild_0.h" @@ -18,8 +19,6 @@ #include "framework/containers/species.h" -#include - #if defined(MPI_ENABLED) #include #endif @@ -32,29 +31,29 @@ namespace ntt { template - auto get_dx0_V0(const std::vector& resolution, - const boundaries_t& extent, - const std::map& params) - -> std::pair { + auto get_dx0_V0( + const std::vector& resolution, + const boundaries_t& extent, + const std::map& params) -> std::pair { const auto metric = M(resolution, extent, params); const auto dx0 = metric.dxMin(); coord_t x_corner { ZERO }; - for (unsigned short d { 0 }; d < (unsigned short)(M::Dim); ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { x_corner[d] = HALF; } const auto V0 = metric.sqrt_det_h(x_corner); return { dx0, V0 }; } - SimulationParams::SimulationParams(const toml::value& raw_data) { + /* + * . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + * Parameters that must not be changed during the checkpoint restart + * . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + */ + void SimulationParams::setImmutableParams(const toml::value& toml_data) { /* [simulation] --------------------------------------------------------- */ - set("simulation.name", toml::find(raw_data, "simulation", "name")); - set("simulation.runtime", - toml::find(raw_data, "simulation", "runtime")); - - const auto engine = fmt::toLower( - toml::find(raw_data, "simulation", "engine")); - const auto engine_enum = SimEngine::pick(engine.c_str()); + const auto engine_enum = SimEngine::pick( + fmt::toLower(toml::find(toml_data, "simulation", "engine")).c_str()); set("simulation.engine", engine_enum); int default_ndomains = 1; @@ -63,7 +62,7 @@ namespace ntt { "MPI_Comm_size failed", HERE); #endif - const auto ndoms = toml::find_or(raw_data, + const auto ndoms = toml::find_or(toml_data, "simulation", "domain", "number", @@ -71,7 +70,7 @@ namespace ntt { set("simulation.domain.number", (unsigned int)ndoms); auto decomposition = toml::find_or>( - raw_data, + toml_data, "simulation", "domain", "decomposition", @@ -79,9 +78,9 @@ namespace ntt { promiseToDefine("simulation.domain.decomposition"); /* [grid] --------------------------------------------------------------- */ - const auto res = toml::find>(raw_data, - "grid", - "resolution"); + const auto res = toml::find>(toml_data, + "grid", + "resolution"); raise::ErrorIf(res.size() < 1 || res.size() > 3, "invalid `grid.resolution`", HERE); @@ -98,7 +97,7 @@ namespace ntt { HERE); set("simulation.domain.decomposition", decomposition); - auto extent = toml::find>>(raw_data, + auto extent = toml::find>>(toml_data, "grid", "extent"); raise::ErrorIf(extent.size() < 1 || extent.size() > 3, @@ -107,17 +106,18 @@ namespace ntt { promiseToDefine("grid.extent"); /* [grid.metric] -------------------------------------------------------- */ - const auto metric = fmt::toLower( - toml::find(raw_data, "grid", "metric", "metric")); - const auto metric_enum = Metric::pick(metric.c_str()); + const auto metric_enum = Metric::pick( + fmt::toLower(toml::find(toml_data, "grid", "metric", "metric")) + .c_str()); promiseToDefine("grid.metric.metric"); std::string coord; - if (metric == "minkowski") { + if (metric_enum == Metric::Minkowski) { raise::ErrorIf(engine_enum != SimEngine::SRPIC, "minkowski metric is only supported for SRPIC", HERE); coord = "cart"; - } else if (metric[0] == 'q') { + } else if (metric_enum == Metric::QKerr_Schild or + metric_enum == Metric::QSpherical) { // quasi-spherical geometry raise::ErrorIf(dim == Dim::_1D, "not enough dimensions for qspherical geometry", @@ -127,9 +127,9 @@ namespace ntt { HERE); coord = "qsph"; set("grid.metric.qsph_r0", - toml::find_or(raw_data, "grid", "metric", "qsph_r0", defaults::qsph::r0)); + toml::find_or(toml_data, "grid", "metric", "qsph_r0", defaults::qsph::r0)); set("grid.metric.qsph_h", - toml::find_or(raw_data, "grid", "metric", "qsph_h", defaults::qsph::h)); + toml::find_or(toml_data, "grid", "metric", "qsph_h", defaults::qsph::h)); } else { // spherical geometry raise::ErrorIf(dim == Dim::_1D, @@ -142,7 +142,7 @@ namespace ntt { } if ((engine_enum == SimEngine::GRPIC) && (metric_enum != Metric::Kerr_Schild_0)) { - const auto ks_a = toml::find_or(raw_data, + const auto ks_a = toml::find_or(toml_data, "grid", "metric", "ks_a", @@ -153,9 +153,193 @@ namespace ntt { const auto coord_enum = Coord::pick(coord.c_str()); set("grid.metric.coord", coord_enum); + /* [scales] ------------------------------------------------------------- */ + const auto larmor0 = toml::find(toml_data, "scales", "larmor0"); + const auto skindepth0 = toml::find(toml_data, "scales", "skindepth0"); + raise::ErrorIf(larmor0 <= ZERO || skindepth0 <= ZERO, + "larmor0 and skindepth0 must be positive", + HERE); + set("scales.larmor0", larmor0); + set("scales.skindepth0", skindepth0); + promiseToDefine("scales.dx0"); + promiseToDefine("scales.V0"); + promiseToDefine("scales.n0"); + promiseToDefine("scales.q0"); + set("scales.sigma0", SQR(skindepth0 / larmor0)); + set("scales.B0", ONE / larmor0); + set("scales.omegaB0", ONE / larmor0); + + /* [particles] ---------------------------------------------------------- */ + const auto ppc0 = toml::find(toml_data, "particles", "ppc0"); + set("particles.ppc0", ppc0); + raise::ErrorIf(ppc0 <= 0.0, "ppc0 must be positive", HERE); + set("particles.use_weights", + toml::find_or(toml_data, "particles", "use_weights", false)); + + /* [particles.species] -------------------------------------------------- */ + std::vector species; + const auto species_tab = toml::find_or(toml_data, + "particles", + "species", + toml::array {}); + set("particles.nspec", species_tab.size()); + + spidx_t idx = 1; + for (const auto& sp : species_tab) { + const auto label = toml::find_or(sp, + "label", + "s" + std::to_string(idx)); + const auto mass = toml::find(sp, "mass"); + const auto charge = toml::find(sp, "charge"); + raise::ErrorIf((charge != 0.0f) && (mass == 0.0f), + "mass of the charged species must be non-zero", + HERE); + const auto is_massless = (mass == 0.0f) && (charge == 0.0f); + const auto def_pusher = (is_massless ? defaults::ph_pusher + : defaults::em_pusher); + const auto maxnpart_real = toml::find(sp, "maxnpart"); + const auto maxnpart = static_cast(maxnpart_real); + auto pusher = toml::find_or(sp, "pusher", std::string(def_pusher)); + const auto npayloads = toml::find_or(sp, + "n_payloads", + static_cast(0)); + const auto cooling = toml::find_or(sp, "cooling", std::string("None")); + raise::ErrorIf((fmt::toLower(cooling) != "none") && is_massless, + "cooling is only applicable to massive particles", + HERE); + raise::ErrorIf((fmt::toLower(pusher) == "photon") && !is_massless, + "photon pusher is only applicable to massless particles", + HERE); + bool use_gca = false; + if (pusher.find(',') != std::string::npos) { + raise::ErrorIf(fmt::toLower(pusher.substr(pusher.find(',') + 1, + pusher.size())) != "gca", + "invalid pusher syntax", + HERE); + use_gca = true; + pusher = pusher.substr(0, pusher.find(',')); + } + const auto pusher_enum = PrtlPusher::pick(pusher.c_str()); + const auto cooling_enum = Cooling::pick(cooling.c_str()); + if (use_gca) { + raise::ErrorIf(engine_enum != SimEngine::SRPIC, + "GCA pushers are only supported for SRPIC", + HERE); + promiseToDefine("algorithms.gca.e_ovr_b_max"); + promiseToDefine("algorithms.gca.larmor_max"); + } + if (cooling_enum == Cooling::SYNCHROTRON) { + raise::ErrorIf(engine_enum != SimEngine::SRPIC, + "Synchrotron cooling is only supported for SRPIC", + HERE); + promiseToDefine("algorithms.synchrotron.gamma_rad"); + } + + species.emplace_back(ParticleSpecies(idx, + label, + mass, + charge, + maxnpart, + pusher_enum, + use_gca, + cooling_enum, + npayloads)); + idx += 1; + } + set("particles.species", species); + + /* inferred variables --------------------------------------------------- */ + // extent + if (extent.size() > dim) { + extent.erase(extent.begin() + (std::size_t)(dim), extent.end()); + } + raise::ErrorIf(extent[0].size() != 2, "invalid `grid.extent[0]`", HERE); + if (coord_enum != Coord::Cart) { + raise::ErrorIf(extent.size() > 1, + "invalid `grid.extent` for non-cartesian geometry", + HERE); + extent.push_back({ ZERO, constant::PI }); + if (dim == Dim::_3D) { + extent.push_back({ ZERO, TWO * constant::PI }); + } + } + raise::ErrorIf(extent.size() != dim, "invalid inferred `grid.extent`", HERE); + boundaries_t extent_pairwise; + for (auto d { 0u }; d < (dim_t)dim; ++d) { + raise::ErrorIf(extent[d].size() != 2, + fmt::format("invalid inferred `grid.extent[%d]`", d), + HERE); + extent_pairwise.push_back({ extent[d][0], extent[d][1] }); + } + set("grid.extent", extent_pairwise); + + // metric, dx0, V0, n0, q0 + { + boundaries_t ext; + for (const auto& e : extent) { + ext.push_back({ e[0], e[1] }); + } + std::map params; + if (coord_enum == Coord::Qsph) { + params["r0"] = get("grid.metric.qsph_r0"); + params["h"] = get("grid.metric.qsph_h"); + } + if ((engine_enum == SimEngine::GRPIC) && + (metric_enum != Metric::Kerr_Schild_0)) { + params["a"] = get("grid.metric.ks_a"); + } + set("grid.metric.params", params); + + std::pair dx0_V0; + if (metric_enum == Metric::Minkowski) { + if (dim == Dim::_1D) { + dx0_V0 = get_dx0_V0>(res, ext, params); + } else if (dim == Dim::_2D) { + dx0_V0 = get_dx0_V0>(res, ext, params); + } else { + dx0_V0 = get_dx0_V0>(res, ext, params); + } + } else if (metric_enum == Metric::Spherical) { + dx0_V0 = get_dx0_V0>(res, ext, params); + } else if (metric_enum == Metric::QSpherical) { + dx0_V0 = get_dx0_V0>(res, ext, params); + } else if (metric_enum == Metric::Kerr_Schild) { + dx0_V0 = get_dx0_V0>(res, ext, params); + } else if (metric_enum == Metric::Kerr_Schild_0) { + dx0_V0 = get_dx0_V0>(res, ext, params); + } else if (metric_enum == Metric::QKerr_Schild) { + dx0_V0 = get_dx0_V0>(res, ext, params); + } + auto [dx0, V0] = dx0_V0; + set("scales.dx0", dx0); + set("scales.V0", V0); + set("scales.n0", ppc0 / V0); + set("scales.q0", V0 / (ppc0 * SQR(skindepth0))); + + set("grid.metric.metric", metric_enum); + } + } + + /* + * . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + * Parameters that may be changed during the checkpoint restart + * . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + */ + void SimulationParams::setMutableParams(const toml::value& toml_data) { + const auto engine_enum = get("simulation.engine"); + const auto coord_enum = get("grid.metric.coord"); + const auto dim = get("grid.dim"); + const auto extent_pairwise = get>("grid.extent"); + + /* [simulation] --------------------------------------------------------- */ + set("simulation.name", + toml::find(toml_data, "simulation", "name")); + set("simulation.runtime", + toml::find(toml_data, "simulation", "runtime")); + /* [grid.boundaraies] --------------------------------------------------- */ auto flds_bc = toml::find>>( - raw_data, + toml_data, "grid", "boundaries", "fields"); @@ -167,9 +351,8 @@ namespace ntt { auto atm_defined = false; for (const auto& bcs : flds_bc) { for (const auto& bc : bcs) { - if (fmt::toLower(bc) == "absorb") { - promiseToDefine("grid.boundaries.absorb.ds"); - promiseToDefine("grid.boundaries.absorb.coeff"); + if (fmt::toLower(bc) == "match") { + promiseToDefine("grid.boundaries.match.ds"); } if (fmt::toLower(bc) == "atmosphere") { raise::ErrorIf(atm_defined, @@ -188,7 +371,7 @@ namespace ntt { } auto prtl_bc = toml::find>>( - raw_data, + toml_data, "grid", "boundaries", "particles"); @@ -202,7 +385,6 @@ namespace ntt { for (const auto& bc : bcs) { if (fmt::toLower(bc) == "absorb") { promiseToDefine("grid.boundaries.absorb.ds"); - promiseToDefine("grid.boundaries.absorb.coeff"); } if (fmt::toLower(bc) == "atmosphere") { raise::ErrorIf(atm_defined, @@ -220,41 +402,26 @@ namespace ntt { } } - /* [scales] ------------------------------------------------------------- */ - const auto larmor0 = toml::find(raw_data, "scales", "larmor0"); - const auto skindepth0 = toml::find(raw_data, "scales", "skindepth0"); - raise::ErrorIf(larmor0 <= ZERO || skindepth0 <= ZERO, - "larmor0 and skindepth0 must be positive", - HERE); - set("scales.larmor0", larmor0); - set("scales.skindepth0", skindepth0); - promiseToDefine("scales.dx0"); - promiseToDefine("scales.V0"); - promiseToDefine("scales.n0"); - promiseToDefine("scales.q0"); - set("scales.sigma0", SQR(skindepth0 / larmor0)); - set("scales.B0", ONE / larmor0); - set("scales.omegaB0", ONE / larmor0); - /* [algorithms] --------------------------------------------------------- */ set("algorithms.current_filters", - toml::find_or(raw_data, + toml::find_or(toml_data, "algorithms", "current_filters", defaults::current_filters)); /* [algorithms.toggles] ------------------------------------------------- */ set("algorithms.toggles.fieldsolver", - toml::find_or(raw_data, "algorithms", "toggles", "fieldsolver", true)); + toml::find_or(toml_data, "algorithms", "toggles", "fieldsolver", true)); set("algorithms.toggles.deposit", - toml::find_or(raw_data, "algorithms", "toggles", "deposit", true)); + toml::find_or(toml_data, "algorithms", "toggles", "deposit", true)); /* [algorithms.timestep] ------------------------------------------------ */ set("algorithms.timestep.CFL", - toml::find_or(raw_data, "algorithms", "timestep", "CFL", defaults::cfl)); - promiseToDefine("algorithms.timestep.dt"); + toml::find_or(toml_data, "algorithms", "timestep", "CFL", defaults::cfl)); + set("algorithms.timestep.dt", + get("algorithms.timestep.CFL") * get("scales.dx0")); set("algorithms.timestep.correction", - toml::find_or(raw_data, + toml::find_or(toml_data, "algorithms", "timestep", "correction", @@ -263,132 +430,52 @@ namespace ntt { /* [algorithms.gr] ------------------------------------------------------ */ if (engine_enum == SimEngine::GRPIC) { set("algorithms.gr.pusher_eps", - toml::find_or(raw_data, + toml::find_or(toml_data, "algorithms", "gr", "pusher_eps", defaults::gr::pusher_eps)); set("algorithms.gr.pusher_niter", - toml::find_or(raw_data, + toml::find_or(toml_data, "algorithms", "gr", "pusher_niter", defaults::gr::pusher_niter)); } - /* [particles] ---------------------------------------------------------- */ - const auto ppc0 = toml::find(raw_data, "particles", "ppc0"); - set("particles.ppc0", ppc0); - raise::ErrorIf(ppc0 <= 0.0, "ppc0 must be positive", HERE); - set("particles.use_weights", - toml::find_or(raw_data, "particles", "use_weights", false)); - -#if defined(MPI_ENABLED) - const std::size_t sort_interval = 1; -#else - const std::size_t sort_interval = toml::find_or(raw_data, - "particles", - "sort_interval", - defaults::sort_interval); -#endif - set("particles.sort_interval", sort_interval); - - /* [particles.species] -------------------------------------------------- */ - std::vector species; - const auto species_tab = toml::find_or(raw_data, - "particles", - "species", - toml::array {}); - set("particles.nspec", species_tab.size()); - - unsigned short idx = 1; - for (const auto& sp : species_tab) { - const auto label = toml::find_or(sp, - "label", - "s" + std::to_string(idx)); - const auto mass = toml::find(sp, "mass"); - const auto charge = toml::find(sp, "charge"); - raise::ErrorIf((charge != 0.0f) && (mass == 0.0f), - "mass of the charged species must be non-zero", - HERE); - const auto is_massless = (mass == 0.0f) && (charge == 0.0f); - const auto def_pusher = (is_massless ? defaults::ph_pusher - : defaults::em_pusher); - const auto maxnpart_real = toml::find(sp, "maxnpart"); - const auto maxnpart = static_cast(maxnpart_real); - auto pusher = toml::find_or(sp, "pusher", std::string(def_pusher)); - const auto npayloads = toml::find_or(sp, - "n_payloads", - static_cast(0)); - const auto cooling = toml::find_or(sp, "cooling", std::string("None")); - raise::ErrorIf((fmt::toLower(cooling) != "none") && is_massless, - "cooling is only applicable to massive particles", - HERE); - raise::ErrorIf((fmt::toLower(pusher) == "photon") && !is_massless, - "photon pusher is only applicable to massless particles", - HERE); - bool use_gca = false; - if (pusher.find(',') != std::string::npos) { - raise::ErrorIf(fmt::toLower(pusher.substr(pusher.find(',') + 1, - pusher.size())) != "gca", - "invalid pusher syntax", - HERE); - use_gca = true; - pusher = pusher.substr(0, pusher.find(',')); - } - const auto pusher_enum = PrtlPusher::pick(pusher.c_str()); - const auto cooling_enum = Cooling::pick(cooling.c_str()); - if (use_gca) { - raise::ErrorIf(engine_enum != SimEngine::SRPIC, - "GCA pushers are only supported for SRPIC", - HERE); - promiseToDefine("algorithms.gca.e_ovr_b_max"); - promiseToDefine("algorithms.gca.larmor_max"); - } - if (cooling_enum == Cooling::SYNCHROTRON) { - raise::ErrorIf(engine_enum != SimEngine::SRPIC, - "Synchrotron cooling is only supported for SRPIC", - HERE); - promiseToDefine("algorithms.synchrotron.gamma_rad"); - } - - species.emplace_back(ParticleSpecies(idx, - label, - mass, - charge, - maxnpart, - pusher_enum, - use_gca, - cooling_enum, - npayloads)); - idx += 1; - } - set("particles.species", species); + set("particles.clear_interval", + toml::find_or(toml_data, "particles", "clear_interval", defaults::clear_interval)); /* [output] ------------------------------------------------------------- */ // fields set("output.format", - toml::find_or(raw_data, "output", "format", defaults::output::format)); + toml::find_or(toml_data, "output", "format", defaults::output::format)); set("output.interval", - toml::find_or(raw_data, "output", "interval", defaults::output::interval)); + toml::find_or(toml_data, "output", "interval", defaults::output::interval)); set("output.interval_time", - toml::find_or(raw_data, "output", "interval_time", -1.0)); + toml::find_or(toml_data, "output", "interval_time", -1.0)); + set("output.separate_files", + toml::find_or(toml_data, "output", "separate_files", true)); + + promiseToDefine("output.fields.enable"); promiseToDefine("output.fields.interval"); promiseToDefine("output.fields.interval_time"); - promiseToDefine("output.fields.enable"); + promiseToDefine("output.particles.enable"); promiseToDefine("output.particles.interval"); promiseToDefine("output.particles.interval_time"); - promiseToDefine("output.particles.enable"); + promiseToDefine("output.spectra.enable"); promiseToDefine("output.spectra.interval"); promiseToDefine("output.spectra.interval_time"); - promiseToDefine("output.spectra.enable"); + promiseToDefine("output.stats.enable"); + promiseToDefine("output.stats.interval"); + promiseToDefine("output.stats.interval_time"); - const auto flds_out = toml::find_or(raw_data, + const auto flds_out = toml::find_or(toml_data, "output", "fields", "quantities", std::vector {}); - const auto custom_flds_out = toml::find_or(raw_data, + const auto custom_flds_out = toml::find_or(toml_data, "output", "fields", "custom", @@ -399,28 +486,58 @@ namespace ntt { set("output.fields.quantities", flds_out); set("output.fields.custom", custom_flds_out); set("output.fields.mom_smooth", - toml::find_or(raw_data, + toml::find_or(toml_data, "output", "fields", "mom_smooth", defaults::output::mom_smooth)); - set("output.fields.stride", - toml::find_or(raw_data, "output", "fields", "stride", defaults::output::flds_stride)); + std::vector field_dwn; + try { + auto field_dwn_ = toml::find>(toml_data, + "output", + "fields", + "downsampling"); + for (auto i = 0u; i < field_dwn_.size(); ++i) { + field_dwn.push_back(field_dwn_[i]); + } + } catch (...) { + try { + auto field_dwn_ = toml::find(toml_data, + "output", + "fields", + "downsampling"); + for (auto i = 0u; i < dim; ++i) { + field_dwn.push_back(field_dwn_); + } + } catch (...) { + for (auto i = 0u; i < dim; ++i) { + field_dwn.push_back(1u); + } + } + } + raise::ErrorIf(field_dwn.size() > 3, "invalid `output.fields.downsampling`", HERE); + if (field_dwn.size() > dim) { + field_dwn.erase(field_dwn.begin() + (std::size_t)(dim), field_dwn.end()); + } + for (const auto& dwn : field_dwn) { + raise::ErrorIf(dwn == 0, "downsampling factor must be nonzero", HERE); + } + set("output.fields.downsampling", field_dwn); // particles - auto prtl_out = toml::find_or(raw_data, - "output", - "particles", - "species", - std::vector {}); - if (prtl_out.size() == 0) { - for (unsigned short i = 0; i < species.size(); ++i) { - prtl_out.push_back(i + 1); - } + auto all_specs = std::vector {}; + const auto nspec = get("particles.nspec"); + for (auto i = 0u; i < nspec; ++i) { + all_specs.push_back(static_cast(i + 1)); } + const auto prtl_out = toml::find_or(toml_data, + "output", + "particles", + "species", + all_specs); set("output.particles.species", prtl_out); set("output.particles.stride", - toml::find_or(raw_data, + toml::find_or(toml_data, "output", "particles", "stride", @@ -428,37 +545,55 @@ namespace ntt { // spectra set("output.spectra.e_min", - toml::find_or(raw_data, "output", "spectra", "e_min", defaults::output::spec_emin)); + toml::find_or(toml_data, "output", "spectra", "e_min", defaults::output::spec_emin)); set("output.spectra.e_max", - toml::find_or(raw_data, "output", "spectra", "e_max", defaults::output::spec_emax)); + toml::find_or(toml_data, "output", "spectra", "e_max", defaults::output::spec_emax)); set("output.spectra.log_bins", - toml::find_or(raw_data, + toml::find_or(toml_data, "output", "spectra", "log_bins", defaults::output::spec_log)); set("output.spectra.n_bins", - toml::find_or(raw_data, "output", "spectra", "n_bins", defaults::output::spec_nbins)); + toml::find_or(toml_data, + "output", + "spectra", + "n_bins", + defaults::output::spec_nbins)); + + // stats + set("output.stats.quantities", + toml::find_or(toml_data, + "output", + "stats", + "quantities", + defaults::output::stats_quantities)); + set("output.stats.custom", + toml::find_or(toml_data, + "output", + "stats", + "custom", + std::vector {})); // intervals - for (const auto& type : { "fields", "particles", "spectra" }) { - const auto q_int = toml::find_or(raw_data, - "output", - std::string(type), - "interval", - 0); - const auto q_int_time = toml::find_or(raw_data, - "output", - std::string(type), - "interval_time", - -1.0); + for (const auto& type : { "fields", "particles", "spectra", "stats" }) { + const auto q_int = toml::find_or(toml_data, + "output", + std::string(type), + "interval", + 0); + const auto q_int_time = toml::find_or(toml_data, + "output", + std::string(type), + "interval_time", + -1.0); set("output." + std::string(type) + ".enable", - toml::find_or(raw_data, "output", std::string(type), "enable", true)); - if (q_int == 0 && q_int_time == -1.0) { + toml::find_or(toml_data, "output", std::string(type), "enable", true)); + if ((q_int == 0) and (q_int_time == -1.0)) { set("output." + std::string(type) + ".interval", - get("output.interval")); + get("output.interval")); set("output." + std::string(type) + ".interval_time", - get("output.interval_time")); + get("output.interval_time")); } else { set("output." + std::string(type) + ".interval", q_int); set("output." + std::string(type) + ".interval_time", q_int_time); @@ -467,43 +602,62 @@ namespace ntt { /* [output.debug] ------------------------------------------------------- */ set("output.debug.as_is", - toml::find_or(raw_data, "output", "debug", "as_is", false)); - set("output.debug.ghosts", - toml::find_or(raw_data, "output", "debug", "ghosts", false)); + toml::find_or(toml_data, "output", "debug", "as_is", false)); + const auto output_ghosts = toml::find_or(toml_data, + "output", + "debug", + "ghosts", + false); + set("output.debug.ghosts", output_ghosts); + if (output_ghosts) { + for (const auto& dwn : field_dwn) { + raise::ErrorIf( + dwn != 1, + "full resolution required when outputting with ghost cells", + HERE); + } + } + + /* [checkpoint] --------------------------------------------------------- */ + set("checkpoint.interval", + toml::find_or(toml_data, + "checkpoint", + "interval", + defaults::checkpoint::interval)); + set("checkpoint.interval_time", + toml::find_or(toml_data, "checkpoint", "interval_time", -1.0)); + set("checkpoint.keep", + toml::find_or(toml_data, "checkpoint", "keep", defaults::checkpoint::keep)); + auto walltime_str = toml::find_or(toml_data, + "checkpoint", + "walltime", + defaults::checkpoint::walltime); + if (walltime_str.empty()) { + walltime_str = defaults::checkpoint::walltime; + } + set("checkpoint.walltime", walltime_str); + + const auto checkpoint_write_path = toml::find_or( + toml_data, + "checkpoint", + "write_path", + fmt::format(defaults::checkpoint::write_path.c_str(), + get("simulation.name").c_str())); + set("checkpoint.write_path", checkpoint_write_path); + set("checkpoint.read_path", + toml::find_or(toml_data, "checkpoint", "read_path", checkpoint_write_path)); /* [diagnostics] -------------------------------------------------------- */ set("diagnostics.interval", - toml::find_or(raw_data, "diagnostics", "interval", defaults::diag::interval)); + toml::find_or(toml_data, "diagnostics", "interval", defaults::diag::interval)); set("diagnostics.blocking_timers", - toml::find_or(raw_data, "diagnostics", "blocking_timers", false)); + toml::find_or(toml_data, "diagnostics", "blocking_timers", false)); set("diagnostics.colored_stdout", - toml::find_or(raw_data, "diagnostics", "colored_stdout", false)); + toml::find_or(toml_data, "diagnostics", "colored_stdout", false)); + set("diagnostics.log_level", + toml::find_or(toml_data, "diagnostics", "log_level", defaults::diag::log_level)); /* inferred variables --------------------------------------------------- */ - // extent - if (extent.size() > dim) { - extent.erase(extent.begin() + (std::size_t)(dim), extent.end()); - } - raise::ErrorIf(extent[0].size() != 2, "invalid `grid.extent[0]`", HERE); - if (coord_enum != Coord::Cart) { - raise::ErrorIf(extent.size() > 1, - "invalid `grid.extent` for non-cartesian geometry", - HERE); - extent.push_back({ ZERO, constant::PI }); - if (dim == Dim::_3D) { - extent.push_back({ ZERO, TWO * constant::PI }); - } - } - raise::ErrorIf(extent.size() != dim, "invalid inferred `grid.extent`", HERE); - boundaries_t extent_parwise; - for (unsigned short d = 0; d < (unsigned short)dim; ++d) { - raise::ErrorIf(extent[d].size() != 2, - fmt::format("invalid inferred `grid.extent[%d]`", d), - HERE); - extent_parwise.push_back({ extent[d][0], extent[d][1] }); - } - set("grid.extent", extent_parwise); - // fields/particle boundaries std::vector> flds_bc_enum; std::vector> prtl_bc_enum; @@ -514,7 +668,7 @@ namespace ntt { raise::ErrorIf(prtl_bc.size() != (std::size_t)dim, "invalid `grid.boundaries.particles`", HERE); - for (unsigned short d = 0; d < (unsigned short)dim; ++d) { + for (auto d { 0u }; d < (dim_t)dim; ++d) { flds_bc_enum.push_back({}); prtl_bc_enum.push_back({}); const auto fbc = flds_bc[d]; @@ -616,7 +770,7 @@ namespace ntt { HERE); boundaries_t flds_bc_pairwise; boundaries_t prtl_bc_pairwise; - for (unsigned short d = 0; d < (unsigned short)dim; ++d) { + for (auto d { 0u }; d < (dim_t)dim; ++d) { raise::ErrorIf( flds_bc_enum[d].size() != 2, fmt::format("invalid inferred `grid.boundaries.fields[%d]`", d), @@ -631,58 +785,109 @@ namespace ntt { set("grid.boundaries.fields", flds_bc_pairwise); set("grid.boundaries.particles", prtl_bc_pairwise); + if (isPromised("grid.boundaries.match.ds")) { + if (coord_enum == Coord::Cart) { + auto min_extent = std::numeric_limits::max(); + for (const auto& e : extent_pairwise) { + min_extent = std::min(min_extent, e.second - e.first); + } + const auto default_ds = min_extent * defaults::bc::match::ds_frac; + boundaries_t ds_array; + try { + auto ds = toml::find(toml_data, "grid", "boundaries", "match", "ds"); + for (auto d = 0u; d < dim; ++d) { + ds_array.push_back({ ds, ds }); + } + } catch (...) { + try { + const auto ds = toml::find>>( + toml_data, + "grid", + "boundaries", + "match", + "ds"); + raise::ErrorIf(ds.size() != dim, + "invalid # in `grid.boundaries.match.ds`", + HERE); + for (auto d = 0u; d < dim; ++d) { + if (ds[d].size() == 1) { + ds_array.push_back({ ds[d][0], ds[d][0] }); + } else if (ds[d].size() == 2) { + ds_array.push_back({ ds[d][0], ds[d][1] }); + } else if (ds[d].size() == 0) { + ds_array.push_back({}); + } else { + raise::Error("invalid `grid.boundaries.match.ds`", HERE); + } + } + } catch (...) { + for (auto d = 0u; d < dim; ++d) { + ds_array.push_back({ default_ds, default_ds }); + } + } + } + set("grid.boundaries.match.ds", ds_array); + } else { + auto r_extent = extent_pairwise[0].second - extent_pairwise[0].first; + const auto ds = toml::find_or( + toml_data, + "grid", + "boundaries", + "match", + "ds", + r_extent * defaults::bc::match::ds_frac); + boundaries_t ds_array { + { ds, ds } + }; + set("grid.boundaries.match.ds", ds_array); + } + } + if (isPromised("grid.boundaries.absorb.ds")) { if (coord_enum == Coord::Cart) { auto min_extent = std::numeric_limits::max(); - for (const auto& e : extent) { - min_extent = std::min(min_extent, e[1] - e[0]); + for (const auto& e : extent_pairwise) { + min_extent = std::min(min_extent, e.second - e.first); } set("grid.boundaries.absorb.ds", - toml::find_or(raw_data, + toml::find_or(toml_data, "grid", "boundaries", "absorb", "ds", min_extent * defaults::bc::absorb::ds_frac)); } else { - auto r_extent = extent[0][1] - extent[0][0]; + auto r_extent = extent_pairwise[0].second - extent_pairwise[0].first; set("grid.boundaries.absorb.ds", - toml::find_or(raw_data, + toml::find_or(toml_data, "grid", "boundaries", "absorb", "ds", r_extent * defaults::bc::absorb::ds_frac)); } - set("grid.boundaries.absorb.coeff", - toml::find_or(raw_data, - "grid", - "boundaries", - "absorb", - "coeff", - defaults::bc::absorb::coeff)); } if (isPromised("grid.boundaries.atmosphere.temperature")) { - const auto atm_T = toml::find(raw_data, + const auto atm_T = toml::find(toml_data, "grid", "boundaries", "atmosphere", "temperature"); - const auto atm_h = toml::find(raw_data, + const auto atm_h = toml::find(toml_data, "grid", "boundaries", "atmosphere", "height"); set("grid.boundaries.atmosphere.temperature", atm_T); set("grid.boundaries.atmosphere.density", - toml::find(raw_data, "grid", "boundaries", "atmosphere", "density")); + toml::find(toml_data, "grid", "boundaries", "atmosphere", "density")); set("grid.boundaries.atmosphere.ds", - toml::find_or(raw_data, "grid", "boundaries", "atmosphere", "ds", ZERO)); + toml::find_or(toml_data, "grid", "boundaries", "atmosphere", "ds", ZERO)); set("grid.boundaries.atmosphere.height", atm_h); set("grid.boundaries.atmosphere.g", atm_T / atm_h); - const auto atm_species = toml::find>( - raw_data, + const auto atm_species = toml::find>( + toml_data, "grid", "boundaries", "atmosphere", @@ -693,78 +898,34 @@ namespace ntt { // gca if (isPromised("algorithms.gca.e_ovr_b_max")) { set("algorithms.gca.e_ovr_b_max", - toml::find_or(raw_data, + toml::find_or(toml_data, "algorithms", "gca", "e_ovr_b_max", defaults::gca::EovrB_max)); set("algorithms.gca.larmor_max", - toml::find_or(raw_data, "algorithms", "gca", "larmor_max", ZERO)); + toml::find_or(toml_data, "algorithms", "gca", "larmor_max", ZERO)); } // cooling if (isPromised("algorithms.synchrotron.gamma_rad")) { set("algorithms.synchrotron.gamma_rad", - toml::find_or(raw_data, + toml::find_or(toml_data, "algorithms", "synchrotron", "gamma_rad", defaults::synchrotron::gamma_rad)); } - // metric, dx0, V0, n0, q0 - { - boundaries_t ext; - for (const auto& e : extent) { - ext.push_back({ e[0], e[1] }); - } - std::map params; - if (coord_enum == Coord::Qsph) { - params["r0"] = get("grid.metric.qsph_r0"); - params["h"] = get("grid.metric.qsph_h"); - } - if ((engine_enum == SimEngine::GRPIC) && - (metric_enum != Metric::Kerr_Schild_0)) { - params["a"] = get("grid.metric.ks_a"); - } - set("grid.metric.params", params); - - std::pair dx0_V0; - if (metric_enum == Metric::Minkowski) { - if (dim == Dim::_1D) { - dx0_V0 = get_dx0_V0>(res, ext, params); - } else if (dim == Dim::_2D) { - dx0_V0 = get_dx0_V0>(res, ext, params); - } else { - dx0_V0 = get_dx0_V0>(res, ext, params); - } - } else if (metric_enum == Metric::Spherical) { - dx0_V0 = get_dx0_V0>(res, ext, params); - } else if (metric_enum == Metric::QSpherical) { - dx0_V0 = get_dx0_V0>(res, ext, params); - } else if (metric_enum == Metric::Kerr_Schild) { - dx0_V0 = get_dx0_V0>(res, ext, params); - } else if (metric_enum == Metric::Kerr_Schild_0) { - dx0_V0 = get_dx0_V0>(res, ext, params); - } else if (metric_enum == Metric::QKerr_Schild) { - dx0_V0 = get_dx0_V0>(res, ext, params); - } - auto [dx0, V0] = dx0_V0; - set("scales.dx0", dx0); - set("scales.V0", V0); - set("scales.n0", ppc0 / V0); - set("scales.q0", V0 / (ppc0 * SQR(skindepth0))); - - set("grid.metric.metric", metric_enum); - set("algorithms.timestep.dt", get("algorithms.timestep.CFL") * dx0); + // @TODO: disabling stats for non-Cartesian + if (coord_enum != Coord::Cart) { + set("output.stats.enable", false); } + } - raise::ErrorIf(!promisesFulfilled(), - "Have not defined all the necessary variables", - HERE); - + void SimulationParams::setSetupParams(const toml::value& toml_data) { /* [setup] -------------------------------------------------------------- */ - const auto& setup = toml::find_or(raw_data, "setup", toml::table {}); + const auto setup = toml::find_or(toml_data, "setup", toml::table {}); for (const auto& [key, val] : setup) { if (val.is_boolean()) { set("setup." + key, (bool)(val.as_boolean())); @@ -812,4 +973,18 @@ namespace ntt { } } } + + void SimulationParams::setCheckpointParams(bool is_resuming, + timestep_t start_step, + simtime_t start_time) { + set("checkpoint.is_resuming", is_resuming); + set("checkpoint.start_step", start_step); + set("checkpoint.start_time", start_time); + } + + void SimulationParams::checkPromises() const { + raise::ErrorIf(!promisesFulfilled(), + "Have not defined all the necessary variables", + HERE); + } } // namespace ntt diff --git a/src/framework/parameters.h b/src/framework/parameters.h index 301f7053f..723e860db 100644 --- a/src/framework/parameters.h +++ b/src/framework/parameters.h @@ -18,22 +18,42 @@ #define FRAMEWORK_PARAMETERS_H #include "utils/param_container.h" - -#include +#include "utils/toml.h" namespace ntt { struct SimulationParams : public prm::Parameters { - SimulationParams() = default; - SimulationParams(const toml::value&); + + SimulationParams() {} + + SimulationParams(const SimulationParams&) = default; SimulationParams& operator=(const SimulationParams& other) { vars = std::move(other.vars); promises = std::move(other.promises); + raw_data = std::move(other.raw_data); return *this; } ~SimulationParams() = default; + + void setImmutableParams(const toml::value&); + void setMutableParams(const toml::value&); + void setCheckpointParams(bool, timestep_t, simtime_t); + void setSetupParams(const toml::value&); + void checkPromises() const; + + [[nodiscard]] + auto data() const -> const toml::value& { + return raw_data; + } + + void setRawData(const toml::value& data) { + raw_data = data; + } + + private: + toml::value raw_data; }; } // namespace ntt diff --git a/src/framework/simulation.cpp b/src/framework/simulation.cpp index b913379b4..6735eda79 100644 --- a/src/framework/simulation.cpp +++ b/src/framework/simulation.cpp @@ -1,40 +1,113 @@ #include "framework/simulation.h" #include "defaults.h" +#include "enums.h" #include "global.h" #include "utils/cargs.h" #include "utils/error.h" #include "utils/formatting.h" +#include "utils/log.h" #include "utils/plog.h" +#include "utils/toml.h" -#include "framework/parameters.h" - -#include - +#include #include namespace ntt { Simulation::Simulation(int argc, char* argv[]) { - GlobalInitialize(argc, argv); - cargs::CommandLineArguments cl_args; cl_args.readCommandLineArguments(argc, argv); const auto inputfname = static_cast( cl_args.getArgument("-input", defaults::input_filename)); - const auto outputdir = static_cast( - cl_args.getArgument("-output", defaults::output_path)); - const auto inputdata = toml::parse(inputfname); - const auto sim_name = toml::find(inputdata, "simulation", "name"); - logger::initPlog(sim_name); + const bool is_resuming = (cl_args.isSpecified("-continue") or + cl_args.isSpecified("-restart") or + cl_args.isSpecified("-resume") or + cl_args.isSpecified("-checkpoint")); + GlobalInitialize(argc, argv); + + const auto raw_params = toml::parse(inputfname); + const auto sim_name = toml::find(raw_params, "simulation", "name"); + const auto log_level = toml::find_or(raw_params, + "diagnostics", + "log_level", + defaults::diag::log_level); + logger::initPlog(sim_name, + log_level); + + m_requested_engine = SimEngine::pick( + fmt::toLower(toml::find(raw_params, "simulation", "engine")).c_str()); + m_requested_metric = Metric::pick( + fmt::toLower(toml::find(raw_params, "grid", "metric", "metric")) + .c_str()); + + const auto res = toml::find>(raw_params, + "grid", + "resolution"); + raise::ErrorIf(res.size() < 1 || res.size() > 3, + "invalid `grid.resolution`", + HERE); + m_requested_dimension = static_cast(res.size()); + + m_params.setRawData(raw_params); + timestep_t checkpoint_step = 0; - params = SimulationParams(inputdata); + if (is_resuming) { + logger::Checkpoint("Reading params from a checkpoint", HERE); + const auto checkpoint_write_path = toml::find_or( + raw_params, + "checkpoint", + "write_path", + fmt::format(defaults::checkpoint::write_path.c_str(), sim_name.c_str())); + const path_t checkpoint_read_path = toml::find_or( + raw_params, + "checkpoint", + "read_path", + checkpoint_write_path); + if (not std::filesystem::exists(checkpoint_read_path)) { + raise::Fatal("No checkpoints found", HERE); + } + for (const auto& entry : + std::filesystem::directory_iterator(checkpoint_read_path)) { + const auto fname = entry.path().filename().string(); + if (fname.find("step-") == 0) { + const timestep_t step = std::stoi(fname.substr(5, fname.size() - 5 - 3)); + if (step > checkpoint_step) { + checkpoint_step = step; + } + } + } + path_t checkpoint_metafname = checkpoint_read_path / + fmt::format("meta-%08lu.toml", checkpoint_step); + if (not std::filesystem::exists(checkpoint_metafname)) { + raise::Fatal( + fmt::format("metainformation for %lu not found", checkpoint_step), + HERE); + checkpoint_metafname = inputfname; + } + logger::Checkpoint(fmt::format("Using %08lu", checkpoint_step), HERE); + const auto raw_checkpoint_params = toml::parse(checkpoint_metafname); + const auto start_time = toml::find(raw_checkpoint_params, + "metadata", + "time"); + m_params.setImmutableParams(raw_checkpoint_params); + m_params.setMutableParams(raw_params); + m_params.setCheckpointParams(true, checkpoint_step, start_time); + m_params.setSetupParams(raw_checkpoint_params); + } else { + logger::Checkpoint("Defining new params", HERE); + m_params.setImmutableParams(raw_params); + m_params.setMutableParams(raw_params); + m_params.setCheckpointParams(false, 0, 0.0); + m_params.setSetupParams(raw_params); + } + m_params.checkPromises(); } Simulation::~Simulation() { GlobalFinalize(); } -} // namespace ntt \ No newline at end of file +} // namespace ntt diff --git a/src/framework/simulation.h b/src/framework/simulation.h index 2f4d3d321..33750030f 100644 --- a/src/framework/simulation.h +++ b/src/framework/simulation.h @@ -18,16 +18,18 @@ #include "arch/traits.h" #include "utils/error.h" +#include "utils/toml.h" #include "framework/parameters.h" -#include -#include - namespace ntt { class Simulation { - SimulationParams params; + SimulationParams m_params; + + Dimension m_requested_dimension; + SimEngine m_requested_engine { SimEngine::INVALID }; + Metric m_requested_metric { Metric::INVALID }; public: Simulation(int argc, char* argv[]); @@ -41,7 +43,7 @@ namespace ntt { static_assert(traits::has_method::value, "Engine must contain a ::run() method"); try { - engine_t engine { params }; + engine_t engine { m_params }; engine.run(); } catch (const std::exception& e) { raise::Fatal(e.what(), HERE); @@ -50,17 +52,17 @@ namespace ntt { [[nodiscard]] inline auto requested_dimension() const -> Dimension { - return params.get("grid.dim"); + return m_requested_dimension; } [[nodiscard]] inline auto requested_engine() const -> SimEngine { - return params.get("simulation.engine"); + return m_requested_engine; } [[nodiscard]] inline auto requested_metric() const -> Metric { - return params.get("grid.metric.metric"); + return m_requested_metric; } }; diff --git a/src/framework/tests/CMakeLists.txt b/src/framework/tests/CMakeLists.txt index c09d4ecc0..92e327d80 100644 --- a/src/framework/tests/CMakeLists.txt +++ b/src/framework/tests/CMakeLists.txt @@ -1,19 +1,23 @@ +# cmake-lint: disable=C0103,C0111 # ------------------------------ # @brief: Generates tests for the `ntt_framework` module +# # @uses: -# - kokkos [required] -# - plog [required] -# - toml11 [required] -# - mpi [optional] -# - adios2 [optional] +# +# * kokkos [required] +# * plog [required] +# * mpi [optional] +# * adios2 [optional] +# # !TODO: -# - add tests for mesh separately -# - add test for 3D metadomain +# +# * add tests for mesh separately +# * add test for 3D metadomain # ------------------------------ set(SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../) -function(gen_test title) +function(gen_test title is_parallel) set(exec test-framework-${title}.xc) set(src ${title}.cpp) add_executable(${exec} ${src}) @@ -22,24 +26,30 @@ function(gen_test title) add_dependencies(${exec} ${libs}) target_link_libraries(${exec} PRIVATE ${libs}) - add_test(NAME "FRAMEWORK::${title}" COMMAND "${exec}") + if(${is_parallel}) + add_test(NAME "FRAMEWORK::${title}" + COMMAND "${MPIEXEC_EXECUTABLE}" "${MPIEXEC_NUMPROC_FLAG}" "4" + "${exec}") + else() + add_test(NAME "FRAMEWORK::${title}" COMMAND "${exec}") + endif() endfunction() -if (${mpi}) -gen_test(comm_mpi) +if(${mpi}) + gen_test(comm_mpi true) else() -gen_test(parameters) -gen_test(particles) -gen_test(fields) -gen_test(grid_mesh) -if (${DEBUG}) - gen_test(metadomain) + gen_test(parameters false) + gen_test(particles false) + gen_test(fields false) + gen_test(grid_mesh false) + if(${DEBUG}) + gen_test(metadomain false) + endif() + gen_test(comm_nompi false) endif() -gen_test(comm_nompi) -endif() - # this test is only run manually to ensure ... # ... command line args are working properly ... # ... and that the logging is done correctly -# gen_test(simulation) +# +# gen_test(simulation) diff --git a/src/framework/tests/comm_mpi.cpp b/src/framework/tests/comm_mpi.cpp index 2f65defd6..487976f73 100644 --- a/src/framework/tests/comm_mpi.cpp +++ b/src/framework/tests/comm_mpi.cpp @@ -5,6 +5,11 @@ #include "arch/directions.h" #include "arch/kokkos_aliases.h" +#include "utils/error.h" +#include "utils/numeric.h" + +#include +#include #include #include @@ -13,49 +18,227 @@ using namespace ntt; auto main(int argc, char* argv[]) -> int { Kokkos::initialize(argc, argv); + MPI_Init(&argc, &argv); try { - const std::size_t nx1 = 15, nx2 = 15; - ndfield_t fld_b1 { "fld", nx1 + 2 * N_GHOSTS, nx2 + 2 * N_GHOSTS }; - ndfield_t fld_b2 { "fld", nx1 + 2 * N_GHOSTS, nx2 + 2 * N_GHOSTS }; + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + + const ncells_t nx1 = 11, nx2 = 15; + ndfield_t fld { "fld", nx1 + 2 * N_GHOSTS, nx2 + 2 * N_GHOSTS }; Kokkos::parallel_for( "Fill", CreateRangePolicy({ 0, 0 }, { nx1 + 2 * N_GHOSTS, nx2 + 2 * N_GHOSTS }), Lambda(index_t i1, index_t i2) { - if ((i1 >= 2 * N_GHOSTS) and (i1 < nx1) and (i2 >= 2 * N_GHOSTS) and - (i2 < nx2)) { - fld_b1(i1, i2, 0) = 4.0; - fld_b1(i1, i2, 1) = 12.0; - fld_b1(i1, i2, 2) = 20.0; - fld_b2(i1, i2, 0) = 4.0; - fld_b2(i1, i2, 1) = 12.0; - fld_b2(i1, i2, 2) = 20.0; - } else if ( - ((i1 < 2 * N_GHOSTS or i1 >= nx1) and (i2 >= 2 * N_GHOSTS and i2 < nx2)) or - ((i2 < 2 * N_GHOSTS or i2 >= nx2) and (i1 >= 2 * N_GHOSTS and i1 < nx1))) { - fld_b1(i1, i2, 0) = 2.0; - fld_b1(i1, i2, 1) = 6.0; - fld_b1(i1, i2, 2) = 10.0; - fld_b2(i1, i2, 0) = 2.0; - fld_b2(i1, i2, 1) = 6.0; - fld_b2(i1, i2, 2) = 10.0; - } else { - fld_b1(i1, i2, 0) = 1.0; - fld_b1(i1, i2, 1) = 3.0; - fld_b1(i1, i2, 2) = 5.0; - fld_b2(i1, i2, 0) = 1.0; - fld_b2(i1, i2, 1) = 3.0; - fld_b2(i1, i2, 2) = 5.0; + if ((i1 >= N_GHOSTS) and (i1 < N_GHOSTS + nx1) and (i2 >= N_GHOSTS) and + (i2 < N_GHOSTS + nx2)) { + fld(i1, i2, 0) = static_cast(rank + 1) + 4.0; + fld(i1, i2, 1) = static_cast(rank + 1) + 12.0; + fld(i1, i2, 2) = static_cast(rank + 1) + 20.0; + } + }); + + { + // send right, recv left + const int send_idx = (rank + 1) % size; + const int recv_idx = (rank - 1 + size) % size; + const unsigned int send_rank = (unsigned int)send_idx; + const unsigned int recv_rank = (unsigned int)recv_idx; + + const std::vector send_slice { + { nx1, nx1 + N_GHOSTS }, + { N_GHOSTS, nx2 + N_GHOSTS } + }; + const std::vector recv_slice { + { 0, N_GHOSTS }, + { N_GHOSTS, nx2 + N_GHOSTS } + }; + const range_tuple_t comp_slice { 0, 3 }; + comm::CommunicateField((unsigned int)(rank), + fld, + fld, + send_idx, + recv_idx, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_slice, + false); + } + { + // recv right, send left + const int send_idx = (rank - 1 + size) % size; + const int recv_idx = (rank + 1) % size; + const unsigned int send_rank = (unsigned int)send_idx; + const unsigned int recv_rank = (unsigned int)recv_idx; + + const std::vector send_slice { + { N_GHOSTS, N_GHOSTS + 2 }, + { N_GHOSTS, nx2 + N_GHOSTS } + }; + const std::vector recv_slice { + { nx1 + N_GHOSTS, nx1 + 2 * N_GHOSTS }, + { N_GHOSTS, nx2 + N_GHOSTS } + }; + const range_tuple_t comp_slice { 0, 3 }; + comm::CommunicateField((unsigned int)(rank), + fld, + fld, + send_idx, + recv_idx, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_slice, + false); + } + + { + const auto left_expect = static_cast((rank - 1 + size) % size + 1); + const auto right_expect = static_cast((rank + 1) % size + 1); + + Kokkos::parallel_for( + "Check", + CreateRangePolicy({ N_GHOSTS }, { nx2 + N_GHOSTS }), + Lambda(index_t i2) { + for (auto i1 { 0u }; i1 < N_GHOSTS; ++i1) { + if (fld(i1, i2, 0) != left_expect + 4.0) { + raise::KernelError(HERE, "Left boundary not correct for #0"); + } + if (fld(i1, i2, 1) != left_expect + 12.0) { + raise::KernelError(HERE, "Left boundary not correct for #1"); + } + if (fld(i1, i2, 2) != left_expect + 20.0) { + raise::KernelError(HERE, "Left boundary not correct for #2"); + } + } + for (auto i1 { nx1 + N_GHOSTS }; i1 < nx1 + 2 * N_GHOSTS; ++i1) { + if (fld(i1, i2, 0) != right_expect + 4.0) { + raise::KernelError(HERE, "Right boundary not correct for #0"); + } + if (fld(i1, i2, 1) != right_expect + 12.0) { + raise::KernelError(HERE, "Right boundary not correct for #1"); + } + if (fld(i1, i2, 2) != right_expect + 20.0) { + raise::KernelError(HERE, "Right boundary not correct for #2"); + } + } + }); + } + + Kokkos::parallel_for( + "Carve", + CreateRangePolicy({ 0, 0 }, + { nx1 + 2 * N_GHOSTS, nx2 + 2 * N_GHOSTS }), + Lambda(index_t i1, index_t i2) { + if (((i1 >= N_GHOSTS) and (i1 < 2 * N_GHOSTS)) or + ((i1 >= nx1) and (i1 < nx1 + N_GHOSTS))) { + fld(i1, i2, 0) = ZERO; + fld(i1, i2, 1) = ZERO; + fld(i1, i2, 2) = ZERO; } }); + + { + // send right, recv left + const int send_idx = (rank + 1) % size; + const int recv_idx = (rank - 1 + size) % size; + const unsigned int send_rank = (unsigned int)send_idx; + const unsigned int recv_rank = (unsigned int)recv_idx; + + const std::vector send_slice { + { nx1 + N_GHOSTS, nx1 + 2 * N_GHOSTS }, + { N_GHOSTS, nx2 + N_GHOSTS } + }; + const std::vector recv_slice { + { N_GHOSTS, 2 * N_GHOSTS }, + { N_GHOSTS, nx2 + N_GHOSTS } + }; + const range_tuple_t comp_slice { 0, 3 }; + comm::CommunicateField((unsigned int)(rank), + fld, + fld, + send_idx, + recv_idx, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_slice, + true); + } + { + // recv right, send left + const int send_idx = (rank - 1 + size) % size; + const int recv_idx = (rank + 1) % size; + const unsigned int send_rank = (unsigned int)send_idx; + const unsigned int recv_rank = (unsigned int)recv_idx; + + const std::vector send_slice { + { 0, N_GHOSTS }, + { N_GHOSTS, nx2 + N_GHOSTS } + }; + const std::vector recv_slice { + { nx1, nx1 + N_GHOSTS }, + { N_GHOSTS, nx2 + N_GHOSTS } + }; + const range_tuple_t comp_slice { 0, 3 }; + comm::CommunicateField((unsigned int)(rank), + fld, + fld, + send_idx, + recv_idx, + send_rank, + recv_rank, + send_slice, + recv_slice, + comp_slice, + true); + } + + { + const auto expect = static_cast(rank + 1); + Kokkos::parallel_for( + "Check", + CreateRangePolicy({ N_GHOSTS }, { nx2 + N_GHOSTS }), + Lambda(index_t i2) { + for (auto i1 { N_GHOSTS }; i1 < 2 * N_GHOSTS; ++i1) { + if (fld(i1, i2, 0) != expect + 4.0) { + raise::KernelError(HERE, "Left boundary not correct for #0"); + } + if (fld(i1, i2, 1) != expect + 12.0) { + raise::KernelError(HERE, "Left boundary not correct for #1"); + } + if (fld(i1, i2, 2) != expect + 20.0) { + raise::KernelError(HERE, "Left boundary not correct for #2"); + } + } + for (auto i1 { nx1 }; i1 < nx1 + N_GHOSTS; ++i1) { + if (fld(i1, i2, 0) != expect + 4.0) { + raise::KernelError(HERE, "Right boundary not correct for #0"); + } + if (fld(i1, i2, 1) != expect + 12.0) { + raise::KernelError(HERE, "Right boundary not correct for #1"); + } + if (fld(i1, i2, 2) != expect + 20.0) { + raise::KernelError(HERE, "Right boundary not correct for #2"); + } + } + }); + } } catch (std::exception& e) { std::cerr << "Exception: " << e.what() << std::endl; + MPI_Finalize(); Kokkos::finalize(); return 1; } + MPI_Finalize(); Kokkos::finalize(); return 0; -} \ No newline at end of file +} diff --git a/src/framework/tests/comm_nompi.cpp b/src/framework/tests/comm_nompi.cpp index 05d54d589..c7646ef03 100644 --- a/src/framework/tests/comm_nompi.cpp +++ b/src/framework/tests/comm_nompi.cpp @@ -45,12 +45,12 @@ auto main(int argc, char* argv[]) -> int { Kokkos::deep_copy(buff, ZERO); const auto send_slice = std::vector { - {nx1 + N_GHOSTS, nx1 + 2 * N_GHOSTS}, - {nx2 + N_GHOSTS, nx2 + 2 * N_GHOSTS} + { nx1 + N_GHOSTS, nx1 + 2 * N_GHOSTS }, + { nx2 + N_GHOSTS, nx2 + 2 * N_GHOSTS } }; const auto recv_slice = std::vector { - {N_GHOSTS, 2 * N_GHOSTS}, - {N_GHOSTS, 2 * N_GHOSTS} + { N_GHOSTS, 2 * N_GHOSTS }, + { N_GHOSTS, 2 * N_GHOSTS } }; const auto comp_slice = range_tuple_t(cur::jx1, cur::jx3 + 1); diff --git a/src/framework/tests/grid_mesh.cpp b/src/framework/tests/grid_mesh.cpp index 4dea275ce..952d9874d 100644 --- a/src/framework/tests/grid_mesh.cpp +++ b/src/framework/tests/grid_mesh.cpp @@ -21,27 +21,26 @@ auto main(int argc, char* argv[]) -> int { using namespace metric; const auto res = std::vector { 10, 10, 10 }; const auto ext = boundaries_t { - {-1.0, 1.0}, - {-1.0, 1.0}, - {-1.0, 1.0} + { -1.0, 1.0 }, + { -1.0, 1.0 }, + { -1.0, 1.0 } }; auto mesh = Mesh>(res, ext, {}); for (const auto& d : { in::x1, in::x2, in::x3 }) { raise::ErrorIf(mesh.i_min(d) != N_GHOSTS, "i_min != N_GHOSTS", HERE); - raise::ErrorIf(mesh.i_max(d) != res[(unsigned short)d] + N_GHOSTS, + raise::ErrorIf(mesh.i_max(d) != res[(dim_t)d] + N_GHOSTS, "i_max != res+N_GHOSTS", HERE); - raise::ErrorIf(mesh.n_active(d) != res[(unsigned short)d], - "n_active != res", - HERE); - raise::ErrorIf(mesh.n_all(d) != res[(unsigned short)d] + 2 * N_GHOSTS, + raise::ErrorIf(mesh.n_active(d) != res[(dim_t)d], "n_active != res", HERE); + raise::ErrorIf(mesh.n_all(d) != res[(dim_t)d] + 2 * N_GHOSTS, "n_all != res+2*N_GHOSTS", HERE); - raise::ErrorIf(mesh.extent(d) != ext[(unsigned short)d], "extent != ext", HERE); + raise::ErrorIf(mesh.extent(d) != ext[(dim_t)d], "extent != ext", HERE); } - raise::ErrorIf(not cmp::AlmostEqual(mesh.metric.dxMin(), (real_t)(0.2 / std::sqrt(3.0))), - "dxMin wrong", - HERE); + raise::ErrorIf( + not cmp::AlmostEqual(mesh.metric.dxMin(), (real_t)(0.2 / std::sqrt(3.0))), + "dxMin wrong", + HERE); } catch (const std::exception& e) { std::cerr << e.what() << std::endl; Kokkos::finalize(); diff --git a/src/framework/tests/metadomain.cpp b/src/framework/tests/metadomain.cpp index 8f5865499..829a2b82f 100644 --- a/src/framework/tests/metadomain.cpp +++ b/src/framework/tests/metadomain.cpp @@ -22,31 +22,31 @@ auto main(int argc, char* argv[]) -> int { using namespace ntt; using namespace metric; { - const std::vector res { 64, 32 }; - const boundaries_t extent { - {1.0, 10.0}, - {0.0, constant::PI} + const std::vector res { 64, 32 }; + const boundaries_t extent { + { 1.0, 10.0 }, + { 0.0, constant::PI } }; const boundaries_t fldsbc { - {FldsBC::ATMOSPHERE, FldsBC::ABSORB}, - { FldsBC::AXIS, FldsBC::AXIS} + { FldsBC::ATMOSPHERE, FldsBC::MATCH }, + { FldsBC::AXIS, FldsBC::AXIS } }; const boundaries_t prtlbc { - {PrtlBC::ATMOSPHERE, PrtlBC::ABSORB}, - { PrtlBC::AXIS, PrtlBC::AXIS} + { PrtlBC::ATMOSPHERE, PrtlBC::ABSORB }, + { PrtlBC::AXIS, PrtlBC::AXIS } }; const std::map params { - {"r0", -ONE}, - { "h", (real_t)0.25} + { "r0", -ONE }, + { "h", (real_t)0.25 } }; #if defined(OUTPUT_ENABLED) Metadomain> metadomain { - 4, { -1, -1 }, - res, extent, fldsbc, prtlbc, params, {}, "disabled" + 4u, { -1, -1 }, + res, extent, fldsbc, prtlbc, params, {} }; #else Metadomain> metadomain { - 4, { -1, -1 }, + 4u, { -1, -1 }, res, extent, fldsbc, prtlbc, params, {} }; #endif @@ -132,7 +132,7 @@ auto main(int argc, char* argv[]) -> int { raise::ErrorIf(self.offset_ndomains()[0] != 1, "Domain::offset_ndomains() failed", HERE); - raise::ErrorIf(self.mesh.flds_bc_in({ +1, 0 }) != FldsBC::ABSORB, + raise::ErrorIf(self.mesh.flds_bc_in({ +1, 0 }) != FldsBC::MATCH, "Mesh::flds_bc_in() failed", HERE); raise::ErrorIf(self.mesh.prtl_bc_in({ +1, 0 }) != PrtlBC::ABSORB, @@ -203,4 +203,4 @@ auto main(int argc, char* argv[]) -> int { } Kokkos::finalize(); return 0; -} \ No newline at end of file +} diff --git a/src/framework/tests/parameters.cpp b/src/framework/tests/parameters.cpp index 8f02d1750..07b2c11b3 100644 --- a/src/framework/tests/parameters.cpp +++ b/src/framework/tests/parameters.cpp @@ -5,11 +5,11 @@ #include "utils/comparators.h" #include "utils/error.h" +#include "utils/toml.h" #include "framework/containers/species.h" #include -#include #include #include @@ -29,12 +29,11 @@ const auto mink_1d = u8R"( metric = "minkowski" [grid.boundaries] - fields = [["PERIODIC"]] + fields = [["MATCH", "MATCH"]] particles = [["ABSORB", "ABSORB"]] - [grid.boundaries.absorb] - coeff = 10.0 - ds = 0.025 + [grid.boundaries.match] + ds = [[0.025, 0.1]] [scales] larmor0 = 0.1 @@ -48,7 +47,7 @@ const auto mink_1d = u8R"( [particles] ppc0 = 10.0 - sort_interval = 100 + clear_interval = 100 [[particles.species]] label = "e-" @@ -73,13 +72,18 @@ const auto mink_1d = u8R"( mystr = "hi" [output] - fields = ["Rho", "J", "B"] - particles = ["X", "U"] format = "hdf5" - mom_smooth = 2 - fields_stride = 1 - prtl_stride = 100 - interval_time = 0.01 + + [output.fields] + quantities = ["Rho", "J", "B"] + mom_smooth = 2 + downsampling = [4, 5] + interval = 100 + + [output.particles] + species = [1, 2] + stride = 100 + interval_time = 0.01 )"_toml; const auto sph_2d = u8R"( @@ -96,12 +100,9 @@ const auto sph_2d = u8R"( metric = "spherical" [grid.boundaries] - fields = [["ATMOSPHERE", "ABSORB"]] + fields = [["ATMOSPHERE", "MATCH"]] particles = [["ATMOSPHERE", "ABSORB"]] - [grid.boundaries.absorb] - coeff = 10.0 - [grid.boundaries.atmosphere] temperature = 0.1 density = 1.0 @@ -129,7 +130,7 @@ const auto sph_2d = u8R"( [particles] ppc0 = 25.0 use_weights = true - sort_interval = 50 + clear_interval = 50 [[particles.species]] @@ -175,7 +176,7 @@ const auto qks_2d = u8R"( ks_a = 0.99 [grid.boundaries] - fields = [["ABSORB"]] + fields = [["MATCH"]] particles = [["ABSORB"]] [scales] @@ -194,7 +195,7 @@ const auto qks_2d = u8R"( [particles] ppc0 = 4.0 - sort_interval = 100 + clear_interval = 100 [[particles.species]] label = "e-" @@ -242,7 +243,11 @@ auto main(int argc, char* argv[]) -> int { using namespace ntt; { - const auto params_mink_1d = SimulationParams(mink_1d); + auto params_mink_1d = SimulationParams(); + params_mink_1d.setImmutableParams(mink_1d); + params_mink_1d.setMutableParams(mink_1d); + params_mink_1d.setSetupParams(mink_1d); + params_mink_1d.checkPromises(); assert_equal(params_mink_1d.get("grid.metric.metric"), Metric::Minkowski, @@ -260,7 +265,7 @@ auto main(int argc, char* argv[]) -> int { (real_t)0.0078125, "scales.V0"); boundaries_t fbc = { - {FldsBC::PERIODIC, FldsBC::PERIODIC} + { FldsBC::MATCH, FldsBC::MATCH } }; assert_equal( params_mink_1d.get>("grid.boundaries.fields")[0].first, @@ -274,6 +279,14 @@ auto main(int argc, char* argv[]) -> int { params_mink_1d.get>("grid.boundaries.fields").size(), fbc.size(), "grid.boundaries.fields.size()"); + assert_equal( + params_mink_1d.get>("grid.boundaries.match.ds")[0].first, + (real_t)0.025, + "grid.boundaries.match.ds[0].first"); + assert_equal( + params_mink_1d.get>("grid.boundaries.match.ds")[0].second, + (real_t)0.1, + "grid.boundaries.match.ds[0].first"); const auto species = params_mink_1d.get>( "particles.species"); @@ -311,10 +324,21 @@ auto main(int argc, char* argv[]) -> int { assert_equal(params_mink_1d.get("setup.mystr"), "hi", "setup.mystr"); + + const auto output_stride = params_mink_1d.get>( + "output.fields.downsampling"); + assert_equal(output_stride.size(), + 1, + "output.fields.downsampling.size()"); + assert_equal(output_stride[0], 4, "output.fields.downsampling[0]"); } { - const auto params_sph_2d = SimulationParams(sph_2d); + auto params_sph_2d = SimulationParams(); + params_sph_2d.setImmutableParams(sph_2d); + params_sph_2d.setMutableParams(sph_2d); + params_sph_2d.setSetupParams(sph_2d); + params_sph_2d.checkPromises(); assert_equal(params_sph_2d.get("grid.metric.metric"), Metric::Spherical, @@ -325,8 +349,8 @@ auto main(int argc, char* argv[]) -> int { "simulation.engine"); boundaries_t fbc = { - {FldsBC::ATMOSPHERE, FldsBC::ABSORB}, - { FldsBC::AXIS, FldsBC::AXIS} + { FldsBC::ATMOSPHERE, FldsBC::MATCH }, + { FldsBC::AXIS, FldsBC::AXIS } }; assert_equal(params_sph_2d.get("scales.B0"), @@ -361,16 +385,11 @@ auto main(int argc, char* argv[]) -> int { fbc.size(), "grid.boundaries.fields.size()"); - // absorb coeffs + // match coeffs assert_equal( - params_sph_2d.get("grid.boundaries.absorb.ds"), - (real_t)(defaults::bc::absorb::ds_frac * 19.0), - "grid.boundaries.absorb.ds"); - - assert_equal( - params_sph_2d.get("grid.boundaries.absorb.coeff"), - (real_t)10.0, - "grid.boundaries.absorb.coeff"); + params_sph_2d.get>("grid.boundaries.match.ds")[0].second, + (real_t)(defaults::bc::match::ds_frac * 19.0), + "grid.boundaries.match.ds"); assert_equal(params_sph_2d.get("particles.use_weights"), true, @@ -427,7 +446,11 @@ auto main(int argc, char* argv[]) -> int { } { - const auto params_qks_2d = SimulationParams(qks_2d); + auto params_qks_2d = SimulationParams(); + params_qks_2d.setImmutableParams(qks_2d); + params_qks_2d.setMutableParams(qks_2d); + params_qks_2d.setSetupParams(qks_2d); + params_qks_2d.checkPromises(); assert_equal(params_qks_2d.get("grid.metric.metric"), Metric::QKerr_Schild, @@ -456,9 +479,9 @@ auto main(int argc, char* argv[]) -> int { "grid.metric.ks_rh"); const auto expect = std::map { - {"r0", 0.0}, - { "h", 0.25}, - { "a", 0.99} + { "r0", 0.0 }, + { "h", 0.25 }, + { "a", 0.99 } }; auto read = params_qks_2d.get>( "grid.metric.params"); @@ -477,8 +500,8 @@ auto main(int argc, char* argv[]) -> int { "algorithms.gr.pusher_niter"); boundaries_t pbc = { - {PrtlBC::HORIZON, PrtlBC::ABSORB}, - { PrtlBC::AXIS, PrtlBC::AXIS} + { PrtlBC::HORIZON, PrtlBC::ABSORB }, + { PrtlBC::AXIS, PrtlBC::AXIS } }; assert_equal(params_qks_2d.get("scales.B0"), @@ -513,16 +536,11 @@ auto main(int argc, char* argv[]) -> int { pbc.size(), "grid.boundaries.particles.size()"); - // absorb coeffs + // match coeffs assert_equal( - params_qks_2d.get("grid.boundaries.absorb.ds"), - (real_t)(defaults::bc::absorb::ds_frac * (100.0 - 0.8)), - "grid.boundaries.absorb.ds"); - - assert_equal( - params_qks_2d.get("grid.boundaries.absorb.coeff"), - defaults::bc::absorb::coeff, - "grid.boundaries.absorb.coeff"); + params_qks_2d.get>("grid.boundaries.match.ds")[0].second, + (real_t)(defaults::bc::match::ds_frac * (100.0 - 0.8)), + "grid.boundaries.match.ds"); const auto species = params_qks_2d.get>( "particles.species"); @@ -555,86 +573,3 @@ auto main(int argc, char* argv[]) -> int { return 0; } - -// const auto mink_1d = R"( -// [simulation] -// name = "" -// engine = "" -// runtime = "" - -// [grid] -// resolution = "" -// extent = "" - -// [grid.metric] -// metric = "" -// qsph_r0 = "" -// qsph_h = "" -// ks_a = "" - -// [grid.boundaries] -// fields = "" -// particles = "" -// absorb_d = "" -// absorb_coeff = "" - -// [scales] -// larmor0 = "" -// skindepth0 = "" - -// [algorithms] -// current_filters = "" - -// [algorithms.toggles] -// fieldsolver = "" -// deposit = "" - -// [algorithms.timestep] -// CFL = "" -// correction = "" - -// [algorithms.gr] -// pusher_eps = "" -// pusher_niter = "" - -// [algorithms.gca] -// e_ovr_b_max = "" -// larmor_max = "" - -// [algorithms.synchrotron] -// gamma_rad = "" - -// [particles] -// ppc0 = "" -// use_weights = "" -// sort_interval = "" - -// [[particles.species]] -// label = "" -// mass = "" -// charge = "" -// maxnpart = "" -// pusher = "" -// n_payloads = "" -// cooling = "" -// [setup] - -// [output] -// fields = "" -// particles = "" -// format = "" -// mom_smooth = "" -// fields_stride = "" -// prtl_stride = "" -// interval = "" -// interval_time = "" - -// [output.debug] -// as_is = "" -// ghosts = "" - -// [diagnostics] -// interval = "" -// log_level = "" -// blocking_timers = "" -// )"_toml; diff --git a/src/framework/tests/particles.cpp b/src/framework/tests/particles.cpp index dabcc062f..6c4c227b5 100644 --- a/src/framework/tests/particles.cpp +++ b/src/framework/tests/particles.cpp @@ -46,9 +46,9 @@ void testParticles(const int& index, raise::ErrorIf(p.tag.extent(0) != maxnpart, "tag incorrectly allocated", HERE); raise::ErrorIf(p.weight.extent(0) != maxnpart, "weight incorrectly allocated", HERE); - raise::ErrorIf(p.pld.size() != npld, "Number of payloads mismatch", HERE); - for (unsigned short n { 0 }; n < npld; ++n) { - raise::ErrorIf(p.pld[n].extent(0) != maxnpart, "pld incorrectly allocated", HERE); + if (npld > 0) { + raise::ErrorIf(p.pld.extent(0) != maxnpart, "pld incorrectly allocated", HERE); + raise::ErrorIf(p.pld.extent(1) != npld, "pld incorrectly allocated", HERE); } if constexpr ((D == Dim::_2D) || (D == Dim::_3D)) { @@ -117,7 +117,8 @@ auto main(int argc, char** argv) -> int { 0.0, 100, PrtlPusher::PHOTON, - Cooling::NONE); + Cooling::NONE, + 5); testParticles(4, "e+", 1.0, @@ -131,7 +132,8 @@ auto main(int argc, char** argv) -> int { 1.0, 100, PrtlPusher::BORIS, - Cooling::NONE); + Cooling::NONE, + 1); } catch (const std::exception& e) { std::cerr << "Error: " << e.what() << std::endl; Kokkos::finalize(); @@ -139,4 +141,4 @@ auto main(int argc, char** argv) -> int { } Kokkos::finalize(); return 0; -} \ No newline at end of file +} diff --git a/src/global/CMakeLists.txt b/src/global/CMakeLists.txt index d00689283..7546b6a98 100644 --- a/src/global/CMakeLists.txt +++ b/src/global/CMakeLists.txt @@ -1,26 +1,39 @@ +# cmake-lint: disable=C0103 # ------------------------------ # @defines: ntt_global [STATIC/SHARED] +# # @sources: -# - global.cpp -# - arch/kokkos_aliases.cpp -# - utils/cargs.cpp +# +# * global.cpp +# * arch/kokkos_aliases.cpp +# * utils/cargs.cpp +# * utils/param_container.cpp +# * utils/timer.cpp +# * utils/diag.cpp +# * utils/progressbar.cpp +# # @includes: -# - ./ +# +# * ./ +# # @uses: -# - kokkos [required] -# - plog [required] -# - mpi [optional] +# +# * kokkos [required] +# * plog [required] +# * mpi [optional] # ------------------------------ set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(SOURCES - ${SRC_DIR}/global.cpp - ${SRC_DIR}/arch/kokkos_aliases.cpp - ${SRC_DIR}/utils/cargs.cpp -) +set(SOURCES + ${SRC_DIR}/global.cpp ${SRC_DIR}/arch/kokkos_aliases.cpp + ${SRC_DIR}/utils/cargs.cpp ${SRC_DIR}/utils/timer.cpp + ${SRC_DIR}/utils/diag.cpp ${SRC_DIR}/utils/progressbar.cpp) +if(${output}) + list(APPEND SOURCES ${SRC_DIR}/utils/param_container.cpp) +endif() add_library(ntt_global ${SOURCES}) -target_include_directories(ntt_global +target_include_directories( + ntt_global PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - INTERFACE ${CMAKE_CURRENT_SOURCE_DIR} -) -target_link_libraries(ntt_global PRIVATE stdc++fs) \ No newline at end of file + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(ntt_global PRIVATE stdc++fs) diff --git a/src/global/arch/directions.h b/src/global/arch/directions.h index 19cf182d6..850bc130d 100644 --- a/src/global/arch/directions.h +++ b/src/global/arch/directions.h @@ -50,14 +50,14 @@ namespace dir { auto operator-() const -> direction_t { auto result = direction_t {}; - for (std::size_t i = 0; i < (short)D; ++i) { + for (auto i { 0u }; i < D; ++i) { result[i] = -(*this)[i]; } return result; } auto operator==(const direction_t& other) const -> bool { - for (std::size_t i = 0; i < (short)D; ++i) { + for (auto i { 0u }; i < D; ++i) { if ((*this)[i] != other[i]) { return false; } @@ -79,7 +79,7 @@ namespace dir { */ auto get_assoc_orth() const -> std::vector { auto result = std::vector {}; - for (std::size_t i = 0; i < this->size(); ++i) { + for (auto i = 0u; i < this->size(); ++i) { if ((*this)[i] != 0) { direction_t dir; dir[i] = (*this)[i]; @@ -91,7 +91,7 @@ namespace dir { auto get_sign() const -> short { short sign = 0; - for (std::size_t i = 0; i < this->size(); ++i) { + for (auto i = 0u; i < this->size(); ++i) { if ((*this)[i] != 0) { raise::ErrorIf(sign != 0, "Undefined signature for non-orth direction", @@ -105,7 +105,7 @@ namespace dir { auto get_dim() const -> in { short dir = -1; - for (std::size_t i = 0; i < this->size(); ++i) { + for (auto i = 0u; i < this->size(); ++i) { if ((*this)[i] != 0) { raise::ErrorIf(dir > 0, "Undefined dim for non-orth direction", HERE); dir = i; @@ -132,8 +132,8 @@ namespace dir { using dirs_t = std::vector>; template - inline auto operator<<(std::ostream& os, const direction_t& dir) - -> std::ostream& { + inline auto operator<<(std::ostream& os, + const direction_t& dir) -> std::ostream& { for (auto& d : dir) { os << std::setw(2) << std::left; if (d > 0) { @@ -175,81 +175,81 @@ namespace dir { template <> struct Directions { inline static const dirs_t all = { - {-1, -1}, - {-1, 0}, - {-1, 1}, - { 0, -1}, - { 0, 1}, - { 1, -1}, - { 1, 0}, - { 1, 1} + { -1, -1 }, + { -1, 0 }, + { -1, 1 }, + { 0, -1 }, + { 0, 1 }, + { 1, -1 }, + { 1, 0 }, + { 1, 1 } }; inline static const dirs_t orth = { - {-1, 0}, - { 0, -1}, - { 0, 1}, - { 1, 0} + { -1, 0 }, + { 0, -1 }, + { 0, 1 }, + { 1, 0 } }; inline static const dirs_t unique = { - { 0, 1}, - { 1, 1}, - { 1, 0}, - {-1, 1} + { 0, 1 }, + { 1, 1 }, + { 1, 0 }, + { -1, 1 } }; }; template <> struct Directions { inline static const dirs_t all = { - {-1, -1, -1}, - {-1, -1, 0}, - {-1, -1, 1}, - {-1, 0, -1}, - {-1, 0, 0}, - {-1, 0, 1}, - {-1, 1, -1}, - {-1, 1, 0}, - {-1, 1, 1}, - { 0, -1, -1}, - { 0, -1, 0}, - { 0, -1, 1}, - { 0, 0, -1}, - { 0, 0, 1}, - { 0, 1, -1}, - { 0, 1, 0}, - { 0, 1, 1}, - { 1, -1, -1}, - { 1, -1, 0}, - { 1, -1, 1}, - { 1, 0, -1}, - { 1, 0, 0}, - { 1, 0, 1}, - { 1, 1, -1}, - { 1, 1, 0}, - { 1, 1, 1} + { -1, -1, -1 }, + { -1, -1, 0 }, + { -1, -1, 1 }, + { -1, 0, -1 }, + { -1, 0, 0 }, + { -1, 0, 1 }, + { -1, 1, -1 }, + { -1, 1, 0 }, + { -1, 1, 1 }, + { 0, -1, -1 }, + { 0, -1, 0 }, + { 0, -1, 1 }, + { 0, 0, -1 }, + { 0, 0, 1 }, + { 0, 1, -1 }, + { 0, 1, 0 }, + { 0, 1, 1 }, + { 1, -1, -1 }, + { 1, -1, 0 }, + { 1, -1, 1 }, + { 1, 0, -1 }, + { 1, 0, 0 }, + { 1, 0, 1 }, + { 1, 1, -1 }, + { 1, 1, 0 }, + { 1, 1, 1 } }; inline static const dirs_t orth = { - {-1, 0, 0}, - { 0, -1, 0}, - { 0, 0, -1}, - { 0, 0, 1}, - { 0, 1, 0}, - { 1, 0, 0} + { -1, 0, 0 }, + { 0, -1, 0 }, + { 0, 0, -1 }, + { 0, 0, 1 }, + { 0, 1, 0 }, + { 1, 0, 0 } }; inline static const dirs_t unique = { - { 0, 0, 1}, - { 0, 1, 0}, - { 1, 0, 0}, - { 1, 1, 0}, - {-1, 1, 0}, - { 0, 1, 1}, - { 0, -1, 1}, - { 1, 0, 1}, - {-1, 0, 1}, - { 1, 1, 1}, - {-1, 1, 1}, - { 1, -1, 1}, - { 1, 1, -1} + { 0, 0, 1 }, + { 0, 1, 0 }, + { 1, 0, 0 }, + { 1, 1, 0 }, + { -1, 1, 0 }, + { 0, 1, 1 }, + { 0, -1, 1 }, + { 1, 0, 1 }, + { -1, 0, 1 }, + { 1, 1, 1 }, + { -1, 1, 1 }, + { 1, -1, 1 }, + { 1, 1, -1 } }; }; diff --git a/src/global/arch/kokkos_aliases.cpp b/src/global/arch/kokkos_aliases.cpp index 4311a40bd..e81d41280 100644 --- a/src/global/arch/kokkos_aliases.cpp +++ b/src/global/arch/kokkos_aliases.cpp @@ -4,83 +4,89 @@ #include +auto CreateParticleRangePolicy(npart_t p1, npart_t p2) -> range_t { + return Kokkos::RangePolicy(p1, p2); +} + template <> -auto CreateRangePolicy(const tuple_t& i1, - const tuple_t& i2) - -> range_t { +auto CreateRangePolicy( + const tuple_t& i1, + const tuple_t& i2) -> range_t { index_t i1min = i1[0]; index_t i1max = i2[0]; - return Kokkos::RangePolicy(i1min, i1max); + return Kokkos::RangePolicy(i1min, i1max); } template <> -auto CreateRangePolicy(const tuple_t& i1, - const tuple_t& i2) - -> range_t { +auto CreateRangePolicy( + const tuple_t& i1, + const tuple_t& i2) -> range_t { index_t i1min = i1[0]; index_t i1max = i2[0]; index_t i2min = i1[1]; index_t i2max = i2[1]; - return Kokkos::MDRangePolicy, AccelExeSpace>({ i1min, i2min }, - { i1max, i2max }); + return Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( + { i1min, i2min }, + { i1max, i2max }); } template <> -auto CreateRangePolicy(const tuple_t& i1, - const tuple_t& i2) - -> range_t { +auto CreateRangePolicy( + const tuple_t& i1, + const tuple_t& i2) -> range_t { index_t i1min = i1[0]; index_t i1max = i2[0]; index_t i2min = i1[1]; index_t i2max = i2[1]; index_t i3min = i1[2]; index_t i3max = i2[2]; - return Kokkos::MDRangePolicy, AccelExeSpace>( + return Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>( { i1min, i2min, i3min }, { i1max, i2max, i3max }); } template <> -auto CreateRangePolicyOnHost(const tuple_t& i1, - const tuple_t& i2) - -> range_h_t { +auto CreateRangePolicyOnHost( + const tuple_t& i1, + const tuple_t& i2) -> range_h_t { index_t i1min = i1[0]; index_t i1max = i2[0]; - return Kokkos::RangePolicy(i1min, i1max); + return Kokkos::RangePolicy(i1min, i1max); } template <> -auto CreateRangePolicyOnHost(const tuple_t& i1, - const tuple_t& i2) - -> range_h_t { +auto CreateRangePolicyOnHost( + const tuple_t& i1, + const tuple_t& i2) -> range_h_t { index_t i1min = i1[0]; index_t i1max = i2[0]; index_t i2min = i1[1]; index_t i2max = i2[1]; - return Kokkos::MDRangePolicy, HostExeSpace>({ i1min, i2min }, - { i1max, i2max }); + return Kokkos::MDRangePolicy, Kokkos::DefaultHostExecutionSpace>( + { i1min, i2min }, + { i1max, i2max }); } template <> -auto CreateRangePolicyOnHost(const tuple_t& i1, - const tuple_t& i2) - -> range_h_t { +auto CreateRangePolicyOnHost( + const tuple_t& i1, + const tuple_t& i2) -> range_h_t { index_t i1min = i1[0]; index_t i1max = i2[0]; index_t i2min = i1[1]; index_t i2max = i2[1]; index_t i3min = i1[2]; index_t i3max = i2[2]; - return Kokkos::MDRangePolicy, HostExeSpace>( + return Kokkos::MDRangePolicy, Kokkos::DefaultHostExecutionSpace>( { i1min, i2min, i3min }, { i1max, i2max, i3max }); } -// auto WaitAndSynchronize(bool debug_only) -> void { -// if (debug_only) { -// #ifndef DEBUG -// return; -// #endif -// } -// Kokkos::fence(); -// } \ No newline at end of file +auto WaitAndSynchronize(bool debug_only) -> void { + if (debug_only) { +#ifndef DEBUG + return; +#endif + } + Kokkos::fence(); +} diff --git a/src/global/arch/kokkos_aliases.h b/src/global/arch/kokkos_aliases.h index f9aac9685..adb0b6451 100644 --- a/src/global/arch/kokkos_aliases.h +++ b/src/global/arch/kokkos_aliases.h @@ -34,7 +34,7 @@ namespace math = Kokkos; template -using array_t = Kokkos::View; +using array_t = Kokkos::View; // Array mirror alias of arbitrary type template @@ -174,17 +174,17 @@ namespace kokkos_aliases_hidden { template <> struct range_impl { - using type = Kokkos::RangePolicy; + using type = Kokkos::RangePolicy; }; template <> struct range_impl { - using type = Kokkos::MDRangePolicy, AccelExeSpace>; + using type = Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>; }; template <> struct range_impl { - using type = Kokkos::MDRangePolicy, AccelExeSpace>; + using type = Kokkos::MDRangePolicy, Kokkos::DefaultExecutionSpace>; }; } // namespace kokkos_aliases_hidden @@ -201,17 +201,17 @@ namespace kokkos_aliases_hidden { template <> struct range_h_impl { - using type = Kokkos::RangePolicy; + using type = Kokkos::RangePolicy; }; template <> struct range_h_impl { - using type = Kokkos::MDRangePolicy, HostExeSpace>; + using type = Kokkos::MDRangePolicy, Kokkos::DefaultHostExecutionSpace>; }; template <> struct range_h_impl { - using type = Kokkos::MDRangePolicy, HostExeSpace>; + using type = Kokkos::MDRangePolicy, Kokkos::DefaultHostExecutionSpace>; }; } // namespace kokkos_aliases_hidden @@ -219,31 +219,38 @@ namespace kokkos_aliases_hidden { template using range_h_t = typename kokkos_aliases_hidden::range_h_impl::type; +/** + * @brief Function template for generating 1D Kokkos range policy for particles. + * @param p1 `npart_t`: min. + * @param p2 `npart_t`: max. + */ +auto CreateParticleRangePolicy(npart_t, npart_t) -> range_t; + /** * @brief Function template for generating ND Kokkos range policy. * @tparam D Dimension - * @param i1 array of size D `std::size_t`: { min }. - * @param i2 array of size D `std::size_t`: { max }. + * @param i1 array of size D `ncells_t`: { min }. + * @param i2 array of size D `ncells_t`: { max }. * @returns Kokkos::RangePolicy or Kokkos::MDRangePolicy in the accelerator execution space. */ template -auto CreateRangePolicy(const tuple_t&, - const tuple_t&) -> range_t; +auto CreateRangePolicy(const tuple_t&, + const tuple_t&) -> range_t; /** * @brief Function template for generating ND Kokkos range policy on the host. * @tparam D Dimension - * @param i1 array of size D `std::size_t`: { min }. - * @param i2 array of size D `std::size_t`: { max }. + * @param i1 array of size D `ncells_t`: { min }. + * @param i2 array of size D `ncells_t`: { max }. * @returns Kokkos::RangePolicy or Kokkos::MDRangePolicy in the host execution space. */ template -auto CreateRangePolicyOnHost(const tuple_t&, - const tuple_t&) -> range_h_t; +auto CreateRangePolicyOnHost(const tuple_t&, + const tuple_t&) -> range_h_t; // Random number pool/generator type alias -using random_number_pool_t = Kokkos::Random_XorShift1024_Pool; -using random_generator_t = typename random_number_pool_t::generator_type; +using random_number_pool_t = Kokkos::Random_XorShift1024_Pool; +using random_generator_t = typename random_number_pool_t::generator_type; // Random number generator functions template diff --git a/src/global/arch/mpi_aliases.h b/src/global/arch/mpi_aliases.h index 9669d6210..1f9a87f7b 100644 --- a/src/global/arch/mpi_aliases.h +++ b/src/global/arch/mpi_aliases.h @@ -14,6 +14,7 @@ #ifndef GLOBAL_ARCH_MPI_ALIASES_H #define GLOBAL_ARCH_MPI_ALIASES_H +#include #include #if defined(MPI_ENABLED) @@ -103,4 +104,4 @@ namespace mpi { #endif // MPI_ENABLED -#endif // GLOBAL_ARCH_MPI_ALIASES_H \ No newline at end of file +#endif // GLOBAL_ARCH_MPI_ALIASES_H diff --git a/src/global/arch/mpi_tags.h b/src/global/arch/mpi_tags.h index 0916542d4..aaf38a8f4 100644 --- a/src/global/arch/mpi_tags.h +++ b/src/global/arch/mpi_tags.h @@ -7,6 +7,8 @@ * @namespaces: * - mpi:: */ +#ifndef GLOBAL_ARCH_MPI_TAGS_H +#define GLOBAL_ARCH_MPI_TAGS_H #include "global.h" @@ -188,8 +190,13 @@ namespace mpi { tag; } - Inline auto SendTag(short tag, bool im1, bool ip1, bool jm1, bool jp1, bool km1, bool kp1) - -> short { + Inline auto SendTag(short tag, + bool im1, + bool ip1, + bool jm1, + bool jp1, + bool km1, + bool kp1) -> short { return ((im1 && jm1 && km1) * (PrtlSendTag::im1_jm1_km1 - 1) + (im1 && jm1 && kp1) * (PrtlSendTag::im1_jm1_kp1 - 1) + (im1 && jp1 && km1) * (PrtlSendTag::im1_jp1_km1 - 1) + @@ -226,3 +233,5 @@ namespace mpi { tag; } } // namespace mpi + +#endif // GLOBAL_ARCH_MPI_TAGS_H diff --git a/src/global/arch/traits.h b/src/global/arch/traits.h index e915bdf1a..d9e5e9310 100644 --- a/src/global/arch/traits.h +++ b/src/global/arch/traits.h @@ -10,7 +10,12 @@ * - traits::run_t, traits::to_string_t * - traits::pgen::init_flds_t * - traits::pgen::ext_force_t - * - traits::pgen::field_driver_t + * - traits::pgen::ext_current_t + * - traits::pgen::atm_fields_t + * - traits::pgen::match_fields_const_t + * - traits::pgen::match_fields_t + * - traits::pgen::fix_fields_const_t + * - traits::pgen::fix_fields_t * - traits::pgen::init_prtls_t * - traits::pgen::custom_fields_t * - traits::pgen::custom_field_output_t @@ -94,7 +99,43 @@ namespace traits { using ext_force_t = decltype(&T::ext_force); template - using field_driver_t = decltype(&T::FieldDriver); + using ext_current_t = decltype(&T::ext_current); + + template + using atm_fields_t = decltype(&T::AtmFields); + + template + using match_fields_t = decltype(&T::MatchFields); + + template + using match_fields_in_x1_t = decltype(&T::MatchFieldsInX1); + + template + using match_fields_in_x2_t = decltype(&T::MatchFieldsInX2); + + template + using match_fields_in_x3_t = decltype(&T::MatchFieldsInX3); + + template + using match_fields_const_t = decltype(&T::MatchFieldsConst); + + template + using fix_fields_t = decltype(&T::FixFields); + + template + using fix_fields_const_t = decltype(&T::FixFieldsConst); + + template + using perfect_conductor_fields_t = decltype(&T::PerfectConductorFields); + + template + using perfect_conductor_fields_const_t = decltype(&T::PerfectConductorFieldsConst); + + template + using perfect_conductor_currents_t = decltype(&T::PerfectConductorCurrents); + + template + using perfect_conductor_currents_const_t = decltype(&T::PerfectConductorCurrentsConst); template using custom_fields_t = decltype(&T::CustomFields); @@ -104,6 +145,9 @@ namespace traits { template using custom_field_output_t = decltype(&T::CustomFieldOutput); + + template + using custom_stat_t = decltype(&T::CustomStat); } // namespace pgen // for pgen extforce diff --git a/src/global/defaults.h b/src/global/defaults.h index d238a3492..9513493b1 100644 --- a/src/global/defaults.h +++ b/src/global/defaults.h @@ -16,16 +16,15 @@ namespace ntt::defaults { constexpr std::string_view input_filename = "input"; - constexpr std::string_view output_path = "output"; const real_t correction = 1.0; const real_t cfl = 0.95; const unsigned short current_filters = 0; - const std::string em_pusher = "Boris"; - const std::string ph_pusher = "Photon"; - const std::size_t sort_interval = 100; + const std::string em_pusher = "Boris"; + const std::string ph_pusher = "Photon"; + const timestep_t clear_interval = 100; namespace qsph { const real_t r0 = 0.0; @@ -42,26 +41,41 @@ namespace ntt::defaults { } // namespace gr namespace bc { + namespace match { + const real_t ds_frac = 0.01; + } // namespace match + namespace absorb { const real_t ds_frac = 0.01; - const real_t coeff = 1.0; } // namespace absorb - } // namespace bc + } // namespace bc namespace output { - const std::string format = "hdf5"; - const std::size_t interval = 100; - const unsigned short mom_smooth = 0; - const unsigned short flds_stride = 1; - const std::size_t prtl_stride = 100; - const real_t spec_emin = 1e-3; - const real_t spec_emax = 1e3; - const bool spec_log = true; - const std::size_t spec_nbins = 200; + const std::string format = "hdf5"; + const timestep_t interval = 100; + const unsigned short mom_smooth = 0; + const npart_t prtl_stride = 100; + const real_t spec_emin = 1e-3; + const real_t spec_emax = 1e3; + const bool spec_log = true; + const std::size_t spec_nbins = 200; + const std::vector stats_quantities = { "B^2", + "E^2", + "ExB", + "Rho", + "T00" }; } // namespace output + namespace checkpoint { + const timestep_t interval = 1000; + const int keep = 2; + const std::string walltime = "00:00:00"; + const std::string write_path = "%s.ckpt"; + } // namespace checkpoint + namespace diag { - const std::size_t interval = 1; + const timestep_t interval = 1; + const std::string log_level = "VERBOSE"; } // namespace diag namespace gca { diff --git a/src/global/enums.h b/src/global/enums.h index 57822dec4..08130a2c8 100644 --- a/src/global/enums.h +++ b/src/global/enums.h @@ -8,12 +8,14 @@ * - enum ntt::SimEngine // SRPIC, GRPIC * - enum ntt::PrtlBC // periodic, absorb, atmosphere, custom, * reflect, horizon, axis, sync - * - enum ntt::FldsBC // periodic, absorb, atmosphere, custom, - * conductor, horizon, axis, sync + * - enum ntt::FldsBC // periodic, match, fixed, atmosphere, + * custom, horizon, axis, conductor, sync * - enum ntt::PrtlPusher // boris, vay, photon, none * - enum ntt::Cooling // synchrotron, none * - enum ntt::FldsID // e, dive, d, divd, b, h, j, - * a, t, rho, charge, n, nppc, custom + * a, t, rho, charge, n, nppc, v, custom + * - enum ntt::StatsID // b^2, e^2, exb, j.e, t, rho, + * charge, n, npart * @namespaces: * - ntt:: * @note Enums of the same type can be compared with each other and with strings @@ -56,7 +58,7 @@ namespace ntt { const char* const* arr_c, const std::size_t n, const char* elem) -> T { - for (std::size_t i = 0; i < n; ++i) { + for (auto i = 0u; i < n; ++i) { if (strcmp(arr_c[i], elem) == 0) { return (T)(arr[i]); } @@ -70,7 +72,7 @@ namespace ntt { constexpr auto baseContains(const char* const* arr_c, const std::size_t n, const char* elem) -> bool { - for (std::size_t i = 0; i < n; ++i) { + for (auto i = 0u; i < n; ++i) { if (strcmp(arr_c[i], elem) == 0) { return true; } @@ -215,23 +217,26 @@ namespace ntt { enum type : uint8_t { INVALID = 0, PERIODIC = 1, - ABSORB = 2, - ATMOSPHERE = 3, - CUSTOM = 4, - CONDUCTOR = 5, + MATCH = 2, + FIXED = 3, + ATMOSPHERE = 4, + CUSTOM = 5, HORIZON = 6, AXIS = 7, - SYNC = 8, // <- SYNC means synchronization with other domains + CONDUCTOR = 8, + SYNC = 9 // <- SYNC means synchronization with other domains }; constexpr FldsBC(uint8_t c) : enums_hidden::BaseEnum { c } {} - static constexpr type variants[] = { PERIODIC, ABSORB, ATMOSPHERE, CUSTOM, - CONDUCTOR, HORIZON, AXIS, SYNC }; - static constexpr const char* lookup[] = { "periodic", "absorb", - "atmosphere", "custom", - "conductor", "horizon", - "axis", "sync" }; + static constexpr type variants[] = { + PERIODIC, MATCH, FIXED, ATMOSPHERE, CUSTOM, + HORIZON, AXIS, CONDUCTOR, SYNC, + }; + static constexpr const char* lookup[] = { + "periodic", "match", "fixed", "atmosphere", "custom", + "horizon", "axis", "conductor", "sync" + }; static constexpr std::size_t total = sizeof(variants) / sizeof(variants[0]); }; @@ -288,17 +293,46 @@ namespace ntt { Charge = 11, N = 12, Nppc = 13, - Custom = 14, + V = 14, + Custom = 15, }; constexpr FldsID(uint8_t c) : enums_hidden::BaseEnum { c } {} - static constexpr type variants[] = { E, divE, D, divD, B, H, J, - A, T, Rho, Charge, N, Nppc, Custom }; - static constexpr const char* lookup[] = { "e", "dive", "d", "divd", - "b", "h", "j", "a", - "t", "rho", "charge", "n", - "nppc", "custom" }; + static constexpr type variants[] = { E, divE, D, divD, B, + H, J, A, T, Rho, + Charge, N, Nppc, V, Custom }; + static constexpr const char* lookup[] = { "e", "dive", "d", "divd", + "b", "h", "j", "a", + "t", "rho", "charge", "n", + "nppc", "v", "custom" }; + static constexpr std::size_t total = sizeof(variants) / sizeof(variants[0]); + }; + + struct StatsID : public enums_hidden::BaseEnum { + static constexpr const char* label = "out_stats"; + + enum type : uint8_t { + INVALID = 0, + B2 = 1, + E2 = 2, + ExB = 3, + JdotE = 4, + T = 5, + Rho = 6, + Charge = 7, + N = 8, + Npart = 9, + Custom = 10, + }; + + constexpr StatsID(uint8_t c) : enums_hidden::BaseEnum { c } {} + + static constexpr type variants[] = { B2, E2, ExB, JdotE, T, + Rho, Charge, N, Npart, Custom }; + static constexpr const char* lookup[] = { "b^2", "e^2", "exb", "j.e", + "t", "rho", "charge", "n", + "npart", "custom" }; static constexpr std::size_t total = sizeof(variants) / sizeof(variants[0]); }; diff --git a/src/global/global.h b/src/global/global.h index ca067d547..adffcf6e9 100644 --- a/src/global/global.h +++ b/src/global/global.h @@ -13,6 +13,10 @@ * - enum PrepareOutput * - enum CellLayer // allLayer, activeLayer, minGhostLayer, * minActiveLayer, maxActiveLayer, maxGhostLayer + * - enum Idx // U, D, T, XYZ, Sph, PU, PD + * - enum Crd // Cd, Ph, XYZ, Sph + * - enum in // x1, x2, x3 + * - enum bc_in // Px1, Mx1, Px2, Mx2, Px3, Mx3 * - type box_region_t * - files::LogFile, files::ErrFile, files::InfoFile * - type prtldx_t @@ -88,6 +92,8 @@ #ifndef GLOBAL_GLOBAL_H #define GLOBAL_GLOBAL_H +#include +#include #include #include #include @@ -107,7 +113,7 @@ namespace files { namespace ntt { - inline constexpr unsigned int N_GHOSTS = 2; + inline constexpr std::size_t N_GHOSTS = 2; // Coordinate shift to account for ghost cells #define COORD(I) \ (static_cast(static_cast((I)) - static_cast(N_GHOSTS))) @@ -184,6 +190,15 @@ enum class in : unsigned short { x3 = 2, }; +enum class bc_in : short { + Mx1 = -1, + Px1 = 1, + Mx2 = -2, + Px2 = 2, + Mx3 = -3, + Px3 = 3, +}; + template using box_region_t = CellLayer[D]; @@ -204,18 +219,15 @@ typedef int PrepareOutputFlags; namespace Timer { enum TimerFlags_ { - None = 0, - PrintRelative = 1 << 0, - PrintUnits = 1 << 1, - PrintIndents = 1 << 2, - PrintTotal = 1 << 3, - PrintTitle = 1 << 4, - AutoConvert = 1 << 5, - Colorful = 1 << 6, - PrintOutput = 1 << 7, - PrintSorting = 1 << 8, - Default = PrintRelative | PrintUnits | PrintIndents | PrintTotal | - PrintTitle | AutoConvert | Colorful, + None = 0, + PrintTotal = 1 << 0, + PrintTitle = 1 << 1, + AutoConvert = 1 << 2, + PrintOutput = 1 << 3, + PrintPrtlClear = 1 << 4, + PrintCheckpoint = 1 << 5, + PrintNormed = 1 << 6, + Default = PrintNormed | PrintTotal | PrintTitle | AutoConvert, }; } // namespace Timer @@ -252,6 +264,18 @@ namespace Comm { typedef int CommTags; +namespace WriteMode { + enum WriteModeTags_ { + None = 0, + Fields = 1 << 0, + Particles = 1 << 1, + Spectra = 1 << 2, + Stats = 1 << 3, + }; +} // namespace WriteMode + +typedef int WriteModeTags; + namespace BC { enum BCTags_ { None = 0, @@ -264,9 +288,17 @@ namespace BC { Dx1 = 1 << 0, Dx2 = 1 << 1, Dx3 = 1 << 2, + Hx1 = 1 << 3, + Hx2 = 1 << 4, + Hx3 = 1 << 5, + Jx1 = 1 << 6, + Jx2 = 1 << 7, + Jx3 = 1 << 8, B = Bx1 | Bx2 | Bx3, E = Ex1 | Ex2 | Ex3, D = Dx1 | Dx2 | Dx3, + H = Hx1 | Hx2 | Hx3, + J = Jx1 | Jx2 | Jx3, }; } // namespace BC @@ -318,10 +350,26 @@ using coord_t = tuple_t; template using vec_t = tuple_t; +// time/duration +using duration_t = double; +using simtime_t = double; +using timestep_t = std::size_t; +using ncells_t = std::size_t; +using npart_t = unsigned long int; + +// walltime +using timestamp_t = std::chrono::time_point; + +// index/number using index_t = const std::size_t; using idx_t = unsigned short; +using spidx_t = unsigned short; +using dim_t = unsigned short; + +// utility +using path_t = std::filesystem::path; -using range_tuple_t = std::pair; +using range_tuple_t = std::pair; template using boundaries_t = std::vector>; diff --git a/src/global/tests/CMakeLists.txt b/src/global/tests/CMakeLists.txt index e9e5de687..c206f85b0 100644 --- a/src/global/tests/CMakeLists.txt +++ b/src/global/tests/CMakeLists.txt @@ -1,11 +1,16 @@ +# cmake-lint: disable=C0103,C0111 # ------------------------------ # @brief: Generates tests for the `ntt_global` module +# # @uses: -# - kokkos [required] -# - plog [required] -# - mpi [optional] +# +# * kokkos [required] +# * plog [required] +# * mpi [optional] +# # !TODO: -# - add optional tests for the `mpi_aliases.h` +# +# * add optional tests for the `mpi_aliases.h` # ------------------------------ set(SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../) diff --git a/src/global/tests/enums.cpp b/src/global/tests/enums.cpp index 1fc57398f..f653f4727 100644 --- a/src/global/tests/enums.cpp +++ b/src/global/tests/enums.cpp @@ -61,14 +61,18 @@ auto main() -> int { enum_str_t all_simulation_engines = { "srpic", "grpic" }; enum_str_t all_particle_bcs = { "periodic", "absorb", "atmosphere", "custom", "reflect", "horizon", "axis", "sync" }; - enum_str_t all_fields_bcs = { "periodic", "absorb", "atmosphere", "custom", - "horizon", "conductor", "axis", "sync" }; + enum_str_t all_fields_bcs = { "periodic", "match", "fixed", + "atmosphere", "custom", "horizon", + "axis", "conductor", "sync" }; enum_str_t all_particle_pushers = { "boris", "vay", "photon", "none" }; enum_str_t all_coolings = { "synchrotron", "none" }; - enum_str_t all_out_flds = { "e", "dive", "d", "divd", "b", - "h", "j", "a", "t", "rho", - "charge", "n", "nppc", "custom" }; + enum_str_t all_out_flds = { "e", "dive", "d", "divd", "b", + "h", "j", "a", "t", "rho", + "charge", "n", "nppc", "v", "custom" }; + + enum_str_t all_out_stats = { "b^2", "e^2", "exb", "j.e", "t", + "rho", "charge", "n", "npart", "custom" }; checkEnum(all_coords); checkEnum(all_metrics); @@ -78,6 +82,7 @@ auto main() -> int { checkEnum(all_particle_pushers); checkEnum(all_coolings); checkEnum(all_out_flds); + checkEnum(all_out_stats); return 0; } diff --git a/src/global/tests/kokkos_aliases.cpp b/src/global/tests/kokkos_aliases.cpp index 56a17c50f..909b6b30c 100644 --- a/src/global/tests/kokkos_aliases.cpp +++ b/src/global/tests/kokkos_aliases.cpp @@ -3,6 +3,7 @@ #include "global.h" #include +#include #include #include @@ -44,8 +45,7 @@ auto main(int argc, char* argv[]) -> int { { // scatter arrays & ranges array_t a { "a", 100 }; - scatter_array_t a_scatter = Kokkos::Experimental::create_scatter_view( - a); + auto a_scatter = Kokkos::Experimental::create_scatter_view(a); Kokkos::parallel_for( // range_t({ 0 }, { 100 }), CreateRangePolicy({ 0 }, { 100 }), @@ -87,4 +87,4 @@ auto main(int argc, char* argv[]) -> int { Kokkos::finalize(); return 0; -} \ No newline at end of file +} diff --git a/src/global/tests/param_container.cpp b/src/global/tests/param_container.cpp index 31a8ca437..60c5495c5 100644 --- a/src/global/tests/param_container.cpp +++ b/src/global/tests/param_container.cpp @@ -25,8 +25,8 @@ auto main() -> int { const auto nonexist_vec = std::vector { 1, 2, 3 }; const auto flds_bc_vec = std::vector { FldsBC::AXIS, FldsBC::PERIODIC }; const auto prtl_bc_vec = boundaries_t { - {PrtlBC::REFLECT, PrtlBC::PERIODIC}, - {PrtlBC::REFLECT, PrtlBC::REFLECT} + { PrtlBC::REFLECT, PrtlBC::PERIODIC }, + { PrtlBC::REFLECT, PrtlBC::REFLECT } }; p.set("a", 1); diff --git a/src/global/utils/cargs.cpp b/src/global/utils/cargs.cpp index 8f641214e..57b79f33b 100644 --- a/src/global/utils/cargs.cpp +++ b/src/global/utils/cargs.cpp @@ -18,8 +18,8 @@ namespace cargs { _initialized = true; } - auto CommandLineArguments::getArgument(std::string_view key, std::string_view def) - -> std::string_view { + auto CommandLineArguments::getArgument(std::string_view key, + std::string_view def) -> std::string_view { if (!_initialized) { throw std::runtime_error( "# Error: command line arguments have not been parsed."); diff --git a/src/global/utils/cargs.h b/src/global/utils/cargs.h index 7c02146b7..530969912 100644 --- a/src/global/utils/cargs.h +++ b/src/global/utils/cargs.h @@ -25,8 +25,7 @@ namespace cargs { public: void readCommandLineArguments(int argc, char* argv[]); [[nodiscard]] - auto getArgument(std::string_view key, std::string_view def) - -> std::string_view; + auto getArgument(std::string_view key, std::string_view def) -> std::string_view; [[nodiscard]] auto getArgument(std::string_view key) -> std::string_view; auto isSpecified(std::string_view key) -> bool; diff --git a/src/global/utils/colors.h b/src/global/utils/colors.h index 512ad81c7..b997317cb 100644 --- a/src/global/utils/colors.h +++ b/src/global/utils/colors.h @@ -53,7 +53,7 @@ namespace color { return msg_nocol; } - inline auto get_color(const std::string& s, bool eight_bit) -> std::string { + inline auto get_color(const std::string& s, bool eight_bit = true) -> std::string { if (not eight_bit) { return ""; } else { @@ -124,23 +124,23 @@ namespace color { c_bmagenta = c_bcyan = c_bwhite = ""; } return { - { "reset", c_reset}, - { "black", c_black}, - { "red", c_red}, - { "green", c_green}, - { "yellow", c_yellow}, - { "blue", c_blue}, - { "magenta", c_magenta}, - { "cyan", c_cyan}, - { "white", c_white}, - { "bblack", c_bblack}, - { "bred", c_bred}, - { "bgreen", c_bgreen}, - { "byellow", c_byellow}, - { "bblue", c_bblue}, - {"bmagenta", c_bmagenta}, - { "bcyan", c_bcyan}, - { "bwhite", c_bwhite} + { "reset", c_reset }, + { "black", c_black }, + { "red", c_red }, + { "green", c_green }, + { "yellow", c_yellow }, + { "blue", c_blue }, + { "magenta", c_magenta }, + { "cyan", c_cyan }, + { "white", c_white }, + { "bblack", c_bblack }, + { "bred", c_bred }, + { "bgreen", c_bgreen }, + { "byellow", c_byellow }, + { "bblue", c_bblue }, + { "bmagenta", c_bmagenta }, + { "bcyan", c_bcyan }, + { "bwhite", c_bwhite } }; } } // namespace color diff --git a/src/global/utils/comparators.h b/src/global/utils/comparators.h index d86fe868c..a12d55e73 100644 --- a/src/global/utils/comparators.h +++ b/src/global/utils/comparators.h @@ -27,31 +27,31 @@ namespace cmp { template - Inline auto AlmostEqual(T a, T b, T eps = Kokkos::Experimental::epsilon::value) - -> bool { + Inline auto AlmostEqual(T a, + T b, + T eps = Kokkos::Experimental::epsilon::value) -> bool { static_assert(std::is_floating_point_v, "T must be a floating point type"); return (a == b) || (math::abs(a - b) <= math::min(math::abs(a), math::abs(b)) * eps); } template - Inline auto AlmostZero(T a, T eps = Kokkos::Experimental::epsilon::value) - -> bool { + Inline auto AlmostZero(T a, T eps = Kokkos::Experimental::epsilon::value) -> bool { static_assert(std::is_floating_point_v, "T must be a floating point type"); return math::abs(a) <= eps; } template - inline auto AlmostEqual_host(T a, T b, T eps = std::numeric_limits::epsilon()) - -> bool { + inline auto AlmostEqual_host(T a, + T b, + T eps = std::numeric_limits::epsilon()) -> bool { static_assert(std::is_floating_point_v, "T must be a floating point type"); return (a == b) || (std::abs(a - b) <= std::min(std::abs(a), std::abs(b)) * eps); } template - inline auto AlmostZero_host(T a, T eps = std::numeric_limits::epsilon()) - -> bool { + inline auto AlmostZero_host(T a, T eps = std::numeric_limits::epsilon()) -> bool { static_assert(std::is_floating_point_v, "T must be a floating point type"); return std::abs(a) <= eps; } diff --git a/src/global/utils/diag.cpp b/src/global/utils/diag.cpp new file mode 100644 index 000000000..f6f615587 --- /dev/null +++ b/src/global/utils/diag.cpp @@ -0,0 +1,247 @@ +#include "utils/diag.h" + +#include "global.h" + +#include "utils/colors.h" +#include "utils/formatting.h" +#include "utils/progressbar.h" +#include "utils/timer.h" + +#if defined(MPI_ENABLED) + #include "arch/mpi_aliases.h" + + #include +#endif // MPI_ENABLED + +#include +#include +#include +#include +#include +#include + +namespace diag { + auto npart_stats( + npart_t npart, + npart_t maxnpart) -> std::vector> { + auto stats = std::vector>(); +#if !defined(MPI_ENABLED) + stats.push_back( + { npart, + static_cast( + 100.0f * static_cast(npart) / static_cast(maxnpart)) }); +#else + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + std::vector mpi_npart(size, 0); + std::vector mpi_maxnpart(size, 0); + MPI_Gather(&npart, + 1, + mpi::get_type(), + mpi_npart.data(), + 1, + mpi::get_type(), + MPI_ROOT_RANK, + MPI_COMM_WORLD); + MPI_Gather(&maxnpart, + 1, + mpi::get_type(), + mpi_maxnpart.data(), + 1, + mpi::get_type(), + MPI_ROOT_RANK, + MPI_COMM_WORLD); + if (rank != MPI_ROOT_RANK) { + return stats; + } + auto tot_npart = std::accumulate(mpi_npart.begin(), mpi_npart.end(), 0); + const auto max_idx = std::distance( + mpi_npart.begin(), + std::max_element(mpi_npart.begin(), mpi_npart.end())); + const auto min_idx = std::distance( + mpi_npart.begin(), + std::min_element(mpi_npart.begin(), mpi_npart.end())); + stats.push_back({ tot_npart, 0u }); + stats.push_back({ mpi_npart[min_idx], + static_cast( + 100.0f * static_cast(mpi_npart[min_idx]) / + static_cast(mpi_maxnpart[min_idx])) }); + stats.push_back({ mpi_npart[max_idx], + static_cast( + 100.0f * static_cast(mpi_npart[max_idx]) / + static_cast(mpi_maxnpart[max_idx])) }); +#endif + return stats; + } + + void printDiagnostics(timestep_t step, + timestep_t tot_steps, + simtime_t time, + simtime_t dt, + timer::Timers& timers, + pbar::DurationHistory& time_history, + ncells_t ncells, + const std::vector& species_labels, + const std::vector& species_npart, + const std::vector& species_maxnpart, + bool print_prtl_clear, + bool print_output, + bool print_checkpoint, + bool print_colors) { + DiagFlags diag_flags = Diag::Default; + TimerFlags timer_flags = Timer::Default; + if (not print_colors) { + diag_flags ^= Diag::Colorful; + } + if (species_labels.size() == 0) { + diag_flags ^= Diag::Species; + } + if (print_prtl_clear) { + timer_flags |= Timer::PrintPrtlClear; + } + if (print_output) { + timer_flags |= Timer::PrintOutput; + } + if (print_checkpoint) { + timer_flags |= Timer::PrintCheckpoint; + } + + std::stringstream ss; + + const auto c_red = color::get_color("red"); + const auto c_yellow = color::get_color("yellow"); + const auto c_green = color::get_color("green"); + const auto c_bgreen = color::get_color("bgreen"); + const auto c_bblack = color::get_color("bblack"); + const auto c_reset = color::get_color("reset"); + + // basic info + CallOnce([&]() { + ss << fmt::alignedTable( + { "Step:", fmt::format("%lu", step), fmt::format("[of %lu]", tot_steps) }, + { c_reset, c_bgreen, c_bblack }, + { 0, -6, -32 }, + { ' ', ' ', '.' }, + c_bblack, + c_reset); + + ss << fmt::alignedTable( + { "Time:", fmt::format("%.4f", time), fmt::format("[Ξ”t = %.4f]", dt) }, + { c_reset, c_bgreen, c_bblack }, + { 0, -6, -32 }, + { ' ', ' ', '.' }, + c_bblack, + c_reset); + }); + + // substep timers + if (diag_flags & Diag::Timers) { + const auto total_npart = std::accumulate(species_npart.begin(), + species_npart.end(), + 0); + const auto timer_diag = timers.printAll(timer_flags, total_npart, ncells); + CallOnce([&]() { + ss << std::endl << timer_diag << std::endl; + }); + } + + // particle counts + if (diag_flags & Diag::Species) { +#if defined(MPI_ENABLED) + CallOnce([&]() { + ss << fmt::alignedTable( + { "[PARTICLE SPECIES]", "[TOTAL]", "[% MIN", "MAX]", "[MIN", "MAX]" }, + { c_bblack, c_bblack, c_bblack, c_bblack, c_bblack, c_bblack }, + { 0, 37, 45, -48, 63, -66 }, + { ' ', ' ', ' ', ':', ' ', ':' }, + c_bblack, + c_reset); + }); +#else + CallOnce([&]() { + ss << fmt::alignedTable({ "[PARTICLE SPECIES]", "[TOTAL]", "[% TOT]" }, + { c_bblack, c_bblack, c_bblack }, + { 0, 37, 45 }, + { ' ', ' ', ' ' }, + c_bblack, + c_reset); + }); +#endif + for (auto i = 0u; i < species_labels.size(); ++i) { + const auto part_stats = npart_stats(species_npart[i], species_maxnpart[i]); + if (part_stats.size() == 0) { + continue; + } + const auto tot_npart = part_stats[0].first; +#if defined(MPI_ENABLED) + const auto min_npart = part_stats[1].first; + const auto min_pct = part_stats[1].second; + const auto max_npart = part_stats[2].first; + const auto max_pct = part_stats[2].second; + ss << fmt::alignedTable( + { + fmt::format("species %2lu (%s)", i, species_labels[i].c_str()), + tot_npart > 9999 ? fmt::format("%.2Le", (long double)tot_npart) + : std::to_string(tot_npart), + std::to_string(min_pct) + "%", + std::to_string(max_pct) + "%", + min_npart > 9999 ? fmt::format("%.2Le", (long double)min_npart) + : std::to_string(min_npart), + max_npart > 9999 ? fmt::format("%.2Le", (long double)max_npart) + : std::to_string(max_npart), + }, + { + c_reset, + c_reset, + (min_pct > 80) ? c_red : ((min_pct > 50) ? c_yellow : c_green), + (max_pct > 80) ? c_red : ((max_pct > 50) ? c_yellow : c_green), + c_reset, + c_reset, + }, + { -2, 37, 45, -48, 63, -66 }, + { ' ', '.', ' ', ':', ' ', ':' }, + c_bblack, + c_reset); +#else + const auto tot_pct = part_stats[0].second; + ss << fmt::alignedTable( + { + fmt::format("species %2lu (%s)", i, species_labels[i].c_str()), + tot_npart > 9999 ? fmt::format("%.2Le", (long double)tot_npart) + : std::to_string(tot_npart), + std::to_string(tot_pct) + "%", + }, + { + c_reset, + c_reset, + (tot_pct > 80) ? c_red : ((tot_pct > 50) ? c_yellow : c_green), + }, + { -2, 37, 45 }, + { ' ', '.', ' ' }, + c_bblack, + c_reset); +#endif + } + CallOnce([&]() { + ss << std::endl; + }); + } + + // progress bar + if (diag_flags & Diag::Progress) { + const auto progbar = pbar::ProgressBar(time_history, step, tot_steps, diag_flags); + CallOnce([&]() { + ss << progbar; + }); + } + + // separator + CallOnce([&]() { + ss << std::setw(80) << std::setfill('.') << "" << std::endl << std::endl; + }); + + std::cout << ((diag_flags & Diag::Colorful) ? ss.str() + : color::strip(ss.str())); + } +} // namespace diag diff --git a/src/global/utils/diag.h b/src/global/utils/diag.h new file mode 100644 index 000000000..6d3c5937d --- /dev/null +++ b/src/global/utils/diag.h @@ -0,0 +1,59 @@ +/** + * @file utils/diag.h + * @brief Routines for diagnostics output at every step + * @implements + * - diag::printDiagnostics -> void + * @cpp: + * - diag.cpp + * @namespces: + * - diag:: + * @macros: + * - MPI_ENABLED + */ + +#ifndef GLOBAL_UTILS_DIAG_H +#define GLOBAL_UTILS_DIAG_H + +#include "utils/progressbar.h" +#include "utils/timer.h" + +#include +#include + +namespace diag { + + /** + * @brief Print diagnostics to the console + * @param step + * @param tot_steps + * @param time + * @param dt + * @param timers + * @param duration_history + * @param ncells (total) + * @param species_labels (vector of particle labels) + * @param npart (per each species) + * @param maxnpart (per each species) + * @param prtlclear (if true, dead particles were removed) + * @param output (if true, output was written) + * @param checkpoint (if true, checkpoint was written) + * @param colorful_print (if true, print with colors) + */ + void printDiagnostics(timestep_t, + timestep_t, + simtime_t, + simtime_t, + timer::Timers&, + pbar::DurationHistory&, + ncells_t, + const std::vector&, + const std::vector&, + const std::vector&, + bool, + bool, + bool, + bool); + +} // namespace diag + +#endif // GLOBAL_UTILS_DIAG_H diff --git a/src/global/utils/error.h b/src/global/utils/error.h index 9d5afed29..df23bce62 100644 --- a/src/global/utils/error.h +++ b/src/global/utils/error.h @@ -34,6 +34,7 @@ namespace raise { using namespace files; + [[noreturn]] inline void Error(const std::string& msg, const std::string& file, const std::string& func, @@ -100,12 +101,12 @@ namespace raise { const char* func, int line, const char* msg) { - printf("\n%s:%d @ %s\nError: %s", file, line, func, msg); + Kokkos::printf("\n%s:%d @ %s\nError: %s", file, line, func, msg); Kokkos::abort("kernel error"); } Inline void KernelNotImplementedError(const char* file, const char* func, int line) { - printf("\n%s:%d @ %s\n", file, line, func); + Kokkos::printf("\n%s:%d @ %s\n", file, line, func); Kokkos::abort("kernel not implemented"); } diff --git a/src/global/utils/formatting.h b/src/global/utils/formatting.h index 85f4e4b43..46c25c633 100644 --- a/src/global/utils/formatting.h +++ b/src/global/utils/formatting.h @@ -8,6 +8,8 @@ * - fmt::splitString -> std::vector * - fmt::repeat -> std::string * - fmt::formatVector -> std::string + * - fmt::strlen_utf8 -> std::size_t + * - fmt::alignedTable -> std::string * @namespaces: * - fmt:: */ @@ -51,8 +53,10 @@ namespace fmt { * @param c Character to pad with * @param right Pad on the right */ - inline auto pad(const std::string& str, std::size_t n, char c, bool right = false) - -> std::string { + inline auto pad(const std::string& str, + std::size_t n, + char c, + bool right = false) -> std::string { if (n <= str.size()) { return str; } @@ -111,8 +115,8 @@ namespace fmt { * @param delim Delimiter * @return Vector of strings */ - inline auto splitString(const std::string& str, const std::string& delim) - -> std::vector { + inline auto splitString(const std::string& str, + const std::string& delim) -> std::vector { std::regex regexz(delim); return { std::sregex_token_iterator(str.begin(), str.end(), regexz, -1), std::sregex_token_iterator() }; @@ -132,6 +136,66 @@ namespace fmt { return result; } + inline auto repeat(char s, std::size_t n) -> std::string { + return repeat(std::string(1, s), n); + } + + /** + * @brief Calculate the length of a UTF-8 string + * @param str UTF-8 string + */ + inline auto strlenUTF8(const std::string& str) -> std::size_t { + std::size_t length = 0; + for (char c : str) { + if ((c & 0xC0) != 0x80) { + ++length; + } + } + return length; + } + + /** + * @brief Create a table with aligned columns and custom colors & separators + * @param columns Vector of column strings + * @param colors Vector of colors + * @param anchors Vector of column anchors (position of edge, negative means left-align) + * @param fillers Vector of separators + * @param c_bblack Black color + * @param c_reset Reset color + */ + inline auto alignedTable(const std::vector& columns, + const std::vector& colors, + const std::vector& anchors, + const std::vector& fillers, + const std::string& c_bblack, + const std::string& c_reset) -> std::string { + std::string result { c_reset }; + std::size_t cntr { 0 }; + for (auto i { 0u }; i < columns.size(); ++i) { + const auto anch { static_cast(anchors[i] < 0 ? -anchors[i] + : anchors[i]) }; + const auto leftalign { anchors[i] <= 0 }; + const auto cmn { columns[i] }; + const auto cmn_len { strlenUTF8(cmn) }; + std::string left { c_bblack }; + if (leftalign) { + if (fillers[i] == ':') { + left += " :"; + left += repeat(' ', anch - cntr - 2); + } else { + left += repeat(fillers[i], anch - cntr); + } + cntr += anch - cntr; + } else { + left += repeat(fillers[i], anch - cntr - cmn_len); + cntr += anch - cntr - cmn_len; + } + result += left + colors[i] + cmn + c_reset; + cntr += cmn_len; + } + return result + c_reset + "\n"; + } + } // namespace fmt #endif // GLOBAL_UTILS_FORMATTING_H diff --git a/src/global/utils/log.h b/src/global/utils/log.h index ac5bc4059..2434414a4 100644 --- a/src/global/utils/log.h +++ b/src/global/utils/log.h @@ -34,6 +34,8 @@ #endif namespace raise { + using namespace files; + inline void Warning(const std::string& msg, const std::string& file, const std::string& func, diff --git a/src/global/utils/numeric.h b/src/global/utils/numeric.h index 0b09f6c11..59cec5ba5 100644 --- a/src/global/utils/numeric.h +++ b/src/global/utils/numeric.h @@ -36,6 +36,7 @@ inline constexpr float TWO = 2.0f; inline constexpr float THREE = 3.0f; inline constexpr float FOUR = 4.0f; inline constexpr float FIVE = 5.0f; +inline constexpr float TWELVE = 12.0f; inline constexpr float ZERO = 0.0f; inline constexpr float HALF = 0.5f; inline constexpr float INV_2 = 0.5f; @@ -50,6 +51,7 @@ inline constexpr double TWO = 2.0; inline constexpr double THREE = 3.0; inline constexpr double FOUR = 4.0; inline constexpr double FIVE = 5.0; +inline constexpr double TWELVE = 12.0; inline constexpr double ZERO = 0.0; inline constexpr double HALF = 0.5; inline constexpr double INV_2 = 0.5; @@ -78,17 +80,23 @@ inline constexpr double INV_64 = 0.015625; #define CROSS_x3(ax1, ax2, ax3, bx1, bx2, bx3) ((ax1) * (bx2) - (ax2) * (bx1)) namespace constant { - inline constexpr std::uint64_t RandomSeed = 0x123456789abcdef0; - inline constexpr double HALF_PI = 1.57079632679489661923; - inline constexpr double PI = 3.14159265358979323846; - inline constexpr double INV_PI = 0.31830988618379067154; - inline constexpr double PI_SQR = 9.86960440108935861882; - inline constexpr double INV_PI_SQR = 0.10132118364233777144; - inline constexpr double TWO_PI = 6.28318530717958647692; - inline constexpr double E = 2.71828182845904523536; - inline constexpr double SQRT2 = 1.41421356237309504880; - inline constexpr double INV_SQRT2 = 0.70710678118654752440; - inline constexpr double SQRT3 = 1.73205080756887729352; + inline constexpr std::uint64_t RandomSeed = 0x123456789abcdef0; + inline constexpr double HALF_PI = 1.57079632679489661923; + inline constexpr double PI = 3.14159265358979323846; + inline constexpr double INV_PI = 0.31830988618379067154; + inline constexpr double PI_SQR = 9.86960440108935861882; + inline constexpr double INV_PI_SQR = 0.10132118364233777144; + inline constexpr double TWO_PI = 6.28318530717958647692; + inline constexpr double E = 2.71828182845904523536; + inline constexpr double SQRT2 = 1.41421356237309504880; + inline constexpr double INV_SQRT2 = 0.70710678118654752440; + inline constexpr double SQRT3 = 1.73205080756887729352; + inline constexpr double SMALL_ANGLE = 1e-3; + inline constexpr double SMALL_ANGLE_GR = 1e-5; } // namespace constant +namespace convert { + inline constexpr double deg2rad = constant::PI / 180.0; +} // namespace convert + #endif // GLOBAL_UTILS_NUMERIC_H diff --git a/src/global/utils/param_container.cpp b/src/global/utils/param_container.cpp new file mode 100644 index 000000000..126a64e4e --- /dev/null +++ b/src/global/utils/param_container.cpp @@ -0,0 +1,349 @@ +#if defined(OUTPUT_ENABLED) + #include "utils/param_container.h" + + #include "enums.h" + #include "global.h" + + #include + #include + + #include + #include + #include + #include + #include + #include + #include + #include + +namespace prm { + template + struct has_to_string : std::false_type {}; + + template + struct has_to_string().to_string())>> + : std::true_type {}; + + template + auto write(adios2::IO& io, const std::string& name, T var) -> + typename std::enable_if::value, void>::type { + io.DefineAttribute(name, std::string(var.to_string())); + } + + template + auto write(adios2::IO& io, const std::string& name, T var) -> decltype(void(T()), + void()) { + io.DefineAttribute(name, var); + } + + template <> + void write(adios2::IO& io, const std::string& name, bool var) { + io.DefineAttribute(name, var ? 1 : 0); + } + + template <> + void write(adios2::IO& io, const std::string& name, Dimension var) { + io.DefineAttribute(name, (unsigned short)var); + } + + template + auto write_pair(adios2::IO& io, const std::string& name, std::pair var) -> + typename std::enable_if::value, void>::type { + std::vector var_str; + var_str.push_back(var.first.to_string()); + var_str.push_back(var.second.to_string()); + io.DefineAttribute(name, var_str.data(), var_str.size()); + } + + template + auto write_pair(adios2::IO& io, + const std::string& name, + std::pair var) -> decltype(void(T()), void()) { + std::vector var_vec; + var_vec.push_back(var.first); + var_vec.push_back(var.second); + io.DefineAttribute(name, var_vec.data(), var_vec.size()); + } + + template + auto write_vec(adios2::IO& io, const std::string& name, std::vector var) -> + typename std::enable_if::value, void>::type { + std::vector var_str; + for (const auto& v : var) { + var_str.push_back(v.to_string()); + } + io.DefineAttribute(name, var_str.data(), var_str.size()); + } + + template + auto write_vec(adios2::IO& io, + const std::string& name, + std::vector var) -> decltype(void(T()), void()) { + io.DefineAttribute(name, var.data(), var.size()); + } + + template + auto write_vec_pair(adios2::IO& io, + const std::string& name, + std::vector> var) -> + typename std::enable_if::value, void>::type { + std::vector var_str; + for (const auto& v : var) { + var_str.push_back(v.first.to_string()); + var_str.push_back(v.second.to_string()); + } + io.DefineAttribute(name, var_str.data(), var_str.size()); + } + + template + auto write_vec_pair(adios2::IO& io, + const std::string& name, + std::vector> var) -> decltype(void(T()), + void()) { + std::vector var_vec; + for (const auto& v : var) { + var_vec.push_back(v.first); + var_vec.push_back(v.second); + } + io.DefineAttribute(name, var_vec.data(), var_vec.size()); + } + + template + auto write_vec_vec(adios2::IO& io, + const std::string& name, + std::vector> var) -> + typename std::enable_if::value, void>::type { + std::vector var_str; + for (const auto& vec : var) { + for (const auto& v : vec) { + var_str.push_back(v.to_string()); + } + } + io.DefineAttribute(name, var_str.data(), var_str.size()); + } + + template + auto write_vec_vec(adios2::IO& io, + const std::string& name, + std::vector> var) -> decltype(void(T()), + void()) { + std::vector var_vec; + for (const auto& vec : var) { + for (const auto& v : vec) { + var_vec.push_back(v); + } + } + io.DefineAttribute(name, var_vec.data(), var_vec.size()); + } + + template + auto write_dict(adios2::IO& io, + const std::string& name, + std::map var) -> + typename std::enable_if::value, void>::type { + for (const auto& [key, v] : var) { + io.DefineAttribute(name + "_" + key, v.to_string()); + } + } + + template + auto write_dict(adios2::IO& io, + const std::string& name, + std::map var) -> decltype(void(T()), void()) { + for (const auto& [key, v] : var) { + io.DefineAttribute(name + "_" + key, v); + } + } + + std::map> + write_functions; + + template + void register_write_function() { + write_functions[std::type_index(typeid(T))] = + [](adios2::IO& io, const std::string& name, std::any a) { + write(io, name, std::any_cast(a)); + }; + } + + template + void register_write_function_for_pair() { + write_functions[std::type_index(typeid(std::pair))] = + [](adios2::IO& io, const std::string& name, std::any a) { + write_pair(io, name, std::any_cast>(a)); + }; + } + + template + void register_write_function_for_vector() { + write_functions[std::type_index(typeid(std::vector))] = + [](adios2::IO& io, const std::string& name, std::any a) { + write_vec(io, name, std::any_cast>(a)); + }; + } + + template + void register_write_function_for_vector_of_pair() { + write_functions[std::type_index(typeid(std::vector>))] = + [](adios2::IO& io, const std::string& name, std::any a) { + write_vec_pair(io, name, std::any_cast>>(a)); + }; + } + + template + void register_write_function_for_vector_of_vector() { + write_functions[std::type_index(typeid(std::vector>))] = + [](adios2::IO& io, const std::string& name, std::any a) { + write_vec_vec(io, name, std::any_cast>>(a)); + }; + } + + template + void register_write_function_for_dict() { + write_functions[std::type_index(typeid(std::map))] = + [](adios2::IO& io, const std::string& name, std::any a) { + write_dict(io, name, std::any_cast>(a)); + }; + } + + void write_any(adios2::IO& io, const std::string& name, std::any a) { + auto it = write_functions.find(a.type()); + if (it != write_functions.end()) { + it->second(io, name, a); + } else { + throw std::runtime_error("No write function registered for this type"); + } + } + + void Parameters::write(adios2::IO& io) const { + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + register_write_function(); + + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + register_write_function_for_pair(); + + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + register_write_function_for_vector(); + + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + register_write_function_for_vector_of_pair(); + + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + register_write_function_for_vector_of_vector(); + + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + register_write_function_for_dict(); + + for (auto& [key, value] : allVars()) { + // @TODO: add particles.species support in attrs + if (key == "particles.species") { + continue; + } + try { + write_any(io, key, value); + } catch (const std::exception& e) { + raise::Warning( + fmt::format("Failed to write parameter '%s', skipping. Error msg: %s", + key.c_str(), + e.what()), + HERE); + continue; + } + } + } + +} // namespace prm + +#endif // OUTPUT_ENABLED diff --git a/src/global/utils/param_container.h b/src/global/utils/param_container.h index dccf91d09..679b80552 100644 --- a/src/global/utils/param_container.h +++ b/src/global/utils/param_container.h @@ -16,6 +16,11 @@ #include "utils/formatting.h" #include "utils/log.h" +#if defined(OUTPUT_ENABLED) + #include + #include +#endif + #include #include #include @@ -172,6 +177,10 @@ namespace prm { } return result.str(); } + +#if defined(OUTPUT_ENABLED) + void write(adios2::IO& io) const; +#endif }; } // namespace prm diff --git a/src/global/utils/plog.h b/src/global/utils/plog.h index 03dc19319..7713a3728 100644 --- a/src/global/utils/plog.h +++ b/src/global/utils/plog.h @@ -13,6 +13,8 @@ #ifndef GLOBAL_UTILS_PLOG_H #define GLOBAL_UTILS_PLOG_H +#include "utils/formatting.h" + #include #include #include @@ -57,7 +59,7 @@ namespace plog { namespace logger { template - inline void initPlog(const std::string& fname) { + inline void initPlog(const std::string& fname, const std::string& log_level) { // setup logging const auto logfile_name = fname + ".log"; const auto infofile_name = fname + ".info"; @@ -77,7 +79,13 @@ namespace logger { infofile_name.c_str()); static plog::RollingFileAppender errfileAppender( errfile_name.c_str()); - plog::init(plog::verbose, &logfileAppender); + auto log_severity = plog::verbose; + if (fmt::toLower(log_level) == "warning") { + log_severity = plog::warning; + } else if (fmt::toLower(log_level) == "error") { + log_severity = plog::error; + } + plog::init(log_severity, &logfileAppender); plog::init(plog::verbose, &infofileAppender); plog::init(plog::verbose, &errfileAppender); @@ -93,4 +101,4 @@ namespace logger { } // namespace logger -#endif // GLOBAL_UTILS_PLOG_H \ No newline at end of file +#endif // GLOBAL_UTILS_PLOG_H diff --git a/src/global/utils/progressbar.cpp b/src/global/utils/progressbar.cpp new file mode 100644 index 000000000..eaa8118fc --- /dev/null +++ b/src/global/utils/progressbar.cpp @@ -0,0 +1,123 @@ +#include "utils/progressbar.h" + +#include "global.h" + +#include "utils/error.h" +#include "utils/formatting.h" + +#if defined(MPI_ENABLED) + #include "arch/mpi_aliases.h" + + #include +#endif // MPI_ENABLED + +#include +#include +#include +#include +#include +#include +#include + +namespace pbar { + + auto normalize_duration_fmt( + duration_t t, + const std::string& u) -> std::pair { + const std::vector> units { + { "Β΅s", 1e0 }, + { "ms", 1e3 }, + { "s", 1e6 }, + { "min", 6e7 }, + { "hr", 3.6e9 } + }; + auto it = std::find_if(units.begin(), units.end(), [&u](const auto& pr) { + return pr.first == u; + }); + int u_idx = (it != units.end()) ? std::distance(units.begin(), it) : -1; + raise::ErrorIf(u_idx < 0, "Invalid unit", HERE); + int shift = 0; + if (t < 1) { + shift = -1; + } else if (1e3 <= t && t < 1e6) { + shift = 1; + } else if (1e6 <= t && t < 6e7) { + shift += 2; + } else if (6e7 <= t && t < 3.6e9) { + shift += 3; + } else if (3.6e9 <= t) { + shift += 4; + } + auto newu_idx = std::min(std::max(0, u_idx + shift), + static_cast(units.size())); + return { t * (units[u_idx].second / units[newu_idx].second), + units[newu_idx].first }; + } + + auto to_human_readable(duration_t t, const std::string& u) -> std::string { + const auto [tt, tu] = normalize_duration_fmt(t, u); + const auto t1 = static_cast(tt); + const auto t2 = tt - static_cast(t1); + const auto [tt2, tu2] = normalize_duration_fmt(t2, tu); + return fmt::format("%d%s %d%s", t1, tu.c_str(), static_cast(tt2), tu2.c_str()); + } + + auto ProgressBar(const DurationHistory& history, + timestep_t step, + timestep_t max_steps, + DiagFlags& flags) -> std::string { + auto avg_duration = history.average(); + +#if defined(MPI_ENABLED) + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + std::vector mpi_avg_durations(size, 0.0); + MPI_Gather(&avg_duration, + 1, + mpi::get_type(), + mpi_avg_durations.data(), + 1, + mpi::get_type(), + MPI_ROOT_RANK, + MPI_COMM_WORLD); + if (rank != MPI_ROOT_RANK) { + return ""; + } + avg_duration = *std::max_element(mpi_avg_durations.begin(), + mpi_avg_durations.end()); +#endif + + const auto avg = to_human_readable(avg_duration, "Β΅s"); + const auto elapsed = to_human_readable(history.elapsed(), "Β΅s"); + const auto remain = to_human_readable( + static_cast(max_steps - step) * avg_duration, + "Β΅s"); + + const auto pct = static_cast(step) / + static_cast(max_steps); + const int nfilled = std::min(static_cast(pct * params::width), + params::width); + const int nempty = params::width - nfilled; + const auto c_bmagenta = color::get_color("bmagenta", flags & Diag::Colorful); + const auto c_reset = color::get_color("reset", flags & Diag::Colorful); + + std::stringstream ss; + + ss << "Timestep duration: " << c_bmagenta << avg << c_reset << std::endl; + ss << "Remaining time: " << c_bmagenta << remain << c_reset << std::endl; + ss << "Elapsed time: " << c_bmagenta << elapsed << c_reset << std::endl; + ss << params::start; + for (auto i { 0 }; i < nfilled; ++i) { + ss << params::fill; + } + for (auto i { 0 }; i < nempty; ++i) { + ss << params::empty; + } + ss << params::end << " " << std::fixed << std::setprecision(2) + << std::setfill(' ') << std::setw(6) << std::right << pct * 100.0 << "%\n"; + + return ss.str(); + } + +} // namespace pbar diff --git a/src/global/utils/progressbar.h b/src/global/utils/progressbar.h index ccbc6215e..588413cb4 100644 --- a/src/global/utils/progressbar.h +++ b/src/global/utils/progressbar.h @@ -3,6 +3,8 @@ * @brief Progress bar for logging the simulation progress * @implements * - pbar::ProgressBar -> void + * @cpp: + * - progressbar.cpp * @namespaces: * - pbar:: * @macros: @@ -16,22 +18,15 @@ #include "utils/colors.h" #include "utils/error.h" +#include "utils/formatting.h" #include -#include -#include #include #include #include #include #include -#if defined(MPI_ENABLED) - #include "arch/mpi_aliases.h" - - #include -#endif // MPI_ENABLED - namespace pbar { namespace params { inline constexpr int width { 70 }; @@ -43,7 +38,7 @@ namespace pbar { class DurationHistory { std::size_t capacity; - std::vector durations; + std::vector durations; const std::chrono::time_point start; std::chrono::time_point prev_start; @@ -65,112 +60,32 @@ namespace pbar { prev_start = now; } - auto average() const -> long double { + auto average() const -> duration_t { if (durations.size() > 0) { return std::accumulate(durations.begin(), durations.end(), 0.0) / - static_cast(durations.size()); + static_cast(durations.size()); } else { return 0.0; } } - auto elapsed() const -> long double { + auto elapsed() const -> duration_t { return std::chrono::duration_cast( std::chrono::system_clock::now() - start) .count(); } }; - inline auto normalize_duration_fmt(long double t, const std::string& u) - -> std::pair { - const std::vector> units { - {"Β΅s", 1e0}, - { "ms", 1e3}, - { "s", 1e6}, - {"min", 6e7}, - { "hr", 3.6e9} - }; - auto it = std::find_if(units.begin(), units.end(), [&u](const auto& pr) { - return pr.first == u; - }); - int u_idx = (it != units.end()) ? std::distance(units.begin(), it) : -1; - raise::ErrorIf(u_idx < 0, "Invalid unit", HERE); - int shift = 0; - if (t < 1e-2) { - shift = -1; - } else if (1e3 <= t && t < 1e6) { - shift = 1; - } else if (1e6 <= t && t < 6e7) { - shift += 2; - } else if (6e7 <= t && t < 3.6e9) { - shift += 3; - } else if (3.6e9 <= t) { - shift += 4; - } - auto newu_idx = std::min(std::max(0, u_idx + shift), - static_cast(units.size())); - return { t * (units[u_idx].second / units[newu_idx].second), - units[newu_idx].first }; - } - - inline void ProgressBar(const DurationHistory& history, - std::size_t step, - std::size_t max_steps, - DiagFlags& flags, - std::ostream& os = std::cout) { - auto avg_duration = history.average(); + auto normalize_duration_fmt( + duration_t t, + const std::string& u) -> std::pair; -#if defined(MPI_ENABLED) - int rank, size; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &size); - std::vector mpi_avg_durations(size, 0.0); - MPI_Gather(&avg_duration, - 1, - mpi::get_type(), - mpi_avg_durations.data(), - 1, - mpi::get_type(), - MPI_ROOT_RANK, - MPI_COMM_WORLD); - if (rank != MPI_ROOT_RANK) { - return; - } - avg_duration = *std::max_element(mpi_avg_durations.begin(), - mpi_avg_durations.end()); -#endif - auto [avg_reduced, avg_units] = normalize_duration_fmt(avg_duration, "Β΅s"); - - const auto remain_nsteps = max_steps - step; - auto [remain_time, remain_units] = normalize_duration_fmt( - static_cast(remain_nsteps) * avg_duration, - "Β΅s"); - auto [elapsed_time, - elapsed_units] = normalize_duration_fmt(history.elapsed(), "Β΅s"); + auto to_human_readable(duration_t t, const std::string& u) -> std::string; - const auto pct = static_cast(step) / - static_cast(max_steps); - const int nfilled = std::min(static_cast(pct * params::width), - params::width); - const int nempty = params::width - nfilled; - const auto c_bmagenta = color::get_color("bmagenta", flags & Diag::Colorful); - const auto c_reset = color::get_color("reset", flags & Diag::Colorful); - os << "Average timestep: " << c_bmagenta << avg_reduced << " " << avg_units - << c_reset << std::endl; - os << "Remaining time: " << c_bmagenta << remain_time << " " << remain_units - << c_reset << std::endl; - os << "Elapsed time: " << c_bmagenta << elapsed_time << " " << elapsed_units - << c_reset << std::endl; - os << params::start; - for (auto i { 0 }; i < nfilled; ++i) { - os << params::fill; - } - for (auto i { 0 }; i < nempty; ++i) { - os << params::empty; - } - os << params::end << " " << std::fixed << std::setprecision(2) - << std::setfill(' ') << std::setw(6) << std::right << pct * 100.0 << "%\n"; - } + auto ProgressBar(const DurationHistory& history, + timestep_t step, + timestep_t max_steps, + DiagFlags& flags) -> std::string; } // namespace pbar diff --git a/src/global/utils/timer.cpp b/src/global/utils/timer.cpp new file mode 100644 index 000000000..19c7c9147 --- /dev/null +++ b/src/global/utils/timer.cpp @@ -0,0 +1,286 @@ +#include "utils/timer.h" + +#include "global.h" + +#include "utils/colors.h" +#include "utils/formatting.h" + +#if defined(MPI_ENABLED) + #include "arch/mpi_aliases.h" + + #include +#endif // MPI_ENABLED + +#include +#include +#include +#include +#include + +namespace timer { + + auto Timers::gather(const std::vector& ignore_in_tot, + npart_t npart, + ncells_t ncells) const + -> std::map> { + auto timer_stats = std::map< + std::string, + std::tuple> {}; +#if defined(MPI_ENABLED) + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + std::map> all_timers {}; + + // accumulate timers from MPI blocks + for (auto& [name, timer] : m_timers) { + all_timers.insert({ name, std::vector(size, 0.0) }); + MPI_Gather(&timer.second, + 1, + mpi::get_type(), + all_timers[name].data(), + 1, + mpi::get_type(), + MPI_ROOT_RANK, + MPI_COMM_WORLD); + } + // accumulate nparts and ncells from MPI blocks + auto all_nparts = std::vector(size, 0); + auto all_ncells = std::vector(size, 0); + MPI_Gather(&npart, + 1, + mpi::get_type(), + all_nparts.data(), + 1, + mpi::get_type(), + MPI_ROOT_RANK, + MPI_COMM_WORLD); + MPI_Gather(&ncells, + 1, + mpi::get_type(), + all_ncells.data(), + 1, + mpi::get_type(), + MPI_ROOT_RANK, + MPI_COMM_WORLD); + if (rank != MPI_ROOT_RANK) { + return {}; + } + std::vector all_totals(size, 0.0); + for (auto i { 0 }; i < size; ++i) { + for (auto& [name, timer] : m_timers) { + if (std::find(ignore_in_tot.begin(), ignore_in_tot.end(), name) == + ignore_in_tot.end()) { + all_totals[i] += all_timers[name][i]; + } + } + } + for (auto& [name, timer] : m_timers) { + const auto max_time = *std::max_element(all_timers[name].begin(), + all_timers[name].end()); + const auto max_idx = std::distance( + all_timers[name].begin(), + std::max_element(all_timers[name].begin(), all_timers[name].end())); + + const auto per_npart = all_nparts[max_idx] > 0 + ? max_time / + static_cast(all_nparts[max_idx]) + : 0.0; + const auto per_ncells = all_ncells[max_idx] > 0 + ? max_time / + static_cast(all_ncells[max_idx]) + : 0.0; + const auto pcent = static_cast( + (max_time / all_totals[max_idx]) * 100.0); + timer_stats.insert( + { name, + std::make_tuple(max_time, + per_npart, + per_ncells, + pcent, + tools::ArrayImbalance(all_timers[name])) }); + } + const auto max_tot = *std::max_element(all_totals.begin(), all_totals.end()); + const auto tot_imb = tools::ArrayImbalance(all_totals); + timer_stats.insert( + { "Total", std::make_tuple(max_tot, 0.0, 0.0, 100u, tot_imb) }); +#else + duration_t local_tot = 0.0; + for (auto& [name, timer] : m_timers) { + if (std::find(ignore_in_tot.begin(), ignore_in_tot.end(), name) == + ignore_in_tot.end()) { + local_tot += timer.second; + } + } + for (auto& [name, timer] : m_timers) { + const auto pcent = static_cast( + (timer.second / local_tot) * 100.0); + timer_stats.insert( + { name, + std::make_tuple( + timer.second, + npart > 0 ? timer.second / static_cast(npart) : 0.0, + timer.second / static_cast(ncells), + pcent, + 0u) }); + } + timer_stats.insert({ "Total", std::make_tuple(local_tot, 0.0, 0.0, 100u, 0u) }); +#endif + return timer_stats; + } + + auto Timers::printAll(TimerFlags flags, + npart_t npart, + ncells_t ncells) const -> std::string { + const std::vector extras { "PrtlClear", "Output", "Checkpoint" }; + const auto stats = gather(extras, npart, ncells); + if (stats.empty()) { + return ""; + } + +#if defined(MPI_ENABLED) + const auto multi_rank = true; +#else + const auto multi_rank = false; +#endif + + std::stringstream ss; + + const auto c_bblack = color::get_color("bblack"); + const auto c_reset = color::get_color("reset"); + const auto c_byellow = color::get_color("byellow"); + const auto c_blue = color::get_color("blue"); + const auto c_red = color::get_color("red"); + const auto c_yellow = color::get_color("yellow"); + const auto c_green = color::get_color("green"); + + if (multi_rank and flags & Timer::PrintTitle) { + ss << fmt::alignedTable( + { "[SUBSTEP]", "[MAX DURATION]", "[% TOT", "VAR]", "[per PRTL", "CELL]" }, + { c_bblack, c_bblack, c_bblack, c_bblack, c_bblack, c_bblack }, + { 0, 37, 45, -48, 63, -66 }, + { ' ', '.', ' ', ':', ' ', ':' }, + c_bblack, + c_reset); + } else { + ss << fmt::alignedTable( + { "[SUBSTEP]", "[DURATION]", "[% TOT]", "[per PRTL", "CELL]" }, + { c_bblack, c_bblack, c_bblack, c_bblack, c_bblack }, + { 0, 37, 45, 55, -58 }, + { ' ', '.', ' ', ' ', ':' }, + c_bblack, + c_reset); + } + + for (auto& [name, timers] : m_timers) { + if (std::find(extras.begin(), extras.end(), name) != extras.end()) { + continue; + } + std::string units = "Β΅s", units_npart = "Β΅s", units_ncells = "Β΅s"; + auto time = std::get<0>(stats.at(name)); + auto per_npart = std::get<1>(stats.at(name)); + auto per_ncells = std::get<2>(stats.at(name)); + const auto tot_pct = std::get<3>(stats.at(name)); + const auto var_pct = std::get<4>(stats.at(name)); + if (flags & Timer::AutoConvert) { + convertTime(time, units); + convertTime(per_npart, units_npart); + convertTime(per_ncells, units_ncells); + } + + if (multi_rank) { + ss << fmt::alignedTable( + { name, + fmt::format("%.2f", time) + " " + units, + std::to_string(tot_pct) + "%", + std::to_string(var_pct) + "%", + fmt::format("%.2f", per_npart) + " " + units_npart, + fmt::format("%.2f", per_ncells) + " " + units_ncells }, + { c_reset, + c_yellow, + ((tot_pct > 60) ? c_red : ((tot_pct > 40) ? c_yellow : c_green)), + ((var_pct > 50) ? c_red : ((var_pct > 30) ? c_yellow : c_green)), + c_yellow, + c_yellow }, + { -2, 37, 45, -48, 63, -66 }, + { ' ', '.', ' ', ':', ' ', ':' }, + c_bblack, + c_reset); + } else { + ss << fmt::alignedTable( + { name, + fmt::format("%.2f", time) + " " + units, + std::to_string(tot_pct) + "%", + fmt::format("%.2f", per_npart) + " " + units_npart, + fmt::format("%.2f", per_ncells) + " " + units_ncells }, + { c_reset, + c_yellow, + ((tot_pct > 60) ? c_red : ((tot_pct > 40) ? c_yellow : c_green)), + c_yellow, + c_yellow }, + { -2, 37, 45, 55, -58 }, + { ' ', '.', ' ', ' ', ':' }, + c_bblack, + c_reset); + } + } + + // total + if (flags & Timer::PrintTotal) { + std::string units = "Β΅s"; + auto time = std::get<0>(stats.at("Total")); + const auto var_pct = std::get<4>(stats.at("Total")); + if (flags & Timer::AutoConvert) { + convertTime(time, units); + } + if (multi_rank) { + ss << fmt::alignedTable( + { "Total", + fmt::format("%.2f", time) + " " + units, + std::to_string(var_pct) + "%" }, + { c_reset, + c_blue, + ((var_pct > 50) ? c_red : ((var_pct > 30) ? c_yellow : c_green)) }, + { 0, 37, -48 }, + { ' ', ' ', ' ' }, + c_bblack, + c_reset); + } else { + ss << fmt::alignedTable({ "Total", fmt::format("%.2f", time) + " " + units }, + { c_reset, c_blue }, + { 0, 37 }, + { ' ', ' ' }, + c_bblack, + c_reset); + } + } + + // print extra timers for output/checkpoint/prtlClear + const std::vector extras_f { Timer::PrintPrtlClear, + Timer::PrintOutput, + Timer::PrintCheckpoint }; + for (auto i { 0u }; i < extras.size(); ++i) { + const auto name = extras[i]; + const auto active = flags & extras_f[i]; + std::string units = "Β΅s"; + auto time = std::get<0>(stats.at(name)); + const auto tot_pct = std::get<3>(stats.at(name)); + if (flags & Timer::AutoConvert) { + convertTime(time, units); + } + ss << fmt::alignedTable({ name, + fmt::format("%.2f", time) + " " + units, + std::to_string(tot_pct) + "%" }, + { (active ? c_reset : c_bblack), + (active ? c_byellow : c_bblack), + (active ? c_byellow : c_bblack) }, + { -2, 37, 45 }, + { ' ', '.', ' ' }, + c_bblack, + c_reset); + } + return ss.str(); + } + +} // namespace timer diff --git a/src/global/utils/timer.h b/src/global/utils/timer.h index 79f325f0e..8abe8ab0a 100644 --- a/src/global/utils/timer.h +++ b/src/global/utils/timer.h @@ -4,6 +4,8 @@ * @implements * - timer::Timers * - enum timer::TimerFlags + * @cpp: + * - timer.cpp * @namespces: * - timer:: * @macros: @@ -21,21 +23,26 @@ #include "utils/error.h" #include "utils/formatting.h" #include "utils/numeric.h" +#include "utils/tools.h" + +#if defined(MPI_ENABLED) + #include "arch/mpi_aliases.h" + + #include +#endif // MPI_ENABLED #include #include -#include -#include #include #include #include +#include #include #include namespace timer { - using timestamp = std::chrono::time_point; - inline void convertTime(long double& value, std::string& units) { + inline void convertTime(duration_t& value, std::string& units) { if (value > 1e6) { value /= 1e6; units = " s"; @@ -49,10 +56,10 @@ namespace timer { } class Timers { - std::map> m_timers; - std::vector m_names; - const bool m_blocking; - const std::function m_synchronize; + std::map> m_timers; + std::vector m_names; + const bool m_blocking; + const std::function m_synchronize; public: Timers(std::initializer_list names, @@ -66,7 +73,7 @@ namespace timer { for (const auto& name : names) { m_timers.insert({ name, - {std::chrono::system_clock::now(), 0.0} + { std::chrono::system_clock::now(), 0.0 } }); m_names.push_back(name); } @@ -80,7 +87,9 @@ namespace timer { void stop(const std::string& name) { if (m_blocking) { - m_synchronize(); + if (m_synchronize != nullptr) { + m_synchronize(); + } #if defined(MPI_ENABLED) MPI_Barrier(MPI_COMM_WORLD); #endif @@ -103,9 +112,9 @@ namespace timer { } [[nodiscard]] - auto get(const std::string& name) const -> long double { + auto get(const std::string& name) const -> duration_t { if (name == "Total") { - long double total = 0.0; + duration_t total = 0.0; for (auto& timer : m_timers) { total += timer.second.second; } @@ -116,174 +125,29 @@ namespace timer { } } - void printAll(const TimerFlags flags = Timer::Default, - std::ostream& os = std::cout) const { - std::string header = fmt::format("%s %27s", "[SUBSTEP]", "[DURATION]"); - - const auto c_bblack = color::get_color("bblack", flags & Timer::Colorful); - const auto c_reset = color::get_color("reset", flags & Timer::Colorful); - const auto c_byellow = color::get_color("byellow", flags & Timer::Colorful); - const auto c_blue = color::get_color("blue", flags & Timer::Colorful); + /** + * @brief Gather all timers from all ranks + * @param ignore_in_tot: vector of timer names to ignore in computing the total + * @return map: + * key: timer name + * value: vector of numbers + * - max duration across ranks + * - max duration per particle + * - max duration per cell + * - max duration as % of total on that rank + * - imbalance % of the given timer + */ + [[nodiscard]] + auto gather(const std::vector& ignore_in_tot, + npart_t npart, + ncells_t ncells) const + -> std::map>; - if (flags & Timer::PrintRelative) { - header += " [% TOT]"; - } -#if defined(MPI_ENABLED) - header += " [MIN : MAX]"; -#endif - header = c_bblack + header + c_reset; - CallOnce( - [](std::ostream& os, std::string header) { - os << header << std::endl; - }, - os, - header); -#if defined(MPI_ENABLED) - int rank, size; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &size); - std::map> mpi_timers {}; - // accumulate timers from MPI blocks - for (auto& [name, timer] : m_timers) { - mpi_timers[name] = std::vector(size, 0.0); - MPI_Gather(&timer.second, - 1, - mpi::get_type(), - mpi_timers[name].data(), - 1, - mpi::get_type(), - MPI_ROOT_RANK, - MPI_COMM_WORLD); - } - if (rank != MPI_ROOT_RANK) { - return; - } - long double total = 0.0; - for (auto& [name, timer] : m_timers) { - auto timers = mpi_timers[name]; - long double tot = std::accumulate(timers.begin(), timers.end(), 0.0); - if (name != "Output") { - total += tot; - } - } - for (auto& [name, timers] : mpi_timers) { - // compute min, max, mean - long double min_time = *std::min_element(timers.begin(), timers.end()); - long double max_time = *std::max_element(timers.begin(), timers.end()); - long double mean_time = std::accumulate(timers.begin(), timers.end(), 0.0) / - size; - std::string mean_units = "Β΅s"; - const auto min_pct = mean_time > ZERO - ? (int)(((mean_time - min_time) / mean_time) * 100.0) - : 0; - const auto max_pct = mean_time > ZERO - ? (int)(((max_time - mean_time) / mean_time) * 100.0) - : 0; - const auto tot_pct = (cmp::AlmostZero_host(total) - ? 0 - : (mean_time * size / total) * 100.0); - if (flags & Timer::AutoConvert) { - convertTime(mean_time, mean_units); - } - if (flags & Timer::PrintIndents) { - os << " "; - } - os << ((name != "Sorting" or flags & Timer::PrintSorting) ? c_reset - : c_bblack) - << name << c_reset << c_bblack - << fmt::pad(name, 20, '.', true).substr(name.size(), 20); - os << std::setw(17) << std::right << std::setfill('.') - << fmt::format("%s%.2Lf", - (name != "Sorting" or flags & Timer::PrintSorting) - ? c_byellow.c_str() - : c_bblack.c_str(), - mean_time); - if (flags & Timer::PrintUnits) { - os << " " << mean_units << " "; - } - if (flags & Timer::PrintRelative) { - os << " " << std::setw(5) << std::right << std::setfill(' ') - << std::fixed << std::setprecision(2) << tot_pct << "%"; - } - os << fmt::format("%+7s : %-7s", - fmt::format("-%d%%", min_pct).c_str(), - fmt::format("+%d%%", max_pct).c_str()); - os << c_reset << std::endl; - } - total /= size; -#else // not MPI_ENABLED - long double total = 0.0; - for (auto& [name, timer] : m_timers) { - if (name != "Output") { - total += timer.second; - } - } - for (auto& [name, timer] : m_timers) { - if (name == "Output") { - continue; - } - std::string units = "Β΅s"; - auto value = timer.second; - if (flags & Timer::AutoConvert) { - convertTime(value, units); - } - if (flags & Timer::PrintIndents) { - os << " "; - } - os << ((name != "Sorting" or flags & Timer::PrintSorting) ? c_reset - : c_bblack) - << name << c_bblack - << fmt::pad(name, 20, '.', true).substr(name.size(), 20); - os << std::setw(17) << std::right << std::setfill('.') - << fmt::format("%s%.2Lf", - (name != "Sorting" or flags & Timer::PrintSorting) - ? c_byellow.c_str() - : c_bblack.c_str(), - value); - if (flags & Timer::PrintUnits) { - os << " " << units; - } - if (flags & Timer::PrintRelative) { - os << " " << std::setw(7) << std::right << std::setfill(' ') - << std::fixed << std::setprecision(2) - << (cmp::AlmostZero_host(total) ? 0 : (timer.second / total) * 100.0); - } - os << c_reset << std::endl; - } -#endif // MPI_ENABLED - if (flags & Timer::PrintTotal) { - std::string units = "Β΅s"; - auto value = total; - if (flags & Timer::AutoConvert) { - convertTime(value, units); - } - os << c_bblack << std::setw(22) << std::left << std::setfill(' ') - << "Total" << c_reset; - os << c_blue << std::setw(12) << std::right << std::setfill(' ') << value; - if (flags & Timer::PrintUnits) { - os << " " << units; - } - os << c_reset << std::endl; - } - { - std::string units = "Β΅s"; - auto value = get("Output"); - if (flags & Timer::AutoConvert) { - convertTime(value, units); - } - os << ((flags & Timer::PrintOutput) ? c_reset : c_bblack) << "Output" - << c_bblack << fmt::pad("Output", 22, '.', true).substr(6, 22); - os << std::setw(17) << std::right << std::setfill('.') - << fmt::format("%s%.2Lf", - (flags & Timer::PrintOutput) ? c_byellow.c_str() - : c_bblack.c_str(), - value); - if (flags & Timer::PrintUnits) { - os << " " << units; - } - os << c_reset << std::endl; - } - } + [[nodiscard]] + auto printAll(TimerFlags flags = Timer::Default, + npart_t npart = 0, + ncells_t ncells = 0) const -> std::string; }; } // namespace timer diff --git a/src/global/utils/toml.h b/src/global/utils/toml.h new file mode 100644 index 000000000..5d9981e9f --- /dev/null +++ b/src/global/utils/toml.h @@ -0,0 +1,17938 @@ +#ifndef TOML11_VERSION_HPP +#define TOML11_VERSION_HPP + +#define TOML11_VERSION_MAJOR 4 +#define TOML11_VERSION_MINOR 1 +#define TOML11_VERSION_PATCH 0 + +#ifndef __cplusplus + #error "__cplusplus is not defined" +#endif + +// Since MSVC does not define `__cplusplus` correctly unless you pass +// `/Zc:__cplusplus` when compiling, the workaround macros are added. +// +// The value of `__cplusplus` macro is defined in the C++ standard spec, but +// MSVC ignores the value, maybe because of backward compatibility. Instead, +// MSVC defines _MSVC_LANG that has the same value as __cplusplus defined in +// the C++ standard. So we check if _MSVC_LANG is defined before using `__cplusplus`. +// +// FYI: https://docs.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170 +// https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macros?view=msvc-170 +// + +#if defined(_MSVC_LANG) && defined(_MSC_VER) && 190024210 <= _MSC_FULL_VER + #define TOML11_CPLUSPLUS_STANDARD_VERSION _MSVC_LANG +#else + #define TOML11_CPLUSPLUS_STANDARD_VERSION __cplusplus +#endif + +#if TOML11_CPLUSPLUS_STANDARD_VERSION < 201103L + #error "toml11 requires C++11 or later." +#endif + +#if !defined(__has_include) + #define __has_include(x) 0 +#endif + +#if !defined(__has_cpp_attribute) + #define __has_cpp_attribute(x) 0 +#endif + +#if !defined(__has_builtin) + #define __has_builtin(x) 0 +#endif + +// hard to remember + +#ifndef TOML11_CXX14_VALUE + #define TOML11_CXX14_VALUE 201402L +#endif // TOML11_CXX14_VALUE + +#ifndef TOML11_CXX17_VALUE + #define TOML11_CXX17_VALUE 201703L +#endif // TOML11_CXX17_VALUE + +#ifndef TOML11_CXX20_VALUE + #define TOML11_CXX20_VALUE 202002L +#endif // TOML11_CXX20_VALUE + +#if defined(__cpp_char8_t) + #if __cpp_char8_t >= 201811L + #define TOML11_HAS_CHAR8_T 1 + #endif +#endif + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX17_VALUE + #if __has_include() + #define TOML11_HAS_STRING_VIEW 1 + #endif +#endif + +#ifndef TOML11_DISABLE_STD_FILESYSTEM + #if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX17_VALUE + #if __has_include() + #define TOML11_HAS_FILESYSTEM 1 + #endif + #endif +#endif + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX17_VALUE + #if __has_include() + #define TOML11_HAS_OPTIONAL 1 + #endif +#endif + +#if defined(TOML11_COMPILE_SOURCES) + #define TOML11_INLINE +#else + #define TOML11_INLINE inline +#endif + +namespace toml { + + inline const char* license_notice() noexcept { + return R"(The MIT License (MIT) + +Copyright (c) 2017-now Toru Niina + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE.)"; + } + +} // namespace toml +#endif // TOML11_VERSION_HPP +#ifndef TOML11_FORMAT_HPP +#define TOML11_FORMAT_HPP + +#ifndef TOML11_FORMAT_FWD_HPP + #define TOML11_FORMAT_FWD_HPP + + #include + #include + #include + #include + #include + +namespace toml { + + // toml types with serialization info + + enum class indent_char : std::uint8_t { + space, // use space + tab, // use tab + none // no indent + }; + + std::ostream& operator<<(std::ostream& os, const indent_char& c); + std::string to_string(const indent_char c); + + // ---------------------------------------------------------------------------- + // boolean + + struct boolean_format_info { + // nothing, for now + }; + + inline bool operator==(const boolean_format_info&, + const boolean_format_info&) noexcept { + return true; + } + + inline bool operator!=(const boolean_format_info&, + const boolean_format_info&) noexcept { + return false; + } + + // ---------------------------------------------------------------------------- + // integer + + enum class integer_format : std::uint8_t { + dec = 0, + bin = 1, + oct = 2, + hex = 3, + }; + + std::ostream& operator<<(std::ostream& os, const integer_format f); + std::string to_string(const integer_format); + + struct integer_format_info { + integer_format fmt = integer_format::dec; + bool uppercase = true; // hex with uppercase + std::size_t width = 0; // minimal width (may exceed) + std::size_t spacer = 0; // position of `_` (if 0, no spacer) + std::string suffix = ""; // _suffix (library extension) + }; + + bool operator==(const integer_format_info&, const integer_format_info&) noexcept; + bool operator!=(const integer_format_info&, const integer_format_info&) noexcept; + + // ---------------------------------------------------------------------------- + // floating + + enum class floating_format : std::uint8_t { + defaultfloat = 0, + fixed = 1, // does not include exponential part + scientific = 2, // always include exponential part + hex = 3 // hexfloat extension + }; + + std::ostream& operator<<(std::ostream& os, const floating_format f); + std::string to_string(const floating_format); + + struct floating_format_info { + floating_format fmt = floating_format::defaultfloat; + std::size_t prec = 0; // precision (if 0, use the default) + std::string suffix = ""; // 1.0e+2_suffix (library extension) + }; + + bool operator==(const floating_format_info&, const floating_format_info&) noexcept; + bool operator!=(const floating_format_info&, const floating_format_info&) noexcept; + + // ---------------------------------------------------------------------------- + // string + + enum class string_format : std::uint8_t { + basic = 0, + literal = 1, + multiline_basic = 2, + multiline_literal = 3 + }; + + std::ostream& operator<<(std::ostream& os, const string_format f); + std::string to_string(const string_format); + + struct string_format_info { + string_format fmt = string_format::basic; + bool start_with_newline = false; + }; + + bool operator==(const string_format_info&, const string_format_info&) noexcept; + bool operator!=(const string_format_info&, const string_format_info&) noexcept; + + // ---------------------------------------------------------------------------- + // datetime + + enum class datetime_delimiter_kind : std::uint8_t { + upper_T = 0, + lower_t = 1, + space = 2, + }; + std::ostream& operator<<(std::ostream& os, const datetime_delimiter_kind d); + std::string to_string(const datetime_delimiter_kind); + + struct offset_datetime_format_info { + datetime_delimiter_kind delimiter = datetime_delimiter_kind::upper_T; + bool has_seconds = true; + std::size_t subsecond_precision = 6; // [us] + }; + + bool operator==(const offset_datetime_format_info&, + const offset_datetime_format_info&) noexcept; + bool operator!=(const offset_datetime_format_info&, + const offset_datetime_format_info&) noexcept; + + struct local_datetime_format_info { + datetime_delimiter_kind delimiter = datetime_delimiter_kind::upper_T; + bool has_seconds = true; + std::size_t subsecond_precision = 6; // [us] + }; + + bool operator==(const local_datetime_format_info&, + const local_datetime_format_info&) noexcept; + bool operator!=(const local_datetime_format_info&, + const local_datetime_format_info&) noexcept; + + struct local_date_format_info { + // nothing, for now + }; + + bool operator==(const local_date_format_info&, + const local_date_format_info&) noexcept; + bool operator!=(const local_date_format_info&, + const local_date_format_info&) noexcept; + + struct local_time_format_info { + bool has_seconds = true; + std::size_t subsecond_precision = 6; // [us] + }; + + bool operator==(const local_time_format_info&, + const local_time_format_info&) noexcept; + bool operator!=(const local_time_format_info&, + const local_time_format_info&) noexcept; + + // ---------------------------------------------------------------------------- + // array + + enum class array_format : std::uint8_t { + default_format = 0, + oneline = 1, + multiline = 2, + array_of_tables = 3 // [[format.in.this.way]] + }; + + std::ostream& operator<<(std::ostream& os, const array_format f); + std::string to_string(const array_format); + + struct array_format_info { + array_format fmt = array_format::default_format; + indent_char indent_type = indent_char::space; + std::int32_t body_indent = 4; // indent in case of multiline + std::int32_t closing_indent = 0; // indent of `]` + }; + + bool operator==(const array_format_info&, const array_format_info&) noexcept; + bool operator!=(const array_format_info&, const array_format_info&) noexcept; + + // ---------------------------------------------------------------------------- + // table + + enum class table_format : std::uint8_t { + multiline = 0, // [foo] \n bar = "baz" + oneline = 1, // foo = {bar = "baz"} + dotted = 2, // foo.bar = "baz" + multiline_oneline = 3, // foo = { \n bar = "baz" \n } + implicit = 4 // [x] defined by [x.y.z]. skip in serializer. + }; + + std::ostream& operator<<(std::ostream& os, const table_format f); + std::string to_string(const table_format); + + struct table_format_info { + table_format fmt = table_format::multiline; + indent_char indent_type = indent_char::space; + std::int32_t body_indent = 0; // indent of values + std::int32_t name_indent = 0; // indent of [table] + std::int32_t closing_indent = 0; // in case of {inline-table} + }; + + bool operator==(const table_format_info&, const table_format_info&) noexcept; + bool operator!=(const table_format_info&, const table_format_info&) noexcept; + + // ---------------------------------------------------------------------------- + // wrapper + + namespace detail { + template + struct value_with_format { + using value_type = T; + using format_type = F; + + value_with_format() = default; + ~value_with_format() = default; + value_with_format(const value_with_format&) = default; + value_with_format(value_with_format&&) = default; + value_with_format& operator=(const value_with_format&) = default; + value_with_format& operator=(value_with_format&&) = default; + + value_with_format(value_type v, format_type f) + : value { std::move(v) } + , format { std::move(f) } {} + + template + value_with_format(value_with_format other) + : value { std::move(other.value) } + , format { std::move(other.format) } {} + + value_type value; + format_type format; + }; + } // namespace detail + +} // namespace toml +#endif // TOML11_FORMAT_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_FORMAT_IMPL_HPP + #define TOML11_FORMAT_IMPL_HPP + + #include + #include + +namespace toml { + + // toml types with serialization info + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, const indent_char& c) { + switch (c) { + case indent_char::space: { + os << "space"; + break; + } + case indent_char::tab: { + os << "tab"; + break; + } + case indent_char::none: { + os << "none"; + break; + } + default: { + os << "unknown indent char: " << static_cast(c); + } + } + return os; + } + + TOML11_INLINE std::string to_string(const indent_char c) { + std::ostringstream oss; + oss << c; + return oss.str(); + } + + // ---------------------------------------------------------------------------- + // boolean + + // ---------------------------------------------------------------------------- + // integer + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, const integer_format f) { + switch (f) { + case integer_format::dec: { + os << "dec"; + break; + } + case integer_format::bin: { + os << "bin"; + break; + } + case integer_format::oct: { + os << "oct"; + break; + } + case integer_format::hex: { + os << "hex"; + break; + } + default: { + os << "unknown integer_format: " << static_cast(f); + break; + } + } + return os; + } + + TOML11_INLINE std::string to_string(const integer_format c) { + std::ostringstream oss; + oss << c; + return oss.str(); + } + + TOML11_INLINE bool operator==(const integer_format_info& lhs, + const integer_format_info& rhs) noexcept { + return lhs.fmt == rhs.fmt && lhs.uppercase == rhs.uppercase && + lhs.width == rhs.width && lhs.spacer == rhs.spacer && + lhs.suffix == rhs.suffix; + } + + TOML11_INLINE bool operator!=(const integer_format_info& lhs, + const integer_format_info& rhs) noexcept { + return !(lhs == rhs); + } + + // ---------------------------------------------------------------------------- + // floating + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, const floating_format f) { + switch (f) { + case floating_format::defaultfloat: { + os << "defaultfloat"; + break; + } + case floating_format::fixed: { + os << "fixed"; + break; + } + case floating_format::scientific: { + os << "scientific"; + break; + } + case floating_format::hex: { + os << "hex"; + break; + } + default: { + os << "unknown floating_format: " << static_cast(f); + break; + } + } + return os; + } + + TOML11_INLINE std::string to_string(const floating_format c) { + std::ostringstream oss; + oss << c; + return oss.str(); + } + + TOML11_INLINE bool operator==(const floating_format_info& lhs, + const floating_format_info& rhs) noexcept { + return lhs.fmt == rhs.fmt && lhs.prec == rhs.prec && lhs.suffix == rhs.suffix; + } + + TOML11_INLINE bool operator!=(const floating_format_info& lhs, + const floating_format_info& rhs) noexcept { + return !(lhs == rhs); + } + + // ---------------------------------------------------------------------------- + // string + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, const string_format f) { + switch (f) { + case string_format::basic: { + os << "basic"; + break; + } + case string_format::literal: { + os << "literal"; + break; + } + case string_format::multiline_basic: { + os << "multiline_basic"; + break; + } + case string_format::multiline_literal: { + os << "multiline_literal"; + break; + } + default: { + os << "unknown string_format: " << static_cast(f); + break; + } + } + return os; + } + + TOML11_INLINE std::string to_string(const string_format c) { + std::ostringstream oss; + oss << c; + return oss.str(); + } + + TOML11_INLINE bool operator==(const string_format_info& lhs, + const string_format_info& rhs) noexcept { + return lhs.fmt == rhs.fmt && lhs.start_with_newline == rhs.start_with_newline; + } + + TOML11_INLINE bool operator!=(const string_format_info& lhs, + const string_format_info& rhs) noexcept { + return !(lhs == rhs); + } + + // ---------------------------------------------------------------------------- + // datetime + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, + const datetime_delimiter_kind d) { + switch (d) { + case datetime_delimiter_kind::upper_T: { + os << "upper_T, "; + break; + } + case datetime_delimiter_kind::lower_t: { + os << "lower_t, "; + break; + } + case datetime_delimiter_kind::space: { + os << "space, "; + break; + } + default: { + os << "unknown datetime delimiter: " << static_cast(d); + break; + } + } + return os; + } + + TOML11_INLINE std::string to_string(const datetime_delimiter_kind c) { + std::ostringstream oss; + oss << c; + return oss.str(); + } + + TOML11_INLINE bool operator==(const offset_datetime_format_info& lhs, + const offset_datetime_format_info& rhs) noexcept { + return lhs.delimiter == rhs.delimiter && lhs.has_seconds == rhs.has_seconds && + lhs.subsecond_precision == rhs.subsecond_precision; + } + + TOML11_INLINE bool operator!=(const offset_datetime_format_info& lhs, + const offset_datetime_format_info& rhs) noexcept { + return !(lhs == rhs); + } + + TOML11_INLINE bool operator==(const local_datetime_format_info& lhs, + const local_datetime_format_info& rhs) noexcept { + return lhs.delimiter == rhs.delimiter && lhs.has_seconds == rhs.has_seconds && + lhs.subsecond_precision == rhs.subsecond_precision; + } + + TOML11_INLINE bool operator!=(const local_datetime_format_info& lhs, + const local_datetime_format_info& rhs) noexcept { + return !(lhs == rhs); + } + + TOML11_INLINE bool operator==(const local_date_format_info&, + const local_date_format_info&) noexcept { + return true; + } + + TOML11_INLINE bool operator!=(const local_date_format_info& lhs, + const local_date_format_info& rhs) noexcept { + return !(lhs == rhs); + } + + TOML11_INLINE bool operator==(const local_time_format_info& lhs, + const local_time_format_info& rhs) noexcept { + return lhs.has_seconds == rhs.has_seconds && + lhs.subsecond_precision == rhs.subsecond_precision; + } + + TOML11_INLINE bool operator!=(const local_time_format_info& lhs, + const local_time_format_info& rhs) noexcept { + return !(lhs == rhs); + } + + // ---------------------------------------------------------------------------- + // array + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, const array_format f) { + switch (f) { + case array_format::default_format: { + os << "default_format"; + break; + } + case array_format::oneline: { + os << "oneline"; + break; + } + case array_format::multiline: { + os << "multiline"; + break; + } + case array_format::array_of_tables: { + os << "array_of_tables"; + break; + } + default: { + os << "unknown array_format: " << static_cast(f); + break; + } + } + return os; + } + + TOML11_INLINE std::string to_string(const array_format c) { + std::ostringstream oss; + oss << c; + return oss.str(); + } + + TOML11_INLINE bool operator==(const array_format_info& lhs, + const array_format_info& rhs) noexcept { + return lhs.fmt == rhs.fmt && lhs.indent_type == rhs.indent_type && + lhs.body_indent == rhs.body_indent && + lhs.closing_indent == rhs.closing_indent; + } + + TOML11_INLINE bool operator!=(const array_format_info& lhs, + const array_format_info& rhs) noexcept { + return !(lhs == rhs); + } + + // ---------------------------------------------------------------------------- + // table + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, const table_format f) { + switch (f) { + case table_format::multiline: { + os << "multiline"; + break; + } + case table_format::oneline: { + os << "oneline"; + break; + } + case table_format::dotted: { + os << "dotted"; + break; + } + case table_format::multiline_oneline: { + os << "multiline_oneline"; + break; + } + case table_format::implicit: { + os << "implicit"; + break; + } + default: { + os << "unknown table_format: " << static_cast(f); + break; + } + } + return os; + } + + TOML11_INLINE std::string to_string(const table_format c) { + std::ostringstream oss; + oss << c; + return oss.str(); + } + + TOML11_INLINE bool operator==(const table_format_info& lhs, + const table_format_info& rhs) noexcept { + return lhs.fmt == rhs.fmt && lhs.indent_type == rhs.indent_type && + lhs.body_indent == rhs.body_indent && + lhs.name_indent == rhs.name_indent && + lhs.closing_indent == rhs.closing_indent; + } + + TOML11_INLINE bool operator!=(const table_format_info& lhs, + const table_format_info& rhs) noexcept { + return !(lhs == rhs); + } + +} // namespace toml + #endif // TOML11_FORMAT_IMPL_HPP +#endif + +#endif // TOML11_FORMAT_HPP +#ifndef TOML11_DATETIME_HPP +#define TOML11_DATETIME_HPP + +#ifndef TOML11_DATETIME_FWD_HPP + #define TOML11_DATETIME_FWD_HPP + + #include + #include + #include + #include + #include + #include + +namespace toml { + + enum class month_t : std::uint8_t { + Jan = 0, + Feb = 1, + Mar = 2, + Apr = 3, + May = 4, + Jun = 5, + Jul = 6, + Aug = 7, + Sep = 8, + Oct = 9, + Nov = 10, + Dec = 11 + }; + + // ---------------------------------------------------------------------------- + + struct local_date { + std::int16_t year { 0 }; // A.D. (like, 2018) + std::uint8_t month { 0 }; // [0, 11] + std::uint8_t day { 0 }; // [1, 31] + + local_date(int y, month_t m, int d) + : year { static_cast(y) } + , month { static_cast(m) } + , day { static_cast(d) } {} + + explicit local_date(const std::tm& t) + : year { static_cast(t.tm_year + 1900) } + , month { static_cast(t.tm_mon) } + , day { static_cast(t.tm_mday) } {} + + explicit local_date(const std::chrono::system_clock::time_point& tp); + explicit local_date(const std::time_t t); + + operator std::chrono::system_clock::time_point() const; + operator std::time_t() const; + + local_date() = default; + ~local_date() = default; + local_date(const local_date&) = default; + local_date(local_date&&) = default; + local_date& operator=(const local_date&) = default; + local_date& operator=(local_date&&) = default; + }; + + bool operator==(const local_date& lhs, const local_date& rhs); + bool operator!=(const local_date& lhs, const local_date& rhs); + bool operator<(const local_date& lhs, const local_date& rhs); + bool operator<=(const local_date& lhs, const local_date& rhs); + bool operator>(const local_date& lhs, const local_date& rhs); + bool operator>=(const local_date& lhs, const local_date& rhs); + + std::ostream& operator<<(std::ostream& os, const local_date& date); + std::string to_string(const local_date& date); + + // ----------------------------------------------------------------------------- + + struct local_time { + std::uint8_t hour { 0 }; // [0, 23] + std::uint8_t minute { 0 }; // [0, 59] + std::uint8_t second { 0 }; // [0, 60] + std::uint16_t millisecond { 0 }; // [0, 999] + std::uint16_t microsecond { 0 }; // [0, 999] + std::uint16_t nanosecond { 0 }; // [0, 999] + + local_time(int h, int m, int s, int ms = 0, int us = 0, int ns = 0) + : hour { static_cast(h) } + , minute { static_cast(m) } + , second { static_cast(s) } + , millisecond { static_cast(ms) } + , microsecond { static_cast(us) } + , nanosecond { static_cast(ns) } {} + + explicit local_time(const std::tm& t) + : hour { static_cast(t.tm_hour) } + , minute { static_cast(t.tm_min) } + , second { static_cast(t.tm_sec) } + , millisecond { 0 } + , microsecond { 0 } + , nanosecond { 0 } {} + + template + explicit local_time(const std::chrono::duration& t) { + const auto h = std::chrono::duration_cast(t); + this->hour = static_cast(h.count()); + const auto t2 = t - h; + const auto m = std::chrono::duration_cast(t2); + this->minute = static_cast(m.count()); + const auto t3 = t2 - m; + const auto s = std::chrono::duration_cast(t3); + this->second = static_cast(s.count()); + const auto t4 = t3 - s; + const auto ms = std::chrono::duration_cast(t4); + this->millisecond = static_cast(ms.count()); + const auto t5 = t4 - ms; + const auto us = std::chrono::duration_cast(t5); + this->microsecond = static_cast(us.count()); + const auto t6 = t5 - us; + const auto ns = std::chrono::duration_cast(t6); + this->nanosecond = static_cast(ns.count()); + } + + operator std::chrono::nanoseconds() const; + + local_time() = default; + ~local_time() = default; + local_time(const local_time&) = default; + local_time(local_time&&) = default; + local_time& operator=(const local_time&) = default; + local_time& operator=(local_time&&) = default; + }; + + bool operator==(const local_time& lhs, const local_time& rhs); + bool operator!=(const local_time& lhs, const local_time& rhs); + bool operator<(const local_time& lhs, const local_time& rhs); + bool operator<=(const local_time& lhs, const local_time& rhs); + bool operator>(const local_time& lhs, const local_time& rhs); + bool operator>=(const local_time& lhs, const local_time& rhs); + + std::ostream& operator<<(std::ostream& os, const local_time& time); + std::string to_string(const local_time& time); + + // ---------------------------------------------------------------------------- + + struct time_offset { + std::int8_t hour { 0 }; // [-12, 12] + std::int8_t minute { 0 }; // [-59, 59] + + time_offset(int h, int m) + : hour { static_cast(h) } + , minute { static_cast(m) } {} + + operator std::chrono::minutes() const; + + time_offset() = default; + ~time_offset() = default; + time_offset(const time_offset&) = default; + time_offset(time_offset&&) = default; + time_offset& operator=(const time_offset&) = default; + time_offset& operator=(time_offset&&) = default; + }; + + bool operator==(const time_offset& lhs, const time_offset& rhs); + bool operator!=(const time_offset& lhs, const time_offset& rhs); + bool operator<(const time_offset& lhs, const time_offset& rhs); + bool operator<=(const time_offset& lhs, const time_offset& rhs); + bool operator>(const time_offset& lhs, const time_offset& rhs); + bool operator>=(const time_offset& lhs, const time_offset& rhs); + + std::ostream& operator<<(std::ostream& os, const time_offset& offset); + + std::string to_string(const time_offset& offset); + + // ----------------------------------------------------------------------------- + + struct local_datetime { + local_date date {}; + local_time time {}; + + local_datetime(local_date d, local_time t) : date { d }, time { t } {} + + explicit local_datetime(const std::tm& t) : date { t }, time { t } {} + + explicit local_datetime(const std::chrono::system_clock::time_point& tp); + explicit local_datetime(const std::time_t t); + + operator std::chrono::system_clock::time_point() const; + operator std::time_t() const; + + local_datetime() = default; + ~local_datetime() = default; + local_datetime(const local_datetime&) = default; + local_datetime(local_datetime&&) = default; + local_datetime& operator=(const local_datetime&) = default; + local_datetime& operator=(local_datetime&&) = default; + }; + + bool operator==(const local_datetime& lhs, const local_datetime& rhs); + bool operator!=(const local_datetime& lhs, const local_datetime& rhs); + bool operator<(const local_datetime& lhs, const local_datetime& rhs); + bool operator<=(const local_datetime& lhs, const local_datetime& rhs); + bool operator>(const local_datetime& lhs, const local_datetime& rhs); + bool operator>=(const local_datetime& lhs, const local_datetime& rhs); + + std::ostream& operator<<(std::ostream& os, const local_datetime& dt); + + std::string to_string(const local_datetime& dt); + + // ----------------------------------------------------------------------------- + + struct offset_datetime { + local_date date {}; + local_time time {}; + time_offset offset {}; + + offset_datetime(local_date d, local_time t, time_offset o) + : date { d } + , time { t } + , offset { o } {} + + offset_datetime(const local_datetime& dt, time_offset o) + : date { dt.date } + , time { dt.time } + , offset { o } {} + + // use the current local timezone offset + explicit offset_datetime(const local_datetime& ld); + explicit offset_datetime(const std::chrono::system_clock::time_point& tp); + explicit offset_datetime(const std::time_t& t); + explicit offset_datetime(const std::tm& t); + + operator std::chrono::system_clock::time_point() const; + + operator std::time_t() const; + + offset_datetime() = default; + ~offset_datetime() = default; + offset_datetime(const offset_datetime&) = default; + offset_datetime(offset_datetime&&) = default; + offset_datetime& operator=(const offset_datetime&) = default; + offset_datetime& operator=(offset_datetime&&) = default; + + private: + static time_offset get_local_offset(const std::time_t* tp); + }; + + bool operator==(const offset_datetime& lhs, const offset_datetime& rhs); + bool operator!=(const offset_datetime& lhs, const offset_datetime& rhs); + bool operator<(const offset_datetime& lhs, const offset_datetime& rhs); + bool operator<=(const offset_datetime& lhs, const offset_datetime& rhs); + bool operator>(const offset_datetime& lhs, const offset_datetime& rhs); + bool operator>=(const offset_datetime& lhs, const offset_datetime& rhs); + + std::ostream& operator<<(std::ostream& os, const offset_datetime& dt); + + std::string to_string(const offset_datetime& dt); + +} // namespace toml +#endif // TOML11_DATETIME_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_DATETIME_IMPL_HPP + #define TOML11_DATETIME_IMPL_HPP + + #include + #include + #include + #include + #include + #include + #include + +namespace toml { + + // To avoid non-threadsafe std::localtime. In C11 (not C++11!), localtime_s is + // provided in the absolutely same purpose, but C++11 is actually not + // compatible with C11. We need to dispatch the function depending on the OS. + namespace detail { + // TODO: find more sophisticated way to handle this + #if defined(_MSC_VER) + TOML11_INLINE std::tm localtime_s(const std::time_t* src) { + std::tm dst; + const auto result = ::localtime_s(&dst, src); + if (result) { + throw std::runtime_error("localtime_s failed."); + } + return dst; + } + + TOML11_INLINE std::tm gmtime_s(const std::time_t* src) { + std::tm dst; + const auto result = ::gmtime_s(&dst, src); + if (result) { + throw std::runtime_error("gmtime_s failed."); + } + return dst; + } + #elif (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 1) || \ + defined(_XOPEN_SOURCE) || defined(_BSD_SOURCE) || \ + defined(_SVID_SOURCE) || defined(_POSIX_SOURCE) + TOML11_INLINE std::tm localtime_s(const std::time_t* src) { + std::tm dst; + const auto result = ::localtime_r(src, &dst); + if (!result) { + throw std::runtime_error("localtime_r failed."); + } + return dst; + } + + TOML11_INLINE std::tm gmtime_s(const std::time_t* src) { + std::tm dst; + const auto result = ::gmtime_r(src, &dst); + if (!result) { + throw std::runtime_error("gmtime_r failed."); + } + return dst; + } + #else // fallback. not threadsafe + TOML11_INLINE std::tm localtime_s(const std::time_t* src) { + const auto result = std::localtime(src); + if (!result) { + throw std::runtime_error("localtime failed."); + } + return *result; + } + + TOML11_INLINE std::tm gmtime_s(const std::time_t* src) { + const auto result = std::gmtime(src); + if (!result) { + throw std::runtime_error("gmtime failed."); + } + return *result; + } + #endif + } // namespace detail + + // ---------------------------------------------------------------------------- + + TOML11_INLINE local_date::local_date( + const std::chrono::system_clock::time_point& tp) { + const auto t = std::chrono::system_clock::to_time_t(tp); + const auto time = detail::localtime_s(&t); + *this = local_date(time); + } + + TOML11_INLINE local_date::local_date(const std::time_t t) + : local_date { std::chrono::system_clock::from_time_t(t) } {} + + TOML11_INLINE local_date::operator std::chrono::system_clock::time_point() const { + // std::mktime returns date as local time zone. no conversion needed + std::tm t; + t.tm_sec = 0; + t.tm_min = 0; + t.tm_hour = 0; + t.tm_mday = static_cast(this->day); + t.tm_mon = static_cast(this->month); + t.tm_year = static_cast(this->year) - 1900; + t.tm_wday = 0; // the value will be ignored + t.tm_yday = 0; // the value will be ignored + t.tm_isdst = -1; + return std::chrono::system_clock::from_time_t(std::mktime(&t)); + } + + TOML11_INLINE local_date::operator std::time_t() const { + return std::chrono::system_clock::to_time_t( + std::chrono::system_clock::time_point(*this)); + } + + TOML11_INLINE bool operator==(const local_date& lhs, const local_date& rhs) { + return std::make_tuple(lhs.year, lhs.month, lhs.day) == + std::make_tuple(rhs.year, rhs.month, rhs.day); + } + + TOML11_INLINE bool operator!=(const local_date& lhs, const local_date& rhs) { + return !(lhs == rhs); + } + + TOML11_INLINE bool operator<(const local_date& lhs, const local_date& rhs) { + return std::make_tuple(lhs.year, lhs.month, lhs.day) < + std::make_tuple(rhs.year, rhs.month, rhs.day); + } + + TOML11_INLINE bool operator<=(const local_date& lhs, const local_date& rhs) { + return (lhs < rhs) || (lhs == rhs); + } + + TOML11_INLINE bool operator>(const local_date& lhs, const local_date& rhs) { + return !(lhs <= rhs); + } + + TOML11_INLINE bool operator>=(const local_date& lhs, const local_date& rhs) { + return !(lhs < rhs); + } + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, const local_date& date) { + os << std::setfill('0') << std::setw(4) << static_cast(date.year) << '-'; + os << std::setfill('0') << std::setw(2) << static_cast(date.month) + 1 + << '-'; + os << std::setfill('0') << std::setw(2) << static_cast(date.day); + return os; + } + + TOML11_INLINE std::string to_string(const local_date& date) { + std::ostringstream oss; + oss.imbue(std::locale::classic()); + oss << date; + return oss.str(); + } + + // ----------------------------------------------------------------------------- + + TOML11_INLINE local_time::operator std::chrono::nanoseconds() const { + return std::chrono::nanoseconds(this->nanosecond) + + std::chrono::microseconds(this->microsecond) + + std::chrono::milliseconds(this->millisecond) + + std::chrono::seconds(this->second) + + std::chrono::minutes(this->minute) + std::chrono::hours(this->hour); + } + + TOML11_INLINE bool operator==(const local_time& lhs, const local_time& rhs) { + return std::make_tuple(lhs.hour, + lhs.minute, + lhs.second, + lhs.millisecond, + lhs.microsecond, + lhs.nanosecond) == std::make_tuple(rhs.hour, + rhs.minute, + rhs.second, + rhs.millisecond, + rhs.microsecond, + rhs.nanosecond); + } + + TOML11_INLINE bool operator!=(const local_time& lhs, const local_time& rhs) { + return !(lhs == rhs); + } + + TOML11_INLINE bool operator<(const local_time& lhs, const local_time& rhs) { + return std::make_tuple(lhs.hour, + lhs.minute, + lhs.second, + lhs.millisecond, + lhs.microsecond, + lhs.nanosecond) < std::make_tuple(rhs.hour, + rhs.minute, + rhs.second, + rhs.millisecond, + rhs.microsecond, + rhs.nanosecond); + } + + TOML11_INLINE bool operator<=(const local_time& lhs, const local_time& rhs) { + return (lhs < rhs) || (lhs == rhs); + } + + TOML11_INLINE bool operator>(const local_time& lhs, const local_time& rhs) { + return !(lhs <= rhs); + } + + TOML11_INLINE bool operator>=(const local_time& lhs, const local_time& rhs) { + return !(lhs < rhs); + } + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, const local_time& time) { + os << std::setfill('0') << std::setw(2) << static_cast(time.hour) << ':'; + os << std::setfill('0') << std::setw(2) << static_cast(time.minute) << ':'; + os << std::setfill('0') << std::setw(2) << static_cast(time.second); + if (time.millisecond != 0 || time.microsecond != 0 || time.nanosecond != 0) { + os << '.'; + os << std::setfill('0') << std::setw(3) + << static_cast(time.millisecond); + if (time.microsecond != 0 || time.nanosecond != 0) { + os << std::setfill('0') << std::setw(3) + << static_cast(time.microsecond); + if (time.nanosecond != 0) { + os << std::setfill('0') << std::setw(3) + << static_cast(time.nanosecond); + } + } + } + return os; + } + + TOML11_INLINE std::string to_string(const local_time& time) { + std::ostringstream oss; + oss.imbue(std::locale::classic()); + oss << time; + return oss.str(); + } + + // ---------------------------------------------------------------------------- + + TOML11_INLINE time_offset::operator std::chrono::minutes() const { + return std::chrono::minutes(this->minute) + std::chrono::hours(this->hour); + } + + TOML11_INLINE bool operator==(const time_offset& lhs, const time_offset& rhs) { + return std::make_tuple(lhs.hour, lhs.minute) == + std::make_tuple(rhs.hour, rhs.minute); + } + + TOML11_INLINE bool operator!=(const time_offset& lhs, const time_offset& rhs) { + return !(lhs == rhs); + } + + TOML11_INLINE bool operator<(const time_offset& lhs, const time_offset& rhs) { + return std::make_tuple(lhs.hour, lhs.minute) < + std::make_tuple(rhs.hour, rhs.minute); + } + + TOML11_INLINE bool operator<=(const time_offset& lhs, const time_offset& rhs) { + return (lhs < rhs) || (lhs == rhs); + } + + TOML11_INLINE bool operator>(const time_offset& lhs, const time_offset& rhs) { + return !(lhs <= rhs); + } + + TOML11_INLINE bool operator>=(const time_offset& lhs, const time_offset& rhs) { + return !(lhs < rhs); + } + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, + const time_offset& offset) { + if (offset.hour == 0 && offset.minute == 0) { + os << 'Z'; + return os; + } + int minute = static_cast(offset.hour) * 60 + offset.minute; + if (minute < 0) { + os << '-'; + minute = std::abs(minute); + } else { + os << '+'; + } + os << std::setfill('0') << std::setw(2) << minute / 60 << ':'; + os << std::setfill('0') << std::setw(2) << minute % 60; + return os; + } + + TOML11_INLINE std::string to_string(const time_offset& offset) { + std::ostringstream oss; + oss.imbue(std::locale::classic()); + oss << offset; + return oss.str(); + } + + // ----------------------------------------------------------------------------- + + TOML11_INLINE local_datetime::local_datetime( + const std::chrono::system_clock::time_point& tp) { + const auto t = std::chrono::system_clock::to_time_t(tp); + std::tm ltime = detail::localtime_s(&t); + + this->date = local_date(ltime); + this->time = local_time(ltime); + + // std::tm lacks subsecond information, so diff between tp and tm + // can be used to get millisecond & microsecond information. + const auto t_diff = tp - std::chrono::system_clock::from_time_t( + std::mktime(<ime)); + this->time.millisecond = static_cast( + std::chrono::duration_cast(t_diff).count()); + this->time.microsecond = static_cast( + std::chrono::duration_cast(t_diff).count()); + this->time.nanosecond = static_cast( + std::chrono::duration_cast(t_diff).count()); + } + + TOML11_INLINE local_datetime::local_datetime(const std::time_t t) + : local_datetime { std::chrono::system_clock::from_time_t(t) } {} + + TOML11_INLINE local_datetime::operator std::chrono::system_clock::time_point() const { + using internal_duration = typename std::chrono::system_clock::time_point::duration; + + // Normally DST begins at A.M. 3 or 4. If we re-use conversion operator + // of local_date and local_time independently, the conversion fails if + // it is the day when DST begins or ends. Since local_date considers the + // time is 00:00 A.M. and local_time does not consider DST because it + // does not have any date information. We need to consider both date and + // time information at the same time to convert it correctly. + + std::tm t; + t.tm_sec = static_cast(this->time.second); + t.tm_min = static_cast(this->time.minute); + t.tm_hour = static_cast(this->time.hour); + t.tm_mday = static_cast(this->date.day); + t.tm_mon = static_cast(this->date.month); + t.tm_year = static_cast(this->date.year) - 1900; + t.tm_wday = 0; // the value will be ignored + t.tm_yday = 0; // the value will be ignored + t.tm_isdst = -1; + + // std::mktime returns date as local time zone. no conversion needed + auto dt = std::chrono::system_clock::from_time_t(std::mktime(&t)); + dt += std::chrono::duration_cast( + std::chrono::milliseconds(this->time.millisecond) + + std::chrono::microseconds(this->time.microsecond) + + std::chrono::nanoseconds(this->time.nanosecond)); + return dt; + } + + TOML11_INLINE local_datetime::operator std::time_t() const { + return std::chrono::system_clock::to_time_t( + std::chrono::system_clock::time_point(*this)); + } + + TOML11_INLINE bool operator==(const local_datetime& lhs, + const local_datetime& rhs) { + return std::make_tuple(lhs.date, lhs.time) == + std::make_tuple(rhs.date, rhs.time); + } + + TOML11_INLINE bool operator!=(const local_datetime& lhs, + const local_datetime& rhs) { + return !(lhs == rhs); + } + + TOML11_INLINE bool operator<(const local_datetime& lhs, + const local_datetime& rhs) { + return std::make_tuple(lhs.date, lhs.time) < + std::make_tuple(rhs.date, rhs.time); + } + + TOML11_INLINE bool operator<=(const local_datetime& lhs, + const local_datetime& rhs) { + return (lhs < rhs) || (lhs == rhs); + } + + TOML11_INLINE bool operator>(const local_datetime& lhs, + const local_datetime& rhs) { + return !(lhs <= rhs); + } + + TOML11_INLINE bool operator>=(const local_datetime& lhs, + const local_datetime& rhs) { + return !(lhs < rhs); + } + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, + const local_datetime& dt) { + os << dt.date << 'T' << dt.time; + return os; + } + + TOML11_INLINE std::string to_string(const local_datetime& dt) { + std::ostringstream oss; + oss.imbue(std::locale::classic()); + oss << dt; + return oss.str(); + } + + // ----------------------------------------------------------------------------- + + TOML11_INLINE offset_datetime::offset_datetime(const local_datetime& ld) + : date { ld.date } + , time { ld.time } + , offset { get_local_offset(nullptr) } + // use the current local timezone offset + {} + + TOML11_INLINE offset_datetime::offset_datetime( + const std::chrono::system_clock::time_point& tp) + : offset { 0, 0 } // use gmtime + { + const auto timet = std::chrono::system_clock::to_time_t(tp); + const auto tm = detail::gmtime_s(&timet); + this->date = local_date(tm); + this->time = local_time(tm); + } + + TOML11_INLINE offset_datetime::offset_datetime(const std::time_t& t) + : offset { 0, 0 } // use gmtime + { + const auto tm = detail::gmtime_s(&t); + this->date = local_date(tm); + this->time = local_time(tm); + } + + TOML11_INLINE offset_datetime::offset_datetime(const std::tm& t) + : offset { 0, 0 } // assume gmtime + { + this->date = local_date(t); + this->time = local_time(t); + } + + TOML11_INLINE offset_datetime::operator std::chrono::system_clock::time_point() const { + // get date-time + using internal_duration = typename std::chrono::system_clock::time_point::duration; + + // first, convert it to local date-time information in the same way as + // local_datetime does. later we will use time_t to adjust time offset. + std::tm t; + t.tm_sec = static_cast(this->time.second); + t.tm_min = static_cast(this->time.minute); + t.tm_hour = static_cast(this->time.hour); + t.tm_mday = static_cast(this->date.day); + t.tm_mon = static_cast(this->date.month); + t.tm_year = static_cast(this->date.year) - 1900; + t.tm_wday = 0; // the value will be ignored + t.tm_yday = 0; // the value will be ignored + t.tm_isdst = -1; + const std::time_t tp_loc = std::mktime(std::addressof(t)); + + auto tp = std::chrono::system_clock::from_time_t(tp_loc); + tp += std::chrono::duration_cast( + std::chrono::milliseconds(this->time.millisecond) + + std::chrono::microseconds(this->time.microsecond) + + std::chrono::nanoseconds(this->time.nanosecond)); + + // Since mktime uses local time zone, it should be corrected. + // `12:00:00+09:00` means `03:00:00Z`. So mktime returns `03:00:00Z` if + // we are in `+09:00` timezone. To represent `12:00:00Z` there, we need + // to add `+09:00` to `03:00:00Z`. + // Here, it uses the time_t converted from date-time info to handle + // daylight saving time. + const auto ofs = get_local_offset(std::addressof(tp_loc)); + tp += std::chrono::hours(ofs.hour); + tp += std::chrono::minutes(ofs.minute); + + // We got `12:00:00Z` by correcting local timezone applied by mktime. + // Then we will apply the offset. Let's say `12:00:00-08:00` is given. + // And now, we have `12:00:00Z`. `12:00:00-08:00` means `20:00:00Z`. + // So we need to subtract the offset. + tp -= std::chrono::minutes(this->offset); + return tp; + } + + TOML11_INLINE offset_datetime::operator std::time_t() const { + return std::chrono::system_clock::to_time_t( + std::chrono::system_clock::time_point(*this)); + } + + TOML11_INLINE time_offset offset_datetime::get_local_offset(const std::time_t* tp) { + // get local timezone with the same date-time information as mktime + const auto t = detail::localtime_s(tp); + + std::array buf; + const auto result = std::strftime(buf.data(), 6, "%z", &t); // +hhmm\0 + if (result != 5) { + throw std::runtime_error("toml::offset_datetime: cannot obtain " + "timezone information of current env"); + } + const int ofs = std::atoi(buf.data()); + const int ofs_h = ofs / 100; + const int ofs_m = ofs - (ofs_h * 100); + return time_offset(ofs_h, ofs_m); + } + + TOML11_INLINE bool operator==(const offset_datetime& lhs, + const offset_datetime& rhs) { + return std::make_tuple(lhs.date, lhs.time, lhs.offset) == + std::make_tuple(rhs.date, rhs.time, rhs.offset); + } + + TOML11_INLINE bool operator!=(const offset_datetime& lhs, + const offset_datetime& rhs) { + return !(lhs == rhs); + } + + TOML11_INLINE bool operator<(const offset_datetime& lhs, + const offset_datetime& rhs) { + return std::make_tuple(lhs.date, lhs.time, lhs.offset) < + std::make_tuple(rhs.date, rhs.time, rhs.offset); + } + + TOML11_INLINE bool operator<=(const offset_datetime& lhs, + const offset_datetime& rhs) { + return (lhs < rhs) || (lhs == rhs); + } + + TOML11_INLINE bool operator>(const offset_datetime& lhs, + const offset_datetime& rhs) { + return !(lhs <= rhs); + } + + TOML11_INLINE bool operator>=(const offset_datetime& lhs, + const offset_datetime& rhs) { + return !(lhs < rhs); + } + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, + const offset_datetime& dt) { + os << dt.date << 'T' << dt.time << dt.offset; + return os; + } + + TOML11_INLINE std::string to_string(const offset_datetime& dt) { + std::ostringstream oss; + oss.imbue(std::locale::classic()); + oss << dt; + return oss.str(); + } + +} // namespace toml + #endif // TOML11_DATETIME_IMPL_HPP +#endif + +#endif // TOML11_DATETIME_HPP +#ifndef TOML11_COMPAT_HPP +#define TOML11_COMPAT_HPP + +#include +#include +#include +#include +#include +#include + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX20_VALUE + #if __has_include() + #include + #endif +#endif + +#include + +// ---------------------------------------------------------------------------- + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX14_VALUE + #if __has_cpp_attribute(deprecated) + #define TOML11_HAS_ATTR_DEPRECATED 1 + #endif +#endif + +#if defined(TOML11_HAS_ATTR_DEPRECATED) + #define TOML11_DEPRECATED(msg) [[deprecated(msg)]] +#elif defined(__GNUC__) + #define TOML11_DEPRECATED(msg) __attribute__((deprecated(msg))) +#elif defined(_MSC_VER) + #define TOML11_DEPRECATED(msg) __declspec(deprecated(msg)) +#else + #define TOML11_DEPRECATED(msg) +#endif + +// ---------------------------------------------------------------------------- + +#if defined(__cpp_if_constexpr) + #if __cpp_if_constexpr >= 201606L + #define TOML11_HAS_CONSTEXPR_IF 1 + #endif +#endif + +#if defined(TOML11_HAS_CONSTEXPR_IF) + #define TOML11_CONSTEXPR_IF if constexpr +#else + #define TOML11_CONSTEXPR_IF if +#endif + +// ---------------------------------------------------------------------------- + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX14_VALUE + #if defined(__cpp_lib_make_unique) + #if __cpp_lib_make_unique >= 201304L + #define TOML11_HAS_STD_MAKE_UNIQUE 1 + #endif + #endif +#endif + +namespace toml { + namespace cxx { + +#if defined(TOML11_HAS_STD_MAKE_UNIQUE) + + using std::make_unique; + +#else + + template + std::unique_ptr make_unique(Ts&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); + } + +#endif // TOML11_HAS_STD_MAKE_UNIQUE + + } // namespace cxx +} // namespace toml + +// --------------------------------------------------------------------------- + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX14_VALUE + #if defined(__cpp_lib_make_reverse_iterator) + #if __cpp_lib_make_reverse_iterator >= 201402L + #define TOML11_HAS_STD_MAKE_REVERSE_ITERATOR 1 + #endif + #endif +#endif + +namespace toml { + namespace cxx { +#if defined(TOML11_HAS_STD_MAKE_REVERSE_ITERATOR) + + using std::make_reverse_iterator; + +#else + + template + std::reverse_iterator make_reverse_iterator(Iterator iter) { + return std::reverse_iterator(iter); + } + +#endif // TOML11_HAS_STD_MAKE_REVERSE_ITERATOR + + } // namespace cxx +} // namespace toml + +// --------------------------------------------------------------------------- + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX20_VALUE + #if defined(__cpp_lib_clamp) + #if __cpp_lib_clamp >= 201603L + #define TOML11_HAS_STD_CLAMP 1 + #endif + #endif +#endif + +namespace toml { + namespace cxx { +#if defined(TOML11_HAS_STD_CLAMP) + + using std::clamp; + +#else + + template + T clamp(const T& x, const T& low, const T& high) noexcept { + assert(low <= high); + return (std::min)((std::max)(x, low), high); + } + +#endif // TOML11_HAS_STD_CLAMP + + } // namespace cxx +} // namespace toml + +// --------------------------------------------------------------------------- + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX20_VALUE + #if defined(__cpp_lib_bit_cast) + #if __cpp_lib_bit_cast >= 201806L + #define TOML11_HAS_STD_BIT_CAST 1 + #endif + #endif +#endif + +namespace toml { + namespace cxx { +#if defined(TOML11_HAS_STD_BIT_CAST) + + using std::bit_cast; + +#else + + template + U bit_cast(const T& x) noexcept { + static_assert(sizeof(T) == sizeof(U), ""); + static_assert(std::is_default_constructible::value, ""); + + U z; + std::memcpy(reinterpret_cast(std::addressof(z)), + reinterpret_cast(std::addressof(x)), + sizeof(T)); + + return z; + } + +#endif // TOML11_HAS_STD_BIT_CAST + + } // namespace cxx +} // namespace toml + +// --------------------------------------------------------------------------- +// C++20 remove_cvref_t + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX20_VALUE + #if defined(__cpp_lib_remove_cvref) + #if __cpp_lib_remove_cvref >= 201711L + #define TOML11_HAS_STD_REMOVE_CVREF 1 + #endif + #endif +#endif + +namespace toml { + namespace cxx { +#if defined(TOML11_HAS_STD_REMOVE_CVREF) + + using std::remove_cvref; + using std::remove_cvref_t; + +#else + + template + struct remove_cvref { + using type = typename std::remove_cv::type>::type; + }; + + template + using remove_cvref_t = typename remove_cvref::type; + +#endif // TOML11_HAS_STD_REMOVE_CVREF + + } // namespace cxx +} // namespace toml + +// --------------------------------------------------------------------------- +// C++17 and/or/not + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX17_VALUE + #if defined(__cpp_lib_logical_traits) + #if __cpp_lib_logical_traits >= 201510L + #define TOML11_HAS_STD_CONJUNCTION 1 + #endif + #endif +#endif + +namespace toml { + namespace cxx { +#if defined(TOML11_HAS_STD_CONJUNCTION) + + using std::conjunction; + using std::disjunction; + using std::negation; + +#else + + template + struct conjunction : std::true_type {}; + + template + struct conjunction : T {}; + + template + struct conjunction + : std::conditional(T::value), conjunction, T>::type { + }; + + template + struct disjunction : std::false_type {}; + + template + struct disjunction : T {}; + + template + struct disjunction + : std::conditional(T::value), T, disjunction>::type { + }; + + template + struct negation + : std::integral_constant(T::value)> {}; + +#endif // TOML11_HAS_STD_CONJUNCTION + + } // namespace cxx +} // namespace toml + +// --------------------------------------------------------------------------- +// C++14 index_sequence + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX14_VALUE + #if defined(__cpp_lib_integer_sequence) + #if __cpp_lib_integer_sequence >= 201304L + #define TOML11_HAS_STD_INTEGER_SEQUENCE 1 + #endif + #endif +#endif + +namespace toml { + namespace cxx { +#if defined(TOML11_HAS_STD_INTEGER_SEQUENCE) + + using std::index_sequence; + using std::make_index_sequence; + +#else + + template + struct index_sequence {}; + + template + struct double_index_sequence; + + template + struct double_index_sequence> { + using type = index_sequence; + }; + + template + struct double_index_sequence> { + using type = index_sequence; + }; + + template + struct index_sequence_maker { + using type = + typename double_index_sequence::type>::type; + }; + + template <> + struct index_sequence_maker<0> { + using type = index_sequence<>; + }; + + template + using make_index_sequence = typename index_sequence_maker::type; + +#endif // TOML11_HAS_STD_INTEGER_SEQUENCE + + } // namespace cxx +} // namespace toml + +// --------------------------------------------------------------------------- +// C++14 enable_if_t + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX14_VALUE + #if defined(__cpp_lib_transformation_trait_aliases) + #if __cpp_lib_transformation_trait_aliases >= 201304L + #define TOML11_HAS_STD_ENABLE_IF_T 1 + #endif + #endif +#endif + +namespace toml { + namespace cxx { +#if defined(TOML11_HAS_STD_ENABLE_IF_T) + + using std::enable_if_t; + +#else + + template + using enable_if_t = typename std::enable_if::type; + +#endif // TOML11_HAS_STD_ENABLE_IF_T + + } // namespace cxx +} // namespace toml + +// --------------------------------------------------------------------------- +// return_type_of_t + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX17_VALUE + #if defined(__cpp_lib_is_invocable) + #if __cpp_lib_is_invocable >= 201703 + #define TOML11_HAS_STD_INVOKE_RESULT 1 + #endif + #endif +#endif + +namespace toml { + namespace cxx { +#if defined(TOML11_HAS_STD_INVOKE_RESULT) + + template + using return_type_of_t = std::invoke_result_t; + +#else + + // result_of is deprecated after C++17 + template + using return_type_of_t = typename std::result_of::type; + +#endif // TOML11_HAS_STD_INVOKE_RESULT + + } // namespace cxx +} // namespace toml + +// ---------------------------------------------------------------------------- +// (subset of) source_location + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= 202002L + #if __has_include() + #define TOML11_HAS_STD_SOURCE_LOCATION + #endif // has_include +#endif // c++20 + +#if !defined(TOML11_HAS_STD_SOURCE_LOCATION) + #if defined(__GNUC__) && !defined(__clang__) + #if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX14_VALUE + #if __has_include() + #define TOML11_HAS_EXPERIMENTAL_SOURCE_LOCATION + #endif + #endif + #endif // GNU g++ +#endif // not TOML11_HAS_STD_SOURCE_LOCATION + +#if !defined(TOML11_HAS_STD_SOURCE_LOCATION) && \ + !defined(TOML11_HAS_EXPERIMENTAL_SOURCE_LOCATION) + #if defined(__GNUC__) && !defined(__clang__) + #if (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 9)) + #define TOML11_HAS_BUILTIN_FILE_LINE 1 + #define TOML11_BUILTIN_LINE_TYPE int + #endif + #elif defined(__clang__) // clang 9.0.0 implements builtin_FILE/LINE + #if __has_builtin(__builtin_FILE) && __has_builtin(__builtin_LINE) + #define TOML11_HAS_BUILTIN_FILE_LINE 1 + #define TOML11_BUILTIN_LINE_TYPE unsigned int + #endif + #elif defined(_MSVC_LANG) && defined(_MSC_VER) + #if _MSC_VER > 1926 + #define TOML11_HAS_BUILTIN_FILE_LINE 1 + #define TOML11_BUILTIN_LINE_TYPE int + #endif + #endif +#endif + +#if defined(TOML11_HAS_STD_SOURCE_LOCATION) + #include + +namespace toml { + namespace cxx { + using source_location = std::source_location; + + inline std::string to_string(const source_location& loc) { + return std::string(" at line ") + std::to_string(loc.line()) + + std::string(" in file ") + std::string(loc.file_name()); + } + } // namespace cxx +} // namespace toml +#elif defined(TOML11_HAS_EXPERIMENTAL_SOURCE_LOCATION) + #include + +namespace toml { + namespace cxx { + using source_location = std::experimental::source_location; + + inline std::string to_string(const source_location& loc) { + return std::string(" at line ") + std::to_string(loc.line()) + + std::string(" in file ") + std::string(loc.file_name()); + } + } // namespace cxx +} // namespace toml +#elif defined(TOML11_HAS_BUILTIN_FILE_LINE) +namespace toml { + namespace cxx { + struct source_location { + using line_type = TOML11_BUILTIN_LINE_TYPE; + + static source_location current(const line_type line = __builtin_LINE(), + const char* file = __builtin_FILE()) { + return source_location(line, file); + } + + source_location(const line_type line, const char* file) + : line_(line) + , file_name_(file) {} + + line_type line() const noexcept { + return line_; + } + + const char* file_name() const noexcept { + return file_name_; + } + + private: + line_type line_; + const char* file_name_; + }; + + inline std::string to_string(const source_location& loc) { + return std::string(" at line ") + std::to_string(loc.line()) + + std::string(" in file ") + std::string(loc.file_name()); + } + } // namespace cxx +} // namespace toml +#else // no builtin +namespace toml { + namespace cxx { + struct source_location { + static source_location current() { + return source_location {}; + } + }; + + inline std::string to_string(const source_location&) { + return std::string(""); + } + } // namespace cxx +} // namespace toml +#endif // TOML11_HAS_STD_SOURCE_LOCATION + +// ---------------------------------------------------------------------------- +// (subset of) optional + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX17_VALUE + #if __has_include() + #include + #endif // has_include(optional) +#endif // C++17 + +#if TOML11_CPLUSPLUS_STANDARD_VERSION >= TOML11_CXX17_VALUE + #if defined(__cpp_lib_optional) + #if __cpp_lib_optional >= 201606L + #define TOML11_HAS_STD_OPTIONAL 1 + #endif + #endif +#endif + +#if defined(TOML11_HAS_STD_OPTIONAL) + +namespace toml { + namespace cxx { + using std::optional; + + inline std::nullopt_t make_nullopt() { + return std::nullopt; + } + + template + std::basic_ostream& operator<<( + std::basic_ostream& os, + const std::nullopt_t&) { + os << "nullopt"; + return os; + } + + } // namespace cxx +} // namespace toml + +#else // TOML11_HAS_STD_OPTIONAL + +namespace toml { + namespace cxx { + + struct nullopt_t {}; + + inline nullopt_t make_nullopt() { + return nullopt_t {}; + } + + inline bool operator==(const nullopt_t&, const nullopt_t&) noexcept { + return true; + } + + inline bool operator!=(const nullopt_t&, const nullopt_t&) noexcept { + return false; + } + + inline bool operator<(const nullopt_t&, const nullopt_t&) noexcept { + return false; + } + + inline bool operator<=(const nullopt_t&, const nullopt_t&) noexcept { + return true; + } + + inline bool operator>(const nullopt_t&, const nullopt_t&) noexcept { + return false; + } + + inline bool operator>=(const nullopt_t&, const nullopt_t&) noexcept { + return true; + } + + template + std::basic_ostream& operator<<( + std::basic_ostream& os, + const nullopt_t&) { + os << "nullopt"; + return os; + } + + template + class optional { + public: + using value_type = T; + + public: + optional() noexcept : has_value_(false), null_('\0') {} + + optional(nullopt_t) noexcept : has_value_(false), null_('\0') {} + + optional(const T& x) : has_value_(true), value_(x) {} + + optional(T&& x) : has_value_(true), value_(std::move(x)) {} + + template ::value, std::nullptr_t> = nullptr> + explicit optional(U&& x) : has_value_(true) + , value_(std::forward(x)) {} + + optional(const optional& rhs) : has_value_(rhs.has_value_) { + if (rhs.has_value_) { + this->assigner(rhs.value_); + } + } + + optional(optional&& rhs) : has_value_(rhs.has_value_) { + if (this->has_value_) { + this->assigner(std::move(rhs.value_)); + } + } + + optional& operator=(const optional& rhs) { + if (this == std::addressof(rhs)) { + return *this; + } + + this->cleanup(); + this->has_value_ = rhs.has_value_; + if (this->has_value_) { + this->assigner(rhs.value_); + } + return *this; + } + + optional& operator=(optional&& rhs) { + if (this == std::addressof(rhs)) { + return *this; + } + + this->cleanup(); + this->has_value_ = rhs.has_value_; + if (this->has_value_) { + this->assigner(std::move(rhs.value_)); + } + return *this; + } + + template >, + std::is_constructible>::value, + std::nullptr_t> = nullptr> + explicit optional(const optional& rhs) + : has_value_(rhs.has_value_) + , null_('\0') { + if (rhs.has_value_) { + this->assigner(rhs.value_); + } + } + + template >, + std::is_constructible>::value, + std::nullptr_t> = nullptr> + explicit optional(optional&& rhs) + : has_value_(rhs.has_value_) + , null_('\0') { + if (this->has_value_) { + this->assigner(std::move(rhs.value_)); + } + } + + template >, + std::is_constructible>::value, + std::nullptr_t> = nullptr> + optional& operator=(const optional& rhs) { + if (this == std::addressof(rhs)) { + return *this; + } + + this->cleanup(); + this->has_value_ = rhs.has_value_; + if (this->has_value_) { + this->assigner(rhs.value_); + } + return *this; + } + + template >, + std::is_constructible>::value, + std::nullptr_t> = nullptr> + optional& operator=(optional&& rhs) { + if (this == std::addressof(rhs)) { + return *this; + } + + this->cleanup(); + this->has_value_ = rhs.has_value_; + if (this->has_value_) { + this->assigner(std::move(rhs.value_)); + } + return *this; + } + + ~optional() noexcept { + this->cleanup(); + } + + explicit operator bool() const noexcept { + return has_value_; + } + + bool has_value() const noexcept { + return has_value_; + } + + const value_type& value(source_location loc = source_location::current()) const { + if (!this->has_value_) { + throw std::runtime_error( + "optional::value(): bad_unwrap" + to_string(loc)); + } + return this->value_; + } + + value_type& value(source_location loc = source_location::current()) { + if (!this->has_value_) { + throw std::runtime_error( + "optional::value(): bad_unwrap" + to_string(loc)); + } + return this->value_; + } + + const value_type& value_or(const value_type& opt) const { + if (this->has_value_) { + return this->value_; + } else { + return opt; + } + } + + value_type& value_or(value_type& opt) { + if (this->has_value_) { + return this->value_; + } else { + return opt; + } + } + + private: + void cleanup() noexcept { + if (this->has_value_) { + value_.~T(); + } + } + + template + void assigner(U&& x) { + const auto tmp = ::new (std::addressof(this->value_)) + value_type(std::forward(x)); + assert(tmp == std::addressof(this->value_)); + (void)tmp; + } + + private: + bool has_value_; + + union { + char null_; + T value_; + }; + }; + } // namespace cxx +} // namespace toml +#endif // TOML11_HAS_STD_OPTIONAL + +#endif // TOML11_COMPAT_HPP +#ifndef TOML11_VALUE_T_HPP +#define TOML11_VALUE_T_HPP + +#ifndef TOML11_VALUE_T_FWD_HPP + #define TOML11_VALUE_T_FWD_HPP + + #include + #include + #include + #include + +namespace toml { + + // forward decl + template + class basic_value; + + // ---------------------------------------------------------------------------- + // enum representing toml types + + enum class value_t : std::uint8_t { + empty = 0, + boolean = 1, + integer = 2, + floating = 3, + string = 4, + offset_datetime = 5, + local_datetime = 6, + local_date = 7, + local_time = 8, + array = 9, + table = 10 + }; + + std::ostream& operator<<(std::ostream& os, value_t t); + std::string to_string(value_t t); + + // ---------------------------------------------------------------------------- + // meta functions for internal use + + namespace detail { + + template + using value_t_constant = std::integral_constant; + + template + struct type_to_enum : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct type_to_enum + : value_t_constant {}; + + template + struct enum_to_type { + using type = void; + }; + + template + struct enum_to_type { + using type = typename V::boolean_type; + }; + + template + struct enum_to_type { + using type = typename V::integer_type; + }; + + template + struct enum_to_type { + using type = typename V::floating_type; + }; + + template + struct enum_to_type { + using type = typename V::string_type; + }; + + template + struct enum_to_type { + using type = typename V::offset_datetime_type; + }; + + template + struct enum_to_type { + using type = typename V::local_datetime_type; + }; + + template + struct enum_to_type { + using type = typename V::local_date_type; + }; + + template + struct enum_to_type { + using type = typename V::local_time_type; + }; + + template + struct enum_to_type { + using type = typename V::array_type; + }; + + template + struct enum_to_type { + using type = typename V::table_type; + }; + + template + using enum_to_type_t = typename enum_to_type::type; + + template + struct enum_to_fmt_type { + using type = void; + }; + + template <> + struct enum_to_fmt_type { + using type = boolean_format_info; + }; + + template <> + struct enum_to_fmt_type { + using type = integer_format_info; + }; + + template <> + struct enum_to_fmt_type { + using type = floating_format_info; + }; + + template <> + struct enum_to_fmt_type { + using type = string_format_info; + }; + + template <> + struct enum_to_fmt_type { + using type = offset_datetime_format_info; + }; + + template <> + struct enum_to_fmt_type { + using type = local_datetime_format_info; + }; + + template <> + struct enum_to_fmt_type { + using type = local_date_format_info; + }; + + template <> + struct enum_to_fmt_type { + using type = local_time_format_info; + }; + + template <> + struct enum_to_fmt_type { + using type = array_format_info; + }; + + template <> + struct enum_to_fmt_type { + using type = table_format_info; + }; + + template + using enum_to_fmt_type_t = typename enum_to_fmt_type::type; + + template + struct is_exact_toml_type0 + : cxx::disjunction, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same> {}; + + template + struct is_exact_toml_type : is_exact_toml_type0, V> {}; + + template + struct is_not_toml_type : cxx::negation> {}; + + } // namespace detail +} // namespace toml +#endif // TOML11_VALUE_T_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_VALUE_T_IMPL_HPP + #define TOML11_VALUE_T_IMPL_HPP + + #include + #include + #include + +namespace toml { + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, value_t t) { + switch (t) { + case value_t::boolean: + os << "boolean"; + return os; + case value_t::integer: + os << "integer"; + return os; + case value_t::floating: + os << "floating"; + return os; + case value_t::string: + os << "string"; + return os; + case value_t::offset_datetime: + os << "offset_datetime"; + return os; + case value_t::local_datetime: + os << "local_datetime"; + return os; + case value_t::local_date: + os << "local_date"; + return os; + case value_t::local_time: + os << "local_time"; + return os; + case value_t::array: + os << "array"; + return os; + case value_t::table: + os << "table"; + return os; + case value_t::empty: + os << "empty"; + return os; + default: + os << "unknown"; + return os; + } + } + + TOML11_INLINE std::string to_string(value_t t) { + std::ostringstream oss; + oss << t; + return oss.str(); + } + +} // namespace toml + #endif // TOML11_VALUE_T_IMPL_HPP +#endif + +#endif // TOML11_VALUE_T_HPP +#ifndef TOML11_STORAGE_HPP +#define TOML11_STORAGE_HPP + +namespace toml { + namespace detail { + + // It owns a pointer to T. It does deep-copy when copied. + // This struct is introduced to implement a recursive type. + // + // `toml::value` contains `std::vector` to represent a toml + // array. But, in the definition of `toml::value`, `toml::value` is still + // incomplete. `std::vector` of an incomplete type is not allowed in C++11 + // (it is allowed after C++17). To avoid this, we need to use a pointer to + // `toml::value`, like `std::vector>`. Although + // `std::unique_ptr` is noncopyable, we want to make `toml::value` copyable. + // `storage` is introduced to resolve those problems. + template + struct storage { + using value_type = T; + + explicit storage(value_type v) + : ptr_(cxx::make_unique(std::move(v))) {} + + ~storage() = default; + + storage(const storage& rhs) : ptr_(cxx::make_unique(*rhs.ptr_)) {} + + storage& operator=(const storage& rhs) { + this->ptr_ = cxx::make_unique(*rhs.ptr_); + return *this; + } + + storage(storage&&) = default; + storage& operator=(storage&&) = default; + + bool is_ok() const noexcept { + return static_cast(ptr_); + } + + value_type& get() const noexcept { + return *ptr_; + } + + private: + std::unique_ptr ptr_; + }; + + } // namespace detail +} // namespace toml +#endif // TOML11_STORAGE_HPP +#ifndef TOML11_COMMENTS_HPP +#define TOML11_COMMENTS_HPP + +#ifndef TOML11_COMMENTS_FWD_HPP + #define TOML11_COMMENTS_FWD_HPP + + // to use __has_builtin + + #include + #include + #include + #include + #include + #include + #include + #include + #include + +// This file provides mainly two classes, `preserve_comments` and `discard_comments`. +// Those two are a container that have the same interface as `std::vector` +// but bahaves in the opposite way. `preserve_comments` is just the same as +// `std::vector` and each `std::string` corresponds to a comment line. +// Conversely, `discard_comments` discards all the strings and ignores everything +// assigned in it. `discard_comments` is always empty and you will encounter an +// error whenever you access to the element. +namespace toml { + class discard_comments; // forward decl + + class preserve_comments { + public: + // `container_type` is not provided in discard_comments. + // do not use this inner-type in a generic code. + using container_type = std::vector; + + using size_type = container_type::size_type; + using difference_type = container_type::difference_type; + using value_type = container_type::value_type; + using reference = container_type::reference; + using const_reference = container_type::const_reference; + using pointer = container_type::pointer; + using const_pointer = container_type::const_pointer; + using iterator = container_type::iterator; + using const_iterator = container_type::const_iterator; + using reverse_iterator = container_type::reverse_iterator; + using const_reverse_iterator = container_type::const_reverse_iterator; + + public: + preserve_comments() = default; + ~preserve_comments() = default; + preserve_comments(const preserve_comments&) = default; + preserve_comments(preserve_comments&&) = default; + preserve_comments& operator=(const preserve_comments&) = default; + preserve_comments& operator=(preserve_comments&&) = default; + + explicit preserve_comments(const std::vector& c) + : comments(c) {} + + explicit preserve_comments(std::vector&& c) + : comments(std::move(c)) {} + + preserve_comments& operator=(const std::vector& c) { + comments = c; + return *this; + } + + preserve_comments& operator=(std::vector&& c) { + comments = std::move(c); + return *this; + } + + explicit preserve_comments(const discard_comments&) {} + + explicit preserve_comments(size_type n) : comments(n) {} + + preserve_comments(size_type n, const std::string& x) : comments(n, x) {} + + preserve_comments(std::initializer_list x) : comments(x) {} + + template + preserve_comments(InputIterator first, InputIterator last) + : comments(first, last) {} + + template + void assign(InputIterator first, InputIterator last) { + comments.assign(first, last); + } + + void assign(std::initializer_list ini) { + comments.assign(ini); + } + + void assign(size_type n, const std::string& val) { + comments.assign(n, val); + } + + // Related to the issue #97. + // + // `std::vector::insert` and `std::vector::erase` in the STL implementation + // included in GCC 4.8.5 takes `std::vector::iterator` instead of + // `std::vector::const_iterator`. It causes compilation error in GCC 4.8.5. + #if defined(__GNUC__) && defined(__GNUC_MINOR__) && \ + defined(__GNUC_PATCHLEVEL__) && !defined(__clang__) + #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) <= 40805 + #define TOML11_WORKAROUND_GCC_4_8_X_STANDARD_LIBRARY_IMPLEMENTATION + #endif + #endif + + #ifdef TOML11_WORKAROUND_GCC_4_8_X_STANDARD_LIBRARY_IMPLEMENTATION + iterator insert(iterator p, const std::string& x) { + return comments.insert(p, x); + } + + iterator insert(iterator p, std::string&& x) { + return comments.insert(p, std::move(x)); + } + + void insert(iterator p, size_type n, const std::string& x) { + return comments.insert(p, n, x); + } + + template + void insert(iterator p, InputIterator first, InputIterator last) { + return comments.insert(p, first, last); + } + + void insert(iterator p, std::initializer_list ini) { + return comments.insert(p, ini); + } + + template + iterator emplace(iterator p, Ts&&... args) { + return comments.emplace(p, std::forward(args)...); + } + + iterator erase(iterator pos) { + return comments.erase(pos); + } + + iterator erase(iterator first, iterator last) { + return comments.erase(first, last); + } + #else + iterator insert(const_iterator p, const std::string& x) { + return comments.insert(p, x); + } + + iterator insert(const_iterator p, std::string&& x) { + return comments.insert(p, std::move(x)); + } + + iterator insert(const_iterator p, size_type n, const std::string& x) { + return comments.insert(p, n, x); + } + + template + iterator insert(const_iterator p, InputIterator first, InputIterator last) { + return comments.insert(p, first, last); + } + + iterator insert(const_iterator p, std::initializer_list ini) { + return comments.insert(p, ini); + } + + template + iterator emplace(const_iterator p, Ts&&... args) { + return comments.emplace(p, std::forward(args)...); + } + + iterator erase(const_iterator pos) { + return comments.erase(pos); + } + + iterator erase(const_iterator first, const_iterator last) { + return comments.erase(first, last); + } + #endif + + void swap(preserve_comments& other) { + comments.swap(other.comments); + } + + void push_back(const std::string& v) { + comments.push_back(v); + } + + void push_back(std::string&& v) { + comments.push_back(std::move(v)); + } + + void pop_back() { + comments.pop_back(); + } + + template + void emplace_back(Ts&&... args) { + comments.emplace_back(std::forward(args)...); + } + + void clear() { + comments.clear(); + } + + size_type size() const noexcept { + return comments.size(); + } + + size_type max_size() const noexcept { + return comments.max_size(); + } + + size_type capacity() const noexcept { + return comments.capacity(); + } + + bool empty() const noexcept { + return comments.empty(); + } + + void reserve(size_type n) { + comments.reserve(n); + } + + void resize(size_type n) { + comments.resize(n); + } + + void resize(size_type n, const std::string& c) { + comments.resize(n, c); + } + + void shrink_to_fit() { + comments.shrink_to_fit(); + } + + reference operator[](const size_type n) noexcept { + return comments[n]; + } + + const_reference operator[](const size_type n) const noexcept { + return comments[n]; + } + + reference at(const size_type n) { + return comments.at(n); + } + + const_reference at(const size_type n) const { + return comments.at(n); + } + + reference front() noexcept { + return comments.front(); + } + + const_reference front() const noexcept { + return comments.front(); + } + + reference back() noexcept { + return comments.back(); + } + + const_reference back() const noexcept { + return comments.back(); + } + + pointer data() noexcept { + return comments.data(); + } + + const_pointer data() const noexcept { + return comments.data(); + } + + iterator begin() noexcept { + return comments.begin(); + } + + iterator end() noexcept { + return comments.end(); + } + + const_iterator begin() const noexcept { + return comments.begin(); + } + + const_iterator end() const noexcept { + return comments.end(); + } + + const_iterator cbegin() const noexcept { + return comments.cbegin(); + } + + const_iterator cend() const noexcept { + return comments.cend(); + } + + reverse_iterator rbegin() noexcept { + return comments.rbegin(); + } + + reverse_iterator rend() noexcept { + return comments.rend(); + } + + const_reverse_iterator rbegin() const noexcept { + return comments.rbegin(); + } + + const_reverse_iterator rend() const noexcept { + return comments.rend(); + } + + const_reverse_iterator crbegin() const noexcept { + return comments.crbegin(); + } + + const_reverse_iterator crend() const noexcept { + return comments.crend(); + } + + friend bool operator==(const preserve_comments&, const preserve_comments&); + friend bool operator!=(const preserve_comments&, const preserve_comments&); + friend bool operator<(const preserve_comments&, const preserve_comments&); + friend bool operator<=(const preserve_comments&, const preserve_comments&); + friend bool operator>(const preserve_comments&, const preserve_comments&); + friend bool operator>=(const preserve_comments&, const preserve_comments&); + + friend void swap(preserve_comments&, std::vector&); + friend void swap(std::vector&, preserve_comments&); + + private: + container_type comments; + }; + + bool operator==(const preserve_comments& lhs, const preserve_comments& rhs); + bool operator!=(const preserve_comments& lhs, const preserve_comments& rhs); + bool operator<(const preserve_comments& lhs, const preserve_comments& rhs); + bool operator<=(const preserve_comments& lhs, const preserve_comments& rhs); + bool operator>(const preserve_comments& lhs, const preserve_comments& rhs); + bool operator>=(const preserve_comments& lhs, const preserve_comments& rhs); + + void swap(preserve_comments& lhs, preserve_comments& rhs); + void swap(preserve_comments& lhs, std::vector& rhs); + void swap(std::vector& lhs, preserve_comments& rhs); + + std::ostream& operator<<(std::ostream& os, const preserve_comments& com); + + namespace detail { + + // To provide the same interface with `preserve_comments`, + // `discard_comments` should have an iterator. But it does not contain + // anything, so we need to add an iterator that points nothing. + // + // It always points null, so DO NOT unwrap this iterator. It always crashes + // your program. + template + struct empty_iterator { + using value_type = T; + using reference_type = typename std::conditional::type; + using pointer_type = typename std::conditional::type; + using difference_type = std::ptrdiff_t; + using iterator_category = std::random_access_iterator_tag; + + empty_iterator() = default; + ~empty_iterator() = default; + empty_iterator(const empty_iterator&) = default; + empty_iterator(empty_iterator&&) = default; + empty_iterator& operator=(const empty_iterator&) = default; + empty_iterator& operator=(empty_iterator&&) = default; + + // DO NOT call these operators. + reference_type operator*() const noexcept { + std::terminate(); + } + + pointer_type operator->() const noexcept { + return nullptr; + } + + reference_type operator[](difference_type) const noexcept { + return this->operator*(); + } + + // These operators do nothing. + empty_iterator& operator++() noexcept { + return *this; + } + + empty_iterator operator++(int) noexcept { + return *this; + } + + empty_iterator& operator--() noexcept { + return *this; + } + + empty_iterator operator--(int) noexcept { + return *this; + } + + empty_iterator& operator+=(difference_type) noexcept { + return *this; + } + + empty_iterator& operator-=(difference_type) noexcept { + return *this; + } + + empty_iterator operator+(difference_type) const noexcept { + return *this; + } + + empty_iterator operator-(difference_type) const noexcept { + return *this; + } + }; + + template + bool operator==(const empty_iterator&, + const empty_iterator&) noexcept { + return true; + } + + template + bool operator!=(const empty_iterator&, + const empty_iterator&) noexcept { + return false; + } + + template + bool operator<(const empty_iterator&, + const empty_iterator&) noexcept { + return false; + } + + template + bool operator<=(const empty_iterator&, + const empty_iterator&) noexcept { + return true; + } + + template + bool operator>(const empty_iterator&, + const empty_iterator&) noexcept { + return false; + } + + template + bool operator>=(const empty_iterator&, + const empty_iterator&) noexcept { + return true; + } + + template + typename empty_iterator::difference_type operator-( + const empty_iterator&, + const empty_iterator&) noexcept { + return 0; + } + + template + empty_iterator operator+(typename empty_iterator::difference_type, + const empty_iterator& rhs) noexcept { + return rhs; + } + + template + empty_iterator operator+( + const empty_iterator& lhs, + typename empty_iterator::difference_type) noexcept { + return lhs; + } + + } // namespace detail + + // The default comment type. It discards all the comments. It requires only one + // byte to contain, so the memory footprint is smaller than preserve_comments. + // + // It just ignores `push_back`, `insert`, `erase`, and any other modifications. + // IT always returns size() == 0, the iterator taken by `begin()` is always the + // same as that of `end()`, and accessing through `operator[]` or iterators + // always causes a segmentation fault. DO NOT access to the element of this. + // + // Why this is chose as the default type is because the last version (2.x.y) + // does not contain any comments in a value. To minimize the impact on the + // efficiency, this is chosen as a default. + // + // To reduce the memory footprint, later we can try empty base optimization (EBO). + class discard_comments { + public: + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using value_type = std::string; + using reference = std::string&; + using const_reference = const std::string&; + using pointer = std::string*; + using const_pointer = const std::string*; + using iterator = detail::empty_iterator; + using const_iterator = detail::empty_iterator; + using reverse_iterator = detail::empty_iterator; + using const_reverse_iterator = detail::empty_iterator; + + public: + discard_comments() = default; + ~discard_comments() = default; + discard_comments(const discard_comments&) = default; + discard_comments(discard_comments&&) = default; + discard_comments& operator=(const discard_comments&) = default; + discard_comments& operator=(discard_comments&&) = default; + + explicit discard_comments(const std::vector&) noexcept {} + + explicit discard_comments(std::vector&&) noexcept {} + + discard_comments& operator=(const std::vector&) noexcept { + return *this; + } + + discard_comments& operator=(std::vector&&) noexcept { + return *this; + } + + explicit discard_comments(const preserve_comments&) noexcept {} + + explicit discard_comments(size_type) noexcept {} + + discard_comments(size_type, const std::string&) noexcept {} + + discard_comments(std::initializer_list) noexcept {} + + template + discard_comments(InputIterator, InputIterator) noexcept {} + + template + void assign(InputIterator, InputIterator) noexcept {} + + void assign(std::initializer_list) noexcept {} + + void assign(size_type, const std::string&) noexcept {} + + iterator insert(const_iterator, const std::string&) { + return iterator {}; + } + + iterator insert(const_iterator, std::string&&) { + return iterator {}; + } + + iterator insert(const_iterator, size_type, const std::string&) { + return iterator {}; + } + + template + iterator insert(const_iterator, InputIterator, InputIterator) { + return iterator {}; + } + + iterator insert(const_iterator, std::initializer_list) { + return iterator {}; + } + + template + iterator emplace(const_iterator, Ts&&...) { + return iterator {}; + } + + iterator erase(const_iterator) { + return iterator {}; + } + + iterator erase(const_iterator, const_iterator) { + return iterator {}; + } + + void swap(discard_comments&) { + return; + } + + void push_back(const std::string&) { + return; + } + + void push_back(std::string&&) { + return; + } + + void pop_back() { + return; + } + + template + void emplace_back(Ts&&...) { + return; + } + + void clear() { + return; + } + + size_type size() const noexcept { + return 0; + } + + size_type max_size() const noexcept { + return 0; + } + + size_type capacity() const noexcept { + return 0; + } + + bool empty() const noexcept { + return true; + } + + void reserve(size_type) { + return; + } + + void resize(size_type) { + return; + } + + void resize(size_type, const std::string&) { + return; + } + + void shrink_to_fit() { + return; + } + + // DO NOT access to the element of this container. This container is always + // empty, so accessing through operator[], front/back, data causes address + // error. + + reference operator[](const size_type) noexcept { + never_call("toml::discard_comment::operator[]"); + } + + const_reference operator[](const size_type) const noexcept { + never_call("toml::discard_comment::operator[]"); + } + + reference at(const size_type) { + throw std::out_of_range("toml::discard_comment is always empty."); + } + + const_reference at(const size_type) const { + throw std::out_of_range("toml::discard_comment is always empty."); + } + + reference front() noexcept { + never_call("toml::discard_comment::front"); + } + + const_reference front() const noexcept { + never_call("toml::discard_comment::front"); + } + + reference back() noexcept { + never_call("toml::discard_comment::back"); + } + + const_reference back() const noexcept { + never_call("toml::discard_comment::back"); + } + + pointer data() noexcept { + return nullptr; + } + + const_pointer data() const noexcept { + return nullptr; + } + + iterator begin() noexcept { + return iterator {}; + } + + iterator end() noexcept { + return iterator {}; + } + + const_iterator begin() const noexcept { + return const_iterator {}; + } + + const_iterator end() const noexcept { + return const_iterator {}; + } + + const_iterator cbegin() const noexcept { + return const_iterator {}; + } + + const_iterator cend() const noexcept { + return const_iterator {}; + } + + reverse_iterator rbegin() noexcept { + return iterator {}; + } + + reverse_iterator rend() noexcept { + return iterator {}; + } + + const_reverse_iterator rbegin() const noexcept { + return const_iterator {}; + } + + const_reverse_iterator rend() const noexcept { + return const_iterator {}; + } + + const_reverse_iterator crbegin() const noexcept { + return const_iterator {}; + } + + const_reverse_iterator crend() const noexcept { + return const_iterator {}; + } + + private: + [[noreturn]] + static void never_call(const char* const this_function) { + #if __has_builtin(__builtin_unreachable) + __builtin_unreachable(); + #endif + throw std::logic_error { this_function }; + } + }; + + inline bool operator==(const discard_comments&, const discard_comments&) noexcept { + return true; + } + + inline bool operator!=(const discard_comments&, const discard_comments&) noexcept { + return false; + } + + inline bool operator<(const discard_comments&, const discard_comments&) noexcept { + return false; + } + + inline bool operator<=(const discard_comments&, const discard_comments&) noexcept { + return true; + } + + inline bool operator>(const discard_comments&, const discard_comments&) noexcept { + return false; + } + + inline bool operator>=(const discard_comments&, const discard_comments&) noexcept { + return true; + } + + inline void swap(const discard_comments&, const discard_comments&) noexcept { + return; + } + + inline std::ostream& operator<<(std::ostream& os, const discard_comments&) { + return os; + } + +} // namespace toml +#endif // TOML11_COMMENTS_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_COMMENTS_IMPL_HPP + #define TOML11_COMMENTS_IMPL_HPP + +namespace toml { + + TOML11_INLINE bool operator==(const preserve_comments& lhs, + const preserve_comments& rhs) { + return lhs.comments == rhs.comments; + } + + TOML11_INLINE bool operator!=(const preserve_comments& lhs, + const preserve_comments& rhs) { + return lhs.comments != rhs.comments; + } + + TOML11_INLINE bool operator<(const preserve_comments& lhs, + const preserve_comments& rhs) { + return lhs.comments < rhs.comments; + } + + TOML11_INLINE bool operator<=(const preserve_comments& lhs, + const preserve_comments& rhs) { + return lhs.comments <= rhs.comments; + } + + TOML11_INLINE bool operator>(const preserve_comments& lhs, + const preserve_comments& rhs) { + return lhs.comments > rhs.comments; + } + + TOML11_INLINE bool operator>=(const preserve_comments& lhs, + const preserve_comments& rhs) { + return lhs.comments >= rhs.comments; + } + + TOML11_INLINE void swap(preserve_comments& lhs, preserve_comments& rhs) { + lhs.swap(rhs); + return; + } + + TOML11_INLINE void swap(preserve_comments& lhs, std::vector& rhs) { + lhs.comments.swap(rhs); + return; + } + + TOML11_INLINE void swap(std::vector& lhs, preserve_comments& rhs) { + lhs.swap(rhs.comments); + return; + } + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, + const preserve_comments& com) { + for (const auto& c : com) { + if (c.front() != '#') { + os << '#'; + } + os << c << '\n'; + } + return os; + } + +} // namespace toml + #endif // TOML11_COMMENTS_IMPL_HPP +#endif + +#endif // TOML11_COMMENTS_HPP +#ifndef TOML11_COLOR_HPP +#define TOML11_COLOR_HPP + +#ifndef TOML11_COLOR_FWD_HPP + #define TOML11_COLOR_FWD_HPP + + #include + + #ifdef TOML11_COLORIZE_ERROR_MESSAGE + #define TOML11_ERROR_MESSAGE_COLORIZED true + #else + #define TOML11_ERROR_MESSAGE_COLORIZED false + #endif + + #ifdef TOML11_USE_THREAD_LOCAL_COLORIZATION + #define TOML11_THREAD_LOCAL_COLORIZATION thread_local + #else + #define TOML11_THREAD_LOCAL_COLORIZATION + #endif + +namespace toml { + namespace color { + // put ANSI escape sequence to ostream + inline namespace ansi { + namespace detail { + + // Control color mode globally + class color_mode { + public: + void enable() noexcept { + should_color_ = true; + } + + void disable() noexcept { + should_color_ = false; + } + + bool should_color() const noexcept { + return should_color_; + } + + private: + bool should_color_ = TOML11_ERROR_MESSAGE_COLORIZED; + }; + + inline color_mode& color_status() noexcept { + static TOML11_THREAD_LOCAL_COLORIZATION color_mode status; + return status; + } + + } // namespace detail + + std::ostream& reset(std::ostream& os); + std::ostream& bold(std::ostream& os); + std::ostream& grey(std::ostream& os); + std::ostream& gray(std::ostream& os); + std::ostream& red(std::ostream& os); + std::ostream& green(std::ostream& os); + std::ostream& yellow(std::ostream& os); + std::ostream& blue(std::ostream& os); + std::ostream& magenta(std::ostream& os); + std::ostream& cyan(std::ostream& os); + std::ostream& white(std::ostream& os); + + } // namespace ansi + + inline void enable() { + return detail::color_status().enable(); + } + + inline void disable() { + return detail::color_status().disable(); + } + + inline bool should_color() { + return detail::color_status().should_color(); + } + + } // namespace color +} // namespace toml +#endif // TOML11_COLOR_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_COLOR_IMPL_HPP + #define TOML11_COLOR_IMPL_HPP + + #include + +namespace toml { + namespace color { + // put ANSI escape sequence to ostream + inline namespace ansi { + + TOML11_INLINE std::ostream& reset(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[00m"; + } + return os; + } + + TOML11_INLINE std::ostream& bold(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[01m"; + } + return os; + } + + TOML11_INLINE std::ostream& grey(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[30m"; + } + return os; + } + + TOML11_INLINE std::ostream& gray(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[30m"; + } + return os; + } + + TOML11_INLINE std::ostream& red(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[31m"; + } + return os; + } + + TOML11_INLINE std::ostream& green(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[32m"; + } + return os; + } + + TOML11_INLINE std::ostream& yellow(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[33m"; + } + return os; + } + + TOML11_INLINE std::ostream& blue(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[34m"; + } + return os; + } + + TOML11_INLINE std::ostream& magenta(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[35m"; + } + return os; + } + + TOML11_INLINE std::ostream& cyan(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[36m"; + } + return os; + } + + TOML11_INLINE std::ostream& white(std::ostream& os) { + if (detail::color_status().should_color()) { + os << "\033[37m"; + } + return os; + } + + } // namespace ansi + } // namespace color +} // namespace toml + #endif // TOML11_COLOR_IMPL_HPP +#endif + +#endif // TOML11_COLOR_HPP +#ifndef TOML11_SPEC_HPP +#define TOML11_SPEC_HPP + +#include +#include +#include + +namespace toml { + + struct semantic_version { + constexpr semantic_version(std::uint32_t mjr, + std::uint32_t mnr, + std::uint32_t p) noexcept + : major { mjr } + , minor { mnr } + , patch { p } {} + + std::uint32_t major; + std::uint32_t minor; + std::uint32_t patch; + }; + + constexpr inline semantic_version make_semver(std::uint32_t mjr, + std::uint32_t mnr, + std::uint32_t p) noexcept { + return semantic_version(mjr, mnr, p); + } + + constexpr inline bool operator==(const semantic_version& lhs, + const semantic_version& rhs) noexcept { + return lhs.major == rhs.major && lhs.minor == rhs.minor && + lhs.patch == rhs.patch; + } + + constexpr inline bool operator!=(const semantic_version& lhs, + const semantic_version& rhs) noexcept { + return !(lhs == rhs); + } + + constexpr inline bool operator<(const semantic_version& lhs, + const semantic_version& rhs) noexcept { + return lhs.major < rhs.major || + (lhs.major == rhs.major && lhs.minor < rhs.minor) || + (lhs.major == rhs.major && lhs.minor == rhs.minor && + lhs.patch < rhs.patch); + } + + constexpr inline bool operator>(const semantic_version& lhs, + const semantic_version& rhs) noexcept { + return rhs < lhs; + } + + constexpr inline bool operator<=(const semantic_version& lhs, + const semantic_version& rhs) noexcept { + return !(lhs > rhs); + } + + constexpr inline bool operator>=(const semantic_version& lhs, + const semantic_version& rhs) noexcept { + return !(lhs < rhs); + } + + inline std::ostream& operator<<(std::ostream& os, const semantic_version& v) { + os << v.major << '.' << v.minor << '.' << v.patch; + return os; + } + + inline std::string to_string(const semantic_version& v) { + std::ostringstream oss; + oss << v; + return oss.str(); + } + + struct spec { + constexpr static spec default_version() noexcept { + return spec::v(1, 0, 0); + } + + constexpr static spec v(std::uint32_t mjr, + std::uint32_t mnr, + std::uint32_t p) noexcept { + return spec(make_semver(mjr, mnr, p)); + } + + constexpr explicit spec(const semantic_version& semver) noexcept + : version{semver}, + v1_1_0_allow_control_characters_in_comments {semantic_version{1, 1, 0} <= semver}, + v1_1_0_allow_newlines_in_inline_tables {semantic_version{1, 1, 0} <= semver}, + v1_1_0_allow_trailing_comma_in_inline_tables{semantic_version{1, 1, 0} <= semver}, + v1_1_0_allow_non_english_in_bare_keys {semantic_version{1, 1, 0} <= semver}, + v1_1_0_add_escape_sequence_e {semantic_version{1, 1, 0} <= semver}, + v1_1_0_add_escape_sequence_x {semantic_version{1, 1, 0} <= semver}, + v1_1_0_make_seconds_optional {semantic_version{1, 1, 0} <= semver}, + ext_hex_float {false}, + ext_num_suffix{false}, + ext_null_value{false} + {} + + semantic_version version; // toml version + + // diff from v1.0.0 -> v1.1.0 + bool v1_1_0_allow_control_characters_in_comments; + bool v1_1_0_allow_newlines_in_inline_tables; + bool v1_1_0_allow_trailing_comma_in_inline_tables; + bool v1_1_0_allow_non_english_in_bare_keys; + bool v1_1_0_add_escape_sequence_e; + bool v1_1_0_add_escape_sequence_x; + bool v1_1_0_make_seconds_optional; + + // library extensions + bool ext_hex_float; // allow hex float (in C++ style) + bool ext_num_suffix; // allow number suffix (in C++ style) + bool ext_null_value; // allow `null` as a value + }; + +} // namespace toml +#endif // TOML11_SPEC_HPP +#ifndef TOML11_ORDERED_MAP_HPP +#define TOML11_ORDERED_MAP_HPP + +#include +#include +#include +#include + +namespace toml { + + namespace detail { + template + struct ordered_map_ebo_container { + Cmp cmp_; // empty base optimization for empty Cmp type + }; + } // namespace detail + + template , + typename Allocator = std::allocator>> + class ordered_map : detail::ordered_map_ebo_container { + public: + using key_type = Key; + using mapped_type = Val; + using value_type = std::pair; + + using key_compare = Cmp; + using allocator_type = Allocator; + + using container_type = std::vector; + using reference = typename container_type::reference; + using pointer = typename container_type::pointer; + using const_reference = typename container_type::const_reference; + using const_pointer = typename container_type::const_pointer; + using iterator = typename container_type::iterator; + using const_iterator = typename container_type::const_iterator; + using size_type = typename container_type::size_type; + using difference_type = typename container_type::difference_type; + + private: + using ebo_base = detail::ordered_map_ebo_container; + + public: + ordered_map() = default; + ~ordered_map() = default; + ordered_map(const ordered_map&) = default; + ordered_map(ordered_map&&) = default; + ordered_map& operator=(const ordered_map&) = default; + ordered_map& operator=(ordered_map&&) = default; + + ordered_map(const ordered_map& other, const Allocator& alloc) + : container_(other.container_, alloc) {} + + ordered_map(ordered_map&& other, const Allocator& alloc) + : container_(std::move(other.container_), alloc) {} + + explicit ordered_map(const Cmp& cmp, const Allocator& alloc = Allocator()) + : ebo_base { cmp } + , container_(alloc) {} + + explicit ordered_map(const Allocator& alloc) : container_(alloc) {} + + template + ordered_map(InputIterator first, + InputIterator last, + const Cmp& cmp = Cmp(), + const Allocator& alloc = Allocator()) + : ebo_base { cmp } + , container_(first, last, alloc) {} + + template + ordered_map(InputIterator first, InputIterator last, const Allocator& alloc) + : container_(first, last, alloc) {} + + ordered_map(std::initializer_list v, + const Cmp& cmp = Cmp(), + const Allocator& alloc = Allocator()) + : ebo_base { cmp } + , container_(std::move(v), alloc) {} + + ordered_map(std::initializer_list v, const Allocator& alloc) + : container_(std::move(v), alloc) {} + + ordered_map& operator=(std::initializer_list v) { + this->container_ = std::move(v); + return *this; + } + + iterator begin() noexcept { + return container_.begin(); + } + + iterator end() noexcept { + return container_.end(); + } + + const_iterator begin() const noexcept { + return container_.begin(); + } + + const_iterator end() const noexcept { + return container_.end(); + } + + const_iterator cbegin() const noexcept { + return container_.cbegin(); + } + + const_iterator cend() const noexcept { + return container_.cend(); + } + + bool empty() const noexcept { + return container_.empty(); + } + + std::size_t size() const noexcept { + return container_.size(); + } + + std::size_t max_size() const noexcept { + return container_.max_size(); + } + + void clear() { + container_.clear(); + } + + void push_back(const value_type& v) { + if (this->contains(v.first)) { + throw std::out_of_range("ordered_map: value already exists"); + } + container_.push_back(v); + } + + void push_back(value_type&& v) { + if (this->contains(v.first)) { + throw std::out_of_range("ordered_map: value already exists"); + } + container_.push_back(std::move(v)); + } + + void emplace_back(key_type k, mapped_type v) { + if (this->contains(k)) { + throw std::out_of_range("ordered_map: value already exists"); + } + container_.emplace_back(std::move(k), std::move(v)); + } + + void pop_back() { + container_.pop_back(); + } + + void insert(value_type kv) { + if (this->contains(kv.first)) { + throw std::out_of_range("ordered_map: value already exists"); + } + container_.push_back(std::move(kv)); + } + + void emplace(key_type k, mapped_type v) { + if (this->contains(k)) { + throw std::out_of_range("ordered_map: value already exists"); + } + container_.emplace_back(std::move(k), std::move(v)); + } + + std::size_t count(const key_type& key) const { + if (this->find(key) != this->end()) { + return 1; + } else { + return 0; + } + } + + bool contains(const key_type& key) const { + return this->find(key) != this->end(); + } + + iterator find(const key_type& key) noexcept { + return std::find_if(this->begin(), + this->end(), + [&key, this](const value_type& v) { + return this->cmp_(v.first, key); + }); + } + + const_iterator find(const key_type& key) const noexcept { + return std::find_if(this->begin(), + this->end(), + [&key, this](const value_type& v) { + return this->cmp_(v.first, key); + }); + } + + mapped_type& at(const key_type& k) { + const auto iter = this->find(k); + if (iter == this->end()) { + throw std::out_of_range("ordered_map: no such element"); + } + return iter->second; + } + + const mapped_type& at(const key_type& k) const { + const auto iter = this->find(k); + if (iter == this->end()) { + throw std::out_of_range("ordered_map: no such element"); + } + return iter->second; + } + + mapped_type& operator[](const key_type& k) { + const auto iter = this->find(k); + if (iter == this->end()) { + this->container_.emplace_back(k, mapped_type {}); + return this->container_.back().second; + } + return iter->second; + } + + const mapped_type& operator[](const key_type& k) const { + const auto iter = this->find(k); + if (iter == this->end()) { + throw std::out_of_range("ordered_map: no such element"); + } + return iter->second; + } + + key_compare key_comp() const { + return this->cmp_; + } + + void swap(ordered_map& other) { + container_.swap(other.container_); + } + + private: + container_type container_; + }; + + template + bool operator==(const ordered_map& lhs, + const ordered_map& rhs) { + return lhs.size() == rhs.size() && + std::equal(lhs.begin(), lhs.end(), rhs.begin()); + } + + template + bool operator!=(const ordered_map& lhs, + const ordered_map& rhs) { + return !(lhs == rhs); + } + + template + bool operator<(const ordered_map& lhs, + const ordered_map& rhs) { + return std::lexicographical_compare(lhs.begin(), + lhs.end(), + rhs.begin(), + rhs.end()); + } + + template + bool operator>(const ordered_map& lhs, + const ordered_map& rhs) { + return rhs < lhs; + } + + template + bool operator<=(const ordered_map& lhs, + const ordered_map& rhs) { + return !(lhs > rhs); + } + + template + bool operator>=(const ordered_map& lhs, + const ordered_map& rhs) { + return !(lhs < rhs); + } + + template + void swap(ordered_map& lhs, ordered_map& rhs) { + lhs.swap(rhs); + return; + } + +} // namespace toml +#endif // TOML11_ORDERED_MAP_HPP +#ifndef TOML11_INTO_HPP +#define TOML11_INTO_HPP + +namespace toml { + + template + struct into; + // { + // static toml::value into_toml(const T& user_defined_type) + // { + // // User-defined conversions ... + // } + // }; + +} // namespace toml +#endif // TOML11_INTO_HPP +#ifndef TOML11_FROM_HPP +#define TOML11_FROM_HPP + +namespace toml { + + template + struct from; + // { + // static T from_toml(const toml::value& v) + // { + // // User-defined conversions ... + // } + // }; + +} // namespace toml +#endif // TOML11_FROM_HPP +#ifndef TOML11_TRAITS_HPP +#define TOML11_TRAITS_HPP + +#include +#include +#include +#include +#include +#include +#include + +#if defined(TOML11_HAS_STRING_VIEW) + #include +#endif + +namespace toml { + template + class basic_value; + + namespace detail { + // --------------------------------------------------------------------------- + // check whether type T is a kind of container/map class + + struct has_iterator_impl { + template + static std::true_type check(typename T::iterator*); + template + static std::false_type check(...); + }; + + struct has_value_type_impl { + template + static std::true_type check(typename T::value_type*); + template + static std::false_type check(...); + }; + + struct has_key_type_impl { + template + static std::true_type check(typename T::key_type*); + template + static std::false_type check(...); + }; + + struct has_mapped_type_impl { + template + static std::true_type check(typename T::mapped_type*); + template + static std::false_type check(...); + }; + + struct has_reserve_method_impl { + template + static std::false_type check(...); + template + static std::true_type check( + decltype(std::declval().reserve(std::declval()))*); + }; + + struct has_push_back_method_impl { + template + static std::false_type check(...); + template + static std::true_type check(decltype(std::declval().push_back( + std::declval()))*); + }; + + struct is_comparable_impl { + template + static std::false_type check(...); + template + static std::true_type check(decltype(std::declval() < std::declval())*); + }; + + struct has_from_toml_method_impl { + template + static std::true_type check(decltype(std::declval().from_toml( + std::declval<::toml::basic_value>()))*); + + template + static std::false_type check(...); + }; + + struct has_into_toml_method_impl { + template + static std::true_type check(decltype(std::declval().into_toml())*); + template + static std::false_type check(...); + }; + + struct has_template_into_toml_method_impl { + template + static std::true_type check( + decltype(std::declval().template into_toml())*); + template + static std::false_type check(...); + }; + + struct has_specialized_from_impl { + template + static std::false_type check(...); + template )> + static std::true_type check(::toml::from*); + }; + + struct has_specialized_into_impl { + template + static std::false_type check(...); + template )> + static std::true_type check(::toml::into*); + }; + +/// Intel C++ compiler can not use decltype in parent class declaration, here +/// is a hack to work around it. https://stackoverflow.com/a/23953090/4692076 +#ifdef __INTEL_COMPILER + #define decltype(...) std::enable_if::type +#endif + + template + struct has_iterator : decltype(has_iterator_impl::check(nullptr)) {}; + + template + struct has_value_type : decltype(has_value_type_impl::check(nullptr)) {}; + + template + struct has_key_type : decltype(has_key_type_impl::check(nullptr)) {}; + + template + struct has_mapped_type : decltype(has_mapped_type_impl::check(nullptr)) {}; + + template + struct has_reserve_method + : decltype(has_reserve_method_impl::check(nullptr)) {}; + + template + struct has_push_back_method + : decltype(has_push_back_method_impl::check(nullptr)) {}; + + template + struct is_comparable : decltype(is_comparable_impl::check(nullptr)) {}; + + template + struct has_from_toml_method + : decltype(has_from_toml_method_impl::check(nullptr)) {}; + + template + struct has_into_toml_method + : decltype(has_into_toml_method_impl::check(nullptr)) {}; + + template + struct has_template_into_toml_method + : decltype(has_template_into_toml_method_impl::check(nullptr)) { + }; + + template + struct has_specialized_from + : decltype(has_specialized_from_impl::check(nullptr)) {}; + + template + struct has_specialized_into + : decltype(has_specialized_into_impl::check(nullptr)) {}; + +#ifdef __INTEL_COMPILER + #undef decltype +#endif + + // --------------------------------------------------------------------------- + // type checkers + + template + struct is_std_pair_impl : std::false_type {}; + + template + struct is_std_pair_impl> : std::true_type {}; + + template + using is_std_pair = is_std_pair_impl>; + + template + struct is_std_tuple_impl : std::false_type {}; + + template + struct is_std_tuple_impl> : std::true_type {}; + + template + using is_std_tuple = is_std_tuple_impl>; + + template + struct is_std_array_impl : std::false_type {}; + + template + struct is_std_array_impl> : std::true_type {}; + + template + using is_std_array = is_std_array_impl>; + + template + struct is_std_forward_list_impl : std::false_type {}; + + template + struct is_std_forward_list_impl> : std::true_type {}; + + template + using is_std_forward_list = is_std_forward_list_impl>; + + template + struct is_std_basic_string_impl : std::false_type {}; + + template + struct is_std_basic_string_impl> : std::true_type { + }; + + template + using is_std_basic_string = is_std_basic_string_impl>; + + template + struct is_1byte_std_basic_string_impl : std::false_type {}; + + template + struct is_1byte_std_basic_string_impl> + : std::integral_constant {}; + + template + using is_1byte_std_basic_string = is_std_basic_string_impl>; + +#if defined(TOML11_HAS_STRING_VIEW) + template + struct is_std_basic_string_view_impl : std::false_type {}; + + template + struct is_std_basic_string_view_impl> + : std::true_type {}; + + template + using is_std_basic_string_view = + is_std_basic_string_view_impl>; + + template + struct is_string_view_of : std::false_type {}; + + template + struct is_string_view_of, std::basic_string> + : std::true_type {}; +#endif + + template + struct is_chrono_duration_impl : std::false_type {}; + + template + struct is_chrono_duration_impl> + : std::true_type {}; + + template + using is_chrono_duration = is_chrono_duration_impl>; + + template + struct is_map_impl + : cxx::conjunction< // map satisfies all the following conditions + has_iterator, // has T::iterator + has_value_type, // has T::value_type + has_key_type, // has T::key_type + has_mapped_type // has T::mapped_type + > {}; + + template + using is_map = is_map_impl>; + + template + struct is_container_impl + : cxx::conjunction>, // not a map + cxx::negation>, // not a std::string +#ifdef TOML11_HAS_STRING_VIEW + cxx::negation>, // not a std::string_view +#endif + has_iterator, // has T::iterator + has_value_type // has T::value_type + > { + }; + + template + using is_container = is_container_impl>; + + template + struct is_basic_value_impl : std::false_type {}; + + template + struct is_basic_value_impl<::toml::basic_value> : std::true_type {}; + + template + using is_basic_value = is_basic_value_impl>; + + } // namespace detail +} // namespace toml +#endif // TOML11_TRAITS_HPP +#ifndef TOML11_EXCEPTION_HPP +#define TOML11_EXCEPTION_HPP + +#include + +namespace toml { + + struct exception : public std::exception { + public: + virtual ~exception() noexcept override = default; + + virtual const char* what() const noexcept override { + return ""; + } + }; + +} // namespace toml +#endif // TOMl11_EXCEPTION_HPP +#ifndef TOML11_RESULT_HPP +#define TOML11_RESULT_HPP + +#include +#include +#include +#include +#include + +namespace toml { + + struct bad_result_access final : public ::toml::exception { + public: + explicit bad_result_access(std::string what_arg) + : what_(std::move(what_arg)) {} + + ~bad_result_access() noexcept override = default; + + const char* what() const noexcept override { + return what_.c_str(); + } + + private: + std::string what_; + }; + + // ----------------------------------------------------------------------------- + + template + struct success { + static_assert(!std::is_same::value, ""); + + using value_type = T; + + explicit success(value_type v) noexcept( + std::is_nothrow_move_constructible::value) + : value(std::move(v)) {} + + template , T>::value, + std::nullptr_t> = nullptr> + explicit success(U&& v) : value(std::forward(v)) {} + + template + explicit success(success v) : value(std::move(v.value)) {} + + ~success() = default; + success(const success&) = default; + success(success&&) = default; + success& operator=(const success&) = default; + success& operator=(success&&) = default; + + value_type& get() noexcept { + return value; + } + + const value_type& get() const noexcept { + return value; + } + + private: + value_type value; + }; + + template + struct success> { + static_assert(!std::is_same::value, ""); + + using value_type = T; + + explicit success(std::reference_wrapper v) noexcept + : value(std::move(v)) {} + + ~success() = default; + success(const success&) = default; + success(success&&) = default; + success& operator=(const success&) = default; + success& operator=(success&&) = default; + + value_type& get() noexcept { + return value.get(); + } + + const value_type& get() const noexcept { + return value.get(); + } + + private: + std::reference_wrapper value; + }; + + template + success::type> ok(T&& v) { + return success::type>(std::forward(v)); + } + + template + success ok(const char (&literal)[N]) { + return success(std::string(literal)); + } + + // ----------------------------------------------------------------------------- + + template + struct failure { + using value_type = T; + + explicit failure(value_type v) noexcept( + std::is_nothrow_move_constructible::value) + : value(std::move(v)) {} + + template , T>::value, + std::nullptr_t> = nullptr> + explicit failure(U&& v) : value(std::forward(v)) {} + + template + explicit failure(failure v) : value(std::move(v.value)) {} + + ~failure() = default; + failure(const failure&) = default; + failure(failure&&) = default; + failure& operator=(const failure&) = default; + failure& operator=(failure&&) = default; + + value_type& get() noexcept { + return value; + } + + const value_type& get() const noexcept { + return value; + } + + private: + value_type value; + }; + + template + struct failure> { + using value_type = T; + + explicit failure(std::reference_wrapper v) noexcept + : value(std::move(v)) {} + + ~failure() = default; + failure(const failure&) = default; + failure(failure&&) = default; + failure& operator=(const failure&) = default; + failure& operator=(failure&&) = default; + + value_type& get() noexcept { + return value.get(); + } + + const value_type& get() const noexcept { + return value.get(); + } + + private: + std::reference_wrapper value; + }; + + template + failure::type> err(T&& v) { + return failure::type>(std::forward(v)); + } + + template + failure err(const char (&literal)[N]) { + return failure(std::string(literal)); + } + + /* ============================================================================ + * _ _ + * _ _ ___ ____ _| | |_ + * | '_/ -_|_-< || | | _| + * |_| \___/__/\_,_|_|\__| + */ + + template + struct result { + using success_type = success; + using failure_type = failure; + using value_type = typename success_type::value_type; + using error_type = typename failure_type::value_type; + + result(success_type s) : is_ok_(true), succ_(std::move(s)) {} + + result(failure_type f) : is_ok_(false), fail_(std::move(f)) {} + + template < + typename U, + cxx::enable_if_t< + cxx::conjunction, value_type>>, + std::is_convertible, value_type>>::value, + std::nullptr_t> = nullptr> + result(success s) : is_ok_(true) + , succ_(std::move(s.value)) {} + + template < + typename U, + cxx::enable_if_t< + cxx::conjunction, error_type>>, + std::is_convertible, error_type>>::value, + std::nullptr_t> = nullptr> + result(failure f) : is_ok_(false) + , fail_(std::move(f.value)) {} + + result& operator=(success_type s) { + this->cleanup(); + this->is_ok_ = true; + auto tmp = ::new (std::addressof(this->succ_)) success_type(std::move(s)); + assert(tmp == std::addressof(this->succ_)); + (void)tmp; + return *this; + } + + result& operator=(failure_type f) { + this->cleanup(); + this->is_ok_ = false; + auto tmp = ::new (std::addressof(this->fail_)) failure_type(std::move(f)); + assert(tmp == std::addressof(this->fail_)); + (void)tmp; + return *this; + } + + template + result& operator=(success s) { + this->cleanup(); + this->is_ok_ = true; + auto tmp = ::new (std::addressof(this->succ_)) + success_type(std::move(s.value)); + assert(tmp == std::addressof(this->succ_)); + (void)tmp; + return *this; + } + + template + result& operator=(failure f) { + this->cleanup(); + this->is_ok_ = false; + auto tmp = ::new (std::addressof(this->fail_)) + failure_type(std::move(f.value)); + assert(tmp == std::addressof(this->fail_)); + (void)tmp; + return *this; + } + + ~result() noexcept { + this->cleanup(); + } + + result(const result& other) : is_ok_(other.is_ok()) { + if (other.is_ok()) { + auto tmp = ::new (std::addressof(this->succ_)) success_type(other.succ_); + assert(tmp == std::addressof(this->succ_)); + (void)tmp; + } else { + auto tmp = ::new (std::addressof(this->fail_)) failure_type(other.fail_); + assert(tmp == std::addressof(this->fail_)); + (void)tmp; + } + } + + result(result&& other) : is_ok_(other.is_ok()) { + if (other.is_ok()) { + auto tmp = ::new (std::addressof(this->succ_)) + success_type(std::move(other.succ_)); + assert(tmp == std::addressof(this->succ_)); + (void)tmp; + } else { + auto tmp = ::new (std::addressof(this->fail_)) + failure_type(std::move(other.fail_)); + assert(tmp == std::addressof(this->fail_)); + (void)tmp; + } + } + + result& operator=(const result& other) { + this->cleanup(); + if (other.is_ok()) { + auto tmp = ::new (std::addressof(this->succ_)) success_type(other.succ_); + assert(tmp == std::addressof(this->succ_)); + (void)tmp; + } else { + auto tmp = ::new (std::addressof(this->fail_)) failure_type(other.fail_); + assert(tmp == std::addressof(this->fail_)); + (void)tmp; + } + is_ok_ = other.is_ok(); + return *this; + } + + result& operator=(result&& other) { + this->cleanup(); + if (other.is_ok()) { + auto tmp = ::new (std::addressof(this->succ_)) + success_type(std::move(other.succ_)); + assert(tmp == std::addressof(this->succ_)); + (void)tmp; + } else { + auto tmp = ::new (std::addressof(this->fail_)) + failure_type(std::move(other.fail_)); + assert(tmp == std::addressof(this->fail_)); + (void)tmp; + } + is_ok_ = other.is_ok(); + return *this; + } + + template < + typename U, + typename F, + cxx::enable_if_t< + cxx::conjunction, value_type>>, + cxx::negation, error_type>>, + std::is_convertible, value_type>, + std::is_convertible, error_type>>::value, + std::nullptr_t> = nullptr> + result(result other) : is_ok_(other.is_ok()) { + if (other.is_ok()) { + auto tmp = ::new (std::addressof(this->succ_)) + success_type(std::move(other.as_ok())); + assert(tmp == std::addressof(this->succ_)); + (void)tmp; + } else { + auto tmp = ::new (std::addressof(this->fail_)) + failure_type(std::move(other.as_err())); + assert(tmp == std::addressof(this->fail_)); + (void)tmp; + } + } + + template < + typename U, + typename F, + cxx::enable_if_t< + cxx::conjunction, value_type>>, + cxx::negation, error_type>>, + std::is_convertible, value_type>, + std::is_convertible, error_type>>::value, + std::nullptr_t> = nullptr> + result& operator=(result other) { + this->cleanup(); + if (other.is_ok()) { + auto tmp = ::new (std::addressof(this->succ_)) + success_type(std::move(other.as_ok())); + assert(tmp == std::addressof(this->succ_)); + (void)tmp; + } else { + auto tmp = ::new (std::addressof(this->fail_)) + failure_type(std::move(other.as_err())); + assert(tmp == std::addressof(this->fail_)); + (void)tmp; + } + is_ok_ = other.is_ok(); + return *this; + } + + bool is_ok() const noexcept { + return is_ok_; + } + + bool is_err() const noexcept { + return !is_ok_; + } + + explicit operator bool() const noexcept { + return is_ok_; + } + + value_type& unwrap(cxx::source_location loc = cxx::source_location::current()) { + if (this->is_err()) { + throw bad_result_access("toml::result: bad unwrap" + cxx::to_string(loc)); + } + return this->succ_.get(); + } + + const value_type& unwrap( + cxx::source_location loc = cxx::source_location::current()) const { + if (this->is_err()) { + throw bad_result_access("toml::result: bad unwrap" + cxx::to_string(loc)); + } + return this->succ_.get(); + } + + value_type& unwrap_or(value_type& opt) noexcept { + if (this->is_err()) { + return opt; + } + return this->succ_.get(); + } + + const value_type& unwrap_or(const value_type& opt) const noexcept { + if (this->is_err()) { + return opt; + } + return this->succ_.get(); + } + + error_type& unwrap_err( + cxx::source_location loc = cxx::source_location::current()) { + if (this->is_ok()) { + throw bad_result_access( + "toml::result: bad unwrap_err" + cxx::to_string(loc)); + } + return this->fail_.get(); + } + + const error_type& unwrap_err( + cxx::source_location loc = cxx::source_location::current()) const { + if (this->is_ok()) { + throw bad_result_access( + "toml::result: bad unwrap_err" + cxx::to_string(loc)); + } + return this->fail_.get(); + } + + value_type& as_ok() noexcept { + assert(this->is_ok()); + return this->succ_.get(); + } + + const value_type& as_ok() const noexcept { + assert(this->is_ok()); + return this->succ_.get(); + } + + error_type& as_err() noexcept { + assert(this->is_err()); + return this->fail_.get(); + } + + const error_type& as_err() const noexcept { + assert(this->is_err()); + return this->fail_.get(); + } + + private: + void cleanup() noexcept { +#if defined(__GNUC__) && !defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wduplicated-branches" +#endif + + if (this->is_ok_) { + this->succ_.~success_type(); + } else { + this->fail_.~failure_type(); + } + +#if defined(__GNUC__) && !defined(__clang__) + #pragma GCC diagnostic pop +#endif + return; + } + + private: + bool is_ok_; + + union { + success_type succ_; + failure_type fail_; + }; + }; + + // ---------------------------------------------------------------------------- + + namespace detail { + struct none_t {}; + + inline bool operator==(const none_t&, const none_t&) noexcept { + return true; + } + + inline bool operator!=(const none_t&, const none_t&) noexcept { + return false; + } + + inline bool operator<(const none_t&, const none_t&) noexcept { + return false; + } + + inline bool operator<=(const none_t&, const none_t&) noexcept { + return true; + } + + inline bool operator>(const none_t&, const none_t&) noexcept { + return false; + } + + inline bool operator>=(const none_t&, const none_t&) noexcept { + return true; + } + + inline std::ostream& operator<<(std::ostream& os, const none_t&) { + os << "none"; + return os; + } + } // namespace detail + + inline success ok() noexcept { + return success(detail::none_t {}); + } + + inline failure err() noexcept { + return failure(detail::none_t {}); + } + +} // namespace toml +#endif // TOML11_RESULT_HPP +#ifndef TOML11_UTILITY_HPP +#define TOML11_UTILITY_HPP + +#include +#include +#include +#include +#include + +namespace toml { + namespace detail { + + // to output character in an error message. + inline std::string show_char(const int c) { + using char_type = unsigned char; + if (std::isgraph(c)) { + return std::string(1, static_cast(c)); + } else { + std::array buf; + buf.fill('\0'); + const auto r = std::snprintf(buf.data(), buf.size(), "0x%02x", c & 0xFF); + assert(r == static_cast(buf.size()) - 1); + (void)r; // Unused variable warning + auto in_hex = std::string(buf.data()); + switch (c) { + case char_type('\0'): { + in_hex += "(NUL)"; + break; + } + case char_type(' '): { + in_hex += "(SPACE)"; + break; + } + case char_type('\n'): { + in_hex += "(LINE FEED)"; + break; + } + case char_type('\r'): { + in_hex += "(CARRIAGE RETURN)"; + break; + } + case char_type('\t'): { + in_hex += "(TAB)"; + break; + } + case char_type('\v'): { + in_hex += "(VERTICAL TAB)"; + break; + } + case char_type('\f'): { + in_hex += "(FORM FEED)"; + break; + } + case char_type('\x1B'): { + in_hex += "(ESCAPE)"; + break; + } + default: + break; + } + return in_hex; + } + } + + // --------------------------------------------------------------------------- + + template + void try_reserve_impl(Container& container, std::size_t N, std::true_type) { + container.reserve(N); + return; + } + + template + void try_reserve_impl(Container&, std::size_t, std::false_type) noexcept { + return; + } + + template + void try_reserve(Container& container, std::size_t N) { + try_reserve_impl(container, N, has_reserve_method {}); + return; + } + + // --------------------------------------------------------------------------- + + template + result from_string(const std::string& str) { + T v; + std::istringstream iss(str); + iss >> v; + if (iss.fail()) { + return err(); + } + return ok(v); + } + + // --------------------------------------------------------------------------- + + // helper function to avoid std::string(0, 'c') or std::string(iter, iter) + template + std::string make_string(Iterator first, Iterator last) { + if (first == last) { + return ""; + } + return std::string(first, last); + } + + inline std::string make_string(std::size_t len, char c) { + if (len == 0) { + return ""; + } + return std::string(len, c); + } + + // --------------------------------------------------------------------------- + + template + struct string_conv_impl { + static_assert(sizeof(Char) == sizeof(char), ""); + static_assert(sizeof(Char2) == sizeof(char), ""); + + static std::basic_string invoke( + std::basic_string s) { + std::basic_string retval; + std::transform(s.begin(), + s.end(), + std::back_inserter(retval), + [](const Char2 c) { + return static_cast(c); + }); + return retval; + } + + template + static std::basic_string invoke(const Char2 (&s)[N]) { + std::basic_string retval; + // "string literal" has null-char at the end. to skip it, we use prev. + std::transform(std::begin(s), + std::prev(std::end(s)), + std::back_inserter(retval), + [](const Char2 c) { + return static_cast(c); + }); + return retval; + } + }; + + template + struct string_conv_impl { + static_assert(sizeof(Char) == sizeof(char), ""); + + static std::basic_string invoke( + std::basic_string s) { + return s; + } + + template + static std::basic_string invoke(const Char (&s)[N]) { + return std::basic_string(s); + } + }; + + template + cxx::enable_if_t::value, S> string_conv( + std::basic_string s) { + using C = typename S::value_type; + using T = typename S::traits_type; + using A = typename S::allocator_type; + return string_conv_impl::invoke( + std::move(s)); + } + + template + cxx::enable_if_t::value, S> string_conv( + const char (&s)[N]) { + using C = typename S::value_type; + using T = typename S::traits_type; + using A = typename S::allocator_type; + using C2 = char; + using T2 = std::char_traits; + using A2 = std::allocator; + + return string_conv_impl::template invoke(s); + } + + } // namespace detail +} // namespace toml +#endif // TOML11_UTILITY_HPP +#ifndef TOML11_LOCATION_HPP +#define TOML11_LOCATION_HPP + +#ifndef TOML11_LOCATION_FWD_HPP + #define TOML11_LOCATION_FWD_HPP + + #include + #include + #include + +namespace toml { + namespace detail { + + class region; // fwd decl + + // + // To represent where we are reading in the parse functions. + // Since it "points" somewhere in the input stream, the length is always 1. + // + class location { + public: + using char_type = unsigned char; // must be unsigned + using container_type = std::vector; + using difference_type = + typename container_type::difference_type; // to suppress sign-conversion warning + using source_ptr = std::shared_ptr; + + public: + location(source_ptr src, std::string src_name) + : source_(std::move(src)) + , source_name_(std::move(src_name)) + , location_(0) + , line_number_(1) {} + + location(const location&) = default; + location(location&&) = default; + location& operator=(const location&) = default; + location& operator=(location&&) = default; + ~location() = default; + + void advance(std::size_t n = 1) noexcept; + void retrace(std::size_t n = 1) noexcept; + + bool is_ok() const noexcept { + return static_cast(this->source_); + } + + bool eof() const noexcept; + char_type current() const; + + char_type peek(); + + std::size_t get_location() const noexcept { + return this->location_; + } + + void set_location(const std::size_t loc) noexcept; + + std::size_t line_number() const noexcept { + return this->line_number_; + } + + std::string get_line() const; + std::size_t column_number() const noexcept; + + const source_ptr& source() const noexcept { + return this->source_; + } + + const std::string& source_name() const noexcept { + return this->source_name_; + } + + private: + void advance_line_number(const std::size_t n); + void retrace_line_number(const std::size_t n); + + private: + friend region; + + private: + source_ptr source_; + std::string source_name_; + std::size_t location_; // std::vector<>::difference_type is signed + std::size_t line_number_; + }; + + bool operator==(const location& lhs, const location& rhs) noexcept; + bool operator!=(const location& lhs, const location& rhs); + + location prev(const location& loc); + location next(const location& loc); + location make_temporary_location(const std::string& str) noexcept; + + template + result find_if(const location& first, + const location& last, + const F& func) noexcept { + if (first.source() != last.source()) { + return err(); + } + if (first.get_location() >= last.get_location()) { + return err(); + } + + auto loc = first; + while (loc.get_location() != last.get_location()) { + if (func(loc.current())) { + return ok(loc); + } + loc.advance(); + } + return err(); + } + + template + result rfind_if(location first, + const location& last, + const F& func) { + if (first.source() != last.source()) { + return err(); + } + if (first.get_location() >= last.get_location()) { + return err(); + } + + auto loc = last; + while (loc.get_location() != first.get_location()) { + if (func(loc.current())) { + return ok(loc); + } + loc.retrace(); + } + if (func(first.current())) { + return ok(first); + } + return err(); + } + + result find(const location& first, + const location& last, + const location::char_type val); + result rfind(const location& first, + const location& last, + const location::char_type val); + + std::size_t count(const location& first, + const location& last, + const location::char_type& c); + + } // namespace detail +} // namespace toml +#endif // TOML11_LOCATION_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_LOCATION_IMPL_HPP + #define TOML11_LOCATION_IMPL_HPP + +namespace toml { + namespace detail { + + TOML11_INLINE void location::advance(std::size_t n) noexcept { + assert(this->is_ok()); + if (this->location_ + n < this->source_->size()) { + this->advance_line_number(n); + this->location_ += n; + } else { + this->advance_line_number(this->source_->size() - this->location_); + this->location_ = this->source_->size(); + } + } + + TOML11_INLINE void location::retrace(std::size_t n) noexcept { + assert(this->is_ok()); + if (this->location_ < n) { + this->location_ = 0; + this->line_number_ = 1; + } else { + this->retrace_line_number(n); + this->location_ -= n; + } + } + + TOML11_INLINE bool location::eof() const noexcept { + assert(this->is_ok()); + return this->location_ >= this->source_->size(); + } + + TOML11_INLINE location::char_type location::current() const { + assert(this->is_ok()); + if (this->eof()) { + return '\0'; + } + + assert(this->location_ < this->source_->size()); + return this->source_->at(this->location_); + } + + TOML11_INLINE location::char_type location::peek() { + assert(this->is_ok()); + if (this->location_ >= this->source_->size()) { + return '\0'; + } else { + return this->source_->at(this->location_ + 1); + } + } + + TOML11_INLINE void location::set_location(const std::size_t loc) noexcept { + if (this->location_ == loc) { + return; + } + + if (loc == 0) { + this->line_number_ = 1; + } else if (this->location_ < loc) { + const auto d = loc - this->location_; + this->advance_line_number(d); + } else { + const auto d = this->location_ - loc; + this->retrace_line_number(d); + } + this->location_ = loc; + } + + TOML11_INLINE std::string location::get_line() const { + assert(this->is_ok()); + const auto iter = std::next(this->source_->cbegin(), + static_cast(this->location_)); + const auto riter = cxx::make_reverse_iterator(iter); + + const auto prev = std::find(riter, this->source_->crend(), char_type('\n')); + const auto next = std::find(iter, this->source_->cend(), char_type('\n')); + + return make_string(std::next(prev.base()), next); + } + + TOML11_INLINE std::size_t location::column_number() const noexcept { + assert(this->is_ok()); + const auto iter = std::next(this->source_->cbegin(), + static_cast(this->location_)); + const auto riter = cxx::make_reverse_iterator(iter); + const auto prev = std::find(riter, this->source_->crend(), char_type('\n')); + + assert(prev.base() <= iter); + return static_cast(std::distance(prev.base(), iter) + 1); // 1-origin + } + + TOML11_INLINE void location::advance_line_number(const std::size_t n) { + assert(this->is_ok()); + assert(this->location_ + n <= this->source_->size()); + + const auto iter = this->source_->cbegin(); + this->line_number_ += static_cast(std::count( + std::next(iter, static_cast(this->location_)), + std::next(iter, static_cast(this->location_ + n)), + char_type('\n'))); + + return; + } + + TOML11_INLINE void location::retrace_line_number(const std::size_t n) { + assert(this->is_ok()); + assert(n <= this->location_); // loc - n >= 0 + + const auto iter = this->source_->cbegin(); + const auto dline_num = static_cast(std::count( + std::next(iter, static_cast(this->location_ - n)), + std::next(iter, static_cast(this->location_)), + char_type('\n'))); + + if (this->line_number_ <= dline_num) { + this->line_number_ = 1; + } else { + this->line_number_ -= dline_num; + } + return; + } + + TOML11_INLINE bool operator==(const location& lhs, const location& rhs) noexcept { + if (!lhs.is_ok() || !rhs.is_ok()) { + return (!lhs.is_ok()) && (!rhs.is_ok()); + } + return lhs.source() == rhs.source() && + lhs.source_name() == rhs.source_name() && + lhs.get_location() == rhs.get_location(); + } + + TOML11_INLINE bool operator!=(const location& lhs, const location& rhs) { + return !(lhs == rhs); + } + + TOML11_INLINE location prev(const location& loc) { + location p(loc); + p.retrace(1); + return p; + } + + TOML11_INLINE location next(const location& loc) { + location p(loc); + p.advance(1); + return p; + } + + TOML11_INLINE location make_temporary_location(const std::string& str) noexcept { + location::container_type cont(str.size()); + std::transform(str.begin(), + str.end(), + cont.begin(), + [](const std::string::value_type& c) { + return cxx::bit_cast(c); + }); + return location( + std::make_shared(std::move(cont)), + "internal temporary"); + } + + TOML11_INLINE result find(const location& first, + const location& last, + const location::char_type val) { + return find_if(first, last, [val](const location::char_type c) { + return c == val; + }); + } + + TOML11_INLINE result rfind(const location& first, + const location& last, + const location::char_type val) { + return rfind_if(first, last, [val](const location::char_type c) { + return c == val; + }); + } + + TOML11_INLINE std::size_t count(const location& first, + const location& last, + const location::char_type& c) { + if (first.source() != last.source()) { + return 0; + } + if (first.get_location() >= last.get_location()) { + return 0; + } + + auto loc = first; + std::size_t num = 0; + while (loc.get_location() != last.get_location()) { + if (loc.current() == c) { + num += 1; + } + loc.advance(); + } + return num; + } + + } // namespace detail +} // namespace toml + #endif // TOML11_LOCATION_HPP +#endif + +#endif // TOML11_LOCATION_HPP +#ifndef TOML11_REGION_HPP +#define TOML11_REGION_HPP + +#ifndef TOML11_REGION_FWD_HPP + #define TOML11_REGION_FWD_HPP + + #include + #include + #include + +namespace toml { + namespace detail { + + // + // To represent where is a toml::value defined, or where does an error occur. + // Stored in toml::value. source_location will be constructed based on this. + // + class region { + public: + using char_type = location::char_type; + using container_type = location::container_type; + using difference_type = location::difference_type; + using source_ptr = location::source_ptr; + + using iterator = typename container_type::iterator; + using const_iterator = typename container_type::const_iterator; + + public: + // a value that is constructed manually does not have input stream info + region() + : source_(nullptr) + , source_name_("") + , length_(0) + , first_line_(0) + , first_column_(0) + , last_line_(0) + , last_column_(0) {} + + // a value defined in [first, last). + // Those source must be the same. Instread, `region` does not make sense. + region(const location& first, const location& last); + + // shorthand of [loc, loc+1) + explicit region(const location& loc); + + ~region() = default; + region(const region&) = default; + region(region&&) = default; + region& operator=(const region&) = default; + region& operator=(region&&) = default; + + bool is_ok() const noexcept { + return static_cast(this->source_); + } + + operator bool() const noexcept { + return this->is_ok(); + } + + std::size_t length() const noexcept { + return this->length_; + } + + std::size_t first_line_number() const noexcept { + return this->first_line_; + } + + std::size_t first_column_number() const noexcept { + return this->first_column_; + } + + std::size_t last_line_number() const noexcept { + return this->last_line_; + } + + std::size_t last_column_number() const noexcept { + return this->last_column_; + } + + char_type at(std::size_t i) const; + + const_iterator begin() const noexcept; + const_iterator end() const noexcept; + const_iterator cbegin() const noexcept; + const_iterator cend() const noexcept; + + std::string as_string() const; + std::vector as_lines() const; + + const source_ptr& source() const noexcept { + return this->source_; + } + + const std::string& source_name() const noexcept { + return this->source_name_; + } + + private: + source_ptr source_; + std::string source_name_; + std::size_t length_; + std::size_t first_; + std::size_t first_line_; + std::size_t first_column_; + std::size_t last_; + std::size_t last_line_; + std::size_t last_column_; + }; + + } // namespace detail +} // namespace toml +#endif // TOML11_REGION_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_REGION_IMPL_HPP + #define TOML11_REGION_IMPL_HPP + + #include + #include + #include + #include + #include + #include + +namespace toml { + namespace detail { + + // a value defined in [first, last). + // Those source must be the same. Instread, `region` does not make sense. + TOML11_INLINE region::region(const location& first, const location& last) + : source_(first.source()) + , source_name_(first.source_name()) + , length_(last.get_location() - first.get_location()) + , first_(first.get_location()) + , first_line_(first.line_number()) + , first_column_(first.column_number()) + , last_(last.get_location()) + , last_line_(last.line_number()) + , last_column_(last.column_number()) { + assert(first.source() == last.source()); + assert(first.source_name() == last.source_name()); + } + + // shorthand of [loc, loc+1) + TOML11_INLINE region::region(const location& loc) + : source_(loc.source()) + , source_name_(loc.source_name()) + , length_(0) + , first_line_(0) + , first_column_(0) + , last_line_(0) + , last_column_(0) { + // if the file ends with LF, the resulting region points no char. + if (loc.eof()) { + if (loc.get_location() == 0) { + this->length_ = 0; + this->first_ = 0; + this->first_line_ = 0; + this->first_column_ = 0; + this->last_ = 0; + this->last_line_ = 0; + this->last_column_ = 0; + } else { + const auto first = prev(loc); + this->first_ = first.get_location(); + this->first_line_ = first.line_number(); + this->first_column_ = first.column_number(); + this->last_ = loc.get_location(); + this->last_line_ = loc.line_number(); + this->last_column_ = loc.column_number(); + this->length_ = 1; + } + } else { + this->first_ = loc.get_location(); + this->first_line_ = loc.line_number(); + this->first_column_ = loc.column_number(); + this->last_ = loc.get_location() + 1; + this->last_line_ = loc.line_number(); + this->last_column_ = loc.column_number() + 1; + this->length_ = 1; + } + } + + TOML11_INLINE region::char_type region::at(std::size_t i) const { + if (this->last_ <= this->first_ + i) { + throw std::out_of_range("range::at: index " + std::to_string(i) + + " exceeds length " + std::to_string(this->length_)); + } + const auto iter = std::next(this->source_->cbegin(), + static_cast(this->first_ + i)); + return *iter; + } + + TOML11_INLINE region::const_iterator region::begin() const noexcept { + return std::next(this->source_->cbegin(), + static_cast(this->first_)); + } + + TOML11_INLINE region::const_iterator region::end() const noexcept { + return std::next(this->source_->cbegin(), + static_cast(this->last_)); + } + + TOML11_INLINE region::const_iterator region::cbegin() const noexcept { + return std::next(this->source_->cbegin(), + static_cast(this->first_)); + } + + TOML11_INLINE region::const_iterator region::cend() const noexcept { + return std::next(this->source_->cbegin(), + static_cast(this->last_)); + } + + TOML11_INLINE std::string region::as_string() const { + if (this->is_ok()) { + const auto begin = std::next(this->source_->cbegin(), + static_cast(this->first_)); + const auto end = std::next(this->source_->cbegin(), + static_cast(this->last_)); + return ::toml::detail::make_string(begin, end); + } else { + return std::string(""); + } + } + + TOML11_INLINE std::vector region::as_lines() const { + assert(this->is_ok()); + if (this->length_ == 0) { + return std::vector { "" }; + } + + // Consider the following toml file + // ``` + // array = [ + // ] # comment + // ``` + // and the region represnets + // ``` + // [ + // ] + // ``` + // but we want to show the following. + // ``` + // array = [ + // ] # comment + // ``` + // So we need to find LFs before `begin` and after `end`. + // + // But, if region ends with LF, it should not include the next line. + // ``` + // a = 42 + // ^^^- with the last LF + // ``` + // So we start from `end-1` when looking for LF. + + const auto begin_idx = static_cast(this->first_); + const auto end_idx = static_cast(this->last_) - 1; + + // length_ != 0, so begin < end. then begin <= end-1 + assert(begin_idx <= end_idx); + + const auto begin = std::next(this->source_->cbegin(), begin_idx); + const auto end = std::next(this->source_->cbegin(), end_idx); + + const auto line_begin = std::find(cxx::make_reverse_iterator(begin), + this->source_->crend(), + char_type('\n')) + .base(); + const auto line_end = std::find(end, this->source_->cend(), char_type('\n')); + + const auto reg_lines = make_string(line_begin, line_end); + + if (reg_lines == "") // the region is an empty line that only contains LF + { + return std::vector { "" }; + } + + std::istringstream iss(reg_lines); + + std::vector lines; + std::string line; + while (std::getline(iss, line)) { + lines.push_back(line); + } + return lines; + } + + } // namespace detail +} // namespace toml + #endif // TOML11_REGION_IMPL_HPP +#endif + +#endif // TOML11_REGION_HPP +#ifndef TOML11_SOURCE_LOCATION_HPP +#define TOML11_SOURCE_LOCATION_HPP + +#ifndef TOML11_SOURCE_LOCATION_FWD_HPP + #define TOML11_SOURCE_LOCATION_FWD_HPP + + #include + #include + #include + +namespace toml { + + // A struct to contain location in a toml file. + struct source_location { + public: + explicit source_location(const detail::region& r); + ~source_location() = default; + source_location(const source_location&) = default; + source_location(source_location&&) = default; + source_location& operator=(const source_location&) = default; + source_location& operator=(source_location&&) = default; + + bool is_ok() const noexcept { + return this->is_ok_; + } + + std::size_t length() const noexcept { + return this->length_; + } + + std::size_t first_line_number() const noexcept { + return this->first_line_; + } + + std::size_t first_column_number() const noexcept { + return this->first_column_; + } + + std::size_t last_line_number() const noexcept { + return this->last_line_; + } + + std::size_t last_column_number() const noexcept { + return this->last_column_; + } + + const std::string& file_name() const noexcept { + return this->file_name_; + } + + std::size_t num_lines() const noexcept { + return this->line_str_.size(); + } + + const std::string& first_line() const; + const std::string& last_line() const; + + const std::vector& lines() const noexcept { + return line_str_; + } + + private: + bool is_ok_; + std::size_t first_line_; + std::size_t first_column_; + std::size_t last_line_; + std::size_t last_column_; + std::size_t length_; + std::string file_name_; + std::vector line_str_; + }; + + namespace detail { + + std::size_t integer_width_base10(std::size_t i) noexcept; + + inline std::size_t line_width() noexcept { + return 0; + } + + template + std::size_t line_width(const source_location& loc, + const std::string& /*msg*/, + const Ts&... tail) noexcept { + return (std::max)(integer_width_base10(loc.last_line_number()), + line_width(tail...)); + } + + std::ostringstream& format_filename(std::ostringstream& oss, + const source_location& loc); + + std::ostringstream& format_empty_line(std::ostringstream& oss, + const std::size_t lnw); + + std::ostringstream& format_line(std::ostringstream& oss, + const std::size_t lnw, + const std::size_t linenum, + const std::string& line); + + std::ostringstream& format_underline(std::ostringstream& oss, + const std::size_t lnw, + const std::size_t col, + const std::size_t len, + const std::string& msg); + + std::string format_location_impl(const std::size_t lnw, + const std::string& prev_fname, + const source_location& loc, + const std::string& msg); + + inline std::string format_location_rec(const std::size_t, const std::string&) { + return ""; + } + + template + std::string format_location_rec(const std::size_t lnw, + const std::string& prev_fname, + const source_location& loc, + const std::string& msg, + const Ts&... tail) { + return format_location_impl(lnw, prev_fname, loc, msg) + + format_location_rec(lnw, loc.file_name(), tail...); + } + + } // namespace detail + + // format a location info without title + template + std::string format_location(const source_location& loc, + const std::string& msg, + const Ts&... tail) { + const auto lnw = detail::line_width(loc, msg, tail...); + + const std::string f(""); // at the 1st iteration, no prev_filename is given + return detail::format_location_rec(lnw, f, loc, msg, tail...); + } + +} // namespace toml +#endif // TOML11_SOURCE_LOCATION_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_SOURCE_LOCATION_IMPL_HPP + #define TOML11_SOURCE_LOCATION_IMPL_HPP + + #include + #include + #include + #include + #include + +namespace toml { + + TOML11_INLINE source_location::source_location(const detail::region& r) + : is_ok_(false) + , first_line_(1) + , first_column_(1) + , last_line_(1) + , last_column_(1) + , length_(0) + , file_name_("unknown file") { + if (r.is_ok()) { + this->is_ok_ = true; + this->file_name_ = r.source_name(); + this->first_line_ = r.first_line_number(); + this->first_column_ = r.first_column_number(); + this->last_line_ = r.last_line_number(); + this->last_column_ = r.last_column_number(); + this->length_ = r.length(); + this->line_str_ = r.as_lines(); + } + } + + TOML11_INLINE const std::string& source_location::first_line() const { + if (this->line_str_.size() == 0) { + throw std::out_of_range( + "toml::source_location::first_line: `lines` is empty"); + } + return this->line_str_.front(); + } + + TOML11_INLINE const std::string& source_location::last_line() const { + if (this->line_str_.size() == 0) { + throw std::out_of_range( + "toml::source_location::first_line: `lines` is empty"); + } + return this->line_str_.back(); + } + + namespace detail { + + TOML11_INLINE std::size_t integer_width_base10(std::size_t i) noexcept { + std::size_t width = 0; + while (i != 0) { + i /= 10; + width += 1; + } + return width; + } + + TOML11_INLINE std::ostringstream& format_filename(std::ostringstream& oss, + const source_location& loc) { + // --> example.toml + oss << color::bold << color::blue << " --> " << color::reset + << color::bold << loc.file_name() << '\n' + << color::reset; + return oss; + } + + TOML11_INLINE std::ostringstream& format_empty_line(std::ostringstream& oss, + const std::size_t lnw) { + // | + oss << detail::make_string(lnw + 1, ' ') << color::bold << color::blue + << " |\n" + << color::reset; + return oss; + } + + TOML11_INLINE std::ostringstream& format_line(std::ostringstream& oss, + const std::size_t lnw, + const std::size_t linenum, + const std::string& line) { + // 10 | key = "value" + oss << ' ' << color::bold << color::blue << std::setw(static_cast(lnw)) + << std::right << linenum << " | " << color::reset; + for (const char c : line) { + if (std::isgraph(c) || c == ' ') { + oss << c; + } else { + oss << show_char(c); + } + } + oss << '\n'; + return oss; + } + + TOML11_INLINE std::ostringstream& format_underline(std::ostringstream& oss, + const std::size_t lnw, + const std::size_t col, + const std::size_t len, + const std::string& msg) { + // | ^^^^^^^-- this part + oss << make_string(lnw + 1, ' ') << color::bold << color::blue << " | " + << color::reset; + + oss << make_string(col - 1 /*1-origin*/, ' ') << color::bold << color::red + << make_string(len, '^') << "-- " << color::reset << msg << '\n'; + + return oss; + } + + TOML11_INLINE std::string format_location_impl(const std::size_t lnw, + const std::string& prev_fname, + const source_location& loc, + const std::string& msg) { + std::ostringstream oss; + + if (loc.file_name() != prev_fname) { + format_filename(oss, loc); + if (!loc.lines().empty()) { + format_empty_line(oss, lnw); + } + } + + if (loc.lines().size() == 1) { + // when column points LF, it exceeds the size of the first line. + std::size_t underline_limit = 1; + if (loc.first_line().size() < loc.first_column_number()) { + underline_limit = 1; + } else { + underline_limit = loc.first_line().size() - loc.first_column_number() + 1; + } + const auto underline_len = (std::min)(underline_limit, loc.length()); + + format_line(oss, lnw, loc.first_line_number(), loc.first_line()); + format_underline(oss, lnw, loc.first_column_number(), underline_len, msg); + } else if (loc.lines().size() == 2) { + const auto first_underline_len = loc.first_line().size() - + loc.first_column_number() + 1; + format_line(oss, lnw, loc.first_line_number(), loc.first_line()); + format_underline(oss, lnw, loc.first_column_number(), first_underline_len, ""); + + format_line(oss, lnw, loc.last_line_number(), loc.last_line()); + format_underline(oss, lnw, 1, loc.last_column_number(), msg); + } else if (loc.lines().size() > 2) { + const auto first_underline_len = loc.first_line().size() - + loc.first_column_number() + 1; + format_line(oss, lnw, loc.first_line_number(), loc.first_line()); + format_underline(oss, lnw, loc.first_column_number(), first_underline_len, "and"); + + if (loc.lines().size() == 3) { + format_line(oss, lnw, loc.first_line_number() + 1, loc.lines().at(1)); + format_underline(oss, lnw, 1, loc.lines().at(1).size(), "and"); + } else { + format_line(oss, lnw, loc.first_line_number() + 1, " ..."); + format_empty_line(oss, lnw); + } + format_line(oss, lnw, loc.last_line_number(), loc.last_line()); + format_underline(oss, lnw, 1, loc.last_column_number(), msg); + } + // if loc is empty, do nothing. + return oss.str(); + } + + } // namespace detail +} // namespace toml + #endif // TOML11_SOURCE_LOCATION_IMPL_HPP +#endif + +#endif // TOML11_SOURCE_LOCATION_HPP +#ifndef TOML11_ERROR_INFO_HPP +#define TOML11_ERROR_INFO_HPP + +#ifndef TOML11_ERROR_INFO_FWD_HPP + #define TOML11_ERROR_INFO_FWD_HPP + +namespace toml { + + // error info returned from parser. + struct error_info { + error_info(std::string t, source_location l, std::string m, std::string s = "") + : title_(std::move(t)) + , locations_ { std::make_pair(std::move(l), std::move(m)) } + , suffix_(std::move(s)) {} + + error_info(std::string t, + std::vector> l, + std::string s = "") + : title_(std::move(t)) + , locations_(std::move(l)) + , suffix_(std::move(s)) {} + + const std::string& title() const noexcept { + return title_; + } + + std::string& title() noexcept { + return title_; + } + + const std::vector>& locations() const noexcept { + return locations_; + } + + void add_locations(source_location loc, std::string msg) noexcept { + locations_.emplace_back(std::move(loc), std::move(msg)); + } + + const std::string& suffix() const noexcept { + return suffix_; + } + + std::string& suffix() noexcept { + return suffix_; + } + + private: + std::string title_; + std::vector> locations_; + std::string suffix_; // hint or something like that + }; + + // forward decl + template + class basic_value; + + namespace detail { + inline error_info make_error_info_rec(error_info e) { + return e; + } + + inline error_info make_error_info_rec(error_info e, std::string s) { + e.suffix() = s; + return e; + } + + template + error_info make_error_info_rec(error_info e, + const basic_value& v, + std::string msg, + Ts&&... tail); + + template + error_info make_error_info_rec(error_info e, + source_location loc, + std::string msg, + Ts&&... tail) { + e.add_locations(std::move(loc), std::move(msg)); + return make_error_info_rec(std::move(e), std::forward(tail)...); + } + + } // namespace detail + + template + error_info make_error_info(std::string title, + source_location loc, + std::string msg, + Ts&&... tail) { + error_info ei(std::move(title), std::move(loc), std::move(msg)); + return detail::make_error_info_rec(ei, std::forward(tail)...); + } + + std::string format_error(const std::string& errkind, const error_info& err); + std::string format_error(const error_info& err); + + // for custom error message + template + std::string format_error(std::string title, + source_location loc, + std::string msg, + Ts&&... tail) { + return format_error("", + make_error_info(std::move(title), + std::move(loc), + std::move(msg), + std::forward(tail)...)); + } + + std::ostream& operator<<(std::ostream& os, const error_info& e); + +} // namespace toml +#endif // TOML11_ERROR_INFO_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_ERROR_INFO_IMPL_HPP + #define TOML11_ERROR_INFO_IMPL_HPP + + #include + +namespace toml { + + TOML11_INLINE std::string format_error(const std::string& errkind, + const error_info& err) { + std::string errmsg; + if (!errkind.empty()) { + errmsg = errkind; + errmsg += ' '; + } + errmsg += err.title(); + errmsg += '\n'; + + const auto lnw = [&err]() { + std::size_t width = 0; + for (const auto& l : err.locations()) { + width = (std::max)(detail::integer_width_base10(l.first.last_line_number()), + width); + } + return width; + }(); + + bool first = true; + std::string prev_fname; + for (const auto& lm : err.locations()) { + if (!first) { + std::ostringstream oss; + oss << detail::make_string(lnw + 1, ' ') << color::bold << color::blue + << " |" << color::reset << color::bold << " ...\n" + << color::reset; + oss << detail::make_string(lnw + 1, ' ') << color::bold << color::blue + << " |\n" + << color::reset; + errmsg += oss.str(); + } + + const auto& l = lm.first; + const auto& m = lm.second; + + errmsg += detail::format_location_impl(lnw, prev_fname, l, m); + + prev_fname = l.file_name(); + first = false; + } + + errmsg += err.suffix(); + + return errmsg; + } + + TOML11_INLINE std::string format_error(const error_info& err) { + std::ostringstream oss; + oss << color::red << color::bold << "[error]" << color::reset; + return format_error(oss.str(), err); + } + + TOML11_INLINE std::ostream& operator<<(std::ostream& os, const error_info& e) { + os << format_error(e); + return os; + } + +} // namespace toml + #endif // TOML11_ERROR_INFO_IMPL_HPP +#endif + +#endif // TOML11_ERROR_INFO_HPP +#ifndef TOML11_VALUE_HPP +#define TOML11_VALUE_HPP + +#ifdef TOML11_HAS_STRING_VIEW + #include +#endif + +#include + +namespace toml { + template + class basic_value; + + struct type_error final : public ::toml::exception { + public: + type_error(std::string what_arg, source_location loc) + : what_(std::move(what_arg)) + , loc_(std::move(loc)) {} + + ~type_error() noexcept override = default; + + const char* what() const noexcept override { + return what_.c_str(); + } + + const source_location& location() const noexcept { + return loc_; + } + + private: + std::string what_; + source_location loc_; + }; + + // only for internal use + namespace detail { + template + error_info make_type_error(const basic_value&, + const std::string&, + const value_t); + + template + error_info make_not_found_error(const basic_value&, + const std::string&, + const typename basic_value::key_type&); + + template + void change_region_of_value(basic_value&, const basic_value&); + + template + struct getter; + } // namespace detail + + template + class basic_value { + public: + using config_type = TypeConfig; + using key_type = typename config_type::string_type; + using value_type = basic_value; + using boolean_type = typename config_type::boolean_type; + using integer_type = typename config_type::integer_type; + using floating_type = typename config_type::floating_type; + using string_type = typename config_type::string_type; + using local_time_type = ::toml::local_time; + using local_date_type = ::toml::local_date; + using local_datetime_type = ::toml::local_datetime; + using offset_datetime_type = ::toml::offset_datetime; + using array_type = typename config_type::template array_type; + using table_type = typename config_type::template table_type; + using comment_type = typename config_type::comment_type; + using char_type = typename string_type::value_type; + + private: + using region_type = detail::region; + + public: + basic_value() noexcept + : type_(value_t::empty) + , empty_('\0') + , region_ {} + , comments_ {} {} + + ~basic_value() noexcept { + this->cleanup(); + } + + // copy/move constructor/assigner ===================================== {{{ + + basic_value(const basic_value& v) + : type_(v.type_) + , region_(v.region_) + , comments_(v.comments_) { + switch (this->type_) { + case value_t::boolean: + assigner(boolean_, v.boolean_); + break; + case value_t::integer: + assigner(integer_, v.integer_); + break; + case value_t::floating: + assigner(floating_, v.floating_); + break; + case value_t::string: + assigner(string_, v.string_); + break; + case value_t::offset_datetime: + assigner(offset_datetime_, v.offset_datetime_); + break; + case value_t::local_datetime: + assigner(local_datetime_, v.local_datetime_); + break; + case value_t::local_date: + assigner(local_date_, v.local_date_); + break; + case value_t::local_time: + assigner(local_time_, v.local_time_); + break; + case value_t::array: + assigner(array_, v.array_); + break; + case value_t::table: + assigner(table_, v.table_); + break; + default: + assigner(empty_, '\0'); + break; + } + } + + basic_value(basic_value&& v) + : type_(v.type()) + , region_(std::move(v.region_)) + , comments_(std::move(v.comments_)) { + switch (this->type_) { + case value_t::boolean: + assigner(boolean_, std::move(v.boolean_)); + break; + case value_t::integer: + assigner(integer_, std::move(v.integer_)); + break; + case value_t::floating: + assigner(floating_, std::move(v.floating_)); + break; + case value_t::string: + assigner(string_, std::move(v.string_)); + break; + case value_t::offset_datetime: + assigner(offset_datetime_, std::move(v.offset_datetime_)); + break; + case value_t::local_datetime: + assigner(local_datetime_, std::move(v.local_datetime_)); + break; + case value_t::local_date: + assigner(local_date_, std::move(v.local_date_)); + break; + case value_t::local_time: + assigner(local_time_, std::move(v.local_time_)); + break; + case value_t::array: + assigner(array_, std::move(v.array_)); + break; + case value_t::table: + assigner(table_, std::move(v.table_)); + break; + default: + assigner(empty_, '\0'); + break; + } + } + + basic_value& operator=(const basic_value& v) { + if (this == std::addressof(v)) { + return *this; + } + + this->cleanup(); + this->type_ = v.type_; + this->region_ = v.region_; + this->comments_ = v.comments_; + switch (this->type_) { + case value_t::boolean: + assigner(boolean_, v.boolean_); + break; + case value_t::integer: + assigner(integer_, v.integer_); + break; + case value_t::floating: + assigner(floating_, v.floating_); + break; + case value_t::string: + assigner(string_, v.string_); + break; + case value_t::offset_datetime: + assigner(offset_datetime_, v.offset_datetime_); + break; + case value_t::local_datetime: + assigner(local_datetime_, v.local_datetime_); + break; + case value_t::local_date: + assigner(local_date_, v.local_date_); + break; + case value_t::local_time: + assigner(local_time_, v.local_time_); + break; + case value_t::array: + assigner(array_, v.array_); + break; + case value_t::table: + assigner(table_, v.table_); + break; + default: + assigner(empty_, '\0'); + break; + } + return *this; + } + + basic_value& operator=(basic_value&& v) { + if (this == std::addressof(v)) { + return *this; + } + + this->cleanup(); + this->type_ = v.type_; + this->region_ = std::move(v.region_); + this->comments_ = std::move(v.comments_); + switch (this->type_) { + case value_t::boolean: + assigner(boolean_, std::move(v.boolean_)); + break; + case value_t::integer: + assigner(integer_, std::move(v.integer_)); + break; + case value_t::floating: + assigner(floating_, std::move(v.floating_)); + break; + case value_t::string: + assigner(string_, std::move(v.string_)); + break; + case value_t::offset_datetime: + assigner(offset_datetime_, std::move(v.offset_datetime_)); + break; + case value_t::local_datetime: + assigner(local_datetime_, std::move(v.local_datetime_)); + break; + case value_t::local_date: + assigner(local_date_, std::move(v.local_date_)); + break; + case value_t::local_time: + assigner(local_time_, std::move(v.local_time_)); + break; + case value_t::array: + assigner(array_, std::move(v.array_)); + break; + case value_t::table: + assigner(table_, std::move(v.table_)); + break; + default: + assigner(empty_, '\0'); + break; + } + return *this; + } + + // }}} + + // constructor to overwrite commnets ================================== {{{ + + basic_value(basic_value v, std::vector com) + : type_(v.type()) + , region_(std::move(v.region_)) + , comments_(std::move(com)) { + switch (this->type_) { + case value_t::boolean: + assigner(boolean_, std::move(v.boolean_)); + break; + case value_t::integer: + assigner(integer_, std::move(v.integer_)); + break; + case value_t::floating: + assigner(floating_, std::move(v.floating_)); + break; + case value_t::string: + assigner(string_, std::move(v.string_)); + break; + case value_t::offset_datetime: + assigner(offset_datetime_, std::move(v.offset_datetime_)); + break; + case value_t::local_datetime: + assigner(local_datetime_, std::move(v.local_datetime_)); + break; + case value_t::local_date: + assigner(local_date_, std::move(v.local_date_)); + break; + case value_t::local_time: + assigner(local_time_, std::move(v.local_time_)); + break; + case value_t::array: + assigner(array_, std::move(v.array_)); + break; + case value_t::table: + assigner(table_, std::move(v.table_)); + break; + default: + assigner(empty_, '\0'); + break; + } + } + + // }}} + + // conversion between different basic_values ========================== {{{ + + template + basic_value(basic_value other) + : type_(other.type_) + , region_(std::move(other.region_)) + , comments_(std::move(other.comments_)) { + switch (other.type_) { + // use auto-convert in constructor + case value_t::boolean: + assigner(boolean_, std::move(other.boolean_)); + break; + case value_t::integer: + assigner(integer_, std::move(other.integer_)); + break; + case value_t::floating: + assigner(floating_, std::move(other.floating_)); + break; + case value_t::string: + assigner(string_, std::move(other.string_)); + break; + case value_t::offset_datetime: + assigner(offset_datetime_, std::move(other.offset_datetime_)); + break; + case value_t::local_datetime: + assigner(local_datetime_, std::move(other.local_datetime_)); + break; + case value_t::local_date: + assigner(local_date_, std::move(other.local_date_)); + break; + case value_t::local_time: + assigner(local_time_, std::move(other.local_time_)); + break; + + // may have different container type + case value_t::array: { + array_type tmp(std::make_move_iterator(other.array_.value.get().begin()), + std::make_move_iterator(other.array_.value.get().end())); + assigner(array_, + array_storage(detail::storage(std::move(tmp)), + other.array_.format)); + break; + } + case value_t::table: { + table_type tmp(std::make_move_iterator(other.table_.value.get().begin()), + std::make_move_iterator(other.table_.value.get().end())); + assigner(table_, + table_storage(detail::storage(std::move(tmp)), + other.table_.format)); + break; + } + default: + break; + } + } + + template + basic_value(basic_value other, std::vector com) + : type_(other.type_) + , region_(std::move(other.region_)) + , comments_(std::move(com)) { + switch (other.type_) { + // use auto-convert in constructor + case value_t::boolean: + assigner(boolean_, std::move(other.boolean_)); + break; + case value_t::integer: + assigner(integer_, std::move(other.integer_)); + break; + case value_t::floating: + assigner(floating_, std::move(other.floating_)); + break; + case value_t::string: + assigner(string_, std::move(other.string_)); + break; + case value_t::offset_datetime: + assigner(offset_datetime_, std::move(other.offset_datetime_)); + break; + case value_t::local_datetime: + assigner(local_datetime_, std::move(other.local_datetime_)); + break; + case value_t::local_date: + assigner(local_date_, std::move(other.local_date_)); + break; + case value_t::local_time: + assigner(local_time_, std::move(other.local_time_)); + break; + + // may have different container type + case value_t::array: { + array_type tmp(std::make_move_iterator(other.array_.value.get().begin()), + std::make_move_iterator(other.array_.value.get().end())); + assigner(array_, + array_storage(detail::storage(std::move(tmp)), + other.array_.format)); + break; + } + case value_t::table: { + table_type tmp(std::make_move_iterator(other.table_.value.get().begin()), + std::make_move_iterator(other.table_.value.get().end())); + assigner(table_, + table_storage(detail::storage(std::move(tmp)), + other.table_.format)); + break; + } + default: + break; + } + } + + template + basic_value& operator=(basic_value other) { + this->cleanup(); + this->region_ = other.region_; + this->comments_ = comment_type(other.comments_); + this->type_ = other.type_; + switch (other.type_) { + // use auto-convert in constructor + case value_t::boolean: + assigner(boolean_, std::move(other.boolean_)); + break; + case value_t::integer: + assigner(integer_, std::move(other.integer_)); + break; + case value_t::floating: + assigner(floating_, std::move(other.floating_)); + break; + case value_t::string: + assigner(string_, std::move(other.string_)); + break; + case value_t::offset_datetime: + assigner(offset_datetime_, std::move(other.offset_datetime_)); + break; + case value_t::local_datetime: + assigner(local_datetime_, std::move(other.local_datetime_)); + break; + case value_t::local_date: + assigner(local_date_, std::move(other.local_date_)); + break; + case value_t::local_time: + assigner(local_time_, std::move(other.local_time_)); + break; + + // may have different container type + case value_t::array: { + array_type tmp(std::make_move_iterator(other.array_.value.get().begin()), + std::make_move_iterator(other.array_.value.get().end())); + assigner(array_, + array_storage(detail::storage(std::move(tmp)), + other.array_.format)); + break; + } + case value_t::table: { + table_type tmp(std::make_move_iterator(other.table_.value.get().begin()), + std::make_move_iterator(other.table_.value.get().end())); + assigner(table_, + table_storage(detail::storage(std::move(tmp)), + other.table_.format)); + break; + } + default: + break; + } + return *this; + } + + // }}} + + // constructor (boolean) ============================================== {{{ + + basic_value(boolean_type x) + : basic_value(x, + boolean_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(boolean_type x, boolean_format_info fmt) + : basic_value(x, fmt, std::vector {}, region_type {}) {} + + basic_value(boolean_type x, std::vector com) + : basic_value(x, boolean_format_info {}, std::move(com), region_type {}) {} + + basic_value(boolean_type x, boolean_format_info fmt, std::vector com) + : basic_value(x, fmt, std::move(com), region_type {}) {} + + basic_value(boolean_type x, + boolean_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::boolean) + , boolean_(boolean_storage(x, fmt)) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(boolean_type x) { + boolean_format_info fmt; + if (this->is_boolean()) { + fmt = this->as_boolean_fmt(); + } + this->cleanup(); + this->type_ = value_t::boolean; + this->region_ = region_type {}; + assigner(this->boolean_, boolean_storage(x, fmt)); + return *this; + } + + // }}} + + // constructor (integer) ============================================== {{{ + + basic_value(integer_type x) + : basic_value(std::move(x), + integer_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(integer_type x, integer_format_info fmt) + : basic_value(std::move(x), + std::move(fmt), + std::vector {}, + region_type {}) {} + + basic_value(integer_type x, std::vector com) + : basic_value(std::move(x), + integer_format_info {}, + std::move(com), + region_type {}) {} + + basic_value(integer_type x, integer_format_info fmt, std::vector com) + : basic_value(std::move(x), std::move(fmt), std::move(com), region_type {}) { + } + + basic_value(integer_type x, + integer_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::integer) + , integer_(integer_storage(std::move(x), std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(integer_type x) { + integer_format_info fmt; + if (this->is_integer()) { + fmt = this->as_integer_fmt(); + } + this->cleanup(); + this->type_ = value_t::integer; + this->region_ = region_type {}; + assigner(this->integer_, integer_storage(std::move(x), std::move(fmt))); + return *this; + } + + private: + template + using enable_if_integer_like_t = cxx::enable_if_t< + cxx::conjunction, boolean_type>>, + cxx::negation, integer_type>>, + std::is_integral>>::value, + std::nullptr_t>; + + public: + template = nullptr> + basic_value(T x) + : basic_value(std::move(x), + integer_format_info {}, + std::vector {}, + region_type {}) {} + + template = nullptr> + basic_value(T x, integer_format_info fmt) + : basic_value(std::move(x), + std::move(fmt), + std::vector {}, + region_type {}) {} + + template = nullptr> + basic_value(T x, std::vector com) + : basic_value(std::move(x), + integer_format_info {}, + std::move(com), + region_type {}) {} + + template = nullptr> + basic_value(T x, integer_format_info fmt, std::vector com) + : basic_value(std::move(x), std::move(fmt), std::move(com), region_type {}) { + } + + template = nullptr> + basic_value(T x, + integer_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::integer) + , integer_(integer_storage(std::move(x), std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + template = nullptr> + basic_value& operator=(T x) { + integer_format_info fmt; + if (this->is_integer()) { + fmt = this->as_integer_fmt(); + } + this->cleanup(); + this->type_ = value_t::integer; + this->region_ = region_type {}; + assigner(this->integer_, integer_storage(x, std::move(fmt))); + return *this; + } + + // }}} + + // constructor (floating) ============================================= {{{ + + basic_value(floating_type x) + : basic_value(std::move(x), + floating_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(floating_type x, floating_format_info fmt) + : basic_value(std::move(x), + std::move(fmt), + std::vector {}, + region_type {}) {} + + basic_value(floating_type x, std::vector com) + : basic_value(std::move(x), + floating_format_info {}, + std::move(com), + region_type {}) {} + + basic_value(floating_type x, + floating_format_info fmt, + std::vector com) + : basic_value(std::move(x), std::move(fmt), std::move(com), region_type {}) { + } + + basic_value(floating_type x, + floating_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::floating) + , floating_(floating_storage(std::move(x), std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(floating_type x) { + floating_format_info fmt; + if (this->is_floating()) { + fmt = this->as_floating_fmt(); + } + this->cleanup(); + this->type_ = value_t::floating; + this->region_ = region_type {}; + assigner(this->floating_, floating_storage(std::move(x), std::move(fmt))); + return *this; + } + + private: + template + using enable_if_floating_like_t = cxx::enable_if_t< + cxx::conjunction, floating_type>>, + std::is_floating_point>>::value, + std::nullptr_t>; + + public: + template = nullptr> + basic_value(T x) + : basic_value(x, + floating_format_info {}, + std::vector {}, + region_type {}) {} + + template = nullptr> + basic_value(T x, floating_format_info fmt) + : basic_value(x, std::move(fmt), std::vector {}, region_type {}) { + } + + template = nullptr> + basic_value(T x, std::vector com) + : basic_value(x, floating_format_info {}, std::move(com), region_type {}) {} + + template = nullptr> + basic_value(T x, floating_format_info fmt, std::vector com) + : basic_value(x, std::move(fmt), std::move(com), region_type {}) {} + + template = nullptr> + basic_value(T x, + floating_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::floating) + , floating_(floating_storage(x, std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + template = nullptr> + basic_value& operator=(T x) { + floating_format_info fmt; + if (this->is_floating()) { + fmt = this->as_floating_fmt(); + } + this->cleanup(); + this->type_ = value_t::floating; + this->region_ = region_type {}; + assigner(this->floating_, floating_storage(x, std::move(fmt))); + return *this; + } + + // }}} + + // constructor (string) =============================================== {{{ + + basic_value(string_type x) + : basic_value(std::move(x), + string_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(string_type x, string_format_info fmt) + : basic_value(std::move(x), + std::move(fmt), + std::vector {}, + region_type {}) {} + + basic_value(string_type x, std::vector com) + : basic_value(std::move(x), + string_format_info {}, + std::move(com), + region_type {}) {} + + basic_value(string_type x, string_format_info fmt, std::vector com) + : basic_value(std::move(x), std::move(fmt), std::move(com), region_type {}) { + } + + basic_value(string_type x, + string_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::string) + , string_(string_storage(std::move(x), std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(string_type x) { + string_format_info fmt; + if (this->is_string()) { + fmt = this->as_string_fmt(); + } + this->cleanup(); + this->type_ = value_t::string; + this->region_ = region_type {}; + assigner(this->string_, string_storage(x, std::move(fmt))); + return *this; + } + + // "string literal" + + basic_value(const typename string_type::value_type* x) + : basic_value(x, + string_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(const typename string_type::value_type* x, string_format_info fmt) + : basic_value(x, std::move(fmt), std::vector {}, region_type {}) { + } + + basic_value(const typename string_type::value_type* x, + std::vector com) + : basic_value(x, string_format_info {}, std::move(com), region_type {}) {} + + basic_value(const typename string_type::value_type* x, + string_format_info fmt, + std::vector com) + : basic_value(x, std::move(fmt), std::move(com), region_type {}) {} + + basic_value(const typename string_type::value_type* x, + string_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::string) + , string_(string_storage(string_type(x), std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(const typename string_type::value_type* x) { + string_format_info fmt; + if (this->is_string()) { + fmt = this->as_string_fmt(); + } + this->cleanup(); + this->type_ = value_t::string; + this->region_ = region_type {}; + assigner(this->string_, string_storage(string_type(x), std::move(fmt))); + return *this; + } + +#if defined(TOML11_HAS_STRING_VIEW) + using string_view_type = std::basic_string_view; + + basic_value(string_view_type x) + : basic_value(x, + string_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(string_view_type x, string_format_info fmt) + : basic_value(x, std::move(fmt), std::vector {}, region_type {}) { + } + + basic_value(string_view_type x, std::vector com) + : basic_value(x, string_format_info {}, std::move(com), region_type {}) {} + + basic_value(string_view_type x, + string_format_info fmt, + std::vector com) + : basic_value(x, std::move(fmt), std::move(com), region_type {}) {} + + basic_value(string_view_type x, + string_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::string) + , string_(string_storage(string_type(x), std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(string_view_type x) { + string_format_info fmt; + if (this->is_string()) { + fmt = this->as_string_fmt(); + } + this->cleanup(); + this->type_ = value_t::string; + this->region_ = region_type {}; + assigner(this->string_, string_storage(string_type(x), std::move(fmt))); + return *this; + } + +#endif // TOML11_HAS_STRING_VIEW + + template , string_type>>, + detail::is_1byte_std_basic_string>::value, + std::nullptr_t> = nullptr> + basic_value(const T& x) + : basic_value(x, + string_format_info {}, + std::vector {}, + region_type {}) {} + + template , string_type>>, + detail::is_1byte_std_basic_string>::value, + std::nullptr_t> = nullptr> + basic_value(const T& x, string_format_info fmt) + : basic_value(x, std::move(fmt), std::vector {}, region_type {}) { + } + + template , string_type>>, + detail::is_1byte_std_basic_string>::value, + std::nullptr_t> = nullptr> + basic_value(const T& x, std::vector com) + : basic_value(x, string_format_info {}, std::move(com), region_type {}) {} + + template , string_type>>, + detail::is_1byte_std_basic_string>::value, + std::nullptr_t> = nullptr> + basic_value(const T& x, string_format_info fmt, std::vector com) + : basic_value(x, std::move(fmt), std::move(com), region_type {}) {} + + template , string_type>>, + detail::is_1byte_std_basic_string>::value, + std::nullptr_t> = nullptr> + basic_value(const T& x, + string_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::string) + , string_(string_storage(detail::string_conv(x), std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + template , string_type>>, + detail::is_1byte_std_basic_string>::value, + std::nullptr_t> = nullptr> + basic_value& operator=(const T& x) { + string_format_info fmt; + if (this->is_string()) { + fmt = this->as_string_fmt(); + } + this->cleanup(); + this->type_ = value_t::string; + this->region_ = region_type {}; + assigner(this->string_, + string_storage(detail::string_conv(x), std::move(fmt))); + return *this; + } + + // }}} + + // constructor (local_date) =========================================== {{{ + + basic_value(local_date_type x) + : basic_value(x, + local_date_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(local_date_type x, local_date_format_info fmt) + : basic_value(x, fmt, std::vector {}, region_type {}) {} + + basic_value(local_date_type x, std::vector com) + : basic_value(x, local_date_format_info {}, std::move(com), region_type {}) { + } + + basic_value(local_date_type x, + local_date_format_info fmt, + std::vector com) + : basic_value(x, fmt, std::move(com), region_type {}) {} + + basic_value(local_date_type x, + local_date_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::local_date) + , local_date_(local_date_storage(x, fmt)) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(local_date_type x) { + local_date_format_info fmt; + if (this->is_local_date()) { + fmt = this->as_local_date_fmt(); + } + this->cleanup(); + this->type_ = value_t::local_date; + this->region_ = region_type {}; + assigner(this->local_date_, local_date_storage(x, fmt)); + return *this; + } + + // }}} + + // constructor (local_time) =========================================== {{{ + + basic_value(local_time_type x) + : basic_value(x, + local_time_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(local_time_type x, local_time_format_info fmt) + : basic_value(x, fmt, std::vector {}, region_type {}) {} + + basic_value(local_time_type x, std::vector com) + : basic_value(x, local_time_format_info {}, std::move(com), region_type {}) { + } + + basic_value(local_time_type x, + local_time_format_info fmt, + std::vector com) + : basic_value(x, fmt, std::move(com), region_type {}) {} + + basic_value(local_time_type x, + local_time_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::local_time) + , local_time_(local_time_storage(x, fmt)) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(local_time_type x) { + local_time_format_info fmt; + if (this->is_local_time()) { + fmt = this->as_local_time_fmt(); + } + this->cleanup(); + this->type_ = value_t::local_time; + this->region_ = region_type {}; + assigner(this->local_time_, local_time_storage(x, fmt)); + return *this; + } + + template + basic_value(const std::chrono::duration& x) + : basic_value(local_time_type(x), + local_time_format_info {}, + std::vector {}, + region_type {}) {} + + template + basic_value(const std::chrono::duration& x, + local_time_format_info fmt) + : basic_value(local_time_type(x), + std::move(fmt), + std::vector {}, + region_type {}) {} + + template + basic_value(const std::chrono::duration& x, + std::vector com) + : basic_value(local_time_type(x), + local_time_format_info {}, + std::move(com), + region_type {}) {} + + template + basic_value(const std::chrono::duration& x, + local_time_format_info fmt, + std::vector com) + : basic_value(local_time_type(x), + std::move(fmt), + std::move(com), + region_type {}) {} + + template + basic_value(const std::chrono::duration& x, + local_time_format_info fmt, + std::vector com, + region_type reg) + : basic_value(local_time_type(x), + std::move(fmt), + std::move(com), + std::move(reg)) {} + + template + basic_value& operator=(const std::chrono::duration& x) { + local_time_format_info fmt; + if (this->is_local_time()) { + fmt = this->as_local_time_fmt(); + } + this->cleanup(); + this->type_ = value_t::local_time; + this->region_ = region_type {}; + assigner(this->local_time_, + local_time_storage(local_time_type(x), std::move(fmt))); + return *this; + } + + // }}} + + // constructor (local_datetime) =========================================== {{{ + + basic_value(local_datetime_type x) + : basic_value(x, + local_datetime_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(local_datetime_type x, local_datetime_format_info fmt) + : basic_value(x, fmt, std::vector {}, region_type {}) {} + + basic_value(local_datetime_type x, std::vector com) + : basic_value(x, local_datetime_format_info {}, std::move(com), region_type {}) { + } + + basic_value(local_datetime_type x, + local_datetime_format_info fmt, + std::vector com) + : basic_value(x, fmt, std::move(com), region_type {}) {} + + basic_value(local_datetime_type x, + local_datetime_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::local_datetime) + , local_datetime_(local_datetime_storage(x, fmt)) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(local_datetime_type x) { + local_datetime_format_info fmt; + if (this->is_local_datetime()) { + fmt = this->as_local_datetime_fmt(); + } + this->cleanup(); + this->type_ = value_t::local_datetime; + this->region_ = region_type {}; + assigner(this->local_datetime_, local_datetime_storage(x, fmt)); + return *this; + } + + // }}} + + // constructor (offset_datetime) =========================================== {{{ + + basic_value(offset_datetime_type x) + : basic_value(x, + offset_datetime_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(offset_datetime_type x, offset_datetime_format_info fmt) + : basic_value(x, fmt, std::vector {}, region_type {}) {} + + basic_value(offset_datetime_type x, std::vector com) + : basic_value(x, + offset_datetime_format_info {}, + std::move(com), + region_type {}) {} + + basic_value(offset_datetime_type x, + offset_datetime_format_info fmt, + std::vector com) + : basic_value(x, fmt, std::move(com), region_type {}) {} + + basic_value(offset_datetime_type x, + offset_datetime_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::offset_datetime) + , offset_datetime_(offset_datetime_storage(x, fmt)) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(offset_datetime_type x) { + offset_datetime_format_info fmt; + if (this->is_offset_datetime()) { + fmt = this->as_offset_datetime_fmt(); + } + this->cleanup(); + this->type_ = value_t::offset_datetime; + this->region_ = region_type {}; + assigner(this->offset_datetime_, offset_datetime_storage(x, fmt)); + return *this; + } + + // system_clock::time_point + + basic_value(std::chrono::system_clock::time_point x) + : basic_value(offset_datetime_type(x), + offset_datetime_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(std::chrono::system_clock::time_point x, + offset_datetime_format_info fmt) + : basic_value(offset_datetime_type(x), + fmt, + std::vector {}, + region_type {}) {} + + basic_value(std::chrono::system_clock::time_point x, + std::vector com) + : basic_value(offset_datetime_type(x), + offset_datetime_format_info {}, + std::move(com), + region_type {}) {} + + basic_value(std::chrono::system_clock::time_point x, + offset_datetime_format_info fmt, + std::vector com) + : basic_value(offset_datetime_type(x), fmt, std::move(com), region_type {}) { + } + + basic_value(std::chrono::system_clock::time_point x, + offset_datetime_format_info fmt, + std::vector com, + region_type reg) + : basic_value(offset_datetime_type(x), + std::move(fmt), + std::move(com), + std::move(reg)) {} + + basic_value& operator=(std::chrono::system_clock::time_point x) { + offset_datetime_format_info fmt; + if (this->is_offset_datetime()) { + fmt = this->as_offset_datetime_fmt(); + } + this->cleanup(); + this->type_ = value_t::offset_datetime; + this->region_ = region_type {}; + assigner(this->offset_datetime_, + offset_datetime_storage(offset_datetime_type(x), fmt)); + return *this; + } + + // }}} + + // constructor (array) ================================================ {{{ + + basic_value(array_type x) + : basic_value(std::move(x), + array_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(array_type x, array_format_info fmt) + : basic_value(std::move(x), + std::move(fmt), + std::vector {}, + region_type {}) {} + + basic_value(array_type x, std::vector com) + : basic_value(std::move(x), + array_format_info {}, + std::move(com), + region_type {}) {} + + basic_value(array_type x, array_format_info fmt, std::vector com) + : basic_value(std::move(x), fmt, std::move(com), region_type {}) {} + + basic_value(array_type x, + array_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::array) + , array_(array_storage(detail::storage(std::move(x)), + std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(array_type x) { + array_format_info fmt; + if (this->is_array()) { + fmt = this->as_array_fmt(); + } + this->cleanup(); + this->type_ = value_t::array; + this->region_ = region_type {}; + assigner(this->array_, + array_storage(detail::storage(std::move(x)), + std::move(fmt))); + return *this; + } + + private: + template + using enable_if_array_like_t = cxx::enable_if_t< + cxx::conjunction, + cxx::negation>, + cxx::negation>, +#if defined(TOML11_HAS_STRING_VIEW) + cxx::negation>, +#endif + cxx::negation>, + cxx::negation>>::value, + std::nullptr_t>; + + public: + template = nullptr> + basic_value(T x) + : basic_value(std::move(x), + array_format_info {}, + std::vector {}, + region_type {}) {} + + template = nullptr> + basic_value(T x, array_format_info fmt) + : basic_value(std::move(x), + std::move(fmt), + std::vector {}, + region_type {}) {} + + template = nullptr> + basic_value(T x, std::vector com) + : basic_value(std::move(x), + array_format_info {}, + std::move(com), + region_type {}) {} + + template = nullptr> + basic_value(T x, array_format_info fmt, std::vector com) + : basic_value(std::move(x), fmt, std::move(com), region_type {}) {} + + template = nullptr> + basic_value(T x, array_format_info fmt, std::vector com, region_type reg) + : type_(value_t::array) + , array_(array_storage(detail::storage( + array_type(std::make_move_iterator(x.begin()), + std::make_move_iterator(x.end()))), + std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + template = nullptr> + basic_value& operator=(T x) { + array_format_info fmt; + if (this->is_array()) { + fmt = this->as_array_fmt(); + } + this->cleanup(); + this->type_ = value_t::array; + this->region_ = region_type {}; + + array_type a(std::make_move_iterator(x.begin()), + std::make_move_iterator(x.end())); + assigner(this->array_, + array_storage(detail::storage(std::move(a)), + std::move(fmt))); + return *this; + } + + // }}} + + // constructor (table) ================================================ {{{ + + basic_value(table_type x) + : basic_value(std::move(x), + table_format_info {}, + std::vector {}, + region_type {}) {} + + basic_value(table_type x, table_format_info fmt) + : basic_value(std::move(x), + std::move(fmt), + std::vector {}, + region_type {}) {} + + basic_value(table_type x, std::vector com) + : basic_value(std::move(x), + table_format_info {}, + std::move(com), + region_type {}) {} + + basic_value(table_type x, table_format_info fmt, std::vector com) + : basic_value(std::move(x), fmt, std::move(com), region_type {}) {} + + basic_value(table_type x, + table_format_info fmt, + std::vector com, + region_type reg) + : type_(value_t::table) + , table_(table_storage(detail::storage(std::move(x)), + std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + basic_value& operator=(table_type x) { + table_format_info fmt; + if (this->is_table()) { + fmt = this->as_table_fmt(); + } + this->cleanup(); + this->type_ = value_t::table; + this->region_ = region_type {}; + assigner(this->table_, + table_storage(detail::storage(std::move(x)), + std::move(fmt))); + return *this; + } + + // table-like + + private: + template + using enable_if_table_like_t = cxx::enable_if_t< + cxx::conjunction>, + detail::is_map, + cxx::negation>, + cxx::negation>>::value, + std::nullptr_t>; + + public: + template = nullptr> + basic_value(T x) + : basic_value(std::move(x), + table_format_info {}, + std::vector {}, + region_type {}) {} + + template = nullptr> + basic_value(T x, table_format_info fmt) + : basic_value(std::move(x), + std::move(fmt), + std::vector {}, + region_type {}) {} + + template = nullptr> + basic_value(T x, std::vector com) + : basic_value(std::move(x), + table_format_info {}, + std::move(com), + region_type {}) {} + + template = nullptr> + basic_value(T x, table_format_info fmt, std::vector com) + : basic_value(std::move(x), fmt, std::move(com), region_type {}) {} + + template = nullptr> + basic_value(T x, table_format_info fmt, std::vector com, region_type reg) + : type_(value_t::table) + , table_(table_storage(detail::storage( + table_type(std::make_move_iterator(x.begin()), + std::make_move_iterator(x.end()))), + std::move(fmt))) + , region_(std::move(reg)) + , comments_(std::move(com)) {} + + template = nullptr> + basic_value& operator=(T x) { + table_format_info fmt; + if (this->is_table()) { + fmt = this->as_table_fmt(); + } + this->cleanup(); + this->type_ = value_t::table; + this->region_ = region_type {}; + + table_type t(std::make_move_iterator(x.begin()), + std::make_move_iterator(x.end())); + assigner(this->table_, + table_storage(detail::storage(std::move(t)), + std::move(fmt))); + return *this; + } + + // }}} + + // constructor (user_defined) ========================================= {{{ + + template ::value, std::nullptr_t> = nullptr> + basic_value(const T& ud) + : basic_value( + into>::template into_toml(ud)) {} + + template ::value, std::nullptr_t> = nullptr> + basic_value(const T& ud, std::vector com) + : basic_value( + into>::template into_toml(ud), + std::move(com)) {} + + template ::value, std::nullptr_t> = nullptr> + basic_value& operator=(const T& ud) { + *this = into>::template into_toml(ud); + return *this; + } + + template , + cxx::negation>>::value, + std::nullptr_t> = nullptr> + basic_value(const T& ud) : basic_value(ud.into_toml()) {} + + template , + cxx::negation>>::value, + std::nullptr_t> = nullptr> + basic_value(const T& ud, std::vector com) + : basic_value(ud.into_toml(), std::move(com)) {} + + template , + cxx::negation>>::value, + std::nullptr_t> = nullptr> + basic_value& operator=(const T& ud) { + *this = ud.into_toml(); + return *this; + } + + template , + cxx::negation>>::value, + std::nullptr_t> = nullptr> + basic_value(const T& ud) + : basic_value(ud.template into_toml()) {} + + template , + cxx::negation>>::value, + std::nullptr_t> = nullptr> + basic_value(const T& ud, std::vector com) + : basic_value(ud.template into_toml(), std::move(com)) {} + + template , + cxx::negation>>::value, + std::nullptr_t> = nullptr> + basic_value& operator=(const T& ud) { + *this = ud.template into_toml(); + return *this; + } + + // }}} + + // empty value with region info ======================================= {{{ + + // mainly for `null` extension + basic_value(detail::none_t, region_type reg) noexcept + : type_(value_t::empty) + , empty_('\0') + , region_(std::move(reg)) + , comments_ {} {} + + // }}} + + // type checking ====================================================== {{{ + + template , value_type>::value, + std::nullptr_t> = nullptr> + bool is() const noexcept { + return detail::type_to_enum::value == this->type_; + } + + bool is(value_t t) const noexcept { + return t == this->type_; + } + + bool is_empty() const noexcept { + return this->is(value_t::empty); + } + + bool is_boolean() const noexcept { + return this->is(value_t::boolean); + } + + bool is_integer() const noexcept { + return this->is(value_t::integer); + } + + bool is_floating() const noexcept { + return this->is(value_t::floating); + } + + bool is_string() const noexcept { + return this->is(value_t::string); + } + + bool is_offset_datetime() const noexcept { + return this->is(value_t::offset_datetime); + } + + bool is_local_datetime() const noexcept { + return this->is(value_t::local_datetime); + } + + bool is_local_date() const noexcept { + return this->is(value_t::local_date); + } + + bool is_local_time() const noexcept { + return this->is(value_t::local_time); + } + + bool is_array() const noexcept { + return this->is(value_t::array); + } + + bool is_table() const noexcept { + return this->is(value_t::table); + } + + bool is_array_of_tables() const noexcept { + if (!this->is_array()) { + return false; + } + const auto& a = this->as_array(std::nothrow); // already checked. + + // when you define [[array.of.tables]], at least one empty table will be + // assigned. In case of array of inline tables, `array_of_tables = []`, + // there is no reason to consider this as an array of *tables*. + // So empty array is not an array-of-tables. + if (a.empty()) { + return false; + } + + // since toml v1.0.0 allows array of heterogeneous types, we need to + // check all the elements. if any of the elements is not a table, it + // is a heterogeneous array and cannot be expressed by `[[aot]]` form. + for (const auto& e : a) { + if (!e.is_table()) { + return false; + } + } + return true; + } + + value_t type() const noexcept { + return type_; + } + + // }}} + + // as_xxx (noexcept) version ========================================== {{{ + + template + const detail::enum_to_type_t>& as( + const std::nothrow_t&) const noexcept { + return detail::getter::get_nothrow(*this); + } + + template + detail::enum_to_type_t>& as( + const std::nothrow_t&) noexcept { + return detail::getter::get_nothrow(*this); + } + + const boolean_type& as_boolean(const std::nothrow_t&) const noexcept { + return this->boolean_.value; + } + + const integer_type& as_integer(const std::nothrow_t&) const noexcept { + return this->integer_.value; + } + + const floating_type& as_floating(const std::nothrow_t&) const noexcept { + return this->floating_.value; + } + + const string_type& as_string(const std::nothrow_t&) const noexcept { + return this->string_.value; + } + + const offset_datetime_type& as_offset_datetime( + const std::nothrow_t&) const noexcept { + return this->offset_datetime_.value; + } + + const local_datetime_type& as_local_datetime( + const std::nothrow_t&) const noexcept { + return this->local_datetime_.value; + } + + const local_date_type& as_local_date(const std::nothrow_t&) const noexcept { + return this->local_date_.value; + } + + const local_time_type& as_local_time(const std::nothrow_t&) const noexcept { + return this->local_time_.value; + } + + const array_type& as_array(const std::nothrow_t&) const noexcept { + return this->array_.value.get(); + } + + const table_type& as_table(const std::nothrow_t&) const noexcept { + return this->table_.value.get(); + } + + boolean_type& as_boolean(const std::nothrow_t&) noexcept { + return this->boolean_.value; + } + + integer_type& as_integer(const std::nothrow_t&) noexcept { + return this->integer_.value; + } + + floating_type& as_floating(const std::nothrow_t&) noexcept { + return this->floating_.value; + } + + string_type& as_string(const std::nothrow_t&) noexcept { + return this->string_.value; + } + + offset_datetime_type& as_offset_datetime(const std::nothrow_t&) noexcept { + return this->offset_datetime_.value; + } + + local_datetime_type& as_local_datetime(const std::nothrow_t&) noexcept { + return this->local_datetime_.value; + } + + local_date_type& as_local_date(const std::nothrow_t&) noexcept { + return this->local_date_.value; + } + + local_time_type& as_local_time(const std::nothrow_t&) noexcept { + return this->local_time_.value; + } + + array_type& as_array(const std::nothrow_t&) noexcept { + return this->array_.value.get(); + } + + table_type& as_table(const std::nothrow_t&) noexcept { + return this->table_.value.get(); + } + + // }}} + + // as_xxx (throw) ===================================================== {{{ + + template + const detail::enum_to_type_t>& as() const { + return detail::getter::get(*this); + } + + template + detail::enum_to_type_t>& as() { + return detail::getter::get(*this); + } + + const boolean_type& as_boolean() const { + if (this->type_ != value_t::boolean) { + this->throw_bad_cast("toml::value::as_boolean()", value_t::boolean); + } + return this->boolean_.value; + } + + const integer_type& as_integer() const { + if (this->type_ != value_t::integer) { + this->throw_bad_cast("toml::value::as_integer()", value_t::integer); + } + return this->integer_.value; + } + + const floating_type& as_floating() const { + if (this->type_ != value_t::floating) { + this->throw_bad_cast("toml::value::as_floating()", value_t::floating); + } + return this->floating_.value; + } + + const string_type& as_string() const { + if (this->type_ != value_t::string) { + this->throw_bad_cast("toml::value::as_string()", value_t::string); + } + return this->string_.value; + } + + const offset_datetime_type& as_offset_datetime() const { + if (this->type_ != value_t::offset_datetime) { + this->throw_bad_cast("toml::value::as_offset_datetime()", + value_t::offset_datetime); + } + return this->offset_datetime_.value; + } + + const local_datetime_type& as_local_datetime() const { + if (this->type_ != value_t::local_datetime) { + this->throw_bad_cast("toml::value::as_local_datetime()", + value_t::local_datetime); + } + return this->local_datetime_.value; + } + + const local_date_type& as_local_date() const { + if (this->type_ != value_t::local_date) { + this->throw_bad_cast("toml::value::as_local_date()", value_t::local_date); + } + return this->local_date_.value; + } + + const local_time_type& as_local_time() const { + if (this->type_ != value_t::local_time) { + this->throw_bad_cast("toml::value::as_local_time()", value_t::local_time); + } + return this->local_time_.value; + } + + const array_type& as_array() const { + if (this->type_ != value_t::array) { + this->throw_bad_cast("toml::value::as_array()", value_t::array); + } + return this->array_.value.get(); + } + + const table_type& as_table() const { + if (this->type_ != value_t::table) { + this->throw_bad_cast("toml::value::as_table()", value_t::table); + } + return this->table_.value.get(); + } + + // ------------------------------------------------------------------------ + // nonconst reference + + boolean_type& as_boolean() { + if (this->type_ != value_t::boolean) { + this->throw_bad_cast("toml::value::as_boolean()", value_t::boolean); + } + return this->boolean_.value; + } + + integer_type& as_integer() { + if (this->type_ != value_t::integer) { + this->throw_bad_cast("toml::value::as_integer()", value_t::integer); + } + return this->integer_.value; + } + + floating_type& as_floating() { + if (this->type_ != value_t::floating) { + this->throw_bad_cast("toml::value::as_floating()", value_t::floating); + } + return this->floating_.value; + } + + string_type& as_string() { + if (this->type_ != value_t::string) { + this->throw_bad_cast("toml::value::as_string()", value_t::string); + } + return this->string_.value; + } + + offset_datetime_type& as_offset_datetime() { + if (this->type_ != value_t::offset_datetime) { + this->throw_bad_cast("toml::value::as_offset_datetime()", + value_t::offset_datetime); + } + return this->offset_datetime_.value; + } + + local_datetime_type& as_local_datetime() { + if (this->type_ != value_t::local_datetime) { + this->throw_bad_cast("toml::value::as_local_datetime()", + value_t::local_datetime); + } + return this->local_datetime_.value; + } + + local_date_type& as_local_date() { + if (this->type_ != value_t::local_date) { + this->throw_bad_cast("toml::value::as_local_date()", value_t::local_date); + } + return this->local_date_.value; + } + + local_time_type& as_local_time() { + if (this->type_ != value_t::local_time) { + this->throw_bad_cast("toml::value::as_local_time()", value_t::local_time); + } + return this->local_time_.value; + } + + array_type& as_array() { + if (this->type_ != value_t::array) { + this->throw_bad_cast("toml::value::as_array()", value_t::array); + } + return this->array_.value.get(); + } + + table_type& as_table() { + if (this->type_ != value_t::table) { + this->throw_bad_cast("toml::value::as_table()", value_t::table); + } + return this->table_.value.get(); + } + + // }}} + + // format accessors (noexcept) ======================================== {{{ + + template + const detail::enum_to_fmt_type_t& as_fmt(const std::nothrow_t&) const noexcept { + return detail::getter::get_fmt_nothrow(*this); + } + + template + detail::enum_to_fmt_type_t& as_fmt(const std::nothrow_t&) noexcept { + return detail::getter::get_fmt_nothrow(*this); + } + + boolean_format_info& as_boolean_fmt(const std::nothrow_t&) noexcept { + return this->boolean_.format; + } + + integer_format_info& as_integer_fmt(const std::nothrow_t&) noexcept { + return this->integer_.format; + } + + floating_format_info& as_floating_fmt(const std::nothrow_t&) noexcept { + return this->floating_.format; + } + + string_format_info& as_string_fmt(const std::nothrow_t&) noexcept { + return this->string_.format; + } + + offset_datetime_format_info& as_offset_datetime_fmt( + const std::nothrow_t&) noexcept { + return this->offset_datetime_.format; + } + + local_datetime_format_info& as_local_datetime_fmt(const std::nothrow_t&) noexcept { + return this->local_datetime_.format; + } + + local_date_format_info& as_local_date_fmt(const std::nothrow_t&) noexcept { + return this->local_date_.format; + } + + local_time_format_info& as_local_time_fmt(const std::nothrow_t&) noexcept { + return this->local_time_.format; + } + + array_format_info& as_array_fmt(const std::nothrow_t&) noexcept { + return this->array_.format; + } + + table_format_info& as_table_fmt(const std::nothrow_t&) noexcept { + return this->table_.format; + } + + const boolean_format_info& as_boolean_fmt(const std::nothrow_t&) const noexcept { + return this->boolean_.format; + } + + const integer_format_info& as_integer_fmt(const std::nothrow_t&) const noexcept { + return this->integer_.format; + } + + const floating_format_info& as_floating_fmt(const std::nothrow_t&) const noexcept { + return this->floating_.format; + } + + const string_format_info& as_string_fmt(const std::nothrow_t&) const noexcept { + return this->string_.format; + } + + const offset_datetime_format_info& as_offset_datetime_fmt( + const std::nothrow_t&) const noexcept { + return this->offset_datetime_.format; + } + + const local_datetime_format_info& as_local_datetime_fmt( + const std::nothrow_t&) const noexcept { + return this->local_datetime_.format; + } + + const local_date_format_info& as_local_date_fmt( + const std::nothrow_t&) const noexcept { + return this->local_date_.format; + } + + const local_time_format_info& as_local_time_fmt( + const std::nothrow_t&) const noexcept { + return this->local_time_.format; + } + + const array_format_info& as_array_fmt(const std::nothrow_t&) const noexcept { + return this->array_.format; + } + + const table_format_info& as_table_fmt(const std::nothrow_t&) const noexcept { + return this->table_.format; + } + + // }}} + + // format accessors (throw) =========================================== {{{ + + template + const detail::enum_to_fmt_type_t& as_fmt() const { + return detail::getter::get_fmt(*this); + } + + template + detail::enum_to_fmt_type_t& as_fmt() { + return detail::getter::get_fmt(*this); + } + + const boolean_format_info& as_boolean_fmt() const { + if (this->type_ != value_t::boolean) { + this->throw_bad_cast("toml::value::as_boolean_fmt()", value_t::boolean); + } + return this->boolean_.format; + } + + const integer_format_info& as_integer_fmt() const { + if (this->type_ != value_t::integer) { + this->throw_bad_cast("toml::value::as_integer_fmt()", value_t::integer); + } + return this->integer_.format; + } + + const floating_format_info& as_floating_fmt() const { + if (this->type_ != value_t::floating) { + this->throw_bad_cast("toml::value::as_floating_fmt()", value_t::floating); + } + return this->floating_.format; + } + + const string_format_info& as_string_fmt() const { + if (this->type_ != value_t::string) { + this->throw_bad_cast("toml::value::as_string_fmt()", value_t::string); + } + return this->string_.format; + } + + const offset_datetime_format_info& as_offset_datetime_fmt() const { + if (this->type_ != value_t::offset_datetime) { + this->throw_bad_cast("toml::value::as_offset_datetime_fmt()", + value_t::offset_datetime); + } + return this->offset_datetime_.format; + } + + const local_datetime_format_info& as_local_datetime_fmt() const { + if (this->type_ != value_t::local_datetime) { + this->throw_bad_cast("toml::value::as_local_datetime_fmt()", + value_t::local_datetime); + } + return this->local_datetime_.format; + } + + const local_date_format_info& as_local_date_fmt() const { + if (this->type_ != value_t::local_date) { + this->throw_bad_cast("toml::value::as_local_date_fmt()", + value_t::local_date); + } + return this->local_date_.format; + } + + const local_time_format_info& as_local_time_fmt() const { + if (this->type_ != value_t::local_time) { + this->throw_bad_cast("toml::value::as_local_time_fmt()", + value_t::local_time); + } + return this->local_time_.format; + } + + const array_format_info& as_array_fmt() const { + if (this->type_ != value_t::array) { + this->throw_bad_cast("toml::value::as_array_fmt()", value_t::array); + } + return this->array_.format; + } + + const table_format_info& as_table_fmt() const { + if (this->type_ != value_t::table) { + this->throw_bad_cast("toml::value::as_table_fmt()", value_t::table); + } + return this->table_.format; + } + + // ------------------------------------------------------------------------ + // nonconst reference + + boolean_format_info& as_boolean_fmt() { + if (this->type_ != value_t::boolean) { + this->throw_bad_cast("toml::value::as_boolean_fmt()", value_t::boolean); + } + return this->boolean_.format; + } + + integer_format_info& as_integer_fmt() { + if (this->type_ != value_t::integer) { + this->throw_bad_cast("toml::value::as_integer_fmt()", value_t::integer); + } + return this->integer_.format; + } + + floating_format_info& as_floating_fmt() { + if (this->type_ != value_t::floating) { + this->throw_bad_cast("toml::value::as_floating_fmt()", value_t::floating); + } + return this->floating_.format; + } + + string_format_info& as_string_fmt() { + if (this->type_ != value_t::string) { + this->throw_bad_cast("toml::value::as_string_fmt()", value_t::string); + } + return this->string_.format; + } + + offset_datetime_format_info& as_offset_datetime_fmt() { + if (this->type_ != value_t::offset_datetime) { + this->throw_bad_cast("toml::value::as_offset_datetime_fmt()", + value_t::offset_datetime); + } + return this->offset_datetime_.format; + } + + local_datetime_format_info& as_local_datetime_fmt() { + if (this->type_ != value_t::local_datetime) { + this->throw_bad_cast("toml::value::as_local_datetime_fmt()", + value_t::local_datetime); + } + return this->local_datetime_.format; + } + + local_date_format_info& as_local_date_fmt() { + if (this->type_ != value_t::local_date) { + this->throw_bad_cast("toml::value::as_local_date_fmt()", + value_t::local_date); + } + return this->local_date_.format; + } + + local_time_format_info& as_local_time_fmt() { + if (this->type_ != value_t::local_time) { + this->throw_bad_cast("toml::value::as_local_time_fmt()", + value_t::local_time); + } + return this->local_time_.format; + } + + array_format_info& as_array_fmt() { + if (this->type_ != value_t::array) { + this->throw_bad_cast("toml::value::as_array_fmt()", value_t::array); + } + return this->array_.format; + } + + table_format_info& as_table_fmt() { + if (this->type_ != value_t::table) { + this->throw_bad_cast("toml::value::as_table_fmt()", value_t::table); + } + return this->table_.format; + } + + // }}} + + // table accessors ==================================================== {{{ + + value_type& at(const key_type& k) { + if (!this->is_table()) { + this->throw_bad_cast("toml::value::at(key_type)", value_t::table); + } + auto& table = this->as_table(std::nothrow); + const auto found = table.find(k); + if (found == table.end()) { + this->throw_key_not_found_error("toml::value::at", k); + } + assert(found->first == k); + return found->second; + } + + const value_type& at(const key_type& k) const { + if (!this->is_table()) { + this->throw_bad_cast("toml::value::at(key_type)", value_t::table); + } + const auto& table = this->as_table(std::nothrow); + const auto found = table.find(k); + if (found == table.end()) { + this->throw_key_not_found_error("toml::value::at", k); + } + assert(found->first == k); + return found->second; + } + + value_type& operator[](const key_type& k) { + if (this->is_empty()) { + (*this) = table_type {}; + } else if (!this->is_table()) // initialized, but not a table + { + this->throw_bad_cast("toml::value::operator[](key_type)", value_t::table); + } + return (this->as_table(std::nothrow))[k]; + } + + std::size_t count(const key_type& k) const { + if (!this->is_table()) { + this->throw_bad_cast("toml::value::count(key_type)", value_t::table); + } + return this->as_table(std::nothrow).count(k); + } + + bool contains(const key_type& k) const { + if (!this->is_table()) { + this->throw_bad_cast("toml::value::contains(key_type)", value_t::table); + } + const auto& table = this->as_table(std::nothrow); + return table.find(k) != table.end(); + } + + // }}} + + // array accessors ==================================================== {{{ + + value_type& at(const std::size_t idx) { + if (!this->is_array()) { + this->throw_bad_cast("toml::value::at(idx)", value_t::array); + } + auto& ar = this->as_array(std::nothrow); + + if (ar.size() <= idx) { + std::ostringstream oss; + oss << "actual length (" << ar.size() + << ") is shorter than the specified index (" << idx << ")."; + throw std::out_of_range(format_error( + "toml::value::at(idx): no element corresponding to the index", + this->location(), + oss.str())); + } + return ar.at(idx); + } + + const value_type& at(const std::size_t idx) const { + if (!this->is_array()) { + this->throw_bad_cast("toml::value::at(idx)", value_t::array); + } + const auto& ar = this->as_array(std::nothrow); + + if (ar.size() <= idx) { + std::ostringstream oss; + oss << "actual length (" << ar.size() + << ") is shorter than the specified index (" << idx << ")."; + + throw std::out_of_range(format_error( + "toml::value::at(idx): no element corresponding to the index", + this->location(), + oss.str())); + } + return ar.at(idx); + } + + value_type& operator[](const std::size_t idx) noexcept { + // no check... + return this->as_array(std::nothrow)[idx]; + } + + const value_type& operator[](const std::size_t idx) const noexcept { + // no check... + return this->as_array(std::nothrow)[idx]; + } + + void push_back(const value_type& x) { + if (!this->is_array()) { + this->throw_bad_cast("toml::value::push_back(idx)", value_t::array); + } + this->as_array(std::nothrow).push_back(x); + return; + } + + void push_back(value_type&& x) { + if (!this->is_array()) { + this->throw_bad_cast("toml::value::push_back(idx)", value_t::array); + } + this->as_array(std::nothrow).push_back(std::move(x)); + return; + } + + template + value_type& emplace_back(Ts&&... args) { + if (!this->is_array()) { + this->throw_bad_cast("toml::value::emplace_back(idx)", value_t::array); + } + auto& ar = this->as_array(std::nothrow); + ar.emplace_back(std::forward(args)...); + return ar.back(); + } + + std::size_t size() const { + switch (this->type_) { + case value_t::array: { + return this->as_array(std::nothrow).size(); + } + case value_t::table: { + return this->as_table(std::nothrow).size(); + } + case value_t::string: { + return this->as_string(std::nothrow).size(); + } + default: { + throw type_error( + format_error("toml::value::size(): bad_cast to container types", + this->location(), + "the actual type is " + to_string(this->type_)), + this->location()); + } + } + } + + // }}} + + source_location location() const { + return source_location(this->region_); + } + + const comment_type& comments() const noexcept { + return this->comments_; + } + + comment_type& comments() noexcept { + return this->comments_; + } + + private: + // private helper functions =========================================== {{{ + + void cleanup() noexcept { + switch (this->type_) { + case value_t::boolean: { + boolean_.~boolean_storage(); + break; + } + case value_t::integer: { + integer_.~integer_storage(); + break; + } + case value_t::floating: { + floating_.~floating_storage(); + break; + } + case value_t::string: { + string_.~string_storage(); + break; + } + case value_t::offset_datetime: { + offset_datetime_.~offset_datetime_storage(); + break; + } + case value_t::local_datetime: { + local_datetime_.~local_datetime_storage(); + break; + } + case value_t::local_date: { + local_date_.~local_date_storage(); + break; + } + case value_t::local_time: { + local_time_.~local_time_storage(); + break; + } + case value_t::array: { + array_.~array_storage(); + break; + } + case value_t::table: { + table_.~table_storage(); + break; + } + default: { + break; + } + } + this->type_ = value_t::empty; + return; + } + + template + static void assigner(T& dst, U&& v) { + const auto tmp = ::new (std::addressof(dst)) T(std::forward(v)); + assert(tmp == std::addressof(dst)); + (void)tmp; + } + + [[noreturn]] + void throw_bad_cast(const std::string& funcname, const value_t ty) const { + throw type_error(format_error(detail::make_type_error(*this, funcname, ty)), + this->location()); + } + + [[noreturn]] + void throw_key_not_found_error(const std::string& funcname, + const key_type& key) const { + throw std::out_of_range( + format_error(detail::make_not_found_error(*this, funcname, key))); + } + + template + friend void detail::change_region_of_value(basic_value&, + const basic_value&); + + template + friend class basic_value; + + // }}} + + private: + using boolean_storage = detail::value_with_format; + using integer_storage = detail::value_with_format; + using floating_storage = detail::value_with_format; + using string_storage = detail::value_with_format; + using offset_datetime_storage = + detail::value_with_format; + using local_datetime_storage = + detail::value_with_format; + using local_date_storage = + detail::value_with_format; + using local_time_storage = + detail::value_with_format; + using array_storage = + detail::value_with_format, array_format_info>; + using table_storage = + detail::value_with_format, table_format_info>; + + private: + value_t type_; + + union { + char empty_; // the smallest type + boolean_storage boolean_; + integer_storage integer_; + floating_storage floating_; + string_storage string_; + offset_datetime_storage offset_datetime_; + local_datetime_storage local_datetime_; + local_date_storage local_date_; + local_time_storage local_time_; + array_storage array_; + table_storage table_; + }; + + region_type region_; + comment_type comments_; + }; + + template + bool operator==(const basic_value& lhs, const basic_value& rhs) { + if (lhs.type() != rhs.type()) { + return false; + } + if (lhs.comments() != rhs.comments()) { + return false; + } + + switch (lhs.type()) { + case value_t::boolean: { + return lhs.as_boolean() == rhs.as_boolean(); + } + case value_t::integer: { + return lhs.as_integer() == rhs.as_integer(); + } + case value_t::floating: { + return lhs.as_floating() == rhs.as_floating(); + } + case value_t::string: { + return lhs.as_string() == rhs.as_string(); + } + case value_t::offset_datetime: { + return lhs.as_offset_datetime() == rhs.as_offset_datetime(); + } + case value_t::local_datetime: { + return lhs.as_local_datetime() == rhs.as_local_datetime(); + } + case value_t::local_date: { + return lhs.as_local_date() == rhs.as_local_date(); + } + case value_t::local_time: { + return lhs.as_local_time() == rhs.as_local_time(); + } + case value_t::array: { + return lhs.as_array() == rhs.as_array(); + } + case value_t::table: { + return lhs.as_table() == rhs.as_table(); + } + case value_t::empty: { + return true; + } + default: { + return false; + } + } + } + + template + bool operator!=(const basic_value& lhs, const basic_value& rhs) { + return !(lhs == rhs); + } + + template + cxx::enable_if_t< + cxx::conjunction::array_type>, + detail::is_comparable::table_type>>::value, + bool> + operator<(const basic_value& lhs, const basic_value& rhs) { + if (lhs.type() != rhs.type()) { + return (lhs.type() < rhs.type()); + } + switch (lhs.type()) { + case value_t::boolean: { + return lhs.as_boolean() < rhs.as_boolean() || + (lhs.as_boolean() == rhs.as_boolean() && + lhs.comments() < rhs.comments()); + } + case value_t::integer: { + return lhs.as_integer() < rhs.as_integer() || + (lhs.as_integer() == rhs.as_integer() && + lhs.comments() < rhs.comments()); + } + case value_t::floating: { + return lhs.as_floating() < rhs.as_floating() || + (lhs.as_floating() == rhs.as_floating() && + lhs.comments() < rhs.comments()); + } + case value_t::string: { + return lhs.as_string() < rhs.as_string() || + (lhs.as_string() == rhs.as_string() && + lhs.comments() < rhs.comments()); + } + case value_t::offset_datetime: { + return lhs.as_offset_datetime() < rhs.as_offset_datetime() || + (lhs.as_offset_datetime() == rhs.as_offset_datetime() && + lhs.comments() < rhs.comments()); + } + case value_t::local_datetime: { + return lhs.as_local_datetime() < rhs.as_local_datetime() || + (lhs.as_local_datetime() == rhs.as_local_datetime() && + lhs.comments() < rhs.comments()); + } + case value_t::local_date: { + return lhs.as_local_date() < rhs.as_local_date() || + (lhs.as_local_date() == rhs.as_local_date() && + lhs.comments() < rhs.comments()); + } + case value_t::local_time: { + return lhs.as_local_time() < rhs.as_local_time() || + (lhs.as_local_time() == rhs.as_local_time() && + lhs.comments() < rhs.comments()); + } + case value_t::array: { + return lhs.as_array() < rhs.as_array() || + (lhs.as_array() == rhs.as_array() && + lhs.comments() < rhs.comments()); + } + case value_t::table: { + return lhs.as_table() < rhs.as_table() || + (lhs.as_table() == rhs.as_table() && + lhs.comments() < rhs.comments()); + } + case value_t::empty: { + return lhs.comments() < rhs.comments(); + } + default: { + return lhs.comments() < rhs.comments(); + } + } + } + + template + cxx::enable_if_t< + cxx::conjunction::array_type>, + detail::is_comparable::table_type>>::value, + bool> + operator<=(const basic_value& lhs, const basic_value& rhs) { + return (lhs < rhs) || (lhs == rhs); + } + + template + cxx::enable_if_t< + cxx::conjunction::array_type>, + detail::is_comparable::table_type>>::value, + bool> + operator>(const basic_value& lhs, const basic_value& rhs) { + return !(lhs <= rhs); + } + + template + cxx::enable_if_t< + cxx::conjunction::array_type>, + detail::is_comparable::table_type>>::value, + bool> + operator>=(const basic_value& lhs, const basic_value& rhs) { + return !(lhs < rhs); + } + + // error_info helper + namespace detail { + template + error_info make_error_info_rec(error_info e, + const basic_value& v, + std::string msg, + Ts&&... tail) { + return make_error_info_rec(std::move(e), + v.location(), + std::move(msg), + std::forward(tail)...); + } + } // namespace detail + + template + error_info make_error_info(std::string title, + const basic_value& v, + std::string msg, + Ts&&... tail) { + return make_error_info(std::move(title), + v.location(), + std::move(msg), + std::forward(tail)...); + } + + template + std::string format_error(std::string title, + const basic_value& v, + std::string msg, + Ts&&... tail) { + return format_error(std::move(title), + v.location(), + std::move(msg), + std::forward(tail)...); + } + + namespace detail { + + template + error_info make_type_error(const basic_value& v, + const std::string& fname, + const value_t ty) { + return make_error_info(fname + ": bad_cast to " + to_string(ty), + v.location(), + "the actual type is " + to_string(v.type())); + } + + template + error_info make_not_found_error(const basic_value& v, + const std::string& fname, + const typename basic_value::key_type& key) { + const auto loc = v.location(); + const std::string title = fname + ": key \"" + + string_conv(key) + "\" not found"; + + std::vector> locs; + if (!loc.is_ok()) { + return error_info(title, locs); + } + + if (loc.first_line_number() == 1 && loc.first_column_number() == 1 && + loc.length() == 1) { + // The top-level table has its region at the 0th character of the file. + // That means that, in the case when a key is not found in the top-level + // table, the error message points to the first character. If the file has + // the first table at the first line, the error message would be like this. + // ```console + // [error] key "a" not found + // --> example.toml + // | + // 1 | [table] + // | ^------ in this table + // ``` + // It actually points to the top-level table at the first character, not + // `[table]`. But it is too confusing. To avoid the confusion, the error + // message should explicitly say "key not found in the top-level table". + locs.emplace_back(v.location(), "at the top-level table"); + } else { + locs.emplace_back(v.location(), "in this table"); + } + return error_info(title, locs); + } + +#define TOML11_DETAIL_GENERATE_COMPTIME_GETTER(ty) \ + template \ + struct getter { \ + using value_type = basic_value; \ + using result_type = enum_to_type_t; \ + using format_type = enum_to_fmt_type_t; \ + \ + static result_type& get(value_type& v) { \ + return v.as_##ty(); \ + } \ + static result_type const& get(const value_type& v) { \ + return v.as_##ty(); \ + } \ + \ + static result_type& get_nothrow(value_type& v) noexcept { \ + return v.as_##ty(std::nothrow); \ + } \ + static result_type const& get_nothrow(const value_type& v) noexcept { \ + return v.as_##ty(std::nothrow); \ + } \ + \ + static format_type& get_fmt(value_type& v) { \ + return v.as_##ty##_fmt(); \ + } \ + static format_type const& get_fmt(const value_type& v) { \ + return v.as_##ty##_fmt(); \ + } \ + \ + static format_type& get_fmt_nothrow(value_type& v) noexcept { \ + return v.as_##ty##_fmt(std::nothrow); \ + } \ + static format_type const& get_fmt_nothrow(const value_type& v) noexcept { \ + return v.as_##ty##_fmt(std::nothrow); \ + } \ + }; + + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(boolean) + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(integer) + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(floating) + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(string) + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(offset_datetime) + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(local_datetime) + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(local_date) + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(local_time) + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(array) + TOML11_DETAIL_GENERATE_COMPTIME_GETTER(table) + +#undef TOML11_DETAIL_GENERATE_COMPTIME_GETTER + + template + void change_region_of_value(basic_value& dst, const basic_value& src) { + dst.region_ = std::move(src.region_); + return; + } + + } // namespace detail +} // namespace toml +#endif // TOML11_VALUE_HPP +#ifndef TOML11_VISIT_HPP +#define TOML11_VISIT_HPP + +namespace toml { + + template + cxx::return_type_of_t::boolean_type&> visit( + Visitor&& visitor, + const basic_value& v) { + switch (v.type()) { + case value_t::boolean: { + return visitor(v.as_boolean()); + } + case value_t::integer: { + return visitor(v.as_integer()); + } + case value_t::floating: { + return visitor(v.as_floating()); + } + case value_t::string: { + return visitor(v.as_string()); + } + case value_t::offset_datetime: { + return visitor(v.as_offset_datetime()); + } + case value_t::local_datetime: { + return visitor(v.as_local_datetime()); + } + case value_t::local_date: { + return visitor(v.as_local_date()); + } + case value_t::local_time: { + return visitor(v.as_local_time()); + } + case value_t::array: { + return visitor(v.as_array()); + } + case value_t::table: { + return visitor(v.as_table()); + } + case value_t::empty: + break; + default: + break; + } + throw type_error(format_error("[error] toml::visit: toml::basic_value " + "does not have any valid type.", + v.location(), + "here"), + v.location()); + } + + template + cxx::return_type_of_t::boolean_type&> visit( + Visitor&& visitor, + basic_value& v) { + switch (v.type()) { + case value_t::boolean: { + return visitor(v.as_boolean()); + } + case value_t::integer: { + return visitor(v.as_integer()); + } + case value_t::floating: { + return visitor(v.as_floating()); + } + case value_t::string: { + return visitor(v.as_string()); + } + case value_t::offset_datetime: { + return visitor(v.as_offset_datetime()); + } + case value_t::local_datetime: { + return visitor(v.as_local_datetime()); + } + case value_t::local_date: { + return visitor(v.as_local_date()); + } + case value_t::local_time: { + return visitor(v.as_local_time()); + } + case value_t::array: { + return visitor(v.as_array()); + } + case value_t::table: { + return visitor(v.as_table()); + } + case value_t::empty: + break; + default: + break; + } + throw type_error(format_error("[error] toml::visit: toml::basic_value " + "does not have any valid type.", + v.location(), + "here"), + v.location()); + } + + template + cxx::return_type_of_t::boolean_type&&> visit( + Visitor&& visitor, + basic_value&& v) { + switch (v.type()) { + case value_t::boolean: { + return visitor(std::move(v.as_boolean())); + } + case value_t::integer: { + return visitor(std::move(v.as_integer())); + } + case value_t::floating: { + return visitor(std::move(v.as_floating())); + } + case value_t::string: { + return visitor(std::move(v.as_string())); + } + case value_t::offset_datetime: { + return visitor(std::move(v.as_offset_datetime())); + } + case value_t::local_datetime: { + return visitor(std::move(v.as_local_datetime())); + } + case value_t::local_date: { + return visitor(std::move(v.as_local_date())); + } + case value_t::local_time: { + return visitor(std::move(v.as_local_time())); + } + case value_t::array: { + return visitor(std::move(v.as_array())); + } + case value_t::table: { + return visitor(std::move(v.as_table())); + } + case value_t::empty: + break; + default: + break; + } + throw type_error(format_error("[error] toml::visit: toml::basic_value " + "does not have any valid type.", + v.location(), + "here"), + v.location()); + } + +} // namespace toml +#endif // TOML11_VISIT_HPP +#ifndef TOML11_TYPES_HPP +#define TOML11_TYPES_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace toml { + + // forward decl + template + class basic_value; + + // when you use a special integer type as toml::value::integer_type, parse must + // be able to read it. So, type_config has static member functions that read the + // integer_type as {dec, hex, oct, bin}-integer. But, in most cases, operator<< + // is enough. To make config easy, we provide the default read functions. + // + // Before this functions is called, syntax is checked and prefix(`0x` etc) and + // spacer(`_`) are removed. + + template + result read_dec_int(const std::string& str, + const source_location src) { + constexpr auto max_digits = std::numeric_limits::digits; + assert(!str.empty()); + + T val { 0 }; + std::istringstream iss(str); + iss >> val; + if (iss.fail()) { + return err(make_error_info("toml::parse_dec_integer: " + "too large integer: current max digits = 2^" + + std::to_string(max_digits), + std::move(src), + "must be < 2^" + std::to_string(max_digits))); + } + return ok(val); + } + + template + result read_hex_int(const std::string& str, + const source_location src) { + constexpr auto max_digits = std::numeric_limits::digits; + assert(!str.empty()); + + T val { 0 }; + std::istringstream iss(str); + iss >> std::hex >> val; + if (iss.fail()) { + return err(make_error_info("toml::parse_hex_integer: " + "too large integer: current max value = 2^" + + std::to_string(max_digits), + std::move(src), + "must be < 2^" + std::to_string(max_digits))); + } + return ok(val); + } + + template + result read_oct_int(const std::string& str, + const source_location src) { + constexpr auto max_digits = std::numeric_limits::digits; + assert(!str.empty()); + + T val { 0 }; + std::istringstream iss(str); + iss >> std::oct >> val; + if (iss.fail()) { + return err(make_error_info("toml::parse_oct_integer: " + "too large integer: current max value = 2^" + + std::to_string(max_digits), + std::move(src), + "must be < 2^" + std::to_string(max_digits))); + } + return ok(val); + } + + template + result read_bin_int(const std::string& str, + const source_location src) { + constexpr auto is_bounded = std::numeric_limits::is_bounded; + constexpr auto max_digits = std::numeric_limits::digits; + const auto max_value = (std::numeric_limits::max)(); + + T val { 0 }; + T base { 1 }; + for (auto i = str.rbegin(); i != str.rend(); ++i) { + const auto c = *i; + if (c == '1') { + val += base; + // prevent `base` from overflow + if (is_bounded && max_value / 2 < base && std::next(i) != str.rend()) { + base = 0; + } else { + base *= 2; + } + } else { + assert(c == '0'); + + if (is_bounded && max_value / 2 < base && std::next(i) != str.rend()) { + base = 0; + } else { + base *= 2; + } + } + } + if (base == 0) { + return err(make_error_info("toml::parse_bin_integer: " + "too large integer: current max value = 2^" + + std::to_string(max_digits), + std::move(src), + "must be < 2^" + std::to_string(max_digits))); + } + return ok(val); + } + + template + result read_int(const std::string& str, + const source_location src, + const std::uint8_t base) { + assert(base == 10 || base == 16 || base == 8 || base == 2); + switch (base) { + case 2: { + return read_bin_int(str, src); + } + case 8: { + return read_oct_int(str, src); + } + case 16: { + return read_hex_int(str, src); + } + default: { + assert(base == 10); + return read_dec_int(str, src); + } + } + } + + inline result read_hex_float(const std::string& str, + const source_location src, + float val) { +#if defined(_MSC_VER) && !defined(__clang__) + const auto res = ::sscanf_s(str.c_str(), "%a", std::addressof(val)); +#else + const auto res = std::sscanf(str.c_str(), "%a", std::addressof(val)); +#endif + if (res != 1) { + return err( + make_error_info("toml::parse_floating: " + "failed to read hexadecimal floating point value ", + std::move(src), + "here")); + } + return ok(val); + } + + inline result read_hex_float(const std::string& str, + const source_location src, + double val) { +#if defined(_MSC_VER) && !defined(__clang__) + const auto res = ::sscanf_s(str.c_str(), "%la", std::addressof(val)); +#else + const auto res = std::sscanf(str.c_str(), "%la", std::addressof(val)); +#endif + if (res != 1) { + return err( + make_error_info("toml::parse_floating: " + "failed to read hexadecimal floating point value ", + std::move(src), + "here")); + } + return ok(val); + } + + template + cxx::enable_if_t< + cxx::conjunction, double>>, + cxx::negation, float>>>::value, + result> + read_hex_float(const std::string&, const source_location src, T) { + return err(make_error_info( + "toml::parse_floating: failed to read " + "floating point value because of unknown type in type_config", + std::move(src), + "here")); + } + + template + result read_dec_float(const std::string& str, + const source_location src) { + T val; + std::istringstream iss(str); + iss >> val; + if (iss.fail()) { + return err( + make_error_info("toml::parse_floating: " + "failed to read floating point value from stream", + std::move(src), + "here")); + } + return ok(val); + } + + template + result read_float(const std::string& str, + const source_location src, + const bool is_hex) { + if (is_hex) { + return read_hex_float(str, src, T {}); + } else { + return read_dec_float(str, src); + } + } + + struct type_config { + using comment_type = preserve_comments; + + using boolean_type = bool; + using integer_type = std::int64_t; + using floating_type = double; + using string_type = std::string; + + template + using array_type = std::vector; + template + using table_type = std::unordered_map; + + static result parse_int(const std::string& str, + const source_location src, + const std::uint8_t base) { + return read_int(str, src, base); + } + + static result parse_float(const std::string& str, + const source_location src, + const bool is_hex) { + return read_float(str, src, is_hex); + } + }; + + using value = basic_value; + using table = typename value::table_type; + using array = typename value::array_type; + + struct ordered_type_config { + using comment_type = preserve_comments; + + using boolean_type = bool; + using integer_type = std::int64_t; + using floating_type = double; + using string_type = std::string; + + template + using array_type = std::vector; + template + using table_type = ordered_map; + + static result parse_int(const std::string& str, + const source_location src, + const std::uint8_t base) { + return read_int(str, src, base); + } + + static result parse_float(const std::string& str, + const source_location src, + const bool is_hex) { + return read_float(str, src, is_hex); + } + }; + + using ordered_value = basic_value; + using ordered_table = typename ordered_value::table_type; + using ordered_array = typename ordered_value::array_type; + + // ---------------------------------------------------------------------------- + // meta functions for internal use + + namespace detail { + + // ---------------------------------------------------------------------------- + // check if type T has all the needed member types + + struct has_comment_type_impl { + template + static std::true_type check(typename T::comment_type*); + template + static std::false_type check(...); + }; + + template + using has_comment_type = decltype(has_comment_type_impl::check(nullptr)); + + struct has_integer_type_impl { + template + static std::true_type check(typename T::integer_type*); + template + static std::false_type check(...); + }; + + template + using has_integer_type = decltype(has_integer_type_impl::check(nullptr)); + + struct has_floating_type_impl { + template + static std::true_type check(typename T::floating_type*); + template + static std::false_type check(...); + }; + + template + using has_floating_type = decltype(has_floating_type_impl::check(nullptr)); + + struct has_string_type_impl { + template + static std::true_type check(typename T::string_type*); + template + static std::false_type check(...); + }; + + template + using has_string_type = decltype(has_string_type_impl::check(nullptr)); + + struct has_array_type_impl { + template + static std::true_type check(typename T::template array_type*); + template + static std::false_type check(...); + }; + + template + using has_array_type = decltype(has_array_type_impl::check(nullptr)); + + struct has_table_type_impl { + template + static std::true_type check(typename T::template table_type*); + template + static std::false_type check(...); + }; + + template + using has_table_type = decltype(has_table_type_impl::check(nullptr)); + + struct has_parse_int_impl { + template + static std::true_type check(decltype(std::declval().parse_int( + std::declval(), + std::declval(), + std::declval()))*); + template + static std::false_type check(...); + }; + + template + using has_parse_int = decltype(has_parse_int_impl::check(nullptr)); + + struct has_parse_float_impl { + template + static std::true_type check(decltype(std::declval().parse_float( + std::declval(), + std::declval(), + std::declval()))*); + template + static std::false_type check(...); + }; + + template + using has_parse_float = decltype(has_parse_float_impl::check(nullptr)); + + template + using is_type_config = cxx::conjunction, + has_integer_type, + has_floating_type, + has_string_type, + has_array_type, + has_table_type, + has_parse_int, + has_parse_float>; + + } // namespace detail +} // namespace toml + +#if defined(TOML11_COMPILE_SOURCES) +namespace toml { + extern template class basic_value; + extern template class basic_value; +} // namespace toml +#endif // TOML11_COMPILE_SOURCES + +#endif // TOML11_TYPES_HPP +#ifndef TOML11_GET_HPP +#define TOML11_GET_HPP + +#include + +#if defined(TOML11_HAS_STRING_VIEW) + #include +#endif // string_view + +namespace toml { + + // ============================================================================ + // T is toml::value; identity transformation. + + template + cxx::enable_if_t>::value, T>& get( + basic_value& v) { + return v; + } + + template + const cxx::enable_if_t>::value, T>& get( + const basic_value& v) { + return v; + } + + template + cxx::enable_if_t>::value, T> get( + basic_value&& v) { + return basic_value(std::move(v)); + } + + // ============================================================================ + // exact toml::* type + + template + cxx::enable_if_t>::value, T>& get( + basic_value& v) { + constexpr auto ty = detail::type_to_enum>::value; + return detail::getter::get(v); + } + + template + const cxx::enable_if_t>::value, T>& get( + const basic_value& v) { + constexpr auto ty = detail::type_to_enum>::value; + return detail::getter::get(v); + } + + template + cxx::enable_if_t>::value, T> get( + basic_value&& v) { + constexpr auto ty = detail::type_to_enum>::value; + return detail::getter::get(std::move(v)); + } + + // ============================================================================ + // T is toml::basic_value + + template + cxx::enable_if_t, + cxx::negation>>>::value, + T> + get(basic_value v) { + return T(std::move(v)); + } + + // ============================================================================ + // integer convertible from toml::value::integer_type + + template + cxx::enable_if_t, + cxx::negation>, + detail::is_not_toml_type>, + cxx::negation>, + cxx::negation>>::value, + T> + get(const basic_value& v) { + return static_cast(v.as_integer()); + } + + // ============================================================================ + // floating point convertible from toml::value::floating_type + + template + cxx::enable_if_t, + detail::is_not_toml_type>, + cxx::negation>, + cxx::negation>>::value, + T> + get(const basic_value& v) { + return static_cast(v.as_floating()); + } + + // ============================================================================ + // std::string with different char/trait/allocator + + template + cxx::enable_if_t>, + detail::is_1byte_std_basic_string>::value, + T> + get(const basic_value& v) { + return detail::string_conv>(v.as_string()); + } + + // ============================================================================ + // std::string_view + +#if defined(TOML11_HAS_STRING_VIEW) + + template + cxx::enable_if_t::string_type>::value, T> + get(const basic_value& v) { + return T(v.as_string()); + } + +#endif // string_view + + // ============================================================================ + // std::chrono::duration from toml::local_time + + template + cxx::enable_if_t::value, T> get( + const basic_value& v) { + return std::chrono::duration_cast( + std::chrono::nanoseconds(v.as_local_time())); + } + + // ============================================================================ + // std::chrono::system_clock::time_point from toml::datetime variants + + template + cxx::enable_if_t::value, T> get( + const basic_value& v) { + switch (v.type()) { + case value_t::local_date: { + return std::chrono::system_clock::time_point(v.as_local_date()); + } + case value_t::local_datetime: { + return std::chrono::system_clock::time_point(v.as_local_datetime()); + } + case value_t::offset_datetime: { + return std::chrono::system_clock::time_point(v.as_offset_datetime()); + } + default: { + const auto loc = v.location(); + throw type_error( + format_error("toml::get: " + "bad_cast to std::chrono::system_clock::time_point", + loc, + "the actual type is " + to_string(v.type())), + loc); + } + } + } + + // ============================================================================ + // forward declaration to use this recursively. ignore this and go ahead. + + // array-like (w/ push_back) + template + cxx::enable_if_t< + cxx::conjunction< + detail::is_container, // T is a container + detail::has_push_back_method, // .push_back() works + detail::is_not_toml_type>, // but not toml::array + cxx::negation>, // but not std::basic_string +#if defined(TOML11_HAS_STRING_VIEW) + cxx::negation>, // but not std::basic_string_view +#endif + cxx::negation>, // no T.from_toml() + cxx::negation>, // no toml::from + cxx::negation&>>>::value, + T> + get(const basic_value&); + + // std::array + template + cxx::enable_if_t::value, T> get(const basic_value&); + + // std::forward_list + template + cxx::enable_if_t::value, T> get( + const basic_value&); + + // std::pair + template + cxx::enable_if_t::value, T> get(const basic_value&); + + // std::tuple + template + cxx::enable_if_t::value, T> get(const basic_value&); + + // std::map (key is convertible from toml::value::key_type) + template + cxx::enable_if_t< + cxx::conjunction, // T is map + detail::is_not_toml_type>, // but not toml::table + std::is_convertible::key_type, + typename T::key_type>, // keys are convertible + cxx::negation>, // no T.from_toml() + cxx::negation>, // no toml::from + cxx::negation&>>>::value, + T> + get(const basic_value& v); + + // std::map (key is not convertible from toml::value::key_type, + // but is a std::basic_string) + template + cxx::enable_if_t< + cxx::conjunction< + detail::is_map, // T is map + detail::is_not_toml_type>, // but not toml::table + cxx::negation::key_type, + typename T::key_type>>, // keys are NOT convertible + detail::is_1byte_std_basic_string, // is std::basic_string + cxx::negation>, // no T.from_toml() + cxx::negation>, // no toml::from + cxx::negation&>>>::value, + T> + get(const basic_value& v); + + // toml::from::from_toml(v) + template + cxx::enable_if_t::value, T> get( + const basic_value&); + + // has T.from_toml(v) but no from + template + cxx::enable_if_t, // has T.from_toml() + cxx::negation>, // no toml::from + std::is_default_constructible // T{} works + >::value, + T> + get(const basic_value&); + + // T(const toml::value&) and T is not toml::basic_value, + // and it does not have `from` nor `from_toml`. + template + cxx::enable_if_t&>, // has T(const basic_value&) + cxx::negation>, // but not basic_value itself + cxx::negation>, // no .from_toml() + cxx::negation> // no toml::from + >::value, + T> + get(const basic_value&); + + // ============================================================================ + // array-like types; most likely STL container, like std::vector, etc. + + template + cxx::enable_if_t< + cxx::conjunction< + detail::is_container, // T is a container + detail::has_push_back_method, // .push_back() works + detail::is_not_toml_type>, // but not toml::array + cxx::negation>, // but not std::basic_string +#if defined(TOML11_HAS_STRING_VIEW) + cxx::negation>, // but not std::basic_string_view +#endif + cxx::negation>, // no T.from_toml() + cxx::negation>, // no toml::from + cxx::negation&>>>::value, + T> + get(const basic_value& v) { + using value_type = typename T::value_type; + const auto& a = v.as_array(); + + T container; + detail::try_reserve(container, a.size()); // if T has .reserve(), call it + + for (const auto& elem : a) { + container.push_back(get(elem)); + } + return container; + } + + // ============================================================================ + // std::array + + template + cxx::enable_if_t::value, T> get(const basic_value& v) { + using value_type = typename T::value_type; + const auto& a = v.as_array(); + + T container; + if (a.size() != container.size()) { + const auto loc = v.location(); + throw std::out_of_range( + format_error("toml::get: while converting to an array: " + " array size is " + + std::to_string(container.size()) + " but there are " + + std::to_string(a.size()) + " elements in toml array.", + loc, + "here")); + } + for (std::size_t i = 0; i < a.size(); ++i) { + container.at(i) = ::toml::get(a.at(i)); + } + return container; + } + + // ============================================================================ + // std::forward_list + + template + cxx::enable_if_t::value, T> get( + const basic_value& v) { + using value_type = typename T::value_type; + + T container; + for (const auto& elem : v.as_array()) { + container.push_front(get(elem)); + } + container.reverse(); + return container; + } + + // ============================================================================ + // std::pair + + template + cxx::enable_if_t::value, T> get(const basic_value& v) { + using first_type = typename T::first_type; + using second_type = typename T::second_type; + + const auto& ar = v.as_array(); + if (ar.size() != 2) { + const auto loc = v.location(); + throw std::out_of_range( + format_error("toml::get: while converting std::pair: " + " but there are " + + std::to_string(ar.size()) + " > 2 elements in toml array.", + loc, + "here")); + } + return std::make_pair(::toml::get(ar.at(0)), + ::toml::get(ar.at(1))); + } + + // ============================================================================ + // std::tuple. + + namespace detail { + template + T get_tuple_impl(const Array& a, cxx::index_sequence) { + return std::make_tuple( + ::toml::get::type>(a.at(I))...); + } + } // namespace detail + + template + cxx::enable_if_t::value, T> get(const basic_value& v) { + const auto& ar = v.as_array(); + if (ar.size() != std::tuple_size::value) { + const auto loc = v.location(); + throw std::out_of_range(format_error( + "toml::get: while converting std::tuple: " + " there are " + + std::to_string(ar.size()) + " > " + + std::to_string(std::tuple_size::value) + " elements in toml array.", + loc, + "here")); + } + return detail::get_tuple_impl( + ar, + cxx::make_index_sequence::value> {}); + } + + // ============================================================================ + // map-like types; most likely STL map, like std::map or std::unordered_map. + + // key is convertible from toml::value::key_type + template + cxx::enable_if_t< + cxx::conjunction, // T is map + detail::is_not_toml_type>, // but not toml::table + std::is_convertible::key_type, + typename T::key_type>, // keys are convertible + cxx::negation>, // no T.from_toml() + cxx::negation>, // no toml::from + cxx::negation&>>>::value, + T> + get(const basic_value& v) { + using key_type = typename T::key_type; + using mapped_type = typename T::mapped_type; + static_assert( + std::is_convertible::key_type, key_type>::value, + "toml::get only supports map type of which key_type is " + "convertible from toml::basic_value::key_type."); + + T m; + for (const auto& kv : v.as_table()) { + m.emplace(key_type(kv.first), get(kv.second)); + } + return m; + } + + // key is NOT convertible from toml::value::key_type but std::basic_string + template + cxx::enable_if_t< + cxx::conjunction< + detail::is_map, // T is map + detail::is_not_toml_type>, // but not toml::table + cxx::negation::key_type, + typename T::key_type>>, // keys are NOT convertible + detail::is_1byte_std_basic_string, // is std::basic_string + cxx::negation>, // no T.from_toml() + cxx::negation>, // no toml::from + cxx::negation&>>>::value, + T> + get(const basic_value& v) { + using key_type = typename T::key_type; + using mapped_type = typename T::mapped_type; + + T m; + for (const auto& kv : v.as_table()) { + m.emplace(detail::string_conv(kv.first), + get(kv.second)); + } + return m; + } + + // ============================================================================ + // user-defined, but convertible types. + + // toml::from + template + cxx::enable_if_t::value, T> get( + const basic_value& v) { + return ::toml::from::from_toml(v); + } + + // has T.from_toml(v) but no from + template + cxx::enable_if_t, // has T.from_toml() + cxx::negation>, // no toml::from + std::is_default_constructible // T{} works + >::value, + T> + get(const basic_value& v) { + T ud; + ud.from_toml(v); + return ud; + } + + // T(const toml::value&) and T is not toml::basic_value, + // and it does not have `from` nor `from_toml`. + template + cxx::enable_if_t&>, // has T(const basic_value&) + cxx::negation>, // but not basic_value itself + cxx::negation>, // no .from_toml() + cxx::negation> // no toml::from + >::value, + T> + get(const basic_value& v) { + return T(v); + } + + // ============================================================================ + // get_or(value, fallback) + + template + const cxx::enable_if_t::value, basic_value>& get_or( + const basic_value& v, + const basic_value&) { + return v; + } + + template + cxx::enable_if_t::value, basic_value>& get_or( + basic_value& v, + basic_value&) { + return v; + } + + template + cxx::enable_if_t::value, basic_value> get_or( + basic_value&& v, + basic_value&&) { + return v; + } + + // ---------------------------------------------------------------------------- + // specialization for the exact toml types (return type becomes lvalue ref) + + template + const cxx::enable_if_t>::value, T>& + get_or(const basic_value& v, const T& opt) noexcept { + try { + return get>(v); + } catch (...) { + return opt; + } + } + + template + cxx::enable_if_t>, + detail::is_exact_toml_type>>::value, + T>& + get_or(basic_value& v, T& opt) noexcept { + try { + return get>(v); + } catch (...) { + return opt; + } + } + + template + cxx::enable_if_t, basic_value>::value, + cxx::remove_cvref_t> + get_or(basic_value&& v, T&& opt) noexcept { + try { + return get>(std::move(v)); + } catch (...) { + return cxx::remove_cvref_t(std::forward(opt)); + } + } + + // ---------------------------------------------------------------------------- + // specialization for string literal + + // template + // typename basic_value::string_type + // get_or(const basic_value& v, + // const typename basic_value::string_type::value_type (&opt)[N]) + // { + // try + // { + // return v.as_string(); + // } + // catch(...) + // { + // return typename basic_value::string_type(opt); + // } + // } + // + // The above only matches to the literal, like `get_or(v, "foo");` but not + // ```cpp + // const auto opt = "foo"; + // const auto str = get_or(v, opt); + // ``` + // . And the latter causes an error. + // To match to both `"foo"` and `const auto opt = "foo"`, we take a pointer to + // a character here. + + template + typename basic_value::string_type get_or( + const basic_value& v, + const typename basic_value::string_type::value_type* opt) { + try { + return v.as_string(); + } catch (...) { + return typename basic_value::string_type(opt); + } + } + + // ---------------------------------------------------------------------------- + // others (require type conversion and return type cannot be lvalue reference) + + template + cxx::enable_if_t< + cxx::conjunction< + cxx::negation>, + cxx::negation>>, + cxx::negation, + const typename basic_value::string_type::value_type*>>>::value, + cxx::remove_cvref_t> + get_or(const basic_value& v, T&& opt) { + try { + return get>(v); + } catch (...) { + return cxx::remove_cvref_t(std::forward(opt)); + } + } + +} // namespace toml +#endif // TOML11_GET_HPP +#ifndef TOML11_FIND_HPP +#define TOML11_FIND_HPP + +#include + +#if defined(TOML11_HAS_STRING_VIEW) + #include +#endif + +namespace toml { + + // ---------------------------------------------------------------------------- + // find(value, key); + + template + decltype(::toml::get(std::declval&>())) find( + const basic_value& v, + const typename basic_value::key_type& ky) { + return ::toml::get(v.at(ky)); + } + + template + decltype(::toml::get(std::declval&>())) find( + basic_value& v, + const typename basic_value::key_type& ky) { + return ::toml::get(v.at(ky)); + } + + template + decltype(::toml::get(std::declval&&>())) find( + basic_value&& v, + const typename basic_value::key_type& ky) { + return ::toml::get(std::move(v.at(ky))); + } + + // ---------------------------------------------------------------------------- + // find(value, idx) + + template + decltype(::toml::get(std::declval&>())) find( + const basic_value& v, + const std::size_t idx) { + return ::toml::get(v.at(idx)); + } + + template + decltype(::toml::get(std::declval&>())) find( + basic_value& v, + const std::size_t idx) { + return ::toml::get(v.at(idx)); + } + + template + decltype(::toml::get(std::declval&&>())) find( + basic_value&& v, + const std::size_t idx) { + return ::toml::get(std::move(v.at(idx))); + } + + // ---------------------------------------------------------------------------- + // find(value, key/idx), w/o conversion + + template + cxx::enable_if_t::value, basic_value>& find( + basic_value& v, + const typename basic_value::key_type& ky) { + return v.at(ky); + } + + template + const cxx::enable_if_t::value, basic_value>& find( + const basic_value& v, + const typename basic_value::key_type& ky) { + return v.at(ky); + } + + template + cxx::enable_if_t::value, basic_value> find( + basic_value&& v, + const typename basic_value::key_type& ky) { + return basic_value(std::move(v.at(ky))); + } + + template + cxx::enable_if_t::value, basic_value>& find( + basic_value& v, + const std::size_t idx) { + return v.at(idx); + } + + template + const cxx::enable_if_t::value, basic_value>& find( + const basic_value& v, + const std::size_t idx) { + return v.at(idx); + } + + template + cxx::enable_if_t::value, basic_value> find( + basic_value&& v, + const std::size_t idx) { + return basic_value(std::move(v.at(idx))); + } + + // -------------------------------------------------------------------------- + // toml::find(toml::value, toml::key, Ts&& ... keys) + + namespace detail { + + // It suppresses warnings by -Wsign-conversion when we pass integer literal + // to toml::find. integer literal `0` is deduced as an int, and will be + // converted to std::size_t. This causes sign-conversion. + + template + std::size_t key_cast(const std::size_t& v) noexcept { + return v; + } + + template + cxx::enable_if_t>::value, std::size_t> key_cast( + const T& v) noexcept { + return static_cast(v); + } + + // for string-like (string, string literal, string_view) + + template + const typename basic_value::key_type& key_cast( + const typename basic_value::key_type& v) noexcept { + return v; + } + + template + typename basic_value::key_type key_cast( + const typename basic_value::key_type::value_type* v) { + return typename basic_value::key_type(v); + } +#if defined(TOML11_HAS_STRING_VIEW) + template + typename basic_value::key_type key_cast(const std::string_view v) { + return typename basic_value::key_type(v); + } +#endif // string_view + + } // namespace detail + + // ---------------------------------------------------------------------------- + // find(v, keys...) + + template + const cxx::enable_if_t::value, basic_value>& + find(const basic_value& v, const K1& k1, const K2& k2, const Ks&... ks) { + return find(v.at(detail::key_cast(k1)), detail::key_cast(k2), ks...); + } + + template + cxx::enable_if_t::value, basic_value>& find( + basic_value& v, + const K1& k1, + const K2& k2, + const Ks&... ks) { + return find(v.at(detail::key_cast(k1)), detail::key_cast(k2), ks...); + } + + template + cxx::enable_if_t::value, basic_value> find( + basic_value&& v, + const K1& k1, + const K2& k2, + const Ks&... ks) { + return find(std::move(v.at(detail::key_cast(k1))), + detail::key_cast(k2), + ks...); + } + + // ---------------------------------------------------------------------------- + // find(v, keys...) + + template + decltype(::toml::get(std::declval&>())) find( + const basic_value& v, + const K1& k1, + const K2& k2, + const Ks&... ks) { + return find(v.at(detail::key_cast(k1)), detail::key_cast(k2), ks...); + } + + template + decltype(::toml::get(std::declval&>())) find( + basic_value& v, + const K1& k1, + const K2& k2, + const Ks&... ks) { + return find(v.at(detail::key_cast(k1)), detail::key_cast(k2), ks...); + } + + template + decltype(::toml::get(std::declval&&>())) find( + basic_value&& v, + const K1& k1, + const K2& k2, + const Ks&... ks) { + return find(std::move(v.at(detail::key_cast(k1))), + detail::key_cast(k2), + ks...); + } + + // =========================================================================== + // find_or(value, key, fallback) + + // --------------------------------------------------------------------------- + // find_or(v, key, other_v) + + template + cxx::enable_if_t::value, basic_value>& find_or( + basic_value& v, + const K& k, + basic_value& opt) noexcept { + try { + return ::toml::find(v, detail::key_cast(k)); + } catch (...) { + return opt; + } + } + + template + const cxx::enable_if_t::value, basic_value>& find_or( + const basic_value& v, + const K& k, + const basic_value& opt) noexcept { + try { + return ::toml::find(v, detail::key_cast(k)); + } catch (...) { + return opt; + } + } + + template + cxx::enable_if_t::value, basic_value> find_or( + basic_value&& v, + const K& k, + basic_value&& opt) noexcept { + try { + return ::toml::find(v, detail::key_cast(k)); + } catch (...) { + return opt; + } + } + + // --------------------------------------------------------------------------- + // toml types (return type can be a reference) + + template + cxx::enable_if_t>::value, + const cxx::remove_cvref_t&> + find_or(const basic_value& v, const K& k, const T& opt) { + try { + return ::toml::get(v.at(detail::key_cast(k))); + } catch (...) { + return opt; + } + } + + template + cxx::enable_if_t>, + detail::is_exact_toml_type>>::value, + cxx::remove_cvref_t&> + find_or(basic_value& v, const K& k, T& opt) { + try { + return ::toml::get(v.at(detail::key_cast(k))); + } catch (...) { + return opt; + } + } + + template + cxx::enable_if_t>::value, + cxx::remove_cvref_t> + find_or(basic_value&& v, const K& k, T opt) { + try { + return ::toml::get(std::move(v.at(detail::key_cast(k)))); + } catch (...) { + return T(std::move(opt)); + } + } + + // --------------------------------------------------------------------------- + // string literal (deduced as std::string) + + // XXX to avoid confusion when T is explicitly specified in find_or(), + // we restrict the string type as std::string. + template + cxx::enable_if_t::value, std::string> find_or( + const basic_value& v, + const K& k, + const char* opt) { + try { + return ::toml::get(v.at(detail::key_cast(k))); + } catch (...) { + return std::string(opt); + } + } + + // --------------------------------------------------------------------------- + // other types (requires type conversion and return type cannot be a reference) + + template + cxx::enable_if_t< + cxx::conjunction< + cxx::negation>>, + detail::is_not_toml_type, basic_value>, + cxx::negation, + const typename basic_value::string_type::value_type*>>>::value, + cxx::remove_cvref_t> + find_or(const basic_value& v, const K& ky, T opt) { + try { + return ::toml::get>(v.at(detail::key_cast(ky))); + } catch (...) { + return cxx::remove_cvref_t(std::move(opt)); + } + } + + // ---------------------------------------------------------------------------- + // recursive + + namespace detail { + + template + auto last_one(Ts&&... args) -> decltype(std::get( + std::forward_as_tuple(std::forward(args)...))) { + return std::get( + std::forward_as_tuple(std::forward(args)...)); + } + + } // namespace detail + + template + auto find_or(Value&& v, const K1& k1, const K2& k2, K3&& k3, Ks&&... keys) noexcept + -> cxx::enable_if_t< + detail::is_basic_value>::value, + decltype(find_or(v, k2, std::forward(k3), std::forward(keys)...))> { + try { + return find_or(v.at(k1), k2, std::forward(k3), std::forward(keys)...); + } catch (...) { + return detail::last_one(k3, keys...); + } + } + + template + T find_or(const basic_value& v, + const K1& k1, + const K2& k2, + const K3& k3, + const Ks&... keys) noexcept { + try { + return find_or(v.at(k1), k2, k3, keys...); + } catch (...) { + return static_cast(detail::last_one(k3, keys...)); + } + } + +} // namespace toml +#endif // TOML11_FIND_HPP +#ifndef TOML11_CONVERSION_HPP +#define TOML11_CONVERSION_HPP + +#if defined(TOML11_HAS_OPTIONAL) + + #include + +namespace toml { + namespace detail { + + template + inline constexpr bool is_optional_v = false; + + template + inline constexpr bool is_optional_v> = true; + + template + void find_member_variable_from_value(T& obj, + const basic_value& v, + const char* var_name) { + if constexpr (is_optional_v) { + if (v.contains(var_name)) { + obj = toml::find(v, var_name); + } else { + obj = std::nullopt; + } + } else { + obj = toml::find(v, var_name); + } + } + + template + void assign_member_variable_to_value(const T& obj, + basic_value& v, + const char* var_name) { + if constexpr (is_optional_v) { + if (obj.has_value()) { + v[var_name] = obj.value(); + } + } else { + v[var_name] = obj; + } + } + + } // namespace detail +} // namespace toml + +#else + +namespace toml { + namespace detail { + + template + void find_member_variable_from_value(T& obj, + const basic_value& v, + const char* var_name) { + obj = toml::find(v, var_name); + } + + template + void assign_member_variable_to_value(const T& obj, + basic_value& v, + const char* var_name) { + v[var_name] = obj; + } + + } // namespace detail +} // namespace toml + +#endif // optional + +// use it in the following way. +// ```cpp +// namespace foo +// { +// struct Foo +// { +// std::string s; +// double d; +// int i; +// }; +// } // foo +// +// TOML11_DEFINE_CONVERSION_NON_INTRUSIVE(foo::Foo, s, d, i) +// ``` +// +// And then you can use `toml::get(v)` and `toml::find(file, "foo");` +// + +#define TOML11_STRINGIZE_AUX(x) #x +#define TOML11_STRINGIZE(x) TOML11_STRINGIZE_AUX(x) + +#define TOML11_CONCATENATE_AUX(x, y) x##y +#define TOML11_CONCATENATE(x, y) TOML11_CONCATENATE_AUX(x, y) + +// ============================================================================ +// TOML11_DEFINE_CONVERSION_NON_INTRUSIVE + +#ifndef TOML11_WITHOUT_DEFINE_NON_INTRUSIVE + + // ---------------------------------------------------------------------------- + // TOML11_ARGS_SIZE + + #define TOML11_INDEX_RSEQ() \ + 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, \ + 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + #define TOML11_ARGS_SIZE_IMPL(ARG1, \ + ARG2, \ + ARG3, \ + ARG4, \ + ARG5, \ + ARG6, \ + ARG7, \ + ARG8, \ + ARG9, \ + ARG10, \ + ARG11, \ + ARG12, \ + ARG13, \ + ARG14, \ + ARG15, \ + ARG16, \ + ARG17, \ + ARG18, \ + ARG19, \ + ARG20, \ + ARG21, \ + ARG22, \ + ARG23, \ + ARG24, \ + ARG25, \ + ARG26, \ + ARG27, \ + ARG28, \ + ARG29, \ + ARG30, \ + ARG31, \ + ARG32, \ + N, \ + ...) \ + N + #define TOML11_ARGS_SIZE_AUX(...) TOML11_ARGS_SIZE_IMPL(__VA_ARGS__) + #define TOML11_ARGS_SIZE(...) \ + TOML11_ARGS_SIZE_AUX(__VA_ARGS__, TOML11_INDEX_RSEQ()) + + // ---------------------------------------------------------------------------- + // TOML11_FOR_EACH_VA_ARGS + + #define TOML11_FOR_EACH_VA_ARGS_AUX_1(FUNCTOR, ARG1) FUNCTOR(ARG1) + #define TOML11_FOR_EACH_VA_ARGS_AUX_2(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_1(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_3(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_2(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_4(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_3(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_5(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_4(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_6(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_5(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_7(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_6(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_8(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_7(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_9(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_8(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_10(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_9(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_11(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_10(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_12(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_11(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_13(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_12(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_14(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_13(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_15(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_14(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_16(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_15(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_17(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_16(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_18(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_17(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_19(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_18(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_20(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_19(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_21(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_20(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_22(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_21(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_23(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_22(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_24(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_23(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_25(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_24(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_26(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_25(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_27(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_26(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_28(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_27(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_29(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_28(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_30(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_29(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_31(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_30(FUNCTOR, __VA_ARGS__) + #define TOML11_FOR_EACH_VA_ARGS_AUX_32(FUNCTOR, ARG1, ...) \ + FUNCTOR(ARG1) TOML11_FOR_EACH_VA_ARGS_AUX_31(FUNCTOR, __VA_ARGS__) + + #define TOML11_FOR_EACH_VA_ARGS(FUNCTOR, ...) \ + TOML11_CONCATENATE(TOML11_FOR_EACH_VA_ARGS_AUX_, \ + TOML11_ARGS_SIZE(__VA_ARGS__)) \ + (FUNCTOR, __VA_ARGS__) + + #define TOML11_FIND_MEMBER_VARIABLE_FROM_VALUE(VAR_NAME) \ + toml::detail::find_member_variable_from_value(obj.VAR_NAME, \ + v, \ + TOML11_STRINGIZE(VAR_NAME)); + + #define TOML11_ASSIGN_MEMBER_VARIABLE_TO_VALUE(VAR_NAME) \ + toml::detail::assign_member_variable_to_value(obj.VAR_NAME, \ + v, \ + TOML11_STRINGIZE(VAR_NAME)); + + #define TOML11_DEFINE_CONVERSION_NON_INTRUSIVE(NAME, ...) \ + namespace toml { \ + template <> \ + struct from { \ + template \ + static NAME from_toml(const basic_value& v) { \ + NAME obj; \ + TOML11_FOR_EACH_VA_ARGS(TOML11_FIND_MEMBER_VARIABLE_FROM_VALUE, \ + __VA_ARGS__) \ + return obj; \ + } \ + }; \ + template <> \ + struct into { \ + template \ + static basic_value into_toml(const NAME& obj) { \ + ::toml::basic_value v = typename ::toml::basic_value::table_type {}; \ + TOML11_FOR_EACH_VA_ARGS(TOML11_ASSIGN_MEMBER_VARIABLE_TO_VALUE, \ + __VA_ARGS__) \ + return v; \ + } \ + }; \ + } /* toml */ + +#endif // TOML11_WITHOUT_DEFINE_NON_INTRUSIVE + +#endif // TOML11_CONVERSION_HPP +#ifndef TOML11_CONTEXT_HPP +#define TOML11_CONTEXT_HPP + +#include + +namespace toml { + namespace detail { + + template + class context { + public: + explicit context(const spec& toml_spec) + : toml_spec_(toml_spec) + , errors_ {} {} + + bool has_error() const noexcept { + return !errors_.empty(); + } + + const std::vector& errors() const noexcept { + return errors_; + } + + semantic_version& toml_version() noexcept { + return toml_spec_.version; + } + + const semantic_version& toml_version() const noexcept { + return toml_spec_.version; + } + + spec& toml_spec() noexcept { + return toml_spec_; + } + + const spec& toml_spec() const noexcept { + return toml_spec_; + } + + void report_error(error_info err) { + this->errors_.push_back(std::move(err)); + } + + error_info pop_last_error() { + assert(!errors_.empty()); + auto e = std::move(errors_.back()); + errors_.pop_back(); + return e; + } + + private: + spec toml_spec_; + std::vector errors_; + }; + + } // namespace detail +} // namespace toml + +#if defined(TOML11_COMPILE_SOURCES) +namespace toml { + struct type_config; + struct ordered_type_config; + + namespace detail { + extern template class context<::toml::type_config>; + extern template class context<::toml::ordered_type_config>; + } // namespace detail +} // namespace toml +#endif // TOML11_COMPILE_SOURCES + +#endif // TOML11_CONTEXT_HPP +#ifndef TOML11_SCANNER_HPP +#define TOML11_SCANNER_HPP + +#ifndef TOML11_SCANNER_FWD_HPP + #define TOML11_SCANNER_FWD_HPP + + #include + #include + #include + #include + #include + #include + #include + +namespace toml { + namespace detail { + + class scanner_base { + public: + virtual ~scanner_base() = default; + virtual region scan(location& loc) const = 0; + virtual scanner_base* clone() const = 0; + + // returns expected character or set of characters or literal. + // to show the error location, it changes loc (in `sequence`, especially). + virtual std::string expected_chars(location& loc) const = 0; + virtual std::string name() const = 0; + }; + + // make `scanner*` copyable + struct scanner_storage { + template >::value, + std::nullptr_t> = nullptr> + explicit scanner_storage(Scanner&& s) + : scanner_(cxx::make_unique>( + std::forward(s))) {} + + ~scanner_storage() = default; + + scanner_storage(const scanner_storage& other); + scanner_storage& operator=(const scanner_storage& other); + scanner_storage(scanner_storage&&) = default; + scanner_storage& operator=(scanner_storage&&) = default; + + bool is_ok() const noexcept { + return static_cast(scanner_); + } + + region scan(location& loc) const; + + std::string expected_chars(location& loc) const; + + scanner_base& get() const noexcept; + + std::string name() const; + + private: + std::unique_ptr scanner_; + }; + + // ---------------------------------------------------------------------------- + + class character final : public scanner_base { + public: + using char_type = location::char_type; + + public: + explicit character(const char_type c) noexcept : value_(c) {} + + ~character() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location&) const override; + + scanner_base* clone() const override; + + std::string name() const override; + + private: + char_type value_; + }; + + // ---------------------------------------------------------------------------- + + class character_either final : public scanner_base { + public: + using char_type = location::char_type; + + public: + explicit character_either(std::initializer_list cs) noexcept + : chars_(std::move(cs)) { + assert(!this->chars_.empty()); + } + + template + explicit character_either(const char (&cs)[N]) noexcept + : chars_(N - 1, '\0') { + static_assert(N >= 1, ""); + for (std::size_t i = 0; i + 1 < N; ++i) { + chars_.at(i) = char_type(cs[i]); + } + } + + ~character_either() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location&) const override; + + scanner_base* clone() const override; + + void push_back(const char_type c); + + std::string name() const override; + + private: + std::vector chars_; + }; + + // ---------------------------------------------------------------------------- + + class character_in_range final : public scanner_base { + public: + using char_type = location::char_type; + + public: + explicit character_in_range(const char_type from, const char_type to) noexcept + : from_(from) + , to_(to) {} + + ~character_in_range() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location&) const override; + + scanner_base* clone() const override; + + std::string name() const override; + + private: + char_type from_; + char_type to_; + }; + + // ---------------------------------------------------------------------------- + + class literal final : public scanner_base { + public: + using char_type = location::char_type; + + public: + template + explicit literal(const char (&cs)[N]) noexcept + : value_(cs) + , size_(N - 1) // remove null character at the end + {} + + ~literal() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location&) const override; + + scanner_base* clone() const override; + + std::string name() const override; + + private: + const char* value_; + std::size_t size_; + }; + + // ---------------------------------------------------------------------------- + + class sequence final : public scanner_base { + public: + using char_type = location::char_type; + + public: + template + explicit sequence(Ts&&... args) { + push_back_all(std::forward(args)...); + } + + sequence(const sequence&) = default; + sequence(sequence&&) = default; + sequence& operator=(const sequence&) = default; + sequence& operator=(sequence&&) = default; + ~sequence() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location& loc) const override; + + scanner_base* clone() const override; + + template + void push_back(Scanner&& other_scanner) { + this->others_.emplace_back(std::forward(other_scanner)); + } + + std::string name() const override; + + private: + void push_back_all() { + return; + } + + template + void push_back_all(T&& head, Ts&&... args) { + others_.emplace_back(std::forward(head)); + push_back_all(std::forward(args)...); + return; + } + + private: + std::vector others_; + }; + + // ---------------------------------------------------------------------------- + + class either final : public scanner_base { + public: + using char_type = location::char_type; + + public: + template + explicit either(Ts&&... args) { + push_back_all(std::forward(args)...); + } + + either(const either&) = default; + either(either&&) = default; + either& operator=(const either&) = default; + either& operator=(either&&) = default; + ~either() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location& loc) const override; + + scanner_base* clone() const override; + + template + void push_back(Scanner&& other_scanner) { + this->others_.emplace_back(std::forward(other_scanner)); + } + + std::string name() const override; + + private: + void push_back_all() { + return; + } + + template + void push_back_all(T&& head, Ts&&... args) { + others_.emplace_back(std::forward(head)); + push_back_all(std::forward(args)...); + return; + } + + private: + std::vector others_; + }; + + // ---------------------------------------------------------------------------- + + class repeat_exact final : public scanner_base { + public: + using char_type = location::char_type; + + public: + template + repeat_exact(const std::size_t length, Scanner&& other) + : length_(length) + , other_(std::forward(other)) {} + + repeat_exact(const repeat_exact&) = default; + repeat_exact(repeat_exact&&) = default; + repeat_exact& operator=(const repeat_exact&) = default; + repeat_exact& operator=(repeat_exact&&) = default; + ~repeat_exact() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location& loc) const override; + + scanner_base* clone() const override; + + std::string name() const override; + + private: + std::size_t length_; + scanner_storage other_; + }; + + // ---------------------------------------------------------------------------- + + class repeat_at_least final : public scanner_base { + public: + using char_type = location::char_type; + + public: + template + repeat_at_least(const std::size_t length, Scanner&& s) + : length_(length) + , other_(std::forward(s)) {} + + repeat_at_least(const repeat_at_least&) = default; + repeat_at_least(repeat_at_least&&) = default; + repeat_at_least& operator=(const repeat_at_least&) = default; + repeat_at_least& operator=(repeat_at_least&&) = default; + ~repeat_at_least() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location& loc) const override; + + scanner_base* clone() const override; + + std::string name() const override; + + private: + std::size_t length_; + scanner_storage other_; + }; + + // ---------------------------------------------------------------------------- + + class maybe final : public scanner_base { + public: + using char_type = location::char_type; + + public: + template + explicit maybe(Scanner&& s) : other_(std::forward(s)) {} + + maybe(const maybe&) = default; + maybe(maybe&&) = default; + maybe& operator=(const maybe&) = default; + maybe& operator=(maybe&&) = default; + ~maybe() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location&) const override; + + scanner_base* clone() const override; + + std::string name() const override; + + private: + scanner_storage other_; + }; + + } // namespace detail +} // namespace toml +#endif // TOML11_SCANNER_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_SCANNER_IMPL_HPP + #define TOML11_SCANNER_IMPL_HPP + +namespace toml { + namespace detail { + + TOML11_INLINE scanner_storage::scanner_storage(const scanner_storage& other) + : scanner_(nullptr) { + if (other.is_ok()) { + scanner_.reset(other.get().clone()); + } + } + + TOML11_INLINE scanner_storage& scanner_storage::operator=( + const scanner_storage& other) { + if (this == std::addressof(other)) { + return *this; + } + if (other.is_ok()) { + scanner_.reset(other.get().clone()); + } + return *this; + } + + TOML11_INLINE region scanner_storage::scan(location& loc) const { + assert(this->is_ok()); + return this->scanner_->scan(loc); + } + + TOML11_INLINE std::string scanner_storage::expected_chars(location& loc) const { + assert(this->is_ok()); + return this->scanner_->expected_chars(loc); + } + + TOML11_INLINE scanner_base& scanner_storage::get() const noexcept { + assert(this->is_ok()); + return *scanner_; + } + + TOML11_INLINE std::string scanner_storage::name() const { + assert(this->is_ok()); + return this->scanner_->name(); + } + + // ---------------------------------------------------------------------------- + + TOML11_INLINE region character::scan(location& loc) const { + if (loc.eof()) { + return region {}; + } + + if (loc.current() == this->value_) { + const auto first = loc; + loc.advance(1); + return region(first, loc); + } + return region {}; + } + + TOML11_INLINE std::string character::expected_chars(location&) const { + return show_char(value_); + } + + TOML11_INLINE scanner_base* character::clone() const { + return new character(*this); + } + + TOML11_INLINE std::string character::name() const { + return "character{" + show_char(value_) + "}"; + } + + // ---------------------------------------------------------------------------- + + TOML11_INLINE region character_either::scan(location& loc) const { + if (loc.eof()) { + return region {}; + } + + for (const auto c : this->chars_) { + if (loc.current() == c) { + const auto first = loc; + loc.advance(1); + return region(first, loc); + } + } + return region {}; + } + + TOML11_INLINE std::string character_either::expected_chars(location&) const { + assert(!chars_.empty()); + + std::string expected; + if (chars_.size() == 1) { + expected += show_char(chars_.at(0)); + } else if (chars_.size() == 2) { + expected += show_char(chars_.at(0)) + " or " + show_char(chars_.at(1)); + } else { + for (std::size_t i = 0; i < chars_.size(); ++i) { + if (i != 0) { + expected += ", "; + } + if (i + 1 == chars_.size()) { + expected += "or "; + } + expected += show_char(chars_.at(i)); + } + } + return expected; + } + + TOML11_INLINE scanner_base* character_either::clone() const { + return new character_either(*this); + } + + TOML11_INLINE void character_either::push_back(const char_type c) { + chars_.push_back(c); + } + + TOML11_INLINE std::string character_either::name() const { + std::string n("character_either{"); + for (const auto c : this->chars_) { + n += show_char(c); + n += ", "; + } + if (!this->chars_.empty()) { + n.pop_back(); + n.pop_back(); + } + n += "}"; + return n; + } + + // ---------------------------------------------------------------------------- + // character_in_range + + TOML11_INLINE region character_in_range::scan(location& loc) const { + if (loc.eof()) { + return region {}; + } + + const auto curr = loc.current(); + if (this->from_ <= curr && curr <= this->to_) { + const auto first = loc; + loc.advance(1); + return region(first, loc); + } + return region {}; + } + + TOML11_INLINE std::string character_in_range::expected_chars(location&) const { + std::string expected("from `"); + expected += show_char(from_); + expected += "` to `"; + expected += show_char(to_); + expected += "`"; + return expected; + } + + TOML11_INLINE scanner_base* character_in_range::clone() const { + return new character_in_range(*this); + } + + TOML11_INLINE std::string character_in_range::name() const { + return "character_in_range{" + show_char(from_) + "," + show_char(to_) + "}"; + } + + // ---------------------------------------------------------------------------- + // literal + + TOML11_INLINE region literal::scan(location& loc) const { + const auto first = loc; + for (std::size_t i = 0; i < size_; ++i) { + if (loc.eof() || char_type(value_[i]) != loc.current()) { + loc = first; + return region {}; + } + loc.advance(1); + } + return region(first, loc); + } + + TOML11_INLINE std::string literal::expected_chars(location&) const { + return std::string(value_); + } + + TOML11_INLINE scanner_base* literal::clone() const { + return new literal(*this); + } + + TOML11_INLINE std::string literal::name() const { + return std::string("literal{") + std::string(value_, size_) + "}"; + } + + // ---------------------------------------------------------------------------- + // sequence + + TOML11_INLINE region sequence::scan(location& loc) const { + const auto first = loc; + for (const auto& other : others_) { + const auto reg = other.scan(loc); + if (!reg.is_ok()) { + loc = first; + return region {}; + } + } + return region(first, loc); + } + + TOML11_INLINE std::string sequence::expected_chars(location& loc) const { + const auto first = loc; + for (const auto& other : others_) { + const auto reg = other.scan(loc); + if (!reg.is_ok()) { + return other.expected_chars(loc); + } + } + assert(false); + return ""; // XXX + } + + TOML11_INLINE scanner_base* sequence::clone() const { + return new sequence(*this); + } + + TOML11_INLINE std::string sequence::name() const { + std::string n("sequence{"); + for (const auto& other : others_) { + n += other.name(); + n += ", "; + } + if (!this->others_.empty()) { + n.pop_back(); + n.pop_back(); + } + n += "}"; + return n; + } + + // ---------------------------------------------------------------------------- + // either + + TOML11_INLINE region either::scan(location& loc) const { + for (const auto& other : others_) { + const auto reg = other.scan(loc); + if (reg.is_ok()) { + return reg; + } + } + return region {}; + } + + TOML11_INLINE std::string either::expected_chars(location& loc) const { + assert(!others_.empty()); + + std::string expected = others_.at(0).expected_chars(loc); + if (others_.size() == 2) { + expected += " or "; + expected += others_.at(1).expected_chars(loc); + } else { + for (std::size_t i = 1; i < others_.size(); ++i) { + expected += ", "; + if (i + 1 == others_.size()) { + expected += "or "; + } + expected += others_.at(i).expected_chars(loc); + } + } + return expected; + } + + TOML11_INLINE scanner_base* either::clone() const { + return new either(*this); + } + + TOML11_INLINE std::string either::name() const { + std::string n("either{"); + for (const auto& other : others_) { + n += other.name(); + n += ", "; + } + if (!this->others_.empty()) { + n.pop_back(); + n.pop_back(); + } + n += "}"; + return n; + } + + // ---------------------------------------------------------------------------- + // repeat_exact + + TOML11_INLINE region repeat_exact::scan(location& loc) const { + const auto first = loc; + for (std::size_t i = 0; i < length_; ++i) { + const auto reg = other_.scan(loc); + if (!reg.is_ok()) { + loc = first; + return region {}; + } + } + return region(first, loc); + } + + TOML11_INLINE std::string repeat_exact::expected_chars(location& loc) const { + for (std::size_t i = 0; i < length_; ++i) { + const auto reg = other_.scan(loc); + if (!reg.is_ok()) { + return other_.expected_chars(loc); + } + } + assert(false); + return ""; + } + + TOML11_INLINE scanner_base* repeat_exact::clone() const { + return new repeat_exact(*this); + } + + TOML11_INLINE std::string repeat_exact::name() const { + return "repeat_exact{" + std::to_string(length_) + ", " + other_.name() + "}"; + } + + // ---------------------------------------------------------------------------- + // repeat_at_least + + TOML11_INLINE region repeat_at_least::scan(location& loc) const { + const auto first = loc; + for (std::size_t i = 0; i < length_; ++i) { + const auto reg = other_.scan(loc); + if (!reg.is_ok()) { + loc = first; + return region {}; + } + } + while (!loc.eof()) { + const auto checkpoint = loc; + const auto reg = other_.scan(loc); + if (!reg.is_ok()) { + loc = checkpoint; + return region(first, loc); + } + } + return region(first, loc); + } + + TOML11_INLINE std::string repeat_at_least::expected_chars(location& loc) const { + for (std::size_t i = 0; i < length_; ++i) { + const auto reg = other_.scan(loc); + if (!reg.is_ok()) { + return other_.expected_chars(loc); + } + } + assert(false); + return ""; + } + + TOML11_INLINE scanner_base* repeat_at_least::clone() const { + return new repeat_at_least(*this); + } + + TOML11_INLINE std::string repeat_at_least::name() const { + return "repeat_at_least{" + std::to_string(length_) + ", " + + other_.name() + "}"; + } + + // ---------------------------------------------------------------------------- + // maybe + + TOML11_INLINE region maybe::scan(location& loc) const { + const auto first = loc; + const auto reg = other_.scan(loc); + if (!reg.is_ok()) { + loc = first; + } + return region(first, loc); + } + + TOML11_INLINE std::string maybe::expected_chars(location&) const { + return ""; + } + + TOML11_INLINE scanner_base* maybe::clone() const { + return new maybe(*this); + } + + TOML11_INLINE std::string maybe::name() const { + return "maybe{" + other_.name() + "}"; + } + + } // namespace detail +} // namespace toml + #endif // TOML11_SCANNER_IMPL_HPP +#endif + +#endif // TOML11_SCANNER_HPP +#ifndef TOML11_SYNTAX_HPP +#define TOML11_SYNTAX_HPP + +#ifndef TOML11_SYNTAX_FWD_HPP + #define TOML11_SYNTAX_FWD_HPP + +namespace toml { + namespace detail { + namespace syntax { + + using char_type = location::char_type; + + // =========================================================================== + // UTF-8 + + // avoid redundant representation and out-of-unicode sequence + + character_in_range utf8_1byte(const spec&); + sequence utf8_2bytes(const spec&); + sequence utf8_3bytes(const spec&); + sequence utf8_4bytes(const spec&); + + class non_ascii final : public scanner_base { + public: + using char_type = location::char_type; + + public: + explicit non_ascii(const spec& s) noexcept; + ~non_ascii() override = default; + + region scan(location& loc) const override { + return scanner_.scan(loc); + } + + std::string expected_chars(location&) const override { + return "non-ascii utf-8 bytes"; + } + + scanner_base* clone() const override { + return new non_ascii(*this); + } + + std::string name() const override { + return "non_ascii"; + } + + private: + either scanner_; + }; + + // =========================================================================== + // Whitespace + + character_either wschar(const spec&); + + repeat_at_least ws(const spec& s); + + // =========================================================================== + // Newline + + either newline(const spec&); + + // =========================================================================== + // Comments + + either allowed_comment_char(const spec& s); + + // XXX Note that it does not take newline + sequence comment(const spec& s); + + // =========================================================================== + // Boolean + + either boolean(const spec&); + + // =========================================================================== + // Integer + + class digit final : public scanner_base { + public: + using char_type = location::char_type; + + public: + explicit digit(const spec&) noexcept; + ~digit() override = default; + + region scan(location& loc) const override { + return scanner_.scan(loc); + } + + std::string expected_chars(location&) const override { + return "digit [0-9]"; + } + + scanner_base* clone() const override { + return new digit(*this); + } + + std::string name() const override { + return "digit"; + } + + private: + character_in_range scanner_; + }; + + class alpha final : public scanner_base { + public: + using char_type = location::char_type; + + public: + explicit alpha(const spec&) noexcept; + ~alpha() override = default; + + region scan(location& loc) const override { + return scanner_.scan(loc); + } + + std::string expected_chars(location&) const override { + return "alpha [a-zA-Z]"; + } + + scanner_base* clone() const override { + return new alpha(*this); + } + + std::string name() const override { + return "alpha"; + } + + private: + either scanner_; + }; + + class hexdig final : public scanner_base { + public: + using char_type = location::char_type; + + public: + explicit hexdig(const spec& s) noexcept; + ~hexdig() override = default; + + region scan(location& loc) const override { + return scanner_.scan(loc); + } + + std::string expected_chars(location&) const override { + return "hex [0-9a-fA-F]"; + } + + scanner_base* clone() const override { + return new hexdig(*this); + } + + std::string name() const override { + return "hexdig"; + } + + private: + either scanner_; + }; + + sequence num_suffix(const spec& s); + + sequence dec_int(const spec& s); + sequence hex_int(const spec& s); + sequence oct_int(const spec&); + sequence bin_int(const spec&); + either integer(const spec& s); + + // =========================================================================== + // Floating + + sequence zero_prefixable_int(const spec& s); + sequence fractional_part(const spec& s); + sequence exponent_part(const spec& s); + sequence hex_floating(const spec& s); + either floating(const spec& s); + + // =========================================================================== + // Datetime + + sequence local_date(const spec& s); + sequence local_time(const spec& s); + either time_offset(const spec& s); + sequence full_time(const spec& s); + character_either time_delim(const spec&); + sequence local_datetime(const spec& s); + sequence offset_datetime(const spec& s); + + // =========================================================================== + // String + + sequence escaped(const spec& s); + + either basic_char(const spec& s); + + sequence basic_string(const spec& s); + + // --------------------------------------------------------------------------- + // multiline string + + sequence escaped_newline(const spec& s); + sequence ml_basic_string(const spec& s); + + // --------------------------------------------------------------------------- + // literal string + + either literal_char(const spec& s); + sequence literal_string(const spec& s); + + sequence ml_literal_string(const spec& s); + + either string(const spec& s); + + // =========================================================================== + // Keys + + // to keep `expected_chars` simple + class non_ascii_key_char final : public scanner_base { + public: + using char_type = location::char_type; + + private: + using in_range = character_in_range; // make definition short + + public: + explicit non_ascii_key_char(const spec& s) noexcept; + ~non_ascii_key_char() override = default; + + region scan(location& loc) const override; + + std::string expected_chars(location&) const override { + return "bare key non-ASCII script"; + } + + scanner_base* clone() const override { + return new non_ascii_key_char(*this); + } + + std::string name() const override { + return "non-ASCII bare key"; + } + + private: + std::uint32_t read_utf8(location& loc) const; + }; + + repeat_at_least unquoted_key(const spec& s); + + either quoted_key(const spec& s); + + either simple_key(const spec& s); + + sequence dot_sep(const spec& s); + + sequence dotted_key(const spec& s); + + class key final : public scanner_base { + public: + using char_type = location::char_type; + + public: + explicit key(const spec& s) noexcept; + ~key() override = default; + + region scan(location& loc) const override { + return scanner_.scan(loc); + } + + std::string expected_chars(location&) const override { + return "basic key([a-zA-Z0-9_-]) or quoted key(\" or ')"; + } + + scanner_base* clone() const override { + return new key(*this); + } + + std::string name() const override { + return "key"; + } + + private: + either scanner_; + }; + + sequence keyval_sep(const spec& s); + + // =========================================================================== + // Table key + + sequence std_table(const spec& s); + + sequence array_table(const spec& s); + + // =========================================================================== + // extension: null + + literal null_value(const spec&); + + } // namespace syntax + } // namespace detail +} // namespace toml +#endif // TOML11_SYNTAX_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_SYNTAX_IMPL_HPP + #define TOML11_SYNTAX_IMPL_HPP + +namespace toml { + namespace detail { + namespace syntax { + + using char_type = location::char_type; + + // =========================================================================== + // UTF-8 + + // avoid redundant representation and out-of-unicode sequence + + TOML11_INLINE character_in_range utf8_1byte(const spec&) { + return character_in_range(0x00, 0x7F); + } + + TOML11_INLINE sequence utf8_2bytes(const spec&) { + return sequence(character_in_range(0xC2, 0xDF), + character_in_range(0x80, 0xBF)); + } + + TOML11_INLINE sequence utf8_3bytes(const spec&) { + return sequence( + /*1~2 bytes = */ either( + sequence(character(0xE0), character_in_range(0xA0, 0xBF)), + sequence(character_in_range(0xE1, 0xEC), character_in_range(0x80, 0xBF)), + sequence(character(0xED), character_in_range(0x80, 0x9F)), + sequence(character_in_range(0xEE, 0xEF), + character_in_range(0x80, 0xBF))), + /*3rd byte = */ character_in_range(0x80, 0xBF)); + } + + TOML11_INLINE sequence utf8_4bytes(const spec&) { + return sequence( + /*1~2 bytes = */ either( + sequence(character(0xF0), character_in_range(0x90, 0xBF)), + sequence(character_in_range(0xF1, 0xF3), character_in_range(0x80, 0xBF)), + sequence(character(0xF4), character_in_range(0x80, 0x8F))), + character_in_range(0x80, 0xBF), + character_in_range(0x80, 0xBF)); + } + + TOML11_INLINE non_ascii::non_ascii(const spec& s) noexcept + : scanner_(utf8_2bytes(s), utf8_3bytes(s), utf8_4bytes(s)) {} + + // =========================================================================== + // Whitespace + + TOML11_INLINE character_either wschar(const spec&) { + return character_either { char_type(' '), char_type('\t') }; + } + + TOML11_INLINE repeat_at_least ws(const spec& s) { + return repeat_at_least(0, wschar(s)); + } + + // =========================================================================== + // Newline + + TOML11_INLINE either newline(const spec&) { + return either(character(char_type('\n')), literal("\r\n")); + } + + // =========================================================================== + // Comments + + TOML11_INLINE either allowed_comment_char(const spec& s) { + if (s.v1_1_0_allow_control_characters_in_comments) { + return either(character_in_range(0x01, 0x09), + character_in_range(0x0E, 0x7F), + non_ascii(s)); + } else { + return either(character(0x09), + character_in_range(0x20, 0x7E), + non_ascii(s)); + } + } + + // XXX Note that it does not take newline + TOML11_INLINE sequence comment(const spec& s) { + return sequence(character(char_type('#')), + repeat_at_least(0, allowed_comment_char(s))); + } + + // =========================================================================== + // Boolean + + TOML11_INLINE either boolean(const spec&) { + return either(literal("true"), literal("false")); + } + + // =========================================================================== + // Integer + + TOML11_INLINE digit::digit(const spec&) noexcept + : scanner_(char_type('0'), char_type('9')) {} + + TOML11_INLINE alpha::alpha(const spec&) noexcept + : scanner_(character_in_range(char_type('a'), char_type('z')), + character_in_range(char_type('A'), char_type('Z'))) {} + + TOML11_INLINE hexdig::hexdig(const spec& s) noexcept + : scanner_(digit(s), + character_in_range(char_type('a'), char_type('f')), + character_in_range(char_type('A'), char_type('F'))) {} + + // non-digit-graph = ([a-zA-Z]|unicode mb char) + // graph = ([a-zA-Z0-9]|unicode mb char) + // suffix = _ non-digit-graph (graph | _graph) + TOML11_INLINE sequence num_suffix(const spec& s) { + const auto non_digit_graph = [&s]() { + return either(alpha(s), non_ascii(s)); + }; + const auto graph = [&s]() { + return either(alpha(s), digit(s), non_ascii(s)); + }; + + return sequence( + character(char_type('_')), + non_digit_graph(), + repeat_at_least( + 0, + either(sequence(character(char_type('_')), graph()), graph()))); + } + + TOML11_INLINE sequence dec_int(const spec& s) { + const auto digit19 = []() { + return character_in_range(char_type('1'), char_type('9')); + }; + return sequence( + maybe(character_either { char_type('-'), char_type('+') }), + either(sequence(digit19(), + repeat_at_least( + 1, + either(digit(s), + sequence(character(char_type('_')), digit(s))))), + digit(s))); + } + + TOML11_INLINE sequence hex_int(const spec& s) { + return sequence( + literal("0x"), + hexdig(s), + repeat_at_least( + 0, + either(hexdig(s), sequence(character(char_type('_')), hexdig(s))))); + } + + TOML11_INLINE sequence oct_int(const spec&) { + const auto digit07 = []() { + return character_in_range(char_type('0'), char_type('7')); + }; + return sequence( + literal("0o"), + digit07(), + repeat_at_least( + 0, + either(digit07(), sequence(character(char_type('_')), digit07())))); + } + + TOML11_INLINE sequence bin_int(const spec&) { + const auto digit01 = []() { + return character_either { char_type('0'), char_type('1') }; + }; + return sequence( + literal("0b"), + digit01(), + repeat_at_least( + 0, + either(digit01(), sequence(character(char_type('_')), digit01())))); + } + + TOML11_INLINE either integer(const spec& s) { + return either(hex_int(s), oct_int(s), bin_int(s), dec_int(s)); + } + + // =========================================================================== + // Floating + + TOML11_INLINE sequence zero_prefixable_int(const spec& s) { + return sequence( + digit(s), + repeat_at_least(0, either(digit(s), sequence(character('_'), digit(s))))); + } + + TOML11_INLINE sequence fractional_part(const spec& s) { + return sequence(character('.'), zero_prefixable_int(s)); + } + + TOML11_INLINE sequence exponent_part(const spec& s) { + return sequence(character_either { char_type('e'), char_type('E') }, + maybe(character_either { char_type('+'), char_type('-') }), + zero_prefixable_int(s)); + } + + TOML11_INLINE sequence hex_floating(const spec& s) { + // C99 hexfloat (%a) + // [+-]? 0x ( [0-9a-fA-F]*\.[0-9a-fA-F]+ | [0-9a-fA-F]+\.? ) [pP] [+-]? [0-9]+ + + // - 0x(int).(frac)p[+-](int) + // - 0x(int).p[+-](int) + // - 0x.(frac)p[+-](int) + // - 0x(int)p[+-](int) + + return sequence( + maybe(character_either { char_type('+'), char_type('-') }), + character('0'), + character_either { char_type('x'), char_type('X') }, + either(sequence(repeat_at_least(0, hexdig(s)), + character('.'), + repeat_at_least(1, hexdig(s))), + sequence(repeat_at_least(1, hexdig(s)), maybe(character('.')))), + character_either { char_type('p'), char_type('P') }, + maybe(character_either { char_type('+'), char_type('-') }), + repeat_at_least(1, character_in_range('0', '9'))); + } + + TOML11_INLINE either floating(const spec& s) { + return either( + sequence(dec_int(s), + either(exponent_part(s), + sequence(fractional_part(s), maybe(exponent_part(s))))), + sequence(maybe(character_either { char_type('-'), char_type('+') }), + either(literal("inf"), literal("nan")))); + } + + // =========================================================================== + // Datetime + + TOML11_INLINE sequence local_date(const spec& s) { + return sequence(repeat_exact(4, digit(s)), + character('-'), + repeat_exact(2, digit(s)), + character('-'), + repeat_exact(2, digit(s))); + } + + TOML11_INLINE sequence local_time(const spec& s) { + auto time = sequence(repeat_exact(2, digit(s)), + character(':'), + repeat_exact(2, digit(s))); + + if (s.v1_1_0_make_seconds_optional) { + time.push_back(maybe(sequence( + character(':'), + repeat_exact(2, digit(s)), + maybe(sequence(character('.'), repeat_at_least(1, digit(s))))))); + } else { + time.push_back(character(':')); + time.push_back(repeat_exact(2, digit(s))); + time.push_back( + maybe(sequence(character('.'), repeat_at_least(1, digit(s))))); + } + + return time; + } + + TOML11_INLINE either time_offset(const spec& s) { + return either(character_either { 'Z', 'z' }, + sequence(character_either { '+', '-' }, + repeat_exact(2, digit(s)), + character(':'), + repeat_exact(2, digit(s)))); + } + + TOML11_INLINE sequence full_time(const spec& s) { + return sequence(local_time(s), time_offset(s)); + } + + TOML11_INLINE character_either time_delim(const spec&) { + return character_either { 'T', 't', ' ' }; + } + + TOML11_INLINE sequence local_datetime(const spec& s) { + return sequence(local_date(s), time_delim(s), local_time(s)); + } + + TOML11_INLINE sequence offset_datetime(const spec& s) { + return sequence(local_date(s), time_delim(s), full_time(s)); + } + + // =========================================================================== + // String + + TOML11_INLINE sequence escaped(const spec& s) { + character_either escape_char { '\"', '\\', 'b', 'f', 'n', 'r', 't' }; + if (s.v1_1_0_add_escape_sequence_e) { + escape_char.push_back(char_type('e')); + } + + either escape_seq(std::move(escape_char), + sequence(character('u'), repeat_exact(4, hexdig(s))), + sequence(character('U'), repeat_exact(8, hexdig(s)))); + + if (s.v1_1_0_add_escape_sequence_x) { + escape_seq.push_back( + sequence(character('x'), repeat_exact(2, hexdig(s)))); + } + + return sequence(character('\\'), std::move(escape_seq)); + } + + TOML11_INLINE either basic_char(const spec& s) { + const auto basic_unescaped = [&s]() { + return either(wschar(s), + character(0x21), // 22 is " + character_in_range(0x23, 0x5B), // 5C is backslash + character_in_range(0x5D, 0x7E), // 7F is DEL + non_ascii(s)); + }; + return either(basic_unescaped(), escaped(s)); + } + + TOML11_INLINE sequence basic_string(const spec& s) { + return sequence(character('"'), + repeat_at_least(0, basic_char(s)), + character('"')); + } + + // --------------------------------------------------------------------------- + // multiline string + + TOML11_INLINE sequence escaped_newline(const spec& s) { + return sequence(character('\\'), + ws(s), + newline(s), + repeat_at_least(0, either(wschar(s), newline(s)))); + } + + TOML11_INLINE sequence ml_basic_string(const spec& s) { + const auto mlb_content = [&s]() { + return either(basic_char(s), newline(s), escaped_newline(s)); + }; + const auto mlb_quotes = []() { + return either(literal("\"\""), character('\"')); + }; + + return sequence( + literal("\"\"\""), + maybe(newline(s)), + repeat_at_least(0, mlb_content()), + repeat_at_least(0, + sequence(mlb_quotes(), repeat_at_least(1, mlb_content()))), + // XXX """ and mlb_quotes are intentionally reordered to avoid + // unexpected match of mlb_quotes + literal("\"\"\""), + maybe(mlb_quotes())); + } + + // --------------------------------------------------------------------------- + // literal string + + TOML11_INLINE either literal_char(const spec& s) { + return either(character(0x09), + character_in_range(0x20, 0x26), + character_in_range(0x28, 0x7E), + non_ascii(s)); + } + + TOML11_INLINE sequence literal_string(const spec& s) { + return sequence(character('\''), + repeat_at_least(0, literal_char(s)), + character('\'')); + } + + TOML11_INLINE sequence ml_literal_string(const spec& s) { + const auto mll_quotes = []() { + return either(literal("''"), character('\'')); + }; + const auto mll_content = [&s]() { + return either(literal_char(s), newline(s)); + }; + + return sequence( + literal("'''"), + maybe(newline(s)), + repeat_at_least(0, mll_content()), + repeat_at_least(0, + sequence(mll_quotes(), repeat_at_least(1, mll_content()))), + literal("'''"), + maybe(mll_quotes()) + // XXX ''' and mll_quotes are intentionally reordered to avoid + // unexpected match of mll_quotes + ); + } + + TOML11_INLINE either string(const spec& s) { + return either(ml_basic_string(s), + ml_literal_string(s), + basic_string(s), + literal_string(s)); + } + + // =========================================================================== + // Keys + + // to keep `expected_chars` simple + TOML11_INLINE non_ascii_key_char::non_ascii_key_char(const spec& s) noexcept { + assert(s.v1_1_0_allow_non_english_in_bare_keys); + (void)s; // for NDEBUG + } + + TOML11_INLINE std::uint32_t non_ascii_key_char::read_utf8(location& loc) const { + // U+0000 ... U+0079 ; 0xxx_xxxx + // U+0080 ... U+07FF ; 110y_yyyx 10xx_xxxx; + // U+0800 ... U+FFFF ; 1110_yyyy 10yx_xxxx 10xx_xxxx + // U+010000 ... U+10FFFF; 1111_0yyy 10yy_xxxx 10xx_xxxx 10xx_xxxx + + const unsigned char b1 = loc.current(); + loc.advance(1); + if (b1 < 0x80) { + return static_cast(b1); + } else if ((b1 >> 5) == 6) // 0b110 == 6 + { + const auto b2 = loc.current(); + loc.advance(1); + + const std::uint32_t c1 = b1 & ((1 << 5) - 1); + const std::uint32_t c2 = b2 & ((1 << 6) - 1); + const std::uint32_t codep = (c1 << 6) + c2; + + if (codep < 0x80) { + return 0xFFFFFFFF; + } + return codep; + } else if ((b1 >> 4) == 14) // 0b1110 == 14 + { + const auto b2 = loc.current(); + loc.advance(1); + if (loc.eof()) { + return 0xFFFFFFFF; + } + const auto b3 = loc.current(); + loc.advance(1); + + const std::uint32_t c1 = b1 & ((1 << 4) - 1); + const std::uint32_t c2 = b2 & ((1 << 6) - 1); + const std::uint32_t c3 = b3 & ((1 << 6) - 1); + + const std::uint32_t codep = (c1 << 12) + (c2 << 6) + c3; + if (codep < 0x800) { + return 0xFFFFFFFF; + } + return codep; + } else if ((b1 >> 3) == 30) // 0b11110 == 30 + { + const auto b2 = loc.current(); + loc.advance(1); + if (loc.eof()) { + return 0xFFFFFFFF; + } + const auto b3 = loc.current(); + loc.advance(1); + if (loc.eof()) { + return 0xFFFFFFFF; + } + const auto b4 = loc.current(); + loc.advance(1); + + const std::uint32_t c1 = b1 & ((1 << 3) - 1); + const std::uint32_t c2 = b2 & ((1 << 6) - 1); + const std::uint32_t c3 = b3 & ((1 << 6) - 1); + const std::uint32_t c4 = b4 & ((1 << 6) - 1); + const std::uint32_t codep = (c1 << 18) + (c2 << 12) + (c3 << 6) + c4; + + if (codep < 0x10000) { + return 0xFFFFFFFF; + } + return codep; + } else // not a Unicode codepoint in UTF-8 + { + return 0xFFFFFFFF; + } + } + + TOML11_INLINE region non_ascii_key_char::scan(location& loc) const { + if (loc.eof()) { + return region {}; + } + + const auto first = loc; + + const auto cp = read_utf8(loc); + + if (cp == 0xFFFFFFFF) { + return region {}; + } + + // ALPHA / DIGIT / %x2D / %x5F ; a-z A-Z 0-9 - _ + // / %xB2 / %xB3 / %xB9 / %xBC-BE ; superscript digits, fractions + // / %xC0-D6 / %xD8-F6 / %xF8-37D ; non-symbol chars in Latin block + // / %x37F-1FFF ; exclude GREEK QUESTION MARK, which + // is basically a semi-colon / %x200C-200D / %x203F-2040 ; from + // General Punctuation Block, include the two tie symbols and ZWNJ, ZWJ + // / %x2070-218F / %x2460-24FF ; include super-/subscripts, + // letterlike/numberlike forms, enclosed alphanumerics / %x2C00-2FEF / + // %x3001-D7FF ; skip arrows, math, box drawing etc, skip 2FF0-3000 + // ideographic up/down markers and spaces / %xF900-FDCF / %xFDF0-FFFD ; + // skip D800-DFFF surrogate block, E000-F8FF Private Use area, FDD0-FDEF + // intended for process-internal use (unicode) / %x10000-EFFFF ; all + // chars outside BMP range, excluding Private Use planes (F0000-10FFFF) + + if (cp == 0xB2 || cp == 0xB3 || cp == 0xB9 || + (0xBC <= cp && cp <= 0xBE) || (0xC0 <= cp && cp <= 0xD6) || + (0xD8 <= cp && cp <= 0xF6) || (0xF8 <= cp && cp <= 0x37D) || + (0x37F <= cp && cp <= 0x1FFF) || (0x200C <= cp && cp <= 0x200D) || + (0x203F <= cp && cp <= 0x2040) || (0x2070 <= cp && cp <= 0x218F) || + (0x2460 <= cp && cp <= 0x24FF) || (0x2C00 <= cp && cp <= 0x2FEF) || + (0x3001 <= cp && cp <= 0xD7FF) || (0xF900 <= cp && cp <= 0xFDCF) || + (0xFDF0 <= cp && cp <= 0xFFFD) || (0x10000 <= cp && cp <= 0xEFFFF)) { + return region(first, loc); + } + loc = first; + return region {}; + } + + TOML11_INLINE repeat_at_least unquoted_key(const spec& s) { + auto keychar = either(alpha(s), + digit(s), + character { 0x2D }, + character { 0x5F }); + + if (s.v1_1_0_allow_non_english_in_bare_keys) { + keychar.push_back(non_ascii_key_char(s)); + } + + return repeat_at_least(1, std::move(keychar)); + } + + TOML11_INLINE either quoted_key(const spec& s) { + return either(basic_string(s), literal_string(s)); + } + + TOML11_INLINE either simple_key(const spec& s) { + return either(unquoted_key(s), quoted_key(s)); + } + + TOML11_INLINE sequence dot_sep(const spec& s) { + return sequence(ws(s), character('.'), ws(s)); + } + + TOML11_INLINE sequence dotted_key(const spec& s) { + return sequence(simple_key(s), + repeat_at_least(1, sequence(dot_sep(s), simple_key(s)))); + } + + TOML11_INLINE key::key(const spec& s) noexcept + : scanner_(dotted_key(s), simple_key(s)) {} + + TOML11_INLINE sequence keyval_sep(const spec& s) { + return sequence(ws(s), character('='), ws(s)); + } + + // =========================================================================== + // Table key + + TOML11_INLINE sequence std_table(const spec& s) { + return sequence(character('['), ws(s), key(s), ws(s), character(']')); + } + + TOML11_INLINE sequence array_table(const spec& s) { + return sequence(literal("[["), ws(s), key(s), ws(s), literal("]]")); + } + + // =========================================================================== + // extension: null + + TOML11_INLINE literal null_value(const spec&) { + return literal("null"); + } + + } // namespace syntax + } // namespace detail +} // namespace toml + #endif // TOML11_SYNTAX_IMPL_HPP +#endif + +#endif // TOML11_SYNTAX_HPP +#ifndef TOML11_SKIP_HPP +#define TOML11_SKIP_HPP + +#include + +namespace toml { + namespace detail { + + template + bool skip_whitespace(location& loc, const context& ctx) { + return syntax::ws(ctx.toml_spec()).scan(loc).is_ok(); + } + + template + bool skip_empty_lines(location& loc, const context& ctx) { + return repeat_at_least(1, + sequence(syntax::ws(ctx.toml_spec()), + syntax::newline(ctx.toml_spec()))) + .scan(loc) + .is_ok(); + } + + // For error recovery. + // + // In case if a comment line contains an invalid character, we need to skip + // it to advance parsing. + template + void skip_comment_block(location& loc, const context& ctx) { + while (!loc.eof()) { + skip_whitespace(loc, ctx); + if (loc.current() == '#') { + while (!loc.eof()) { + // both CRLF and LF ends with LF. + if (loc.current() == '\n') { + loc.advance(); + break; + } + } + } else if (syntax::newline(ctx.toml_spec()).scan(loc).is_ok()) { + ; // an empty line. skip this also + } else { + // the next token is neither a comment nor empty line. + return; + } + } + return; + } + + template + void skip_empty_or_comment_lines(location& loc, const context& ctx) { + const auto& spec = ctx.toml_spec(); + repeat_at_least(0, + sequence(syntax::ws(spec), + maybe(syntax::comment(spec)), + syntax::newline(spec))) + .scan(loc); + return; + } + + // For error recovery. + // + // Sometimes we need to skip a value and find a delimiter, like `,`, `]`, or `}`. + // To find delimiter, we need to skip delimiters in a string. + // Since we are skipping invalid value while error recovery, we don't need + // to check the syntax. Here we just skip string-like region until closing quote + // is found. + template + void skip_string_like(location& loc, const context&) { + // if """ is found, skip until the closing """ is found. + if (literal("\"\"\"").scan(loc).is_ok()) { + while (!loc.eof()) { + if (literal("\"\"\"").scan(loc).is_ok()) { + return; + } + loc.advance(); + } + } else if (literal("'''").scan(loc).is_ok()) { + while (!loc.eof()) { + if (literal("'''").scan(loc).is_ok()) { + return; + } + loc.advance(); + } + } + // if " is found, skip until the closing " or newline is found. + else if (loc.current() == '"') { + while (!loc.eof()) { + loc.advance(); + if (loc.current() == '"' || loc.current() == '\n') { + loc.advance(); + return; + } + } + } else if (loc.current() == '\'') { + while (!loc.eof()) { + loc.advance(); + if (loc.current() == '\'' || loc.current() == '\n') { + loc.advance(); + return; + } + } + } + return; + } + + template + void skip_value(location& loc, const context& ctx); + template + void skip_array_like(location& loc, const context& ctx); + template + void skip_inline_table_like(location& loc, const context& ctx); + template + void skip_key_value_pair(location& loc, const context& ctx); + + template + result guess_value_type(const location& loc, + const context& ctx); + + template + void skip_array_like(location& loc, const context& ctx) { + const auto& spec = ctx.toml_spec(); + assert(loc.current() == '['); + loc.advance(); + + while (!loc.eof()) { + if (loc.current() == '\"' || loc.current() == '\'') { + skip_string_like(loc, ctx); + } else if (loc.current() == '#') { + skip_comment_block(loc, ctx); + } else if (loc.current() == '{') { + skip_inline_table_like(loc, ctx); + } else if (loc.current() == '[') { + const auto checkpoint = loc; + if (syntax::std_table(spec).scan(loc).is_ok() || + syntax::array_table(spec).scan(loc).is_ok()) { + loc = checkpoint; + break; + } + // if it is not a table-definition, then it is an array. + skip_array_like(loc, ctx); + } else if (loc.current() == '=') { + // key-value pair cannot be inside the array. + // guessing the error is "missing closing bracket `]`". + // find the previous key just before `=`. + while (loc.get_location() != 0) { + loc.retrace(); + if (loc.current() == '\n') { + loc.advance(); + break; + } + } + break; + } else if (loc.current() == ']') { + break; // found closing bracket + } else { + loc.advance(); + } + } + return; + } + + template + void skip_inline_table_like(location& loc, const context& ctx) { + assert(loc.current() == '{'); + loc.advance(); + + const auto& spec = ctx.toml_spec(); + + while (!loc.eof()) { + if (loc.current() == '\n' && !spec.v1_1_0_allow_newlines_in_inline_tables) { + break; // missing closing `}`. + } else if (loc.current() == '\"' || loc.current() == '\'') { + skip_string_like(loc, ctx); + } else if (loc.current() == '#') { + skip_comment_block(loc, ctx); + if (!spec.v1_1_0_allow_newlines_in_inline_tables) { + // comment must end with newline. + break; // missing closing `}`. + } + } else if (loc.current() == '[') { + const auto checkpoint = loc; + if (syntax::std_table(spec).scan(loc).is_ok() || + syntax::array_table(spec).scan(loc).is_ok()) { + loc = checkpoint; + break; // missing closing `}`. + } + // if it is not a table-definition, then it is an array. + skip_array_like(loc, ctx); + } else if (loc.current() == '{') { + skip_inline_table_like(loc, ctx); + } else if (loc.current() == '}') { + // closing brace found. guessing the error is inside the table. + break; + } else { + // skip otherwise. + loc.advance(); + } + } + return; + } + + template + void skip_value(location& loc, const context& ctx) { + value_t ty = guess_value_type(loc, ctx).unwrap_or(value_t::empty); + if (ty == value_t::string) { + skip_string_like(loc, ctx); + } else if (ty == value_t::array) { + skip_array_like(loc, ctx); + } else if (ty == value_t::table) { + // In case of multiline tables, it may skip key-value pair but not the + // whole table. + skip_inline_table_like(loc, ctx); + } else // others are an "in-line" values. skip until the next line + { + while (!loc.eof()) { + if (loc.current() == '\n') { + break; + } else if (loc.current() == ',' || loc.current() == ']' || + loc.current() == '}') { + break; + } + loc.advance(); + } + } + return; + } + + template + void skip_key_value_pair(location& loc, const context& ctx) { + while (!loc.eof()) { + if (loc.current() == '=') { + skip_whitespace(loc, ctx); + skip_value(loc, ctx); + return; + } else if (loc.current() == '\n') { + // newline is found before finding `=`. assuming "missing `=`". + return; + } + loc.advance(); + } + return; + } + + template + void skip_until_next_table(location& loc, const context& ctx) { + const auto& spec = ctx.toml_spec(); + while (!loc.eof()) { + if (loc.current() == '\n') { + loc.advance(); + const auto line_begin = loc; + + skip_whitespace(loc, ctx); + if (syntax::std_table(spec).scan(loc).is_ok()) { + loc = line_begin; + return; + } + if (syntax::array_table(spec).scan(loc).is_ok()) { + loc = line_begin; + return; + } + } + loc.advance(); + } + } + + } // namespace detail +} // namespace toml + +#if defined(TOML11_COMPILE_SOURCES) +namespace toml { + struct type_config; + struct ordered_type_config; + + namespace detail { + extern template bool skip_whitespace(location& loc, + const context&); + extern template bool skip_empty_lines(location& loc, + const context&); + extern template void skip_comment_block( + location& loc, + const context&); + extern template void skip_empty_or_comment_lines( + location& loc, + const context&); + extern template void skip_string_like(location& loc, + const context&); + extern template void skip_array_like(location& loc, + const context&); + extern template void skip_inline_table_like( + location& loc, + const context&); + extern template void skip_value(location& loc, + const context&); + extern template void skip_key_value_pair( + location& loc, + const context&); + extern template void skip_until_next_table( + location& loc, + const context&); + + extern template bool skip_whitespace( + location& loc, + const context&); + extern template bool skip_empty_lines( + location& loc, + const context&); + extern template void skip_comment_block( + location& loc, + const context&); + extern template void skip_empty_or_comment_lines( + location& loc, + const context&); + extern template void skip_string_like( + location& loc, + const context&); + extern template void skip_array_like( + location& loc, + const context&); + extern template void skip_inline_table_like( + location& loc, + const context&); + extern template void skip_value( + location& loc, + const context&); + extern template void skip_key_value_pair( + location& loc, + const context&); + extern template void skip_until_next_table( + location& loc, + const context&); + + } // namespace detail +} // namespace toml +#endif // TOML11_COMPILE_SOURCES + +#endif // TOML11_SKIP_HPP +#ifndef TOML11_PARSER_HPP +#define TOML11_PARSER_HPP + +#include +#include +#include +#include + +#if defined(TOML11_HAS_FILESYSTEM) && TOML11_HAS_FILESYSTEM + #include +#endif + +namespace toml { + + struct syntax_error final : public ::toml::exception { + public: + syntax_error(std::string what_arg, std::vector err) + : what_(std::move(what_arg)) + , err_(std::move(err)) {} + + ~syntax_error() noexcept override = default; + + const char* what() const noexcept override { + return what_.c_str(); + } + + const std::vector& errors() const noexcept { + return err_; + } + + private: + std::string what_; + std::vector err_; + }; + + struct file_io_error final : public ::toml::exception { + public: + file_io_error(const std::string& msg, const std::string& fname) + : errno_(cxx::make_nullopt()) + , what_(msg + " \"" + fname + "\"") {} + + file_io_error(int errnum, const std::string& msg, const std::string& fname) + : errno_(errnum) + , what_(msg + " \"" + fname + "\": errno=" + std::to_string(errnum)) {} + + ~file_io_error() noexcept override = default; + + const char* what() const noexcept override { + return what_.c_str(); + } + + bool has_errno() const noexcept { + return errno_.has_value(); + } + + int get_errno() const noexcept { + return errno_.value_or(0); + } + + private: + cxx::optional errno_; + std::string what_; + }; + + namespace detail { + + /* ============================================================================ + * __ ___ _ __ _ __ ___ _ _ + * / _/ _ \ ' \| ' \/ _ \ ' \ + * \__\___/_|_|_|_|_|_\___/_||_| + */ + + template + error_info make_syntax_error(std::string title, + const S& scanner, + location loc, + std::string suffix = "") { + auto msg = std::string("expected ") + scanner.expected_chars(loc); + auto src = source_location(region(loc)); + return make_error_info(std::move(title), + std::move(src), + std::move(msg), + std::move(suffix)); + } + + /* ============================================================================ + * _ + * __ ___ _ __ _ __ ___ _ _| |_ + * / _/ _ \ ' \| ' \/ -_) ' \ _| + * \__\___/_|_|_|_|_|_\___|_||_\__| + */ + + template + result, error_info> parse_comment_line( + location& loc, + context& ctx) { + const auto& spec = ctx.toml_spec(); + const auto first = loc; + + skip_whitespace(loc, ctx); + + const auto com_reg = syntax::comment(spec).scan(loc); + if (com_reg.is_ok()) { + // once comment started, newline must follow (or reach EOF). + if (!loc.eof() && !syntax::newline(spec).scan(loc).is_ok()) { + while (!loc.eof()) // skip until newline to continue parsing + { + loc.advance(); + if (loc.current() == '\n') { /*skip LF*/ + loc.advance(); + break; + } + } + return err(make_error_info("toml::parse_comment_line: " + "newline (LF / CRLF) or EOF is expected", + source_location(region(loc)), + "but got this", + "Hint: most of the control characters are " + "not allowed in comments")); + } + return ok(cxx::optional(com_reg.as_string())); + } else { + loc = first; // rollback whitespace to parse indent + return ok(cxx::optional(cxx::make_nullopt())); + } + } + + /* ============================================================================ + * ___ _ + * | _ ) ___ ___| |___ __ _ _ _ + * | _ \/ _ \/ _ \ / -_) _` | ' \ + * |___/\___/\___/_\___\__,_|_||_| + */ + + template + result, error_info> parse_boolean(location& loc, + const context& ctx) { + const auto& spec = ctx.toml_spec(); + + // ---------------------------------------------------------------------- + // check syntax + auto reg = syntax::boolean(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_boolean: " + "invalid boolean: boolean must be `true` or `false`, in lowercase. " + "string must be surrounded by `\"`", + syntax::boolean(spec), + loc)); + } + + // ---------------------------------------------------------------------- + // it matches. gen value + const auto str = reg.as_string(); + const auto val = [&str]() { + if (str == "true") { + return true; + } else { + assert(str == "false"); + return false; + } + }(); + + // ---------------------------------------------------------------------- + // no format info for boolean + boolean_format_info fmt; + + return ok(basic_value(val, std::move(fmt), {}, std::move(reg))); + } + + /* ============================================================================ + * ___ _ + * |_ _|_ _| |_ ___ __ _ ___ _ _ + * | || ' \ _/ -_) _` / -_) '_| + * |___|_||_\__\___\__, \___|_| + * |___/ + */ + + template + result, error_info> parse_bin_integer(location& loc, + const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + auto reg = syntax::bin_int(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_bin_integer: " + "invalid integer: bin_int must be like: 0b0101, 0b1111_0000", + syntax::bin_int(spec), + loc)); + } + + auto str = reg.as_string(); + + integer_format_info fmt; + fmt.fmt = integer_format::bin; + fmt.width = str.size() - 2 - + static_cast(std::count(str.begin(), str.end(), '_')); + + const auto first_underscore = std::find(str.rbegin(), str.rend(), '_'); + if (first_underscore != str.rend()) { + fmt.spacer = static_cast( + std::distance(str.rbegin(), first_underscore)); + } + + // skip prefix `0b` and zeros and underscores at the MSB + str.erase(str.begin(), std::find(std::next(str.begin(), 2), str.end(), '1')); + + // remove all `_` before calling TC::parse_int + str.erase(std::remove(str.begin(), str.end(), '_'), str.end()); + + // 0b0000_0000 becomes empty. + if (str.empty()) { + str = "0"; + } + + const auto val = TC::parse_int(str, source_location(region(loc)), 2); + if (val.is_ok()) { + return ok(basic_value(val.as_ok(), std::move(fmt), {}, std::move(reg))); + } else { + loc = first; + return err(val.as_err()); + } + } + + // ---------------------------------------------------------------------------- + + template + result, error_info> parse_oct_integer(location& loc, + const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + auto reg = syntax::oct_int(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_oct_integer: " + "invalid integer: oct_int must be like: 0o775, 0o04_44", + syntax::oct_int(spec), + loc)); + } + + auto str = reg.as_string(); + + integer_format_info fmt; + fmt.fmt = integer_format::oct; + fmt.width = str.size() - 2 - + static_cast(std::count(str.begin(), str.end(), '_')); + + const auto first_underscore = std::find(str.rbegin(), str.rend(), '_'); + if (first_underscore != str.rend()) { + fmt.spacer = static_cast( + std::distance(str.rbegin(), first_underscore)); + } + + // skip prefix `0o` and zeros and underscores at the MSB + str.erase(str.begin(), + std::find_if(std::next(str.begin(), 2), str.end(), [](const char c) { + return c != '0' && c != '_'; + })); + + // remove all `_` before calling TC::parse_int + str.erase(std::remove(str.begin(), str.end(), '_'), str.end()); + + // 0o0000_0000 becomes empty. + if (str.empty()) { + str = "0"; + } + + const auto val = TC::parse_int(str, source_location(region(loc)), 8); + if (val.is_ok()) { + return ok(basic_value(val.as_ok(), std::move(fmt), {}, std::move(reg))); + } else { + loc = first; + return err(val.as_err()); + } + } + + template + result, error_info> parse_hex_integer(location& loc, + const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + auto reg = syntax::hex_int(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_hex_integer: " + "invalid integer: hex_int must be like: 0xC0FFEE, 0xdead_beef", + syntax::hex_int(spec), + loc)); + } + + auto str = reg.as_string(); + + integer_format_info fmt; + fmt.fmt = integer_format::hex; + fmt.width = str.size() - 2 - + static_cast(std::count(str.begin(), str.end(), '_')); + + const auto first_underscore = std::find(str.rbegin(), str.rend(), '_'); + if (first_underscore != str.rend()) { + fmt.spacer = static_cast( + std::distance(str.rbegin(), first_underscore)); + } + + // skip prefix `0x` and zeros and underscores at the MSB + str.erase(str.begin(), + std::find_if(std::next(str.begin(), 2), str.end(), [](const char c) { + return c != '0' && c != '_'; + })); + + // remove all `_` before calling TC::parse_int + str.erase(std::remove(str.begin(), str.end(), '_'), str.end()); + + // 0x0000_0000 becomes empty. + if (str.empty()) { + str = "0"; + } + + // prefix zero and _ is removed. check if it uses upper/lower case. + // if both upper and lower case letters are found, set upper=true. + const auto lower_not_found = std::find_if(str.begin(), str.end(), [](const char c) { + return std::islower(static_cast(c)) != 0; + }) == str.end(); + const auto upper_found = std::find_if(str.begin(), str.end(), [](const char c) { + return std::isupper(static_cast(c)) != 0; + }) != str.end(); + fmt.uppercase = lower_not_found || upper_found; + + const auto val = TC::parse_int(str, source_location(region(loc)), 16); + if (val.is_ok()) { + return ok(basic_value(val.as_ok(), std::move(fmt), {}, std::move(reg))); + } else { + loc = first; + return err(val.as_err()); + } + } + + template + result, error_info> parse_dec_integer(location& loc, + const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + // ---------------------------------------------------------------------- + // check syntax + auto reg = syntax::dec_int(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_dec_integer: " + "invalid integer: dec_int must be like: 42, 123_456_789", + syntax::dec_int(spec), + loc)); + } + + // ---------------------------------------------------------------------- + // it matches. gen value + auto str = reg.as_string(); + + integer_format_info fmt; + fmt.fmt = integer_format::dec; + fmt.width = str.size() - static_cast( + std::count(str.begin(), str.end(), '_')); + + const auto first_underscore = std::find(str.rbegin(), str.rend(), '_'); + if (first_underscore != str.rend()) { + fmt.spacer = static_cast( + std::distance(str.rbegin(), first_underscore)); + } + + // remove all `_` before calling TC::parse_int + str.erase(std::remove(str.begin(), str.end(), '_'), str.end()); + + auto src = source_location(region(loc)); + const auto val = TC::parse_int(str, src, 10); + if (val.is_err()) { + loc = first; + return err(val.as_err()); + } + + // ---------------------------------------------------------------------- + // parse suffix (extension) + + if (spec.ext_num_suffix && loc.current() == '_') { + const auto sfx_reg = syntax::num_suffix(spec).scan(loc); + if (!sfx_reg.is_ok()) { + loc = first; + return err(make_error_info( + "toml::parse_dec_integer: " + "invalid suffix: should be `_ non-digit-graph (graph | _graph)`", + source_location(region(loc)), + "here")); + } + auto sfx = sfx_reg.as_string(); + assert(!sfx.empty() && sfx.front() == '_'); + sfx.erase(sfx.begin()); // remove the first `_` + + fmt.suffix = sfx; + } + + return ok(basic_value(val.as_ok(), std::move(fmt), {}, std::move(reg))); + } + + template + result, error_info> parse_integer(location& loc, + const context& ctx) { + const auto first = loc; + + if (!loc.eof() && (loc.current() == '+' || loc.current() == '-')) { + // skip +/- to diagnose +0xDEADBEEF or -0b0011 (invalid). + // without this, +0xDEAD_BEEF will be parsed as a decimal int and + // unexpected "xDEAD_BEEF" will appear after integer "+0". + loc.advance(); + } + + if (!loc.eof() && loc.current() == '0') { + loc.advance(); + if (loc.eof()) { + // `[+-]?0`. parse as an decimal integer. + loc = first; + return parse_dec_integer(loc, ctx); + } + + const auto prefix = loc.current(); + auto prefix_src = source_location(region(loc)); + + loc = first; + + if (prefix == 'b') { + return parse_bin_integer(loc, ctx); + } + if (prefix == 'o') { + return parse_oct_integer(loc, ctx); + } + if (prefix == 'x') { + return parse_hex_integer(loc, ctx); + } + + if (std::isdigit(prefix)) { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_integer: " + "leading zero in an decimal integer is not allowed", + std::move(src), + "leading zero")); + } + } + + loc = first; + return parse_dec_integer(loc, ctx); + } + + /* ============================================================================ + * ___ _ _ _ + * | __| |___ __ _| |_(_)_ _ __ _ + * | _|| / _ \/ _` | _| | ' \/ _` | + * |_| |_\___/\__,_|\__|_|_||_\__, | + * |___/ + */ + + template + result, error_info> parse_floating(location& loc, + const context& ctx) { + using floating_type = typename basic_value::floating_type; + + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + // ---------------------------------------------------------------------- + // check syntax + bool is_hex = false; + std::string str; + region reg; + if (spec.ext_hex_float && + sequence(character('0'), character('x')).scan(loc).is_ok()) { + loc = first; + is_hex = true; + + reg = syntax::hex_floating(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_floating: " + "invalid hex floating: float must be like: 0xABCp-3f", + syntax::floating(spec), + loc)); + } + str = reg.as_string(); + } else { + reg = syntax::floating(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_floating: " + "invalid floating: float must be like: -3.14159_26535, 6.022e+23, " + "inf, or nan (lowercase).", + syntax::floating(spec), + loc)); + } + str = reg.as_string(); + } + + // ---------------------------------------------------------------------- + // it matches. gen value + + floating_format_info fmt; + + if (is_hex) { + fmt.fmt = floating_format::hex; + } else { + // since we already checked that the string conforms the TOML standard. + if (std::find(str.begin(), str.end(), 'e') != str.end() || + std::find(str.begin(), str.end(), 'E') != str.end()) { + fmt.fmt = floating_format::scientific; // use exponent part + } else { + fmt.fmt = floating_format::fixed; // do not use exponent part + } + } + + str.erase(std::remove(str.begin(), str.end(), '_'), str.end()); + + floating_type val { 0 }; + + if (str == "inf" || str == "+inf") { + TOML11_CONSTEXPR_IF(std::numeric_limits::has_infinity) { + val = std::numeric_limits::infinity(); + } + else { + return err(make_error_info( + "toml::parse_floating: inf value found" + " but the current environment does not support inf. Please" + " make sure that the floating-point implementation conforms" + " IEEE 754/ISO 60559 international standard.", + source_location(region(loc)), + "floating_type: inf is not supported")); + } + } else if (str == "-inf") { + TOML11_CONSTEXPR_IF(std::numeric_limits::has_infinity) { + val = -std::numeric_limits::infinity(); + } + else { + return err(make_error_info( + "toml::parse_floating: inf value found" + " but the current environment does not support inf. Please" + " make sure that the floating-point implementation conforms" + " IEEE 754/ISO 60559 international standard.", + source_location(region(loc)), + "floating_type: inf is not supported")); + } + } else if (str == "nan" || str == "+nan") { + TOML11_CONSTEXPR_IF(std::numeric_limits::has_quiet_NaN) { + val = std::numeric_limits::quiet_NaN(); + } + else TOML11_CONSTEXPR_IF( + std::numeric_limits::has_signaling_NaN) { + val = std::numeric_limits::signaling_NaN(); + } + else { + return err(make_error_info( + "toml::parse_floating: NaN value found" + " but the current environment does not support NaN. Please" + " make sure that the floating-point implementation conforms" + " IEEE 754/ISO 60559 international standard.", + source_location(region(loc)), + "floating_type: NaN is not supported")); + } + } else if (str == "-nan") { + using std::copysign; + TOML11_CONSTEXPR_IF(std::numeric_limits::has_quiet_NaN) { + val = copysign(std::numeric_limits::quiet_NaN(), + floating_type(-1)); + } + else TOML11_CONSTEXPR_IF( + std::numeric_limits::has_signaling_NaN) { + val = copysign(std::numeric_limits::signaling_NaN(), + floating_type(-1)); + } + else { + return err(make_error_info( + "toml::parse_floating: NaN value found" + " but the current environment does not support NaN. Please" + " make sure that the floating-point implementation conforms" + " IEEE 754/ISO 60559 international standard.", + source_location(region(loc)), + "floating_type: NaN is not supported")); + } + } else { + // set precision + const auto has_sign = !str.empty() && + (str.front() == '+' || str.front() == '-'); + const auto decpoint = std::find(str.begin(), str.end(), '.'); + const auto exponent = std::find_if(str.begin(), str.end(), [](const char c) { + return c == 'e' || c == 'E'; + }); + if (decpoint != str.end() && exponent != str.end()) { + assert(decpoint < exponent); + } + + if (fmt.fmt == floating_format::scientific) { + // total width + fmt.prec = static_cast(std::distance(str.begin(), exponent)); + if (has_sign) { + fmt.prec -= 1; + } + if (decpoint != str.end()) { + fmt.prec -= 1; + } + } else if (fmt.fmt == floating_format::hex) { + fmt.prec = std::numeric_limits::max_digits10; + } else { + // width after decimal point + fmt.prec = static_cast( + std::distance(std::next(decpoint), exponent)); + } + + auto src = source_location(region(loc)); + const auto res = TC::parse_float(str, src, is_hex); + if (res.is_ok()) { + val = res.as_ok(); + } else { + return err(res.as_err()); + } + } + + // ---------------------------------------------------------------------- + // parse suffix (extension) + + if (spec.ext_num_suffix && loc.current() == '_') { + const auto sfx_reg = syntax::num_suffix(spec).scan(loc); + if (!sfx_reg.is_ok()) { + auto src = source_location(region(loc)); + loc = first; + return err(make_error_info( + "toml::parse_floating: " + "invalid suffix: should be `_ non-digit-graph (graph | _graph)`", + std::move(src), + "here")); + } + auto sfx = sfx_reg.as_string(); + assert(!sfx.empty() && sfx.front() == '_'); + sfx.erase(sfx.begin()); // remove the first `_` + + fmt.suffix = sfx; + } + + return ok(basic_value(val, std::move(fmt), {}, std::move(reg))); + } + + /* ============================================================================ + * ___ _ _ _ + * | \ __ _| |_ ___| |_(_)_ __ ___ + * | |) / _` | _/ -_) _| | ' \/ -_) + * |___/\__,_|\__\___|\__|_|_|_|_\___| + */ + + // all the offset_datetime, local_datetime, local_date parses date part. + template + result, error_info> + parse_local_date_only(location& loc, const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + local_date_format_info fmt; + + // ---------------------------------------------------------------------- + // check syntax + auto reg = syntax::local_date(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_local_date: " + "invalid date: date must be like: 1234-05-06, yyyy-mm-dd.", + syntax::local_date(spec), + loc)); + } + + // ---------------------------------------------------------------------- + // it matches. gen value + const auto str = reg.as_string(); + + // 0123456789 + // yyyy-mm-dd + const auto year_r = from_string(str.substr(0, 4)); + const auto month_r = from_string(str.substr(5, 2)); + const auto day_r = from_string(str.substr(8, 2)); + + if (year_r.is_err()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_date: " + "failed to read year `" + + str.substr(0, 4) + "`", + std::move(src), + "here")); + } + if (month_r.is_err()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_date: " + "failed to read month `" + + str.substr(5, 2) + "`", + std::move(src), + "here")); + } + if (day_r.is_err()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_date: " + "failed to read day `" + + str.substr(8, 2) + "`", + std::move(src), + "here")); + } + + const auto year = year_r.unwrap(); + const auto month = month_r.unwrap(); + const auto day = day_r.unwrap(); + + { + // We briefly check whether the input date is valid or not. + // Actually, because of the historical reasons, there are several + // edge cases, such as 1582/10/5-1582/10/14 (only in several countries). + // But here, we do not care about it. + // It makes the code complicated and there is only low probability + // that such a specific date is needed in practice. If someone need to + // validate date accurately, that means that the one need a specialized + // library for their purpose in another layer. + + const bool is_leap = (year % 4 == 0) && + ((year % 100 != 0) || (year % 400 == 0)); + const auto max_day = [month, is_leap]() { + if (month == 2) { + return is_leap ? 29 : 28; + } + if (month == 4 || month == 6 || month == 9 || month == 11) { + return 30; + } + return 31; + }(); + + if ((month < 1 || 12 < month) || (day < 1 || max_day < day)) { + auto src = source_location(region(first)); + return err( + make_error_info("toml::parse_local_date: invalid date.", + std::move(src), + "month must be 01-12, day must be any of " + "01-28,29,30,31 depending on the month/year.")); + } + } + + return ok( + std::make_tuple(local_date(year, static_cast(month - 1), day), + std::move(fmt), + std::move(reg))); + } + + template + result, error_info> parse_local_date(location& loc, + const context& ctx) { + auto val_fmt_reg = parse_local_date_only(loc, ctx); + if (val_fmt_reg.is_err()) { + return err(val_fmt_reg.unwrap_err()); + } + + auto val = std::move(std::get<0>(val_fmt_reg.unwrap())); + auto fmt = std::move(std::get<1>(val_fmt_reg.unwrap())); + auto reg = std::move(std::get<2>(val_fmt_reg.unwrap())); + + return ok( + basic_value(std::move(val), std::move(fmt), {}, std::move(reg))); + } + + // all the offset_datetime, local_datetime, local_time parses date part. + template + result, error_info> + parse_local_time_only(location& loc, const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + local_time_format_info fmt; + + // ---------------------------------------------------------------------- + // check syntax + auto reg = syntax::local_time(spec).scan(loc); + if (!reg.is_ok()) { + if (spec.v1_1_0_make_seconds_optional) { + return err(make_syntax_error( + "toml::parse_local_time: " + "invalid time: time must be HH:MM(:SS.sss) (seconds are optional)", + syntax::local_time(spec), + loc)); + } else { + return err( + make_syntax_error("toml::parse_local_time: " + "invalid time: time must be HH:MM:SS(.sss) " + "(subseconds are optional)", + syntax::local_time(spec), + loc)); + } + } + + // ---------------------------------------------------------------------- + // it matches. gen value + const auto str = reg.as_string(); + + // at least we have HH:MM. + // 01234 + // HH:MM + const auto hour_r = from_string(str.substr(0, 2)); + const auto minute_r = from_string(str.substr(3, 2)); + + if (hour_r.is_err()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_time: " + "failed to read hour `" + + str.substr(0, 2) + "`", + std::move(src), + "here")); + } + if (minute_r.is_err()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_time: " + "failed to read minute `" + + str.substr(3, 2) + "`", + std::move(src), + "here")); + } + + const auto hour = hour_r.unwrap(); + const auto minute = minute_r.unwrap(); + + if ((hour < 0 || 24 <= hour) || (minute < 0 || 60 <= minute)) { + auto src = source_location(region(first)); + return err( + make_error_info("toml::parse_local_time: invalid time.", + std::move(src), + "hour must be 00-23, minute must be 00-59.")); + } + + // ----------------------------------------------------------------------- + // we have hour and minute. + // Since toml v1.1.0, second and subsecond part becomes optional. + // Check the version and return if second does not exist. + + if (str.size() == 5 && spec.v1_1_0_make_seconds_optional) { + fmt.has_seconds = false; + fmt.subsecond_precision = 0; + return ok(std::make_tuple(local_time(hour, minute, 0), + std::move(fmt), + std::move(reg))); + } + assert(str.at(5) == ':'); + + // we have at least `:SS` part. `.subseconds` are optional. + + // 0 1 + // 012345678901234 + // HH:MM:SS.subsec + const auto sec_r = from_string(str.substr(6, 2)); + if (sec_r.is_err()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_time: " + "failed to read second `" + + str.substr(6, 2) + "`", + std::move(src), + "here")); + } + const auto sec = sec_r.unwrap(); + + if (sec < 0 || 60 < sec) // :60 is allowed + { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_time: invalid time.", + std::move(src), + "second must be 00-60.")); + } + + if (str.size() == 8) { + fmt.has_seconds = true; + fmt.subsecond_precision = 0; + return ok(std::make_tuple(local_time(hour, minute, sec), + std::move(fmt), + std::move(reg))); + } + + assert(str.at(8) == '.'); + + auto secfrac = str.substr(9, str.size() - 9); + + fmt.has_seconds = true; + fmt.subsecond_precision = secfrac.size(); + + while (secfrac.size() < 9) { + secfrac += '0'; + } + assert(9 <= secfrac.size()); + const auto ms_r = from_string(secfrac.substr(0, 3)); + const auto us_r = from_string(secfrac.substr(3, 3)); + const auto ns_r = from_string(secfrac.substr(6, 3)); + + if (ms_r.is_err()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_time: " + "failed to read milliseconds `" + + secfrac.substr(0, 3) + "`", + std::move(src), + "here")); + } + if (us_r.is_err()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_time: " + "failed to read microseconds`" + + str.substr(3, 3) + "`", + std::move(src), + "here")); + } + if (ns_r.is_err()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_local_time: " + "failed to read nanoseconds`" + + str.substr(6, 3) + "`", + std::move(src), + "here")); + } + const auto ms = ms_r.unwrap(); + const auto us = us_r.unwrap(); + const auto ns = ns_r.unwrap(); + + return ok(std::make_tuple(local_time(hour, minute, sec, ms, us, ns), + std::move(fmt), + std::move(reg))); + } + + template + result, error_info> parse_local_time(location& loc, + const context& ctx) { + const auto first = loc; + + auto val_fmt_reg = parse_local_time_only(loc, ctx); + if (val_fmt_reg.is_err()) { + return err(val_fmt_reg.unwrap_err()); + } + + auto val = std::move(std::get<0>(val_fmt_reg.unwrap())); + auto fmt = std::move(std::get<1>(val_fmt_reg.unwrap())); + auto reg = std::move(std::get<2>(val_fmt_reg.unwrap())); + + return ok( + basic_value(std::move(val), std::move(fmt), {}, std::move(reg))); + } + + template + result, error_info> parse_local_datetime(location& loc, + const context& ctx) { + using char_type = location::char_type; + + const auto first = loc; + + local_datetime_format_info fmt; + + // ---------------------------------------------------------------------- + + auto date_fmt_reg = parse_local_date_only(loc, ctx); + if (date_fmt_reg.is_err()) { + return err(date_fmt_reg.unwrap_err()); + } + + if (loc.current() == char_type('T')) { + loc.advance(); + fmt.delimiter = datetime_delimiter_kind::upper_T; + } else if (loc.current() == char_type('t')) { + loc.advance(); + fmt.delimiter = datetime_delimiter_kind::lower_t; + } else if (loc.current() == char_type(' ')) { + loc.advance(); + fmt.delimiter = datetime_delimiter_kind::space; + } else { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_local_datetime: " + "expect date-time delimiter `T`, `t` or ` `(space).", + std::move(src), + "here")); + } + + auto time_fmt_reg = parse_local_time_only(loc, ctx); + if (time_fmt_reg.is_err()) { + return err(time_fmt_reg.unwrap_err()); + } + + fmt.has_seconds = std::get<1>(time_fmt_reg.unwrap()).has_seconds; + fmt.subsecond_precision = std::get<1>(time_fmt_reg.unwrap()).subsecond_precision; + + // ---------------------------------------------------------------------- + + region reg(first, loc); + local_datetime val(std::get<0>(date_fmt_reg.unwrap()), + std::get<0>(time_fmt_reg.unwrap())); + + return ok(basic_value(val, std::move(fmt), {}, std::move(reg))); + } + + template + result, error_info> parse_offset_datetime( + location& loc, + const context& ctx) { + using char_type = location::char_type; + + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + offset_datetime_format_info fmt; + + // ---------------------------------------------------------------------- + // date part + + auto date_fmt_reg = parse_local_date_only(loc, ctx); + if (date_fmt_reg.is_err()) { + return err(date_fmt_reg.unwrap_err()); + } + + // ---------------------------------------------------------------------- + // delimiter + + if (loc.current() == char_type('T')) { + loc.advance(); + fmt.delimiter = datetime_delimiter_kind::upper_T; + } else if (loc.current() == char_type('t')) { + loc.advance(); + fmt.delimiter = datetime_delimiter_kind::lower_t; + } else if (loc.current() == char_type(' ')) { + loc.advance(); + fmt.delimiter = datetime_delimiter_kind::space; + } else { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_offset_datetime: " + "expect date-time delimiter `T` or ` `(space).", + std::move(src), + "here")); + } + + // ---------------------------------------------------------------------- + // time part + + auto time_fmt_reg = parse_local_time_only(loc, ctx); + if (time_fmt_reg.is_err()) { + return err(time_fmt_reg.unwrap_err()); + } + + fmt.has_seconds = std::get<1>(time_fmt_reg.unwrap()).has_seconds; + fmt.subsecond_precision = std::get<1>(time_fmt_reg.unwrap()).subsecond_precision; + + // ---------------------------------------------------------------------- + // offset part + + const auto ofs_reg = syntax::time_offset(spec).scan(loc); + if (!ofs_reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_offset_datetime: " + "invalid offset: offset must be like: Z, +01:00, or -10:00.", + syntax::time_offset(spec), + loc)); + } + + const auto ofs_str = ofs_reg.as_string(); + + time_offset offset(0, 0); + + assert(ofs_str.size() != 0); + + if (ofs_str.at(0) == char_type('+') || ofs_str.at(0) == char_type('-')) { + const auto hour_r = from_string(ofs_str.substr(1, 2)); + const auto minute_r = from_string(ofs_str.substr(4, 2)); + if (hour_r.is_err()) { + auto src = source_location(region(loc)); + return err(make_error_info("toml::parse_offset_datetime: " + "Failed to read offset hour part", + std::move(src), + "here")); + } + if (minute_r.is_err()) { + auto src = source_location(region(loc)); + return err(make_error_info("toml::parse_offset_datetime: " + "Failed to read offset minute part", + std::move(src), + "here")); + } + const auto hour = hour_r.unwrap(); + const auto minute = minute_r.unwrap(); + + if (ofs_str.at(0) == '+') { + offset = time_offset(hour, minute); + } else { + offset = time_offset(-hour, -minute); + } + } else { + assert(ofs_str.at(0) == char_type('Z') || ofs_str.at(0) == char_type('z')); + } + + if (offset.hour < -24 || 24 < offset.hour || offset.minute < -60 || + 60 < offset.minute) { + return err( + make_error_info("toml::parse_offset_datetime: " + "too large offset: |hour| <= 24, |minute| <= 60", + source_location(region(first, loc)), + "here")); + } + + // ---------------------------------------------------------------------- + + region reg(first, loc); + offset_datetime val(local_datetime(std::get<0>(date_fmt_reg.unwrap()), + std::get<0>(time_fmt_reg.unwrap())), + offset); + + return ok(basic_value(val, std::move(fmt), {}, std::move(reg))); + } + + /* ============================================================================ + * ___ _ _ + * / __| |_ _ _(_)_ _ __ _ + * \__ \ _| '_| | ' \/ _` | + * |___/\__|_| |_|_||_\__, | + * |___/ + */ + + template + result::string_type, error_info> parse_utf8_codepoint( + const region& reg) { + using string_type = typename basic_value::string_type; + using char_type = typename string_type::value_type; + + // assert(reg.as_lines().size() == 1); // XXX heavy check + + const auto str = reg.as_string(); + assert(!str.empty()); + assert(str.front() == 'u' || str.front() == 'U' || str.front() == 'x'); + + std::uint_least32_t codepoint; + std::istringstream iss(str.substr(1)); + iss >> std::hex >> codepoint; + + const auto to_char = [](const std::uint_least32_t i) noexcept -> char_type { + const auto uc = static_cast(i & 0xFF); + return cxx::bit_cast(uc); + }; + + string_type character; + if (codepoint < 0x80) // U+0000 ... U+0079 ; just an ASCII. + { + character += static_cast(codepoint); + } else if (codepoint < 0x800) // U+0080 ... U+07FF + { + // 110yyyyx 10xxxxxx; 0x3f == 0b0011'1111 + character += to_char(0xC0 | (codepoint >> 6)); + character += to_char(0x80 | (codepoint & 0x3F)); + } else if (codepoint < 0x10000) // U+0800...U+FFFF + { + if (0xD800 <= codepoint && codepoint <= 0xDFFF) { + auto src = source_location(reg); + return err(make_error_info("toml::parse_utf8_codepoint: " + "[0xD800, 0xDFFF] is not a valid UTF-8", + std::move(src), + "here")); + } + assert(codepoint < 0xD800 || 0xDFFF < codepoint); + // 1110yyyy 10yxxxxx 10xxxxxx + character += to_char(0xE0 | (codepoint >> 12)); + character += to_char(0x80 | ((codepoint >> 6) & 0x3F)); + character += to_char(0x80 | ((codepoint) & 0x3F)); + } else if (codepoint < 0x110000) // U+010000 ... U+10FFFF + { + // 11110yyy 10yyxxxx 10xxxxxx 10xxxxxx + character += to_char(0xF0 | (codepoint >> 18)); + character += to_char(0x80 | ((codepoint >> 12) & 0x3F)); + character += to_char(0x80 | ((codepoint >> 6) & 0x3F)); + character += to_char(0x80 | ((codepoint) & 0x3F)); + } else // out of UTF-8 region + { + auto src = source_location(reg); + return err(make_error_info("toml::parse_utf8_codepoint: " + "input codepoint is too large.", + std::move(src), + "must be in range [0x00, 0x10FFFF]")); + } + return ok(character); + } + + template + result::string_type, error_info> parse_escape_sequence( + location& loc, + const context& ctx) { + using string_type = typename basic_value::string_type; + using char_type = typename string_type::value_type; + + const auto& spec = ctx.toml_spec(); + + assert(!loc.eof()); + assert(loc.current() == '\\'); + loc.advance(); // consume the first backslash + + string_type retval; + + if (loc.current() == '\\') { + retval += char_type('\\'); + loc.advance(); + } else if (loc.current() == '"') { + retval += char_type('\"'); + loc.advance(); + } else if (loc.current() == 'b') { + retval += char_type('\b'); + loc.advance(); + } else if (loc.current() == 'f') { + retval += char_type('\f'); + loc.advance(); + } else if (loc.current() == 'n') { + retval += char_type('\n'); + loc.advance(); + } else if (loc.current() == 'r') { + retval += char_type('\r'); + loc.advance(); + } else if (loc.current() == 't') { + retval += char_type('\t'); + loc.advance(); + } else if (spec.v1_1_0_add_escape_sequence_e && loc.current() == 'e') { + retval += char_type('\x1b'); + loc.advance(); + } else if (spec.v1_1_0_add_escape_sequence_x && loc.current() == 'x') { + auto scanner = sequence(character('x'), + repeat_exact(2, syntax::hexdig(spec))); + const auto reg = scanner.scan(loc); + if (!reg.is_ok()) { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_escape_sequence: " + "invalid token found in UTF-8 codepoint \\xhh", + std::move(src), + "here")); + } + const auto utf8 = parse_utf8_codepoint(reg); + if (utf8.is_err()) { + return err(utf8.as_err()); + } + retval += utf8.unwrap(); + } else if (loc.current() == 'u') { + auto scanner = sequence(character('u'), + repeat_exact(4, syntax::hexdig(spec))); + const auto reg = scanner.scan(loc); + if (!reg.is_ok()) { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_escape_sequence: " + "invalid token found in UTF-8 codepoint \\uhhhh", + std::move(src), + "here")); + } + const auto utf8 = parse_utf8_codepoint(reg); + if (utf8.is_err()) { + return err(utf8.as_err()); + } + retval += utf8.unwrap(); + } else if (loc.current() == 'U') { + auto scanner = sequence(character('U'), + repeat_exact(8, syntax::hexdig(spec))); + const auto reg = scanner.scan(loc); + if (!reg.is_ok()) { + auto src = source_location(region(loc)); + return err(make_error_info( + "toml::parse_escape_sequence: " + "invalid token found in UTF-8 codepoint \\Uhhhhhhhh", + std::move(src), + "here")); + } + const auto utf8 = parse_utf8_codepoint(reg); + if (utf8.is_err()) { + return err(utf8.as_err()); + } + retval += utf8.unwrap(); + } else { + auto src = source_location(region(loc)); + std::string escape_seqs = + "allowed escape seqs: \\\\, \\\", \\b, \\f, \\n, \\r, \\t"; + if (spec.v1_1_0_add_escape_sequence_e) { + escape_seqs += ", \\e"; + } + if (spec.v1_1_0_add_escape_sequence_x) { + escape_seqs += ", \\xhh"; + } + escape_seqs += ", \\uhhhh, or \\Uhhhhhhhh"; + + return err(make_error_info("toml::parse_escape_sequence: " + "unknown escape sequence.", + std::move(src), + escape_seqs)); + } + return ok(retval); + } + + template + result, error_info> parse_ml_basic_string( + location& loc, + const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + string_format_info fmt; + fmt.fmt = string_format::multiline_basic; + + auto reg = syntax::ml_basic_string(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error("toml::parse_ml_basic_string: " + "invalid string format", + syntax::ml_basic_string(spec), + loc)); + } + + // ---------------------------------------------------------------------- + // it matches. gen value + + auto str = reg.as_string(); + + // we already checked that it starts with """ and ends with """. + assert(str.substr(0, 3) == "\"\"\""); + str.erase(0, 3); + + assert(str.size() >= 3); + assert(str.substr(str.size() - 3, 3) == "\"\"\""); + str.erase(str.size() - 3, 3); + + // the first newline just after """ is trimmed + if (str.size() >= 1 && str.at(0) == '\n') { + str.erase(0, 1); + fmt.start_with_newline = true; + } else if (str.size() >= 2 && str.at(0) == '\r' && str.at(1) == '\n') { + str.erase(0, 2); + fmt.start_with_newline = true; + } + + using string_type = typename basic_value::string_type; + string_type val; + { + auto iter = str.cbegin(); + while (iter != str.cend()) { + if (*iter == '\\') // remove whitespaces around escaped-newline + { + // we assume that the string is not too long to copy + auto loc2 = make_temporary_location(make_string(iter, str.cend())); + if (syntax::escaped_newline(spec).scan(loc2).is_ok()) { + std::advance(iter, + loc2.get_location()); // skip escaped newline and indent + // now iter points non-WS char + assert(iter == str.end() || (*iter != ' ' && *iter != '\t')); + } else // normal escape seq. + { + auto esc = parse_escape_sequence(loc2, ctx); + + // syntax does not check its value. the unicode codepoint may be + // invalid, e.g. out-of-bound, [0xD800, 0xDFFF] + if (esc.is_err()) { + return err(esc.unwrap_err()); + } + + val += esc.unwrap(); + std::advance(iter, loc2.get_location()); + } + } else // we already checked the syntax. we don't need to check it again. + { + val += static_cast(*iter); + ++iter; + } + } + } + + return ok( + basic_value(std::move(val), std::move(fmt), {}, std::move(reg))); + } + + template + result::string_type, region>, error_info> + parse_basic_string_only(location& loc, const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + auto reg = syntax::basic_string(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error("toml::parse_basic_string: " + "invalid string format", + syntax::basic_string(spec), + loc)); + } + + // ---------------------------------------------------------------------- + // it matches. gen value + + auto str = reg.as_string(); + + assert(str.back() == '\"'); + str.pop_back(); + assert(str.at(0) == '\"'); + str.erase(0, 1); + + using string_type = typename basic_value::string_type; + using char_type = typename string_type::value_type; + string_type val; + + { + auto iter = str.begin(); + while (iter != str.end()) { + if (*iter == '\\') { + auto loc2 = make_temporary_location(make_string(iter, str.end())); + + auto esc = parse_escape_sequence(loc2, ctx); + + // syntax does not check its value. the unicode codepoint may be + // invalid, e.g. out-of-bound, [0xD800, 0xDFFF] + if (esc.is_err()) { + return err(esc.unwrap_err()); + } + + val += esc.unwrap(); + std::advance(iter, loc2.get_location()); + } else { + val += char_type(*iter); // we already checked the syntax. + ++iter; + } + } + } + return ok(std::make_pair(val, reg)); + } + + template + result, error_info> parse_basic_string(location& loc, + const context& ctx) { + const auto first = loc; + + string_format_info fmt; + fmt.fmt = string_format::basic; + + auto val_res = parse_basic_string_only(loc, ctx); + if (val_res.is_err()) { + return err(std::move(val_res.unwrap_err())); + } + auto val = std::move(val_res.unwrap().first); + auto reg = std::move(val_res.unwrap().second); + + return ok( + basic_value(std::move(val), std::move(fmt), {}, std::move(reg))); + } + + template + result, error_info> parse_ml_literal_string( + location& loc, + const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + string_format_info fmt; + fmt.fmt = string_format::multiline_literal; + + auto reg = syntax::ml_literal_string(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error("toml::parse_ml_literal_string: " + "invalid string format", + syntax::ml_literal_string(spec), + loc)); + } + + // ---------------------------------------------------------------------- + // it matches. gen value + + auto str = reg.as_string(); + + assert(str.substr(0, 3) == "'''"); + assert(str.substr(str.size() - 3, 3) == "'''"); + str.erase(0, 3); + str.erase(str.size() - 3, 3); + + // the first newline just after """ is trimmed + if (str.size() >= 1 && str.at(0) == '\n') { + str.erase(0, 1); + fmt.start_with_newline = true; + } else if (str.size() >= 2 && str.at(0) == '\r' && str.at(1) == '\n') { + str.erase(0, 2); + fmt.start_with_newline = true; + } + + using string_type = typename basic_value::string_type; + string_type val(str.begin(), str.end()); + + return ok( + basic_value(std::move(val), std::move(fmt), {}, std::move(reg))); + } + + template + result::string_type, region>, error_info> + parse_literal_string_only(location& loc, const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + auto reg = syntax::literal_string(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error("toml::parse_literal_string: " + "invalid string format", + syntax::literal_string(spec), + loc)); + } + + // ---------------------------------------------------------------------- + // it matches. gen value + + auto str = reg.as_string(); + + assert(str.back() == '\''); + str.pop_back(); + assert(str.at(0) == '\''); + str.erase(0, 1); + + using string_type = typename basic_value::string_type; + string_type val(str.begin(), str.end()); + + return ok(std::make_pair(std::move(val), std::move(reg))); + } + + template + result, error_info> parse_literal_string(location& loc, + const context& ctx) { + const auto first = loc; + + string_format_info fmt; + fmt.fmt = string_format::literal; + + auto val_res = parse_literal_string_only(loc, ctx); + if (val_res.is_err()) { + return err(std::move(val_res.unwrap_err())); + } + auto val = std::move(val_res.unwrap().first); + auto reg = std::move(val_res.unwrap().second); + + return ok( + basic_value(std::move(val), std::move(fmt), {}, std::move(reg))); + } + + template + result, error_info> parse_string(location& loc, + const context& ctx) { + const auto first = loc; + + if (!loc.eof() && loc.current() == '"') { + if (literal("\"\"\"").scan(loc).is_ok()) { + loc = first; + return parse_ml_basic_string(loc, ctx); + } else { + loc = first; + return parse_basic_string(loc, ctx); + } + } else if (!loc.eof() && loc.current() == '\'') { + if (literal("'''").scan(loc).is_ok()) { + loc = first; + return parse_ml_literal_string(loc, ctx); + } else { + loc = first; + return parse_literal_string(loc, ctx); + } + } else { + auto src = source_location(region(loc)); + return err(make_error_info("toml::parse_string: " + "not a string", + std::move(src), + "here")); + } + } + + template + result, error_info> parse_null(location& loc, + const context& ctx) { + const auto& spec = ctx.toml_spec(); + if (!spec.ext_null_value) { + return err( + make_error_info("toml::parse_null: " + "invalid spec: spec.ext_null_value must be true.", + source_location(region(loc)), + "here")); + } + + // ---------------------------------------------------------------------- + // check syntax + auto reg = syntax::null_value(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error("toml::parse_null: " + "invalid null: null must be lowercase. ", + syntax::null_value(spec), + loc)); + } + + // ---------------------------------------------------------------------- + // it matches. gen value + + // ---------------------------------------------------------------------- + // no format info for boolean + + return ok(basic_value(detail::none_t {}, std::move(reg))); + } + + /* ============================================================================ + * _ __ + * | |/ /___ _ _ + * | ' + result::key_type, error_info> parse_simple_key( + location& loc, + const context& ctx) { + using key_type = typename basic_value::key_type; + const auto& spec = ctx.toml_spec(); + + if (loc.current() == '\"') { + auto str_res = parse_basic_string_only(loc, ctx); + if (str_res.is_ok()) { + return ok(std::move(str_res.unwrap().first)); + } else { + return err(std::move(str_res.unwrap_err())); + } + } else if (loc.current() == '\'') { + auto str_res = parse_literal_string_only(loc, ctx); + if (str_res.is_ok()) { + return ok(std::move(str_res.unwrap().first)); + } else { + return err(std::move(str_res.unwrap_err())); + } + } + + // bare key. + + if (const auto bare = syntax::unquoted_key(spec).scan(loc)) { + return ok(string_conv(bare.as_string())); + } else { + std::string postfix; + if (spec.v1_1_0_allow_non_english_in_bare_keys) { + postfix = + "Hint: Not all Unicode characters are allowed as bare key.\n"; + } else { + postfix = "Hint: non-ASCII scripts are allowed in toml v1.1.0, but " + "not in v1.0.0.\n"; + } + return err(make_syntax_error( + "toml::parse_simple_key: " + "invalid key: key must be \"quoted\", 'quoted-literal', or bare key.", + syntax::unquoted_key(spec), + loc, + postfix)); + } + } + + // dotted key become vector of keys + template + result::key_type>, region>, error_info> + parse_key(location& loc, const context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + using key_type = typename basic_value::key_type; + std::vector keys; + while (!loc.eof()) { + auto key = parse_simple_key(loc, ctx); + if (!key.is_ok()) { + return err(key.unwrap_err()); + } + keys.push_back(std::move(key.unwrap())); + + auto reg = syntax::dot_sep(spec).scan(loc); + if (!reg.is_ok()) { + break; + } + } + if (keys.empty()) { + auto src = source_location(region(first)); + return err(make_error_info("toml::parse_key: expected a new key, " + "but got nothing", + std::move(src), + "reached EOF")); + } + + return ok(std::make_pair(std::move(keys), region(first, loc))); + } + + // ============================================================================ + + // forward-decl to implement parse_array and parse_table + template + result, error_info> parse_value(location&, context& ctx); + + template + result::key_type>, region>, + basic_value>, + error_info> + parse_key_value_pair(location& loc, context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + auto key_res = parse_key(loc, ctx); + if (key_res.is_err()) { + loc = first; + return err(key_res.unwrap_err()); + } + + if (!syntax::keyval_sep(spec).scan(loc).is_ok()) { + auto e = make_syntax_error("toml::parse_key_value_pair: " + "invalid key value separator `=`", + syntax::keyval_sep(spec), + loc); + loc = first; + return err(std::move(e)); + } + + auto v_res = parse_value(loc, ctx); + if (v_res.is_err()) { + // loc = first; + return err(v_res.unwrap_err()); + } + return ok( + std::make_pair(std::move(key_res.unwrap()), std::move(v_res.unwrap()))); + } + + /* ============================================================================ + * __ _ _ _ _ _ __ _ _ _ + * / _` | '_| '_/ _` | || | + * \__,_|_| |_| \__,_|\_, | + * |__/ + */ + + // array(and multiline inline table with `{` and `}`) has the following format. + // `[` + // (ws|newline|comment-line)? (value) (ws|newline|comment-line)? `,` + // (ws|newline|comment-line)? (value) (ws|newline|comment-line)? `,` + // ... + // (ws|newline|comment-line)? (value) (ws|newline|comment-line)? (`,`)? + // (ws|newline|comment-line)? `]` + // it skips (ws|newline|comment-line) and returns the token. + template + struct multiline_spacer { + using comment_type = typename TC::comment_type; + bool newline_found; + indent_char indent_type; + std::int32_t indent; + comment_type comments; + }; + + template + std::ostream& operator<<(std::ostream& os, const multiline_spacer& sp) { + os << "{newline=" << sp.newline_found << ", "; + os << "indent_type=" << sp.indent_type << ", "; + os << "indent=" << sp.indent << ", "; + os << "comments=" << sp.comments.size() << "}"; + return os; + } + + template + cxx::optional> skip_multiline_spacer( + location& loc, + context& ctx, + const bool newline_found = false) { + const auto& spec = ctx.toml_spec(); + + multiline_spacer spacer; + spacer.newline_found = newline_found; + spacer.indent_type = indent_char::none; + spacer.indent = 0; + spacer.comments.clear(); + + bool spacer_found = false; + while (!loc.eof()) { + if (auto comm = sequence(syntax::comment(spec), syntax::newline(spec)) + .scan(loc)) { + spacer.newline_found = true; + auto comment = comm.as_string(); + if (!comment.empty() && comment.back() == '\n') { + comment.pop_back(); + if (!comment.empty() && comment.back() == '\r') { + comment.pop_back(); + } + } + + spacer.comments.push_back(std::move(comment)); + spacer.indent_type = indent_char::none; + spacer.indent = 0; + spacer_found = true; + } else if (auto nl = syntax::newline(spec).scan(loc)) { + spacer.newline_found = true; + spacer.comments.clear(); + spacer.indent_type = indent_char::none; + spacer.indent = 0; + spacer_found = true; + } else if (auto sp = repeat_at_least( + 1, + character(cxx::bit_cast(' '))) + .scan(loc)) { + spacer.indent_type = indent_char::space; + spacer.indent = static_cast(sp.length()); + spacer_found = true; + } else if ( + auto tabs = repeat_at_least( + 1, + character(cxx::bit_cast('\t'))) + .scan(loc)) { + spacer.indent_type = indent_char::tab; + spacer.indent = static_cast(tabs.length()); + spacer_found = true; + } else { + break; // done + } + } + if (!spacer_found) { + return cxx::make_nullopt(); + } + return spacer; + } + + // not an [[array.of.tables]]. It parses ["this", "type"] + template + result, error_info> parse_array(location& loc, + context& ctx) { + const auto num_errors = ctx.errors().size(); + + const auto first = loc; + + if (loc.eof() || loc.current() != '[') { + auto src = source_location(region(loc)); + return err(make_error_info("toml::parse_array: " + "The next token is not an array", + std::move(src), + "here")); + } + loc.advance(); + + typename basic_value::array_type val; + + array_format_info fmt; + fmt.fmt = array_format::oneline; + fmt.indent_type = indent_char::none; + + auto spacer = skip_multiline_spacer(loc, ctx); + if (spacer.has_value() && spacer.value().newline_found) { + fmt.fmt = array_format::multiline; + } + + bool comma_found = true; + while (!loc.eof()) { + if (loc.current() == location::char_type(']')) { + if (spacer.has_value() && spacer.value().newline_found && + spacer.value().indent_type != indent_char::none) { + fmt.indent_type = spacer.value().indent_type; + fmt.closing_indent = spacer.value().indent; + } + break; + } + + if (!comma_found) { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_array: " + "expected value-separator `,` or closing `]`", + std::move(src), + "here")); + } + + if (spacer.has_value() && spacer.value().newline_found && + spacer.value().indent_type != indent_char::none) { + fmt.indent_type = spacer.value().indent_type; + fmt.body_indent = spacer.value().indent; + } + + if (auto elem_res = parse_value(loc, ctx)) { + auto elem = std::move(elem_res.unwrap()); + + if (spacer.has_value()) // copy previous comments to value + { + elem.comments() = std::move(spacer.value().comments); + } + + // parse spaces between a value and a comma + // array = [ + // 42 , # the answer + // ^^^^ + // 3.14 # pi + // , 2.71 ^^^^ + // ^^ + spacer = skip_multiline_spacer(loc, ctx); + if (spacer.has_value()) { + for (std::size_t i = 0; i < spacer.value().comments.size(); ++i) { + elem.comments().push_back(std::move(spacer.value().comments.at(i))); + } + if (spacer.value().newline_found) { + fmt.fmt = array_format::multiline; + } + } + + comma_found = character(',').scan(loc).is_ok(); + + // parse comment after a comma + // array = [ + // 42 , # the answer + // ^^^^^^^^^^^^ + // 3.14 # pi + // ^^^^ + // ] + auto com_res = parse_comment_line(loc, ctx); + if (com_res.is_err()) { + ctx.report_error(com_res.unwrap_err()); + } + + const bool comment_found = com_res.is_ok() && + com_res.unwrap().has_value(); + if (comment_found) { + fmt.fmt = array_format::multiline; + elem.comments().push_back(com_res.unwrap().value()); + } + if (comma_found) { + spacer = skip_multiline_spacer(loc, ctx, comment_found); + if (spacer.has_value() && spacer.value().newline_found) { + fmt.fmt = array_format::multiline; + } + } + val.push_back(std::move(elem)); + } else { + // if err, push error to ctx and try recovery. + ctx.report_error(std::move(elem_res.unwrap_err())); + + // if it looks like some value, then skip the value. + // otherwise, it may be a new key-value pair or a new table and + // the error is "missing closing ]". stop parsing. + + const auto before_skip = loc.get_location(); + skip_value(loc, ctx); + if (before_skip == loc.get_location()) // cannot skip! break... + { + break; + } + } + } + + if (loc.current() != ']') { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_array: missing closing bracket `]`", + std::move(src), + "expected `]`, reached EOF")); + } else { + loc.advance(); + } + // any error reported from this function + if (num_errors != ctx.errors().size()) { + assert(ctx.has_error()); // already reported + return err(ctx.errors().back()); + } + + return ok( + basic_value(std::move(val), std::move(fmt), {}, region(first, loc))); + } + + /* ============================================================================ + * _ _ _ _ _ _ + * (_)_ _ | (_)_ _ ___ | |_ __ _| |__| |___ + * | | ' \| | | ' \/ -_) | _/ _` | '_ \ / -_) + * |_|_||_|_|_|_||_\___| \__\__,_|_.__/_\___| + */ + + // ---------------------------------------------------------------------------- + // insert_value is the most complicated part of the toml spec. + // + // To parse a toml file correctly, we sometimes need to check an exising + // value is appendable or not. + // + // For example while parsing an inline array of tables, + // + // ```toml + // aot = [ + // {a = "foo"}, + // {a = "bar", b = "baz"}, + // ] + // ``` + // + // this `aot` is appendable until parser reaches to `]`. After that, it + // becomes non-appendable. + // + // On the other hand, a normal array of tables, such as + // + // ```toml + // [[aot]] + // a = "foo" + // + // [[aot]] + // a = "bar" + // b = "baz" + // ``` + // This `[[aot]]` is appendable until the parser reaches to the EOF. + // + // + // It becomes a bit more difficult in case of dotted keys. + // In TOML, it is allowed to append a key-value pair to a table that is + // *implicitly* defined by a subtable definitino. + // + // ```toml + // [x.y.z] + // w = 123 + // + // [x] + // a = "foo" # OK. x is defined implicitly by `[x.y.z]`. + // ``` + // + // But if the table is defined by a dotted keys, it is not appendable. + // + // ```toml + // [x] + // y.z.w = 123 + // + // [x.y] + // # ERROR. x.y is already defined by a dotted table in the previous table. + // ``` + // + // Also, reopening a table using dotted keys is invalid. + // + // ```toml + // [x.y.z] + // w = 123 + // + // [x] + // y.z.v = 42 # ERROR. [x.y.z] is already defined. + // ``` + // + // + // ```toml + // [a] + // b.c = "foo" + // b.d = "bar" + // ``` + // + // + // ```toml + // a.b = "foo" + // [a] + // c = "bar" # ERROR + // ``` + // + // In summary, + // - a table must be defined only once. + // - assignment to an exising table is possible only when: + // - defining a subtable [x.y] to an existing table [x]. + // - defining supertable [x] explicitly after [x.y]. + // - adding dotted keys in the same table. + + enum class inserting_value_kind : std::uint8_t { + std_table, // insert [standard.table] + array_table, // insert [[array.of.tables]] + dotted_keys // insert a.b.c = "this" + }; + + template + result*, error_info> insert_value( + const inserting_value_kind kind, + typename basic_value::table_type* current_table_ptr, + const std::vector::key_type>& keys, + region key_reg, + basic_value val) { + using value_type = basic_value; + using array_type = typename basic_value::array_type; + using table_type = typename basic_value::table_type; + + auto key_loc = source_location(key_reg); + + assert(!keys.empty()); + + // dotted key can insert to dotted key tables defined at the same level. + // dotted key can NOT reopen a table even if it is implcitly-defined one. + // + // [x.y.z] # define x and x.y implicitly. + // a = 42 + // + // [x] # reopening implcitly defined table + // r.s.t = 3.14 # VALID r and r.s are new tables. + // r.s.u = 2.71 # VALID r and r.s are dotted-key tables. valid. + // + // y.z.b = "foo" # INVALID x.y.z are multiline table, not a dotted key. + // y.c = "bar" # INVALID x.y is implicit multiline table, not a dotted key. + + // a table cannot reopen dotted-key tables. + // + // [t1] + // t2.t3.v = 0 + // [t1.t2] # INVALID t1.t2 is defined as a dotted-key table. + + for (std::size_t i = 0; i < keys.size(); ++i) { + const auto& key = keys.at(i); + table_type& current_table = *current_table_ptr; + + if (i + 1 < keys.size()) // there are more keys. go down recursively... + { + const auto found = current_table.find(key); + if (found == current_table.end()) // not found. add new table + { + table_format_info fmt; + fmt.indent_type = indent_char::none; + if (kind == inserting_value_kind::dotted_keys) { + fmt.fmt = table_format::dotted; + } else // table / array of tables + { + fmt.fmt = table_format::implicit; + } + current_table.emplace( + key, + value_type(table_type {}, fmt, std::vector {}, key_reg)); + + assert(current_table.at(key).is_table()); + current_table_ptr = std::addressof(current_table.at(key).as_table()); + } else if (found->second.is_table()) { + const auto fmt = found->second.as_table_fmt().fmt; + if (fmt == table_format::oneline || + fmt == table_format::multiline_oneline) { + // foo = {bar = "baz"} or foo = { \n bar = "baz" \n } + return err(make_error_info( + "toml::insert_value: " + "failed to insert a value: inline table is immutable", + key_loc, + "inserting this", + found->second.location(), + "to this table")); + } + // dotted key cannot reopen a table. + if (kind == inserting_value_kind::dotted_keys && + fmt != table_format::dotted) { + return err(make_error_info("toml::insert_value: " + "reopening a table using dotted keys", + key_loc, + "dotted key cannot reopen a table", + found->second.location(), + "this table is already closed")); + } + assert(found->second.is_table()); + current_table_ptr = std::addressof(found->second.as_table()); + } else if (found->second.is_array_of_tables()) { + // aot = [{this = "type", of = "aot"}] # cannot be reopened + if (found->second.as_array_fmt().fmt != array_format::array_of_tables) { + return err(make_error_info("toml::insert_value:" + "inline array of tables are immutable", + key_loc, + "inserting this", + found->second.location(), + "inline array of tables")); + } + // appending to [[aot]] + + if (kind == inserting_value_kind::dotted_keys) { + // [[array.of.tables]] + // [array.of] # reopening supertable is okay + // tables.x = "foo" # appending `x` to the first table + return err( + make_error_info("toml::insert_value:" + "dotted key cannot reopen an array-of-tables", + key_loc, + "inserting this", + found->second.location(), + "to this array-of-tables.")); + } + + // insert_value_by_dotkeys::std_table + // [[array.of.tables]] + // [array.of.tables.subtable] # appending to the last aot + // + // insert_value_by_dotkeys::array_table + // [[array.of.tables]] + // [[array.of.tables.subtable]] # appending to the last aot + auto& current_array_table = found->second.as_array().back(); + + assert(current_array_table.is_table()); + current_table_ptr = std::addressof(current_array_table.as_table()); + } else { + return err( + make_error_info("toml::insert_value: " + "failed to insert a value, value already exists", + key_loc, + "while inserting this", + found->second.location(), + "non-table value already exists")); + } + } else // this is the last key. insert a new value. + { + switch (kind) { + case inserting_value_kind::dotted_keys: { + if (current_table.find(key) != current_table.end()) { + return err(make_error_info( + "toml::insert_value: " + "failed to insert a value, value already exists", + key_loc, + "inserting this", + current_table.at(key).location(), + "but value already exists")); + } + current_table.emplace(key, std::move(val)); + return ok(std::addressof(current_table.at(key))); + } + case inserting_value_kind::std_table: { + // defining a new table or reopening supertable + auto found = current_table.find(key); + if (found == current_table.end()) // define a new aot + { + current_table.emplace(key, std::move(val)); + return ok(std::addressof(current_table.at(key))); + } else // the table is already defined, reopen it + { + // assigning a [std.table]. it must be an implicit table. + auto& target = found->second; + if (!target.is_table() || // could be an array-of-tables + target.as_table_fmt().fmt != table_format::implicit) { + return err(make_error_info( + "toml::insert_value: " + "failed to insert a table, table already defined", + key_loc, + "inserting this", + target.location(), + "this table is explicitly defined")); + } + + // merge table + for (const auto& kv : val.as_table()) { + if (target.contains(kv.first)) { + // [x.y.z] + // w = "foo" + // [x] + // y = "bar" + return err( + make_error_info("toml::insert_value: " + "failed to insert a table, table keys " + "conflict to each other", + key_loc, + "inserting this table", + kv.second.location(), + "having this value", + target.at(kv.first).location(), + "already defined here")); + } else { + target[kv.first] = kv.second; + } + } + // change implicit -> explicit + target.as_table_fmt().fmt = table_format::multiline; + // change definition region + change_region_of_value(target, val); + + return ok(std::addressof(current_table.at(key))); + } + } + case inserting_value_kind::array_table: { + auto found = current_table.find(key); + if (found == current_table.end()) // define a new aot + { + array_format_info fmt; + fmt.fmt = array_format::array_of_tables; + fmt.indent_type = indent_char::none; + + current_table.emplace(key, + value_type(array_type { std::move(val) }, + std::move(fmt), + std::vector {}, + std::move(key_reg))); + + assert(!current_table.at(key).as_array().empty()); + return ok(std::addressof(current_table.at(key).as_array().back())); + } else // the array is already defined, append to it + { + if (!found->second.is_array_of_tables()) { + return err(make_error_info( + "toml::insert_value: " + "failed to insert an array of tables, value already exists", + key_loc, + "while inserting this", + found->second.location(), + "non-table value already exists")); + } + if (found->second.as_array_fmt().fmt != + array_format::array_of_tables) { + return err(make_error_info("toml::insert_value: " + "failed to insert a table, inline " + "array of tables is immutable", + key_loc, + "while inserting this", + found->second.location(), + "this is inline array-of-tables")); + } + found->second.as_array().push_back(std::move(val)); + assert(!current_table.at(key).as_array().empty()); + return ok(std::addressof(current_table.at(key).as_array().back())); + } + } + default: { + assert(false); + } + } + } + } + return err(make_error_info("toml::insert_key: no keys found", + std::move(key_loc), + "here")); + } + + // ---------------------------------------------------------------------------- + + template + result, error_info> parse_inline_table(location& loc, + context& ctx) { + using table_type = typename basic_value::table_type; + + const auto num_errors = ctx.errors().size(); + + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + if (loc.eof() || loc.current() != '{') { + auto src = source_location(region(loc)); + return err(make_error_info("toml::parse_inline_table: " + "The next token is not an inline table", + std::move(src), + "here")); + } + loc.advance(); + + table_type table; + table_format_info fmt; + fmt.fmt = table_format::oneline; + fmt.indent_type = indent_char::none; + + cxx::optional> spacer(cxx::make_nullopt()); + + if (spec.v1_1_0_allow_newlines_in_inline_tables) { + spacer = skip_multiline_spacer(loc, ctx); + if (spacer.has_value() && spacer.value().newline_found) { + fmt.fmt = table_format::multiline_oneline; + } + } else { + skip_whitespace(loc, ctx); + } + + bool still_empty = true; + bool comma_found = false; + while (!loc.eof()) { + // closing! + if (loc.current() == '}') { + if (comma_found && !spec.v1_1_0_allow_trailing_comma_in_inline_tables) { + auto src = source_location(region(loc)); + return err(make_error_info("toml::parse_inline_table: trailing " + "comma is not allowed in TOML-v1.0.0)", + std::move(src), + "here")); + } + + if (spec.v1_1_0_allow_newlines_in_inline_tables) { + if (spacer.has_value() && spacer.value().newline_found && + spacer.value().indent_type != indent_char::none) { + fmt.indent_type = spacer.value().indent_type; + fmt.closing_indent = spacer.value().indent; + } + } + break; + } + + // if we already found a value and didn't found `,` nor `}`, error. + if (!comma_found && !still_empty) { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_inline_table: " + "expected value-separator `,` or closing `}`", + std::move(src), + "here")); + } + + // parse indent. + if (spacer.has_value() && spacer.value().newline_found && + spacer.value().indent_type != indent_char::none) { + fmt.indent_type = spacer.value().indent_type; + fmt.body_indent = spacer.value().indent; + } + + still_empty = false; // parsing a value... + if (auto kv_res = parse_key_value_pair(loc, ctx)) { + auto keys = std::move(kv_res.unwrap().first.first); + auto key_reg = std::move(kv_res.unwrap().first.second); + auto val = std::move(kv_res.unwrap().second); + + auto ins_res = insert_value(inserting_value_kind::dotted_keys, + std::addressof(table), + keys, + std::move(key_reg), + std::move(val)); + if (ins_res.is_err()) { + ctx.report_error(std::move(ins_res.unwrap_err())); + // we need to skip until the next value (or end of the table) + // because we don't have valid kv pair. + while (!loc.eof()) { + const auto c = loc.current(); + if (c == ',' || c == '\n' || c == '}') { + comma_found = (c == ','); + break; + } + loc.advance(); + } + continue; + } + + // if comment line follows immediately(without newline) after `,`, then + // the comment is for the elem. we need to check if comment follows `,`. + // + // (key) = (val) (ws|newline|comment-line)? `,` (ws)? (comment)? + + if (spec.v1_1_0_allow_newlines_in_inline_tables) { + if (spacer.has_value()) // copy previous comments to value + { + for (std::size_t i = 0; i < spacer.value().comments.size(); ++i) { + ins_res.unwrap()->comments().push_back( + spacer.value().comments.at(i)); + } + } + spacer = skip_multiline_spacer(loc, ctx); + if (spacer.has_value()) { + for (std::size_t i = 0; i < spacer.value().comments.size(); ++i) { + ins_res.unwrap()->comments().push_back( + spacer.value().comments.at(i)); + } + if (spacer.value().newline_found) { + fmt.fmt = table_format::multiline_oneline; + if (spacer.value().indent_type != indent_char::none) { + fmt.indent_type = spacer.value().indent_type; + fmt.body_indent = spacer.value().indent; + } + } + } + } else { + skip_whitespace(loc, ctx); + } + + comma_found = character(',').scan(loc).is_ok(); + + if (spec.v1_1_0_allow_newlines_in_inline_tables) { + auto com_res = parse_comment_line(loc, ctx); + if (com_res.is_err()) { + ctx.report_error(com_res.unwrap_err()); + } + const bool comment_found = com_res.is_ok() && + com_res.unwrap().has_value(); + if (comment_found) { + fmt.fmt = table_format::multiline_oneline; + ins_res.unwrap()->comments().push_back(com_res.unwrap().value()); + } + if (comma_found) { + spacer = skip_multiline_spacer(loc, ctx, comment_found); + if (spacer.has_value() && spacer.value().newline_found) { + fmt.fmt = table_format::multiline_oneline; + } + } + } else { + skip_whitespace(loc, ctx); + } + } else { + ctx.report_error(std::move(kv_res.unwrap_err())); + while (!loc.eof()) { + if (loc.current() == '}') { + break; + } + if (!spec.v1_1_0_allow_newlines_in_inline_tables && + loc.current() == '\n') { + break; + } + loc.advance(); + } + break; + } + } + + if (loc.current() != '}') { + auto src = source_location(region(loc)); + return err(make_error_info("toml::parse_inline_table: " + "missing closing bracket `}`", + std::move(src), + "expected `}`, reached line end")); + } else { + loc.advance(); // skip } + } + + // any error reported from this function + if (num_errors < ctx.errors().size()) { + assert(ctx.has_error()); // already reported + return err(ctx.pop_last_error()); + } + + basic_value retval(std::move(table), std::move(fmt), {}, region(first, loc)); + + return ok(std::move(retval)); + } + + /* ============================================================================ + * _ + * __ ____ _| |_ _ ___ + * \ V / _` | | || / -_) + * \_/\__,_|_|\_,_\___| + */ + + template + result guess_number_type(const location& first, + const context& ctx) { + const auto& spec = ctx.toml_spec(); + location loc = first; + + if (syntax::offset_datetime(spec).scan(loc).is_ok()) { + return ok(value_t::offset_datetime); + } + loc = first; + + if (syntax::local_datetime(spec).scan(loc).is_ok()) { + const auto curr = loc.current(); + // if offset_datetime contains bad offset, it syntax::offset_datetime + // fails to scan it. + if (curr == '+' || curr == '-') { + return err( + make_syntax_error("bad offset: must be [+-]HH:MM or Z", + syntax::time_offset(spec), + loc, + std::string("Hint: valid : +09:00, -05:30\n" + "Hint: invalid: +9:00, -5:30\n"))); + } + return ok(value_t::local_datetime); + } + loc = first; + + if (syntax::local_date(spec).scan(loc).is_ok()) { + // bad time may appear after this. + + if (!loc.eof()) { + const auto c = loc.current(); + if (c == 'T' || c == 't') { + loc.advance(); + + return err(make_syntax_error( + "bad time: must be HH:MM:SS.subsec", + syntax::local_time(spec), + loc, + std::string( + "Hint: valid : 1979-05-27T07:32:00, 1979-05-27 " + "07:32:00.999999\n" + "Hint: invalid: 1979-05-27T7:32:00, 1979-05-27 17:32\n"))); + } + if (c == ' ') { + // A space is allowed as a delimiter between local time. + // But there is a case where bad time follows a space. + // - invalid: 2019-06-16 7:00:00 + // - valid : 2019-06-16 07:00:00 + loc.advance(); + if (!loc.eof() && ('0' <= loc.current() && loc.current() <= '9')) { + return err(make_syntax_error( + "bad time: must be HH:MM:SS.subsec", + syntax::local_time(spec), + loc, + std::string( + "Hint: valid : 1979-05-27T07:32:00, 1979-05-27 " + "07:32:00.999999\n" + "Hint: invalid: 1979-05-27T7:32:00, 1979-05-27 17:32\n"))); + } + } + if ('0' <= c && c <= '9') { + return err(make_syntax_error( + "bad datetime: missing T or space", + character_either { 'T', 't', ' ' }, + loc, + std::string( + "Hint: valid : 1979-05-27T07:32:00, 1979-05-27 " + "07:32:00.999999\n" + "Hint: invalid: 1979-05-27T7:32:00, 1979-05-27 17:32\n"))); + } + } + return ok(value_t::local_date); + } + loc = first; + + if (syntax::local_time(spec).scan(loc).is_ok()) { + return ok(value_t::local_time); + } + loc = first; + + if (syntax::floating(spec).scan(loc).is_ok()) { + if (!loc.eof() && loc.current() == '_') { + if (spec.ext_num_suffix && syntax::num_suffix(spec).scan(loc).is_ok()) { + return ok(value_t::floating); + } + auto src = source_location(region(loc)); + return err(make_error_info( + "bad float: `_` must be surrounded by digits", + std::move(src), + "invalid underscore", + "Hint: valid : +1.0, -2e-2, 3.141_592_653_589, inf, nan\n" + "Hint: invalid: .0, 1., _1.0, 1.0_, 1_.0, 1.0__0\n")); + } + return ok(value_t::floating); + } + loc = first; + + if (spec.ext_hex_float) { + if (syntax::hex_floating(spec).scan(loc).is_ok()) { + if (!loc.eof() && loc.current() == '_') { + if (spec.ext_num_suffix && syntax::num_suffix(spec).scan(loc).is_ok()) { + return ok(value_t::floating); + } + auto src = source_location(region(loc)); + return err(make_error_info( + "bad float: `_` must be surrounded by digits", + std::move(src), + "invalid underscore", + "Hint: valid : +1.0, -2e-2, 3.141_592_653_589, inf, nan\n" + "Hint: invalid: .0, 1., _1.0, 1.0_, 1_.0, 1.0__0\n")); + } + return ok(value_t::floating); + } + loc = first; + } + + if (auto int_reg = syntax::integer(spec).scan(loc)) { + if (!loc.eof()) { + const auto c = loc.current(); + if (c == '_') { + if (spec.ext_num_suffix && syntax::num_suffix(spec).scan(loc).is_ok()) { + return ok(value_t::integer); + } + + if (int_reg.length() <= 2 && + (int_reg.as_string() == "0" || int_reg.as_string() == "-0" || + int_reg.as_string() == "+0")) { + auto src = source_location(region(loc)); + return err(make_error_info( + "bad integer: leading zero is not allowed in decimal int", + std::move(src), + "leading zero", + "Hint: valid : -42, 1_000, 1_2_3_4_5, 0xC0FFEE, 0b0010, " + "0o755\n" + "Hint: invalid: _42, 1__000, 0123\n")); + } else { + auto src = source_location(region(loc)); + return err( + make_error_info("bad integer: `_` must be surrounded by digits", + std::move(src), + "invalid underscore", + "Hint: valid : -42, 1_000, 1_2_3_4_5, " + "0xC0FFEE, 0b0010, 0o755\n" + "Hint: invalid: _42, 1__000, 0123\n")); + } + } + if ('0' <= c && c <= '9') { + if (loc.current() == '0') { + loc.retrace(); + return err(make_error_info( + "bad integer: leading zero", + source_location(region(loc)), + "leading zero is not allowed", + std::string("Hint: valid : -42, 1_000, 1_2_3_4_5, 0xC0FFEE, " + "0b0010, 0o755\n" + "Hint: invalid: _42, 1__000, 0123\n"))); + } else // invalid digits, especially in oct/bin ints. + { + return err(make_error_info( + "bad integer: invalid digit after an integer", + source_location(region(loc)), + "this digit is not allowed", + std::string("Hint: valid : -42, 1_000, 1_2_3_4_5, 0xC0FFEE, " + "0b0010, 0o755\n" + "Hint: invalid: _42, 1__000, 0123\n"))); + } + } + if (c == ':' || c == '-') { + auto src = source_location(region(loc)); + return err(make_error_info( + "bad datetime: invalid format", + std::move(src), + "here", + std::string("Hint: valid : 1979-05-27T07:32:00-07:00, " + "1979-05-27 07:32:00.999999Z\n" + "Hint: invalid: 1979-05-27T7:32:00-7:00, 1979-05-27 " + "7:32-00:30"))); + } + if (c == '.' || c == 'e' || c == 'E') { + auto src = source_location(region(loc)); + return err(make_error_info( + "bad float: invalid format", + std::move(src), + "here", + std::string( + "Hint: valid : +1.0, -2e-2, 3.141_592_653_589, inf, nan\n" + "Hint: invalid: .0, 1., _1.0, 1.0_, 1_.0, 1.0__0\n"))); + } + } + return ok(value_t::integer); + } + if (!loc.eof() && loc.current() == '.') { + auto src = source_location(region(loc)); + return err(make_error_info( + "bad float: integer part is required before decimal point", + std::move(src), + "missing integer part", + std::string( + "Hint: valid : +1.0, -2e-2, 3.141_592_653_589, inf, nan\n" + "Hint: invalid: .0, 1., _1.0, 1.0_, 1_.0, 1.0__0\n"))); + } + if (!loc.eof() && loc.current() == '_') { + auto src = source_location(region(loc)); + return err(make_error_info( + "bad number: `_` must be surrounded by digits", + std::move(src), + "digits required before `_`", + std::string( + "Hint: valid : -42, 1_000, 1_2_3_4_5, 0xC0FFEE, 0b0010, 0o755\n" + "Hint: invalid: _42, 1__000, 0123\n"))); + } + + auto src = source_location(region(loc)); + return err(make_error_info("bad format: unknown value appeared", + std::move(src), + "here")); + } + + template + result guess_value_type(const location& loc, + const context& ctx) { + const auto& sp = ctx.toml_spec(); + location inner(loc); + + switch (loc.current()) { + case '"': { + return ok(value_t::string); + } + case '\'': { + return ok(value_t::string); + } + case '[': { + return ok(value_t::array); + } + case '{': { + return ok(value_t::table); + } + case 't': { + return ok(value_t::boolean); + } + case 'f': { + return ok(value_t::boolean); + } + case 'T': // invalid boolean. + { + return err(make_syntax_error("toml::parse_value: " + "`true` must be in lowercase. " + "A string must be surrounded by quotes.", + syntax::boolean(sp), + inner)); + } + case 'F': { + return err(make_syntax_error("toml::parse_value: " + "`false` must be in lowercase. " + "A string must be surrounded by quotes.", + syntax::boolean(sp), + inner)); + } + case 'i': // inf or string without quotes(syntax error). + { + if (literal("inf").scan(inner).is_ok()) { + return ok(value_t::floating); + } else { + return err( + make_syntax_error("toml::parse_value: " + "`inf` must be in lowercase. " + "A string must be surrounded by quotes.", + syntax::floating(sp), + inner)); + } + } + case 'I': // Inf or string without quotes(syntax error). + { + return err(make_syntax_error("toml::parse_value: " + "`inf` must be in lowercase. " + "A string must be surrounded by quotes.", + syntax::floating(sp), + inner)); + } + case 'n': // nan or null-extension + { + if (sp.ext_null_value) { + if (literal("nan").scan(inner).is_ok()) { + return ok(value_t::floating); + } else if (literal("null").scan(inner).is_ok()) { + return ok(value_t::empty); + } else { + return err( + make_syntax_error("toml::parse_value: " + "Both `nan` and `null` must be in lowercase. " + "A string must be surrounded by quotes.", + syntax::floating(sp), + inner)); + } + } else // must be nan. + { + if (literal("nan").scan(inner).is_ok()) { + return ok(value_t::floating); + } else { + return err( + make_syntax_error("toml::parse_value: " + "`nan` must be in lowercase. " + "A string must be surrounded by quotes.", + syntax::floating(sp), + inner)); + } + } + } + case 'N': // nan or null-extension + { + if (sp.ext_null_value) { + return err( + make_syntax_error("toml::parse_value: " + "Both `nan` and `null` must be in lowercase. " + "A string must be surrounded by quotes.", + syntax::floating(sp), + inner)); + } else { + return err( + make_syntax_error("toml::parse_value: " + "`nan` must be in lowercase. " + "A string must be surrounded by quotes.", + syntax::floating(sp), + inner)); + } + } + default: { + return guess_number_type(loc, ctx); + } + } + } + + template + result, error_info> parse_value(location& loc, + context& ctx) { + const auto ty_res = guess_value_type(loc, ctx); + if (ty_res.is_err()) { + return err(ty_res.unwrap_err()); + } + + switch (ty_res.unwrap()) { + case value_t::empty: { + if (ctx.toml_spec().ext_null_value) { + return parse_null(loc, ctx); + } else { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_value: unknown value appeared", + std::move(src), + "here")); + } + } + case value_t::boolean: { + return parse_boolean(loc, ctx); + } + case value_t::integer: { + return parse_integer(loc, ctx); + } + case value_t::floating: { + return parse_floating(loc, ctx); + } + case value_t::string: { + return parse_string(loc, ctx); + } + case value_t::offset_datetime: { + return parse_offset_datetime(loc, ctx); + } + case value_t::local_datetime: { + return parse_local_datetime(loc, ctx); + } + case value_t::local_date: { + return parse_local_date(loc, ctx); + } + case value_t::local_time: { + return parse_local_time(loc, ctx); + } + case value_t::array: { + return parse_array(loc, ctx); + } + case value_t::table: { + return parse_inline_table(loc, ctx); + } + default: { + auto src = source_location(region(loc)); + return err( + make_error_info("toml::parse_value: unknown value appeared", + std::move(src), + "here")); + } + } + } + + /* ============================================================================ + * _____ _ _ + * |_ _|_ _| |__| |___ + * | |/ _` | '_ \ / -_) + * |_|\__,_|_.__/_\___| + */ + + template + result::key_type>, region>, error_info> + parse_table_key(location& loc, context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + auto reg = syntax::std_table(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error("toml::parse_table_key: invalid table key", + syntax::std_table(spec), + loc)); + } + + loc = first; + loc.advance(); // skip [ + skip_whitespace(loc, ctx); + + auto keys_res = parse_key(loc, ctx); + if (keys_res.is_err()) { + return err(std::move(keys_res.unwrap_err())); + } + + skip_whitespace(loc, ctx); + loc.advance(); // ] + + return ok(std::make_pair(std::move(keys_res.unwrap().first), std::move(reg))); + } + + template + result::key_type>, region>, error_info> + parse_array_table_key(location& loc, context& ctx) { + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + auto reg = syntax::array_table(spec).scan(loc); + if (!reg.is_ok()) { + return err(make_syntax_error( + "toml::parse_array_table_key: invalid array-of-tables key", + syntax::array_table(spec), + loc)); + } + + loc = first; + loc.advance(); // [ + loc.advance(); // [ + skip_whitespace(loc, ctx); + + auto keys_res = parse_key(loc, ctx); + if (keys_res.is_err()) { + return err(std::move(keys_res.unwrap_err())); + } + + skip_whitespace(loc, ctx); + loc.advance(); // ] + loc.advance(); // ] + + return ok(std::make_pair(std::move(keys_res.unwrap().first), std::move(reg))); + } + + // called after reading [table.keys] and comments around it. + // Since table may already contain a subtable ([x.y.z] can be defined before + // [x]), the table that is being parsed is passed as an argument. + template + result parse_table(location& loc, + context& ctx, + basic_value& table) { + assert(table.is_table()); + + const auto num_errors = ctx.errors().size(); + const auto& spec = ctx.toml_spec(); + + // clear indent info + table.as_table_fmt().indent_type = indent_char::none; + + bool newline_found = true; + while (!loc.eof()) { + const auto start = loc; + + auto sp = skip_multiline_spacer(loc, ctx, newline_found); + + // if reached to EOF, the table ends here. return. + if (loc.eof()) { + break; + } + // if next table is comming, return. + if (sequence(syntax::ws(spec), character('[')).scan(loc).is_ok()) { + loc = start; + break; + } + // otherwise, it should be a key-value pair. + newline_found = newline_found || + (sp.has_value() && sp.value().newline_found); + if (!newline_found) { + return err(make_error_info("toml::parse_table: " + "newline (LF / CRLF) or EOF is expected", + source_location(region(loc)), + "here")); + } + if (sp.has_value() && sp.value().indent_type != indent_char::none) { + table.as_table_fmt().indent_type = sp.value().indent_type; + table.as_table_fmt().body_indent = sp.value().indent; + } + + newline_found = false; // reset + if (auto kv_res = parse_key_value_pair(loc, ctx)) { + auto keys = std::move(kv_res.unwrap().first.first); + auto key_reg = std::move(kv_res.unwrap().first.second); + auto val = std::move(kv_res.unwrap().second); + + if (sp.has_value()) { + for (const auto& com : sp.value().comments) { + val.comments().push_back(com); + } + } + + if (auto com_res = parse_comment_line(loc, ctx)) { + if (auto com_opt = com_res.unwrap()) { + val.comments().push_back(com_opt.value()); + newline_found = true; // comment includes newline at the end + } + } else { + ctx.report_error(std::move(com_res.unwrap_err())); + } + + auto ins_res = insert_value(inserting_value_kind::dotted_keys, + std::addressof(table.as_table()), + keys, + std::move(key_reg), + std::move(val)); + if (ins_res.is_err()) { + ctx.report_error(std::move(ins_res.unwrap_err())); + } + } else { + ctx.report_error(std::move(kv_res.unwrap_err())); + skip_key_value_pair(loc, ctx); + } + } + + if (num_errors < ctx.errors().size()) { + assert(ctx.has_error()); // already reported + return err(ctx.pop_last_error()); + } + return ok(); + } + + template + result, std::vector> parse_file(location& loc, + context& ctx) { + using value_type = basic_value; + using table_type = typename value_type::table_type; + + const auto first = loc; + const auto& spec = ctx.toml_spec(); + + if (loc.eof()) { + return ok(value_type(table_type(), table_format_info {}, {}, region(loc))); + } + + value_type root(table_type(), table_format_info {}, {}, region(loc)); + root.as_table_fmt().fmt = table_format::multiline; + root.as_table_fmt().indent_type = indent_char::none; + + // parse top comment. + // + // ```toml + // # this is a comment for the top-level table. + // + // key = "the first value" + // ``` + // + // ```toml + // # this is a comment for "the first value". + // key = "the first value" + // ``` + while (!loc.eof()) { + if (auto com_res = parse_comment_line(loc, ctx)) { + if (auto com_opt = com_res.unwrap()) { + root.comments().push_back(std::move(com_opt.value())); + } else // no comment found. + { + // if it is not an empty line, clear the root comment. + if (!sequence(syntax::ws(spec), syntax::newline(spec)).scan(loc).is_ok()) { + loc = first; + root.comments().clear(); + } + break; + } + } else { + ctx.report_error(std::move(com_res.unwrap_err())); + skip_comment_block(loc, ctx); + } + } + + // parse root table + { + const auto res = parse_table(loc, ctx, root); + if (res.is_err()) { + ctx.report_error(std::move(res.unwrap_err())); + skip_until_next_table(loc, ctx); + } + } + + // parse tables + + while (!loc.eof()) { + auto sp = skip_multiline_spacer(loc, ctx, /*newline_found=*/true); + + if (auto key_res = parse_array_table_key(loc, ctx)) { + auto key = std::move(std::get<0>(key_res.unwrap())); + auto reg = std::move(std::get<1>(key_res.unwrap())); + + std::vector com; + if (sp.has_value()) { + for (std::size_t i = 0; i < sp.value().comments.size(); ++i) { + com.push_back(std::move(sp.value().comments.at(i))); + } + } + + // [table.def] must be followed by one of + // - a comment line + // - whitespace + newline + // - EOF + if (auto com_res = parse_comment_line(loc, ctx)) { + if (auto com_opt = com_res.unwrap()) { + com.push_back(com_opt.value()); + } else // if there is no comment, ws+newline must exist (or EOF) + { + skip_whitespace(loc, ctx); + if (!loc.eof() && + !syntax::newline(ctx.toml_spec()).scan(loc).is_ok()) { + ctx.report_error(make_syntax_error("toml::parse_file: " + "newline (or EOF) expected", + syntax::newline(ctx.toml_spec()), + loc)); + skip_until_next_table(loc, ctx); + continue; + } + } + } else // comment syntax error (rare) + { + ctx.report_error(com_res.unwrap_err()); + skip_until_next_table(loc, ctx); + continue; + } + + table_format_info fmt; + fmt.fmt = table_format::multiline; + fmt.indent_type = indent_char::none; + auto tab = value_type(table_type {}, std::move(fmt), std::move(com), reg); + + auto inserted = insert_value(inserting_value_kind::array_table, + std::addressof(root.as_table()), + key, + std::move(reg), + std::move(tab)); + + if (inserted.is_err()) { + ctx.report_error(inserted.unwrap_err()); + + // check errors in the table + auto tmp = basic_value(table_type()); + auto res = parse_table(loc, ctx, tmp); + if (res.is_err()) { + ctx.report_error(res.unwrap_err()); + skip_until_next_table(loc, ctx); + } + continue; + } + + auto tab_ptr = inserted.unwrap(); + assert(tab_ptr); + + const auto tab_res = parse_table(loc, ctx, *tab_ptr); + if (tab_res.is_err()) { + ctx.report_error(tab_res.unwrap_err()); + skip_until_next_table(loc, ctx); + } + + // parse_table first clears `indent_type`. + // to keep header indent info, we must store it later. + if (sp.has_value() && sp.value().indent_type != indent_char::none) { + tab_ptr->as_table_fmt().indent_type = sp.value().indent_type; + tab_ptr->as_table_fmt().name_indent = sp.value().indent; + } + continue; + } + if (auto key_res = parse_table_key(loc, ctx)) { + auto key = std::move(std::get<0>(key_res.unwrap())); + auto reg = std::move(std::get<1>(key_res.unwrap())); + + std::vector com; + if (sp.has_value()) { + for (std::size_t i = 0; i < sp.value().comments.size(); ++i) { + com.push_back(std::move(sp.value().comments.at(i))); + } + } + + // [table.def] must be followed by one of + // - a comment line + // - whitespace + newline + // - EOF + if (auto com_res = parse_comment_line(loc, ctx)) { + if (auto com_opt = com_res.unwrap()) { + com.push_back(com_opt.value()); + } else // if there is no comment, ws+newline must exist (or EOF) + { + skip_whitespace(loc, ctx); + if (!loc.eof() && + !syntax::newline(ctx.toml_spec()).scan(loc).is_ok()) { + ctx.report_error(make_syntax_error("toml::parse_file: " + "newline (or EOF) expected", + syntax::newline(ctx.toml_spec()), + loc)); + skip_until_next_table(loc, ctx); + continue; + } + } + } else // comment syntax error (rare) + { + ctx.report_error(com_res.unwrap_err()); + skip_until_next_table(loc, ctx); + continue; + } + + table_format_info fmt; + fmt.fmt = table_format::multiline; + fmt.indent_type = indent_char::none; + auto tab = value_type(table_type {}, std::move(fmt), std::move(com), reg); + + auto inserted = insert_value(inserting_value_kind::std_table, + std::addressof(root.as_table()), + key, + std::move(reg), + std::move(tab)); + + if (inserted.is_err()) { + ctx.report_error(inserted.unwrap_err()); + + // check errors in the table + auto tmp = basic_value(table_type()); + auto res = parse_table(loc, ctx, tmp); + if (res.is_err()) { + ctx.report_error(res.unwrap_err()); + skip_until_next_table(loc, ctx); + } + continue; + } + + auto tab_ptr = inserted.unwrap(); + assert(tab_ptr); + + const auto tab_res = parse_table(loc, ctx, *tab_ptr); + if (tab_res.is_err()) { + ctx.report_error(tab_res.unwrap_err()); + skip_until_next_table(loc, ctx); + } + if (sp.has_value() && sp.value().indent_type != indent_char::none) { + tab_ptr->as_table_fmt().indent_type = sp.value().indent_type; + tab_ptr->as_table_fmt().name_indent = sp.value().indent; + } + continue; + } + + // does not match array_table nor std_table. report an error. + const auto keytop = loc; + const auto maybe_array_of_tables = literal("[[").scan(loc).is_ok(); + loc = keytop; + + if (maybe_array_of_tables) { + ctx.report_error( + make_syntax_error("toml::parse_file: invalid array-table key", + syntax::array_table(spec), + loc)); + } else { + ctx.report_error( + make_syntax_error("toml::parse_file: invalid table key", + syntax::std_table(spec), + loc)); + } + skip_until_next_table(loc, ctx); + } + + if (!ctx.errors().empty()) { + return err(std::move(ctx.errors())); + } + return ok(std::move(root)); + } + + template + result, std::vector> parse_impl( + std::vector cs, + std::string fname, + const spec& s) { + using value_type = basic_value; + using table_type = typename value_type::table_type; + + // an empty file is a valid toml file. + if (cs.empty()) { + auto src = std::make_shared>( + std::move(cs)); + location loc(std::move(src), std::move(fname)); + return ok(value_type(table_type(), + table_format_info {}, + std::vector {}, + region(loc))); + } + + // to simplify parser, add newline at the end if there is no LF. + // But, if it has raw CR, the file is invalid (in TOML, CR is not a valid + // newline char). if it ends with CR, do not add LF and report it. + if (cs.back() != '\n' && cs.back() != '\r') { + cs.push_back('\n'); + } + + auto src = std::make_shared>(std::move(cs)); + + location loc(std::move(src), std::move(fname)); + + // skip BOM if found + if (loc.source()->size() >= 3) { + auto first = loc.get_location(); + + const auto c0 = loc.current(); + loc.advance(); + const auto c1 = loc.current(); + loc.advance(); + const auto c2 = loc.current(); + loc.advance(); + + const auto bom_found = (c0 == 0xEF) && (c1 == 0xBB) && (c2 == 0xBF); + if (!bom_found) { + loc.set_location(first); + } + } + + context ctx(s); + + return parse_file(loc, ctx); + } + + } // namespace detail + + // ----------------------------------------------------------------------------- + // parse(byte array) + + template + result, std::vector> try_parse( + std::vector content, + std::string filename, + spec s = spec::default_version()) { + return detail::parse_impl(std::move(content), + std::move(filename), + std::move(s)); + } + + template + basic_value parse(std::vector content, + std::string filename, + spec s = spec::default_version()) { + auto res = try_parse(std::move(content), std::move(filename), std::move(s)); + if (res.is_ok()) { + return res.unwrap(); + } else { + std::string msg; + for (const auto& err : res.unwrap_err()) { + msg += format_error(err); + } + throw syntax_error(std::move(msg), std::move(res.unwrap_err())); + } + } + + // ----------------------------------------------------------------------------- + // parse(istream) + + template + result, std::vector> try_parse( + std::istream& is, + std::string fname = "unknown file", + spec s = spec::default_version()) { + const auto beg = is.tellg(); + is.seekg(0, std::ios::end); + const auto end = is.tellg(); + const auto fsize = end - beg; + is.seekg(beg); + + // read whole file as a sequence of char + assert(fsize >= 0); + std::vector letters(static_cast(fsize), + '\0'); + is.read(reinterpret_cast(letters.data()), + static_cast(fsize)); + + return detail::parse_impl(std::move(letters), + std::move(fname), + std::move(s)); + } + + template + basic_value parse(std::istream& is, + std::string fname = "unknown file", + spec s = spec::default_version()) { + auto res = try_parse(is, std::move(fname), std::move(s)); + if (res.is_ok()) { + return res.unwrap(); + } else { + std::string msg; + for (const auto& err : res.unwrap_err()) { + msg += format_error(err); + } + throw syntax_error(std::move(msg), std::move(res.unwrap_err())); + } + } + + // ----------------------------------------------------------------------------- + // parse(filename) + + template + result, std::vector> try_parse( + std::string fname, + spec s = spec::default_version()) { + std::ifstream ifs(fname, std::ios_base::binary); + if (!ifs.good()) { + std::vector e; + e.push_back( + error_info("toml::parse: Error opening file \"" + fname + "\"", {})); + return err(std::move(e)); + } + ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit); + + return try_parse(ifs, std::move(fname), std::move(s)); + } + + template + basic_value parse(std::string fname, spec s = spec::default_version()) { + std::ifstream ifs(fname, std::ios_base::binary); + if (!ifs.good()) { + throw file_io_error("toml::parse: error opening file", fname); + } + ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit); + + return parse(ifs, std::move(fname), std::move(s)); + } + + template + result, std::vector> try_parse( + const char (&fname)[N], + spec s = spec::default_version()) { + return try_parse(std::string(fname), std::move(s)); + } + + template + basic_value parse(const char (&fname)[N], spec s = spec::default_version()) { + return parse(std::string(fname), std::move(s)); + } + + // ---------------------------------------------------------------------------- + // parse_str + + template + result, std::vector> try_parse_str( + std::string content, + spec s = spec::default_version(), + cxx::source_location loc = cxx::source_location::current()) { + std::istringstream iss(std::move(content)); + std::string name("internal string" + cxx::to_string(loc)); + return try_parse(iss, std::move(name), std::move(s)); + } + + template + basic_value parse_str( + std::string content, + spec s = spec::default_version(), + cxx::source_location loc = cxx::source_location::current()) { + auto res = try_parse_str(std::move(content), std::move(s), std::move(loc)); + if (res.is_ok()) { + return res.unwrap(); + } else { + std::string msg; + for (const auto& err : res.unwrap_err()) { + msg += format_error(err); + } + throw syntax_error(std::move(msg), std::move(res.unwrap_err())); + } + } + + // ---------------------------------------------------------------------------- + // filesystem + +#if defined(TOML11_HAS_FILESYSTEM) + + template + cxx::enable_if_t::value, + result, std::vector>> + try_parse(const FSPATH& fpath, spec s = spec::default_version()) { + std::ifstream ifs(fpath, std::ios_base::binary); + if (!ifs.good()) { + std::vector e; + e.push_back( + error_info("toml::parse: Error opening file \"" + fpath.string() + "\"", + {})); + return err(std::move(e)); + } + ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit); + + return try_parse(ifs, fpath.string(), std::move(s)); + } + + template + cxx::enable_if_t::value, basic_value> + parse(const FSPATH& fpath, spec s = spec::default_version()) { + std::ifstream ifs(fpath, std::ios_base::binary); + if (!ifs.good()) { + throw file_io_error("toml::parse: error opening file", fpath.string()); + } + ifs.exceptions(std::ifstream::failbit | std::ifstream::badbit); + + return parse(ifs, fpath.string(), std::move(s)); + } +#endif + + // ----------------------------------------------------------------------------- + // FILE* + + template + result, std::vector> try_parse( + FILE* fp, + std::string filename, + spec s = spec::default_version()) { + const long beg = std::ftell(fp); + if (beg == -1L) { + return err(std::vector { + error_info(std::string("Failed to access: \"") + filename + + "\", errno = " + std::to_string(errno), + {}) }); + } + + const int res_seekend = std::fseek(fp, 0, SEEK_END); + if (res_seekend != 0) { + return err(std::vector { + error_info(std::string("Failed to seek: \"") + filename + + "\", errno = " + std::to_string(errno), + {}) }); + } + + const long end = std::ftell(fp); + if (end == -1L) { + return err(std::vector { + error_info(std::string("Failed to access: \"") + filename + + "\", errno = " + std::to_string(errno), + {}) }); + } + + const auto fsize = end - beg; + + const auto res_seekbeg = std::fseek(fp, beg, SEEK_SET); + if (res_seekbeg != 0) { + return err(std::vector { + error_info(std::string("Failed to seek: \"") + filename + + "\", errno = " + std::to_string(errno), + {}) }); + } + + // read whole file as a sequence of char + assert(fsize >= 0); + std::vector letters( + static_cast(fsize)); + const auto actual = std::fread(letters.data(), + sizeof(char), + static_cast(fsize), + fp); + if (actual != static_cast(fsize)) { + return err(std::vector { + error_info(std::string("File size changed: \"") + filename + + std::string("\" make sure that FILE* is in binary mode " + "to avoid LF <-> CRLF conversion"), + {}) }); + } + + return detail::parse_impl(std::move(letters), + std::move(filename), + std::move(s)); + } + + template + basic_value parse(FILE* fp, + std::string filename, + spec s = spec::default_version()) { + const long beg = std::ftell(fp); + if (beg == -1L) { + throw file_io_error(errno, "Failed to access", filename); + } + + const int res_seekend = std::fseek(fp, 0, SEEK_END); + if (res_seekend != 0) { + throw file_io_error(errno, "Failed to seek", filename); + } + + const long end = std::ftell(fp); + if (end == -1L) { + throw file_io_error(errno, "Failed to access", filename); + } + + const auto fsize = end - beg; + + const auto res_seekbeg = std::fseek(fp, beg, SEEK_SET); + if (res_seekbeg != 0) { + throw file_io_error(errno, "Failed to seek", filename); + } + + // read whole file as a sequence of char + assert(fsize >= 0); + std::vector letters( + static_cast(fsize)); + const auto actual = std::fread(letters.data(), + sizeof(char), + static_cast(fsize), + fp); + if (actual != static_cast(fsize)) { + throw file_io_error( + errno, + "File size changed; make sure that " + "FILE* is in binary mode to avoid LF <-> CRLF conversion", + filename); + } + + auto res = detail::parse_impl(std::move(letters), + std::move(filename), + std::move(s)); + if (res.is_ok()) { + return res.unwrap(); + } else { + std::string msg; + for (const auto& err : res.unwrap_err()) { + msg += format_error(err); + } + throw syntax_error(std::move(msg), std::move(res.unwrap_err())); + } + } + +} // namespace toml + +#if defined(TOML11_COMPILE_SOURCES) +namespace toml { + struct type_config; + struct ordered_type_config; + + extern template result, std::vector> + try_parse(std::vector, std::string, spec); + extern template result, std::vector> + try_parse(std::istream&, std::string, spec); + extern template result, std::vector> + try_parse(std::string, spec); + extern template result, std::vector> + try_parse(FILE*, std::string, spec); + extern template result, std::vector> + try_parse_str(std::string, spec, cxx::source_location); + + extern template basic_value parse( + std::vector, + std::string, + spec); + extern template basic_value parse(std::istream&, + std::string, + spec); + extern template basic_value parse(std::string, spec); + extern template basic_value parse(FILE*, + std::string, + spec); + extern template basic_value parse_str( + std::string, + spec, + cxx::source_location); + + extern template result, std::vector> + try_parse(std::vector, std::string, spec); + extern template result, std::vector> + try_parse(std::istream&, std::string, spec); + extern template result, std::vector> + try_parse(std::string, spec); + extern template result, std::vector> + try_parse(FILE*, std::string, spec); + extern template result, std::vector> + try_parse_str(std::string, spec, cxx::source_location); + + extern template basic_value parse( + std::vector, + std::string, + spec); + extern template basic_value parse( + std::istream&, + std::string, + spec); + extern template basic_value parse( + std::string, + spec); + extern template basic_value parse( + FILE*, + std::string, + spec); + extern template basic_value parse_str( + std::string, + spec, + cxx::source_location); + + #if defined(TOML11_HAS_FILESYSTEM) + extern template cxx::enable_if_t< + std::is_same::value, + result, std::vector>> + try_parse(const std::filesystem::path&, + spec); + extern template cxx::enable_if_t< + std::is_same::value, + result, std::vector>> + try_parse( + const std::filesystem::path&, + spec); + extern template cxx::enable_if_t< + std::is_same::value, + basic_value> + parse(const std::filesystem::path&, spec); + extern template cxx::enable_if_t< + std::is_same::value, + basic_value> + parse(const std::filesystem::path&, + spec); + #endif // filesystem + +} // namespace toml +#endif // TOML11_COMPILE_SOURCES + +#endif // TOML11_PARSER_HPP +#ifndef TOML11_LITERAL_HPP +#define TOML11_LITERAL_HPP + +#ifndef TOML11_LITERAL_FWD_HPP + #define TOML11_LITERAL_FWD_HPP + +namespace toml { + + namespace detail { + // implementation + ::toml::value literal_internal_impl(location loc); + } // namespace detail + + inline namespace literals { + inline namespace toml_literals { + + ::toml::value operator"" _toml(const char* str, std::size_t len); + + #if defined(TOML11_HAS_CHAR8_T) + // value of u8"" literal has been changed from char to char8_t and char8_t + // is NOT compatible to char + ::toml::value operator"" _toml(const char8_t* str, std::size_t len); + #endif + + } // namespace toml_literals + } // namespace literals +} // namespace toml +#endif // TOML11_LITERAL_FWD_HPP + +#if !defined(TOML11_COMPILE_SOURCES) + #ifndef TOML11_LITERAL_IMPL_HPP + #define TOML11_LITERAL_IMPL_HPP + +namespace toml { + + namespace detail { + // implementation + TOML11_INLINE ::toml::value literal_internal_impl(location loc) { + const auto s = ::toml::spec::default_version(); + context ctx(s); + + const auto front = loc; + + // ------------------------------------------------------------------------ + // check if it is a raw value. + + // skip empty lines and comment lines + auto sp = skip_multiline_spacer(loc, ctx); + if (loc.eof()) { + ::toml::value val; + if (sp.has_value()) { + for (std::size_t i = 0; i < sp.value().comments.size(); ++i) { + val.comments().push_back(std::move(sp.value().comments.at(i))); + } + } + return val; + } + + // to distinguish arrays and tables, first check it is a table or not. + // + // "[1,2,3]"_toml; // json: [1, 2, 3] + // "[table]"_toml; // json: {"table": {}} + // "[[1,2,3]]"_toml; // json: [[1, 2, 3]] + // "[[table]]"_toml; // json: {"table": [{}]} + // + // "[[1]]"_toml; // json: {"1": [{}]} + // "1 = [{}]"_toml; // json: {"1": [{}]} + // "[[1,]]"_toml; // json: [[1]] + // "[[1],]"_toml; // json: [[1]] + const auto val_start = loc; + + const bool is_table_key = syntax::std_table(s).scan(loc).is_ok(); + loc = val_start; + const bool is_aots_key = syntax::array_table(s).scan(loc).is_ok(); + loc = val_start; + + // If it is neither a table-key or a array-of-table-key, it may be a value. + if (!is_table_key && !is_aots_key) { + auto data = parse_value(loc, ctx); + if (data.is_ok()) { + auto val = std::move(data.unwrap()); + if (sp.has_value()) { + for (std::size_t i = 0; i < sp.value().comments.size(); ++i) { + val.comments().push_back(std::move(sp.value().comments.at(i))); + } + } + auto com_res = parse_comment_line(loc, ctx); + if (com_res.is_ok() && com_res.unwrap().has_value()) { + val.comments().push_back(com_res.unwrap().value()); + } + return val; + } + } + + // ------------------------------------------------------------------------- + // Note that still it can be a table, because the literal might be + // something like the following. + // ```cpp + // // c++11 raw-string literal + // const auto val = R"( + // key = "value" + // int = 42 + // )"_toml; + // ``` + // It is a valid toml file. + // It should be parsed as if we parse a file with this content. + + loc = front; + auto data = parse_file(loc, ctx); + if (data.is_ok()) { + return data.unwrap(); + } else // not a value && not a file. error. + { + std::string msg; + for (const auto& err : data.unwrap_err()) { + msg += format_error(err); + } + throw ::toml::syntax_error(std::move(msg), std::move(data.unwrap_err())); + } + } + + } // namespace detail + + inline namespace literals { + inline namespace toml_literals { + + TOML11_INLINE ::toml::value operator"" _toml(const char* str, + std::size_t len) { + if (len == 0) { + return ::toml::value {}; + } + + ::toml::detail::location::container_type c(len); + std::copy( + reinterpret_cast(str), + reinterpret_cast(str + len), + c.begin()); + if (!c.empty() && c.back()) { + c.push_back('\n'); // to make it easy to parse comment, we add newline + } + + return literal_internal_impl(::toml::detail::location( + std::make_shared( + std::move(c)), + "TOML literal encoded in a C++ code")); + } + + #if defined(__cpp_char8_t) + #if __cpp_char8_t >= 201811L + #define TOML11_HAS_CHAR8_T 1 + #endif + #endif + + #if defined(TOML11_HAS_CHAR8_T) + // value of u8"" literal has been changed from char to char8_t and char8_t + // is NOT compatible to char + TOML11_INLINE ::toml::value operator"" _toml(const char8_t* str, + std::size_t len) { + if (len == 0) { + return ::toml::value {}; + } + + ::toml::detail::location::container_type c(len); + std::copy( + reinterpret_cast(str), + reinterpret_cast(str + len), + c.begin()); + if (!c.empty() && c.back()) { + c.push_back('\n'); // to make it easy to parse comment, we add newline + } + + return literal_internal_impl(::toml::detail::location( + std::make_shared( + std::move(c)), + "TOML literal encoded in a C++ code")); + } + #endif + + } // namespace toml_literals + } // namespace literals +} // namespace toml + #endif // TOML11_LITERAL_IMPL_HPP +#endif + +#endif // TOML11_LITERAL_HPP +#ifndef TOML11_SERIALIZER_HPP +#define TOML11_SERIALIZER_HPP + +#include +#include +#include +#include +#include + +namespace toml { + + struct serialization_error final : public ::toml::exception { + public: + explicit serialization_error(std::string what_arg, source_location loc) + : what_(std::move(what_arg)) + , loc_(std::move(loc)) {} + + ~serialization_error() noexcept override = default; + + const char* what() const noexcept override { + return what_.c_str(); + } + + const source_location& location() const noexcept { + return loc_; + } + + private: + std::string what_; + source_location loc_; + }; + + namespace detail { + template + class serializer { + public: + using value_type = basic_value; + + using key_type = typename value_type::key_type; + using comment_type = typename value_type::comment_type; + using boolean_type = typename value_type::boolean_type; + using integer_type = typename value_type::integer_type; + using floating_type = typename value_type::floating_type; + using string_type = typename value_type::string_type; + using local_time_type = typename value_type::local_time_type; + using local_date_type = typename value_type::local_date_type; + using local_datetime_type = typename value_type::local_datetime_type; + using offset_datetime_type = typename value_type::offset_datetime_type; + using array_type = typename value_type::array_type; + using table_type = typename value_type::table_type; + + using char_type = typename string_type::value_type; + + public: + explicit serializer(const spec& sp) + : spec_(sp) + , force_inline_(false) + , current_indent_(0) {} + + string_type operator()(const std::vector& ks, const value_type& v) { + for (const auto& k : ks) { + this->keys_.push_back(k); + } + return (*this)(v); + } + + string_type operator()(const key_type& k, const value_type& v) { + this->keys_.push_back(k); + return (*this)(v); + } + + string_type operator()(const value_type& v) { + switch (v.type()) { + case value_t::boolean: { + return (*this)(v.as_boolean(), v.as_boolean_fmt(), v.location()); + } + case value_t::integer: { + return (*this)(v.as_integer(), v.as_integer_fmt(), v.location()); + } + case value_t::floating: { + return (*this)(v.as_floating(), v.as_floating_fmt(), v.location()); + } + case value_t::string: { + return (*this)(v.as_string(), v.as_string_fmt(), v.location()); + } + case value_t::offset_datetime: { + return (*this)(v.as_offset_datetime(), + v.as_offset_datetime_fmt(), + v.location()); + } + case value_t::local_datetime: { + return (*this)(v.as_local_datetime(), + v.as_local_datetime_fmt(), + v.location()); + } + case value_t::local_date: { + return (*this)(v.as_local_date(), v.as_local_date_fmt(), v.location()); + } + case value_t::local_time: { + return (*this)(v.as_local_time(), v.as_local_time_fmt(), v.location()); + } + case value_t::array: { + return ( + *this)(v.as_array(), v.as_array_fmt(), v.comments(), v.location()); + } + case value_t::table: { + string_type retval; + if (this->keys_.empty()) // it might be the root table. emit comments here. + { + retval += format_comments(v.comments(), v.as_table_fmt().indent_type); + } + if (!retval.empty()) // we have comment. + { + retval += char_type('\n'); + } + + retval += (*this)(v.as_table(), + v.as_table_fmt(), + v.comments(), + v.location()); + return retval; + } + case value_t::empty: { + if (this->spec_.ext_null_value) { + return string_conv("null"); + } + break; + } + default: { + break; + } + } + throw serialization_error( + format_error("[error] toml::serializer: toml::basic_value " + "does not have any valid type.", + v.location(), + "here"), + v.location()); + } + + private: + string_type operator()(const boolean_type& b, + const boolean_format_info&, + const source_location&) // {{{ + { + if (b) { + return string_conv("true"); + } else { + return string_conv("false"); + } + } // }}} + + string_type operator()(const integer_type i, + const integer_format_info& fmt, + const source_location& loc) // {{{ + { + std::ostringstream oss; + this->set_locale(oss); + + const auto insert_spacer = [&fmt](std::string s) -> std::string { + if (fmt.spacer == 0) { + return s; + } + + std::string sign; + if (!s.empty() && (s.at(0) == '+' || s.at(0) == '-')) { + sign += s.at(0); + s.erase(s.begin()); + } + + std::string spaced; + std::size_t counter = 0; + for (auto iter = s.rbegin(); iter != s.rend(); ++iter) { + if (counter != 0 && counter % fmt.spacer == 0) { + spaced += '_'; + } + spaced += *iter; + counter += 1; + } + if (!spaced.empty() && spaced.back() == '_') { + spaced.pop_back(); + } + + s.clear(); + std::copy(spaced.rbegin(), spaced.rend(), std::back_inserter(s)); + return sign + s; + }; + + std::string retval; + if (fmt.fmt == integer_format::dec) { + oss << std::setw(static_cast(fmt.width)) << std::dec << i; + retval = insert_spacer(oss.str()); + + if (this->spec_.ext_num_suffix && !fmt.suffix.empty()) { + retval += '_'; + retval += fmt.suffix; + } + } else { + if (i < 0) { + throw serialization_error( + format_error("binary, octal, hexadecimal " + "integer does not allow negative value", + loc, + "here"), + loc); + } + switch (fmt.fmt) { + case integer_format::hex: { + oss << std::noshowbase << std::setw(static_cast(fmt.width)) + << std::setfill('0') << std::hex; + if (fmt.uppercase) { + oss << std::uppercase; + } else { + oss << std::nouppercase; + } + oss << i; + retval = std::string("0x") + insert_spacer(oss.str()); + break; + } + case integer_format::oct: { + oss << std::setw(static_cast(fmt.width)) << std::setfill('0') + << std::oct << i; + retval = std::string("0o") + insert_spacer(oss.str()); + break; + } + case integer_format::bin: { + integer_type x { i }; + std::string tmp; + std::size_t bits(0); + while (x != 0) { + if (fmt.spacer != 0) { + if (bits != 0 && (bits % fmt.spacer) == 0) { + tmp += '_'; + } + } + if (x % 2 == 1) { + tmp += '1'; + } else { + tmp += '0'; + } + x >>= 1; + bits += 1; + } + for (; bits < fmt.width; ++bits) { + if (fmt.spacer != 0) { + if (bits != 0 && (bits % fmt.spacer) == 0) { + tmp += '_'; + } + } + tmp += '0'; + } + for (auto iter = tmp.rbegin(); iter != tmp.rend(); ++iter) { + oss << *iter; + } + retval = std::string("0b") + oss.str(); + break; + } + default: { + throw serialization_error( + format_error("none of dec, hex, oct, bin: " + to_string(fmt.fmt), + loc, + "here"), + loc); + } + } + } + return string_conv(retval); + } // }}} + + string_type operator()(const floating_type f, + const floating_format_info& fmt, + const source_location&) // {{{ + { + using std::isinf; + using std::isnan; + using std::signbit; + + std::ostringstream oss; + this->set_locale(oss); + + if (isnan(f)) { + if (signbit(f)) { + oss << '-'; + } + oss << "nan"; + if (this->spec_.ext_num_suffix && !fmt.suffix.empty()) { + oss << '_'; + oss << fmt.suffix; + } + return string_conv(oss.str()); + } + + if (isinf(f)) { + if (signbit(f)) { + oss << '-'; + } + oss << "inf"; + if (this->spec_.ext_num_suffix && !fmt.suffix.empty()) { + oss << '_'; + oss << fmt.suffix; + } + return string_conv(oss.str()); + } + + switch (fmt.fmt) { + case floating_format::defaultfloat: { + if (fmt.prec != 0) { + oss << std::setprecision(static_cast(fmt.prec)); + } + oss << f; + // since defaultfloat may omit point, we need to add it + std::string s = oss.str(); + if (s.find('.') == std::string::npos && + s.find('e') == std::string::npos && + s.find('E') == std::string::npos) { + s += ".0"; + } + if (this->spec_.ext_num_suffix && !fmt.suffix.empty()) { + s += '_'; + s += fmt.suffix; + } + return string_conv(s); + } + case floating_format::fixed: { + if (fmt.prec != 0) { + oss << std::setprecision(static_cast(fmt.prec)); + } + oss << std::fixed << f; + if (this->spec_.ext_num_suffix && !fmt.suffix.empty()) { + oss << '_' << fmt.suffix; + } + return string_conv(oss.str()); + } + case floating_format::scientific: { + if (fmt.prec != 0) { + oss << std::setprecision(static_cast(fmt.prec)); + } + oss << std::scientific << f; + if (this->spec_.ext_num_suffix && !fmt.suffix.empty()) { + oss << '_' << fmt.suffix; + } + return string_conv(oss.str()); + } + case floating_format::hex: { + if (this->spec_.ext_hex_float) { + oss << std::hexfloat << f; + // suffix is only for decimal numbers. + return string_conv(oss.str()); + } else // no hex allowed. output with max precision. + { + oss << std::setprecision( + std::numeric_limits::max_digits10) + << std::scientific << f; + // suffix is only for decimal numbers. + return string_conv(oss.str()); + } + } + default: { + if (this->spec_.ext_num_suffix && !fmt.suffix.empty()) { + oss << '_' << fmt.suffix; + } + return string_conv(oss.str()); + } + } + } // }}} + + string_type operator()(string_type s, + const string_format_info& fmt, + const source_location& loc) // {{{ + { + string_type retval; + switch (fmt.fmt) { + case string_format::basic: { + retval += char_type('"'); + retval += this->escape_basic_string(s); + retval += char_type('"'); + return retval; + } + case string_format::literal: { + if (std::find(s.begin(), s.end(), char_type('\n')) != s.end()) { + throw serialization_error( + format_error( + "toml::serializer: " + "(non-multiline) literal string cannot have a newline", + loc, + "here"), + loc); + } + retval += char_type('\''); + retval += s; + retval += char_type('\''); + return retval; + } + case string_format::multiline_basic: { + retval += string_conv("\"\"\""); + if (fmt.start_with_newline) { + retval += char_type('\n'); + } + + retval += this->escape_ml_basic_string(s); + + retval += string_conv("\"\"\""); + return retval; + } + case string_format::multiline_literal: { + retval += string_conv("'''"); + if (fmt.start_with_newline) { + retval += char_type('\n'); + } + retval += s; + retval += string_conv("'''"); + return retval; + } + default: { + throw serialization_error( + format_error("[error] toml::serializer::operator()(string): " + "invalid string_format value", + loc, + "here"), + loc); + } + } + } // }}} + + string_type operator()(const local_date_type& d, + const local_date_format_info&, + const source_location&) // {{{ + { + std::ostringstream oss; + oss << d; + return string_conv(oss.str()); + } // }}} + + string_type operator()(const local_time_type& t, + const local_time_format_info& fmt, + const source_location&) // {{{ + { + return this->format_local_time(t, fmt.has_seconds, fmt.subsecond_precision); + } // }}} + + string_type operator()(const local_datetime_type& dt, + const local_datetime_format_info& fmt, + const source_location&) // {{{ + { + std::ostringstream oss; + oss << dt.date; + switch (fmt.delimiter) { + case datetime_delimiter_kind::upper_T: { + oss << 'T'; + break; + } + case datetime_delimiter_kind::lower_t: { + oss << 't'; + break; + } + case datetime_delimiter_kind::space: { + oss << ' '; + break; + } + default: { + oss << 'T'; + break; + } + } + return string_conv(oss.str()) + + this->format_local_time(dt.time, + fmt.has_seconds, + fmt.subsecond_precision); + } // }}} + + string_type operator()(const offset_datetime_type& odt, + const offset_datetime_format_info& fmt, + const source_location&) // {{{ + { + std::ostringstream oss; + oss << odt.date; + switch (fmt.delimiter) { + case datetime_delimiter_kind::upper_T: { + oss << 'T'; + break; + } + case datetime_delimiter_kind::lower_t: { + oss << 't'; + break; + } + case datetime_delimiter_kind::space: { + oss << ' '; + break; + } + default: { + oss << 'T'; + break; + } + } + oss << string_conv( + this->format_local_time(odt.time, fmt.has_seconds, fmt.subsecond_precision)); + oss << odt.offset; + return string_conv(oss.str()); + } // }}} + + string_type operator()(const array_type& a, + const array_format_info& fmt, + const comment_type& com, + const source_location& loc) // {{{ + { + array_format f = fmt.fmt; + if (fmt.fmt == array_format::default_format) { + // [[in.this.form]], you cannot add a comment to the array itself + // (but you can add a comment to each table). + // To keep comments, we need to avoid multiline array-of-tables + // if array itself has a comment. + if (!this->keys_.empty() && !a.empty() && com.empty() && + std::all_of(a.begin(), a.end(), [](const value_type& e) { + return e.is_table(); + })) { + f = array_format::array_of_tables; + } else { + f = array_format::oneline; + + // check if it becomes long + std::size_t approx_len = 0; + for (const auto& e : a) { + // have a comment. cannot be inlined + if (!e.comments().empty()) { + f = array_format::multiline; + break; + } + // possibly long types ... + if (e.is_array() || e.is_table() || e.is_offset_datetime() || + e.is_local_datetime()) { + f = array_format::multiline; + break; + } else if (e.is_boolean()) { + approx_len += + (*this)(e.as_boolean(), e.as_boolean_fmt(), e.location()).size(); + } else if (e.is_integer()) { + approx_len += + (*this)(e.as_integer(), e.as_integer_fmt(), e.location()).size(); + } else if (e.is_floating()) { + approx_len += + (*this)(e.as_floating(), e.as_floating_fmt(), e.location()).size(); + } else if (e.is_string()) { + if (e.as_string_fmt().fmt == string_format::multiline_basic || + e.as_string_fmt().fmt == string_format::multiline_literal) { + f = array_format::multiline; + break; + } + approx_len += + 2 + + (*this)(e.as_string(), e.as_string_fmt(), e.location()).size(); + } else if (e.is_local_date()) { + approx_len += 10; // 1234-56-78 + } else if (e.is_local_time()) { + approx_len += 15; // 12:34:56.789012 + } + + if (approx_len > 60) // key, ` = `, `[...]` < 80 + { + f = array_format::multiline; + break; + } + approx_len += 2; // `, ` + } + } + } + if (this->force_inline_ && f == array_format::array_of_tables) { + f = array_format::multiline; + } + if (a.empty() && f == array_format::array_of_tables) { + f = array_format::oneline; + } + + // -------------------------------------------------------------------- + + if (f == array_format::array_of_tables) { + if (this->keys_.empty()) { + throw serialization_error("array of table must have its key. " + "use format(key, v)", + loc); + } + string_type retval; + for (const auto& e : a) { + assert(e.is_table()); + + this->current_indent_ += e.as_table_fmt().name_indent; + retval += this->format_comments(e.comments(), + e.as_table_fmt().indent_type); + retval += this->format_indent(e.as_table_fmt().indent_type); + this->current_indent_ -= e.as_table_fmt().name_indent; + + retval += string_conv("[["); + retval += this->format_keys(this->keys_).value(); + retval += string_conv("]]\n"); + + retval += this->format_ml_table(e.as_table(), e.as_table_fmt()); + } + return retval; + } else if (f == array_format::oneline) { + // ignore comments. we cannot emit comments + string_type retval; + retval += char_type('['); + for (const auto& e : a) { + this->force_inline_ = true; + retval += (*this)(e); + retval += string_conv(", "); + } + if (!a.empty()) { + retval.pop_back(); // ` ` + retval.pop_back(); // `,` + } + retval += char_type(']'); + this->force_inline_ = false; + return retval; + } else { + assert(f == array_format::multiline); + + string_type retval; + retval += string_conv("[\n"); + + for (const auto& e : a) { + this->current_indent_ += fmt.body_indent; + retval += this->format_comments(e.comments(), fmt.indent_type); + retval += this->format_indent(fmt.indent_type); + this->current_indent_ -= fmt.body_indent; + + this->force_inline_ = true; + retval += (*this)(e); + retval += string_conv(",\n"); + } + this->force_inline_ = false; + + this->current_indent_ += fmt.closing_indent; + retval += this->format_indent(fmt.indent_type); + this->current_indent_ -= fmt.closing_indent; + + retval += char_type(']'); + return retval; + } + } // }}} + + string_type operator()(const table_type& t, + const table_format_info& fmt, + const comment_type& com, + const source_location& loc) // {{{ + { + if (this->force_inline_) { + if (fmt.fmt == table_format::multiline_oneline) { + return this->format_ml_inline_table(t, fmt); + } else { + return this->format_inline_table(t, fmt); + } + } else { + if (fmt.fmt == table_format::multiline) { + string_type retval; + // comment is emitted inside format_ml_table + if (auto k = this->format_keys(this->keys_)) { + this->current_indent_ += fmt.name_indent; + retval += this->format_comments(com, fmt.indent_type); + retval += this->format_indent(fmt.indent_type); + this->current_indent_ -= fmt.name_indent; + retval += char_type('['); + retval += k.value(); + retval += string_conv("]\n"); + } + // otherwise, its the root. + + retval += this->format_ml_table(t, fmt); + return retval; + } else if (fmt.fmt == table_format::oneline) { + return this->format_inline_table(t, fmt); + } else if (fmt.fmt == table_format::multiline_oneline) { + return this->format_ml_inline_table(t, fmt); + } else if (fmt.fmt == table_format::dotted) { + std::vector keys; + if (this->keys_.empty()) { + throw serialization_error( + format_error( + "toml::serializer: " + "dotted table must have its key. use format(key, v)", + loc, + "here"), + loc); + } + keys.push_back(this->keys_.back()); + + const auto retval = this->format_dotted_table(t, fmt, loc, keys); + keys.pop_back(); + return retval; + } else { + assert(fmt.fmt == table_format::implicit); + + string_type retval; + for (const auto& kv : t) { + const auto& k = kv.first; + const auto& v = kv.second; + + if (!v.is_table() && !v.is_array_of_tables()) { + throw serialization_error( + format_error("toml::serializer: " + "an implicit table cannot have non-table value.", + v.location(), + "here"), + v.location()); + } + if (v.is_table()) { + if (v.as_table_fmt().fmt != table_format::multiline && + v.as_table_fmt().fmt != table_format::implicit) { + throw serialization_error( + format_error( + "toml::serializer: " + "an implicit table cannot have non-multiline table", + v.location(), + "here"), + v.location()); + } + } else { + assert(v.is_array()); + for (const auto& e : v.as_array()) { + if (e.as_table_fmt().fmt != table_format::multiline && + v.as_table_fmt().fmt != table_format::implicit) { + throw serialization_error( + format_error( + "toml::serializer: " + "an implicit table cannot have non-multiline table", + e.location(), + "here"), + e.location()); + } + } + } + + keys_.push_back(k); + retval += (*this)(v); + keys_.pop_back(); + } + return retval; + } + } + } // }}} + + private: + string_type escape_basic_string(const string_type& s) const // {{{ + { + string_type retval; + for (const char_type c : s) { + switch (c) { + case char_type('\\'): { + retval += string_conv("\\\\"); + break; + } + case char_type('\"'): { + retval += string_conv("\\\""); + break; + } + case char_type('\b'): { + retval += string_conv("\\b"); + break; + } + case char_type('\t'): { + retval += string_conv("\\t"); + break; + } + case char_type('\f'): { + retval += string_conv("\\f"); + break; + } + case char_type('\n'): { + retval += string_conv("\\n"); + break; + } + case char_type('\r'): { + retval += string_conv("\\r"); + break; + } + default: { + if (c == char_type(0x1B) && spec_.v1_1_0_add_escape_sequence_e) { + retval += string_conv("\\e"); + } else if ((char_type(0x00) <= c && c <= char_type(0x08)) || + (char_type(0x0A) <= c && c <= char_type(0x1F)) || + c == char_type(0x7F)) { + if (spec_.v1_1_0_add_escape_sequence_x) { + retval += string_conv("\\x"); + } else { + retval += string_conv("\\u00"); + } + const auto c1 = c / 16; + const auto c2 = c % 16; + retval += static_cast('0' + c1); + if (c2 < 10) { + retval += static_cast('0' + c2); + } else // 10 <= c2 + { + retval += static_cast('A' + (c2 - 10)); + } + } else { + retval += c; + } + } + } + } + return retval; + } // }}} + + string_type escape_ml_basic_string(const string_type& s) // {{{ + { + string_type retval; + for (const char_type c : s) { + switch (c) { + case char_type('\\'): { + retval += string_conv("\\\\"); + break; + } + case char_type('\b'): { + retval += string_conv("\\b"); + break; + } + case char_type('\t'): { + retval += string_conv("\\t"); + break; + } + case char_type('\f'): { + retval += string_conv("\\f"); + break; + } + case char_type('\n'): { + retval += string_conv("\n"); + break; + } + case char_type('\r'): { + retval += string_conv("\\r"); + break; + } + default: { + if (c == char_type(0x1B) && spec_.v1_1_0_add_escape_sequence_e) { + retval += string_conv("\\e"); + } else if ((char_type(0x00) <= c && c <= char_type(0x08)) || + (char_type(0x0A) <= c && c <= char_type(0x1F)) || + c == char_type(0x7F)) { + if (spec_.v1_1_0_add_escape_sequence_x) { + retval += string_conv("\\x"); + } else { + retval += string_conv("\\u00"); + } + const auto c1 = c / 16; + const auto c2 = c % 16; + retval += static_cast('0' + c1); + if (c2 < 10) { + retval += static_cast('0' + c2); + } else // 10 <= c2 + { + retval += static_cast('A' + (c2 - 10)); + } + } else { + retval += c; + } + } + } + } + // Only 1 or 2 consecutive `"`s are allowed in multiline basic string. + // 3 consecutive `"`s are considered as a closing delimiter. + // We need to check if there are 3 or more consecutive `"`s and insert + // backslash to break them down into several short `"`s like the `str6` + // in the following example. + // ```toml + // str4 = """Here are two quotation marks: "". Simple enough.""" + // # str5 = """Here are three quotation marks: """.""" # INVALID + // str5 = """Here are three quotation marks: ""\".""" + // str6 = """Here are fifteen quotation marks: ""\"""\"""\"""\"""\".""" + // ``` + auto found_3_quotes = retval.find(string_conv("\"\"\"")); + while (found_3_quotes != string_type::npos) { + retval.replace(found_3_quotes, 3, string_conv("\"\"\\\"")); + found_3_quotes = retval.find(string_conv("\"\"\"")); + } + return retval; + } // }}} + + string_type format_local_time(const local_time_type& t, + const bool has_seconds, + const std::size_t subsec_prec) // {{{ + { + std::ostringstream oss; + oss << std::setfill('0') << std::setw(2) << static_cast(t.hour); + oss << ':'; + oss << std::setfill('0') << std::setw(2) << static_cast(t.minute); + if (has_seconds) { + oss << ':'; + oss << std::setfill('0') << std::setw(2) << static_cast(t.second); + if (subsec_prec != 0) { + std::ostringstream subsec; + subsec << std::setfill('0') << std::setw(3) + << static_cast(t.millisecond); + subsec << std::setfill('0') << std::setw(3) + << static_cast(t.microsecond); + subsec << std::setfill('0') << std::setw(3) + << static_cast(t.nanosecond); + std::string subsec_str = subsec.str(); + oss << '.' << subsec_str.substr(0, subsec_prec); + } + } + return string_conv(oss.str()); + } // }}} + + string_type format_ml_table(const table_type& t, + const table_format_info& fmt) // {{{ + { + const auto format_later = [](const value_type& v) -> bool { + const bool is_ml_table = v.is_table() && + v.as_table_fmt().fmt != table_format::oneline && + v.as_table_fmt().fmt != + table_format::multiline_oneline && + v.as_table_fmt().fmt != table_format::dotted; + + const bool is_ml_array_table = v.is_array_of_tables() && + v.as_array_fmt().fmt != + array_format::oneline && + v.as_array_fmt().fmt != + array_format::multiline; + + return is_ml_table || is_ml_array_table; + }; + + string_type retval; + this->current_indent_ += fmt.body_indent; + for (const auto& kv : t) { + const auto& key = kv.first; + const auto& val = kv.second; + if (format_later(val)) { + continue; + } + this->keys_.push_back(key); + + retval += format_comments(val.comments(), fmt.indent_type); + retval += format_indent(fmt.indent_type); + if (val.is_table() && val.as_table_fmt().fmt == table_format::dotted) { + retval += (*this)(val); + } else { + retval += format_key(key); + retval += string_conv(" = "); + retval += (*this)(val); + retval += char_type('\n'); + } + this->keys_.pop_back(); + } + this->current_indent_ -= fmt.body_indent; + + if (!retval.empty()) { + retval += char_type('\n'); // for readability, add empty line between tables + } + for (const auto& kv : t) { + if (!format_later(kv.second)) { + continue; + } + // must be a [multiline.table] or [[multiline.array.of.tables]]. + // comments will be generated inside it. + this->keys_.push_back(kv.first); + retval += (*this)(kv.second); + this->keys_.pop_back(); + } + return retval; + } // }}} + + string_type format_inline_table(const table_type& t, + const table_format_info&) // {{{ + { + // comments are ignored because we cannot write without newline + string_type retval; + retval += char_type('{'); + for (const auto& kv : t) { + this->force_inline_ = true; + retval += this->format_key(kv.first); + retval += string_conv(" = "); + retval += (*this)(kv.second); + retval += string_conv(", "); + } + if (!t.empty()) { + retval.pop_back(); // ' ' + retval.pop_back(); // ',' + } + retval += char_type('}'); + this->force_inline_ = false; + return retval; + } // }}} + + string_type format_ml_inline_table(const table_type& t, + const table_format_info& fmt) // {{{ + { + string_type retval; + retval += string_conv("{\n"); + this->current_indent_ += fmt.body_indent; + for (const auto& kv : t) { + this->force_inline_ = true; + retval += format_comments(kv.second.comments(), fmt.indent_type); + retval += format_indent(fmt.indent_type); + retval += kv.first; + retval += string_conv(" = "); + + this->force_inline_ = true; + retval += (*this)(kv.second); + + retval += string_conv(",\n"); + } + if (!t.empty()) { + retval.pop_back(); // '\n' + retval.pop_back(); // ',' + } + this->current_indent_ -= fmt.body_indent; + this->force_inline_ = false; + + this->current_indent_ += fmt.closing_indent; + retval += format_indent(fmt.indent_type); + this->current_indent_ -= fmt.closing_indent; + + retval += char_type('}'); + return retval; + } // }}} + + string_type format_dotted_table(const table_type& t, + const table_format_info& fmt, // {{{ + const source_location&, + std::vector& keys) { + // lets say we have: `{"a": {"b": {"c": {"d": "foo", "e": "bar"} } }` + // and `a` and `b` are `dotted`. + // + // - in case if `c` is `oneline`: + // ```toml + // a.b.c = {d = "foo", e = "bar"} + // ``` + // + // - in case if and `c` is `dotted`: + // ```toml + // a.b.c.d = "foo" + // a.b.c.e = "bar" + // ``` + + string_type retval; + + for (const auto& kv : t) { + const auto& key = kv.first; + const auto& val = kv.second; + + keys.push_back(key); + + // format recursive dotted table? + if (val.is_table() && val.as_table_fmt().fmt != table_format::oneline && + val.as_table_fmt().fmt != table_format::multiline_oneline) { + retval += this->format_dotted_table(val.as_table(), + val.as_table_fmt(), + val.location(), + keys); + } else // non-table or inline tables. format normally + { + retval += format_comments(val.comments(), fmt.indent_type); + retval += format_indent(fmt.indent_type); + retval += format_keys(keys).value(); + retval += string_conv(" = "); + this->force_inline_ = true; // sub-table must be inlined + retval += (*this)(val); + retval += char_type('\n'); + this->force_inline_ = false; + } + keys.pop_back(); + } + return retval; + } // }}} + + string_type format_key(const key_type& key) // {{{ + { + if (key.empty()) { + return string_conv("\"\""); + } + + // check the key can be a bare (unquoted) key + auto loc = detail::make_temporary_location(string_conv(key)); + auto reg = detail::syntax::unquoted_key(this->spec_).scan(loc); + if (reg.is_ok() && loc.eof()) { + return key; + } + + // if it includes special characters, then format it in a "quoted" key. + string_type formatted = string_conv("\""); + for (const char_type c : key) { + switch (c) { + case char_type('\\'): { + formatted += string_conv("\\\\"); + break; + } + case char_type('\"'): { + formatted += string_conv("\\\""); + break; + } + case char_type('\b'): { + formatted += string_conv("\\b"); + break; + } + case char_type('\t'): { + formatted += string_conv("\\t"); + break; + } + case char_type('\f'): { + formatted += string_conv("\\f"); + break; + } + case char_type('\n'): { + formatted += string_conv("\\n"); + break; + } + case char_type('\r'): { + formatted += string_conv("\\r"); + break; + } + default: { + // ASCII ctrl char + if ((char_type(0x00) <= c && c <= char_type(0x08)) || + (char_type(0x0A) <= c && c <= char_type(0x1F)) || + c == char_type(0x7F)) { + if (spec_.v1_1_0_add_escape_sequence_x) { + formatted += string_conv("\\x"); + } else { + formatted += string_conv("\\u00"); + } + const auto c1 = c / 16; + const auto c2 = c % 16; + formatted += static_cast('0' + c1); + if (c2 < 10) { + formatted += static_cast('0' + c2); + } else // 10 <= c2 + { + formatted += static_cast('A' + (c2 - 10)); + } + } else { + formatted += c; + } + break; + } + } + } + formatted += string_conv("\""); + return formatted; + } // }}} + + cxx::optional format_keys(const std::vector& keys) // {{{ + { + if (keys.empty()) { + return cxx::make_nullopt(); + } + + string_type formatted; + for (const auto& ky : keys) { + formatted += format_key(ky); + formatted += char_type('.'); + } + formatted.pop_back(); // remove the last dot '.' + return formatted; + } // }}} + + string_type format_comments(const discard_comments&, + const indent_char) const // {{{ + { + return string_conv(""); + } // }}} + + string_type format_comments(const preserve_comments& comments, + const indent_char indent_type) const // {{{ + { + string_type retval; + for (const auto& c : comments) { + if (c.empty()) { + continue; + } + retval += format_indent(indent_type); + if (c.front() != '#') { + retval += char_type('#'); + } + retval += string_conv(c); + if (c.back() != '\n') { + retval += char_type('\n'); + } + } + return retval; + } // }}} + + string_type format_indent(const indent_char indent_type) const // {{{ + { + const auto indent = static_cast( + (std::max)(0, this->current_indent_)); + if (indent_type == indent_char::space) { + return string_conv(make_string(indent, ' ')); + } else if (indent_type == indent_char::tab) { + return string_conv(make_string(indent, '\t')); + } else { + return string_type {}; + } + } // }}} + + std::locale set_locale(std::ostream& os) const { + return os.imbue(std::locale::classic()); + } + + private: + spec spec_; + bool force_inline_; // table inside an array without fmt specification + std::int32_t current_indent_; + std::vector keys_; + }; + } // namespace detail + + template + typename basic_value::string_type format( + const basic_value& v, + const spec s = spec::default_version()) { + detail::serializer ser(s); + return ser(v); + } + + template + typename basic_value::string_type format( + const typename basic_value::key_type& k, + const basic_value& v, + const spec s = spec::default_version()) { + detail::serializer ser(s); + return ser(k, v); + } + + template + typename basic_value::string_type format( + const std::vector::key_type>& ks, + const basic_value& v, + const spec s = spec::default_version()) { + detail::serializer ser(s); + return ser(ks, v); + } + + template + std::ostream& operator<<(std::ostream& os, const basic_value& v) { + os << format(v); + return os; + } + +} // namespace toml + +#if defined(TOML11_COMPILE_SOURCES) +namespace toml { + struct type_config; + struct ordered_type_config; + + extern template typename basic_value::string_type format( + const basic_value&, + const spec); + + extern template typename basic_value::string_type format( + const typename basic_value::key_type& k, + const basic_value& v, + const spec); + + extern template typename basic_value::string_type format( + const std::vector::key_type>& ks, + const basic_value& v, + const spec s); + + extern template typename basic_value::string_type + format(const basic_value&, + const spec); + + extern template typename basic_value::string_type format( + const typename basic_value::key_type& k, + const basic_value& v, + const spec); + + extern template typename basic_value::string_type format( + const std::vector::key_type>& ks, + const basic_value& v, + const spec s); + + namespace detail { + extern template class serializer<::toml::type_config>; + extern template class serializer<::toml::ordered_type_config>; + } // namespace detail +} // namespace toml +#endif // TOML11_COMPILE_SOURCES + +#endif // TOML11_SERIALIZER_HPP +#ifndef TOML11_TOML_HPP +#define TOML11_TOML_HPP + +// The MIT License (MIT) +// +// Copyright (c) 2017-now Toru Niina +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// IWYU pragma: begin_exports +// IWYU pragma: end_exports + +#endif // TOML11_TOML_HPP diff --git a/src/global/utils/tools.h b/src/global/utils/tools.h index 451c24264..8a568ae20 100644 --- a/src/global/utils/tools.h +++ b/src/global/utils/tools.h @@ -2,23 +2,27 @@ * @file utils/tools.h * @brief Helper functions for general use * @implements + * - tools::ArrayImbalance -> unsigned short * - tools::TensorProduct<> -> boundaries_t - * - tools::decompose1D -> std::vector + * - tools::decompose1D -> std::vector * - tools::divideInProportions2D -> std::tuple * - tools::divideInProportions3D -> std::tuple - * - tools::Decompose -> std::vector> + * - tools::Decompose -> std::vector> + * - tools::Tracker * @namespaces: * - tools:: */ -#ifndef UTILS_HELPERS_H -#define UTILS_HELPERS_H +#ifndef UTILS_TOOLS_H +#define UTILS_TOOLS_H #include "global.h" +#include "utils/comparators.h" #include "utils/error.h" #include "utils/numeric.h" +#include #include #include #include @@ -26,14 +30,38 @@ namespace tools { + /** + * @brief Compute the imbalance of a list of nonnegative values + * @param values List of values + * @return Imbalance of the list (0...100) + */ + template + auto ArrayImbalance(const std::vector& values) -> unsigned short { + raise::ErrorIf(values.empty(), "Disbalance error: value array is empty", HERE); + const auto mean = static_cast(std::accumulate(values.begin(), + values.end(), + static_cast(0))) / + static_cast(values.size()); + const auto sq_sum = static_cast(std::inner_product(values.begin(), + values.end(), + values.begin(), + static_cast(0))); + if (cmp::AlmostZero_host(sq_sum) || cmp::AlmostZero_host(mean)) { + return 0; + } + const auto cv = std::sqrt( + sq_sum / static_cast(values.size()) / mean - 1.0); + return static_cast(100.0 / (1.0 + math::exp(-cv))); + } + /** * @brief Compute a tensor product of a list of vectors * @param list List of vectors * @return Tensor product of list */ template - inline auto TensorProduct(const std::vector>& list) - -> std::vector> { + inline auto TensorProduct( + const std::vector>& list) -> std::vector> { std::vector> result = { {} }; for (const auto& sublist : list) { std::vector> temp; @@ -53,16 +81,16 @@ namespace tools { * @param ndomains Number of domains * @param ncells Number of cells */ - inline auto decompose1D(unsigned int ndomains, std::size_t ncells) - -> std::vector { - auto size = (std::size_t)((double)ncells / (double)ndomains); - auto ncells_domain = std::vector(ndomains, size); - for (std::size_t i { 0 }; i < ncells - size * ndomains; ++i) { + inline auto decompose1D(unsigned int ndomains, + ncells_t ncells) -> std::vector { + auto size = (ncells_t)((double)ncells / (double)ndomains); + auto ncells_domain = std::vector(ndomains, size); + for (auto i { 0u }; i < ncells - size * ndomains; ++i) { ncells_domain[i] += 1; } auto sum = std::accumulate(ncells_domain.begin(), ncells_domain.end(), - (std::size_t)0); + (ncells_t)0); raise::ErrorIf(sum != ncells, "Decomposition error: sum != ncells", HERE); raise::ErrorIf(ncells_domain.size() != (std::size_t)ndomains, "Decomposition error: size != ndomains", @@ -79,8 +107,10 @@ namespace tools { * @param s1 Proportion of the first dimension * @param s2 Proportion of the second dimension */ - inline auto divideInProportions2D(unsigned int ntot, unsigned int s1, unsigned int s2) - -> std::tuple { + inline auto divideInProportions2D( + unsigned int ntot, + unsigned int s1, + unsigned int s2) -> std::tuple { auto n1 = (unsigned int)(std::sqrt((double)ntot * (double)s1 / (double)s2)); if (n1 == 0) { return { 1, ntot }; @@ -102,11 +132,11 @@ namespace tools { * @param s2 Proportion of the second dimension * @param s3 Proportion of the third dimension */ - inline auto divideInProportions3D(unsigned int ntot, - unsigned int s1, - unsigned int s2, - unsigned int s3) - -> std::tuple { + inline auto divideInProportions3D( + unsigned int ntot, + unsigned int s1, + unsigned int s2, + unsigned int s3) -> std::tuple { auto n1 = (unsigned int)(std::cbrt( (double)ntot * (double)(SQR(s1)) / (double)(s2 * s3))); if (n1 > ntot) { @@ -135,10 +165,10 @@ namespace tools { * * @note If decomposition has -1, it will be calculated automatically */ - inline auto Decompose(unsigned int ndomains, - const std::vector& ncells, - const std::vector& decomposition) - -> std::vector> { + inline auto Decompose( + unsigned int ndomains, + const std::vector& ncells, + const std::vector& decomposition) -> std::vector> { const auto dimension = ncells.size(); raise::ErrorIf(dimension != decomposition.size(), "Decomposition error: dimension != decomposition.size", @@ -227,7 +257,7 @@ namespace tools { raise::ErrorIf(ndomains % n1 != 0, "Decomposition error: does not divide evenly", HERE); - std::tie(n2, + std::tie(n2, n3) = divideInProportions2D(ndomains / n1, ncells[1], ncells[2]); } else if (decomposition[0] < 0 && decomposition[1] < 0 && decomposition[2] < 0) { @@ -245,6 +275,81 @@ namespace tools { } } + /** + * Class for tracking the passage of time either in steps, physical time units, or walltime + * + * @note Primarily used for writing checkpoints and all types of outputs at specified intervals + */ + class Tracker { + bool m_initialized { false }; + + std::string m_type; + timestep_t m_interval; + simtime_t m_interval_time; + bool m_use_time; + + timestamp_t m_start_walltime; + timestamp_t m_end_walltime; + bool m_walltime_pending { false }; + + simtime_t m_last_output_time { -1.0 }; + + public: + Tracker() = default; + + Tracker(const std::string& type, + timestep_t interval, + simtime_t interval_time, + const std::string& end_walltime = "", + const timestamp_t& start_walltime = std::chrono::system_clock::now()) { + init(type, interval, interval_time, end_walltime, start_walltime); + } + + ~Tracker() = default; + + void init(const std::string& type, + timestep_t interval, + simtime_t interval_time, + const std::string& end_walltime = "", + const timestamp_t& start_walltime = std::chrono::system_clock::now()) { + m_initialized = true; + m_type = type; + m_interval = interval; + m_interval_time = interval_time; + m_use_time = interval_time > 0.0; + m_start_walltime = start_walltime; + if (not(end_walltime.empty() or end_walltime == "00:00:00")) { + m_walltime_pending = true; + raise::ErrorIf(end_walltime.size() != 8, + "invalid end walltime format, expected HH:MM:SS", + HERE); + m_end_walltime = m_start_walltime + + std::chrono::hours(std::stoi(end_walltime.substr(0, 2))) + + std::chrono::minutes(std::stoi(end_walltime.substr(3, 2))) + + std::chrono::seconds(std::stoi(end_walltime.substr(6, 2))); + } + } + + auto shouldWrite(timestep_t step, simtime_t time) -> bool { + raise::ErrorIf(!m_initialized, "Tracker not initialized", HERE); + if (m_walltime_pending and + (std::chrono::system_clock::now() > m_end_walltime)) { + m_walltime_pending = false; + return true; + } else if (m_use_time) { + if ((m_last_output_time < 0) or + (time - m_last_output_time >= m_interval_time)) { + m_last_output_time = time; + return true; + } else { + return false; + } + } else { + return step % m_interval == 0; + } + } + }; + } // namespace tools -#endif // UTILS_HELPERS_H +#endif // UTILS_TOOLS_H diff --git a/src/kernels/CMakeLists.txt b/src/kernels/CMakeLists.txt index d24dff0a4..60eda24cb 100644 --- a/src/kernels/CMakeLists.txt +++ b/src/kernels/CMakeLists.txt @@ -1,13 +1,20 @@ +# cmake-lint: disable=C0103 # ------------------------------ # @defines: ntt_kernels [INTERFACE] +# # @includes: -# - ../ +# +# * ../ +# # @depends: -# - ntt_global [required] +# +# * ntt_global [required] +# # @uses: -# - kokkos [required] -# - plog [required] -# - mpi [optional] +# +# * kokkos [required] +# * plog [required] +# * mpi [optional] # ------------------------------ add_library(ntt_kernels INTERFACE) @@ -17,5 +24,4 @@ add_dependencies(ntt_kernels ${libs}) target_link_libraries(ntt_kernels INTERFACE ${libs}) target_include_directories(ntt_kernels - INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../ -) \ No newline at end of file + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../) diff --git a/src/kernels/ampere_gr.hpp b/src/kernels/ampere_gr.hpp index 5af0fa4ef..327ce0cdf 100644 --- a/src/kernels/ampere_gr.hpp +++ b/src/kernels/ampere_gr.hpp @@ -37,7 +37,7 @@ namespace kernel::gr { ndfield_t Dout; const ndfield_t H; const M metric; - const std::size_t i2max; + const ncells_t i2max; const real_t coeff; bool is_axis_i2min { false }, is_axis_i2max { false }; @@ -47,7 +47,7 @@ namespace kernel::gr { const ndfield_t& H, const M& metric, real_t coeff, - std::size_t ni2, + ncells_t ni2, const boundaries_t& boundaries) : Din { Din } , Dout { Dout } @@ -57,9 +57,6 @@ namespace kernel::gr { , coeff { coeff } { if constexpr ((D == Dim::_2D) || (D == Dim::_3D)) { raise::ErrorIf(boundaries.size() < 2, "boundaries defined incorrectly", HERE); - raise::ErrorIf(boundaries[1].size() < 2, - "boundaries defined incorrectly", - HERE); is_axis_i2min = (boundaries[1].first == FldsBC::AXIS); is_axis_i2max = (boundaries[1].second == FldsBC::AXIS); } @@ -67,9 +64,9 @@ namespace kernel::gr { Inline void operator()(index_t i1, index_t i2) const { if constexpr (D == Dim::_2D) { - constexpr std::size_t i2min { N_GHOSTS }; - const real_t i1_ { COORD(i1) }; - const real_t i2_ { COORD(i2) }; + constexpr ncells_t i2min { N_GHOSTS }; + const real_t i1_ { COORD(i1) }; + const real_t i2_ { COORD(i2) }; const real_t inv_sqrt_detH_0pH { ONE / metric.sqrt_det_h({ i1_, i2_ + HALF }) }; @@ -77,6 +74,7 @@ namespace kernel::gr { if ((i2 == i2min) && is_axis_i2min) { // theta = 0 const real_t inv_polar_area_pH { ONE / metric.polar_area(i1_ + HALF) }; + const real_t inv_sqrt_detH_0pH { ONE / metric.sqrt_det_h({ i1_, HALF }) }; Dout(i1, i2, em::dx1) = Din(i1, i2, em::dx1) + inv_polar_area_pH * coeff * H(i1, i2, em::hx3); Dout(i1, i2, em::dx2) = Din(i1, i2, em::dx2) + @@ -118,7 +116,7 @@ namespace kernel::gr { ndfield_t Df; const ndfield_t J; const M metric; - const std::size_t i2max; + const ncells_t i2max; const real_t coeff; bool is_axis_i2min { false }; bool is_axis_i2max { false }; @@ -132,7 +130,7 @@ namespace kernel::gr { const ndfield_t& J, const M& metric, real_t coeff, - std::size_t ni2, + ncells_t ni2, const boundaries_t& boundaries) : Df { Df } , J { J } @@ -148,9 +146,9 @@ namespace kernel::gr { Inline void operator()(index_t i1, index_t i2) const { if constexpr (D == Dim::_2D) { - constexpr std::size_t i2min { N_GHOSTS }; - const real_t i1_ { COORD(i1) }; - const real_t i2_ { COORD(i2) }; + constexpr ncells_t i2min { N_GHOSTS }; + const real_t i1_ { COORD(i1) }; + const real_t i2_ { COORD(i2) }; const real_t inv_sqrt_detH_0pH { ONE / metric.sqrt_det_h({ i1_, i2_ + HALF }) }; diff --git a/src/kernels/ampere_mink.hpp b/src/kernels/ampere_mink.hpp index 16ed1655a..45773ead3 100644 --- a/src/kernels/ampere_mink.hpp +++ b/src/kernels/ampere_mink.hpp @@ -15,10 +15,15 @@ #include "arch/kokkos_aliases.h" #include "utils/error.h" +#include "utils/numeric.h" namespace kernel::mink { using namespace ntt; + struct NoCurrent_t { + NoCurrent_t() {} + }; + /** * @brief Algorithm for the Ampere's law: `dE/dt = curl B` in Minkowski space. * @tparam D Dimension. @@ -88,33 +93,59 @@ namespace kernel::mink { * @brief `coeff` includes metric coefficient. * @tparam D Dimension. */ - template + template class CurrentsAmpere_kernel { - ndfield_t E; - ndfield_t J; - // coeff = -dt * q0 * n0 / (B0 * V0) - const real_t coeff; - const real_t inv_n0; + static constexpr auto ExtCurrent = not std::is_same::value; + ndfield_t E; + ndfield_t J; + // coeff = -dt * q0 / (B0 * V0) + const real_t coeff; + const real_t ppc0; + const C ext_current; + real_t x1min { ZERO }; + real_t x2min { ZERO }; + real_t x3min { ZERO }; + real_t dx; public: + CurrentsAmpere_kernel(const ndfield_t& E, + const ndfield_t J, + real_t coeff, + real_t ppc0, + const C& ext_current, + const std::vector xmin, + real_t dx) + : E { E } + , J { J } + , coeff { coeff } + , ppc0 { ppc0 } + , ext_current { ext_current } + , x1min { xmin.size() > 0 ? xmin[0] : ZERO } + , x2min { xmin.size() > 1 ? xmin[1] : ZERO } + , x3min { xmin.size() > 2 ? xmin[2] : ZERO } + , dx { dx } {} + CurrentsAmpere_kernel(const ndfield_t& E, const ndfield_t J, real_t coeff, real_t inv_n0) - : E { E } - , J { J } - , coeff { coeff } - , inv_n0 { inv_n0 } {} + : CurrentsAmpere_kernel { E, J, coeff, inv_n0, NoCurrent_t {}, {}, ZERO } {} Inline void operator()(index_t i1) const { if constexpr (D == Dim::_1D) { - J(i1, cur::jx1) *= inv_n0; - J(i1, cur::jx2) *= inv_n0; - J(i1, cur::jx3) *= inv_n0; - + if constexpr (ExtCurrent) { + const auto i1_ = COORD(i1); + J(i1, cur::jx1) += ppc0 * ext_current.jx1({ (i1_ + HALF) * dx + x1min }); + J(i1, cur::jx2) += ppc0 * ext_current.jx2({ i1_ * dx + x1min }); + J(i1, cur::jx3) += ppc0 * ext_current.jx3({ i1_ * dx + x1min }); + } E(i1, em::ex1) += J(i1, cur::jx1) * coeff; E(i1, em::ex2) += J(i1, cur::jx2) * coeff; E(i1, em::ex3) += J(i1, cur::jx3) * coeff; + + J(i1, cur::jx1) /= ppc0; + J(i1, cur::jx2) /= ppc0; + J(i1, cur::jx3) /= ppc0; } else { raise::KernelError( HERE, @@ -124,14 +155,24 @@ namespace kernel::mink { Inline void operator()(index_t i1, index_t i2) const { if constexpr (D == Dim::_2D) { - J(i1, i2, cur::jx1) *= inv_n0; - J(i1, i2, cur::jx2) *= inv_n0; - J(i1, i2, cur::jx3) *= inv_n0; - + if constexpr (ExtCurrent) { + const auto i1_ = COORD(i1); + const auto i2_ = COORD(i2); + J(i1, i2, cur::jx1) += ppc0 * ext_current.jx1({ (i1_ + HALF) * dx + x1min, + i2_ * dx + x2min }); + J(i1, i2, cur::jx2) += ppc0 * + ext_current.jx2({ i1_ * dx + x1min, + (i2_ + HALF) * dx + x2min }); + J(i1, i2, cur::jx3) += ppc0 * ext_current.jx3({ i1_ * dx + x1min, + i2_ * dx + x2min }); + } E(i1, i2, em::ex1) += J(i1, i2, cur::jx1) * coeff; E(i1, i2, em::ex2) += J(i1, i2, cur::jx2) * coeff; E(i1, i2, em::ex3) += J(i1, i2, cur::jx3) * coeff; + J(i1, i2, cur::jx1) /= ppc0; + J(i1, i2, cur::jx2) /= ppc0; + J(i1, i2, cur::jx3) /= ppc0; } else { raise::KernelError( HERE, @@ -141,13 +182,30 @@ namespace kernel::mink { Inline void operator()(index_t i1, index_t i2, index_t i3) const { if constexpr (D == Dim::_3D) { - J(i1, i2, i3, cur::jx1) *= inv_n0; - J(i1, i2, i3, cur::jx2) *= inv_n0; - J(i1, i2, i3, cur::jx3) *= inv_n0; - + if constexpr (ExtCurrent) { + const auto i1_ = COORD(i1); + const auto i2_ = COORD(i2); + const auto i3_ = COORD(i3); + J(i1, i2, i3, cur::jx1) += ppc0 * + ext_current.jx1({ (i1_ + HALF) * dx + x1min, + i2_ * dx + x2min, + i3_ * dx + x3min }); + J(i1, i2, i3, cur::jx2) += ppc0 * + ext_current.jx2({ i1_ * dx + x1min, + (i2_ + HALF) * dx + x2min, + i3_ * dx + x3min }); + J(i1, i2, i3, cur::jx3) += ppc0 * ext_current.jx3( + { i1_ * dx + x1min, + i2_ * dx + x2min, + (i3_ + HALF) * dx + x3min }); + } E(i1, i2, i3, em::ex1) += J(i1, i2, i3, cur::jx1) * coeff; E(i1, i2, i3, em::ex2) += J(i1, i2, i3, cur::jx2) * coeff; E(i1, i2, i3, em::ex3) += J(i1, i2, i3, cur::jx3) * coeff; + + J(i1, i2, i3, cur::jx1) /= ppc0; + J(i1, i2, i3, cur::jx2) /= ppc0; + J(i1, i2, i3, cur::jx3) /= ppc0; } else { raise::KernelError( HERE, diff --git a/src/kernels/ampere_sr.hpp b/src/kernels/ampere_sr.hpp index e4faec6ce..dbe9d3dbf 100644 --- a/src/kernels/ampere_sr.hpp +++ b/src/kernels/ampere_sr.hpp @@ -32,17 +32,17 @@ namespace kernel::sr { static_assert(M::is_metric, "M must be a metric class"); static constexpr auto D = M::Dim; - ndfield_t EB; - const M metric; - const std::size_t i2max; - const real_t coeff; - bool is_axis_i2min { false }, is_axis_i2max { false }; + ndfield_t EB; + const M metric; + const ncells_t i2max; + const real_t coeff; + bool is_axis_i2min { false }, is_axis_i2max { false }; public: Ampere_kernel(const ndfield_t& EB, const M& metric, real_t coeff, - std::size_t ni2, + ncells_t ni2, const boundaries_t& boundaries) : EB { EB } , metric { metric } @@ -57,9 +57,9 @@ namespace kernel::sr { Inline void operator()(index_t i1, index_t i2) const { if constexpr (D == Dim::_2D) { - constexpr std::size_t i2min { N_GHOSTS }; - const real_t i1_ { COORD(i1) }; - const real_t i2_ { COORD(i2) }; + constexpr ncells_t i2min { N_GHOSTS }; + const real_t i1_ { COORD(i1) }; + const real_t i2_ { COORD(i2) }; const real_t inv_sqrt_detH_0pH { ONE / metric.sqrt_det_h({ i1_, i2_ + HALF }) }; @@ -122,18 +122,18 @@ namespace kernel::sr { */ template class CurrentsAmpere_kernel { - static constexpr auto D = M::Dim; - static constexpr std::size_t i2min = N_GHOSTS; + static constexpr auto D = M::Dim; + static constexpr ncells_t i2min = N_GHOSTS; - ndfield_t E; - ndfield_t J; - const M metric; - const std::size_t i2max; + ndfield_t E; + ndfield_t J; + const M metric; + const ncells_t i2max; // coeff = -dt * q0 * n0 / B0 - const real_t coeff; - const real_t inv_n0; - bool is_axis_i2min { false }; - bool is_axis_i2max { false }; + const real_t coeff; + const real_t inv_n0; + bool is_axis_i2min { false }; + bool is_axis_i2max { false }; public: /** @@ -145,7 +145,7 @@ namespace kernel::sr { const M& metric, real_t coeff, real_t inv_n0, - std::size_t ni2, + ncells_t ni2, const boundaries_t& boundaries) : E { E } , J { J } diff --git a/src/kernels/aux_fields_gr.hpp b/src/kernels/aux_fields_gr.hpp index 5744c3092..c2e45d6de 100644 --- a/src/kernels/aux_fields_gr.hpp +++ b/src/kernels/aux_fields_gr.hpp @@ -230,6 +230,86 @@ namespace kernel::gr { } } }; + + /** + * @brief Kernel for computing time average of B and D + * @tparam M Metric + */ + template + class TimeAverageDB_kernel { + static_assert(M::is_metric, "M must be a metric class"); + static constexpr auto D = M::Dim; + + const ndfield_t BDf; + ndfield_t BDf0; + const M metric; + + public: + TimeAverageDB_kernel(const ndfield_t& BDf, + const ndfield_t& BDf0, + const M& metric) + : BDf { BDf } + , BDf0 { BDf0 } + , metric { metric } {} + + Inline void operator()(index_t i1, index_t i2) const { + if constexpr (D == Dim::_2D) { + BDf0(i1, i2, em::bx1) = HALF * + (BDf0(i1, i2, em::bx1) + BDf(i1, i2, em::bx1)); + BDf0(i1, i2, em::bx2) = HALF * + (BDf0(i1, i2, em::bx2) + BDf(i1, i2, em::bx2)); + BDf0(i1, i2, em::bx3) = HALF * + (BDf0(i1, i2, em::bx3) + BDf(i1, i2, em::bx3)); + BDf0(i1, i2, em::ex1) = HALF * + (BDf0(i1, i2, em::ex1) + BDf(i1, i2, em::ex1)); + BDf0(i1, i2, em::ex2) = HALF * + (BDf0(i1, i2, em::ex2) + BDf(i1, i2, em::ex2)); + BDf0(i1, i2, em::ex3) = HALF * + (BDf0(i1, i2, em::ex3) + BDf(i1, i2, em::ex3)); + } else { + raise::KernelError( + HERE, + "ComputeAuxH_kernel: 2D implementation called for D != 2"); + } + } + }; + + /** + * @brief Kernel for computing time average of J + * @tparam M Metric + */ + template + class TimeAverageJ_kernel { + static_assert(M::is_metric, "M must be a metric class"); + static constexpr auto D = M::Dim; + + ndfield_t Jf; + const ndfield_t Jf0; + const M metric; + + public: + TimeAverageJ_kernel(const ndfield_t& Jf, + const ndfield_t& Jf0, + const M& metric) + : Jf { Jf } + , Jf0 { Jf0 } + , metric { metric } {} + + Inline void operator()(index_t i1, index_t i2) const { + if constexpr (D == Dim::_2D) { + Jf(i1, i2, cur::jx1) = HALF * + (Jf0(i1, i2, cur::jx1) + Jf(i1, i2, cur::jx1)); + Jf(i1, i2, cur::jx2) = HALF * + (Jf0(i1, i2, cur::jx2) + Jf(i1, i2, cur::jx2)); + Jf(i1, i2, cur::jx3) = HALF * + (Jf0(i1, i2, cur::jx3) + Jf(i1, i2, cur::jx3)); + } else { + raise::KernelError( + HERE, + "ComputeAuxH_kernel: 2D implementation called for D != 2"); + } + } + }; } // namespace kernel::gr #endif // KERNELS_AUX_FIELDS_GR_HPP diff --git a/src/kernels/comm.hpp b/src/kernels/comm.hpp new file mode 100644 index 000000000..60251d8c6 --- /dev/null +++ b/src/kernels/comm.hpp @@ -0,0 +1,342 @@ +/** + * @file kernels/comm.hpp + * @brief Kernels used during communications + * @implements + * - kernel::comm::PrepareOutgoingPrtls_kernel<> + * - kernel::comm::PopulatePrtlSendBuffer_kernel<> + * - kernel::comm::ExtractReceivedPrtls_kernel<> + * @namespaces: + * - kernel::comm:: + */ + +#ifndef KERNELS_COMM_HPP +#define KERNELS_COMM_HPP + +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" + +#include + +namespace kernel::comm { + using namespace ntt; + + template + class PrepareOutgoingPrtls_kernel { + const array_t shifts_in_x1, shifts_in_x2, shifts_in_x3; + array_t outgoing_indices; + + const npart_t npart, npart_alive, npart_dead; + const std::size_t ntags; + + array_t i1, i1_prev, i2, i2_prev, i3, i3_prev; + const array_t tag; + + const array_t tag_offsets; + + array_t current_offset; + + public: + PrepareOutgoingPrtls_kernel(const array_t& shifts_in_x1, + const array_t& shifts_in_x2, + const array_t& shifts_in_x3, + array_t& outgoing_indices, + npart_t npart, + npart_t npart_alive, + npart_t npart_dead, + std::size_t ntags, + array_t& i1, + array_t& i1_prev, + array_t& i2, + array_t& i2_prev, + array_t& i3, + array_t& i3_prev, + const array_t& tag, + const array_t& tag_offsets) + : shifts_in_x1 { shifts_in_x1 } + , shifts_in_x2 { shifts_in_x2 } + , shifts_in_x3 { shifts_in_x3 } + , outgoing_indices { outgoing_indices } + , npart { npart } + , npart_alive { npart_alive } + , npart_dead { npart_dead } + , ntags { ntags } + , i1 { i1 } + , i1_prev { i1_prev } + , i2 { i2 } + , i2_prev { i2_prev } + , i3 { i3 } + , i3_prev { i3_prev } + , tag { tag } + , tag_offsets { tag_offsets } + , current_offset { "current_offset", ntags } {} + + Inline void operator()(index_t p) const { + if (tag(p) != ParticleTag::alive) { + // dead or to-be-sent + auto idx_for_tag = Kokkos::atomic_fetch_add(¤t_offset(tag(p)), 1); + if (tag(p) != ParticleTag::dead) { + idx_for_tag += npart_dead; + } + if (tag(p) > 2) { + idx_for_tag += tag_offsets(tag(p) - 3); + } + if (idx_for_tag >= npart - npart_alive) { + raise::KernelError(HERE, "Outgoing indices idx exceeds the array size"); + } + outgoing_indices(idx_for_tag) = p; + // apply offsets + if (tag(p) != ParticleTag::dead) { + if constexpr (D == Dim::_1D or D == Dim::_2D or D == Dim::_3D) { + i1(p) += shifts_in_x1(tag(p) - 2); + i1_prev(p) += shifts_in_x1(tag(p) - 2); + } + if constexpr (D == Dim::_2D or D == Dim::_3D) { + i2(p) += shifts_in_x2(tag(p) - 2); + i2_prev(p) += shifts_in_x2(tag(p) - 2); + } + if constexpr (D == Dim::_3D) { + i3(p) += shifts_in_x3(tag(p) - 2); + i3_prev(p) += shifts_in_x3(tag(p) - 2); + } + } + } + } + }; + + template + class PopulatePrtlSendBuffer_kernel { + array_t send_buff_int; + array_t send_buff_real; + array_t send_buff_prtldx; + array_t send_buff_pld; + + const unsigned short NINTS, NREALS, NPRTLDX, NPLDS; + const npart_t idx_offset; + + const array_t i1, i1_prev, i2, i2_prev, i3, i3_prev; + const array_t dx1, dx1_prev, dx2, dx2_prev, dx3, dx3_prev; + const array_t ux1, ux2, ux3, weight, phi; + const array_t pld; + array_t tag; + const array_t outgoing_indices; + + public: + PopulatePrtlSendBuffer_kernel(array_t& send_buff_int, + array_t& send_buff_real, + array_t& send_buff_prtldx, + array_t& send_buff_pld, + unsigned short NINTS, + unsigned short NREALS, + unsigned short NPRTLDX, + unsigned short NPLDS, + npart_t idx_offset, + const array_t& i1, + const array_t& i1_prev, + const array_t& dx1, + const array_t& dx1_prev, + const array_t& i2, + const array_t& i2_prev, + const array_t& dx2, + const array_t& dx2_prev, + const array_t& i3, + const array_t& i3_prev, + const array_t& dx3, + const array_t& dx3_prev, + const array_t& ux1, + const array_t& ux2, + const array_t& ux3, + const array_t& weight, + const array_t& phi, + const array_t& pld, + array_t& tag, + const array_t& outgoing_indices) + : send_buff_int { send_buff_int } + , send_buff_real { send_buff_real } + , send_buff_prtldx { send_buff_prtldx } + , send_buff_pld { send_buff_pld } + , NINTS { NINTS } + , NREALS { NREALS } + , NPRTLDX { NPRTLDX } + , NPLDS { NPLDS } + , idx_offset { idx_offset } + , i1 { i1 } + , i1_prev { i1_prev } + , i2 { i2 } + , i2_prev { i2_prev } + , i3 { i3 } + , i3_prev { i3_prev } + , dx1 { dx1 } + , dx1_prev { dx1_prev } + , dx2 { dx2 } + , dx2_prev { dx2_prev } + , dx3 { dx3 } + , dx3_prev { dx3_prev } + , ux1 { ux1 } + , ux2 { ux2 } + , ux3 { ux3 } + , weight { weight } + , phi { phi } + , pld { pld } + , tag { tag } + , outgoing_indices { outgoing_indices } {} + + Inline void operator()(index_t p) const { + const auto idx = outgoing_indices(idx_offset + p); + if constexpr (D == Dim::_1D or D == Dim::_2D or D == Dim::_3D) { + send_buff_int(NINTS * p + 0) = i1(idx); + send_buff_int(NINTS * p + 1) = i1_prev(idx); + send_buff_prtldx(NPRTLDX * p + 0) = dx1(idx); + send_buff_prtldx(NPRTLDX * p + 1) = dx1_prev(idx); + } + if constexpr (D == Dim::_2D or D == Dim::_3D) { + send_buff_int(NINTS * p + 2) = i2(idx); + send_buff_int(NINTS * p + 3) = i2_prev(idx); + send_buff_prtldx(NPRTLDX * p + 2) = dx2(idx); + send_buff_prtldx(NPRTLDX * p + 3) = dx2_prev(idx); + } + if constexpr (D == Dim::_3D) { + send_buff_int(NINTS * p + 4) = i3(idx); + send_buff_int(NINTS * p + 5) = i3_prev(idx); + send_buff_prtldx(NPRTLDX * p + 4) = dx3(idx); + send_buff_prtldx(NPRTLDX * p + 5) = dx3_prev(idx); + } + send_buff_real(NREALS * p + 0) = ux1(idx); + send_buff_real(NREALS * p + 1) = ux2(idx); + send_buff_real(NREALS * p + 2) = ux3(idx); + send_buff_real(NREALS * p + 3) = weight(idx); + if constexpr (D == Dim::_2D and C != Coord::Cart) { + send_buff_real(NREALS * p + 4) = phi(idx); + } + if (NPLDS > 0) { + for (auto l { 0u }; l < NPLDS; ++l) { + send_buff_pld(NPLDS * p + l) = pld(idx, l); + } + } + tag(idx) = ParticleTag::dead; + } + }; + + template + class ExtractReceivedPrtls_kernel { + const array_t recv_buff_int; + const array_t recv_buff_real; + const array_t recv_buff_prtldx; + const array_t recv_buff_pld; + + const unsigned short NINTS, NREALS, NPRTLDX, NPLDS; + const npart_t npart, npart_holes; + + array_t i1, i1_prev, i2, i2_prev, i3, i3_prev; + array_t dx1, dx1_prev, dx2, dx2_prev, dx3, dx3_prev; + array_t ux1, ux2, ux3, weight, phi; + array_t pld; + array_t tag; + const array_t outgoing_indices; + + public: + ExtractReceivedPrtls_kernel(const array_t& recv_buff_int, + const array_t& recv_buff_real, + const array_t& recv_buff_prtldx, + const array_t& recv_buff_pld, + unsigned short NINTS, + unsigned short NREALS, + unsigned short NPRTLDX, + unsigned short NPLDS, + npart_t npart, + array_t& i1, + array_t& i1_prev, + array_t& dx1, + array_t& dx1_prev, + array_t& i2, + array_t& i2_prev, + array_t& dx2, + array_t& dx2_prev, + array_t& i3, + array_t& i3_prev, + array_t& dx3, + array_t& dx3_prev, + array_t& ux1, + array_t& ux2, + array_t& ux3, + array_t& weight, + array_t& phi, + array_t& pld, + array_t& tag, + const array_t& outgoing_indices) + : recv_buff_int { recv_buff_int } + , recv_buff_real { recv_buff_real } + , recv_buff_prtldx { recv_buff_prtldx } + , recv_buff_pld { recv_buff_pld } + , NINTS { NINTS } + , NREALS { NREALS } + , NPRTLDX { NPRTLDX } + , NPLDS { NPLDS } + , npart { npart } + , npart_holes { outgoing_indices.extent(0) } + , i1 { i1 } + , i1_prev { i1_prev } + , i2 { i2 } + , i2_prev { i2_prev } + , i3 { i3 } + , i3_prev { i3_prev } + , dx1 { dx1 } + , dx1_prev { dx1_prev } + , dx2 { dx2 } + , dx2_prev { dx2_prev } + , dx3 { dx3 } + , dx3_prev { dx3_prev } + , ux1 { ux1 } + , ux2 { ux2 } + , ux3 { ux3 } + , weight { weight } + , phi { phi } + , pld { pld } + , tag { tag } + , outgoing_indices { outgoing_indices } {} + + Inline void operator()(index_t p) const { + npart_t idx; + if (p >= npart_holes) { + idx = npart + p - npart_holes; + } else { + idx = outgoing_indices(p); + } + if constexpr (D == Dim::_1D or D == Dim::_2D or D == Dim::_3D) { + i1(idx) = recv_buff_int(NINTS * p + 0); + i1_prev(idx) = recv_buff_int(NINTS * p + 1); + dx1(idx) = recv_buff_prtldx(NPRTLDX * p + 0); + dx1_prev(idx) = recv_buff_prtldx(NPRTLDX * p + 1); + } + if constexpr (D == Dim::_2D or D == Dim::_3D) { + i2(idx) = recv_buff_int(NINTS * p + 2); + i2_prev(idx) = recv_buff_int(NINTS * p + 3); + dx2(idx) = recv_buff_prtldx(NPRTLDX * p + 2); + dx2_prev(idx) = recv_buff_prtldx(NPRTLDX * p + 3); + } + if constexpr (D == Dim::_3D) { + i3(idx) = recv_buff_int(NINTS * p + 4); + i3_prev(idx) = recv_buff_int(NINTS * p + 5); + dx3(idx) = recv_buff_prtldx(NPRTLDX * p + 4); + dx3_prev(idx) = recv_buff_prtldx(NPRTLDX * p + 5); + } + ux1(idx) = recv_buff_real(NREALS * p + 0); + ux2(idx) = recv_buff_real(NREALS * p + 1); + ux3(idx) = recv_buff_real(NREALS * p + 2); + weight(idx) = recv_buff_real(NREALS * p + 3); + if constexpr (D == Dim::_2D and C != Coord::Cart) { + phi(idx) = recv_buff_real(NREALS * p + 4); + } + if (NPLDS > 0) { + for (auto l { 0u }; l < NPLDS; ++l) { + pld(idx, l) = recv_buff_pld(NPLDS * p + l); + } + } + tag(idx) = ParticleTag::alive; + } + }; + +} // namespace kernel::comm + +#endif // KERNELS_COMM_HPP diff --git a/src/kernels/currents_deposit.hpp b/src/kernels/currents_deposit.hpp index ca9a94878..bd84554f8 100644 --- a/src/kernels/currents_deposit.hpp +++ b/src/kernels/currents_deposit.hpp @@ -67,8 +67,8 @@ namespace kernel { const array_t& weight, const array_t& tag, const M& metric, - const real_t& charge, - const real_t& dt) + real_t charge, + real_t dt) : J { scatter_cur } , i1 { i1 } , i2 { i2 } @@ -100,45 +100,84 @@ namespace kernel { if (tag(p) == ParticleTag::dead) { return; } - // _f = final, _i = initial - tuple_t Ip_f, Ip_i; - coord_t xp_f, xp_i, xp_r; + // recover particle velocity to deposit in unsimulated direction vec_t vp { ZERO }; + { + coord_t xp { ZERO }; + if constexpr (D == Dim::_1D) { + xp[0] = i_di_to_Xi(i1(p), dx1(p)); + } else if constexpr (D == Dim::_2D) { + if constexpr (M::PrtlDim == Dim::_3D) { + xp[0] = i_di_to_Xi(i1(p), dx1(p)); + xp[1] = i_di_to_Xi(i2(p), dx2(p)); + xp[2] = phi(p); + } else { + xp[0] = i_di_to_Xi(i1(p), dx1(p)); + xp[1] = i_di_to_Xi(i2(p), dx2(p)); + } + } else { + xp[0] = i_di_to_Xi(i1(p), dx1(p)); + xp[1] = i_di_to_Xi(i2(p), dx2(p)); + xp[2] = i_di_to_Xi(i3(p), dx3(p)); + } + auto inv_energy { ZERO }; + if constexpr (S == SimEngine::SRPIC) { + metric.template transform_xyz(xp, + { ux1(p), ux2(p), ux3(p) }, + vp); + inv_energy = ONE / math::sqrt(ONE + NORM_SQR(ux1(p), ux2(p), ux3(p))); + } else { + coord_t xp_ { ZERO }; + xp_[0] = xp[0]; + real_t theta_Cd { xp[1] }; + const real_t theta_Ph { metric.template convert<2, Crd::Cd, Crd::Ph>( + theta_Cd) }; + const real_t small_angle { static_cast(constant::SMALL_ANGLE_GR) }; + const auto large_angle { static_cast(constant::PI) - small_angle }; + if (theta_Ph < small_angle) { + theta_Cd = metric.template convert<2, Crd::Ph, Crd::Cd>(small_angle); + } else if (theta_Ph >= large_angle) { + theta_Cd = metric.template convert<2, Crd::Ph, Crd::Cd>(large_angle); + } + xp_[1] = theta_Cd; + metric.template transform(xp_, + { ux1(p), ux2(p), ux3(p) }, + vp); + inv_energy = metric.alpha(xp_) / + math::sqrt(ONE + ux1(p) * vp[0] + ux2(p) * vp[1] + + ux3(p) * vp[2]); + } + if (Kokkos::isnan(vp[2]) || Kokkos::isinf(vp[2])) { + vp[2] = ZERO; + } + vp[0] *= inv_energy; + vp[1] *= inv_energy; + vp[2] *= inv_energy; + } - // get [i, di]_init and [i, di]_final (per dimension) - getDepositInterval(p, Ip_f, Ip_i, xp_f, xp_i, xp_r); - // recover particle velocity to deposit in unsimulated direction - getPrtl3Vel(p, vp); const real_t coeff { weight(p) * charge }; - depositCurrentsFromParticle(coeff, vp, Ip_f, Ip_i, xp_f, xp_i, xp_r); - } - /** - * @brief Deposit currents from a single particle. - * @param[in] coeff Particle weight x charge. - * @param[in] vp Particle 3-velocity. - * @param[in] Ip_f Final position of the particle (cell index). - * @param[in] Ip_i Initial position of the particle (cell index). - * @param[in] xp_f Final position. - * @param[in] xp_i Previous step position. - * @param[in] xp_r Intermediate point used in zig-zag deposit. - */ - Inline auto depositCurrentsFromParticle(const real_t& coeff, - const vec_t& vp, - const tuple_t& Ip_f, - const tuple_t& Ip_i, - const coord_t& xp_f, - const coord_t& xp_i, - const coord_t& xp_r) const -> void { - const real_t Wx1_1 { HALF * (xp_i[0] + xp_r[0]) - - static_cast(Ip_i[0]) }; - const real_t Wx1_2 { HALF * (xp_f[0] + xp_r[0]) - - static_cast(Ip_f[0]) }; - const real_t Fx1_1 { (xp_r[0] - xp_i[0]) * coeff * inv_dt }; - const real_t Fx1_2 { (xp_f[0] - xp_r[0]) * coeff * inv_dt }; + const auto dxp_r_1 { static_cast(i1(p) == i1_prev(p)) * + (dx1(p) + dx1_prev(p)) * static_cast(INV_2) }; + + const real_t Wx1_1 { INV_2 * (dxp_r_1 + dx1_prev(p) + + static_cast(i1(p) > i1_prev(p))) }; + const real_t Wx1_2 { INV_2 * (dx1(p) + dxp_r_1 + + static_cast( + static_cast(i1(p) > i1_prev(p)) + + i1_prev(p) - i1(p))) }; + const real_t Fx1_1 { (static_cast(i1(p) > i1_prev(p)) + dxp_r_1 - + dx1_prev(p)) * + coeff * inv_dt }; + const real_t Fx1_2 { (static_cast( + i1(p) - i1_prev(p) - + static_cast(i1(p) > i1_prev(p))) + + dx1(p) - dxp_r_1) * + coeff * inv_dt }; auto J_acc = J.access(); + // tuple_t dxp_r; if constexpr (D == Dim::_1D) { const real_t Fx2_1 { HALF * vp[1] * coeff }; const real_t Fx2_2 { HALF * vp[1] * coeff }; @@ -146,265 +185,210 @@ namespace kernel { const real_t Fx3_1 { HALF * vp[2] * coeff }; const real_t Fx3_2 { HALF * vp[2] * coeff }; - J_acc(Ip_i[0] + N_GHOSTS, cur::jx1) += Fx1_1; - J_acc(Ip_f[0] + N_GHOSTS, cur::jx1) += Fx1_2; + J_acc(i1_prev(p) + N_GHOSTS, cur::jx1) += Fx1_1; + J_acc(i1(p) + N_GHOSTS, cur::jx1) += Fx1_2; - J_acc(Ip_i[0] + N_GHOSTS, cur::jx2) += Fx2_1 * (ONE - Wx1_1); - J_acc(Ip_i[0] + N_GHOSTS + 1, cur::jx2) += Fx2_1 * Wx1_1; - J_acc(Ip_f[0] + N_GHOSTS, cur::jx2) += Fx2_2 * (ONE - Wx1_2); - J_acc(Ip_f[0] + N_GHOSTS + 1, cur::jx2) += Fx2_2 * Wx1_2; + J_acc(i1_prev(p) + N_GHOSTS, cur::jx2) += Fx2_1 * (ONE - Wx1_1); + J_acc(i1_prev(p) + N_GHOSTS + 1, cur::jx2) += Fx2_1 * Wx1_1; + J_acc(i1(p) + N_GHOSTS, cur::jx2) += Fx2_2 * (ONE - Wx1_2); + J_acc(i1(p) + N_GHOSTS + 1, cur::jx2) += Fx2_2 * Wx1_2; - J_acc(Ip_i[0] + N_GHOSTS, cur::jx3) += Fx3_1 * (ONE - Wx1_1); - J_acc(Ip_i[0] + N_GHOSTS + 1, cur::jx3) += Fx3_1 * Wx1_1; - J_acc(Ip_f[0] + N_GHOSTS, cur::jx3) += Fx3_2 * (ONE - Wx1_2); - J_acc(Ip_f[0] + N_GHOSTS + 1, cur::jx3) += Fx3_2 * Wx1_2; + J_acc(i1_prev(p) + N_GHOSTS, cur::jx3) += Fx3_1 * (ONE - Wx1_1); + J_acc(i1_prev(p) + N_GHOSTS + 1, cur::jx3) += Fx3_1 * Wx1_1; + J_acc(i1(p) + N_GHOSTS, cur::jx3) += Fx3_2 * (ONE - Wx1_2); + J_acc(i1(p) + N_GHOSTS + 1, cur::jx3) += Fx3_2 * Wx1_2; } else if constexpr (D == Dim::_2D || D == Dim::_3D) { - const real_t Wx2_1 { HALF * (xp_i[1] + xp_r[1]) - - static_cast(Ip_i[1]) }; - const real_t Wx2_2 { HALF * (xp_f[1] + xp_r[1]) - - static_cast(Ip_f[1]) }; - const real_t Fx2_1 { (xp_r[1] - xp_i[1]) * coeff * inv_dt }; - const real_t Fx2_2 { (xp_f[1] - xp_r[1]) * coeff * inv_dt }; + const auto dxp_r_2 { static_cast(i2(p) == i2_prev(p)) * + (dx2(p) + dx2_prev(p)) * + static_cast(INV_2) }; + + const real_t Wx2_1 { INV_2 * (dxp_r_2 + dx2_prev(p) + + static_cast(i2(p) > i2_prev(p))) }; + const real_t Wx2_2 { INV_2 * (dx2(p) + dxp_r_2 + + static_cast( + static_cast(i2(p) > i2_prev(p)) + + i2_prev(p) - i2(p))) }; + const real_t Fx2_1 { (static_cast(i2(p) > i2_prev(p)) + + dxp_r_2 - dx2_prev(p)) * + coeff * inv_dt }; + const real_t Fx2_2 { (static_cast( + i2(p) - i2_prev(p) - + static_cast(i2(p) > i2_prev(p))) + + dx2(p) - dxp_r_2) * + coeff * inv_dt }; if constexpr (D == Dim::_2D) { const real_t Fx3_1 { HALF * vp[2] * coeff }; const real_t Fx3_2 { HALF * vp[2] * coeff }; - J_acc(Ip_i[0] + N_GHOSTS, Ip_i[1] + N_GHOSTS, cur::jx1) += Fx1_1 * - (ONE - Wx2_1); - J_acc(Ip_i[0] + N_GHOSTS, Ip_i[1] + N_GHOSTS + 1, cur::jx1) += Fx1_1 * - Wx2_1; - J_acc(Ip_f[0] + N_GHOSTS, Ip_f[1] + N_GHOSTS, cur::jx1) += Fx1_2 * - (ONE - Wx2_2); - J_acc(Ip_f[0] + N_GHOSTS, Ip_f[1] + N_GHOSTS + 1, cur::jx1) += Fx1_2 * - Wx2_2; - - J_acc(Ip_i[0] + N_GHOSTS, Ip_i[1] + N_GHOSTS, cur::jx2) += Fx2_1 * - (ONE - Wx1_1); - J_acc(Ip_i[0] + N_GHOSTS + 1, Ip_i[1] + N_GHOSTS, cur::jx2) += Fx2_1 * - Wx1_1; - J_acc(Ip_f[0] + N_GHOSTS, Ip_f[1] + N_GHOSTS, cur::jx2) += Fx2_2 * - (ONE - Wx1_2); - J_acc(Ip_f[0] + N_GHOSTS + 1, Ip_f[1] + N_GHOSTS, cur::jx2) += Fx2_2 * - Wx1_2; - - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS, + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS, + cur::jx1) += Fx1_1 * (ONE - Wx2_1); + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS + 1, + cur::jx1) += Fx1_1 * Wx2_1; + J_acc(i1(p) + N_GHOSTS, i2(p) + N_GHOSTS, cur::jx1) += Fx1_2 * + (ONE - Wx2_2); + J_acc(i1(p) + N_GHOSTS, i2(p) + N_GHOSTS + 1, cur::jx1) += Fx1_2 * Wx2_2; + + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS, + cur::jx2) += Fx2_1 * (ONE - Wx1_1); + J_acc(i1_prev(p) + N_GHOSTS + 1, + i2_prev(p) + N_GHOSTS, + cur::jx2) += Fx2_1 * Wx1_1; + J_acc(i1(p) + N_GHOSTS, i2(p) + N_GHOSTS, cur::jx2) += Fx2_2 * + (ONE - Wx1_2); + J_acc(i1(p) + N_GHOSTS + 1, i2(p) + N_GHOSTS, cur::jx2) += Fx2_2 * Wx1_2; + + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS, cur::jx3) += Fx3_1 * (ONE - Wx1_1) * (ONE - Wx2_1); - J_acc(Ip_i[0] + N_GHOSTS + 1, - Ip_i[1] + N_GHOSTS, - cur::jx3) += Fx3_1 * Wx1_2 * (ONE - Wx2_1); - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS + 1, + J_acc(i1_prev(p) + N_GHOSTS + 1, + i2_prev(p) + N_GHOSTS, + cur::jx3) += Fx3_1 * Wx1_1 * (ONE - Wx2_1); + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS + 1, cur::jx3) += Fx3_1 * (ONE - Wx1_1) * Wx2_1; - J_acc(Ip_i[0] + N_GHOSTS + 1, - Ip_i[1] + N_GHOSTS + 1, + J_acc(i1_prev(p) + N_GHOSTS + 1, + i2_prev(p) + N_GHOSTS + 1, cur::jx3) += Fx3_1 * Wx1_1 * Wx2_1; - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS, - cur::jx3) += Fx3_2 * (ONE - Wx1_2) * (ONE - Wx2_2); - J_acc(Ip_f[0] + N_GHOSTS + 1, - Ip_f[1] + N_GHOSTS, + J_acc(i1(p) + N_GHOSTS, i2(p) + N_GHOSTS, cur::jx3) += Fx3_2 * + (ONE - Wx1_2) * + (ONE - Wx2_2); + J_acc(i1(p) + N_GHOSTS + 1, + i2(p) + N_GHOSTS, cur::jx3) += Fx3_2 * Wx1_2 * (ONE - Wx2_2); - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS + 1, + J_acc(i1(p) + N_GHOSTS, + i2(p) + N_GHOSTS + 1, cur::jx3) += Fx3_2 * (ONE - Wx1_2) * Wx2_2; - J_acc(Ip_f[0] + N_GHOSTS + 1, - Ip_f[1] + N_GHOSTS + 1, - cur::jx3) += Fx3_2 * Wx1_2 * Wx2_2; + J_acc(i1(p) + N_GHOSTS + 1, i2(p) + N_GHOSTS + 1, cur::jx3) += Fx3_2 * + Wx1_2 * + Wx2_2; } else { - const real_t Wx3_1 { HALF * (xp_i[2] + xp_r[2]) - - static_cast(Ip_i[2]) }; - const real_t Wx3_2 { HALF * (xp_f[2] + xp_r[2]) - - static_cast(Ip_f[2]) }; - const real_t Fx3_1 { (xp_r[2] - xp_i[2]) * coeff * inv_dt }; - const real_t Fx3_2 { (xp_f[2] - xp_r[2]) * coeff * inv_dt }; - - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS, - Ip_i[2] + N_GHOSTS, + const auto dxp_r_3 { static_cast(i3(p) == i3_prev(p)) * + (dx3(p) + dx3_prev(p)) * + static_cast(INV_2) }; + const real_t Wx3_1 { INV_2 * (dxp_r_3 + dx3_prev(p) + + static_cast(i3(p) > i3_prev(p))) }; + const real_t Wx3_2 { INV_2 * (dx3(p) + dxp_r_3 + + static_cast( + static_cast(i3(p) > i3_prev(p)) + + i3_prev(p) - i3(p))) }; + const real_t Fx3_1 { (static_cast(i3(p) > i3_prev(p)) + + dxp_r_3 - dx3_prev(p)) * + coeff * inv_dt }; + const real_t Fx3_2 { (static_cast( + i3(p) - i3_prev(p) - + static_cast(i3(p) > i3_prev(p))) + + dx3(p) - dxp_r_3) * + coeff * inv_dt }; + + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS, + i3_prev(p) + N_GHOSTS, cur::jx1) += Fx1_1 * (ONE - Wx2_1) * (ONE - Wx3_1); - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS + 1, - Ip_i[2] + N_GHOSTS, + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS + 1, + i3_prev(p) + N_GHOSTS, cur::jx1) += Fx1_1 * Wx2_1 * (ONE - Wx3_1); - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS, - Ip_i[2] + N_GHOSTS + 1, + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS, + i3_prev(p) + N_GHOSTS + 1, cur::jx1) += Fx1_1 * (ONE - Wx2_1) * Wx3_1; - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS + 1, - Ip_i[2] + N_GHOSTS + 1, + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS + 1, + i3_prev(p) + N_GHOSTS + 1, cur::jx1) += Fx1_1 * Wx2_1 * Wx3_1; - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS, - Ip_f[2] + N_GHOSTS, + J_acc(i1(p) + N_GHOSTS, + i2(p) + N_GHOSTS, + i3(p) + N_GHOSTS, cur::jx1) += Fx1_2 * (ONE - Wx2_2) * (ONE - Wx3_2); - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS + 1, - Ip_f[2] + N_GHOSTS, + J_acc(i1(p) + N_GHOSTS, + i2(p) + N_GHOSTS + 1, + i3(p) + N_GHOSTS, cur::jx1) += Fx1_2 * Wx2_2 * (ONE - Wx3_2); - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS, - Ip_f[2] + N_GHOSTS + 1, + J_acc(i1(p) + N_GHOSTS, + i2(p) + N_GHOSTS, + i3(p) + N_GHOSTS + 1, cur::jx1) += Fx1_2 * (ONE - Wx2_2) * Wx3_2; - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS + 1, - Ip_f[2] + N_GHOSTS + 1, + J_acc(i1(p) + N_GHOSTS, + i2(p) + N_GHOSTS + 1, + i3(p) + N_GHOSTS + 1, cur::jx1) += Fx1_2 * Wx2_2 * Wx3_2; - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS, - Ip_i[2] + N_GHOSTS, + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS, + i3_prev(p) + N_GHOSTS, cur::jx2) += Fx2_1 * (ONE - Wx1_1) * (ONE - Wx3_1); - J_acc(Ip_i[0] + N_GHOSTS + 1, - Ip_i[1] + N_GHOSTS, - Ip_i[2] + N_GHOSTS, + J_acc(i1_prev(p) + N_GHOSTS + 1, + i2_prev(p) + N_GHOSTS, + i3_prev(p) + N_GHOSTS, cur::jx2) += Fx2_1 * Wx1_1 * (ONE - Wx3_1); - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS, - Ip_i[2] + N_GHOSTS + 1, + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS, + i3_prev(p) + N_GHOSTS + 1, cur::jx2) += Fx2_1 * (ONE - Wx1_1) * Wx3_1; - J_acc(Ip_i[0] + N_GHOSTS + 1, - Ip_i[1] + N_GHOSTS, - Ip_i[2] + N_GHOSTS + 1, + J_acc(i1_prev(p) + N_GHOSTS + 1, + i2_prev(p) + N_GHOSTS, + i3_prev(p) + N_GHOSTS + 1, cur::jx2) += Fx2_1 * Wx1_1 * Wx3_1; - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS, - Ip_f[2] + N_GHOSTS, + J_acc(i1(p) + N_GHOSTS, + i2(p) + N_GHOSTS, + i3(p) + N_GHOSTS, cur::jx2) += Fx2_2 * (ONE - Wx1_2) * (ONE - Wx3_2); - J_acc(Ip_f[0] + N_GHOSTS + 1, - Ip_f[1] + N_GHOSTS, - Ip_f[2] + N_GHOSTS, + J_acc(i1(p) + N_GHOSTS + 1, + i2(p) + N_GHOSTS, + i3(p) + N_GHOSTS, cur::jx2) += Fx2_2 * Wx1_2 * (ONE - Wx3_2); - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS, - Ip_f[2] + N_GHOSTS + 1, + J_acc(i1(p) + N_GHOSTS, + i2(p) + N_GHOSTS, + i3(p) + N_GHOSTS + 1, cur::jx2) += Fx2_2 * (ONE - Wx1_2) * Wx3_2; - J_acc(Ip_f[0] + N_GHOSTS + 1, - Ip_f[1] + N_GHOSTS, - Ip_f[2] + N_GHOSTS + 1, + J_acc(i1(p) + N_GHOSTS + 1, + i2(p) + N_GHOSTS, + i3(p) + N_GHOSTS + 1, cur::jx2) += Fx2_2 * Wx1_2 * Wx3_2; - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS, - Ip_i[2] + N_GHOSTS, + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS, + i3_prev(p) + N_GHOSTS, cur::jx3) += Fx3_1 * (ONE - Wx1_1) * (ONE - Wx2_1); - J_acc(Ip_i[0] + N_GHOSTS + 1, - Ip_i[1] + N_GHOSTS, - Ip_i[2] + N_GHOSTS, + J_acc(i1_prev(p) + N_GHOSTS + 1, + i2_prev(p) + N_GHOSTS, + i3_prev(p) + N_GHOSTS, cur::jx3) += Fx3_1 * Wx1_1 * (ONE - Wx2_1); - J_acc(Ip_i[0] + N_GHOSTS, - Ip_i[1] + N_GHOSTS + 1, - Ip_i[2] + N_GHOSTS, + J_acc(i1_prev(p) + N_GHOSTS, + i2_prev(p) + N_GHOSTS + 1, + i3_prev(p) + N_GHOSTS, cur::jx3) += Fx3_1 * (ONE - Wx1_1) * Wx2_1; - J_acc(Ip_i[0] + N_GHOSTS + 1, - Ip_i[1] + N_GHOSTS + 1, - Ip_i[2] + N_GHOSTS, + J_acc(i1_prev(p) + N_GHOSTS + 1, + i2_prev(p) + N_GHOSTS + 1, + i3_prev(p) + N_GHOSTS, cur::jx3) += Fx3_1 * Wx1_1 * Wx2_1; - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS, - Ip_f[2] + N_GHOSTS, + J_acc(i1(p) + N_GHOSTS, + i2(p) + N_GHOSTS, + i3(p) + N_GHOSTS, cur::jx3) += Fx3_2 * (ONE - Wx1_2) * (ONE - Wx2_2); - J_acc(Ip_f[0] + N_GHOSTS + 1, - Ip_f[1] + N_GHOSTS, - Ip_f[2] + N_GHOSTS, + J_acc(i1(p) + N_GHOSTS + 1, + i2(p) + N_GHOSTS, + i3(p) + N_GHOSTS, cur::jx3) += Fx3_2 * Wx1_2 * (ONE - Wx2_2); - J_acc(Ip_f[0] + N_GHOSTS, - Ip_f[1] + N_GHOSTS + 1, - Ip_f[2] + N_GHOSTS, + J_acc(i1(p) + N_GHOSTS, + i2(p) + N_GHOSTS + 1, + i3(p) + N_GHOSTS, cur::jx3) += Fx3_2 * (ONE - Wx1_2) * Wx2_2; - J_acc(Ip_f[0] + N_GHOSTS + 1, - Ip_f[1] + N_GHOSTS + 1, - Ip_f[2] + N_GHOSTS, + J_acc(i1(p) + N_GHOSTS + 1, + i2(p) + N_GHOSTS + 1, + i3(p) + N_GHOSTS, cur::jx3) += Fx3_2 * Wx1_2 * Wx2_2; } } } - - /** - * @brief Get particle position in `coord_t` form. - * @param[in] p Index of particle. - * @param[out] Ip_f Final position of the particle (cell index). - * @param[out] Ip_i Initial position of the particle (cell index). - * @param[out] xp_f Final position. - * @param[out] xp_i Previous step position. - * @param[out] xp_r Intermediate point used in zig-zag deposit. - */ - Inline auto getDepositInterval(index_t& p, - tuple_t& Ip_f, - tuple_t& Ip_i, - coord_t& xp_f, - coord_t& xp_i, - coord_t& xp_r) const -> void { - Ip_f[0] = i1(p); - Ip_i[0] = i1_prev(p); - xp_f[0] = i_di_to_Xi(Ip_f[0], dx1(p)); - xp_i[0] = i_di_to_Xi(Ip_i[0], dx1_prev(p)); - if constexpr (D == Dim::_2D || D == Dim::_3D) { - Ip_f[1] = i2(p); - Ip_i[1] = i2_prev(p); - xp_f[1] = i_di_to_Xi(Ip_f[1], dx2(p)); - xp_i[1] = i_di_to_Xi(Ip_i[1], dx2_prev(p)); - } - if constexpr (D == Dim::_3D) { - Ip_f[2] = i3(p); - Ip_i[2] = i3_prev(p); - xp_f[2] = i_di_to_Xi(Ip_f[2], dx3(p)); - xp_i[2] = i_di_to_Xi(Ip_i[2], dx3_prev(p)); - } - for (auto i = 0u; i < D; ++i) { - xp_r[i] = math::fmin(static_cast(IMIN(Ip_i[i], Ip_f[i]) + 1), - math::fmax(static_cast(IMAX(Ip_i[i], Ip_f[i])), - HALF * (xp_i[i] + xp_f[i]))); - } - } - - // Getters - Inline void getPrtlPos(index_t& p, coord_t& xp) const { - if constexpr (D == Dim::_1D) { - xp[0] = i_di_to_Xi(i1(p), dx1(p)); - } else if constexpr (D == Dim::_2D) { - if constexpr (M::PrtlDim == Dim::_3D) { - xp[0] = i_di_to_Xi(i1(p), dx1(p)); - xp[1] = i_di_to_Xi(i2(p), dx2(p)); - xp[2] = phi(p); - } else { - xp[0] = i_di_to_Xi(i1(p), dx1(p)); - xp[1] = i_di_to_Xi(i2(p), dx2(p)); - } - } else { - xp[0] = i_di_to_Xi(i1(p), dx1(p)); - xp[1] = i_di_to_Xi(i2(p), dx2(p)); - xp[2] = i_di_to_Xi(i3(p), dx3(p)); - } - } - - Inline void getPrtl3Vel(index_t& p, vec_t& vp) const { - coord_t xp { ZERO }; - getPrtlPos(p, xp); - auto inv_energy { ZERO }; - if constexpr (S == SimEngine::SRPIC) { - metric.template transform_xyz(xp, - { ux1(p), ux2(p), ux3(p) }, - vp); - inv_energy = ONE / math::sqrt(ONE + NORM_SQR(ux1(p), ux2(p), ux3(p))); - } else { - metric.template transform(xp, { ux1(p), ux2(p), ux3(p) }, vp); - inv_energy = ONE / math::sqrt(ONE + ux1(p) * vp[0] + ux2(p) * vp[1] + - ux3(p) * vp[2]); - } - if (Kokkos::isnan(vp[2]) || Kokkos::isinf(vp[2])) { - vp[2] = ZERO; - } - vp[0] *= inv_energy; - vp[1] *= inv_energy; - vp[2] *= inv_energy; - } }; } // namespace kernel diff --git a/src/kernels/digital_filter.hpp b/src/kernels/digital_filter.hpp index 5d05fad2d..5ac60327d 100644 --- a/src/kernels/digital_filter.hpp +++ b/src/kernels/digital_filter.hpp @@ -17,9 +17,40 @@ #include "utils/error.h" #include "utils/numeric.h" -#define FILTER_IN_I1(ARR, COMP, I, J) \ +#define FILTER2D_IN_I1(ARR, COMP, I, J) \ INV_2*(ARR)((I), (J), (COMP)) + \ - INV_4*((ARR)((I)-1, (J), (COMP)) + (ARR)((I) + 1, (J), (COMP))) + INV_4*((ARR)((I) - 1, (J), (COMP)) + (ARR)((I) + 1, (J), (COMP))) + +#define FILTER2D_IN_I2(ARR, COMP, I, J) \ + INV_2*(ARR)((I), (J), (COMP)) + \ + INV_4*((ARR)((I), (J) - 1, (COMP)) + (ARR)((I), (J) + 1, (COMP))) + +#define FILTER3D_IN_I1_I2(ARR, COMP, I, J, K) \ + INV_4*(ARR)(I, J, K, (COMP)) + \ + INV_8*((ARR)((I) - 1, (J), (K), (COMP)) + (ARR)((I) + 1, (J), (K), (COMP)) + \ + (ARR)((I), (J) - 1, (K), (COMP)) + (ARR)((I), (J) + 1, (K), (COMP))) + \ + INV_16*((ARR)((I) - 1, (J) - 1, (K), (COMP)) + \ + (ARR)((I) + 1, (J) + 1, (K), (COMP)) + \ + (ARR)((I) - 1, (J) + 1, (K), (COMP)) + \ + (ARR)((I) + 1, (J) - 1, (K), (COMP))) + +#define FILTER3D_IN_I2_I3(ARR, COMP, I, J, K) \ + INV_4*(ARR)(I, J, K, (COMP)) + \ + INV_8*((ARR)((I), (J) - 1, (K), (COMP)) + (ARR)((I), (J) + 1, (K), (COMP)) + \ + (ARR)((I), (J), (K) - 1, (COMP)) + (ARR)((I), (J), (K) + 1, (COMP))) + \ + INV_16*((ARR)((I), (J) - 1, (K) - 1, (COMP)) + \ + (ARR)((I), (J) + 1, (K) + 1, (COMP)) + \ + (ARR)((I), (J) - 1, (K) + 1, (COMP)) + \ + (ARR)((I), (J) + 1, (K) - 1, (COMP))) + +#define FILTER3D_IN_I1_I3(ARR, COMP, I, J, K) \ + INV_4*(ARR)(I, J, K, (COMP)) + \ + INV_8*((ARR)((I) - 1, (J), (K), (COMP)) + (ARR)((I) + 1, (J), (K), (COMP)) + \ + (ARR)((I), (J), (K) - 1, (COMP)) + (ARR)((I), (J), (K) + 1, (COMP))) + \ + INV_16*((ARR)((I) - 1, (J), (K) - 1, (COMP)) + \ + (ARR)((I) + 1, (J), (K) + 1, (COMP)) + \ + (ARR)((I) - 1, (J), (K) + 1, (COMP)) + \ + (ARR)((I) + 1, (J), (K) - 1, (COMP))) namespace kernel { using namespace ntt; @@ -28,31 +59,67 @@ namespace kernel { class DigitalFilter_kernel { ndfield_t array; const ndfield_t buffer; - bool is_axis_i2min { false }, is_axis_i2max { false }; - static constexpr auto i2_min = N_GHOSTS; - const std::size_t i2_max; + const bool is_axis_i2min, is_axis_i2max; + const bool is_conductor_i1min, is_conductor_i1max; + const bool is_conductor_i2min, is_conductor_i2max; + const bool is_conductor_i3min, is_conductor_i3max; + static constexpr auto i1_min = N_GHOSTS, i2_min = N_GHOSTS, i3_min = N_GHOSTS; + const ncells_t i1_max, i2_max, i3_max; + + // @TODO: Current implementation might have issues + // ... at the corners between two conductors public: DigitalFilter_kernel(ndfield_t& array, const ndfield_t& buffer, - const std::size_t (&size_)[D], + const ncells_t (&size_)[D], const boundaries_t& boundaries) : array { array } , buffer { buffer } - , i2_max { (short)D > 1 ? size_[1] + N_GHOSTS : 0 } { - if constexpr ((C != Coord::Cart) && (D != Dim::_1D)) { - raise::ErrorIf(boundaries.size() < 2, "boundaries defined incorrectly", HERE); - is_axis_i2min = (boundaries[1].first == FldsBC::AXIS); - is_axis_i2max = (boundaries[1].second == FldsBC::AXIS); - } - } + , is_axis_i2min { (D == Dim::_2D) and (boundaries[1].first == FldsBC::AXIS) } + , is_axis_i2max { (D == Dim::_2D) and (boundaries[1].second == FldsBC::AXIS) } + , is_conductor_i1min { boundaries[0].first == FldsBC::CONDUCTOR } + , is_conductor_i1max { boundaries[0].second == FldsBC::CONDUCTOR } + , is_conductor_i2min { (short)D > 1 + ? (boundaries[1].first == FldsBC::CONDUCTOR) + : false } + , is_conductor_i2max { (short)D > 1 + ? (boundaries[1].second == FldsBC::CONDUCTOR) + : false } + , is_conductor_i3min { (short)D > 2 + ? (boundaries[2].first == FldsBC::CONDUCTOR) + : false } + , is_conductor_i3max { (short)D > 2 + ? (boundaries[2].second == FldsBC::CONDUCTOR) + : false } + , i1_max { size_[0] + N_GHOSTS } + , i2_max { (short)D > 1 ? (size_[1] + N_GHOSTS) : 0 } + , i3_max { (short)D > 2 ? (size_[2] + N_GHOSTS) : 0 } {} Inline void operator()(index_t i1) const { if constexpr ((D == Dim::_1D) && (C == Coord::Cart)) { + if ((is_conductor_i1min and i1 == i1_min) or + (is_conductor_i1max and i1 == i1_max - 1)) { + const auto i1side = is_conductor_i1min ? (i1 + 1) : (i1 - 1); + array(i1, cur::jx1) = (THREE * INV_4) * buffer(i1, cur::jx1) + + (INV_4)*buffer(i1side, cur::jx1); + } else if ((is_conductor_i1min and i1 == i1_min + 1) or + (is_conductor_i1max and i1 == i1_max - 2)) { + const auto i1side = is_conductor_i1min ? (i1 + 1) : (i1 - 1); + array(i1, cur::jx1) = INV_2 * buffer(i1, cur::jx1) + + INV_4 * (buffer(i1 - 1, cur::jx1) + + buffer(i1 + 1, cur::jx1)); + array(i1, cur::jx2) = (INV_2)*buffer(i1, cur::jx2) + + (INV_4)*buffer(i1side, cur::jx2); + array(i1, cur::jx3) = (INV_2)*buffer(i1, cur::jx3) + + (INV_4)*buffer(i1side, cur::jx3); + } else { #pragma unroll - for (const auto& comp : { cur::jx1, cur::jx2, cur::jx3 }) { - array(i1, comp) = INV_2 * buffer(i1, comp) + - INV_4 * (buffer(i1 - 1, comp) + buffer(i1 + 1, comp)); + for (const auto& comp : { cur::jx1, cur::jx2, cur::jx3 }) { + array(i1, comp) = INV_2 * buffer(i1, comp) + + INV_4 * + (buffer(i1 - 1, comp) + buffer(i1 + 1, comp)); + } } } else { raise::KernelError(HERE, "DigitalFilter_kernel: 1D implementation called for D != 1 or for non-Cartesian metric"); @@ -62,25 +129,78 @@ namespace kernel { Inline void operator()(index_t i1, index_t i2) const { if constexpr (D == Dim::_2D) { if constexpr (C == Coord::Cart) { + if ((is_conductor_i1min and i1 == i1_min) or + (is_conductor_i1max and i1 == i1_max - 1)) { + const auto i1side = is_conductor_i1min ? (i1 + 1) : (i1 - 1); + array(i1, i2, cur::jx1) = + (THREE * INV_4) * (FILTER2D_IN_I2(buffer, cur::jx1, i1, i2)) + + (INV_4) * (FILTER2D_IN_I2(buffer, cur::jx1, i1side, i2)); + } else if ((is_conductor_i1min and i1 == i1_min + 1) or + (is_conductor_i1max and i1 == i1_max - 2)) { + const auto i1side = is_conductor_i1min ? (i1 + 1) : (i1 - 1); + array(i1, + i2, + cur::jx1) = INV_2 * (FILTER2D_IN_I2(buffer, cur::jx1, i1, i2)) + + INV_4 * + ((FILTER2D_IN_I2(buffer, cur::jx1, i1 - 1, i2)) + + (FILTER2D_IN_I2(buffer, cur::jx1, i1 + 1, i2))); + array(i1, + i2, + cur::jx2) = INV_2 * (FILTER2D_IN_I2(buffer, cur::jx2, i1, i2)) + + INV_4 * + (FILTER2D_IN_I2(buffer, cur::jx2, i1side, i2)); + array(i1, + i2, + cur::jx3) = INV_2 * (FILTER2D_IN_I2(buffer, cur::jx3, i1, i2)) + + INV_4 * + (FILTER2D_IN_I2(buffer, cur::jx3, i1side, i2)); + } else if ((is_conductor_i2min and i2 == i2_min) or + (is_conductor_i2max and i2 == i2_max - 1)) { + const auto i2side = is_conductor_i2min ? (i2 + 1) : (i2 - 1); + array(i1, i2, cur::jx2) = + (THREE * INV_4) * (FILTER2D_IN_I1(buffer, cur::jx2, i1, i2)) + + (INV_4) * (FILTER2D_IN_I1(buffer, cur::jx2, i1, i2side)); + } else if ((is_conductor_i2min and i2 == i2_min + 1) or + (is_conductor_i2max and i2 == i2_max - 2)) { + const auto i2side = is_conductor_i2min ? (i2 + 1) : (i2 - 1); + array(i1, + i2, + cur::jx1) = INV_2 * (FILTER2D_IN_I1(buffer, cur::jx1, i1, i2)) + + INV_4 * + (FILTER2D_IN_I1(buffer, cur::jx1, i1, i2side)); + array(i1, + i2, + cur::jx2) = INV_2 * (FILTER2D_IN_I1(buffer, cur::jx2, i1, i2)) + + INV_4 * + ((FILTER2D_IN_I1(buffer, cur::jx2, i1, i2 - 1)) + + (FILTER2D_IN_I1(buffer, cur::jx2, i1, i2 + 1))); + array(i1, + i2, + cur::jx3) = INV_2 * (FILTER2D_IN_I1(buffer, cur::jx3, i1, i2)) + + INV_4 * + (FILTER2D_IN_I1(buffer, cur::jx3, i1, i2side)); + } else { #pragma unroll - for (const auto comp : { cur::jx1, cur::jx2, cur::jx3 }) { - array(i1, i2, comp) = INV_4 * buffer(i1, i2, comp) + - INV_8 * (buffer(i1 - 1, i2, comp) + - buffer(i1 + 1, i2, comp) + - buffer(i1, i2 - 1, comp) + - buffer(i1, i2 + 1, comp)) + - INV_16 * (buffer(i1 - 1, i2 - 1, comp) + - buffer(i1 + 1, i2 + 1, comp) + - buffer(i1 - 1, i2 + 1, comp) + - buffer(i1 + 1, i2 - 1, comp)); + for (const auto comp : { cur::jx1, cur::jx2, cur::jx3 }) { + array(i1, i2, comp) = INV_4 * buffer(i1, i2, comp) + + INV_8 * (buffer(i1 - 1, i2, comp) + + buffer(i1 + 1, i2, comp) + + buffer(i1, i2 - 1, comp) + + buffer(i1, i2 + 1, comp)) + + INV_16 * (buffer(i1 - 1, i2 - 1, comp) + + buffer(i1 + 1, i2 + 1, comp) + + buffer(i1 - 1, i2 + 1, comp) + + buffer(i1 + 1, i2 - 1, comp)); + } } } else { // spherical + // @TODO: get rid of temporary variables real_t cur_00, cur_0p1, cur_0m1; if (is_axis_i2min && (i2 == i2_min)) { /* --------------------------------- r, phi --------------------------------- */ // ... filter in r - cur_00 = FILTER_IN_I1(buffer, cur::jx1, i1, i2); - cur_0p1 = FILTER_IN_I1(buffer, cur::jx1, i1, i2 + 1); + cur_00 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2); + cur_0p1 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2 + 1); // ... filter in theta array(i1, i2, cur::jx1) = INV_2 * cur_00 + INV_2 * cur_0p1; @@ -88,58 +208,58 @@ namespace kernel { /* ---------------------------------- theta --------------------------------- */ // ... filter in r - cur_00 = FILTER_IN_I1(buffer, cur::jx2, i1, i2); - cur_0p1 = FILTER_IN_I1(buffer, cur::jx2, i1, i2 + 1); + cur_00 = FILTER2D_IN_I1(buffer, cur::jx2, i1, i2); + cur_0p1 = FILTER2D_IN_I1(buffer, cur::jx2, i1, i2 + 1); // ... filter in theta array(i1, i2, cur::jx2) = INV_4 * (cur_00 + cur_0p1); } else if (is_axis_i2min && (i2 == i2_min + 1)) { /* --------------------------------- r, phi --------------------------------- */ // ... filter in r - cur_00 = FILTER_IN_I1(buffer, cur::jx1, i1, i2); - cur_0p1 = FILTER_IN_I1(buffer, cur::jx1, i1, i2 + 1); - cur_0m1 = FILTER_IN_I1(buffer, cur::jx1, i1, i2 - 1); + cur_00 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2); + cur_0p1 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2 + 1); + cur_0m1 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2 - 1); // ... filter in theta array(i1, i2, cur::jx1) = INV_2 * cur_00 + INV_4 * (cur_0p1 + cur_0m1); // ... filter in r - cur_00 = FILTER_IN_I1(buffer, cur::jx3, i1, i2); - cur_0p1 = FILTER_IN_I1(buffer, cur::jx3, i1, i2 + 1); + cur_00 = FILTER2D_IN_I1(buffer, cur::jx3, i1, i2); + cur_0p1 = FILTER2D_IN_I1(buffer, cur::jx3, i1, i2 + 1); // ... filter in theta array(i1, i2, cur::jx3) = INV_2 * cur_00 + INV_4 * cur_0p1; /* ---------------------------------- theta --------------------------------- */ // ... filter in r - cur_00 = FILTER_IN_I1(buffer, cur::jx2, i1, i2); - cur_0p1 = FILTER_IN_I1(buffer, cur::jx2, i1, i2 + 1); - cur_0m1 = FILTER_IN_I1(buffer, cur::jx2, i1, i2 - 1); + cur_00 = FILTER2D_IN_I1(buffer, cur::jx2, i1, i2); + cur_0p1 = FILTER2D_IN_I1(buffer, cur::jx2, i1, i2 + 1); + cur_0m1 = FILTER2D_IN_I1(buffer, cur::jx2, i1, i2 - 1); // ... filter in theta array(i1, i2, cur::jx2) = INV_2 * cur_00 + INV_4 * (cur_0m1 + cur_0p1); } else if (is_axis_i2max && (i2 == i2_max - 1)) { /* --------------------------------- r, phi --------------------------------- */ // ... filter in r - cur_00 = FILTER_IN_I1(buffer, cur::jx1, i1, i2); - cur_0p1 = FILTER_IN_I1(buffer, cur::jx1, i1, i2 + 1); - cur_0m1 = FILTER_IN_I1(buffer, cur::jx1, i1, i2 - 1); + cur_00 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2); + cur_0p1 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2 + 1); + cur_0m1 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2 - 1); // ... filter in theta array(i1, i2, cur::jx1) = INV_2 * cur_00 + INV_4 * (cur_0m1 + cur_0p1); // ... filter in r - cur_00 = FILTER_IN_I1(buffer, cur::jx3, i1, i2); - cur_0m1 = FILTER_IN_I1(buffer, cur::jx3, i1, i2 - 1); + cur_00 = FILTER2D_IN_I1(buffer, cur::jx3, i1, i2); + cur_0m1 = FILTER2D_IN_I1(buffer, cur::jx3, i1, i2 - 1); // ... filter in theta array(i1, i2, cur::jx3) = INV_2 * cur_00 + INV_4 * cur_0m1; /* ---------------------------------- theta --------------------------------- */ // ... filter in r - cur_00 = FILTER_IN_I1(buffer, cur::jx2, i1, i2); - cur_0m1 = FILTER_IN_I1(buffer, cur::jx2, i1, i2 - 1); + cur_00 = FILTER2D_IN_I1(buffer, cur::jx2, i1, i2); + cur_0m1 = FILTER2D_IN_I1(buffer, cur::jx2, i1, i2 - 1); // ... filter in theta array(i1, i2, cur::jx2) = INV_4 * (cur_00 + cur_0m1); } else if (is_axis_i2max && (i2 == i2_max)) { /* --------------------------------- r, phi --------------------------------- */ // ... filter in r - cur_00 = FILTER_IN_I1(buffer, cur::jx1, i1, i2); - cur_0m1 = FILTER_IN_I1(buffer, cur::jx1, i1, i2 - 1); + cur_00 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2); + cur_0m1 = FILTER2D_IN_I1(buffer, cur::jx1, i1, i2 - 1); // ... filter in theta array(i1, i2, cur::jx1) = INV_2 * cur_00 + INV_2 * cur_0m1; @@ -169,33 +289,93 @@ namespace kernel { Inline void operator()(index_t i1, index_t i2, index_t i3) const { if constexpr (D == Dim::_3D) { if constexpr (C == Coord::Cart) { + if ((is_conductor_i1min and i1 == i1_min) or + (is_conductor_i1max and i1 == i1_max - 1)) { + const auto i1side = is_conductor_i1min ? (i1 + 1) : (i1 - 1); + array(i1, i2, i3, cur::jx1) = + (THREE * INV_4) * (FILTER3D_IN_I2_I3(buffer, cur::jx1, i1, i2, i3)) + + (INV_4) * (FILTER3D_IN_I2_I3(buffer, cur::jx1, i1side, i2, i3)); + } else if ((is_conductor_i1min and i1 == i1_min + 1) or + (is_conductor_i1max and i1 == i1_max - 2)) { + const auto i1side = is_conductor_i1min ? (i1 + 1) : (i1 - 1); + array(i1, i2, i3, cur::jx1) = + INV_2 * (FILTER3D_IN_I2_I3(buffer, cur::jx1, i1, i2, i3)) + + INV_4 * ((FILTER3D_IN_I2_I3(buffer, cur::jx1, i1 - 1, i2, i3)) + + (FILTER3D_IN_I2_I3(buffer, cur::jx1, i1 + 1, i2, i3))); + array(i1, i2, i3, cur::jx2) = + INV_2 * (FILTER3D_IN_I2_I3(buffer, cur::jx2, i1, i2, i3)) + + INV_4 * (FILTER3D_IN_I2_I3(buffer, cur::jx2, i1side, i2, i3)); + array(i1, i2, i3, cur::jx3) = + INV_2 * (FILTER3D_IN_I2_I3(buffer, cur::jx3, i1, i2, i3)) + + INV_4 * (FILTER3D_IN_I2_I3(buffer, cur::jx3, i1side, i2, i3)); + } else if ((is_conductor_i2min and i2 == i2_min) or + (is_conductor_i2max and i2 == i2_max - 1)) { + const auto i2side = is_conductor_i2min ? (i2 + 1) : (i2 - 1); + array(i1, i2, i3, cur::jx2) = + (THREE * INV_4) * (FILTER3D_IN_I1_I3(buffer, cur::jx2, i1, i2, i3)) + + (INV_4) * (FILTER3D_IN_I1_I3(buffer, cur::jx2, i1, i2side, i3)); + } else if ((is_conductor_i2min and i2 == i2_min + 1) or + (is_conductor_i2max and i2 == i2_max - 2)) { + const auto i2side = is_conductor_i2min ? (i2 + 1) : (i2 - 1); + array(i1, i2, i3, cur::jx1) = + INV_2 * (FILTER3D_IN_I1_I3(buffer, cur::jx1, i1, i2, i3)) + + INV_4 * (FILTER3D_IN_I1_I3(buffer, cur::jx1, i1, i2side, i3)); + array(i1, i2, i3, cur::jx2) = + INV_2 * (FILTER3D_IN_I1_I3(buffer, cur::jx2, i1, i2, i3)) + + INV_4 * ((FILTER3D_IN_I1_I3(buffer, cur::jx2, i1, i2 - 1, i3)) + + (FILTER3D_IN_I1_I3(buffer, cur::jx2, i1, i2 + 1, i3))); + array(i1, i2, i3, cur::jx3) = + INV_2 * (FILTER3D_IN_I1_I3(buffer, cur::jx3, i1, i2, i3)) + + INV_4 * (FILTER3D_IN_I1_I3(buffer, cur::jx3, i1, i2side, i3)); + } else if ((is_conductor_i3min and i3 == i3_min) or + (is_conductor_i3max and i3 == i3_max - 1)) { + const auto i3side = is_conductor_i3min ? (i3 + 1) : (i3 - 1); + array(i1, i2, i3, cur::jx3) = + (THREE * INV_4) * (FILTER3D_IN_I1_I2(buffer, cur::jx3, i1, i2, i3)) + + (INV_4) * (FILTER3D_IN_I1_I2(buffer, cur::jx3, i1, i2, i3side)); + } else if ((is_conductor_i3min and i3 == i3_min + 1) or + (is_conductor_i3max and i3 == i3_max - 2)) { + const auto i3side = is_conductor_i3min ? (i3 + 1) : (i3 - 1); + array(i1, i2, i3, cur::jx1) = + INV_2 * (FILTER3D_IN_I1_I2(buffer, cur::jx1, i1, i2, i3)) + + INV_4 * (FILTER3D_IN_I1_I2(buffer, cur::jx1, i1, i2, i3side)); + array(i1, i2, i3, cur::jx2) = + INV_2 * (FILTER3D_IN_I1_I2(buffer, cur::jx2, i1, i2, i3)) + + INV_4 * (FILTER3D_IN_I1_I2(buffer, cur::jx2, i1, i2, i3side)); + array(i1, i2, i3, cur::jx3) = + INV_2 * (FILTER3D_IN_I1_I2(buffer, cur::jx3, i1, i2, i3)) + + INV_4 * ((FILTER3D_IN_I1_I2(buffer, cur::jx3, i1, i2, i3 - 1)) + + (FILTER3D_IN_I1_I2(buffer, cur::jx3, i1, i2, i3 + 1))); + } else { #pragma unroll - for (auto& comp : { cur::jx1, cur::jx2, cur::jx3 }) { - array(i1, i2, i3, comp) = - INV_8 * buffer(i1, i2, i3, comp) + - INV_16 * - (buffer(i1 - 1, i2, i3, comp) + buffer(i1 + 1, i2, i3, comp) + - buffer(i1, i2 - 1, i3, comp) + buffer(i1, i2 + 1, i3, comp) + - buffer(i1, i2, i3 - 1, comp) + buffer(i1, i2, i3 + 1, comp)) + - INV_32 * - (buffer(i1 - 1, i2 - 1, i3, comp) + - buffer(i1 + 1, i2 + 1, i3, comp) + - buffer(i1 - 1, i2 + 1, i3, comp) + - buffer(i1 + 1, i2 - 1, i3, comp) + - buffer(i1, i2 - 1, i3 - 1, comp) + - buffer(i1, i2 + 1, i3 + 1, comp) + buffer(i1, i2, i3 - 1, comp) + - buffer(i1, i2, i3 + 1, comp) + buffer(i1 - 1, i2, i3 - 1, comp) + - buffer(i1 + 1, i2, i3 + 1, comp) + - buffer(i1 - 1, i2, i3 + 1, comp) + - buffer(i1 + 1, i2, i3 - 1, comp)) + - INV_64 * (buffer(i1 - 1, i2 - 1, i3 - 1, comp) + - buffer(i1 + 1, i2 + 1, i3 + 1, comp) + - buffer(i1 - 1, i2 + 1, i3 + 1, comp) + - buffer(i1 + 1, i2 - 1, i3 - 1, comp) + - buffer(i1 - 1, i2 - 1, i3 + 1, comp) + - buffer(i1 + 1, i2 + 1, i3 - 1, comp) + - buffer(i1 - 1, i2 + 1, i3 - 1, comp) + - buffer(i1 + 1, i2 - 1, i3 + 1, comp)); + for (auto& comp : { cur::jx1, cur::jx2, cur::jx3 }) { + array(i1, i2, i3, comp) = + INV_8 * buffer(i1, i2, i3, comp) + + INV_16 * + (buffer(i1 - 1, i2, i3, comp) + buffer(i1 + 1, i2, i3, comp) + + buffer(i1, i2 - 1, i3, comp) + buffer(i1, i2 + 1, i3, comp) + + buffer(i1, i2, i3 - 1, comp) + buffer(i1, i2, i3 + 1, comp)) + + INV_32 * + (buffer(i1 - 1, i2 - 1, i3, comp) + + buffer(i1 + 1, i2 + 1, i3, comp) + + buffer(i1 - 1, i2 + 1, i3, comp) + + buffer(i1 + 1, i2 - 1, i3, comp) + + buffer(i1, i2 - 1, i3 - 1, comp) + + buffer(i1, i2 + 1, i3 + 1, comp) + + buffer(i1, i2, i3 - 1, comp) + buffer(i1, i2, i3 + 1, comp) + + buffer(i1 - 1, i2, i3 - 1, comp) + + buffer(i1 + 1, i2, i3 + 1, comp) + + buffer(i1 - 1, i2, i3 + 1, comp) + + buffer(i1 + 1, i2, i3 - 1, comp)) + + INV_64 * (buffer(i1 - 1, i2 - 1, i3 - 1, comp) + + buffer(i1 + 1, i2 + 1, i3 + 1, comp) + + buffer(i1 - 1, i2 + 1, i3 + 1, comp) + + buffer(i1 + 1, i2 - 1, i3 - 1, comp) + + buffer(i1 - 1, i2 - 1, i3 + 1, comp) + + buffer(i1 + 1, i2 + 1, i3 - 1, comp) + + buffer(i1 - 1, i2 + 1, i3 - 1, comp) + + buffer(i1 + 1, i2 - 1, i3 + 1, comp)); + } } } else { raise::KernelNotImplementedError(HERE); @@ -210,6 +390,11 @@ namespace kernel { } // namespace kernel -#undef FILTER_IN_I1 +#undef FILTER3D_IN_I1_I3 +#undef FILTER3D_IN_I2_I3 +#undef FILTER3D_IN_I1_I2 + +#undef FILTER2D_IN_I2 +#undef FILTER2D_IN_I1 #endif // DIGITAL_FILTER_HPP diff --git a/src/kernels/divergences.hpp b/src/kernels/divergences.hpp new file mode 100644 index 000000000..c60be564b --- /dev/null +++ b/src/kernels/divergences.hpp @@ -0,0 +1,123 @@ +/** + * @file kernels/divergences.hpp + * @brief Compute covariant divergences of fields + * @implements + * - kernel::ComputeDivergence_kernel<> + * @namespaces: + * - kernel:: + */ + +#ifndef KERNELS_DIVERGENCES_HPP +#define KERNELS_DIVERGENCES_HPP + +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/error.h" + +namespace kernel { + using namespace ntt; + + // @TODO: take care of boundaries + template + class ComputeDivergence_kernel { + const M metric; + + const ndfield_t fields; + ndfield_t buff; + const idx_t buff_idx; + + public: + ComputeDivergence_kernel(const M& metric, + const ndfield_t& fields, + ndfield_t& buff, + idx_t buff_idx) + : metric { metric } + , fields { fields } + , buff { buff } + , buff_idx { buff_idx } { + raise::ErrorIf(buff_idx >= N, "Invalid component index", HERE); + } + + Inline void operator()(index_t i1) const { + if constexpr (M::Dim == Dim::_1D) { + if constexpr (M::CoordType == Coord::Cart) { + buff(i1, buff_idx) = fields(i1, em::ex1) - fields(i1 - 1, em::ex1); + } else { + const auto i1_ = COORD(i1); + buff(i1, buff_idx) = (fields(i1, em::ex1) * + metric.sqrt_det_h({ i1_ + HALF }) - + fields(i1 - 1, em::ex1) * + metric.sqrt_det_h({ i1_ - HALF })) / + metric.sqrt_det_h({ i1_ }); + } + } else { + raise::KernelError( + HERE, + "1D implementation of ComputeDivergence_kernel called for non-1D"); + } + } + + Inline void operator()(index_t i1, index_t i2) const { + if constexpr (M::Dim == Dim::_2D) { + if constexpr (M::CoordType == Coord::Cart) { + buff(i1, i2, buff_idx) = fields(i1, i2, em::ex1) - + fields(i1 - 1, i2, em::ex1) + + fields(i1, i2, em::ex2) - + fields(i1, i2 - 1, em::ex2); + } else { + const auto i1_ = COORD(i1); + const auto i2_ = COORD(i2); + buff(i1, i2, buff_idx) = + (fields(i1, i2, em::ex1) * metric.sqrt_det_h({ i1_ + HALF, i2_ }) - + fields(i1 - 1, i2, em::ex1) * metric.sqrt_det_h({ i1_ - HALF, i2_ }) + + fields(i1, i2, em::ex2) * metric.sqrt_det_h({ i1_, i2_ + HALF }) - + fields(i1, i2 - 1, em::ex2) * metric.sqrt_det_h({ i1_, i2_ - HALF })) / + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF }); + } + } else { + raise::KernelError( + HERE, + "2D implementation of ComputeDivergence_kernel called for non-2D"); + } + } + + Inline void operator()(index_t i1, index_t i2, index_t i3) const { + if constexpr (M::Dim == Dim::_3D) { + if constexpr (M::CoordType == Coord::Cart) { + buff(i1, i2, i3, buff_idx) = fields(i1, i2, i3, em::ex1) - + fields(i1 - 1, i2, i3, em::ex1) + + fields(i1, i2, i3, em::ex2) - + fields(i1, i2 - 1, i3, em::ex2) + + fields(i1, i2, i3, em::ex3) - + fields(i1, i2, i3 - 1, em::ex3); + } else { + const auto i1_ = COORD(i1); + const auto i2_ = COORD(i2); + const auto i3_ = COORD(i3); + buff(i1, i2, i3, buff_idx) = + (fields(i1, i2, i3, em::ex1) * + metric.sqrt_det_h({ i1_ + HALF, i2_, i3_ }) - + fields(i1 - 1, i2, i3, em::ex1) * + metric.sqrt_det_h({ i1_ - HALF, i2_, i3_ }) + + fields(i1, i2, i3, em::ex2) * + metric.sqrt_det_h({ i1_, i2_ + HALF, i3_ }) - + fields(i1, i2 - 1, i3, em::ex2) * + metric.sqrt_det_h({ i1_, i2_ - HALF, i3_ }) + + fields(i1, i2, i3, em::ex3) * + metric.sqrt_det_h({ i1_, i2_, i3_ + HALF }) - + fields(i1, i2, i3 - 1, em::ex3) * + metric.sqrt_det_h({ i1_, i2_, i3_ - HALF })) / + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF, i3_ + HALF }); + } + } else { + raise::KernelError( + HERE, + "3D implementation of ComputeDivergence_kernel called for non-3D"); + } + } + }; + +} // namespace kernel + +#endif // KERNELS_DIVERGENCES_HPP diff --git a/src/kernels/faraday_gr.hpp b/src/kernels/faraday_gr.hpp index 19eede5f2..b79afa460 100644 --- a/src/kernels/faraday_gr.hpp +++ b/src/kernels/faraday_gr.hpp @@ -36,7 +36,7 @@ namespace kernel::gr { ndfield_t Bout; const ndfield_t E; const M metric; - const std::size_t i2max; + const ncells_t i2max; const real_t coeff; bool is_axis_i2min { false }; @@ -73,7 +73,9 @@ namespace kernel::gr { Bout(i1, i2, em::bx1) = Bin(i1, i2, em::bx1) + coeff * inv_sqrt_detH_0pH * (E(i1, i2, em::ex3) - E(i1, i2 + 1, em::ex3)); - if ((i2 != i2min) || !is_axis_i2min) { + if ((i2 == i2min) && is_axis_i2min) { + Bout(i1, i2, em::bx2) = ZERO; + } else { const real_t inv_sqrt_detH_pH0 { ONE / metric.sqrt_det_h( { i1_ + HALF, i2_ }) }; Bout(i1, i2, em::bx2) = Bin(i1, i2, em::bx2) + diff --git a/src/kernels/faraday_sr.hpp b/src/kernels/faraday_sr.hpp index 727974248..be3e60dd0 100644 --- a/src/kernels/faraday_sr.hpp +++ b/src/kernels/faraday_sr.hpp @@ -52,9 +52,9 @@ namespace kernel::sr { Inline void operator()(index_t i1, index_t i2) const { if constexpr (D == Dim::_2D) { - constexpr std::size_t i2min { N_GHOSTS }; - const real_t i1_ { COORD(i1) }; - const real_t i2_ { COORD(i2) }; + constexpr ncells_t i2min { N_GHOSTS }; + const real_t i1_ { COORD(i1) }; + const real_t i2_ { COORD(i2) }; const real_t inv_sqrt_detH_0pH { ONE / metric.sqrt_det_h({ i1_, i2_ + HALF }) }; diff --git a/src/kernels/fields_bcs.hpp b/src/kernels/fields_bcs.hpp index e617010b4..5a1074970 100644 --- a/src/kernels/fields_bcs.hpp +++ b/src/kernels/fields_bcs.hpp @@ -1,10 +1,22 @@ /** - * @brief: kernels/fields_bcs.hpp + * @file kernels/fields_bcs.hpp + * @brief Kernels used for field boundary conditions + * @implements + * - kernel::bc::MatchBoundaries_kernel<> + * - kernel::bc::AxisBoundaries_kernel<> + * - kernel::bc::AxisBoundariesGR_kernel<> + * - kernel::bc::AbsorbCurrentsGR_kernel<> + * - kernel::bc::EnforcedBoundaries_kernel<> + * - kernel::bc::HorizonBoundaries_kernel<> + * - kernel::bc::ConductorBoundaries_kernel<> + * @namespaces: + * - kernel::bc:: */ #ifndef KERNELS_FIELDS_BCS_HPP #define KERNELS_FIELDS_BCS_HPP +#include "enums.h" #include "global.h" #include "arch/kokkos_aliases.h" @@ -12,64 +24,158 @@ #include "utils/error.h" #include "utils/numeric.h" -namespace kernel { +namespace kernel::bc { using namespace ntt; - template - struct AbsorbBoundaries_kernel { + /* + * @tparam S: Simulation Engine + * @tparam I: Field Setter class + * @tparam M: Metric + * @tparam o: Orientation + * + * @brief Applies matching boundary conditions (with a smooth profile) in a specific direction. + * @note If a component is not specified in the field setter, it is ignored. + * @note It is supposed to only be called on the active side of the absorbing edge (so sign is not needed). + */ + template + struct MatchBoundaries_kernel { static_assert(M::is_metric, "M must be a metric class"); - static_assert(i <= static_cast(M::Dim), + static_assert(static_cast(o) < static_cast(M::Dim), "Invalid component index"); + static constexpr auto D = M::Dim; + static constexpr idx_t i = static_cast(o) + 1u; + static constexpr bool defines_dx1 = traits::has_method::value; + static constexpr bool defines_dx2 = traits::has_method::value; + static constexpr bool defines_dx3 = traits::has_method::value; + static constexpr bool defines_ex1 = traits::has_method::value; + static constexpr bool defines_ex2 = traits::has_method::value; + static constexpr bool defines_ex3 = traits::has_method::value; + static constexpr bool defines_bx1 = traits::has_method::value; + static constexpr bool defines_bx2 = traits::has_method::value; + static constexpr bool defines_bx3 = traits::has_method::value; + static_assert( + (S == SimEngine::SRPIC and (defines_ex1 or defines_ex2 or defines_ex3 or + defines_bx1 or defines_bx2 or defines_bx3)) or + ((S == SimEngine::GRPIC) and (defines_dx1 or defines_dx2 or defines_dx3 or + defines_bx1 or defines_bx2 or defines_bx3)), + "none of the components of E/D or B are specified in PGEN"); ndfield_t Fld; + const I fset; const M metric; const real_t xg_edge; const real_t dx_abs; const BCTags tags; - AbsorbBoundaries_kernel(ndfield_t Fld, - const M& metric, - real_t xg_edge, - real_t dx_abs, - BCTags tags) + ncells_t extent_2 { 0u }; + bool is_axis_i2min { false }, is_axis_i2max { false }; + + MatchBoundaries_kernel(ndfield_t Fld, + const I& fset, + const M& metric, + real_t xg_edge, + real_t dx_abs, + BCTags tags, + const boundaries_t& boundaries) : Fld { Fld } + , fset { fset } , metric { metric } , xg_edge { xg_edge } , dx_abs { dx_abs } - , tags { tags } {} + , tags { tags } { + if constexpr ((M::CoordType != Coord::Cart) && + ((D == Dim::_2D) || (D == Dim::_3D))) { + raise::ErrorIf(boundaries.size() < 2, "boundaries defined incorrectly", HERE); + is_axis_i2min = (boundaries[1].first == FldsBC::AXIS); + is_axis_i2max = (boundaries[1].second == FldsBC::AXIS); + extent_2 = static_cast(Fld.extent(1)); + } + } + + Inline auto shape(const real_t& dx) const -> real_t { + return math::tanh(dx * FOUR / dx_abs); + } Inline void operator()(index_t i1) const { if constexpr (M::Dim == Dim::_1D) { const auto i1_ = COORD(i1); - for (const auto comp : - { em::ex1, em::ex2, em::ex3, em::bx1, em::bx2, em::bx3 }) { - if ((comp == em::ex1) and not(tags & BC::Ex1)) { - continue; - } else if ((comp == em::ex2) and not(tags & BC::Ex2)) { - continue; - } else if ((comp == em::ex3) and not(tags & BC::Ex3)) { - continue; - } else if ((comp == em::bx1) and not(tags & BC::Bx1)) { - continue; - } else if ((comp == em::bx2) and not(tags & BC::Bx2)) { - continue; - } else if ((comp == em::bx3) and not(tags & BC::Bx3)) { - continue; + + if constexpr (S == SimEngine::SRPIC) { + coord_t x_Ph_0 { ZERO }; + coord_t x_Ph_H { ZERO }; + metric.template convert({ i1_ }, x_Ph_0); + metric.template convert({ i1_ + HALF }, x_Ph_H); + + if constexpr (defines_ex1 or defines_bx2 or defines_bx3) { + const auto s = shape(math::abs( + metric.template convert(i1_ + HALF) - xg_edge)); + if constexpr (defines_ex1) { + if (tags & BC::E) { + Fld(i1, em::ex1) = s * Fld(i1, em::ex1) + + (ONE - s) * + metric.template transform<1, Idx::T, Idx::U>( + { i1_ + HALF }, + fset.ex1(x_Ph_H)); + } + } + if constexpr (defines_bx2 or defines_bx3) { + if (tags & BC::B) { + if constexpr (defines_bx2) { + Fld(i1, em::bx2) = s * Fld(i1, em::bx2) + + (ONE - s) * + metric.template transform<2, Idx::T, Idx::U>( + { i1_ + HALF }, + fset.bx2(x_Ph_H)); + } + if constexpr (defines_bx3) { + Fld(i1, em::bx3) = s * Fld(i1, em::bx3) + + (ONE - s) * + metric.template transform<3, Idx::T, Idx::U>( + { i1_ + HALF }, + fset.bx3(x_Ph_H)); + } + } + } } - coord_t x_Cd { ZERO }; - if (comp == em::ex1 or comp == em::bx2 or comp == em::bx3) { - x_Cd[0] = i1_ + HALF; - } else if (comp == em::ex2 or comp == em::bx1 or comp == em::ex3) { - x_Cd[0] = i1_; + if constexpr (defines_bx1 or defines_ex2 or defines_ex3) { + const auto s = shape(math::abs( + metric.template convert(i1_) - xg_edge)); + if constexpr (defines_bx1) { + if (tags & BC::B) { + Fld(i1, em::bx1) = s * Fld(i1, em::bx1) + + (ONE - s) * + metric.template transform<1, Idx::T, Idx::U>( + { i1_ }, + fset.bx1(x_Ph_0)); + } + } + if constexpr (defines_ex2 or defines_ex3) { + if (tags & BC::E) { + if constexpr (defines_ex2) { + Fld(i1, em::ex2) = s * Fld(i1, em::ex2) + + (ONE - s) * + metric.template transform<2, Idx::T, Idx::U>( + { i1_ }, + fset.ex2(x_Ph_0)); + } + if constexpr (defines_ex3) { + Fld(i1, em::ex3) = s * Fld(i1, em::ex3) + + (ONE - s) * + metric.template transform<3, Idx::T, Idx::U>( + { i1_ }, + fset.ex3(x_Ph_0)); + } + } + } } - const auto dx = math::abs( - metric.template convert(x_Cd[i - 1]) - xg_edge); - Fld(i1, comp) *= math::tanh(dx / (INV_4 * dx_abs)); + } else { + // GRPIC + raise::KernelError(HERE, "1D GRPIC not implemented"); } } else { raise::KernelError( HERE, - "AbsorbFields_kernel: 1D implementation called for D != 1"); + "MatchBoundaries_kernel: 1D implementation called for D != 1"); } } @@ -77,43 +183,168 @@ namespace kernel { if constexpr (M::Dim == Dim::_2D) { const auto i1_ = COORD(i1); const auto i2_ = COORD(i2); - for (const auto comp : - { em::ex1, em::ex2, em::ex3, em::bx1, em::bx2, em::bx3 }) { - if ((comp == em::ex1) and not(tags & BC::Ex1)) { - continue; - } else if ((comp == em::ex2) and not(tags & BC::Ex2)) { - continue; - } else if ((comp == em::ex3) and not(tags & BC::Ex3)) { - continue; - } else if ((comp == em::bx1) and not(tags & BC::Bx1)) { - continue; - } else if ((comp == em::bx2) and not(tags & BC::Bx2)) { - continue; - } else if ((comp == em::bx3) and not(tags & BC::Bx3)) { - continue; + + // SRPIC + if constexpr (defines_ex1 or defines_dx1 or defines_bx2) { + // i1 + 1/2, i2 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_ + HALF; + } else { + xi_Cd = i2_; } - coord_t x_Cd { ZERO }; - if (comp == em::ex1 or comp == em::bx2) { - x_Cd[0] = i1_ + HALF; - x_Cd[1] = i2_; - } else if (comp == em::ex2 or comp == em::bx1) { - x_Cd[0] = i1_; - x_Cd[1] = i2_ + HALF; - } else if (comp == em::ex3) { - x_Cd[0] = i1_; - x_Cd[1] = i2_; - } else if (comp == em::bx3) { - x_Cd[0] = i1_ + HALF; - x_Cd[1] = i2_ + HALF; + + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_H0 { ZERO }; + metric.template convert({ i1_ + HALF, i2_ }, x_Ph_H0); + + if constexpr (defines_ex1 or defines_dx1) { + if ((tags & BC::E) or (tags & BC::D)) { + if constexpr (defines_ex1 and S == SimEngine::SRPIC) { + Fld(i1, i2, em::ex1) = s * Fld(i1, i2, em::ex1) + + (ONE - s) * + metric.template transform<1, Idx::T, Idx::U>( + { i1_ + HALF, i2_ }, + fset.ex1(x_Ph_H0)); + } else if constexpr (defines_dx1 and S == SimEngine::GRPIC) { + Fld(i1, i2, em::dx1) = s * Fld(i1, i2, em::dx1) + + (ONE - s) * fset.dx1(x_Ph_H0); + } + } + } + + if constexpr (defines_bx2) { + if (tags & BC::B) { + if constexpr (S == SimEngine::SRPIC) { + Fld(i1, i2, em::bx2) = s * Fld(i1, i2, em::bx2) + + (ONE - s) * + metric.template transform<2, Idx::T, Idx::U>( + { i1_ + HALF, i2_ }, + fset.bx2(x_Ph_H0)); + } else if constexpr (S == SimEngine::GRPIC) { + Fld(i1, i2, em::bx2) = s * Fld(i1, i2, em::bx2) + + (ONE - s) * fset.bx2(x_Ph_H0); + } + } + } + } + + if constexpr (defines_ex2 or defines_dx2 or defines_bx1) { + // i1, i2 + 1/2 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_; + } else { + xi_Cd = i2_ + HALF; + } + + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_0H { ZERO }; + metric.template convert({ i1_, i2_ + HALF }, x_Ph_0H); + + if constexpr (defines_ex2 or defines_dx2) { + if ((tags & BC::E) or (tags & BC::D)) { + if constexpr (defines_ex2 and S == SimEngine::SRPIC) { + Fld(i1, i2, em::ex2) = s * Fld(i1, i2, em::ex2) + + (ONE - s) * + metric.template transform<2, Idx::T, Idx::U>( + { i1_, i2_ + HALF }, + fset.ex2(x_Ph_0H)); + } else if constexpr (defines_dx2 and S == SimEngine::GRPIC) { + Fld(i1, i2, em::dx2) = s * Fld(i1, i2, em::dx2) + + (ONE - s) * fset.dx2(x_Ph_0H); + } + } + } + + if constexpr (defines_bx1) { + if (tags & BC::B) { + if constexpr (S == SimEngine::SRPIC) { + Fld(i1, i2, em::bx1) = s * Fld(i1, i2, em::bx1) + + (ONE - s) * + metric.template transform<1, Idx::T, Idx::U>( + { i1_, i2_ + HALF }, + fset.bx1(x_Ph_0H)); + } else if constexpr (S == SimEngine::GRPIC) { + Fld(i1, i2, em::bx1) = s * Fld(i1, i2, em::bx1) + + (ONE - s) * fset.bx1(x_Ph_0H); + } + } + } + } + + if constexpr (defines_ex3 or defines_dx3) { + if (tags & BC::E) { + // i1, i2 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_; + } else { + xi_Cd = i2_; + } + + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_00 { ZERO }; + metric.template convert({ i1_, i2_ }, x_Ph_00); + + if constexpr (defines_ex3 and S == SimEngine::SRPIC) { + Fld(i1, i2, em::ex3) = s * Fld(i1, i2, em::ex3); + if ((!is_axis_i2min or (i2 > N_GHOSTS)) and + (!is_axis_i2max or (i2 < extent_2 - N_GHOSTS))) { + Fld(i1, i2, em::ex3) += (ONE - s) * + metric.template transform<3, Idx::T, Idx::U>( + { i1_, i2_ }, + fset.ex3(x_Ph_00)); + } + } else if constexpr (defines_dx3 and S == SimEngine::GRPIC) { + Fld(i1, i2, em::dx3) = s * Fld(i1, i2, em::dx3); + if ((!is_axis_i2min or (i2 > N_GHOSTS)) and + (!is_axis_i2max or (i2 < extent_2 - N_GHOSTS))) { + Fld(i1, i2, em::dx3) += (ONE - s) * fset.dx3(x_Ph_00); + } + } + } + } + + if constexpr (defines_bx3) { + if (tags & BC::B) { + // i1 + 1/2, i2 + 1/2 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_ + HALF; + } else { + xi_Cd = i2_ + HALF; + } + + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_HH { ZERO }; + metric.template convert({ i1_ + HALF, i2_ + HALF }, + x_Ph_HH); + + if constexpr (S == SimEngine::SRPIC) { + Fld(i1, i2, em::bx3) = s * Fld(i1, i2, em::bx3) + + (ONE - s) * + metric.template transform<3, Idx::T, Idx::U>( + { i1_ + HALF, i2_ + HALF }, + fset.bx3(x_Ph_HH)); + } else if constexpr (S == SimEngine::GRPIC) { + Fld(i1, i2, em::bx3) = s * Fld(i1, i2, em::bx3) + + (ONE - s) * fset.bx3(x_Ph_HH); + } } - const auto dx = math::abs( - metric.template convert(x_Cd[i - 1]) - xg_edge); - Fld(i1, i2, comp) *= math::tanh(dx / (INV_4 * dx_abs)); } } else { raise::KernelError( HERE, - "AbsorbFields_kernel: 2D implementation called for D != 2"); + "MatchBoundaries_kernel: 2D implementation called for D != 2"); } } @@ -122,66 +353,486 @@ namespace kernel { const auto i1_ = COORD(i1); const auto i2_ = COORD(i2); const auto i3_ = COORD(i3); - for (const auto comp : - { em::ex1, em::ex2, em::ex3, em::bx1, em::bx2, em::bx3 }) { - if ((comp == em::ex1) and not(tags & BC::Ex1)) { - continue; - } else if ((comp == em::ex2) and not(tags & BC::Ex2)) { - continue; - } else if ((comp == em::ex3) and not(tags & BC::Ex3)) { - continue; - } else if ((comp == em::bx1) and not(tags & BC::Bx1)) { - continue; - } else if ((comp == em::bx2) and not(tags & BC::Bx2)) { - continue; - } else if ((comp == em::bx3) and not(tags & BC::Bx3)) { - continue; + + if constexpr (S == SimEngine::SRPIC) { + // SRPIC + if constexpr (defines_ex1 or defines_ex2 or defines_ex3) { + if (tags & BC::E) { + if constexpr (defines_ex1) { + // i1 + 1/2, i2, i3 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_ + HALF; + } else if constexpr (o == in::x2) { + xi_Cd = i2_; + } else { + xi_Cd = i3_; + } + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_H00 { ZERO }; + metric.template convert({ i1_ + HALF, i2_, i3_ }, + x_Ph_H00); + + Fld(i1, i2, i3, em::ex1) = + s * Fld(i1, i2, i3, em::ex1) + + (ONE - s) * metric.template transform<1, Idx::T, Idx::U>( + { i1_ + HALF, i2_, i3_ }, + fset.ex1(x_Ph_H00)); + } + + if constexpr (defines_ex2) { + // i1, i2 + 1/2, i3 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_; + } else if constexpr (o == in::x2) { + xi_Cd = i2_ + HALF; + } else { + xi_Cd = i3_; + } + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_0H0 { ZERO }; + metric.template convert({ i1_, i2_ + HALF, i3_ }, + x_Ph_0H0); + + Fld(i1, i2, i3, em::ex2) = + s * Fld(i1, i2, i3, em::ex2) + + (ONE - s) * metric.template transform<2, Idx::T, Idx::U>( + { i1_, i2_ + HALF, i3_ }, + fset.ex2(x_Ph_0H0)); + } + + if constexpr (defines_ex3) { + // i1, i2, i3 + 1/2 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_; + } else if constexpr (o == in::x2) { + xi_Cd = i2_; + } else { + xi_Cd = i3_ + HALF; + } + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_00H { ZERO }; + metric.template convert({ i1_, i2_, i3_ + HALF }, + x_Ph_00H); + Fld(i1, i2, i3, em::ex3) = s * Fld(i1, i2, i3, em::ex3); + if ((!is_axis_i2min or (i2 > N_GHOSTS)) and + (!is_axis_i2max or (i2 < extent_2 - N_GHOSTS))) { + Fld(i1, i2, i3, em::ex3) += + (ONE - s) * metric.template transform<3, Idx::T, Idx::U>( + { i1_, i2_, i3_ + HALF }, + fset.ex3(x_Ph_00H)); + } + } + } } - coord_t x_Cd { ZERO }; - if (comp == em::ex1) { - x_Cd[0] = i1_ + HALF; - x_Cd[1] = i2_; - x_Cd[2] = i3_; - } else if (comp == em::ex2) { - x_Cd[0] = i1_; - x_Cd[1] = i2_ + HALF; - x_Cd[2] = i3_; - } else if (comp == em::ex3) { - x_Cd[0] = i1_; - x_Cd[1] = i2_; - x_Cd[2] = i3_ + HALF; - } else if (comp == em::bx1) { - x_Cd[0] = i1_; - x_Cd[1] = i2_ + HALF; - x_Cd[2] = i3_ + HALF; - } else if (comp == em::bx2) { - x_Cd[0] = i1_ + HALF; - x_Cd[1] = i2_; - x_Cd[2] = i3_ + HALF; - } else if (comp == em::bx3) { - x_Cd[0] = i1_ + HALF; - x_Cd[1] = i2_ + HALF; - x_Cd[2] = i3_; + + if constexpr (defines_bx1 or defines_bx2 or defines_bx3) { + if (tags & BC::B) { + if constexpr (defines_bx1) { + // i1, i2 + 1/2, i3 + 1/2 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_; + } else if constexpr (o == in::x2) { + xi_Cd = i2_ + HALF; + } else { + xi_Cd = i3_ + HALF; + } + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_0HH { ZERO }; + metric.template convert( + { i1_, i2_ + HALF, i3_ + HALF }, + x_Ph_0HH); + + Fld(i1, i2, i3, em::bx1) = + s * Fld(i1, i2, i3, em::bx1) + + (ONE - s) * metric.template transform<1, Idx::T, Idx::U>( + { i1_, i2_ + HALF, i3_ + HALF }, + fset.bx1(x_Ph_0HH)); + } + + if constexpr (defines_bx2) { + // i1 + 1/2, i2, i3 + 1/2 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_ + HALF; + } else if constexpr (o == in::x2) { + xi_Cd = i2_; + } else { + xi_Cd = i3_ + HALF; + } + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_H0H { ZERO }; + metric.template convert( + { i1_ + HALF, i2_, i3_ + HALF }, + x_Ph_H0H); + + Fld(i1, i2, i3, em::bx2) = + s * Fld(i1, i2, i3, em::bx2) + + (ONE - s) * metric.template transform<2, Idx::T, Idx::U>( + { i1_ + HALF, i2_, i3_ + HALF }, + fset.bx2(x_Ph_H0H)); + } + + if constexpr (defines_bx3) { + // i1 + 1/2, i2 + 1/2, i3 + real_t xi_Cd; + if constexpr (o == in::x1) { + xi_Cd = i1_ + HALF; + } else if constexpr (o == in::x2) { + xi_Cd = i2_ + HALF; + } else { + xi_Cd = i3_; + } + + const auto s = shape(math::abs( + metric.template convert(xi_Cd) - xg_edge)); + + coord_t x_Ph_HH0 { ZERO }; + metric.template convert( + { i1_ + HALF, i2_ + HALF, i3_ }, + x_Ph_HH0); + + Fld(i1, i2, i3, em::bx3) = + s * Fld(i1, i2, i3, em::bx3) + + (ONE - s) * metric.template transform<3, Idx::T, Idx::U>( + { i1_ + HALF, i2_ + HALF, i3_ }, + fset.bx3(x_Ph_HH0)); + } + } } - const auto dx = math::abs( - metric.template convert(x_Cd[i - 1]) - xg_edge); - Fld(i1, i2, i3, comp) *= math::tanh(dx / (INV_4 * dx_abs)); + } else { + // GRPIC + raise::KernelError(HERE, "GRPIC not implemented"); } } else { raise::KernelError( HERE, - "AbsorbFields_kernel: 3D implementation called for D != 3"); + "MatchBoundaries_kernel: 3D implementation called for D != 3"); } } }; - template - struct AxisBoundaries_kernel { + template + struct ConductorBoundaries_kernel { + static_assert(static_cast(o) < static_cast(D), + "Invalid component index"); + ndfield_t Fld; const std::size_t i_edge; - const bool setE, setB; + const BCTags tags; + + ConductorBoundaries_kernel(ndfield_t Fld, std::size_t i_edge, BCTags tags) + : Fld { Fld } + , i_edge { i_edge } + , tags { tags } {} + + Inline void operator()(index_t i1) const { + if constexpr (D == Dim::_1D) { + if (tags & BC::E) { + if (i1 == 0) { + Fld(i_edge, em::ex2) = ZERO; + Fld(i_edge, em::ex3) = ZERO; + } else { + if constexpr (not P) { + Fld(i_edge - i1, em::ex1) = Fld(i_edge + i1 - 1, em::ex1); + Fld(i_edge - i1, em::ex2) = -Fld(i_edge + i1, em::ex2); + Fld(i_edge - i1, em::ex3) = -Fld(i_edge + i1, em::ex3); + } else { + Fld(i_edge + i1 - 1, em::ex1) = Fld(i_edge - i1, em::ex1); + Fld(i_edge + i1, em::ex2) = -Fld(i_edge - i1, em::ex2); + Fld(i_edge + i1, em::ex3) = -Fld(i_edge - i1, em::ex3); + } + } + } + + if (tags & BC::B) { + if (i1 == 0) { + Fld(i_edge, em::bx1) = ZERO; + } else { + if constexpr (not P) { + Fld(i_edge - i1, em::bx1) = -Fld(i_edge + i1, em::bx1); + Fld(i_edge - i1, em::bx2) = Fld(i_edge + i1 - 1, em::bx2); + Fld(i_edge - i1, em::bx3) = Fld(i_edge + i1 - 1, em::bx3); + } else { + Fld(i_edge + i1, em::bx1) = -Fld(i_edge - i1, em::bx1); + Fld(i_edge + i1 - 1, em::bx2) = Fld(i_edge - i1, em::bx2); + Fld(i_edge + i1 - 1, em::bx3) = Fld(i_edge - i1, em::bx3); + } + } + } + } else { + raise::KernelError( + HERE, + "ConductorBoundaries_kernel: 1D implementation called for D != 1"); + } + } + + Inline void operator()(index_t i1, index_t i2) const { + if constexpr (D == Dim::_2D) { + if constexpr (o == in::x1) { + if (tags & BC::E) { + if (i1 == 0) { + Fld(i_edge, i2, em::ex2) = ZERO; + Fld(i_edge, i2, em::ex3) = ZERO; + } else { + if constexpr (not P) { + Fld(i_edge - i1, i2, em::ex1) = Fld(i_edge + i1 - 1, i2, em::ex1); + Fld(i_edge - i1, i2, em::ex2) = -Fld(i_edge + i1, i2, em::ex2); + Fld(i_edge - i1, i2, em::ex3) = -Fld(i_edge + i1, i2, em::ex3); + } else { + Fld(i_edge + i1 - 1, i2, em::ex1) = Fld(i_edge - i1, i2, em::ex1); + Fld(i_edge + i1, i2, em::ex2) = -Fld(i_edge - i1, i2, em::ex2); + Fld(i_edge + i1, i2, em::ex3) = -Fld(i_edge - i1, i2, em::ex3); + } + } + } + + if (tags & BC::B) { + if (i1 == 0) { + Fld(i_edge, i2, em::bx1) = ZERO; + } else { + if constexpr (not P) { + Fld(i_edge - i1, i2, em::bx1) = -Fld(i_edge + i1, i2, em::bx1); + Fld(i_edge - i1, i2, em::bx2) = Fld(i_edge + i1 - 1, i2, em::bx2); + Fld(i_edge - i1, i2, em::bx3) = Fld(i_edge + i1 - 1, i2, em::bx3); + } else { + Fld(i_edge + i1, i2, em::bx1) = -Fld(i_edge - i1, i2, em::bx1); + Fld(i_edge + i1 - 1, i2, em::bx2) = Fld(i_edge - i1, i2, em::bx2); + Fld(i_edge + i1 - 1, i2, em::bx3) = Fld(i_edge - i1, i2, em::bx3); + } + } + } + } else { + if (tags & BC::E) { + if (i2 == 0) { + Fld(i1, i_edge, em::ex1) = ZERO; + Fld(i1, i_edge, em::ex3) = ZERO; + } else { + if constexpr (not P) { + Fld(i1, i_edge - i2, em::ex1) = -Fld(i1, i_edge + i2, em::ex1); + Fld(i1, i_edge - i2, em::ex2) = Fld(i1, i_edge + i2 - 1, em::ex2); + Fld(i1, i_edge - i2, em::ex3) = -Fld(i1, i_edge + i2, em::ex3); + } else { + Fld(i1, i_edge + i2, em::ex1) = -Fld(i1, i_edge - i2, em::ex1); + Fld(i1, i_edge + i2 - 1, em::ex2) = Fld(i1, i_edge - i2, em::ex2); + Fld(i1, i_edge + i2, em::ex3) = -Fld(i1, i_edge - i2, em::ex3); + } + } + } + + if (tags & BC::B) { + if (i2 == 0) { + Fld(i1, i_edge, em::bx2) = ZERO; + } else { + if constexpr (not P) { + Fld(i1, i_edge - i2, em::bx1) = Fld(i1, i_edge + i2 - 1, em::bx1); + Fld(i1, i_edge - i2, em::bx2) = -Fld(i1, i_edge + i2, em::bx2); + Fld(i1, i_edge - i2, em::bx3) = Fld(i1, i_edge + i2 - 1, em::bx3); + } else { + Fld(i1, i_edge + i2 - 1, em::bx1) = Fld(i1, i_edge - i2, em::bx1); + Fld(i1, i_edge + i2, em::bx2) = -Fld(i1, i_edge - i2, em::bx2); + Fld(i1, i_edge + i2 - 1, em::bx3) = Fld(i1, i_edge - i2, em::bx3); + } + } + } + } + } else { + raise::KernelError( + HERE, + "ConductorBoundaries_kernel: 2D implementation called for D != 2"); + } + } + + Inline void operator()(index_t i1, index_t i2, index_t i3) const { + if constexpr (D == Dim::_3D) { + if constexpr (o == in::x1) { + if (tags & BC::E) { + if (i1 == 0) { + Fld(i_edge, i2, i3, em::ex2) = ZERO; + Fld(i_edge, i2, i3, em::ex3) = ZERO; + } else { + if constexpr (not P) { + Fld(i_edge - i1, i2, i3, em::ex1) = Fld(i_edge + i1 - 1, + i2, + i3, + em::ex1); + Fld(i_edge - i1, i2, i3, em::ex2) = -Fld(i_edge + i1, i2, i3, em::ex2); + Fld(i_edge - i1, i2, i3, em::ex3) = -Fld(i_edge + i1, i2, i3, em::ex3); + } else { + Fld(i_edge + i1 - 1, i2, i3, em::ex1) = Fld(i_edge - i1, + i2, + i3, + em::ex1); + Fld(i_edge + i1, i2, i3, em::ex2) = -Fld(i_edge - i1, i2, i3, em::ex2); + Fld(i_edge + i1, i2, i3, em::ex3) = -Fld(i_edge - i1, i2, i3, em::ex3); + } + } + } - AxisBoundaries_kernel(ndfield_t Fld, std::size_t i_edge, BCTags tags) + if (tags & BC::B) { + if (i1 == 0) { + Fld(i_edge, i2, i3, em::bx1) = ZERO; + } else { + if constexpr (not P) { + Fld(i_edge - i1, i2, i3, em::bx1) = -Fld(i_edge + i1, i2, i3, em::bx1); + Fld(i_edge - i1, i2, i3, em::bx2) = Fld(i_edge + i1 - 1, + i2, + i3, + em::bx2); + Fld(i_edge - i1, i2, i3, em::bx3) = Fld(i_edge + i1 - 1, + i2, + i3, + em::bx3); + } else { + Fld(i_edge + i1, i2, i3, em::bx1) = -Fld(i_edge - i1, i2, i3, em::bx1); + Fld(i_edge + i1 - 1, i2, i3, em::bx2) = Fld(i_edge - i1, + i2, + i3, + em::bx2); + Fld(i_edge + i1 - 1, i2, i3, em::bx3) = Fld(i_edge - i1, + i2, + i3, + em::bx3); + } + } + } + } else if (o == in::x2) { + if (tags & BC::E) { + if (i2 == 0) { + Fld(i1, i_edge, i3, em::ex1) = ZERO; + Fld(i1, i_edge, i3, em::ex3) = ZERO; + } else { + if constexpr (not P) { + Fld(i1, i_edge - i2, i3, em::ex1) = -Fld(i1, i_edge + i2, i3, em::ex1); + Fld(i1, i_edge - i2, i3, em::ex2) = Fld(i1, + i_edge + i2 - 1, + i3, + em::ex2); + Fld(i1, i_edge - i2, i3, em::ex3) = -Fld(i1, i_edge + i2, i3, em::ex3); + } else { + Fld(i1, i_edge + i2, i3, em::ex1) = -Fld(i1, i_edge - i2, i3, em::ex1); + Fld(i1, i_edge + i2 - 1, i3, em::ex2) = Fld(i1, + i_edge - i2, + i3, + em::ex2); + Fld(i1, i_edge + i2, i3, em::ex3) = -Fld(i1, i_edge - i2, i3, em::ex3); + } + } + } + + if (tags & BC::B) { + if (i2 == 0) { + Fld(i1, i_edge, i3, em::bx2) = ZERO; + } else { + if constexpr (not P) { + Fld(i1, i_edge - i2, i3, em::bx1) = Fld(i1, + i_edge + i2 - 1, + i3, + em::bx1); + Fld(i1, i_edge - i2, i3, em::bx2) = -Fld(i1, i_edge + i2, i3, em::bx2); + Fld(i1, i_edge - i2, i3, em::bx3) = Fld(i1, + i_edge + i2 - 1, + i3, + em::bx3); + } else { + Fld(i1, i_edge + i2 - 1, i3, em::bx1) = Fld(i1, + i_edge - i2, + i3, + em::bx1); + Fld(i1, i_edge + i2, i3, em::bx2) = -Fld(i1, i_edge - i2, i3, em::bx2); + Fld(i1, i_edge + i2 - 1, i3, em::bx3) = Fld(i1, + i_edge - i2, + i3, + em::bx3); + } + } + } + } else { + if (tags & BC::E) { + if (i3 == 0) { + Fld(i1, i2, i_edge, em::ex1) = ZERO; + Fld(i1, i2, i_edge, em::ex2) = ZERO; + } else { + if constexpr (not P) { + Fld(i1, i2, i_edge - i3, em::ex1) = -Fld(i1, i2, i_edge + i3, em::ex1); + Fld(i1, i2, i_edge - i3, em::ex2) = -Fld(i1, i2, i_edge + i3, em::ex2); + Fld(i1, i2, i_edge - i3, em::ex3) = Fld(i1, + i2, + i_edge + i3 - 1, + em::ex3); + } else { + Fld(i1, i2, i_edge + i3, em::ex1) = -Fld(i1, i2, i_edge - i3, em::ex1); + Fld(i1, i2, i_edge + i3, em::ex2) = -Fld(i1, i2, i_edge - i3, em::ex2); + Fld(i1, i2, i_edge + i3 - 1, em::ex3) = Fld(i1, + i2, + i_edge - i3, + em::ex3); + } + } + } + + if (tags & BC::B) { + if (i3 == 0) { + Fld(i1, i2, i_edge, em::bx3) = ZERO; + } else { + if constexpr (not P) { + Fld(i1, i2, i_edge - i3, em::bx1) = Fld(i1, + i2, + i_edge + i3 - 1, + em::bx1); + Fld(i1, i2, i_edge - i3, em::bx2) = Fld(i1, + i2, + i_edge + i3 - 1, + em::bx2); + Fld(i1, i2, i_edge - i3, em::bx3) = -Fld(i1, i2, i_edge + i3, em::bx3); + } else { + Fld(i1, i2, i_edge + i3 - 1, em::bx1) = Fld(i1, + i2, + i_edge - i3, + em::bx1); + Fld(i1, i2, i_edge + i3 - 1, em::bx2) = Fld(i1, + i2, + i_edge - i3, + em::bx2); + Fld(i1, i2, i_edge + i3, em::bx3) = -Fld(i1, i2, i_edge - i3, em::bx3); + } + } + } + } + } else { + raise::KernelError( + HERE, + "ConductorBoundaries_kernel: 3D implementation called for D != 3"); + } + } + }; + + /* + * @tparam D: Dimension + * @tparam P: Positive/Negative direction + * + * @brief Applies boundary conditions near the polar axis + */ + template + struct AxisBoundaries_kernel { + ndfield_t Fld; + const ncells_t i_edge; + const bool setE, setB; + + AxisBoundaries_kernel(ndfield_t Fld, ncells_t i_edge, BCTags tags) : Fld { Fld } , i_edge { i_edge } , setE { tags & BC::Ex1 or tags & BC::Ex2 or tags & BC::Ex3 } @@ -189,25 +840,30 @@ namespace kernel { Inline void operator()(index_t i1) const { if constexpr (D == Dim::_2D) { + // ! TODO: not all components are necessary if constexpr (not P) { if (setE) { Fld(i1, i_edge - 1, em::ex2) = -Fld(i1, i_edge, em::ex2); Fld(i1, i_edge, em::ex3) = ZERO; + Fld(i1, i_edge - 1, em::ex3) = Fld(i1, i_edge + 1, em::ex3); } if (setB) { Fld(i1, i_edge - 1, em::bx1) = Fld(i1, i_edge, em::bx1); Fld(i1, i_edge, em::bx2) = ZERO; + Fld(i1, i_edge - 1, em::bx2) = -Fld(i1, i_edge + 1, em::bx2); Fld(i1, i_edge - 1, em::bx3) = Fld(i1, i_edge, em::bx3); } } else { if (setE) { - Fld(i1, i_edge, em::ex2) = -Fld(i1, i_edge - 1, em::ex2); - Fld(i1, i_edge, em::ex3) = ZERO; + Fld(i1, i_edge, em::ex2) = -Fld(i1, i_edge - 1, em::ex2); + Fld(i1, i_edge, em::ex3) = ZERO; + Fld(i1, i_edge + 1, em::ex3) = Fld(i1, i_edge - 1, em::ex3); } if (setB) { - Fld(i1, i_edge, em::bx1) = Fld(i1, i_edge - 1, em::bx1); - Fld(i1, i_edge, em::bx2) = ZERO; - Fld(i1, i_edge, em::bx3) = Fld(i1, i_edge - 1, em::bx3); + Fld(i1, i_edge, em::bx1) = Fld(i1, i_edge - 1, em::bx1); + Fld(i1, i_edge, em::bx2) = ZERO; + Fld(i1, i_edge + 1, em::bx2) = -Fld(i1, i_edge - 1, em::bx2); + Fld(i1, i_edge, em::bx3) = Fld(i1, i_edge - 1, em::bx3); } } } else { @@ -216,8 +872,60 @@ namespace kernel { } }; + // /* + // * @tparam I: Field Setter class + // * @tparam M: Metric + // * @tparam P: Positive/Negative direction + // * @tparam O: Orientation + // * + // * @brief Applies enforced boundary conditions (fixed value) + // */ + // template + // struct AxisBoundariesGR_kernel { + // ndfield_t Fld; + // const std::size_t i_edge; + // const bool setE, setB; + // + // AxisBoundariesGR_kernel(ndfield_t Fld, std::size_t i_edge, BCTags tags) + // : Fld { Fld } // , i_edge { i_edge } + // , i_edge { P ? (i_edge + 1) : i_edge } + // , setE { tags & BC::Ex1 or tags & BC::Ex2 or tags & BC::Ex3 } + // , setB { tags & BC::Bx1 or tags & BC::Bx2 or tags & BC::Bx3 } {} + // + // Inline void operator()(index_t i1) const { + // if constexpr (D == Dim::_2D) { + // // if (setB) { + // // Fld(i1, i_edge, em::bx2) = ZERO; + // // } + // if constexpr (not P) { + // if (setE) { + // Fld(i1, i_edge - 1, em::ex2) = -Fld(i1, i_edge, em::ex2); + // Fld(i1, i_edge, em::ex3) = ZERO; + // } + // if (setB) { + // Fld(i1, i_edge - 1, em::bx1) = Fld(i1, i_edge, em::bx1); + // Fld(i1, i_edge, em::bx2) = ZERO; + // Fld(i1, i_edge - 1, em::bx3) = Fld(i1, i_edge, em::bx3); + // } + // } else { + // if (setE) { + // Fld(i1, i_edge + 1, em::ex2) = -Fld(i1, i_edge, em::ex2); + // Fld(i1, i_edge + 1, em::ex3) = ZERO; + // } + // if (setB) { + // Fld(i1, i_edge + 1, em::bx1) = Fld(i1, i_edge, em::bx1); + // Fld(i1, i_edge + 1, em::bx2) = ZERO; + // Fld(i1, i_edge + 1, em::bx3) = Fld(i1, i_edge, em::bx3); + // } + // } + // } else { + // raise::KernelError(HERE, "AxisBoundariesGR_kernel: D != 2"); + // } + // } + // }; + template - struct AtmosphereBoundaries_kernel { + struct EnforcedBoundaries_kernel { static constexpr Dimension D = M::Dim; static constexpr bool defines_ex1 = traits::has_method::value; static constexpr bool defines_ex2 = traits::has_method::value; @@ -226,31 +934,29 @@ namespace kernel { static constexpr bool defines_bx2 = traits::has_method::value; static constexpr bool defines_bx3 = traits::has_method::value; - static_assert(defines_ex1 and defines_ex2 and defines_ex3 and - defines_bx1 and defines_bx2 and defines_bx3, - "not all components of E or B are specified in PGEN"); + static_assert(defines_ex1 or defines_ex2 or defines_ex3 or defines_bx1 or + defines_bx2 or defines_bx3, + "none of the components of E or B are specified in PGEN"); static_assert(M::is_metric, "M must be a metric class"); - static_assert(static_cast(O) < - static_cast(M::Dim), + static_assert(static_cast(O) < static_cast(M::Dim), "Invalid Orientation"); - ndfield_t Fld; - const I finit; - const M metric; - const std::size_t i_edge; - const bool setE, setB; + ndfield_t Fld; + const I fset; + const M metric; + const ncells_t i_edge; + const BCTags tags; - AtmosphereBoundaries_kernel(ndfield_t& Fld, - const I& finit, - const M& metric, - std::size_t i_edge, - BCTags tags) + EnforcedBoundaries_kernel(ndfield_t& Fld, + const I& fset, + const M& metric, + ncells_t i_edge, + BCTags tags) : Fld { Fld } - , finit { finit } + , fset { fset } , metric { metric } , i_edge { i_edge + N_GHOSTS } - , setE { tags & BC::Ex1 or tags & BC::Ex2 or tags & BC::Ex3 } - , setB { tags & BC::Bx1 or tags & BC::Bx2 or tags & BC::Bx3 } {} + , tags { tags } {} Inline void operator()(index_t i1) const { if constexpr (D == Dim::_1D) { @@ -259,8 +965,12 @@ namespace kernel { coord_t x_Ph_H { ZERO }; metric.template convert({ i1_ }, x_Ph_0); metric.template convert({ i1_ + HALF }, x_Ph_H); - bool setEx1 = setE, setEx2 = setE, setEx3 = setE, setBx1 = setB, - setBx2 = setB, setBx3 = setB; + bool setEx1 = defines_ex1 and (tags & BC::E), + setEx2 = defines_ex2 and (tags & BC::E), + setEx3 = defines_ex3 and (tags & BC::E), + setBx1 = defines_bx1 and (tags & BC::B), + setBx2 = defines_bx2 and (tags & BC::B), + setBx3 = defines_bx3 and (tags & BC::B); if constexpr (O == in::x1) { // x1 -- normal // x2,x3 -- tangential @@ -276,35 +986,47 @@ namespace kernel { } else { raise::KernelError(HERE, "Invalid Orientation"); } - if (setEx1) { - Fld(i1, em::ex1) = metric.template transform<1, Idx::T, Idx::U>( - { i1_ + HALF }, - finit.ex1(x_Ph_H)); + if constexpr (defines_ex1) { + if (setEx1) { + Fld(i1, em::ex1) = metric.template transform<1, Idx::T, Idx::U>( + { i1_ + HALF }, + fset.ex1(x_Ph_H)); + } } - if (setEx2) { - Fld(i1, em::ex2) = metric.template transform<2, Idx::T, Idx::U>( - { i1_ }, - finit.ex2(x_Ph_0)); + if constexpr (defines_ex2) { + if (setEx2) { + Fld(i1, em::ex2) = metric.template transform<2, Idx::T, Idx::U>( + { i1_ }, + fset.ex2(x_Ph_0)); + } } - if (setEx3) { - Fld(i1, em::ex3) = metric.template transform<3, Idx::T, Idx::U>( - { i1_ }, - finit.ex3(x_Ph_0)); + if constexpr (defines_ex3) { + if (setEx3) { + Fld(i1, em::ex3) = metric.template transform<3, Idx::T, Idx::U>( + { i1_ }, + fset.ex3(x_Ph_0)); + } } - if (setBx1) { - Fld(i1, em::bx1) = metric.template transform<1, Idx::T, Idx::U>( - { i1_ }, - finit.bx1(x_Ph_0)); + if constexpr (defines_bx1) { + if (setBx1) { + Fld(i1, em::bx1) = metric.template transform<1, Idx::T, Idx::U>( + { i1_ }, + fset.bx1(x_Ph_0)); + } } - if (setBx2) { - Fld(i1, em::bx2) = metric.template transform<2, Idx::T, Idx::U>( - { i1_ + HALF }, - finit.bx2(x_Ph_H)); + if constexpr (defines_bx2) { + if (setBx2) { + Fld(i1, em::bx2) = metric.template transform<2, Idx::T, Idx::U>( + { i1_ + HALF }, + fset.bx2(x_Ph_H)); + } } - if (setBx3) { - Fld(i1, em::bx3) = metric.template transform<3, Idx::T, Idx::U>( - { i1_ + HALF }, - finit.bx3(x_Ph_H)); + if constexpr (defines_bx3) { + if (setBx3) { + Fld(i1, em::bx3) = metric.template transform<3, Idx::T, Idx::U>( + { i1_ + HALF }, + fset.bx3(x_Ph_H)); + } } } else { raise::KernelError(HERE, "Invalid Dimension"); @@ -324,8 +1046,12 @@ namespace kernel { metric.template convert({ i1_ + HALF, i2_ }, x_Ph_H0); metric.template convert({ i1_ + HALF, i2_ + HALF }, x_Ph_HH); - bool setEx1 = setE, setEx2 = setE, setEx3 = setE, setBx1 = setB, - setBx2 = setB, setBx3 = setB; + bool setEx1 = defines_ex1 and (tags & BC::E), + setEx2 = defines_ex2 and (tags & BC::E), + setEx3 = defines_ex3 and (tags & BC::E), + setBx1 = defines_bx1 and (tags & BC::B), + setBx2 = defines_bx2 and (tags & BC::B), + setBx3 = defines_bx3 and (tags & BC::B); if constexpr (O == in::x1) { // x1 -- normal // x2,x3 -- tangential @@ -353,35 +1079,47 @@ namespace kernel { } else { raise::KernelError(HERE, "Invalid Orientation"); } - if (setEx1) { - Fld(i1, i2, em::ex1) = metric.template transform<1, Idx::T, Idx::U>( - { i1_ + HALF, i2_ }, - finit.ex1(x_Ph_H0)); + if constexpr (defines_ex1) { + if (setEx1) { + Fld(i1, i2, em::ex1) = metric.template transform<1, Idx::T, Idx::U>( + { i1_ + HALF, i2_ }, + fset.ex1(x_Ph_H0)); + } } - if (setEx2) { - Fld(i1, i2, em::ex2) = metric.template transform<2, Idx::T, Idx::U>( - { i1_, i2_ + HALF }, - finit.ex2(x_Ph_0H)); + if constexpr (defines_ex2) { + if (setEx2) { + Fld(i1, i2, em::ex2) = metric.template transform<2, Idx::T, Idx::U>( + { i1_, i2_ + HALF }, + fset.ex2(x_Ph_0H)); + } } - if (setEx3) { - Fld(i1, i2, em::ex3) = metric.template transform<3, Idx::T, Idx::U>( - { i1_, i2_ }, - finit.ex3(x_Ph_00)); + if constexpr (defines_ex3) { + if (setEx3) { + Fld(i1, i2, em::ex3) = metric.template transform<3, Idx::T, Idx::U>( + { i1_, i2_ }, + fset.ex3(x_Ph_00)); + } } - if (setBx1) { - Fld(i1, i2, em::bx1) = metric.template transform<1, Idx::T, Idx::U>( - { i1_, i2_ + HALF }, - finit.bx1(x_Ph_0H)); + if constexpr (defines_bx1) { + if (setBx1) { + Fld(i1, i2, em::bx1) = metric.template transform<1, Idx::T, Idx::U>( + { i1_, i2_ + HALF }, + fset.bx1(x_Ph_0H)); + } } - if (setBx2) { - Fld(i1, i2, em::bx2) = metric.template transform<2, Idx::T, Idx::U>( - { i1_ + HALF, i2_ }, - finit.bx2(x_Ph_H0)); + if constexpr (defines_bx2) { + if (setBx2) { + Fld(i1, i2, em::bx2) = metric.template transform<2, Idx::T, Idx::U>( + { i1_ + HALF, i2_ }, + fset.bx2(x_Ph_H0)); + } } - if (setBx3) { - Fld(i1, i2, em::bx3) = metric.template transform<3, Idx::T, Idx::U>( - { i1_ + HALF, i2_ + HALF }, - finit.bx3(x_Ph_HH)); + if constexpr (defines_bx3) { + if (setBx3) { + Fld(i1, i2, em::bx3) = metric.template transform<3, Idx::T, Idx::U>( + { i1_ + HALF, i2_ + HALF }, + fset.bx3(x_Ph_HH)); + } } } else { raise::KernelError(HERE, "Invalid Dimension"); @@ -412,8 +1150,12 @@ namespace kernel { x_Ph_H0H); metric.template convert({ i1_, i2_ + HALF, i3_ + HALF }, x_Ph_0HH); - bool setEx1 = setE, setEx2 = setE, setEx3 = setE, setBx1 = setB, - setBx2 = setB, setBx3 = setB; + bool setEx1 = defines_ex1 and (tags & BC::E), + setEx2 = defines_ex2 and (tags & BC::E), + setEx3 = defines_ex3 and (tags & BC::E), + setBx1 = defines_bx1 and (tags & BC::B), + setBx2 = defines_bx2 and (tags & BC::B), + setBx3 = defines_bx3 and (tags & BC::B); if constexpr (O == in::x1) { // x1 -- normal // x2,x3 -- tangential @@ -453,35 +1195,47 @@ namespace kernel { } else { raise::KernelError(HERE, "Invalid Orientation"); } - if (setEx1) { - Fld(i1, i2, i3, em::ex1) = metric.template transform<1, Idx::T, Idx::U>( - { i1_ + HALF, i2_, i3_ }, - finit.ex1(x_Ph_H00)); + if constexpr (defines_ex1) { + if (setEx1) { + Fld(i1, i2, i3, em::ex1) = metric.template transform<1, Idx::T, Idx::U>( + { i1_ + HALF, i2_, i3_ }, + fset.ex1(x_Ph_H00)); + } } - if (setEx2) { - Fld(i1, i2, i3, em::ex2) = metric.template transform<2, Idx::T, Idx::U>( - { i1_, i2_ + HALF, i3_ }, - finit.ex2(x_Ph_0H0)); + if constexpr (defines_ex2) { + if (setEx2) { + Fld(i1, i2, i3, em::ex2) = metric.template transform<2, Idx::T, Idx::U>( + { i1_, i2_ + HALF, i3_ }, + fset.ex2(x_Ph_0H0)); + } } - if (setEx3) { - Fld(i1, i2, i3, em::ex3) = metric.template transform<3, Idx::T, Idx::U>( - { i1_, i2_, i3_ + HALF }, - finit.ex3(x_Ph_00H)); + if constexpr (defines_ex3) { + if (setEx3) { + Fld(i1, i2, i3, em::ex3) = metric.template transform<3, Idx::T, Idx::U>( + { i1_, i2_, i3_ + HALF }, + fset.ex3(x_Ph_00H)); + } } - if (setBx1) { - Fld(i1, i2, i3, em::bx1) = metric.template transform<1, Idx::T, Idx::U>( - { i1_, i2_ + HALF, i3_ + HALF }, - finit.bx1(x_Ph_0HH)); + if constexpr (defines_bx1) { + if (setBx1) { + Fld(i1, i2, i3, em::bx1) = metric.template transform<1, Idx::T, Idx::U>( + { i1_, i2_ + HALF, i3_ + HALF }, + fset.bx1(x_Ph_0HH)); + } } - if (setBx2) { - Fld(i1, i2, i3, em::bx2) = metric.template transform<2, Idx::T, Idx::U>( - { i1_ + HALF, i2_, i3_ + HALF }, - finit.bx2(x_Ph_H0H)); + if constexpr (defines_bx2) { + if (setBx2) { + Fld(i1, i2, i3, em::bx2) = metric.template transform<2, Idx::T, Idx::U>( + { i1_ + HALF, i2_, i3_ + HALF }, + fset.bx2(x_Ph_H0H)); + } } - if (setBx3) { - Fld(i1, i2, i3, em::bx3) = metric.template transform<3, Idx::T, Idx::U>( - { i1_ + HALF, i2_ + HALF, i3_ }, - finit.bx3(x_Ph_HH0)); + if constexpr (defines_bx3) { + if (setBx3) { + Fld(i1, i2, i3, em::bx3) = metric.template transform<3, Idx::T, Idx::U>( + { i1_ + HALF, i2_ + HALF, i3_ }, + fset.bx3(x_Ph_HH0)); + } } } else { raise::KernelError(HERE, "Invalid Dimension"); @@ -489,6 +1243,105 @@ namespace kernel { } }; -} // namespace kernel + namespace gr { + + template + struct HorizonBoundaries_kernel { + ndfield_t Fld; + const std::size_t i1_min; + const bool setE, setB; + const std::size_t nfilter; + + HorizonBoundaries_kernel(ndfield_t Fld, + std::size_t i1_min, + BCTags tags, + std::size_t nfilter) + : Fld { Fld } + , i1_min { i1_min } + , setE { (tags & BC::Ex1 or tags & BC::Ex2 or tags & BC::Ex3) or + (tags & BC::Dx1 or tags & BC::Dx2 or tags & BC::Dx3) } + , setB { (tags & BC::Bx1 or tags & BC::Bx2 or tags & BC::Bx3) or + (tags & BC::Hx1 or tags & BC::Hx2 or tags & BC::Hx3) } + , nfilter { nfilter } {} + + Inline void operator()(index_t i2) const { + if constexpr (M::Dim == Dim::_2D) { + if (setE) { + for (unsigned short i = 0; i <= 2 + nfilter; ++i) { + Fld(i1_min - N_GHOSTS + i, + i2, + em::dx1) = Fld(i1_min + 1 + nfilter, i2, em::dx1); + Fld(i1_min - N_GHOSTS + i, + i2, + em::dx2) = Fld(i1_min + 1 + nfilter, i2, em::dx2); + Fld(i1_min - N_GHOSTS + i, + i2, + em::dx3) = Fld(i1_min + 1 + nfilter, i2, em::dx3); + } + } + if (setB) { + for (unsigned short i = 0; i <= 2 + nfilter; ++i) { + Fld(i1_min - N_GHOSTS + i, + i2, + em::bx1) = Fld(i1_min + 1 + nfilter, i2, em::bx1); + Fld(i1_min - N_GHOSTS + i, + i2, + em::bx2) = Fld(i1_min + 1 + nfilter, i2, em::bx2); + Fld(i1_min - N_GHOSTS + i, + i2, + em::bx3) = Fld(i1_min + 1 + nfilter, i2, em::bx3); + } + } + } else { + raise::KernelError( + HERE, + "HorizonBoundaries_kernel: 2D implementation called for D != 2"); + } + } + }; + + template + struct AbsorbCurrents_kernel { + static_assert(M::is_metric, "M must be a metric class"); + static_assert(i <= static_cast(M::Dim), + "Invalid component index"); + + ndfield_t J; + const M metric; + const real_t xg_edge; + const real_t dx_abs; + + AbsorbCurrents_kernel(ndfield_t J, + const M& metric, + real_t xg_edge, + real_t dx_abs) + : J { J } + , metric { metric } + , xg_edge { xg_edge } + , dx_abs { dx_abs } {} + + Inline void operator()(index_t i1, index_t i2) const { + if constexpr (M::Dim == Dim::_2D) { + const auto i1_ = COORD(i1); + const auto i2_ = COORD(i2); + coord_t x_Cd { ZERO }; + x_Cd[0] = i1_; + x_Cd[1] = i2_; + const auto dx = math::abs( + metric.template convert(x_Cd[i - 1]) - xg_edge); + J(i1, i2, 0) *= math::tanh(dx / (INV_4 * dx_abs)); + J(i1, i2, 1) *= math::tanh(dx / (INV_4 * dx_abs)); + J(i1, i2, 2) *= math::tanh(dx / (INV_4 * dx_abs)); + + } else { + raise::KernelError( + HERE, + "gr::AbsorbCurrents_kernel: 2D implementation called for D != 2"); + } + } + }; + } // namespace gr + +} // namespace kernel::bc #endif // KERNELS_FIELDS_BCS_HPP diff --git a/src/kernels/injectors.hpp b/src/kernels/injectors.hpp index 9d3fd7d81..c321d18f8 100644 --- a/src/kernels/injectors.hpp +++ b/src/kernels/injectors.hpp @@ -24,7 +24,6 @@ namespace kernel { using namespace ntt; - using spidx_t = unsigned short; template struct UniformInjector_kernel { @@ -47,9 +46,9 @@ namespace kernel { array_t weights_2; array_t tags_2; - std::size_t offset1, offset2; + npart_t offset1, offset2; const M metric; - const array_t ni; + const array_t xi_min, xi_max; const ED energy_dist; const real_t inv_V0; random_number_pool_t random_pool; @@ -58,10 +57,11 @@ namespace kernel { spidx_t spidx2, Particles& species1, Particles& species2, - std::size_t offset1, - std::size_t offset2, + npart_t offset1, + npart_t offset2, const M& metric, - const array_t& ni, + const array_t& xi_min, + const array_t& xi_max, const ED& energy_dist, real_t inv_V0, random_number_pool_t& random_pool) @@ -94,7 +94,8 @@ namespace kernel { , offset1 { offset1 } , offset2 { offset2 } , metric { metric } - , ni { ni } + , xi_min { xi_min } + , xi_max { xi_max } , energy_dist { energy_dist } , inv_V0 { inv_V0 } , random_pool { random_pool } {} @@ -104,12 +105,12 @@ namespace kernel { vec_t v1 { ZERO }, v2 { ZERO }; { // generate a random coordinate auto rand_gen = random_pool.get_state(); - x_Cd[0] = Random(rand_gen) * ni(0); + x_Cd[0] = xi_min(0) + Random(rand_gen) * (xi_max(0) - xi_min(0)); if constexpr (M::Dim == Dim::_2D or M::Dim == Dim::_3D) { - x_Cd[1] = Random(rand_gen) * ni(1); + x_Cd[1] = xi_min(1) + Random(rand_gen) * (xi_max(1) - xi_min(1)); } if constexpr (M::Dim == Dim::_3D) { - x_Cd[2] = Random(rand_gen) * ni(2); + x_Cd[2] = xi_min(2) + Random(rand_gen) * (xi_max(2) - xi_min(2)); } random_pool.free_state(rand_gen); } @@ -117,27 +118,24 @@ namespace kernel { coord_t x_Ph { ZERO }; metric.template convert(x_Cd, x_Ph); if constexpr (M::CoordType == Coord::Cart) { - vec_t v_Ph { ZERO }; - energy_dist(x_Ph, v_Ph, spidx1); - metric.template transform_xyz(x_Ph, v_Ph, v1); - energy_dist(x_Ph, v_Ph, spidx2); - metric.template transform_xyz(x_Ph, v_Ph, v2); + energy_dist(x_Ph, v1, spidx1); + energy_dist(x_Ph, v2, spidx2); } else if constexpr (S == SimEngine::SRPIC) { - coord_t x_Ph_ { ZERO }; - x_Ph_[0] = x_Ph[0]; - x_Ph_[1] = x_Ph[1]; - x_Ph_[2] = ZERO; // phi = 0 + coord_t x_Cd_ { ZERO }; + x_Cd_[0] = x_Cd[0]; + x_Cd_[1] = x_Cd[1]; + x_Cd_[2] = ZERO; // phi = 0 vec_t v_Ph { ZERO }; energy_dist(x_Ph, v_Ph, spidx1); - metric.template transform_xyz(x_Ph_, v_Ph, v1); + metric.template transform_xyz(x_Cd_, v_Ph, v1); energy_dist(x_Ph, v_Ph, spidx2); - metric.template transform_xyz(x_Ph_, v_Ph, v2); + metric.template transform_xyz(x_Cd_, v_Ph, v2); } else if constexpr (S == SimEngine::GRPIC) { vec_t v_Ph { ZERO }; energy_dist(x_Ph, v_Ph, spidx1); - metric.template transform(x_Ph, v_Ph, v1); + metric.template transform(x_Cd, v_Ph, v1); energy_dist(x_Ph, v_Ph, spidx2); - metric.template transform(x_Ph, v_Ph, v2); + metric.template transform(x_Cd, v_Ph, v2); } else { raise::KernelError(HERE, "Unknown simulation engine"); } @@ -185,6 +183,176 @@ namespace kernel { } }; // struct UniformInjector_kernel + namespace experimental { + + template + struct UniformInjector_kernel { + static_assert(ED1::is_energy_dist, + "ED1 must be an energy distribution class"); + static_assert(ED2::is_energy_dist, + "ED2 must be an energy distribution class"); + static_assert(M::is_metric, "M must be a metric class"); + + const spidx_t spidx1, spidx2; + + array_t i1s_1, i2s_1, i3s_1; + array_t dx1s_1, dx2s_1, dx3s_1; + array_t ux1s_1, ux2s_1, ux3s_1; + array_t phis_1; + array_t weights_1; + array_t tags_1; + + array_t i1s_2, i2s_2, i3s_2; + array_t dx1s_2, dx2s_2, dx3s_2; + array_t ux1s_2, ux2s_2, ux3s_2; + array_t phis_2; + array_t weights_2; + array_t tags_2; + + npart_t offset1, offset2; + const M metric; + const array_t xi_min, xi_max; + const ED1 energy_dist_1; + const ED2 energy_dist_2; + const real_t inv_V0; + random_number_pool_t random_pool; + + UniformInjector_kernel(spidx_t spidx1, + spidx_t spidx2, + Particles& species1, + Particles& species2, + npart_t offset1, + npart_t offset2, + const M& metric, + const array_t& xi_min, + const array_t& xi_max, + const ED1& energy_dist_1, + const ED2& energy_dist_2, + real_t inv_V0, + random_number_pool_t& random_pool) + : spidx1 { spidx1 } + , spidx2 { spidx2 } + , i1s_1 { species1.i1 } + , i2s_1 { species1.i2 } + , i3s_1 { species1.i3 } + , dx1s_1 { species1.dx1 } + , dx2s_1 { species1.dx2 } + , dx3s_1 { species1.dx3 } + , ux1s_1 { species1.ux1 } + , ux2s_1 { species1.ux2 } + , ux3s_1 { species1.ux3 } + , phis_1 { species1.phi } + , weights_1 { species1.weight } + , tags_1 { species1.tag } + , i1s_2 { species2.i1 } + , i2s_2 { species2.i2 } + , i3s_2 { species2.i3 } + , dx1s_2 { species2.dx1 } + , dx2s_2 { species2.dx2 } + , dx3s_2 { species2.dx3 } + , ux1s_2 { species2.ux1 } + , ux2s_2 { species2.ux2 } + , ux3s_2 { species2.ux3 } + , phis_2 { species2.phi } + , weights_2 { species2.weight } + , tags_2 { species2.tag } + , offset1 { offset1 } + , offset2 { offset2 } + , metric { metric } + , xi_min { xi_min } + , xi_max { xi_max } + , energy_dist_1 { energy_dist_1 } + , energy_dist_2 { energy_dist_2 } + , inv_V0 { inv_V0 } + , random_pool { random_pool } {} + + Inline void operator()(index_t p) const { + coord_t x_Cd { ZERO }; + vec_t v1 { ZERO }, v2 { ZERO }; + { // generate a random coordinate + auto rand_gen = random_pool.get_state(); + x_Cd[0] = xi_min(0) + Random(rand_gen) * (xi_max(0) - xi_min(0)); + if constexpr (M::Dim == Dim::_2D or M::Dim == Dim::_3D) { + x_Cd[1] = xi_min(1) + + Random(rand_gen) * (xi_max(1) - xi_min(1)); + } + if constexpr (M::Dim == Dim::_3D) { + x_Cd[2] = xi_min(2) + + Random(rand_gen) * (xi_max(2) - xi_min(2)); + } + random_pool.free_state(rand_gen); + } + { // generate the velocity + coord_t x_Ph { ZERO }; + metric.template convert(x_Cd, x_Ph); + if constexpr (M::CoordType == Coord::Cart) { + energy_dist_1(x_Ph, v1, spidx1); + energy_dist_2(x_Ph, v2, spidx2); + } else if constexpr (S == SimEngine::SRPIC) { + coord_t x_Cd_ { ZERO }; + x_Cd_[0] = x_Cd[0]; + x_Cd_[1] = x_Cd[1]; + x_Cd_[2] = ZERO; // phi = 0 + vec_t v_Ph { ZERO }; + energy_dist_1(x_Ph, v_Ph, spidx1); + metric.template transform_xyz(x_Cd_, v_Ph, v1); + energy_dist_2(x_Ph, v_Ph, spidx2); + metric.template transform_xyz(x_Cd_, v_Ph, v2); + } else if constexpr (S == SimEngine::GRPIC) { + vec_t v_Ph { ZERO }; + energy_dist_1(x_Ph, v_Ph, spidx1); + metric.template transform(x_Cd, v_Ph, v1); + energy_dist_2(x_Ph, v_Ph, spidx2); + metric.template transform(x_Cd, v_Ph, v2); + } else { + raise::KernelError(HERE, "Unknown simulation engine"); + } + } + // inject + i1s_1(p + offset1) = static_cast(x_Cd[0]); + dx1s_1(p + offset1) = static_cast( + x_Cd[0] - static_cast(i1s_1(p + offset1))); + i1s_2(p + offset2) = i1s_1(p + offset1); + dx1s_2(p + offset2) = dx1s_1(p + offset1); + if constexpr (M::Dim == Dim::_2D or M::Dim == Dim::_3D) { + i2s_1(p + offset1) = static_cast(x_Cd[1]); + dx2s_1(p + offset1) = static_cast( + x_Cd[1] - static_cast(i2s_1(p + offset1))); + i2s_2(p + offset2) = i2s_1(p + offset1); + dx2s_2(p + offset2) = dx2s_1(p + offset1); + if constexpr (S == SimEngine::SRPIC && M::CoordType != Coord::Cart) { + phis_1(p + offset1) = ZERO; + phis_2(p + offset2) = ZERO; + } + } + if constexpr (M::Dim == Dim::_3D) { + i3s_1(p + offset1) = static_cast(x_Cd[2]); + dx3s_1(p + offset1) = static_cast( + x_Cd[2] - static_cast(i3s_1(p + offset1))); + i3s_2(p + offset2) = i3s_1(p + offset1); + dx3s_2(p + offset2) = dx3s_1(p + offset1); + } + ux1s_1(p + offset1) = v1[0]; + ux2s_1(p + offset1) = v1[1]; + ux3s_1(p + offset1) = v1[2]; + ux1s_2(p + offset2) = v2[0]; + ux2s_2(p + offset2) = v2[1]; + ux3s_2(p + offset2) = v2[2]; + tags_1(p + offset1) = ParticleTag::alive; + tags_2(p + offset2) = ParticleTag::alive; + if constexpr (M::CoordType == Coord::Cart) { + weights_1(p + offset1) = ONE; + weights_2(p + offset2) = ONE; + } else { + const auto sqrt_det_h = metric.sqrt_det_h(x_Cd); + weights_1(p + offset1) = sqrt_det_h * inv_V0; + weights_2(p + offset2) = sqrt_det_h * inv_V0; + } + } + }; // struct UniformInjector_kernel + + } // namespace experimental + template struct GlobalInjector_kernel { static_assert(M::is_metric, "M must be a metric class"); @@ -201,20 +369,20 @@ namespace kernel { array_t in_phi; array_t in_wei; - array_t idx { "idx" }; - array_t i1s, i2s, i3s; - array_t dx1s, dx2s, dx3s; - array_t ux1s, ux2s, ux3s; - array_t phis; - array_t weights; - array_t tags; + array_t idx { "idx" }; + array_t i1s, i2s, i3s; + array_t dx1s, dx2s, dx3s; + array_t ux1s, ux2s, ux3s; + array_t phis; + array_t weights; + array_t tags; - const std::size_t offset; + const npart_t offset; M global_metric; - real_t x1_min, x1_max, x2_min, x2_max, x3_min, x3_max; - std::size_t i1_offset, i2_offset, i3_offset; + real_t x1_min, x1_max, x2_min, x2_max, x3_min, x3_max; + ncells_t i1_offset, i2_offset, i3_offset; GlobalInjector_kernel(Particles& species, const M& global_metric, @@ -269,18 +437,18 @@ namespace kernel { void copy_from_vector(const std::string& name, array_t& arr, const std::map>& data, - std::size_t n_inject) { + npart_t n_inject) { raise::ErrorIf(data.find(name) == data.end(), name + " not found in data", HERE); raise::ErrorIf(data.at(name).size() != n_inject, "Inconsistent data size", HERE); arr = array_t { name, n_inject }; auto arr_h = Kokkos::create_mirror_view(arr); - for (std::size_t i = 0; i < data.at(name).size(); ++i) { + for (auto i = 0u; i < data.at(name).size(); ++i) { arr_h(i) = data.at(name)[i]; } Kokkos::deep_copy(arr, arr_h); } - auto number_injected() const -> std::size_t { + auto number_injected() const -> npart_t { auto idx_h = Kokkos::create_mirror_view(idx); Kokkos::deep_copy(idx_h, idx); return idx_h(); @@ -298,7 +466,7 @@ namespace kernel { global_metric.template transform_xyz(x_Cd, u_Ph, u_XYZ); const auto i1 = static_cast( - static_cast(x_Cd[0]) - i1_offset); + static_cast(x_Cd[0]) - i1_offset); const auto dx1 = static_cast( x_Cd[0] - static_cast(i1 + i1_offset)); @@ -322,9 +490,8 @@ namespace kernel { vec_t u_Ph { in_ux1(p), in_ux2(p), in_ux3(p) }; coord_t x_Cd_ { ZERO }; - auto index { - offset + Kokkos::atomic_fetch_add(&idx(), static_cast(1)) - }; + auto index { offset + + Kokkos::atomic_fetch_add(&idx(), static_cast(1)) }; global_metric.template convert({ in_x1(p), in_x2(p) }, x_Cd); x_Cd_[0] = x_Cd[0]; @@ -335,16 +502,16 @@ namespace kernel { if constexpr (S == SimEngine::SRPIC) { global_metric.template transform_xyz(x_Cd_, u_Ph, u_Cd); } else if constexpr (S == SimEngine::GRPIC) { - global_metric.template transform(x_Cd, u_Ph, u_Cd); + global_metric.template transform(x_Cd, u_Ph, u_Cd); } else { raise::KernelError(HERE, "Unknown simulation engine"); } const auto i1 = static_cast( - static_cast(x_Cd[0]) - i1_offset); + static_cast(x_Cd[0]) - i1_offset); const auto dx1 = static_cast( x_Cd[0] - static_cast(i1 + i1_offset)); const auto i2 = static_cast( - static_cast(x_Cd[1]) - i2_offset); + static_cast(x_Cd[1]) - i2_offset); const auto dx2 = static_cast( x_Cd[1] - static_cast(i2 + i2_offset)); @@ -380,20 +547,20 @@ namespace kernel { if constexpr (S == SimEngine::SRPIC) { global_metric.template transform_xyz(x_Cd, u_Ph, u_Cd); } else if constexpr (S == SimEngine::GRPIC) { - global_metric.template transform(x_Cd, u_Ph, u_Cd); + global_metric.template transform(x_Cd, u_Ph, u_Cd); } else { raise::KernelError(HERE, "Unknown simulation engine"); } const auto i1 = static_cast( - static_cast(x_Cd[0]) - i1_offset); + static_cast(x_Cd[0]) - i1_offset); const auto dx1 = static_cast( x_Cd[0] - static_cast(i1 + i1_offset)); const auto i2 = static_cast( - static_cast(x_Cd[1]) - i2_offset); + static_cast(x_Cd[1]) - i2_offset); const auto dx2 = static_cast( x_Cd[1] - static_cast(i2 + i2_offset)); const auto i3 = static_cast( - static_cast(x_Cd[2]) - i3_offset); + static_cast(x_Cd[2]) - i3_offset); const auto dx3 = static_cast( x_Cd[2] - static_cast(i3 + i3_offset)); @@ -440,9 +607,9 @@ namespace kernel { array_t weights_2; array_t tags_2; - array_t idx { "idx" }; + array_t idx { "idx" }; - std::size_t offset1, offset2; + npart_t offset1, offset2; M metric; const ED energy_dist; const SD spatial_dist; @@ -454,8 +621,8 @@ namespace kernel { spidx_t spidx2, Particles& species1, Particles& species2, - std::size_t offset1, - std::size_t offset2, + npart_t offset1, + npart_t offset2, const M& metric, const ED& energy_dist, const SD& spatial_dist, @@ -496,7 +663,7 @@ namespace kernel { , inv_V0 { inv_V0 } , random_pool { random_pool } {} - auto number_injected() const -> std::size_t { + auto number_injected() const -> npart_t { auto idx_h = Kokkos::create_mirror_view(idx); Kokkos::deep_copy(idx_h, idx); return idx_h(); @@ -508,7 +675,7 @@ namespace kernel { coord_t x_Cd { i1_ + HALF }; coord_t x_Ph { ZERO }; metric.template convert(x_Cd, x_Ph); - const auto ppc = static_cast(ppc0 * spatial_dist(x_Ph)); + const auto ppc = static_cast(ppc0 * spatial_dist(x_Ph)); if (ppc == 0) { return; } @@ -564,7 +731,7 @@ namespace kernel { x_Cd_[2] = ZERO; } metric.template convert(x_Cd, x_Ph); - const auto ppc = static_cast(ppc0 * spatial_dist(x_Ph)); + const auto ppc = static_cast(ppc0 * spatial_dist(x_Ph)); if (ppc == 0) { return; } @@ -631,7 +798,7 @@ namespace kernel { coord_t x_Cd { i1_ + HALF, i2_ + HALF, i3_ + HALF }; coord_t x_Ph { ZERO }; metric.template convert(x_Cd, x_Ph); - const auto ppc = static_cast(ppc0 * spatial_dist(x_Ph)); + const auto ppc = static_cast(ppc0 * spatial_dist(x_Ph)); if (ppc == 0) { return; } diff --git a/src/kernels/particle_moments.hpp b/src/kernels/particle_moments.hpp index 8b668a036..9c66aed01 100644 --- a/src/kernels/particle_moments.hpp +++ b/src/kernels/particle_moments.hpp @@ -14,11 +14,10 @@ #include "global.h" #include "arch/kokkos_aliases.h" +#include "utils/comparators.h" #include "utils/error.h" #include "utils/numeric.h" -#include - #include namespace kernel { @@ -40,13 +39,15 @@ namespace kernel { static_assert(M::is_metric, "M must be a metric class"); static constexpr auto D = M::Dim; - static_assert((F == FldsID::Rho) || (F == FldsID::Charge) || - (F == FldsID::N) || (F == FldsID::Nppc) || (F == FldsID::T), + static_assert(!((S == SimEngine::GRPIC) && (F == FldsID::V)), + "Bulk velocity not supported for GRPIC"); + static_assert((F == FldsID::Rho) || (F == FldsID::Charge) || (F == FldsID::N) || + (F == FldsID::Nppc) || (F == FldsID::T) || (F == FldsID::V), "Invalid field ID"); const unsigned short c1, c2; scatter_ndfield_t Buff; - const unsigned short buff_idx; + const idx_t buff_idx; const array_t i1, i2, i3; const array_t dx1, dx2, dx3; const array_t ux1, ux2, ux3; @@ -58,7 +59,6 @@ namespace kernel { const bool use_weights; const M metric; const int ni2; - const real_t inv_n0; const unsigned short window; const real_t contrib; @@ -68,7 +68,7 @@ namespace kernel { public: ParticleMoments_kernel(const std::vector& components, const scatter_ndfield_t& scatter_buff, - unsigned short buff_idx, + idx_t buff_idx, const array_t& i1, const array_t& i2, const array_t& i3, @@ -86,11 +86,11 @@ namespace kernel { bool use_weights, const M& metric, const boundaries_t& boundaries, - std::size_t ni2, + ncells_t ni2, real_t inv_n0, unsigned short window) - : c1 { (components.size() == 2) ? components[0] - : static_cast(0) } + : c1 { (components.size() > 0) ? components[0] + : static_cast(0) } , c2 { (components.size() == 2) ? components[1] : static_cast(0) } , Buff { scatter_buff } @@ -112,11 +112,10 @@ namespace kernel { , use_weights { use_weights } , metric { metric } , ni2 { static_cast(ni2) } - , inv_n0 { inv_n0 } , window { window } , contrib { get_contrib(mass, charge) } - , smooth { ONE / (real_t)(math::pow(TWO * (real_t)window + ONE, - static_cast(D))) } { + , smooth { inv_n0 / (real_t)(math::pow(TWO * (real_t)window + ONE, + static_cast(D))) } { raise::ErrorIf(buff_idx >= N, "Invalid buffer index", HERE); raise::ErrorIf(window > N_GHOSTS, "Window size too large", HERE); raise::ErrorIf(((F == FldsID::Rho) || (F == FldsID::Charge)) && (mass == ZERO), @@ -200,30 +199,117 @@ namespace kernel { coeff *= u_Phys[c - 1]; } } + } else if constexpr (F == FldsID::V) { + real_t gamma { ZERO }; + // for bulk 3vel (tetrad basis) + vec_t u_Phys { ZERO }; + if constexpr (M::CoordType == Coord::Cart) { + u_Phys[0] = ux1(p); + u_Phys[1] = ux2(p); + u_Phys[2] = ux3(p); + } else { + coord_t x_Code { ZERO }; + x_Code[0] = static_cast(i1(p)) + static_cast(dx1(p)); + x_Code[1] = static_cast(i2(p)) + static_cast(dx2(p)); + if constexpr (D == Dim::_3D) { + x_Code[2] = static_cast(i3(p)) + static_cast(dx3(p)); + } else { + x_Code[2] = phi(p); + } + metric.template transform_xyz(x_Code, + { ux1(p), ux2(p), ux3(p) }, + u_Phys); + } + if (mass == ZERO) { + gamma = NORM(u_Phys[0], u_Phys[1], u_Phys[2]); + } else { + gamma = math::sqrt(ONE + NORM_SQR(u_Phys[0], u_Phys[1], u_Phys[2])); + } + // compute the corresponding moment + coeff = (mass == ZERO ? ONE : mass) * u_Phys[c1 - 1] / gamma; } else { // for other cases, use the `contrib` defined above coeff = contrib; } + if constexpr (F == FldsID::V) { + real_t gamma { ZERO }; + // for stress-energy tensor + vec_t u_Phys { ZERO }; + if constexpr (S == SimEngine::SRPIC) { + // SR + // stress-energy tensor for SR is computed in the tetrad (hatted) basis + if constexpr (M::CoordType == Coord::Cart) { + u_Phys[0] = ux1(p); + u_Phys[1] = ux2(p); + u_Phys[2] = ux3(p); + } else { + static_assert(D != Dim::_1D, "non-Cartesian SRPIC 1D"); + coord_t x_Code { ZERO }; + x_Code[0] = static_cast(i1(p)) + static_cast(dx1(p)); + x_Code[1] = static_cast(i2(p)) + static_cast(dx2(p)); + if constexpr (D == Dim::_3D) { + x_Code[2] = static_cast(i3(p)) + static_cast(dx3(p)); + } else { + x_Code[2] = phi(p); + } + metric.template transform_xyz( + x_Code, + { ux1(p), ux2(p), ux3(p) }, + u_Phys); + } + if (mass == ZERO) { + gamma = NORM(u_Phys[0], u_Phys[1], u_Phys[2]); + } else { + gamma = math::sqrt(ONE + NORM_SQR(u_Phys[0], u_Phys[1], u_Phys[2])); + } + } else { + // GR + // stress-energy tensor for GR is computed in contravariant basis + static_assert(D != Dim::_1D, "GRPIC 1D"); + coord_t x_Code { ZERO }; + x_Code[0] = static_cast(i1(p)) + static_cast(dx1(p)); + x_Code[1] = static_cast(i2(p)) + static_cast(dx2(p)); + if constexpr (D == Dim::_3D) { + x_Code[2] = static_cast(i3(p)) + static_cast(dx3(p)); + } + vec_t u_Cntrv { ZERO }; + // compute u_i u^i for energy + metric.template transform(x_Code, + { ux1(p), ux2(p), ux3(p) }, + u_Cntrv); + gamma = u_Cntrv[0] * ux1(p) + u_Cntrv[1] * ux2(p) + u_Cntrv[2] * ux3(p); + if (mass == ZERO) { + gamma = math::sqrt(gamma); + } else { + gamma = math::sqrt(ONE + gamma); + } + metric.template transform(x_Code, u_Cntrv, u_Phys); + } + // compute the corresponding moment + coeff = u_Phys[c1 - 1] / gamma; + } + if constexpr (F != FldsID::Nppc) { // for nppc calculation ... // ... do not take volume, weights or smoothing into account if constexpr (D == Dim::_1D) { - coeff *= inv_n0 / + coeff *= smooth / metric.sqrt_det_h({ static_cast(i1(p)) + HALF }); } else if constexpr (D == Dim::_2D) { - coeff *= inv_n0 / + coeff *= smooth / metric.sqrt_det_h({ static_cast(i1(p)) + HALF, static_cast(i2(p)) + HALF }); } else if constexpr (D == Dim::_3D) { - coeff *= inv_n0 / + coeff *= smooth / metric.sqrt_det_h({ static_cast(i1(p)) + HALF, static_cast(i2(p)) + HALF, static_cast(i3(p)) + HALF }); } - coeff *= weight(p) * smooth; + if (use_weights) { + coeff *= weight(p); + } } - auto buff_access = Buff.access(); if constexpr (D == Dim::_1D) { for (auto di1 { -window }; di1 <= window; ++di1) { @@ -289,6 +375,79 @@ namespace kernel { } }; + template + class NormalizeVectorByRho_kernel { + const ndfield_t Rho; + ndfield_t Vector; + const unsigned short c_rho, c_v1, c_v2, c_v3; + + public: + NormalizeVectorByRho_kernel(const ndfield_t& rho, + const ndfield_t& vector, + unsigned short crho, + unsigned short cv1, + unsigned short cv2, + unsigned short cv3) + : Rho { rho } + , Vector { vector } + , c_rho { crho } + , c_v1 { cv1 } + , c_v2 { cv2 } + , c_v3 { cv3 } { + raise::ErrorIf(c_rho >= N or c_v1 >= N or c_v2 >= N or c_v3 >= N, + "Invalid component index", + HERE); + raise::ErrorIf(c_rho == c_v1 or c_rho == c_v2 or c_rho == c_v3, + "Invalid component index", + HERE); + raise::ErrorIf(c_v1 == c_v2 or c_v1 == c_v3 or c_v2 == c_v3, + "Invalid component index", + HERE); + } + + Inline void operator()(index_t i1) const { + if constexpr (D == Dim::_1D) { + if (not cmp::AlmostZero(Rho(i1, c_rho))) { + Vector(i1, c_v1) /= Rho(i1, c_rho); + Vector(i1, c_v2) /= Rho(i1, c_rho); + Vector(i1, c_v3) /= Rho(i1, c_rho); + } + } else { + raise::KernelError( + HERE, + "1D implementation of NormalizeVectorByRho_kernel called for non-1D"); + } + } + + Inline void operator()(index_t i1, index_t i2) const { + if constexpr (D == Dim::_2D) { + if (not cmp::AlmostZero(Rho(i1, i2, c_rho))) { + Vector(i1, i2, c_v1) /= Rho(i1, i2, c_rho); + Vector(i1, i2, c_v2) /= Rho(i1, i2, c_rho); + Vector(i1, i2, c_v3) /= Rho(i1, i2, c_rho); + } + } else { + raise::KernelError( + HERE, + "2D implementation of NormalizeVectorByRho_kernel called for non-2D"); + } + } + + Inline void operator()(index_t i1, index_t i2, index_t i3) const { + if constexpr (D == Dim::_3D) { + if (not cmp::AlmostZero(Rho(i1, i2, i3, c_rho))) { + Vector(i1, i2, i3, c_v1) /= Rho(i1, i2, i3, c_rho); + Vector(i1, i2, i3, c_v2) /= Rho(i1, i2, i3, c_rho); + Vector(i1, i2, i3, c_v3) /= Rho(i1, i2, i3, c_rho); + } + } else { + raise::KernelError( + HERE, + "3D implementation of NormalizeVectorByRho_kernel called for non-3D"); + } + } + }; + } // namespace kernel #endif // KERNELS_PARTICLE_MOMENTS_HPP diff --git a/src/kernels/particle_pusher_gr.hpp b/src/kernels/particle_pusher_gr.hpp index 547463fa7..c1dfdf949 100644 --- a/src/kernels/particle_pusher_gr.hpp +++ b/src/kernels/particle_pusher_gr.hpp @@ -29,7 +29,7 @@ /* Local macros */ /* -------------------------------------------------------------------------- */ #define from_Xi_to_i(XI, I) \ - { I = static_cast((XI)); } + { I = static_cast((XI + 1)) - 1; } #define from_Xi_to_i_di(XI, I, DI) \ { \ @@ -65,6 +65,7 @@ namespace kernel::gr { static_assert(M::is_metric, "M must be a metric class"); static constexpr auto D = M::Dim; + private: const randacc_ndfield_t DB; const randacc_ndfield_t DB0; array_t i1, i2, i3; @@ -76,44 +77,43 @@ namespace kernel::gr { array_t tag; const M metric; - const real_t coeff, dt; - const int ni1, ni2, ni3; - const real_t epsilon; - const int niter; - const int i1_absorb; + const real_t coeff, dt; + const int ni1, ni2, ni3; + const real_t epsilon; + const unsigned short niter; bool is_axis_i2min { false }, is_axis_i2max { false }; bool is_absorb_i1min { false }, is_absorb_i1max { false }; public: - Pusher_kernel(const ndfield_t& DB, - const ndfield_t& DB0, - const array_t& i1, - const array_t& i2, - const array_t& i3, - const array_t& i1_prev, - const array_t& i2_prev, - const array_t& i3_prev, - const array_t& dx1, - const array_t& dx2, - const array_t& dx3, - const array_t& dx1_prev, - const array_t& dx2_prev, - const array_t& dx3_prev, - const array_t& ux1, - const array_t& ux2, - const array_t& ux3, - const array_t& phi, - const array_t& tag, - const M& metric, - const real_t& coeff, - const real_t& dt, - const int& ni1, - const int& ni2, - const int& ni3, - const real_t& epsilon, - const int& niter, - const boundaries_t& boundaries) + Pusher_kernel(const ndfield_t& DB, + const ndfield_t& DB0, + array_t& i1, + array_t& i2, + array_t& i3, + array_t& i1_prev, + array_t& i2_prev, + array_t& i3_prev, + array_t& dx1, + array_t& dx2, + array_t& dx3, + array_t& dx1_prev, + array_t& dx2_prev, + array_t& dx3_prev, + array_t& ux1, + array_t& ux2, + array_t& ux3, + array_t& phi, + array_t& tag, + const M& metric, + real_t coeff, + real_t dt, + int ni1, + int ni2, + int ni3, + const real_t& epsilon, + const unsigned short& niter, + const boundaries_t& boundaries) : DB { DB } , DB0 { DB0 } , i1 { i1 } @@ -140,10 +140,7 @@ namespace kernel::gr { , ni2 { ni2 } , ni3 { ni3 } , epsilon { epsilon } - , niter { niter } - , i1_absorb { static_cast(metric.template convert<1, Crd::Ph, Crd::Cd>( - metric.rhorizon())) - - 5 } { + , niter { niter } { raise::ErrorIf(boundaries.size() < 2, "boundaries defined incorrectly", HERE); is_absorb_i1min = (boundaries[0].first == PrtlBC::ABSORB) || @@ -333,32 +330,48 @@ namespace kernel::gr { // find contravariant midpoint velocity metric.template transform(xp, vp_mid, vp_mid_cntrv); - // find Gamma / alpha at midpoint + // find Gamma / alpha at midpointΡ‹ real_t u0 { computeGamma(T {}, vp_mid, vp_mid_cntrv) / metric.alpha(xp) }; // find updated velocity - vp_upd[0] = - vp[0] + - dt * - (-metric.alpha(xp) * u0 * DERIVATIVE_IN_R(metric.alpha, xp) + - vp_mid[0] * DERIVATIVE_IN_R(metric.beta1, xp) - - (HALF / u0) * - (DERIVATIVE_IN_R((metric.template h<1, 1>), xp) * SQR(vp_mid[0]) + - DERIVATIVE_IN_R((metric.template h<2, 2>), xp) * SQR(vp_mid[1]) + - DERIVATIVE_IN_R((metric.template h<3, 3>), xp) * SQR(vp_mid[2]) + - TWO * DERIVATIVE_IN_R((metric.template h<1, 3>), xp) * - vp_mid[0] * vp_mid[2])); - vp_upd[1] = - vp[1] + - dt * - (-metric.alpha(xp) * u0 * DERIVATIVE_IN_TH(alpha, xp) + - vp_mid[1] * DERIVATIVE_IN_TH(beta1, xp) - - (HALF / u0) * - (DERIVATIVE_IN_TH((metric.template h<1, 1>), xp) * SQR(vp_mid[0]) + - DERIVATIVE_IN_TH((metric.template h<2, 2>), xp) * SQR(vp_mid[1]) + - DERIVATIVE_IN_TH((metric.template h<3, 3>), xp) * SQR(vp_mid[2]) + - TWO * DERIVATIVE_IN_TH((metric.template h<1, 3>), xp) * - vp_mid[0] * vp_mid[2])); + // vp_upd[0] = + // vp[0] + + // dt * + // (-metric.alpha(xp) * u0 * DERIVATIVE_IN_R(metric.alpha, xp) + + // vp_mid[0] * DERIVATIVE_IN_R(metric.beta1, xp) - + // (HALF / u0) * + // (DERIVATIVE_IN_R((metric.template h<1, 1>), xp) * SQR(vp_mid[0]) + + // DERIVATIVE_IN_R((metric.template h<2, 2>), xp) * SQR(vp_mid[1]) + + // DERIVATIVE_IN_R((metric.template h<3, 3>), xp) * SQR(vp_mid[2]) + + // TWO * DERIVATIVE_IN_R((metric.template h<1, 3>), xp) * + // vp_mid[0] * vp_mid[2])); + // vp_upd[1] = + // vp[1] + + // dt * + // (-metric.alpha(xp) * u0 * DERIVATIVE_IN_TH(metric.alpha, xp) + + // vp_mid[0] * DERIVATIVE_IN_TH(metric.beta1, xp) - + // (HALF / u0) * + // (DERIVATIVE_IN_TH((metric.template h<1, 1>), xp) * SQR(vp_mid[0]) + + // DERIVATIVE_IN_TH((metric.template h<2, 2>), xp) * SQR(vp_mid[1]) + + // DERIVATIVE_IN_TH((metric.template h<3, 3>), xp) * SQR(vp_mid[2]) + + // TWO * DERIVATIVE_IN_TH((metric.template h<1, 3>), xp) * + // vp_mid[0] * vp_mid[2])); + vp_upd[0] = vp[0] + + dt * (-metric.alpha(xp) * u0 * metric.dr_alpha(xp) + + vp_mid[0] * metric.dr_beta1(xp) - + (HALF / u0) * + (metric.dr_h11(xp) * SQR(vp_mid[0]) + + metric.dr_h22(xp) * SQR(vp_mid[1]) + + metric.dr_h33(xp) * SQR(vp_mid[2]) + + TWO * metric.dr_h13(xp) * vp_mid[0] * vp_mid[2])); + vp_upd[1] = vp[1] + + dt * (-metric.alpha(xp) * u0 * metric.dt_alpha(xp) + + vp_mid[0] * metric.dt_beta1(xp) - + (HALF / u0) * + (metric.dt_h11(xp) * SQR(vp_mid[0]) + + metric.dt_h22(xp) * SQR(vp_mid[1]) + + metric.dt_h33(xp) * SQR(vp_mid[2]) + + TWO * metric.dt_h13(xp) * vp_mid[0] * vp_mid[2])); } } else if constexpr (D == Dim::_3D) { raise::KernelNotImplementedError(HERE); @@ -455,7 +468,7 @@ namespace kernel::gr { dt * (-metric.alpha(xp_mid) * u0 * DERIVATIVE_IN_TH(metric.alpha, xp_mid) + - vp_mid[1] * DERIVATIVE_IN_TH(metric.beta1, xp_mid) - + vp_mid[0] * DERIVATIVE_IN_TH(metric.beta1, xp_mid) - (HALF / u0) * (DERIVATIVE_IN_TH((metric.template h<1, 1>), xp_mid) * SQR(vp_mid[0]) + @@ -659,6 +672,20 @@ namespace kernel::gr { xp[0] = i_di_to_Xi(i1(p), dx1(p)); xp[1] = i_di_to_Xi(i2(p), dx2(p)); + coord_t xp_ { ZERO }; + xp_[0] = xp[0]; + real_t theta_Cd { xp[1] }; + const real_t theta_Ph { metric.template convert<2, Crd::Cd, Crd::Ph>( + theta_Cd) }; + const real_t small_angle { constant::SMALL_ANGLE_GR }; + const auto large_angle { constant::PI - small_angle }; + if (theta_Ph < small_angle) { + theta_Cd = metric.template convert<2, Crd::Ph, Crd::Cd>(small_angle); + } else if (theta_Ph >= large_angle) { + theta_Cd = metric.template convert<2, Crd::Ph, Crd::Cd>(large_angle); + } + xp_[1] = theta_Cd; + vec_t Dp_cntrv { ZERO }, Bp_cntrv { ZERO }, Dp_hat { ZERO }, Bp_hat { ZERO }; interpolateFields(p, Dp_cntrv, Bp_cntrv); @@ -675,7 +702,7 @@ namespace kernel::gr { vp[0] = vp_upd[0]; vp[1] = vp_upd[1]; vp[2] = vp_upd[2]; - GeodesicMomentumPush(Massive_t {}, xp, vp, vp_upd); + GeodesicMomentumPush(Massive_t {}, xp_, vp, vp_upd); /* u**_i(n) -> u_i(n + 1/2) */ vp[0] = vp_upd[0]; vp[1] = vp_upd[1]; @@ -718,19 +745,23 @@ namespace kernel::gr { template Inline void Pusher_kernel::boundaryConditions(index_t& p) const { if constexpr (D == Dim::_1D || D == Dim::_2D || D == Dim::_3D) { - if (i1(p) < i1_absorb && is_absorb_i1min) { + if (i1(p) < 0 && is_absorb_i1min) { tag(p) = ParticleTag::dead; } else if (i1(p) >= ni1 && is_absorb_i1max) { tag(p) = ParticleTag::dead; } } if constexpr (D == Dim::_2D || D == Dim::_3D) { - if (i2(p) < 1) { + if (i2(p) < 0) { if (is_axis_i2min) { + i2(p) = 0; + dx2(p) = ONE - dx2(p); ux2(p) = -ux2(p); } - } else if (i2(p) >= ni2 - 1) { - if (is_axis_i2min) { + } else if (i2(p) >= ni2) { + if (is_axis_i2max) { + i2(p) = ni2 - 1; + dx2(p) = ONE - dx2(p); ux2(p) = -ux2(p); } } diff --git a/src/kernels/particle_pusher_sr.hpp b/src/kernels/particle_pusher_sr.hpp index 0deb73c6f..91bc6a760 100644 --- a/src/kernels/particle_pusher_sr.hpp +++ b/src/kernels/particle_pusher_sr.hpp @@ -90,7 +90,7 @@ namespace kernel::sr { Force(const F& pgen_force) : Force { pgen_force, - {ZERO, ZERO, ZERO}, + { ZERO, ZERO, ZERO }, ZERO, ZERO } { @@ -102,10 +102,10 @@ namespace kernel::sr { raise::ErrorIf(ExtForce, "External force not provided", HERE); } - Inline auto fx1(const unsigned short& sp, - const real_t& time, - bool ext_force, - const coord_t& x_Ph) const -> real_t { + Inline auto fx1(const spidx_t& sp, + const simtime_t& time, + bool ext_force, + const coord_t& x_Ph) const -> real_t { real_t f_x1 = ZERO; if constexpr (ExtForce) { if (ext_force) { @@ -128,10 +128,10 @@ namespace kernel::sr { return f_x1; } - Inline auto fx2(const unsigned short& sp, - const real_t& time, - bool ext_force, - const coord_t& x_Ph) const -> real_t { + Inline auto fx2(const spidx_t& sp, + const simtime_t& time, + bool ext_force, + const coord_t& x_Ph) const -> real_t { real_t f_x2 = ZERO; if constexpr (ExtForce) { if (ext_force) { @@ -154,10 +154,10 @@ namespace kernel::sr { return f_x2; } - Inline auto fx3(const unsigned short& sp, - const real_t& time, - bool ext_force, - const coord_t& x_Ph) const -> real_t { + Inline auto fx3(const spidx_t& sp, + const simtime_t& time, + bool ext_force, + const coord_t& x_Ph) const -> real_t { real_t f_x3 = ZERO; if constexpr (ExtForce) { if (ext_force) { @@ -198,7 +198,7 @@ namespace kernel::sr { const CoolingTags cooling; const randacc_ndfield_t EB; - const unsigned short sp; + const spidx_t sp; array_t i1, i2, i3; array_t i1_prev, i2_prev, i3_prev; array_t dx1, dx2, dx3; @@ -227,41 +227,41 @@ namespace kernel::sr { const real_t coeff_sync; public: - Pusher_kernel(const PrtlPusher::type& pusher, - bool GCA, - bool ext_force, - CoolingTags cooling, - const ndfield_t& EB, - unsigned short sp, - array_t& i1, - array_t& i2, - array_t& i3, - array_t& i1_prev, - array_t& i2_prev, - array_t& i3_prev, - array_t& dx1, - array_t& dx2, - array_t& dx3, - array_t& dx1_prev, - array_t& dx2_prev, - array_t& dx3_prev, - array_t& ux1, - array_t& ux2, - array_t& ux3, - array_t& phi, - array_t& tag, - const M& metric, - const F& force, - real_t time, - real_t coeff, - real_t dt, - int ni1, - int ni2, - int ni3, - const boundaries_t& boundaries, - real_t gca_larmor_max, - real_t gca_eovrb_max, - real_t coeff_sync) + Pusher_kernel(const PrtlPusher::type& pusher, + bool GCA, + bool ext_force, + CoolingTags cooling, + const randacc_ndfield_t& EB, + spidx_t sp, + array_t& i1, + array_t& i2, + array_t& i3, + array_t& i1_prev, + array_t& i2_prev, + array_t& i3_prev, + array_t& dx1, + array_t& dx2, + array_t& dx3, + array_t& dx1_prev, + array_t& dx2_prev, + array_t& dx3_prev, + array_t& ux1, + array_t& ux2, + array_t& ux3, + array_t& phi, + array_t& tag, + const M& metric, + const F& force, + real_t time, + real_t coeff, + real_t dt, + int ni1, + int ni2, + int ni3, + const boundaries_t& boundaries, + real_t gca_larmor_max, + real_t gca_eovrb_max, + real_t coeff_sync) : pusher { pusher } , GCA { GCA } , ext_force { ext_force } @@ -336,7 +336,7 @@ namespace kernel::sr { bool ext_force, CoolingTags cooling, const ndfield_t& EB, - unsigned short sp, + spidx_t sp, array_t& i1, array_t& i2, array_t& i3, @@ -355,7 +355,7 @@ namespace kernel::sr { array_t& phi, array_t& tag, const M& metric, - real_t time, + simtime_t time, real_t coeff, real_t dt, int ni1, @@ -562,45 +562,85 @@ namespace kernel::sr { Inline void posUpd(bool massive, index_t& p, coord_t& xp) const { // get cartesian velocity - const real_t inv_energy { - massive ? ONE / math::sqrt(ONE + SQR(ux1(p)) + SQR(ux2(p)) + SQR(ux3(p))) - : ONE / math::sqrt(SQR(ux1(p)) + SQR(ux2(p)) + SQR(ux3(p))) - }; - vec_t vp_Cart { ux1(p) * inv_energy, - ux2(p) * inv_energy, - ux3(p) * inv_energy }; - // get cartesian position - coord_t xp_Cart { ZERO }; - metric.template convert_xyz(xp, xp_Cart); - // update cartesian position - for (auto d = 0u; d < M::PrtlDim; ++d) { - xp_Cart[d] += vp_Cart[d] * dt; - } - // transform back to code - metric.template convert_xyz(xp_Cart, xp); - - // update x1 - if constexpr (D == Dim::_1D || D == Dim::_2D || D == Dim::_3D) { - i1_prev(p) = i1(p); - dx1_prev(p) = dx1(p); - from_Xi_to_i_di(xp[0], i1(p), dx1(p)); - } + if constexpr (M::CoordType == Coord::Cart) { + // i+di push for Cartesian basis + const real_t dt_inv_energy { + massive + ? (dt / math::sqrt(ONE + SQR(ux1(p)) + SQR(ux2(p)) + SQR(ux3(p)))) + : (dt / math::sqrt(SQR(ux1(p)) + SQR(ux2(p)) + SQR(ux3(p)))) + }; + if constexpr (D == Dim::_1D || D == Dim::_2D || D == Dim::_3D) { + i1_prev(p) = i1(p); + dx1_prev(p) = dx1(p); + dx1(p) += metric.template transform<1, Idx::XYZ, Idx::U>(xp, ux1(p)) * + dt_inv_energy; + i1(p) += static_cast(dx1(p) >= ONE) - + static_cast(dx1(p) < ZERO); + dx1(p) -= (dx1(p) >= ONE); + dx1(p) += (dx1(p) < ZERO); + } + if constexpr (D == Dim::_2D || D == Dim::_3D) { + i2_prev(p) = i2(p); + dx2_prev(p) = dx2(p); + dx2(p) += metric.template transform<2, Idx::XYZ, Idx::U>(xp, ux2(p)) * + dt_inv_energy; + i2(p) += static_cast(dx2(p) >= ONE) - + static_cast(dx2(p) < ZERO); + dx2(p) -= (dx2(p) >= ONE); + dx2(p) += (dx2(p) < ZERO); + } + if constexpr (D == Dim::_3D) { + i3_prev(p) = i3(p); + dx3_prev(p) = dx3(p); + dx3(p) += metric.template transform<3, Idx::XYZ, Idx::U>(xp, ux3(p)) * + dt_inv_energy; + i3(p) += static_cast(dx3(p) >= ONE) - + static_cast(dx3(p) < ZERO); + dx3(p) -= (dx3(p) >= ONE); + dx3(p) += (dx3(p) < ZERO); + } + } else { + // full Cartesian coordinate push in non-Cartesian basis + const real_t inv_energy { + massive ? ONE / math::sqrt(ONE + SQR(ux1(p)) + SQR(ux2(p)) + SQR(ux3(p))) + : ONE / math::sqrt(SQR(ux1(p)) + SQR(ux2(p)) + SQR(ux3(p))) + }; + vec_t vp_Cart { ux1(p) * inv_energy, + ux2(p) * inv_energy, + ux3(p) * inv_energy }; + // get cartesian position + coord_t xp_Cart { ZERO }; + metric.template convert_xyz(xp, xp_Cart); + // update cartesian position + for (auto d = 0u; d < M::PrtlDim; ++d) { + xp_Cart[d] += vp_Cart[d] * dt; + } + // transform back to code + metric.template convert_xyz(xp_Cart, xp); + + // update x1 + if constexpr (D == Dim::_1D || D == Dim::_2D || D == Dim::_3D) { + i1_prev(p) = i1(p); + dx1_prev(p) = dx1(p); + from_Xi_to_i_di(xp[0], i1(p), dx1(p)); + } - // update x2 & phi - if constexpr (D == Dim::_2D || D == Dim::_3D) { - i2_prev(p) = i2(p); - dx2_prev(p) = dx2(p); - from_Xi_to_i_di(xp[1], i2(p), dx2(p)); - if constexpr (D == Dim::_2D && M::PrtlDim == Dim::_3D) { - phi(p) = xp[2]; + // update x2 & phi + if constexpr (D == Dim::_2D || D == Dim::_3D) { + i2_prev(p) = i2(p); + dx2_prev(p) = dx2(p); + from_Xi_to_i_di(xp[1], i2(p), dx2(p)); + if constexpr (D == Dim::_2D && M::PrtlDim == Dim::_3D) { + phi(p) = xp[2]; + } } - } - // update x3 - if constexpr (D == Dim::_3D) { - i3_prev(p) = i3(p); - dx3_prev(p) = dx3(p); - from_Xi_to_i_di(xp[2], i3(p), dx3(p)); + // update x3 + if constexpr (D == Dim::_3D) { + i3_prev(p) = i3(p); + dx3_prev(p) = dx3(p); + from_Xi_to_i_di(xp[2], i3(p), dx3(p)); + } } boundaryConditions(p, xp); } diff --git a/src/kernels/prtls_to_phys.hpp b/src/kernels/prtls_to_phys.hpp index f12eefc95..4dd7d88b0 100644 --- a/src/kernels/prtls_to_phys.hpp +++ b/src/kernels/prtls_to_phys.hpp @@ -31,7 +31,7 @@ namespace kernel { static constexpr Dimension D = M::Dim; protected: - const std::size_t stride; + const npart_t stride; array_t buff_x1; array_t buff_x2; array_t buff_x3; @@ -47,7 +47,7 @@ namespace kernel { const M metric; public: - PrtlToPhys_kernel(std::size_t stride, + PrtlToPhys_kernel(npart_t stride, array_t& buff_x1, array_t& buff_x2, array_t& buff_x3, diff --git a/src/kernels/reduced_stats.hpp b/src/kernels/reduced_stats.hpp new file mode 100644 index 000000000..68e6aa97b --- /dev/null +++ b/src/kernels/reduced_stats.hpp @@ -0,0 +1,551 @@ +/** + * @file kernels/reduced_stats.hpp + * @brief Compute reduced field/moment quantities for stats output + * @implements + * - kernel::PrtlToPhys_kernel<> + * @namespaces: + * - kernel:: + */ + +#ifndef KERNELS_REDUCED_STATS_HPP +#define KERNELS_REDUCED_STATS_HPP + +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/numeric.h" + +namespace kernel { + using namespace ntt; + + template + class ReducedFields_kernel { + static_assert(M::is_metric, "M must be a metric class"); + static_assert(I <= 3, + "I must be less than or equal to 3 for ReducedFields_kernel"); + static constexpr auto D = M::Dim; + + ndfield_t EM; + ndfield_t J; + const M metric; + + public: + ReducedFields_kernel(const ndfield_t& EM, + const ndfield_t& J, + const M& metric) + : EM { EM } + , J { J } + , metric { metric } {} + + Inline void operator()(index_t i1, real_t& buff) const { + const auto i1_ = COORD(i1); + if constexpr (F == StatsID::B2) { + if constexpr (I == 1) { + const auto b1_u = EM(i1, em::bx1); + const auto b1_d = metric.template transform<1, Idx::U, Idx::D>({ i1_ }, + b1_u); + buff += b1_u * b1_d * metric.sqrt_det_h({ i1_ }); + } else if constexpr (I == 2) { + const auto b2_u = EM(i1, em::bx2); + const auto b2_d = metric.template transform<2, Idx::U, Idx::D>( + { i1_ + HALF }, + b2_u); + buff += b2_u * b2_d * metric.sqrt_det_h({ i1_ + HALF }); + } else { + const auto b3_u = EM(i1, em::bx3); + const auto b3_d = metric.template transform<3, Idx::U, Idx::D>( + { i1_ + HALF }, + b3_u); + buff += b3_u * b3_d * metric.sqrt_det_h({ i1_ + HALF }); + } + } else if constexpr (F == StatsID::E2) { + if constexpr (I == 1) { + const auto e1_u = EM(i1, em::ex1); + const auto e1_d = metric.template transform<1, Idx::U, Idx::D>( + { i1_ + HALF }, + e1_u); + buff += e1_u * e1_d * metric.sqrt_det_h({ i1_ + HALF }); + } else if constexpr (I == 2) { + const auto e2_u = EM(i1, em::ex2); + const auto e2_d = metric.template transform<2, Idx::U, Idx::D>({ i1_ }, + e2_u); + buff += e2_u * e2_d * metric.sqrt_det_h({ i1_ }); + } else { + const auto e3_u = EM(i1, em::ex3); + const auto e3_d = metric.template transform<3, Idx::U, Idx::D>({ i1_ }, + e3_u); + buff += e3_u * e3_d * metric.sqrt_det_h({ i1_ }); + } + } else if constexpr (F == StatsID::ExB) { + if constexpr (I == 1) { + const auto e2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF }, + HALF * (EM(i1, em::ex2) + EM(i1 + 1, em::ex2))); + const auto e3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF }, + HALF * (EM(i1, em::ex3) + EM(i1 + 1, em::ex3))); + const auto b2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF }, + EM(i1, em::bx2)); + const auto b3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF }, + EM(i1, em::bx3)); + buff += (e2_t * b3_t - e3_t * b2_t) * metric.sqrt_det_h({ i1_ + HALF }); + } else if constexpr (I == 2) { + const auto e1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF }, + EM(i1, em::ex1)); + const auto e3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF }, + HALF * (EM(i1, em::ex3) + EM(i1 + 1, em::ex3))); + const auto b1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF }, + HALF * (EM(i1, em::bx1) + EM(i1 + 1, em::bx1))); + const auto b3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF }, + EM(i1, em::bx3)); + buff += (e3_t * b1_t - e1_t * b3_t) * metric.sqrt_det_h({ i1_ + HALF }); + } else { + const auto e1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF }, + EM(i1, em::ex1)); + const auto e2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF }, + HALF * (EM(i1, em::ex2) + EM(i1 + 1, em::ex2))); + const auto b1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF }, + HALF * (EM(i1, em::bx1) + EM(i1 + 1, em::bx1))); + const auto b2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF }, + EM(i1, em::bx2)); + buff += (e1_t * b2_t - e2_t * b1_t) * metric.sqrt_det_h({ i1_ + HALF }); + } + } else if constexpr (F == StatsID::JdotE) { + vec_t e_t { ZERO }; + vec_t j_t { ZERO }; + metric.template transform( + { i1_ + HALF }, + { EM(i1, em::ex1), + HALF * (EM(i1, em::ex2) + EM(i1 + 1, em::ex2)), + HALF * (EM(i1, em::ex3) + EM(i1 + 1, em::ex3)) }, + e_t); + metric.template transform( + { i1_ + HALF }, + { J(i1, cur::jx1), + HALF * (J(i1, cur::jx2) + J(i1 + 1, cur::jx2)), + HALF * (J(i1, cur::jx3) + J(i1 + 1, cur::jx3)) }, + j_t); + buff += (e_t[0] * j_t[0] + e_t[1] * j_t[1] + e_t[2] * j_t[2]) * + metric.sqrt_det_h({ i1_ + HALF }); + } + } + + Inline void operator()(index_t i1, index_t i2, real_t& buff) const { + const auto i1_ = COORD(i1); + const auto i2_ = COORD(i2); + if constexpr (F == StatsID::B2) { + if constexpr (I == 1) { + const auto b1_u = EM(i1, i2, em::bx1); + const auto b1_d = metric.template transform<1, Idx::U, Idx::D>( + { i1_, i2_ + HALF }, + b1_u); + buff += b1_u * b1_d * metric.sqrt_det_h({ i1_, i2_ + HALF }); + } else if constexpr (I == 2) { + const auto b2_u = EM(i1, i2, em::bx2); + const auto b2_d = metric.template transform<2, Idx::U, Idx::D>( + { i1_ + HALF, i2_ }, + b2_u); + buff += b2_u * b2_d * metric.sqrt_det_h({ i1_ + HALF, i2_ }); + } else { + const auto b3_u = EM(i1, i2, em::bx3); + const auto b3_d = metric.template transform<3, Idx::U, Idx::D>( + { i1_ + HALF, i2_ + HALF }, + b3_u); + buff += b3_u * b3_d * metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF }); + } + } else if constexpr (F == StatsID::E2) { + if constexpr (I == 1) { + const auto e1_u = EM(i1, i2, em::ex1); + const auto e1_d = metric.template transform<1, Idx::U, Idx::D>( + { i1_ + HALF, i2_ }, + e1_u); + buff += e1_u * e1_d * metric.sqrt_det_h({ i1_ + HALF, i2_ }); + } else if constexpr (I == 2) { + const auto e2_u = EM(i1, i2, em::ex2); + const auto e2_d = metric.template transform<2, Idx::U, Idx::D>( + { i1_, i2_ + HALF }, + e2_u); + buff += e2_u * e2_d * metric.sqrt_det_h({ i1_, i2_ + HALF }); + } else { + const auto e3_u = EM(i1, i2, em::ex3); + const auto e3_d = metric.template transform<3, Idx::U, Idx::D>( + { i1_, i2_ }, + e3_u); + buff += e3_u * e3_d * metric.sqrt_det_h({ i1_, i2_ }); + } + } else if constexpr (F == StatsID::ExB) { + if constexpr (I == 1) { + const auto e2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + HALF * (EM(i1, i2, em::ex2) + EM(i1 + 1, i2, em::ex2))); + const auto e3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + INV_4 * (EM(i1, i2, em::ex3) + EM(i1 + 1, i2, em::ex3) + + EM(i1, i2 + 1, em::ex3) + EM(i1 + 1, i2 + 1, em::ex3))); + const auto b2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + HALF * (EM(i1, i2, em::bx2) + EM(i1, i2 + 1, em::bx2))); + const auto b3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + EM(i1, i2, em::bx3)); + buff += (e2_t * b3_t - e3_t * b2_t) * + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF }); + } else if constexpr (I == 2) { + const auto e1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + HALF * (EM(i1, i2, em::ex1) + EM(i1, i2 + 1, em::ex1))); + const auto e3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + INV_4 * (EM(i1, i2, em::ex3) + EM(i1 + 1, i2, em::ex3) + + EM(i1, i2 + 1, em::ex3) + EM(i1 + 1, i2 + 1, em::ex3))); + const auto b1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + HALF * (EM(i1, i2, em::bx1) + EM(i1 + 1, i2, em::bx1))); + const auto b3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + EM(i1, i2, em::bx3)); + buff += (e3_t * b1_t - e1_t * b3_t) * + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF }); + } else { + const auto e1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + HALF * (EM(i1, i2, em::ex1) + EM(i1, i2 + 1, em::ex1))); + const auto e2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + HALF * (EM(i1, i2, em::ex2) + EM(i1 + 1, i2, em::ex2))); + const auto b1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + HALF * (EM(i1, i2, em::bx1) + EM(i1 + 1, i2, em::bx1))); + const auto b2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF }, + HALF * (EM(i1, i2, em::bx2) + EM(i1, i2 + 1, em::bx2))); + buff += (e1_t * b2_t - e2_t * b1_t) * + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF }); + } + } else if constexpr (F == StatsID::JdotE) { + vec_t e_t { ZERO }; + vec_t j_t { ZERO }; + metric.template transform( + { i1_ + HALF, i2_ + HALF }, + { HALF * (EM(i1, i2, em::ex1) + EM(i1, i2 + 1, em::ex1)), + HALF * (EM(i1, i2, em::ex2) + EM(i1 + 1, i2, em::ex2)), + INV_4 * (EM(i1, i2, em::ex3) + EM(i1 + 1, i2, em::ex3) + + EM(i1, i2 + 1, em::ex3) + EM(i1 + 1, i2 + 1, em::ex3)) }, + e_t); + metric.template transform( + { i1_ + HALF, i2_ + HALF }, + { HALF * (J(i1, i2, cur::jx1) + J(i1, i2 + 1, cur::jx1)), + HALF * (J(i1, i2, cur::jx2) + J(i1 + 1, i2, cur::jx2)), + INV_4 * (J(i1, i2, cur::jx3) + J(i1 + 1, i2, cur::jx3) + + J(i1, i2 + 1, cur::jx3) + J(i1 + 1, i2 + 1, cur::jx3)) }, + j_t); + buff += (e_t[0] * j_t[0] + e_t[1] * j_t[1] + e_t[2] * j_t[2]) * + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF }); + } + } + + Inline void operator()(index_t i1, index_t i2, index_t i3, real_t& buff) const { + const auto i1_ = COORD(i1); + const auto i2_ = COORD(i2); + const auto i3_ = COORD(i3); + if constexpr (F == StatsID::B2) { + if constexpr (I == 1) { + const auto b1_u = EM(i1, i2, i3, em::bx1); + const auto b1_d = metric.template transform<1, Idx::U, Idx::D>( + { i1_, i2_ + HALF, i3_ + HALF }, + b1_u); + buff += b1_u * b1_d * metric.sqrt_det_h({ i1_, i2_ + HALF, i3_ + HALF }); + } else if constexpr (I == 2) { + const auto b2_u = EM(i1, i2, i3, em::bx2); + const auto b2_d = metric.template transform<2, Idx::U, Idx::D>( + { i1_ + HALF, i2_, i3_ + HALF }, + b2_u); + buff += b2_u * b2_d * metric.sqrt_det_h({ i1_ + HALF, i2_, i3_ + HALF }); + } else { + const auto b3_u = EM(i1, i2, i3, em::bx3); + const auto b3_d = metric.template transform<3, Idx::U, Idx::D>( + { i1_ + HALF, i2_ + HALF, i3_ }, + b3_u); + buff += b3_u * b3_d * metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF, i3_ }); + } + } else if constexpr (F == StatsID::E2) { + if constexpr (I == 1) { + const auto e1_u = EM(i1, i2, i3, em::ex1); + const auto e1_d = metric.template transform<1, Idx::U, Idx::D>( + { i1_ + HALF, i2_, i3_ }, + e1_u); + buff += e1_u * e1_d * metric.sqrt_det_h({ i1_ + HALF, i2_, i3_ }); + } else if constexpr (I == 2) { + const auto e2_u = EM(i1, i2, i3, em::ex2); + const auto e2_d = metric.template transform<2, Idx::U, Idx::D>( + { i1_, i2_ + HALF, i3_ }, + e2_u); + buff += e2_u * e2_d * metric.sqrt_det_h({ i1_, i2_ + HALF, i3_ }); + } else { + const auto e3_u = EM(i1, i2, i3, em::ex3); + const auto e3_d = metric.template transform<3, Idx::U, Idx::D>( + { i1_, i2_, i3_ + HALF }, + e3_u); + buff += e3_u * e3_d * metric.sqrt_det_h({ i1_, i2_, i3_ + HALF }); + } + } else if constexpr (F == StatsID::ExB) { + if constexpr (I == 1) { + const auto e2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + INV_4 * + (EM(i1, i2, i3, em::ex2) + EM(i1 + 1, i2, i3, em::ex2) + + EM(i1, i2, i3 + 1, em::ex2) + EM(i1 + 1, i2, i3 + 1, em::ex2))); + const auto e3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + INV_4 * + (EM(i1, i2, i3, em::ex3) + EM(i1 + 1, i2, i3, em::ex3) + + EM(i1, i2 + 1, i3, em::ex3) + EM(i1 + 1, i2 + 1, i3, em::ex3))); + const auto b2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + HALF * (EM(i1, i2, i3, em::bx2) + EM(i1, i2 + 1, i3, em::bx2))); + const auto b3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + HALF * (EM(i1, i2, i3, em::bx3) + EM(i1, i2, i3 + 1, em::bx3))); + buff += (e2_t * b3_t - e3_t * b2_t) * + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF, i3_ + HALF }); + } else if constexpr (I == 2) { + const auto e1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + INV_4 * + (EM(i1, i2, i3, em::ex1) + EM(i1, i2 + 1, i3, em::ex1) + + EM(i1, i2, i3 + 1, em::ex1) + EM(i1, i2 + 1, i3 + 1, em::ex1))); + const auto e3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + INV_4 * + (EM(i1, i2, i3, em::ex3) + EM(i1 + 1, i2, i3, em::ex3) + + EM(i1, i2 + 1, i3, em::ex3) + EM(i1 + 1, i2 + 1, i3, em::ex3))); + const auto b1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + HALF * (EM(i1, i2, i3, em::bx1) + EM(i1 + 1, i2, i3, em::bx1))); + const auto b3_t = metric.template transform<3, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + HALF * (EM(i1, i2, i3, em::bx3) + EM(i1, i2, i3 + 1, em::bx3))); + buff += (e3_t * b1_t - e1_t * b3_t) * + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF, i3_ + HALF }); + } else { + const auto e1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + INV_4 * + (EM(i1, i2, i3, em::ex1) + EM(i1, i2 + 1, i3, em::ex1) + + EM(i1, i2, i3 + 1, em::ex1) + EM(i1, i2 + 1, i3 + 1, em::ex1))); + const auto e2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + INV_4 * + (EM(i1, i2, i3, em::ex2) + EM(i1 + 1, i2, i3, em::ex2) + + EM(i1, i2, i3 + 1, em::ex2) + EM(i1 + 1, i2, i3 + 1, em::ex2))); + const auto b1_t = metric.template transform<1, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + HALF * (EM(i1, i2, i3, em::bx1) + EM(i1 + 1, i2, i3, em::bx1))); + const auto b2_t = metric.template transform<2, Idx::U, Idx::T>( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + HALF * (EM(i1, i2, i3, em::bx2) + EM(i1, i2 + 1, i3, em::bx2))); + buff += (e1_t * b2_t - e2_t * b1_t) * + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF, i3_ + HALF }); + } + } else if constexpr (F == StatsID::JdotE) { + vec_t e_t { ZERO }; + vec_t j_t { ZERO }; + metric.template transform( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + { INV_4 * (EM(i1, i2, i3, em::ex1) + EM(i1, i2 + 1, i3, em::ex1) + + EM(i1, i2, i3 + 1, em::ex1) + EM(i1, i2 + 1, i3 + 1, em::ex1)), + INV_4 * (EM(i1, i2, i3, em::ex2) + EM(i1 + 1, i2, i3, em::ex2) + + EM(i1, i2, i3 + 1, em::ex2) + EM(i1 + 1, i2, i3 + 1, em::ex2)), + INV_4 * (EM(i1, i2, i3, em::ex3) + EM(i1 + 1, i2, i3, em::ex3) + + EM(i1, i2 + 1, i3, em::ex3) + EM(i1 + 1, i2 + 1, i3, em::ex3)) }, + e_t); + metric.template transform( + { i1_ + HALF, i2_ + HALF, i3_ + HALF }, + { INV_4 * (J(i1, i2, i3, cur::jx1) + J(i1, i2 + 1, i3, cur::jx1) + + J(i1, i2, i3 + 1, cur::jx1) + J(i1, i2 + 1, i3 + 1, cur::jx1)), + INV_4 * (J(i1, i2, i3, cur::jx2) + J(i1 + 1, i2, i3, cur::jx2) + + J(i1, i2, i3 + 1, cur::jx2) + J(i1 + 1, i2, i3 + 1, cur::jx2)), + INV_4 * (J(i1, i2, i3, cur::jx3) + J(i1 + 1, i2, i3, cur::jx3) + + J(i1, i2 + 1, i3, cur::jx3) + J(i1 + 1, i2 + 1, i3, cur::jx3)) }, + j_t); + buff += (e_t[0] * j_t[0] + e_t[1] * j_t[1] + e_t[2] * j_t[2]) * + metric.sqrt_det_h({ i1_ + HALF, i2_ + HALF, i3_ + HALF }); + } + } + }; + + template + auto get_contrib(float mass, float charge) -> real_t { + if constexpr (P == StatsID::Rho) { + return mass; + } else if constexpr (P == StatsID::Charge) { + return charge; + } else { + return ONE; + } + } + + template + class ReducedParticleMoments_kernel { + static_assert(M::is_metric, "M must be a metric class"); + static constexpr auto D = M::Dim; + + static_assert((P == StatsID::Rho) || (P == StatsID::Charge) || + (P == StatsID::N) || (P == StatsID::Npart) || + (P == StatsID::T), + "Invalid stats ID"); + + const unsigned short c1, c2; + const array_t i1, i2, i3; + const array_t dx1, dx2, dx3; + const array_t ux1, ux2, ux3; + const array_t phi; + const array_t weight; + const array_t tag; + const float mass; + const float charge; + const bool use_weights; + const M metric; + + const real_t contrib; + + public: + ReducedParticleMoments_kernel(const std::vector& components, + const array_t& i1, + const array_t& i2, + const array_t& i3, + const array_t& dx1, + const array_t& dx2, + const array_t& dx3, + const array_t& ux1, + const array_t& ux2, + const array_t& ux3, + const array_t& phi, + const array_t& weight, + const array_t& tag, + float mass, + float charge, + bool use_weights, + const M& metric) + : c1 { (components.size() > 0) ? components[0] + : static_cast(0) } + , c2 { (components.size() == 2) ? components[1] + : static_cast(0) } + , i1 { i1 } + , i2 { i2 } + , i3 { i3 } + , dx1 { dx1 } + , dx2 { dx2 } + , dx3 { dx3 } + , ux1 { ux1 } + , ux2 { ux2 } + , ux3 { ux3 } + , phi { phi } + , weight { weight } + , tag { tag } + , mass { mass } + , charge { charge } + , use_weights { use_weights } + , metric { metric } + , contrib { get_contrib

(mass, charge) } { + raise::ErrorIf(((P == StatsID::Rho) || (P == StatsID::Charge)) && + (mass == ZERO), + "Rho & Charge for massless particles not defined", + HERE); + } + + Inline void operator()(index_t p, real_t& buff) const { + if (tag(p) == ParticleTag::dead) { + return; + } + if constexpr (P == StatsID::Npart) { + buff += ONE; + return; + } else if constexpr (P == StatsID::N or P == StatsID::Rho or + P == StatsID::Charge) { + buff += use_weights ? weight(p) : contrib; + return; + } else { + // for stress-energy tensor + real_t energy { ZERO }; + vec_t u_Phys { ZERO }; + if constexpr (S == SimEngine::SRPIC) { + // SR + // stress-energy tensor for SR is computed in the tetrad (hatted) basis + if constexpr (M::CoordType == Coord::Cart) { + u_Phys[0] = ux1(p); + u_Phys[1] = ux2(p); + u_Phys[2] = ux3(p); + } else { + static_assert(D != Dim::_1D, "non-Cartesian SRPIC 1D"); + coord_t x_Code { ZERO }; + x_Code[0] = static_cast(i1(p)) + static_cast(dx1(p)); + x_Code[1] = static_cast(i2(p)) + static_cast(dx2(p)); + if constexpr (D == Dim::_3D) { + x_Code[2] = static_cast(i3(p)) + static_cast(dx3(p)); + } else { + x_Code[2] = phi(p); + } + metric.template transform_xyz( + x_Code, + { ux1(p), ux2(p), ux3(p) }, + u_Phys); + } + if (mass == ZERO) { + energy = NORM(u_Phys[0], u_Phys[1], u_Phys[2]); + } else { + energy = mass * + math::sqrt(ONE + NORM_SQR(u_Phys[0], u_Phys[1], u_Phys[2])); + } + } else { + // GR + // stress-energy tensor for GR is computed in contravariant basis + static_assert(D != Dim::_1D, "GRPIC 1D"); + coord_t x_Code { ZERO }; + x_Code[0] = static_cast(i1(p)) + static_cast(dx1(p)); + x_Code[1] = static_cast(i2(p)) + static_cast(dx2(p)); + if constexpr (D == Dim::_3D) { + x_Code[2] = static_cast(i3(p)) + static_cast(dx3(p)); + } + vec_t u_Cntrv { ZERO }; + // compute u_i u^i for energy + metric.template transform(x_Code, + { ux1(p), ux2(p), ux3(p) }, + u_Cntrv); + energy = u_Cntrv[0] * ux1(p) + u_Cntrv[1] * ux2(p) + u_Cntrv[2] * ux3(p); + if (mass == ZERO) { + energy = math::sqrt(energy); + } else { + energy = mass * math::sqrt(ONE + energy); + } + metric.template transform(x_Code, u_Cntrv, u_Phys); + } + // compute the corresponding moment + real_t coeff = ONE; +#pragma unroll + for (const auto& c : { c1, c2 }) { + if (c == 0) { + coeff *= energy; + } else { + coeff *= u_Phys[c - 1]; + } + } + buff += coeff / energy; + } + } + }; + +} // namespace kernel + +#endif diff --git a/src/kernels/tests/CMakeLists.txt b/src/kernels/tests/CMakeLists.txt index e55dbc111..2702e0526 100644 --- a/src/kernels/tests/CMakeLists.txt +++ b/src/kernels/tests/CMakeLists.txt @@ -1,9 +1,12 @@ +# cmake-lint: disable=C0103,C0111 # ------------------------------ # @brief: Generates tests for the `ntt_kernels` module +# # @uses: -# - kokkos [required] -# - plog [required] -# - mpi [optional] +# +# * kokkos [required] +# * plog [required] +# * mpi [optional] # ------------------------------ set(SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../) @@ -29,3 +32,7 @@ gen_test(fields_to_phys) gen_test(prtls_to_phys) gen_test(gca_pusher) gen_test(prtl_bc) +gen_test(flds_bc) +gen_test(pusher) +gen_test(ext_force) +gen_test(reduced_stats) diff --git a/src/kernels/tests/ampere_mink.cpp b/src/kernels/tests/ampere_mink.cpp index 80af88fa5..181946613 100644 --- a/src/kernels/tests/ampere_mink.cpp +++ b/src/kernels/tests/ampere_mink.cpp @@ -25,7 +25,7 @@ void errorIf(bool condition, const std::string& message) { Inline auto equal(real_t a, real_t b, const char* msg, real_t acc) -> bool { if (not(math::abs(a - b) < acc)) { - printf("%.12e != %.12e [%.12e] %s\n", a, b, math::abs(a - b), msg); + Kokkos::printf("%.12e != %.12e [%.12e] %s\n", a, b, math::abs(a - b), msg); return false; } return true; @@ -108,7 +108,7 @@ void testAmpere(const std::vector& res) { const real_t sx = constant::TWO_PI, sy = 4.0 * constant::PI; const auto metric = Minkowski { res, - {{ ZERO, sx }, { ZERO, sy }} + { { ZERO, sx }, { ZERO, sy } } }; auto emfield = ndfield_t { "emfield", res[0] + 2 * N_GHOSTS, @@ -116,7 +116,7 @@ void testAmpere(const std::vector& res) { const std::size_t i1min = N_GHOSTS, i1max = res[0] + N_GHOSTS; const std::size_t i2min = N_GHOSTS, i2max = res[1] + N_GHOSTS; const auto range = CreateRangePolicy({ i1min, i2min }, - { i1max, i2max }); + { i1max, i2max }); const auto range_ext = CreateRangePolicy( { 0, 0 }, { res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS }); @@ -213,7 +213,7 @@ void testAmpere(const std::vector& res) { sz = constant::TWO_PI; const auto metric = Minkowski { res, - {{ ZERO, sx }, { ZERO, sy }, { ZERO, sz }} + { { ZERO, sx }, { ZERO, sy }, { ZERO, sz } } }; auto emfield = ndfield_t { "emfield", res[0] + 2 * N_GHOSTS, @@ -223,7 +223,7 @@ void testAmpere(const std::vector& res) { const std::size_t i2min = N_GHOSTS, i2max = res[1] + N_GHOSTS; const std::size_t i3min = N_GHOSTS, i3max = res[2] + N_GHOSTS; const auto range = CreateRangePolicy({ i1min, i2min, i3min }, - { i1max, i2max, i3max }); + { i1max, i2max, i3max }); const auto range_ext = CreateRangePolicy( { 0, 0, 0 }, { res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS, res[2] + 2 * N_GHOSTS }); diff --git a/src/kernels/tests/deposit.cpp b/src/kernels/tests/deposit.cpp index 9a8ae1cc6..3df2828b5 100644 --- a/src/kernels/tests/deposit.cpp +++ b/src/kernels/tests/deposit.cpp @@ -27,13 +27,16 @@ void errorIf(bool condition, const std::string& message) { } } -inline static constexpr auto epsilon = std::numeric_limits::epsilon(); - -Inline auto equal(real_t a, real_t b, const char* msg = "", real_t acc = ONE) - -> bool { - const auto eps = epsilon * acc; - if (not cmp::AlmostEqual(a, b, eps)) { - printf("%.12e != %.12e %s\n", a, b, msg); +const real_t eps = std::is_same_v ? (real_t)(1e-6) + : (real_t)(1e-3); + +Inline auto equal(real_t a, real_t b, const char* msg, real_t eps) -> bool { + if ((a - b) >= eps * math::max(math::fabs(a), math::fabs(b))) { + Kokkos::printf("%.12e != %.12e %s\n", a, b, msg); + Kokkos::printf("%.12e >= %.12e %s\n", + a - b, + eps * math::max(math::fabs(a), math::fabs(b)), + msg); return false; } return true; @@ -49,13 +52,18 @@ void put_value(array_t arr, T value, int i) { template void testDeposit(const std::vector& res, const boundaries_t& ext, - const std::map& params = {}, - const real_t acc = ONE) { + const std::map& params, + const real_t eps) { static_assert(M::Dim == 2); errorIf(res.size() != M::Dim, "res.size() != M::Dim"); using namespace ntt; - M metric { res, ext, params }; + auto extents = ext; + if constexpr (M::CoordType != Coord::Cart) { + extents.emplace_back(ZERO, (real_t)(constant::PI)); + } + + M metric { res, extents, params }; const auto nx1 = res[0]; const auto nx2 = res[1]; @@ -81,9 +89,7 @@ void testDeposit(const std::vector& res, array_t tag { "tag", 10 }; const real_t charge { 1.0 }, inv_dt { 1.0 }; - auto J_scat = Kokkos::Experimental::create_scatter_view(J); - - const int i0 = 4, j0 = 4; + const int i0 = 40, j0 = 40; const prtldx_t dxi = 0.53, dxf = 0.47; const prtldx_t dyi = 0.34, dyf = 0.52; @@ -119,33 +125,25 @@ void testDeposit(const std::vector& res, put_value(dx2, dyf, 0); put_value(dx1_prev, dxi, 0); put_value(dx2_prev, dyi, 0); + put_value(ux1, ZERO, 0); + put_value(ux2, ZERO, 0); + put_value(ux3, ZERO, 0); put_value(weight, 1.0, 0); put_value(tag, ParticleTag::alive, 0); - Kokkos::parallel_for("CurrentsDeposit", - 10, + auto J_scat = Kokkos::Experimental::create_scatter_view(J); + + // clang-format off + Kokkos::parallel_for("CurrentsDeposit", 1, kernel::DepositCurrents_kernel(J_scat, - i1, - i2, - i3, - i1_prev, - i2_prev, - i3_prev, - dx1, - dx2, - dx3, - dx1_prev, - dx2_prev, - dx3_prev, - ux1, - ux2, - ux3, - phi, - weight, - tag, - metric, - charge, - inv_dt)); + i1, i2, i3, + i1_prev, i2_prev, i3_prev, + dx1, dx2, dx3, + dx1_prev, dx2_prev, dx3_prev, + ux1, ux2, ux3, + phi, weight, tag, + metric, charge, inv_dt)); + // clang-format on Kokkos::Experimental::contribute(J, J_scat); @@ -166,13 +164,13 @@ void testDeposit(const std::vector& res, if (not cmp::AlmostZero(SumDivJ)) { throw std::logic_error("DepositCurrents_kernel::SumDivJ != 0"); } - errorIf(not equal(J_h(i0 + N_GHOSTS, j0 + N_GHOSTS, cur::jx1), Jx1, "", acc), + errorIf(not equal(J_h(i0 + N_GHOSTS, j0 + N_GHOSTS, cur::jx1), Jx1, "", eps), "DepositCurrents_kernel::Jx1 is incorrect"); - errorIf(not equal(J_h(i0 + N_GHOSTS, j0 + 1 + N_GHOSTS, cur::jx1), Jx2, "", acc), + errorIf(not equal(J_h(i0 + N_GHOSTS, j0 + 1 + N_GHOSTS, cur::jx1), Jx2, "", eps), "DepositCurrents_kernel::Jx2 is incorrect"); - errorIf(not equal(J_h(i0 + N_GHOSTS, j0 + N_GHOSTS, cur::jx2), Jy1, "", acc), + errorIf(not equal(J_h(i0 + N_GHOSTS, j0 + N_GHOSTS, cur::jx2), Jy1, "", eps), "DepositCurrents_kernel::Jy1 is incorrect"); - errorIf(not equal(J_h(i0 + 1 + N_GHOSTS, j0 + N_GHOSTS, cur::jx2), Jy2, "", acc), + errorIf(not equal(J_h(i0 + 1 + N_GHOSTS, j0 + N_GHOSTS, cur::jx2), Jy2, "", eps), "DepositCurrents_kernel::Jy2 is incorrect"); } @@ -183,59 +181,26 @@ auto main(int argc, char* argv[]) -> int { using namespace ntt; using namespace metric; - testDeposit, SimEngine::SRPIC>( - { - 10, - 10 - }, - { { 0.0, 55.0 }, { 0.0, 55.0 } }, - {}, - 30); - - testDeposit, SimEngine::SRPIC>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - {}, - 30); - - testDeposit, SimEngine::SRPIC>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - { { "r0", 0.0 }, { "h", 0.25 } }, - 30); - - testDeposit, SimEngine::GRPIC>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - { { "a", 0.9 } }, - 30); - - testDeposit, SimEngine::GRPIC>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - { { "r0", 0.0 }, { "h", 0.25 }, { "a", 0.9 } }, - 30); - - testDeposit, SimEngine::GRPIC>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - { { "a", 0.9 } }, - 30); + const auto res = std::vector { 100, 100 }; + const auto r_extent = boundaries_t { + { 1.0, 100.0 } + }; + const auto xy_extent = boundaries_t { + { 0.0, 55.0 }, + { 0.0, 55.0 } + }; + const std::map params { + { "r0", 0.0 }, + { "h", 0.25 }, + { "a", 0.9 } + }; + + testDeposit, SimEngine::SRPIC>(res, xy_extent, {}, eps); + testDeposit, SimEngine::SRPIC>(res, r_extent, {}, eps); + testDeposit, SimEngine::SRPIC>(res, r_extent, params, eps); + testDeposit, SimEngine::GRPIC>(res, r_extent, params, eps); + testDeposit, SimEngine::GRPIC>(res, r_extent, params, eps); + testDeposit, SimEngine::GRPIC>(res, r_extent, params, eps); } catch (std::exception& e) { std::cerr << e.what() << std::endl; diff --git a/src/kernels/tests/digital_filter.cpp b/src/kernels/tests/digital_filter.cpp index e0cc352f5..30059e471 100644 --- a/src/kernels/tests/digital_filter.cpp +++ b/src/kernels/tests/digital_filter.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include void errorIf(bool condition, const std::string& message) { @@ -34,18 +33,25 @@ void testFilter(const std::vector& res, const boundaries_t& ext, const std::map& params = {}) { static_assert(M::Dim == 2); - errorIf(res.size() != M::Dim, "res.size() != M::Dim"); + errorIf(res.size() != static_cast(M::Dim), "res.size() != M::Dim"); using namespace ntt; auto boundaries = boundaries_t {}; + auto extents = ext; if constexpr (M::CoordType != Coord::Cart) { boundaries = { - {FldsBC::CUSTOM, FldsBC::CUSTOM}, - { FldsBC::AXIS, FldsBC::AXIS} + { FldsBC::CUSTOM, FldsBC::CUSTOM }, + { FldsBC::AXIS, FldsBC::AXIS } + }; + extents.emplace_back(ZERO, (real_t)(constant::PI)); + } else { + boundaries = { + { FldsBC::PERIODIC, FldsBC::PERIODIC }, + { FldsBC::PERIODIC, FldsBC::PERIODIC } }; } - M metric { res, ext, params }; + M metric { res, extents, params }; const auto nx1 = res[0]; const auto nx2 = res[1]; @@ -128,53 +134,45 @@ auto main(int argc, char* argv[]) -> int { using namespace ntt; using namespace metric; - testFilter>( - { - 10, - 10 - }, - { { 0.0, 55.0 }, { 0.0, 55.0 } }, - {}); - - testFilter>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - {}); - - testFilter>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - { { "r0", 0.0 }, { "h", 0.25 } }); - - testFilter>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - { { "a", 0.9 } }); - - testFilter>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - { { "r0", 0.0 }, { "h", 0.25 }, { "a", 0.9 } }); - - testFilter>( - { - 10, - 10 - }, - { { 1.0, 100.0 } }, - { { "a", 0.9 } }); + const auto res = std::vector { 10, 10 }; + const auto r_extent = boundaries_t { + { 0.0, 100.0 } + }; + const auto xy_extent = boundaries_t { + { 0.0, 55.0 }, + { 0.0, 55.0 } + }; + + testFilter>(res, xy_extent, {}); + + testFilter>(res, r_extent, {}); + + testFilter>(res, + r_extent, + { + { "r0", 0.0 }, + { "h", 0.25 } + }); + + testFilter>(res, + r_extent, + { + { "a", 0.9 } + }); + + testFilter>(res, + r_extent, + { + { "r0", 0.0 }, + { "h", 0.25 }, + { "a", 0.9 } + }); + + testFilter>(res, + r_extent, + { + { "a", 0.9 } + }); } catch (std::exception& e) { std::cerr << e.what() << std::endl; @@ -183,4 +181,4 @@ auto main(int argc, char* argv[]) -> int { } Kokkos::finalize(); return 0; -} \ No newline at end of file +} diff --git a/src/kernels/tests/ext_force.cpp b/src/kernels/tests/ext_force.cpp new file mode 100644 index 000000000..12e3466cf --- /dev/null +++ b/src/kernels/tests/ext_force.cpp @@ -0,0 +1,290 @@ +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/error.h" +#include "utils/numeric.h" +#include "utils/plog.h" + +#include "metrics/minkowski.h" + +#include "kernels/particle_pusher_sr.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using namespace ntt; +using namespace metric; + +void check_value(unsigned int t, + real_t target, + real_t value, + real_t eps, + const std::string& msg) { + const auto msg_ = fmt::format("%s: %e != %e @ %u", msg.c_str(), target, value, t); + const auto diff = math::abs(target - value); + const auto sum = HALF * (math::abs(target) + math::abs(value)); + raise::ErrorIf(((sum > eps) and (diff / sum > eps)) or + ((sum <= eps) and (diff > eps / 10.0)), + msg_ + " " + fmt::format("%.12e, %.12e", diff, sum), + HERE); +} + +template +void put_value(array_t& arr, T v, index_t p) { + auto h = Kokkos::create_mirror_view(arr); + Kokkos::deep_copy(h, arr); + h(p) = v; + Kokkos::deep_copy(arr, h); +} + +struct Force { + const std::vector species { 1 }; + + Force(real_t force) : force { force } {} + + Inline auto fx1(const spidx_t&, + const simtime_t&, + const coord_t&) const -> real_t { + return force * math::sin(ONE) * math::sin(ONE); + } + + Inline auto fx2(const spidx_t&, + const simtime_t&, + const coord_t&) const -> real_t { + return force * math::sin(ONE) * math::cos(ONE); + } + + Inline auto fx3(const spidx_t&, + const simtime_t&, + const coord_t&) const -> real_t { + return force * math::cos(ONE); + } + +private: + const real_t force; +}; + +template +void testPusher(const std::vector& res) { + static_assert(M::Dim == 3); + raise::ErrorIf(res.size() != M::Dim, "res.size() != M::Dim", HERE); + + M metric { + res, + { { 0.0, (real_t)(res[0]) }, { 0.0, (real_t)(res[1]) }, { 0.0, (real_t)(res[2]) } }, + {} + }; + + const int nx1 = res[0]; + const int nx2 = res[1]; + const int nx3 = res[2]; + + const auto range_ext = CreateRangePolicy( + { 0, 0, 0 }, + { res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS, res[2] + 2 * N_GHOSTS }); + + auto emfield = ndfield_t { "emfield", + res[0] + 2 * N_GHOSTS, + res[1] + 2 * N_GHOSTS, + res[2] + 2 * N_GHOSTS }; + + const real_t x1_0 = 1.15, x2_0 = 1.85, x3_0 = 1.25; + const real_t ux1_0 = 0.02, ux2_0 = -0.2, ux3_0 = 0.1; + // const real_t gamma_0 = math::sqrt(ONE + NORM_SQR(ux1_0, ux2_0, ux3_0)); + const real_t omegaB0 = 1.0; + const real_t dt = 0.01; + const real_t f_mag = 0.01; + + Kokkos::parallel_for( + "init 3D", + range_ext, + Lambda(index_t i1, index_t i2, index_t i3) { + emfield(i1, i2, i3, em::ex1) = ZERO; + emfield(i1, i2, i3, em::ex2) = ZERO; + emfield(i1, i2, i3, em::ex3) = ZERO; + emfield(i1, i2, i3, em::bx1) = ZERO; + emfield(i1, i2, i3, em::bx2) = ZERO; + emfield(i1, i2, i3, em::bx3) = ZERO; + }); + + array_t i1 { "i1", 2 }; + array_t i2 { "i2", 2 }; + array_t i3 { "i3", 2 }; + array_t i1_prev { "i1_prev", 2 }; + array_t i2_prev { "i2_prev", 2 }; + array_t i3_prev { "i3_prev", 2 }; + array_t dx1 { "dx1", 2 }; + array_t dx2 { "dx2", 2 }; + array_t dx3 { "dx3", 2 }; + array_t dx1_prev { "dx1_prev", 2 }; + array_t dx2_prev { "dx2_prev", 2 }; + array_t dx3_prev { "dx3_prev", 2 }; + array_t ux1 { "ux1", 2 }; + array_t ux2 { "ux2", 2 }; + array_t ux3 { "ux3", 2 }; + array_t phi { "phi", 2 }; + array_t weight { "weight", 2 }; + array_t tag { "tag", 2 }; + + put_value(i1, (int)(x1_0), 0); + put_value(i2, (int)(x2_0), 0); + put_value(i3, (int)(x3_0), 0); + put_value(dx1, (prtldx_t)(x1_0 - (int)(x1_0)), 0); + put_value(dx2, (prtldx_t)(x2_0 - (int)(x2_0)), 0); + put_value(dx3, (prtldx_t)(x3_0 - (int)(x3_0)), 0); + put_value(ux1, ux1_0, 0); + put_value(ux2, ux2_0, 0); + put_value(ux3, ux3_0, 0); + put_value(tag, ParticleTag::alive, 0); + + put_value(i1, (int)(x1_0), 1); + put_value(i2, (int)(x2_0), 1); + put_value(i3, (int)(x3_0), 1); + put_value(dx1, (prtldx_t)(x1_0 - (int)(x1_0)), 1); + put_value(dx2, (prtldx_t)(x2_0 - (int)(x2_0)), 1); + put_value(dx3, (prtldx_t)(x3_0 - (int)(x3_0)), 1); + put_value(ux1, -ux1_0, 1); + put_value(ux2, -ux2_0, 1); + put_value(ux3, -ux3_0, 1); + put_value(tag, ParticleTag::alive, 1); + + // Particle boundaries + auto boundaries = boundaries_t {}; + boundaries = { + { PrtlBC::PERIODIC, PrtlBC::PERIODIC }, + { PrtlBC::PERIODIC, PrtlBC::PERIODIC }, + { PrtlBC::PERIODIC, PrtlBC::PERIODIC } + }; + + const spidx_t sp { 1u }; + + const real_t coeff = HALF * dt * omegaB0; + + const real_t eps = std::is_same_v ? 1e-4 : 1e-6; + + const auto ext_force = Force { f_mag }; + const auto force = + kernel::sr::Force { ext_force }; + + static plog::RollingFileAppender file_appender( + "pusher_log.csv"); + plog::init(plog::verbose, &file_appender); + PLOGD << "t,i1,i2,i3,dx1,dx2,dx3,ux1,ux2,ux3"; + + for (auto t { 0u }; t < 100; ++t) { + const real_t time = t * dt; + + // clang-format off + Kokkos::parallel_for( + "pusher", + CreateRangePolicy({0}, {2}), + kernel::sr::Pusher_kernel, decltype(force)>(PrtlPusher::BORIS, + false, true, kernel::sr::Cooling::None, + emfield, + sp, + i1, i2, i3, + i1_prev, i2_prev, i3_prev, + dx1, dx2, dx3, + dx1_prev, dx2_prev, dx3_prev, + ux1, ux2, ux3, + phi, tag, + metric, force, + (simtime_t)time, coeff, dt, + nx1, nx2, nx3, + boundaries, + ZERO, ZERO, ZERO)); + + auto i1_prev_ = Kokkos::create_mirror_view(i1_prev); + auto i2_prev_ = Kokkos::create_mirror_view(i2_prev); + auto i3_prev_ = Kokkos::create_mirror_view(i3_prev); + auto i1_ = Kokkos::create_mirror_view(i1); + auto i2_ = Kokkos::create_mirror_view(i2); + auto i3_ = Kokkos::create_mirror_view(i3); + Kokkos::deep_copy(i1_prev_, i1_prev); + Kokkos::deep_copy(i2_prev_, i2_prev); + Kokkos::deep_copy(i3_prev_, i3_prev); + Kokkos::deep_copy(i1_, i1); + Kokkos::deep_copy(i2_, i2); + Kokkos::deep_copy(i3_, i3); + + auto dx1_prev_ = Kokkos::create_mirror_view(dx1_prev); + auto dx2_prev_ = Kokkos::create_mirror_view(dx2_prev); + auto dx3_prev_ = Kokkos::create_mirror_view(dx3_prev); + auto dx1_ = Kokkos::create_mirror_view(dx1); + auto dx2_ = Kokkos::create_mirror_view(dx2); + auto dx3_ = Kokkos::create_mirror_view(dx3); + auto ux1_ = Kokkos::create_mirror_view(ux1); + auto ux2_ = Kokkos::create_mirror_view(ux2); + auto ux3_ = Kokkos::create_mirror_view(ux3); + Kokkos::deep_copy(dx1_prev_, dx1_prev); + Kokkos::deep_copy(dx2_prev_, dx2_prev); + Kokkos::deep_copy(dx3_prev_, dx3_prev); + Kokkos::deep_copy(dx1_, dx1); + Kokkos::deep_copy(dx2_, dx2); + Kokkos::deep_copy(dx3_, dx3); + Kokkos::deep_copy(ux1_, ux1); + Kokkos::deep_copy(ux2_, ux2); + Kokkos::deep_copy(ux3_, ux3); + + PLOGD.printf("%e,%d,%d,%d,%e,%e,%e,%e,%e,%e", + time, + i1_(1), + i2_(1), + i3_(1), + dx1_( 1), + dx2_( 1), + dx3_( 1), + ux1_( 1), + ux2_( 1), + ux3_( 1)); + + { + const real_t ux1_expect = ux1_0 + (time + dt) * f_mag * std::sin(ONE) * std::sin(ONE); + const real_t ux2_expect = ux2_0 + (time + dt) * f_mag * std::sin(ONE) * std::cos(ONE); + const real_t ux3_expect = ux3_0 + (time + dt) * f_mag * std::cos(ONE); + + check_value(t, ux1_(0), ux1_expect, eps, "Particle #1 ux1"); + check_value(t, ux2_(0), ux2_expect, eps, "Particle #1 ux2"); + check_value(t, ux3_(0), ux3_expect, eps, "Particle #1 ux3"); + } + + { + const real_t ux1_expect = -ux1_0 + (time + dt) * f_mag * std::sin(ONE) * std::sin(ONE); + const real_t ux2_expect = -ux2_0 + (time + dt) * f_mag * std::sin(ONE) * std::cos(ONE); + const real_t ux3_expect = -ux3_0 + (time + dt) * f_mag * std::cos(ONE); + + check_value(t, ux1_(1), ux1_expect, eps, "Particle #2 ux1"); + check_value(t, ux2_(1), ux2_expect, eps, "Particle #2 ux2"); + check_value(t, ux3_(1), ux3_expect, eps, "Particle #2 ux3"); + } + + } +} + +auto main(int argc, char* argv[]) -> int { + Kokkos::initialize(argc, argv); + + try { + using namespace ntt; + + testPusher>({ 10, 10, 10 }); + + } catch (std::exception& e) { + std::cerr << e.what() << std::endl; + Kokkos::finalize(); + return 1; + } + Kokkos::finalize(); + return 0; +} diff --git a/src/kernels/tests/faraday_mink.cpp b/src/kernels/tests/faraday_mink.cpp index 74c2b9b1a..db3b6d5bc 100644 --- a/src/kernels/tests/faraday_mink.cpp +++ b/src/kernels/tests/faraday_mink.cpp @@ -25,7 +25,7 @@ void errorIf(bool condition, const std::string& message) { Inline auto equal(real_t a, real_t b, const char* msg, real_t acc) -> bool { if (not(math::abs(a - b) < acc)) { - printf("%.12e != %.12e [%.12e] %s\n", a, b, math::abs(a - b), msg); + Kokkos::printf("%.12e != %.12e [%.12e] %s\n", a, b, math::abs(a - b), msg); return false; } return true; @@ -108,7 +108,7 @@ void testFaraday(const std::vector& res) { const real_t sx = constant::TWO_PI, sy = 4.0 * constant::PI; const auto metric = Minkowski { res, - {{ ZERO, sx }, { ZERO, sy }} + { { ZERO, sx }, { ZERO, sy } } }; auto emfield = ndfield_t { "emfield", res[0] + 2 * N_GHOSTS, @@ -116,7 +116,7 @@ void testFaraday(const std::vector& res) { const std::size_t i1min = N_GHOSTS, i1max = res[0] + N_GHOSTS; const std::size_t i2min = N_GHOSTS, i2max = res[1] + N_GHOSTS; const auto range = CreateRangePolicy({ i1min, i2min }, - { i1max, i2max }); + { i1max, i2max }); const auto range_ext = CreateRangePolicy( { 0, 0 }, { res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS }); @@ -212,7 +212,7 @@ void testFaraday(const std::vector& res) { sz = constant::TWO_PI; const auto metric = Minkowski { res, - {{ ZERO, sx }, { ZERO, sy }, { ZERO, sz }} + { { ZERO, sx }, { ZERO, sy }, { ZERO, sz } } }; auto emfield = ndfield_t { "emfield", res[0] + 2 * N_GHOSTS, @@ -222,7 +222,7 @@ void testFaraday(const std::vector& res) { const std::size_t i2min = N_GHOSTS, i2max = res[1] + N_GHOSTS; const std::size_t i3min = N_GHOSTS, i3max = res[2] + N_GHOSTS; const auto range = CreateRangePolicy({ i1min, i2min, i3min }, - { i1max, i2max, i3max }); + { i1max, i2max, i3max }); const auto range_ext = CreateRangePolicy( { 0, 0, 0 }, { res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS, res[2] + 2 * N_GHOSTS }); diff --git a/src/kernels/tests/fields_to_phys.cpp b/src/kernels/tests/fields_to_phys.cpp index 84554ed83..d05488242 100644 --- a/src/kernels/tests/fields_to_phys.cpp +++ b/src/kernels/tests/fields_to_phys.cpp @@ -45,7 +45,7 @@ void testFlds2Phys(const std::vector& res, } else { extent = { ext[0], - {ZERO, constant::PI} + { ZERO, constant::PI } }; } diff --git a/src/kernels/tests/flds_bc.cpp b/src/kernels/tests/flds_bc.cpp new file mode 100644 index 000000000..4a9f50b5c --- /dev/null +++ b/src/kernels/tests/flds_bc.cpp @@ -0,0 +1,225 @@ +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/comparators.h" +#include "utils/error.h" + +#include "metrics/minkowski.h" + +#include "kernels/fields_bcs.hpp" + +#include + +#include +#include +#include + +using namespace ntt; +using namespace kernel::bc; +using namespace metric; + +void errorIf(bool condition, const std::string& message) { + if (condition) { + throw std::runtime_error(message); + } +} + +template +struct DummyFieldsBCs { + DummyFieldsBCs() {} + + Inline auto ex1(const coord_t&) const -> real_t { + return TWO; + } + + Inline auto ex2(const coord_t&) const -> real_t { + return THREE; + } + + Inline auto bx2(const coord_t&) const -> real_t { + return FOUR; + } + + Inline auto bx3(const coord_t&) const -> real_t { + return FIVE; + } +}; + +Inline auto equal(real_t a, real_t b, const char* msg, real_t acc) -> bool { + if (not(math::abs(a - b) < acc)) { + Kokkos::printf("%.12e != %.12e [%.12e] %s\n", a, b, math::abs(a - b), msg); + return false; + } + return true; +} + +template +void testFldsBCs(const std::vector& res) { + errorIf(res.size() != (dim_t)D, "res.size() != D"); + boundaries_t sx; + for (const auto& r : res) { + sx.emplace_back(ZERO, r); + } + const auto metric = Minkowski { res, sx }; + auto fset = DummyFieldsBCs {}; + ndfield_t flds; + if constexpr (D == Dim::_1D) { + flds = ndfield_t { "flds", res[0] + 2 * N_GHOSTS }; + } else if constexpr (D == Dim::_2D) { + flds = ndfield_t { "flds", res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS }; + } else if constexpr (D == Dim::_3D) { + flds = ndfield_t { "flds", + res[0] + 2 * N_GHOSTS, + res[1] + 2 * N_GHOSTS, + res[2] + 2 * N_GHOSTS }; + } + + range_t range; + + if constexpr (D == Dim::_1D) { + range = CreateRangePolicy({ res[0] / 2 + N_GHOSTS }, + { res[0] + 2 * N_GHOSTS }); + } else if constexpr (D == Dim::_2D) { + range = CreateRangePolicy({ res[0] / 2 + N_GHOSTS, 0 }, + { res[0] + 2 * N_GHOSTS, res[1] + N_GHOSTS }); + } else if constexpr (D == Dim::_3D) { + range = CreateRangePolicy( + { res[0] / 2 + N_GHOSTS, 0, 0 }, + { res[0] + 2 * N_GHOSTS, res[1] + N_GHOSTS, res[2] + N_GHOSTS }); + } + + const auto xg_edge = (real_t)(sx[0].second); + const auto dx_abs = (real_t)(res[0] / 10.0); + boundaries_t flds_bc { + { FldsBC::PERIODIC, FldsBC::PERIODIC }, + { FldsBC::PERIODIC, FldsBC::PERIODIC }, + { FldsBC::PERIODIC, FldsBC::PERIODIC } + }; + Kokkos::parallel_for( + "MatchBoundaries_kernel", + range, + MatchBoundaries_kernel( + flds, + fset, + metric, + xg_edge, + dx_abs, + BC::E | BC::B, + flds_bc)); + + if constexpr (D == Dim::_1D) { + Kokkos::parallel_for( + "MatchBoundaries_kernel", + CreateRangePolicy({ N_GHOSTS }, { res[0] + N_GHOSTS }), + Lambda(index_t i1) { + const auto x = static_cast(i1 - N_GHOSTS); + const auto factor1 = math::tanh( + FOUR * math::abs(x + HALF - xg_edge) / dx_abs); + const auto factor2 = math::tanh(FOUR * math::abs(x - xg_edge) / dx_abs); + if (not cmp::AlmostEqual(flds(i1, em::ex1), TWO * (ONE - factor1))) { + Kokkos::printf("%f != %f\n", flds(i1, em::ex1), TWO * (ONE - factor1)); + raise::KernelError(HERE, "incorrect ex1"); + } + if (not cmp::AlmostEqual(flds(i1, em::ex2), THREE * (ONE - factor2))) { + Kokkos::printf("%f != %f\n", flds(i1, em::ex2), THREE * (ONE - factor2)); + raise::KernelError(HERE, "incorrect ex2"); + } + if (not cmp::AlmostEqual(flds(i1, em::bx2), FOUR * (ONE - factor1))) { + Kokkos::printf("%f != %f\n", flds(i1, em::bx2), FOUR * (ONE - factor1)); + raise::KernelError(HERE, "incorrect bx2"); + } + if (not cmp::AlmostEqual(flds(i1, em::bx3), FIVE * (ONE - factor1))) { + Kokkos::printf("%f != %f\n", flds(i1, em::bx3), FIVE * (ONE - factor1)); + raise::KernelError(HERE, "incorrect bx3"); + } + }); + } else if constexpr (D == Dim::_2D) { + Kokkos::parallel_for( + "MatchBoundaries_kernel", + CreateRangePolicy({ N_GHOSTS, N_GHOSTS }, + { res[0] + N_GHOSTS, res[1] + N_GHOSTS }), + Lambda(index_t i1, index_t i2) { + const auto x = static_cast(i1 - N_GHOSTS); + const auto factor1 = math::tanh( + FOUR * math::abs(x + HALF - xg_edge) / dx_abs); + const auto factor2 = math::tanh(FOUR * math::abs(x - xg_edge) / dx_abs); + if (not cmp::AlmostEqual(flds(i1, i2, em::ex1), TWO * (ONE - factor1))) { + Kokkos::printf("%f != %f\n", flds(i1, i2, em::ex1), TWO * (ONE - factor1)); + raise::KernelError(HERE, "incorrect ex1"); + } + if (not cmp::AlmostEqual(flds(i1, i2, em::ex2), THREE * (ONE - factor2))) { + Kokkos::printf("%f != %f\n", + flds(i1, i2, em::ex2), + THREE * (ONE - factor2)); + raise::KernelError(HERE, "incorrect ex2"); + } + if (not cmp::AlmostEqual(flds(i1, i2, em::bx2), FOUR * (ONE - factor1))) { + Kokkos::printf("%f != %f\n", flds(i1, i2, em::bx2), FOUR * (ONE - factor1)); + raise::KernelError(HERE, "incorrect bx2"); + } + if (not cmp::AlmostEqual(flds(i1, i2, em::bx3), FIVE * (ONE - factor1))) { + Kokkos::printf("%f != %f\n", flds(i1, i2, em::bx3), FIVE * (ONE - factor1)); + raise::KernelError(HERE, "incorrect bx3"); + } + }); + } else if constexpr (D == Dim::_3D) { + Kokkos::parallel_for( + "MatchBoundaries_kernel", + CreateRangePolicy( + { N_GHOSTS, N_GHOSTS, N_GHOSTS }, + { res[0] + N_GHOSTS, res[1] + N_GHOSTS, res[2] + N_GHOSTS }), + Lambda(index_t i1, index_t i2, index_t i3) { + const auto x = static_cast(i1 - N_GHOSTS); + const auto factor1 = math::tanh( + FOUR * math::abs(x + HALF - xg_edge) / dx_abs); + const auto factor2 = math::tanh(FOUR * math::abs(x - xg_edge) / dx_abs); + if (not cmp::AlmostEqual(flds(i1, i2, i3, em::ex1), TWO * (ONE - factor1))) { + Kokkos::printf("%f != %f\n", + flds(i1, i2, i3, em::ex1), + TWO * (ONE - factor1)); + raise::KernelError(HERE, "incorrect ex1"); + } + if (not cmp::AlmostEqual(flds(i1, i2, i3, em::ex2), + THREE * (ONE - factor2))) { + Kokkos::printf("%f != %f\n", + flds(i1, i2, i3, em::ex2), + THREE * (ONE - factor2)); + raise::KernelError(HERE, "incorrect ex2"); + } + if (not cmp::AlmostEqual(flds(i1, i2, i3, em::bx2), + FOUR * (ONE - factor1))) { + Kokkos::printf("%f != %f\n", + flds(i1, i2, i3, em::bx2), + FOUR * (ONE - factor1)); + raise::KernelError(HERE, "incorrect bx2"); + } + if (not cmp::AlmostEqual(flds(i1, i2, i3, em::bx3), + FIVE * (ONE - factor1))) { + Kokkos::printf("%f != %f\n", + flds(i1, i2, i3, em::bx3), + FIVE * (ONE - factor1)); + raise::KernelError(HERE, "incorrect bx3"); + } + }); + } +} + +auto main(int argc, char* argv[]) -> int { + Kokkos::initialize(argc, argv); + + try { + using namespace ntt; + + testFldsBCs({ 24 }); + testFldsBCs({ 64, 32 }); + testFldsBCs({ 14, 22, 15 }); + + } catch (std::exception& e) { + std::cerr << e.what() << std::endl; + Kokkos::finalize(); + return 1; + } + Kokkos::finalize(); + return 0; +} diff --git a/src/kernels/tests/gca_pusher.cpp b/src/kernels/tests/gca_pusher.cpp index c96ce3d66..5630de414 100644 --- a/src/kernels/tests/gca_pusher.cpp +++ b/src/kernels/tests/gca_pusher.cpp @@ -2,6 +2,7 @@ #include "global.h" #include "arch/kokkos_aliases.h" +#include "utils/error.h" #include "utils/numeric.h" #include "metrics/minkowski.h" @@ -10,20 +11,35 @@ #include #include +#include +#include +#include +#include +#include #include -#include -#include #include #include using namespace ntt; using namespace metric; -void errorIf(bool condition, const std::string& message) { - if (condition) { - throw std::runtime_error(message); - } +void check_value(unsigned int t, + real_t target, + real_t value, + real_t eps, + const std::string& msg) { + const auto msg_ = fmt::format("%s: %.12e != %.12e @ %u", + msg.c_str(), + target, + value, + t); + const auto diff = math::abs(target - value); + const auto sum = HALF * (math::abs(target) + math::abs(value)); + raise::ErrorIf(((sum > eps) and (diff / sum > eps)) or + ((sum <= eps) and (diff > eps / 10.0)), + msg_ + " " + fmt::format("%.12e, %.12e", diff, sum), + HERE); } template @@ -35,24 +51,20 @@ void put_value(array_t& arr, T v, index_t p) { } template -void testGCAPusher(const std::vector& res, - const boundaries_t& ext, - const std::map& params = {}) { +void testPusher(const std::vector& res) { static_assert(M::Dim == 3); - errorIf(res.size() != M::Dim, "res.size() != M::Dim"); - - boundaries_t extent; - extent = ext; + raise::ErrorIf(res.size() != M::Dim, "res.size() != M::Dim", HERE); - M metric { res, extent, params }; + M metric { + res, + { { 0.0, (real_t)(res[0]) }, { 0.0, (real_t)(res[1]) }, { 0.0, (real_t)(res[2]) } }, + {} + }; const int nx1 = res[0]; const int nx2 = res[1]; const int nx3 = res[2]; - auto coeff = real_t { 1.0 }; - auto dt = real_t { 0.01 }; - const auto range_ext = CreateRangePolicy( { 0, 0, 0 }, { res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS, res[2] + 2 * N_GHOSTS }); @@ -62,16 +74,29 @@ void testGCAPusher(const std::vector& res, res[1] + 2 * N_GHOSTS, res[2] + 2 * N_GHOSTS }; + const real_t bx1 = 0.66, bx2 = 0.55, bx3 = 0.44; + const real_t x1_0 = 1.15, x2_0 = 1.85, x3_0 = 1.25; + const real_t ux1_0 = 1.0, ux2_0 = -2.0, ux3_0 = 0.1; + const real_t omegaB0 = 0.2; + const real_t dt = 0.01; + + const real_t b_mag = math::sqrt(NORM_SQR(bx1, bx2, bx3)); + const real_t upar_0 = DOT(ux1_0, ux2_0, ux3_0, bx1, bx2, bx3) / b_mag; + + const real_t ux1_expect = bx1 * upar_0 / (b_mag); + const real_t ux2_expect = bx2 * upar_0 / (b_mag); + const real_t ux3_expect = bx3 * upar_0 / (b_mag); + Kokkos::parallel_for( "init 3D", range_ext, Lambda(index_t i1, index_t i2, index_t i3) { - emfield(i1, i2, i3, em::ex1) = 0.0; - emfield(i1, i2, i3, em::ex2) = 0.0; - emfield(i1, i2, i3, em::ex3) = 0.0; - emfield(i1, i2, i3, em::bx1) = 0.22; - emfield(i1, i2, i3, em::bx2) = 0.44; - emfield(i1, i2, i3, em::bx3) = 0.66; + emfield(i1, i2, i3, em::ex1) = ZERO; + emfield(i1, i2, i3, em::ex2) = ZERO; + emfield(i1, i2, i3, em::ex3) = ZERO; + emfield(i1, i2, i3, em::bx1) = bx1; + emfield(i1, i2, i3, em::bx2) = bx2; + emfield(i1, i2, i3, em::bx3) = bx3; }); array_t i1 { "i1", 2 }; @@ -93,119 +118,77 @@ void testGCAPusher(const std::vector& res, array_t weight { "weight", 2 }; array_t tag { "tag", 2 }; - put_value(i1, 5, 0); - put_value(i2, 5, 0); - put_value(i3, 5, 0); - put_value(dx1, (prtldx_t)(0.15), 0); - put_value(dx2, (prtldx_t)(0.85), 0); - put_value(dx3, (prtldx_t)(0.25), 0); - put_value(ux1, (real_t)(1.0), 0); - put_value(ux2, (real_t)(-2.0), 0); - put_value(ux3, (real_t)(0.1), 0); + put_value(i1, (int)(x1_0), 0); + put_value(i2, (int)(x2_0), 0); + put_value(i3, (int)(x3_0), 0); + put_value(dx1, (prtldx_t)(x1_0 - (int)(x1_0)), 0); + put_value(dx2, (prtldx_t)(x2_0 - (int)(x2_0)), 0); + put_value(dx3, (prtldx_t)(x3_0 - (int)(x3_0)), 0); + put_value(ux1, ux1_0, 0); + put_value(ux2, ux2_0, 0); + put_value(ux3, ux3_0, 0); put_value(tag, ParticleTag::alive, 0); - put_value(i1, 5, 1); - put_value(i2, 5, 1); - put_value(i3, 5, 1); - put_value(dx1, (prtldx_t)(0.15), 1); - put_value(dx2, (prtldx_t)(0.85), 1); - put_value(dx3, (prtldx_t)(0.25), 1); - put_value(ux1, (real_t)(1.0), 1); - put_value(ux2, (real_t)(-2.0), 1); - put_value(ux3, (real_t)(0.1), 1); + put_value(i1, (int)(x1_0), 1); + put_value(i2, (int)(x2_0), 1); + put_value(i3, (int)(x3_0), 1); + put_value(dx1, (prtldx_t)(x1_0 - (int)(x1_0)), 1); + put_value(dx2, (prtldx_t)(x2_0 - (int)(x2_0)), 1); + put_value(dx3, (prtldx_t)(x3_0 - (int)(x3_0)), 1); + put_value(ux1, -ux1_0, 1); + put_value(ux2, -ux2_0, 1); + put_value(ux3, -ux3_0, 1); put_value(tag, ParticleTag::alive, 1); // Particle boundaries auto boundaries = boundaries_t {}; boundaries = { - {PrtlBC::PERIODIC, PrtlBC::PERIODIC}, - {PrtlBC::PERIODIC, PrtlBC::PERIODIC}, - {PrtlBC::PERIODIC, PrtlBC::PERIODIC} + { PrtlBC::PERIODIC, PrtlBC::PERIODIC }, + { PrtlBC::PERIODIC, PrtlBC::PERIODIC }, + { PrtlBC::PERIODIC, PrtlBC::PERIODIC } }; - // clang-format off - Kokkos::parallel_for( - "pusher", - 1, - kernel::sr::Pusher_kernel>(PrtlPusher::BORIS, - true, false, kernel::sr::Cooling::None, - emfield, - 1, - i1, i2, i3, - i1_prev, i2_prev, i3_prev, - dx1, dx2, dx3, - dx1_prev, dx2_prev, dx3_prev, - ux1, ux2, ux3, - phi, tag, - metric, - ZERO, coeff, dt, - nx1, nx2, nx3, - boundaries, - (real_t)100000.0, (real_t)1.0, ZERO)); - - Kokkos::parallel_for( - "pusher", - CreateRangePolicy({ 0 }, { 1 }), - kernel::sr::Pusher_kernel>(PrtlPusher::BORIS, - true, false, kernel::sr::Cooling::None, - emfield, - 1, - i1, i2, i3, - i1_prev, i2_prev, i3_prev, - dx1, dx2, dx3, - dx1_prev, dx2_prev, dx3_prev, - ux1, ux2, ux3, - phi, tag, - metric, - ZERO, -coeff, dt, - nx1, nx2, nx3, - boundaries, - (real_t)100000.0, (real_t)1.0, ZERO)); - // clang-format on - - auto i1_prev_ = Kokkos::create_mirror_view(i1_prev); - auto i2_prev_ = Kokkos::create_mirror_view(i2_prev); - auto i3_prev_ = Kokkos::create_mirror_view(i3_prev); - auto i1_ = Kokkos::create_mirror_view(i1); - auto i2_ = Kokkos::create_mirror_view(i2); - auto i3_ = Kokkos::create_mirror_view(i3); - Kokkos::deep_copy(i1_prev_, i1_prev); - Kokkos::deep_copy(i2_prev_, i2_prev); - Kokkos::deep_copy(i3_prev_, i3_prev); - Kokkos::deep_copy(i1_, i1); - Kokkos::deep_copy(i2_, i2); - Kokkos::deep_copy(i3_, i3); - - auto dx1_prev_ = Kokkos::create_mirror_view(dx1_prev); - auto dx2_prev_ = Kokkos::create_mirror_view(dx2_prev); - auto dx3_prev_ = Kokkos::create_mirror_view(dx3_prev); - auto dx1_ = Kokkos::create_mirror_view(dx1); - auto dx2_ = Kokkos::create_mirror_view(dx2); - auto dx3_ = Kokkos::create_mirror_view(dx3); - Kokkos::deep_copy(dx1_prev_, dx1_prev); - Kokkos::deep_copy(dx2_prev_, dx2_prev); - Kokkos::deep_copy(dx3_prev_, dx3_prev); - Kokkos::deep_copy(dx1_, dx1); - Kokkos::deep_copy(dx2_, dx2); - Kokkos::deep_copy(dx3_, dx3); - - auto disx = i1_[0] + dx1_[0] - i1_prev_[0] - dx1_prev_[0]; - auto disy = i2_[0] + dx2_[0] - i2_prev_[0] - dx2_prev_[0]; - auto disz = i3_[0] + dx3_[0] - i3_prev_[0] - dx3_prev_[0]; - - auto disdotB = (disx * 0.22 + disy * 0.44 + disz * 0.66) / - (0.823165 * math::sqrt(SQR(disx) + SQR(disy) + SQR(disz))); - - printf("%.12e \n", (1 - math::abs(disdotB))); - - disx = i1_[1] + dx1_[1] - i1_prev_[1] - dx1_prev_[1]; - disy = i2_[1] + dx2_[1] - i2_prev_[1] - dx2_prev_[1]; - disz = i3_[1] + dx3_[1] - i3_prev_[1] - dx3_prev_[1]; - - disdotB = (disx * 0.22 + disy * 0.44 + disz * 0.66) / - (0.823165 * math::sqrt(SQR(disx) + SQR(disy) + SQR(disz))); - - printf("%.12e \n", (1 - math::abs(disdotB))); + const spidx_t sp { 1u }; + + const real_t coeff = HALF * dt * omegaB0; + + const real_t eps = std::is_same_v ? 1e-3 : 1e-6; + + for (auto t { 0u }; t < 2000; ++t) { + // clang-format off + Kokkos::parallel_for( + "pusher", + CreateRangePolicy({0}, {2}), + kernel::sr::Pusher_kernel>(PrtlPusher::BORIS, + true, false, kernel::sr::Cooling::None, + emfield, + sp, + i1, i2, i3, + i1_prev, i2_prev, i3_prev, + dx1, dx2, dx3, + dx1_prev, dx2_prev, dx3_prev, + ux1, ux2, ux3, + phi, tag, + metric, + ZERO, coeff, dt, + nx1, nx2, nx3, + boundaries, + (real_t)10000.0, ONE, ZERO)); + + auto ux1_ = Kokkos::create_mirror_view(ux1); + auto ux2_ = Kokkos::create_mirror_view(ux2); + auto ux3_ = Kokkos::create_mirror_view(ux3); + Kokkos::deep_copy(ux1_, ux1); + Kokkos::deep_copy(ux2_, ux2); + Kokkos::deep_copy(ux3_, ux3); + + check_value(t, ux1_(0), ux1_expect, eps, "Particle #1 ux1"); + check_value(t, ux2_(0), ux2_expect, eps, "Particle #1 ux2"); + check_value(t, ux3_(0), ux3_expect, eps, "Particle #1 ux3"); + check_value(t, ux1_(1), -ux1_expect, eps, "Particle #2 ux1"); + check_value(t, ux2_(1), -ux2_expect, eps, "Particle #2 ux2"); + check_value(t, ux3_(1), -ux3_expect, eps, "Particle #2 ux3"); + } } auto main(int argc, char* argv[]) -> int { @@ -214,14 +197,7 @@ auto main(int argc, char* argv[]) -> int { try { using namespace ntt; - testGCAPusher>( - { - 10, - 10, - 10 - }, - { { 0.0, 10.0 }, { 0.0, 10.0 }, { 0.0, 10.0 } }, - {}); + testPusher>({ 10, 10, 10 }); } catch (std::exception& e) { std::cerr << e.what() << std::endl; diff --git a/src/kernels/tests/particle_moments.cpp b/src/kernels/tests/particle_moments.cpp index 25ed7d4d9..ca3c2a7a0 100644 --- a/src/kernels/tests/particle_moments.cpp +++ b/src/kernels/tests/particle_moments.cpp @@ -12,6 +12,8 @@ #include "metrics/qspherical.h" #include "metrics/spherical.h" +#include "kernels/reduced_stats.hpp" + #include #include @@ -20,7 +22,6 @@ #include #include #include -#include #include using namespace ntt; @@ -56,7 +57,7 @@ void testParticleMoments(const std::vector& res, } else { extent = { ext[0], - {ZERO, constant::PI} + { ZERO, constant::PI } }; } @@ -106,8 +107,8 @@ void testParticleMoments(const std::vector& res, auto boundaries = boundaries_t {}; if constexpr (M::CoordType != Coord::Cart) { boundaries = { - {FldsBC::CUSTOM, FldsBC::CUSTOM}, - { FldsBC::AXIS, FldsBC::AXIS} + { FldsBC::CUSTOM, FldsBC::CUSTOM }, + { FldsBC::AXIS, FldsBC::AXIS } }; } @@ -117,84 +118,85 @@ void testParticleMoments(const std::vector& res, const unsigned short window = 1; auto scatter_buff = Kokkos::Experimental::create_scatter_view(buff); + // clang-format off Kokkos::parallel_for( - "ParticleMoments", - 10, - kernel::ParticleMoments_kernel(comp1, - scatter_buff, - 0, - i1, - i2, - i3, - dx1, - dx2, - dx3, - ux1, - ux2, - ux3, - phi, - weight, - tag, - mass, - charge, + "ParticleMoments", 10, + kernel::ParticleMoments_kernel(comp1, scatter_buff, 0, + i1, i2, i3, + dx1, dx2, dx3, + ux1, ux2, ux3, + phi, weight, tag, + mass, charge, use_weights, metric, - boundaries, - nx2, - inv_n0, - window)); + boundaries, nx2, inv_n0, window)); + Kokkos::parallel_for( - "ParticleMoments", - 10, - kernel::ParticleMoments_kernel(comp2, - scatter_buff, - 1, - i1, - i2, - i3, - dx1, - dx2, - dx3, - ux1, - ux2, - ux3, - phi, - weight, - tag, - mass, - charge, + "ParticleMoments", 10, + kernel::ParticleMoments_kernel(comp2, scatter_buff, 1, + i1, i2, i3, + dx1, dx2, dx3, + ux1, ux2, ux3, + phi, weight, tag, + mass, charge, use_weights, metric, - boundaries, - nx2, - inv_n0, - window)); + boundaries, nx2, inv_n0, window)); Kokkos::parallel_for( - "ParticleMoments", - 10, - kernel::ParticleMoments_kernel(comp3, - scatter_buff, - 2, - i1, - i2, - i3, - dx1, - dx2, - dx3, - ux1, - ux2, - ux3, - phi, - weight, - tag, - mass, - charge, + "ParticleMoments", 10, + kernel::ParticleMoments_kernel(comp3, scatter_buff, 2, + i1, i2, i3, + dx1, dx2, dx3, + ux1, ux2, ux3, + phi, weight, tag, + mass, charge, use_weights, metric, - boundaries, - nx2, - inv_n0, - window)); + boundaries, nx2, inv_n0, window)); + + real_t n = ZERO, npart = ZERO, rho = ZERO, t00 = ZERO; + Kokkos::parallel_reduce( + "ReducedParticleMoments", 10, + kernel::ReducedParticleMoments_kernel({}, + i1, i2, i3, + dx1, dx2, dx3, + ux1, ux2, ux3, + phi, weight, tag, + mass, charge, + use_weights, + metric), n); + + Kokkos::parallel_reduce( + "ReducedParticleMoments", 10, + kernel::ReducedParticleMoments_kernel({}, + i1, i2, i3, + dx1, dx2, dx3, + ux1, ux2, ux3, + phi, weight, tag, + mass, charge, + use_weights, + metric), npart); + Kokkos::parallel_reduce( + "ReducedParticleMoments", 10, + kernel::ReducedParticleMoments_kernel({}, + i1, i2, i3, + dx1, dx2, dx3, + ux1, ux2, ux3, + phi, weight, tag, + mass, charge, + use_weights, + metric), rho); + Kokkos::parallel_reduce( + "ReducedParticleMoments", 10, + kernel::ReducedParticleMoments_kernel({0u, 0u}, + i1, i2, i3, + dx1, dx2, dx3, + ux1, ux2, ux3, + phi, weight, tag, + mass, charge, + use_weights, + metric), t00); + // clang-format on Kokkos::Experimental::contribute(buff, scatter_buff); auto i1_h = Kokkos::create_mirror_view(i1); @@ -231,6 +233,9 @@ void testParticleMoments(const std::vector& res, const real_t gammaSQR_1_expect = 15.0; const real_t gammaSQR_2_expect = 15.0; + const real_t n_expect = 2.0; + const real_t t00_expect = 2.0 * math::sqrt(15.0); + errorIf(not cmp::AlmostEqual_host(gammaSQR_1, gammaSQR_1_expect, epsilon * acc), fmt::format("wrong gamma_1 %.12e %.12e for %dD %s", gammaSQR_1, @@ -243,6 +248,31 @@ void testParticleMoments(const std::vector& res, gammaSQR_2_expect, metric.Dim, metric.Label)); + + errorIf(not cmp::AlmostEqual_host(n, n_expect, epsilon * acc), + fmt::format("wrong n reduction %.12e %.12e for %dD %s", + n, + n_expect, + metric.Dim, + metric.Label)); + errorIf(not cmp::AlmostEqual_host(npart, n_expect, epsilon * acc), + fmt::format("wrong npart reduction %.12e %.12e for %dD %s", + npart, + n_expect, + metric.Dim, + metric.Label)); + errorIf(not cmp::AlmostEqual_host(rho, n_expect, epsilon * acc), + fmt::format("wrong rho reduction %.12e %.12e for %dD %s", + rho, + n_expect, + metric.Dim, + metric.Label)); + errorIf(not cmp::AlmostEqual_host(t00, t00_expect, epsilon * acc), + fmt::format("wrong t00 reduction %.12e %.12e for %dD %s", + t00, + t00_expect, + metric.Dim, + metric.Label)); } } @@ -286,4 +316,4 @@ auto main(int argc, char* argv[]) -> int { } Kokkos::finalize(); return 0; -} \ No newline at end of file +} diff --git a/src/kernels/tests/prtl_bc.cpp b/src/kernels/tests/prtl_bc.cpp index c8f9eae04..f7f8be43b 100644 --- a/src/kernels/tests/prtl_bc.cpp +++ b/src/kernels/tests/prtl_bc.cpp @@ -29,7 +29,7 @@ void errorIf(bool condition, const std::string& message = "") { Inline auto equal(real_t a, real_t b, const std::string& msg) -> bool { if (not(math::abs(a - b) < 1e-4)) { - printf("%.12e != %.12e %s\n", a, b, msg.c_str()); + Kokkos::printf("%.12e != %.12e %s\n", a, b, msg.c_str()); return false; } return true; @@ -53,35 +53,45 @@ void testPeriodicBC(const std::vector& res, const auto NoGCA = false; const auto NoExtForce = false; - boundaries_t extent; - extent = ext; - const auto sx = static_cast(extent[0].second - extent[0].first); - const auto sy = static_cast( - extent.size() > 1 ? extent[1].second - extent[1].first : 0); - const auto sz = static_cast( - extent.size() > 2 ? extent[2].second - extent[2].first : 0); + real_t sx = ZERO, sy = ZERO, sz = ZERO; + if (ext.size() > 0) { + sx = static_cast(ext.at(0).second - ext.at(0).first); + } + if (ext.size() > 1) { + sy = static_cast(ext.at(1).second - ext.at(1).first); + } + if (ext.size() > 2) { + sz = static_cast(ext.at(2).second - ext.at(2).first); + } - M metric { res, extent, params }; + M metric { res, ext, params }; - const int nx1 = res[0]; - const int nx2 = res[1]; - const int nx3 = res[2]; + int nx1 = 0, nx2 = 0, nx3 = 0; + if (res.size() > 0) { + nx1 = static_cast(res.at(0)); + } + if (res.size() > 1) { + nx2 = static_cast(res.at(1)); + } + if (res.size() > 2) { + nx3 = static_cast(res.at(2)); + } - const real_t dt = 0.1 * (extent[0].second - extent[0].first) / sx; + const real_t dt = 0.1 * (ext.at(0).second - ext.at(0).first) / sx; const real_t coeff = HALF * dt; ndfield_t emfield; if constexpr (M::Dim == Dim::_1D) { - emfield = ndfield_t { "emfield", res[0] + 2 * N_GHOSTS }; + emfield = ndfield_t { "emfield", res.at(0) + 2 * N_GHOSTS }; } else if constexpr (M::Dim == Dim::_2D) { emfield = ndfield_t { "emfield", - res[0] + 2 * N_GHOSTS, - res[1] + 2 * N_GHOSTS }; + res.at(0) + 2 * N_GHOSTS, + res.at(1) + 2 * N_GHOSTS }; } else { emfield = ndfield_t { "emfield", - res[0] + 2 * N_GHOSTS, - res[1] + 2 * N_GHOSTS, - res[2] + 2 * N_GHOSTS }; + res.at(0) + 2 * N_GHOSTS, + res.at(1) + 2 * N_GHOSTS, + res.at(2) + 2 * N_GHOSTS }; } const short sp_idx = 1; @@ -106,18 +116,26 @@ void testPeriodicBC(const std::vector& res, array_t phi; // init parameters of prtl #1 - real_t xi_1 = 0.515460 * sx + extent[0].first; - real_t yi_1 = 0.340680 * sy + extent[1].first; - real_t zi_1 = 0.940722 * sz + extent[2].first; + real_t xi_1 = ZERO, yi_1 = ZERO, zi_1 = ZERO; + real_t xi_2 = ZERO, yi_2 = ZERO, zi_2 = ZERO; + if constexpr (M::Dim == Dim::_1D or M::Dim == Dim::_2D or M::Dim == Dim::_3D) { + xi_1 = 0.515460 * sx + ext.at(0).first; + xi_2 = 0.149088 * sx + ext.at(0).first; + } + if constexpr (M::Dim == Dim::_2D or M::Dim == Dim::_3D) { + yi_1 = 0.340680 * sy + ext.at(1).first; + yi_2 = 0.997063 * sy + ext.at(1).first; + } + if constexpr (M::Dim == Dim::_3D) { + zi_1 = 0.940722 * sz + ext.at(2).first; + zi_2 = 0.607354 * sz + ext.at(2).first; + } real_t ux_1 = 0.569197; real_t uy_1 = 0.716085; real_t uz_1 = 0.760101; real_t gamma_1 = math::sqrt(1.0 + SQR(ux_1) + SQR(uy_1) + SQR(uz_1)); // init parameters of prtl #2 - real_t xi_2 = 0.149088 * sx + extent[0].first; - real_t yi_2 = 0.997063 * sy + extent[1].first; - real_t zi_2 = 0.607354 * sz + extent[2].first; real_t ux_2 = -0.872069; real_t uy_2 = 0.0484461; real_t uz_2 = -0.613575; @@ -142,21 +160,28 @@ void testPeriodicBC(const std::vector& res, xi[2] = zi_1; } metric.template convert_xyz(xi, xCd); - put_value(i1, static_cast(xCd[0]), prtl_idx); - put_value(i2, static_cast(xCd[1]), prtl_idx); - put_value(i3, static_cast(xCd[2]), prtl_idx); - put_value(dx1, - static_cast(xCd[0]) - - static_cast(static_cast(xCd[0])), - prtl_idx); - put_value(dx2, - static_cast(xCd[1]) - - static_cast(static_cast(xCd[1])), - prtl_idx); - put_value(dx3, - static_cast(xCd[2]) - - static_cast(static_cast(xCd[2])), - prtl_idx); + if constexpr (M::PrtlDim == Dim::_1D or M::PrtlDim == Dim::_2D or + M::PrtlDim == Dim::_3D) { + put_value(i1, static_cast(xCd[0]), prtl_idx); + put_value(dx1, + static_cast(xCd[0]) - + static_cast(static_cast(xCd[0])), + prtl_idx); + } + if constexpr (M::PrtlDim == Dim::_2D or M::PrtlDim == Dim::_3D) { + put_value(i2, static_cast(xCd[1]), prtl_idx); + put_value(dx2, + static_cast(xCd[1]) - + static_cast(static_cast(xCd[1])), + prtl_idx); + } + if constexpr (M::PrtlDim == Dim::_3D) { + put_value(i3, static_cast(xCd[2]), prtl_idx); + put_value(dx3, + static_cast(xCd[2]) - + static_cast(static_cast(xCd[2])), + prtl_idx); + } put_value(ux1, ux_1, prtl_idx); put_value(ux2, uy_1, prtl_idx); put_value(ux3, uz_1, prtl_idx); @@ -177,21 +202,28 @@ void testPeriodicBC(const std::vector& res, xi[2] = zi_2; } metric.template convert_xyz(xi, xCd); - put_value(i1, static_cast(xCd[0]), prtl_idx); - put_value(i2, static_cast(xCd[1]), prtl_idx); - put_value(i3, static_cast(xCd[2]), prtl_idx); - put_value(dx1, - static_cast(xCd[0]) - - static_cast(static_cast(xCd[0])), - prtl_idx); - put_value(dx2, - static_cast(xCd[1]) - - static_cast(static_cast(xCd[1])), - prtl_idx); - put_value(dx3, - static_cast(xCd[2]) - - static_cast(static_cast(xCd[2])), - prtl_idx); + if constexpr (M::PrtlDim == Dim::_1D or M::PrtlDim == Dim::_2D or + M::PrtlDim == Dim::_3D) { + put_value(i1, static_cast(xCd[0]), prtl_idx); + put_value(dx1, + static_cast(xCd[0]) - + static_cast(static_cast(xCd[0])), + prtl_idx); + } + if constexpr (M::PrtlDim == Dim::_2D or M::PrtlDim == Dim::_3D) { + put_value(i2, static_cast(xCd[1]), prtl_idx); + put_value(dx2, + static_cast(xCd[1]) - + static_cast(static_cast(xCd[1])), + prtl_idx); + } + if constexpr (M::PrtlDim == Dim::_3D) { + put_value(i3, static_cast(xCd[2]), prtl_idx); + put_value(dx3, + static_cast(xCd[2]) - + static_cast(static_cast(xCd[2])), + prtl_idx); + } put_value(ux1, ux_2, prtl_idx); put_value(ux2, uy_2, prtl_idx); put_value(ux3, uz_2, prtl_idx); @@ -201,9 +233,9 @@ void testPeriodicBC(const std::vector& res, // Particle boundaries auto boundaries = boundaries_t {}; boundaries = { - {PrtlBC::PERIODIC, PrtlBC::PERIODIC}, - {PrtlBC::PERIODIC, PrtlBC::PERIODIC}, - {PrtlBC::PERIODIC, PrtlBC::PERIODIC} + { PrtlBC::PERIODIC, PrtlBC::PERIODIC }, + { PrtlBC::PERIODIC, PrtlBC::PERIODIC }, + { PrtlBC::PERIODIC, PrtlBC::PERIODIC } }; real_t time = ZERO; @@ -268,16 +300,16 @@ void testPeriodicBC(const std::vector& res, if constexpr (M::Dim == Dim::_1D or M::Dim == Dim::_2D or M::Dim == Dim::_3D) { xi_1 += dt * ux_1 / gamma_1; xi_2 += dt * ux_2 / gamma_2; - if (xi_1 >= extent[0].second) { + if (xi_1 >= ext.at(0).second) { xi_1 -= sx; } - if (xi_1 < extent[0].first) { + if (xi_1 < ext.at(0).first) { xi_1 += sx; } - if (xi_2 >= extent[0].second) { + if (xi_2 >= ext.at(0).second) { xi_2 -= sx; } - if (xi_2 < extent[0].first) { + if (xi_2 < ext.at(0).first) { xi_2 += sx; } errorIf(not equal(xPh_1[0] / sx, @@ -290,16 +322,16 @@ void testPeriodicBC(const std::vector& res, if constexpr (M::Dim == Dim::_2D or M::Dim == Dim::_3D) { yi_1 += dt * uy_1 / gamma_1; yi_2 += dt * uy_2 / gamma_2; - if (yi_1 >= extent[1].second) { + if (yi_1 >= ext.at(1).second) { yi_1 -= sy; } - if (yi_1 < extent[1].first) { + if (yi_1 < ext.at(1).first) { yi_1 += sy; } - if (yi_2 >= extent[1].second) { + if (yi_2 >= ext.at(1).second) { yi_2 -= sy; } - if (yi_2 < extent[1].first) { + if (yi_2 < ext.at(1).first) { yi_2 += sy; } errorIf(not equal(xPh_1[1] / sy, @@ -312,16 +344,16 @@ void testPeriodicBC(const std::vector& res, if constexpr (M::Dim == Dim::_3D) { zi_1 += dt * uz_1 / gamma_1; zi_2 += dt * uz_2 / gamma_2; - if (zi_1 >= extent[2].second) { + if (zi_1 >= ext.at(2).second) { zi_1 -= sz; } - if (zi_1 < extent[2].first) { + if (zi_1 < ext.at(2).first) { zi_1 += sz; } - if (zi_2 >= extent[2].second) { + if (zi_2 >= ext.at(2).second) { zi_2 -= sz; } - if (zi_2 < extent[2].first) { + if (zi_2 < ext.at(2).first) { zi_2 += sz; } errorIf(not equal(xPh_1[2] / sz, @@ -343,18 +375,18 @@ auto main(int argc, char* argv[]) -> int { const std::vector res1d { 50 }; const boundaries_t ext1d { - {0.0, 1000.0}, + { 0.0, 1000.0 }, }; const std::vector res2d { 30, 20 }; const boundaries_t ext2d { - {-15.0, 15.0}, - {-10.0, 10.0}, + { -15.0, 15.0 }, + { -10.0, 10.0 }, }; const std::vector res3d { 10, 10, 10 }; const boundaries_t ext3d { - {0.0, 1.0}, - {0.0, 1.0}, - {0.0, 1.0} + { 0.0, 1.0 }, + { 0.0, 1.0 }, + { 0.0, 1.0 } }; testPeriodicBC>(res1d, ext1d, {}); testPeriodicBC>(res2d, ext2d, {}); diff --git a/src/kernels/tests/prtls_to_phys.cpp b/src/kernels/tests/prtls_to_phys.cpp index 4719fe6a1..962c21b5c 100644 --- a/src/kernels/tests/prtls_to_phys.cpp +++ b/src/kernels/tests/prtls_to_phys.cpp @@ -132,7 +132,7 @@ void testPrtl2PhysSR(const std::vector& res, extent = { ext[0], - {ZERO, constant::PI} + { ZERO, constant::PI } }; const M metric { res, extent, params }; @@ -177,28 +177,29 @@ void testPrtl2PhysSR(const std::vector& res, array_t buff_ux3 { "buff_ux3", nprtl / stride }; array_t buff_wei { "buff_wei", nprtl / stride }; - Kokkos::parallel_for("Init", - Kokkos::RangePolicy(0, nprtl / stride), - kernel::PrtlToPhys_kernel(stride, - buff_x1, - buff_x2, - buff_x3, - buff_ux1, - buff_ux2, - buff_ux3, - buff_wei, - i1, - i2, - i3, - dx1, - dx2, - dx3, - ux1, - ux2, - ux3, - phi, - weight, - metric)); + Kokkos::parallel_for( + "Init", + Kokkos::RangePolicy(0, nprtl / stride), + kernel::PrtlToPhys_kernel(stride, + buff_x1, + buff_x2, + buff_x3, + buff_ux1, + buff_ux2, + buff_ux3, + buff_wei, + i1, + i2, + i3, + dx1, + dx2, + dx3, + ux1, + ux2, + ux3, + phi, + weight, + metric)); Kokkos::parallel_for("Check", nprtl / stride, Checker(metric, diff --git a/src/kernels/tests/pusher.cpp b/src/kernels/tests/pusher.cpp new file mode 100644 index 000000000..8496b592d --- /dev/null +++ b/src/kernels/tests/pusher.cpp @@ -0,0 +1,274 @@ +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/error.h" +#include "utils/numeric.h" + +#include "metrics/minkowski.h" + +#include "kernels/particle_pusher_sr.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace ntt; +using namespace metric; + +void check_value(unsigned int t, + real_t target, + real_t value, + real_t eps, + const std::string& msg) { + const auto msg_ = fmt::format("%s: %.12e != %.12e @ %u", + msg.c_str(), + target, + value, + t); + const auto diff = math::abs(target - value); + const auto sum = HALF * (math::abs(target) + math::abs(value)); + raise::ErrorIf(((sum > eps) and (diff / sum > eps)) or + ((sum <= eps) and (diff > eps / 10.0)), + msg_ + " " + fmt::format("%.12e, %.12e", diff, sum), + HERE); +} + +template +void put_value(array_t& arr, T v, index_t p) { + auto h = Kokkos::create_mirror_view(arr); + Kokkos::deep_copy(h, arr); + h(p) = v; + Kokkos::deep_copy(arr, h); +} + +template +void testPusher(const std::vector& res) { + static_assert(M::Dim == 3); + raise::ErrorIf(res.size() != M::Dim, "res.size() != M::Dim", HERE); + + M metric { + res, + { { 0.0, (real_t)(res[0]) }, { 0.0, (real_t)(res[1]) }, { 0.0, (real_t)(res[2]) } }, + {} + }; + + const int nx1 = res[0]; + const int nx2 = res[1]; + const int nx3 = res[2]; + + const auto range_ext = CreateRangePolicy( + { 0, 0, 0 }, + { res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS, res[2] + 2 * N_GHOSTS }); + + auto emfield = ndfield_t { "emfield", + res[0] + 2 * N_GHOSTS, + res[1] + 2 * N_GHOSTS, + res[2] + 2 * N_GHOSTS }; + + const real_t bx1 = 0.66, bx2 = 0.55, bx3 = 0.44; + const real_t b_mag = math::sqrt(NORM_SQR(bx1, bx2, bx3)); + const real_t x1_0 = 1.15, x2_0 = 1.85, x3_0 = 1.25; + const real_t ux1_0 = 1.0, ux2_0 = -2.0, ux3_0 = 0.1; + const real_t gamma_0 = math::sqrt(ONE + NORM_SQR(ux1_0, ux2_0, ux3_0)); + const real_t omegaB0 = 0.2; + const real_t dt = 0.01; + + Kokkos::parallel_for( + "init 3D", + range_ext, + Lambda(index_t i1, index_t i2, index_t i3) { + emfield(i1, i2, i3, em::ex1) = ZERO; + emfield(i1, i2, i3, em::ex2) = ZERO; + emfield(i1, i2, i3, em::ex3) = ZERO; + emfield(i1, i2, i3, em::bx1) = bx1; + emfield(i1, i2, i3, em::bx2) = bx2; + emfield(i1, i2, i3, em::bx3) = bx3; + }); + + array_t i1 { "i1", 2 }; + array_t i2 { "i2", 2 }; + array_t i3 { "i3", 2 }; + array_t i1_prev { "i1_prev", 2 }; + array_t i2_prev { "i2_prev", 2 }; + array_t i3_prev { "i3_prev", 2 }; + array_t dx1 { "dx1", 2 }; + array_t dx2 { "dx2", 2 }; + array_t dx3 { "dx3", 2 }; + array_t dx1_prev { "dx1_prev", 2 }; + array_t dx2_prev { "dx2_prev", 2 }; + array_t dx3_prev { "dx3_prev", 2 }; + array_t ux1 { "ux1", 2 }; + array_t ux2 { "ux2", 2 }; + array_t ux3 { "ux3", 2 }; + array_t phi { "phi", 2 }; + array_t weight { "weight", 2 }; + array_t tag { "tag", 2 }; + + put_value(i1, (int)(x1_0), 0); + put_value(i2, (int)(x2_0), 0); + put_value(i3, (int)(x3_0), 0); + put_value(dx1, (prtldx_t)(x1_0 - (int)(x1_0)), 0); + put_value(dx2, (prtldx_t)(x2_0 - (int)(x2_0)), 0); + put_value(dx3, (prtldx_t)(x3_0 - (int)(x3_0)), 0); + put_value(ux1, ux1_0, 0); + put_value(ux2, ux2_0, 0); + put_value(ux3, ux3_0, 0); + put_value(tag, ParticleTag::alive, 0); + + put_value(i1, (int)(x1_0), 1); + put_value(i2, (int)(x2_0), 1); + put_value(i3, (int)(x3_0), 1); + put_value(dx1, (prtldx_t)(x1_0 - (int)(x1_0)), 1); + put_value(dx2, (prtldx_t)(x2_0 - (int)(x2_0)), 1); + put_value(dx3, (prtldx_t)(x3_0 - (int)(x3_0)), 1); + put_value(ux1, ux1_0, 1); + put_value(ux2, ux2_0, 1); + put_value(ux3, ux3_0, 1); + put_value(tag, ParticleTag::alive, 1); + + // Particle boundaries + auto boundaries = boundaries_t {}; + boundaries = { + { PrtlBC::PERIODIC, PrtlBC::PERIODIC }, + { PrtlBC::PERIODIC, PrtlBC::PERIODIC }, + { PrtlBC::PERIODIC, PrtlBC::PERIODIC } + }; + + const spidx_t sp { 1u }; + + const real_t coeff = HALF * dt * omegaB0; + + const auto u0_dot_b = (ux1_0 * bx1 + ux2_0 * bx2 + ux3_0 * bx3) / b_mag; + const auto u0_cross_b_x1 = (ux2_0 * bx3 - ux3_0 * bx2) / b_mag; + const auto u0_cross_b_x2 = (ux3_0 * bx1 - ux1_0 * bx3) / b_mag; + const auto u0_cross_b_x3 = (ux1_0 * bx2 - ux2_0 * bx1) / b_mag; + + const real_t eps = std::is_same_v ? 1e-2 : 1e-3; + + for (auto t { 0u }; t < 2000; ++t) { + const real_t time = t * dt; + + // clang-format off + Kokkos::parallel_for( + "pusher", + CreateRangePolicy({0}, {1}), + kernel::sr::Pusher_kernel>(PrtlPusher::BORIS, + false, false, kernel::sr::Cooling::None, + emfield, + sp, + i1, i2, i3, + i1_prev, i2_prev, i3_prev, + dx1, dx2, dx3, + dx1_prev, dx2_prev, dx3_prev, + ux1, ux2, ux3, + phi, tag, + metric, + ZERO, coeff, dt, + nx1, nx2, nx3, + boundaries, + ZERO, ZERO, ZERO)); + + Kokkos::parallel_for( + "pusher", + CreateRangePolicy({1}, {2}), + kernel::sr::Pusher_kernel>(PrtlPusher::VAY, + false, false, kernel::sr::Cooling::None, + emfield, + sp, + i1, i2, i3, + i1_prev, i2_prev, i3_prev, + dx1, dx2, dx3, + dx1_prev, dx2_prev, dx3_prev, + ux1, ux2, ux3, + phi, tag, + metric, + ZERO, coeff, dt, + nx1, nx2, nx3, + boundaries, + ZERO, ZERO, ZERO)); + + auto i1_prev_ = Kokkos::create_mirror_view(i1_prev); + auto i2_prev_ = Kokkos::create_mirror_view(i2_prev); + auto i3_prev_ = Kokkos::create_mirror_view(i3_prev); + auto i1_ = Kokkos::create_mirror_view(i1); + auto i2_ = Kokkos::create_mirror_view(i2); + auto i3_ = Kokkos::create_mirror_view(i3); + Kokkos::deep_copy(i1_prev_, i1_prev); + Kokkos::deep_copy(i2_prev_, i2_prev); + Kokkos::deep_copy(i3_prev_, i3_prev); + Kokkos::deep_copy(i1_, i1); + Kokkos::deep_copy(i2_, i2); + Kokkos::deep_copy(i3_, i3); + + auto dx1_prev_ = Kokkos::create_mirror_view(dx1_prev); + auto dx2_prev_ = Kokkos::create_mirror_view(dx2_prev); + auto dx3_prev_ = Kokkos::create_mirror_view(dx3_prev); + auto dx1_ = Kokkos::create_mirror_view(dx1); + auto dx2_ = Kokkos::create_mirror_view(dx2); + auto dx3_ = Kokkos::create_mirror_view(dx3); + auto ux1_ = Kokkos::create_mirror_view(ux1); + auto ux2_ = Kokkos::create_mirror_view(ux2); + auto ux3_ = Kokkos::create_mirror_view(ux3); + Kokkos::deep_copy(dx1_prev_, dx1_prev); + Kokkos::deep_copy(dx2_prev_, dx2_prev); + Kokkos::deep_copy(dx3_prev_, dx3_prev); + Kokkos::deep_copy(dx1_, dx1); + Kokkos::deep_copy(dx2_, dx2); + Kokkos::deep_copy(dx3_, dx3); + Kokkos::deep_copy(ux1_, ux1); + Kokkos::deep_copy(ux2_, ux2); + Kokkos::deep_copy(ux3_, ux3); + + const real_t gamma1 = math::sqrt(ONE + NORM_SQR(ux1_(0), ux2_(0), ux3_(0))); + const real_t gamma2 = math::sqrt(ONE + NORM_SQR(ux1_(1), ux2_(1), ux3_(1))); + + check_value(t, gamma1, gamma_0, eps, "Particle #1 Lorentz factor"); + check_value(t, gamma2, gamma_0, eps, "Particle #2 Lorentz factor"); + + const real_t arg = (b_mag * omegaB0 * (time + dt)) / gamma_0; + const real_t ux1_expect = (bx1 / b_mag) * u0_dot_b + + (-(bx1 / b_mag) * u0_dot_b + ux1_0) * math::cos(arg) + + u0_cross_b_x1 * math::sin(arg); + const real_t ux2_expect = (bx2 / b_mag) * u0_dot_b + + (-(bx2 / b_mag) * u0_dot_b + ux2_0) * math::cos(arg) + + u0_cross_b_x2 * math::sin(arg); + const real_t ux3_expect = (bx3 / b_mag) * u0_dot_b + + (-(bx3 / b_mag) * u0_dot_b + ux3_0) * math::cos(arg) + + u0_cross_b_x3 * math::sin(arg); + + check_value(t, ux1_(0), ux1_expect, eps, "Particle #1 ux1"); + check_value(t, ux2_(0), ux2_expect, eps, "Particle #1 ux2"); + check_value(t, ux3_(0), ux3_expect, eps, "Particle #1 ux3"); + + check_value(t, ux1_(1), ux1_expect, eps, "Particle #2 ux1"); + check_value(t, ux2_(1), ux2_expect, eps, "Particle #2 ux2"); + check_value(t, ux3_(1), ux3_expect, eps, "Particle #2 ux3"); + + } +} + +auto main(int argc, char* argv[]) -> int { + Kokkos::initialize(argc, argv); + + try { + using namespace ntt; + + testPusher>({ 10, 10, 10 }); + + } catch (std::exception& e) { + std::cerr << e.what() << std::endl; + Kokkos::finalize(); + return 1; + } + Kokkos::finalize(); + return 0; +} diff --git a/src/kernels/tests/reduced_stats.cpp b/src/kernels/tests/reduced_stats.cpp new file mode 100644 index 000000000..ee036395a --- /dev/null +++ b/src/kernels/tests/reduced_stats.cpp @@ -0,0 +1,307 @@ +#include "kernels/reduced_stats.hpp" + +#include "enums.h" +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/error.h" + +#include "metrics/minkowski.h" + +#include +#include + +using namespace ntt; +using namespace metric; + +template +class Fill_kernel { + ndfield_t arr; + real_t v; + unsigned short c; + +public: + Fill_kernel(ndfield_t& arr_, real_t v_, unsigned short c_) + : arr { arr_ } + , v { v_ } + , c { c_ } { + raise::ErrorIf(c_ >= N, "c > N", HERE); + } + + Inline void operator()(index_t i1) const { + arr(i1, c) = v; + } + + Inline void operator()(index_t i1, index_t i2) const { + arr(i1, i2, c) = v; + } + + Inline void operator()(index_t i1, index_t i2, index_t i3) const { + arr(i1, i2, i3, c) = v; + } +}; + +template +void put_value(ndfield_t& arr, real_t v, unsigned short c) { + range_t range; + if constexpr (D == Dim::_1D) { + range = { + { 0u, arr.extent(0) } + }; + } else if constexpr (D == Dim::_2D) { + range = { + { 0u, 0u }, + { arr.extent(0), arr.extent(1) } + }; + } else { + range = { + { 0u, 0u, 0u }, + { arr.extent(0), arr.extent(1), arr.extent(2) } + }; + } + Kokkos::parallel_for("Fill", range, Fill_kernel(arr, v, c)); +} + +template +auto compute_field_stat(const M& metric, + const ndfield_t& em, + const ndfield_t& j, + const range_t& range) -> real_t { + real_t buff = ZERO; + Kokkos::parallel_reduce("ReduceFields", + range, + kernel::ReducedFields_kernel(em, j, metric), + buff); + return buff / metric.totVolume(); +} + +auto almost_equal(real_t a, real_t b, real_t acc) -> bool { + return (math::fabs(a - b) < acc * math::max(math::fabs(a), math::fabs(b))) + + (real_t)1e-10; +} + +template +void testReducedStats(const std::vector& res, + const boundaries_t& ext, + const real_t acc) { + raise::ErrorIf(res.size() != M::Dim, "Invalid resolution size", HERE); + + M metric { res, ext, {} }; + + std::vector x_s, y_s, z_s; + + coord_t dummy { ZERO }; + std::vector values; + values.push_back(metric.template transform<1, Idx::T, Idx::U>(dummy, ONE)); + values.push_back(metric.template transform<2, Idx::T, Idx::U>(dummy, TWO)); + values.push_back(metric.template transform<3, Idx::T, Idx::U>(dummy, THREE)); + + values.push_back(metric.template transform<1, Idx::T, Idx::U>(dummy, FOUR * ONE)); + values.push_back(metric.template transform<2, Idx::T, Idx::U>(dummy, FOUR * TWO)); + values.push_back( + metric.template transform<3, Idx::T, Idx::U>(dummy, FOUR * THREE)); + + values.push_back(metric.template transform<1, Idx::T, Idx::U>(dummy, -ONE)); + values.push_back(metric.template transform<2, Idx::T, Idx::U>(dummy, -TWO)); + values.push_back(metric.template transform<3, Idx::T, Idx::U>(dummy, THREE)); + + values.push_back(metric.template transform<1, Idx::T, Idx::U>(dummy, FOUR)); + values.push_back(metric.template transform<2, Idx::T, Idx::U>(dummy, TWO)); + values.push_back(metric.template transform<3, Idx::T, Idx::U>(dummy, ONE)); + + ndfield_t EM; + ndfield_t J; + range_t cell_range; + + if constexpr (M::Dim == Dim::_1D) { + EM = ndfield_t { "EM", res[0] + 2 * N_GHOSTS }; + J = ndfield_t { "J", res[0] + 2 * N_GHOSTS }; + cell_range = { N_GHOSTS, res[0] + N_GHOSTS }; + + put_value(EM, values[0], em::ex1); + put_value(EM, values[1], em::ex2); + put_value(EM, values[2], em::ex3); + + put_value(EM, values[6], em::bx1); + put_value(EM, values[7], em::bx2); + put_value(EM, values[8], em::bx3); + + put_value(J, values[9], cur::jx1); + put_value(J, values[10], cur::jx2); + put_value(J, values[11], cur::jx3); + } else if constexpr (M::Dim == Dim::_2D) { + EM = ndfield_t { "EM", res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS }; + J = ndfield_t { "J", res[0] + 2 * N_GHOSTS, res[1] + 2 * N_GHOSTS }; + + cell_range = { + { N_GHOSTS, N_GHOSTS }, + { res[0] + N_GHOSTS, res[1] + N_GHOSTS } + }; + + put_value(EM, values[0], em::ex1); + put_value(EM, values[1], em::ex2); + put_value(EM, values[2], em::ex3); + + put_value(EM, values[6], em::bx1); + put_value(EM, values[7], em::bx2); + put_value(EM, values[8], em::bx3); + + put_value(J, values[9], cur::jx1); + put_value(J, values[10], cur::jx2); + put_value(J, values[11], cur::jx3); + } else { + EM = ndfield_t { "EM", + res[0] + 2 * N_GHOSTS, + res[1] + 2 * N_GHOSTS, + res[2] + 2 * N_GHOSTS }; + J = ndfield_t { "J", + res[0] + 2 * N_GHOSTS, + res[1] + 2 * N_GHOSTS, + res[2] + 2 * N_GHOSTS }; + + cell_range = { + { N_GHOSTS, N_GHOSTS, N_GHOSTS }, + { res[0] + N_GHOSTS, res[1] + N_GHOSTS, res[2] + N_GHOSTS } + }; + + put_value(EM, values[0], em::ex1); + put_value(EM, values[1], em::ex2); + put_value(EM, values[2], em::ex3); + + put_value(EM, values[6], em::bx1); + put_value(EM, values[7], em::bx2); + put_value(EM, values[8], em::bx3); + + put_value(J, values[9], cur::jx1); + put_value(J, values[10], cur::jx2); + put_value(J, values[11], cur::jx3); + } + + { + const auto Ex_Sq = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(Ex_Sq, (real_t)(1), acc), + "Ex_Sq does not match expected value", + HERE); + } + + { + const auto Ey_Sq = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(Ey_Sq, (real_t)(4), acc), + "Ey_Sq does not match expected value", + HERE); + } + + { + const auto Ez_Sq = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(Ez_Sq, (real_t)(9), acc), + "Ez_Sq does not match expected value", + HERE); + } + + { + const auto Bx_Sq = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(Bx_Sq, (real_t)(1), acc), + "Bx_Sq does not match expected value", + HERE); + } + + { + const auto By_Sq = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(By_Sq, (real_t)(4), acc), + "By_Sq does not match expected value", + HERE); + } + + { + const auto Bz_Sq = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(Bz_Sq, (real_t)(9), acc), + "Bz_Sq does not match expected value", + HERE); + } + + { + const auto ExB_x = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(ExB_x, (real_t)(12), acc), + "ExB_x does not match expected value", + HERE); + } + + { + const auto ExB_y = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(ExB_y, (real_t)(-6), acc), + "ExB_y does not match expected value", + HERE); + } + + { + const auto ExB_z = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(ExB_z, (real_t)(0), acc), + "ExB_z does not match expected value", + HERE); + } + + { + const auto JdotE = compute_field_stat(metric, + EM, + J, + cell_range); + raise::ErrorIf(not almost_equal(JdotE, (real_t)(11), acc), + "JdotE does not match expected value", + HERE); + } +} + +auto main(int argc, char* argv[]) -> int { + Kokkos::initialize(argc, argv); + + try { + using namespace ntt; + + const ncells_t nx = 100, ny = 123, nz = 52; + std::pair x_ext { -2.0, 2.0 }; + std::pair y_ext { 0.0, 4.92 }; + std::pair z_ext { 0.0, 2.08 }; + + testReducedStats>({ nx }, { x_ext }, 1e-6); + testReducedStats>({ nx, ny }, + { x_ext, y_ext }, + 1e-6); + testReducedStats>({ nx, ny, nz }, + { x_ext, y_ext, z_ext }, + 1e-6); + + } catch (std::exception& e) { + std::cerr << e.what() << std::endl; + Kokkos::finalize(); + return 1; + } + Kokkos::finalize(); + return 0; +} diff --git a/src/kernels/utils.hpp b/src/kernels/utils.hpp new file mode 100644 index 000000000..628ed267f --- /dev/null +++ b/src/kernels/utils.hpp @@ -0,0 +1,66 @@ +/** + * @file kernels/utils.hpp + * @brief Commonly used generic kernels + * @implements + * - kernel::ComputeSum_kernel<> + * - kernel::ComputeDivergence_kernel<> + * @namespaces: + * - kernel:: + */ + +#ifndef KERNELS_UTILS_HPP +#define KERNELS_UTILS_HPP + +#include "global.h" + +#include "arch/kokkos_aliases.h" +#include "utils/error.h" + +namespace kernel { + + template + class ComputeSum_kernel { + const ndfield_t buff; + const idx_t buff_idx; + + public: + ComputeSum_kernel(const ndfield_t& buff, idx_t buff_idx) + : buff { buff } + , buff_idx { buff_idx } { + raise::ErrorIf(buff_idx >= N, "Invalid component index", HERE); + } + + Inline void operator()(index_t i1, real_t& lsum) const { + if constexpr (D == Dim::_1D) { + lsum += buff(i1, buff_idx); + } else { + raise::KernelError( + HERE, + "1D implementation of ComputeSum_kernel called for non-1D"); + } + } + + Inline void operator()(index_t i1, index_t i2, real_t& lsum) const { + if (D == Dim::_2D) { + lsum += buff(i1, i2, buff_idx); + } else { + raise::KernelError( + HERE, + "2D implementation of ComputeSum_kernel called for non-2D"); + } + } + + Inline void operator()(index_t i1, index_t i2, index_t i3, real_t& lsum) const { + if (D == Dim::_3D) { + lsum += buff(i1, i2, i3, buff_idx); + } else { + raise::KernelError( + HERE, + "3D implementation of ComputeSum_kernel called for non-3D"); + } + } + }; + +} // namespace kernel + +#endif // KERNELS_UTILS_HPP diff --git a/src/metrics/CMakeLists.txt b/src/metrics/CMakeLists.txt index 0f303fcfc..0bb5b977c 100644 --- a/src/metrics/CMakeLists.txt +++ b/src/metrics/CMakeLists.txt @@ -1,11 +1,18 @@ +# cmake-lint: disable=C0103 # ------------------------------ # @defines: ntt_metrics [INTERFACE] +# # @includes: -# - ../ +# +# * ../ +# # @depends: -# - ntt_global [required] +# +# * ntt_global [required] +# # @uses: -# - kokkos [required] +# +# * kokkos [required] # ------------------------------ add_library(ntt_metrics INTERFACE) @@ -15,5 +22,4 @@ add_dependencies(ntt_metrics ${libs}) target_link_libraries(ntt_metrics INTERFACE ${libs}) target_include_directories(ntt_metrics - INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../ -) \ No newline at end of file + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../) diff --git a/src/metrics/kerr_schild.h b/src/metrics/kerr_schild.h index 5a60def53..e8626fb79 100644 --- a/src/metrics/kerr_schild.h +++ b/src/metrics/kerr_schild.h @@ -16,6 +16,7 @@ #include "global.h" #include "arch/kokkos_aliases.h" +#include "utils/comparators.h" #include "utils/numeric.h" #include "metrics/metric_base.h" @@ -72,8 +73,8 @@ namespace metric { using MetricBase::nx3; using MetricBase::set_dxMin; - KerrSchild(std::vector res, - boundaries_t ext, + KerrSchild(const std::vector& res, + const boundaries_t& ext, const std::map& params) : MetricBase { res, ext } , a { params.at("a") } @@ -127,6 +128,15 @@ namespace metric { return min_dx; } + /** + * total volume of the region described by the metric (in physical units) + */ + [[nodiscard]] + auto totVolume() const -> real_t override { + // @TODO: Ask Alisa + return ZERO; + } + /** * metric component with lower indices: h_ij * @param x coordinate array in code units @@ -216,6 +226,29 @@ namespace metric { return ONE / math::sqrt(ONE + z(r, theta)); } + /** + * dr derivative of lapse function + * @param x coordinate array in code units + */ + Inline auto dr_alpha(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + const real_t dr_Sigma { TWO * r * dr }; + + return -(dr * Sigma(r, theta) - r * dr_Sigma) * CUBE(alpha(x)) / + SQR(Sigma(r, theta)); + } + + /** + * dtheta derivative of lapse function + * @param x coordinate array in code units + */ + Inline auto dt_alpha(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + return CUBE(alpha(x)) * r * dt_Sigma(theta) / SQR(Sigma(r, theta)); + } + /** * radial component of shift vector * @param x coordinate array in code units @@ -225,6 +258,156 @@ namespace metric { return dr_inv * z_ / (ONE + z_); } + /** + * dr derivative of radial component of shift vector + * @param x coordinate array in code units + */ + Inline auto dr_beta1(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + const real_t dr_Sigma { TWO * r * dr }; + + return dr_inv * TWO * (dr * Sigma(r, theta) - r * dr_Sigma) / + SQR(Sigma(r, theta) + TWO * r); + } + + /** + * dtheta derivative of radial component of shift vector + * @param x coordinate array in code units + */ + Inline auto dt_beta1(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + return -dr_inv * TWO * r * dt_Sigma(theta) / + SQR(Sigma(r, theta) * (ONE + z(r, theta))); + } + + /** + * dr derivative of h^11 + * @param x coordinate array in code units + */ + Inline auto dr_h11(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + const real_t dr_Sigma { TWO * r * dr }; + const real_t dr_Delta { TWO * dr * (r - ONE) }; + const real_t dr_A { FOUR * r * dr * (SQR(r) + SQR(a)) - + SQR(a) * SQR(math::sin(theta)) * dr_Delta }; + + return (Sigma(r, theta) * (Sigma(r, theta) + TWO * r) * dr_A - + TWO * A(r, theta) * + (r * dr_Sigma + Sigma(r, theta) * (dr_Sigma + dr))) / + (SQR(Sigma(r, theta) * (Sigma(r, theta) + TWO * r))) * SQR(dr_inv); + } + + /** + * dr derivative of h^22 + * @param x coordinate array in code units + */ + Inline auto dr_h22(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + const real_t dr_Sigma { TWO * r * dr }; + + return -dr_Sigma / SQR(Sigma(r, theta)) * SQR(dtheta_inv); + } + + /** + * dr derivative of h^33 + * @param x coordinate array in code units + */ + Inline auto dr_h33(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + const real_t dr_Sigma { TWO * r * dr }; + + return -dr_Sigma / SQR(Sigma(r, theta)) / SQR(math::sin(theta)); + } + + /** + * dr derivative of h^13 + * @param x coordinate array in code units + */ + Inline auto dr_h13(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + const real_t dr_Sigma { TWO * r * dr }; + + return -a * dr_Sigma / SQR(Sigma(r, theta)) * dr_inv; + } + + /** + * dtheta derivative of Sigma + * @param x coordinate array in code units + */ + Inline auto dt_Sigma(const real_t& theta) const -> real_t { + const real_t dt_Sigma { -TWO * SQR(a) * math::sin(theta) * + math::cos(theta) * dtheta }; + if (cmp::AlmostZero(dt_Sigma)) { + return ZERO; + } else { + return dt_Sigma; + } + } + + /** + * dtheta derivative of A + * @param x coordinate array in code units + */ + Inline auto dt_A(const real_t& r, const real_t& theta) const -> real_t { + const real_t dt_A { -TWO * SQR(a) * math::sin(theta) * math::cos(theta) * + Delta(r) * dtheta }; + if (cmp::AlmostZero(dt_A)) { + return ZERO; + } else { + return dt_A; + } + } + + /** + * dtheta derivative of h^11 + * @param x coordinate array in code units + */ + Inline auto dt_h11(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + return (Sigma(r, theta) * (Sigma(r, theta) + TWO * r) * dt_A(r, theta) - + TWO * A(r, theta) * dt_Sigma(theta) * (r + Sigma(r, theta))) / + (SQR(Sigma(r, theta) * (Sigma(r, theta) + TWO * r))) * SQR(dr_inv); + } + + /** + * dtheta derivative of h^22 + * @param x coordinate array in code units + */ + Inline auto dt_h22(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + return -dt_Sigma(theta) / SQR(Sigma(r, theta)) * SQR(dtheta_inv); + } + + /** + * dtheta derivative of h^33 + * @param x coordinate array in code units + */ + Inline auto dt_h33(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + return -TWO * dtheta * math::cos(theta) * + (Sigma(r, theta) - SQR(a) * SQR(math::sin(theta))) / + CUBE(math::sin(theta)) / SQR(Sigma(r, theta)); + } + + /** + * dtheta derivative of h^13 + * @param x coordinate array in code units + */ + Inline auto dt_h13(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + return -a * dt_Sigma(theta) / SQR(Sigma(r, theta)) * dr_inv; + } + /** * sqrt(det(h_ij)) * @param x coordinate array in code units diff --git a/src/metrics/kerr_schild_0.h b/src/metrics/kerr_schild_0.h index 70689f4f0..142e88b7a 100644 --- a/src/metrics/kerr_schild_0.h +++ b/src/metrics/kerr_schild_0.h @@ -36,6 +36,7 @@ namespace metric { private: const real_t dr, dtheta, dphi; const real_t dr_inv, dtheta_inv, dphi_inv; + const real_t a, rg_, rh_; public: static constexpr const char* Label { "kerr_schild_0" }; @@ -53,10 +54,13 @@ namespace metric { using MetricBase::nx3; using MetricBase::set_dxMin; - KerrSchild0(std::vector res, - boundaries_t ext, + KerrSchild0(const std::vector& res, + const boundaries_t& ext, const std::map& = {}) : MetricBase { res, ext } + , a { ZERO } + , rg_ { ONE } + , rh_ { TWO } , dr { (x1_max - x1_min) / nx1 } , dtheta { (x2_max - x2_min) / nx2 } , dphi { (x3_max - x3_min) / nx3 } @@ -70,17 +74,17 @@ namespace metric { [[nodiscard]] Inline auto spin() const -> const real_t& { - return ZERO; + return a; } [[nodiscard]] Inline auto rhorizon() const -> const real_t& { - return ZERO; + return rh_; } [[nodiscard]] Inline auto rg() const -> const real_t& { - return ZERO; + return rg_; } /** @@ -104,6 +108,20 @@ namespace metric { return min_dx; } + /** + * total volume of the region described by the metric (in physical units) + */ + [[nodiscard]] + auto totVolume() const -> real_t override { + if constexpr (D == Dim::_1D) { + raise::Error("1D spherical metric not applicable", HERE); + } else if constexpr (D == Dim::_2D) { + return (SQR(x1_max) - SQR(x1_min)) * (x2_max - x2_min); + } else { + return (SQR(x1_max) - SQR(x1_min)) * (x2_max - x2_min) * (x3_max - x3_min); + } + } + /** * metric component with lower indices: h_ij * @param x coordinate array in code units @@ -167,6 +185,22 @@ namespace metric { return ONE; } + /** + * dr derivative of lapse function + * @param x coordinate array in code units + */ + Inline auto dr_alpha(const coord_t& x) const -> real_t { + return ZERO; + } + + /** + * dtheta derivative of lapse function + * @param x coordinate array in code units + */ + Inline auto dt_alpha(const coord_t& x) const -> real_t { + return ZERO; + } + /** * radial component of shift vector * @param x coordinate array in code units @@ -175,6 +209,92 @@ namespace metric { return ZERO; } + /** + * dr derivative of radial component of shift vector + * @param x coordinate array in code units + */ + Inline auto dr_beta1(const coord_t& x) const -> real_t { + return ZERO; + } + + /** + * dtheta derivative of radial component of shift vector + * @param x coordinate array in code units + */ + Inline auto dt_beta1(const coord_t& x) const -> real_t { + return ZERO; + } + + /** + * dr derivative of h^11 + * @param x coordinate array in code units + */ + Inline auto dr_h11(const coord_t& x) const -> real_t { + return ZERO; + } + + /** + * dr derivative of h^22 + * @param x coordinate array in code units + */ + Inline auto dr_h22(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + return -TWO / CUBE(r) * SQR(dtheta_inv) * dr; + } + + /** + * dr derivative of h^33 + * @param x coordinate array in code units + */ + Inline auto dr_h33(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + return -TWO / CUBE(r) / SQR(math::sin(theta)) * dr; + } + + /** + * dr derivative of h^13 + * @param x coordinate array in code units + */ + Inline auto dr_h13(const coord_t& x) const -> real_t { + return ZERO; + } + + /** + * dtheta derivative of h^11 + * @param x coordinate array in code units + */ + Inline auto dt_h11(const coord_t& x) const -> real_t { + return ZERO; + } + + /** + * dtheta derivative of h^22 + * @param x coordinate array in code units + */ + Inline auto dt_h22(const coord_t& x) const -> real_t { + return ZERO; + } + + /** + * dtheta derivative of h^33 + * @param x coordinate array in code units + */ + Inline auto dt_h33(const coord_t& x) const -> real_t { + const real_t r { x[0] * dr + x1_min }; + const real_t theta { x[1] * dtheta + x2_min }; + return -TWO * math::cos(theta) / SQR(r) / CUBE(math::sin(theta)) * dtheta; + } + + /** + * dtheta derivative of h^13 + * @param x coordinate array in code units + */ + Inline auto dt_h13(const coord_t& x) const -> real_t { + return ZERO; + } + /** * sqrt(det(h_ij)) * @param x coordinate array in code units diff --git a/src/metrics/metric_base.h b/src/metrics/metric_base.h index 321d39bbd..8b0e5cae0 100644 --- a/src/metrics/metric_base.h +++ b/src/metrics/metric_base.h @@ -46,12 +46,38 @@ #include "global.h" +#include "utils/error.h" #include "utils/numeric.h" #include namespace metric { + namespace { + template + auto getNXi(const std::vector& res) -> real_t { + if constexpr (i >= static_cast(D)) { + return ONE; + } else { + raise::ErrorIf(res.size() <= i, "Invalid res size provided to metric", HERE); + return static_cast(res.at(i)); + } + }; + + template + auto getExtent(const boundaries_t& ext) -> real_t { + if constexpr (i >= static_cast(D)) { + return ZERO; + } else { + raise::ErrorIf(ext.size() <= i, "Invalid ext size provided to metric", HERE); + return min ? ext.at(i).first : ext.at(i).second; + } + }; + + constexpr bool XMin = true; + constexpr bool XMax = false; + }; // namespace + /** * Virtual parent metric class template: h_ij * Coordinates vary from `0` to `nx1` ... (code units) @@ -61,22 +87,25 @@ namespace metric { static constexpr bool is_metric { true }; static constexpr Dimension Dim { D }; - MetricBase(std::vector res, boundaries_t ext) - : nx1 { res.size() > 0 ? (real_t)(res[0]) : ONE } - , nx2 { res.size() > 1 ? (real_t)(res[1]) : ONE } - , nx3 { res.size() > 2 ? (real_t)(res[2]) : ONE } - , x1_min { res.size() > 0 ? ext[0].first : ZERO } - , x1_max { res.size() > 0 ? ext[0].second : ZERO } - , x2_min { res.size() > 1 ? ext[1].first : ZERO } - , x2_max { res.size() > 1 ? ext[1].second : ZERO } - , x3_min { res.size() > 2 ? ext[2].first : ZERO } - , x3_max { res.size() > 2 ? ext[2].second : ZERO } {} + MetricBase(const std::vector& res, const boundaries_t& ext) + : nx1 { getNXi(res) } + , nx2 { getNXi(res) } + , nx3 { getNXi(res) } + , x1_min { getExtent(ext) } + , x1_max { getExtent(ext) } + , x2_min { getExtent(ext) } + , x2_max { getExtent(ext) } + , x3_min { getExtent(ext) } + , x3_max { getExtent(ext) } {} ~MetricBase() = default; [[nodiscard]] virtual auto find_dxMin() const -> real_t = 0; + [[nodiscard]] + virtual auto totVolume() const -> real_t = 0; + [[nodiscard]] auto dxMin() const -> real_t { return dx_min; diff --git a/src/metrics/minkowski.h b/src/metrics/minkowski.h index c22ac1ad9..cf4aad2b6 100644 --- a/src/metrics/minkowski.h +++ b/src/metrics/minkowski.h @@ -47,22 +47,26 @@ namespace metric { using MetricBase::nx3; using MetricBase::set_dxMin; - Minkowski(std::vector res, - boundaries_t ext, + Minkowski(const std::vector& res, + const boundaries_t& ext, const std::map& = {}) : MetricBase { res, ext } , dx { (x1_max - x1_min) / nx1 } , dx_inv { ONE / dx } { set_dxMin(find_dxMin()); + const auto epsilon = std::numeric_limits::epsilon() * + static_cast(100.0); if constexpr (D != Dim::_1D) { - raise::ErrorIf(not cmp::AlmostEqual((x2_max - x2_min) / (real_t)(nx2), dx), - "dx2 must be equal to dx1 in 2D", - HERE); + raise::ErrorIf( + not cmp::AlmostEqual((x2_max - x2_min) / (real_t)(nx2), dx, epsilon), + "dx2 must be equal to dx1 in 2D", + HERE); } if constexpr (D == Dim::_3D) { - raise::ErrorIf(not cmp::AlmostEqual((x3_max - x3_min) / (real_t)(nx3), dx), - "dx3 must be equal to dx1 in 3D", - HERE); + raise::ErrorIf( + not cmp::AlmostEqual((x3_max - x3_min) / (real_t)(nx3), dx, epsilon), + "dx3 must be equal to dx1 in 3D", + HERE); } } @@ -76,6 +80,20 @@ namespace metric { return dx / math::sqrt(static_cast(D)); } + /** + * total volume of the region described by the metric (in physical units) + */ + [[nodiscard]] + auto totVolume() const -> real_t override { + if constexpr (D == Dim::_1D) { + return x1_max - x1_min; + } else if constexpr (D == Dim::_2D) { + return (x1_max - x1_min) * (x2_max - x2_min); + } else { + return (x1_max - x1_min) * (x2_max - x2_min) * (x3_max - x3_min); + } + } + /** * metric component with lower indices: h_ij * @param x coordinate array in code units @@ -240,8 +258,7 @@ namespace metric { * @note tetrad/cart <-> cntrv <-> cov */ template - Inline auto transform(const coord_t& xi, const real_t& v_in) const - -> real_t { + Inline auto transform(const coord_t& xi, const real_t& v_in) const -> real_t { static_assert(i > 0 && i <= 3, "Invalid index i"); static_assert(in != out, "Invalid vector transformation"); if constexpr (i > static_cast(D)) { diff --git a/src/metrics/qkerr_schild.h b/src/metrics/qkerr_schild.h index d531b8b3b..c137b5bcf 100644 --- a/src/metrics/qkerr_schild.h +++ b/src/metrics/qkerr_schild.h @@ -72,8 +72,8 @@ namespace metric { using MetricBase::nx3; using MetricBase::set_dxMin; - QKerrSchild(std::vector res, - boundaries_t ext, + QKerrSchild(const std::vector& res, + const boundaries_t& ext, const std::map& params) : MetricBase { res, ext } , a { params.at("a") } @@ -132,6 +132,15 @@ namespace metric { return min_dx; } + /** + * total volume of the region described by the metric (in physical units) + */ + [[nodiscard]] + auto totVolume() const -> real_t override { + // @TODO: Ask Alisa + return ZERO; + } + /** * metric component with lower indices: h_ij * @param x coordinate array in code units @@ -234,6 +243,39 @@ namespace metric { return ONE / math::sqrt(ONE + z(r, theta)); } + /** + * dr derivative of lapse function + * @param x coordinate array in code units + */ + Inline auto dr_alpha(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t theta { eta2theta(x[1] * deta + eta_min) }; + const real_t dx_r { dchi * math::exp(x[0] * dchi + chi_min) }; + const real_t dr_Sigma { TWO * r * dx_r }; + return -(dx_r * Sigma(r, theta) - r * dr_Sigma) * CUBE(alpha(x)) / + SQR(Sigma(r, theta)); + } + + /** + * dtheta derivative of lapse function + * @param x coordinate array in code units + */ + Inline auto dt_alpha(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t eta { x[1] * deta + eta_min }; + const real_t theta { eta2theta(eta) }; + const real_t dx_dt { + deta * (ONE + TWO * h0 * static_cast(constant::INV_PI_SQR) * + (TWO * THREE * SQR(eta) - + TWO * THREE * static_cast(constant::PI) * eta + + static_cast(constant::PI_SQR))) + }; + const real_t dt_Sigma { -TWO * SQR(a) * math::sin(theta) * + math::cos(theta) * dx_dt }; + + return r * dt_Sigma * CUBE(alpha(x)) / SQR(Sigma(r, theta)); + } + /** * radial component of shift vector * @param x coordinate array in code units @@ -246,6 +288,182 @@ namespace metric { return math::exp(-chi) * dchi_inv * z_ / (ONE + z_); } + /** + * dr derivative of radial component of shift vector + * @param x coordinate array in code units + */ + Inline auto dr_beta1(const coord_t& x) const -> real_t { + const real_t chi { x[0] * dchi + chi_min }; + const real_t r { r0 + math::exp(chi) }; + const real_t theta { eta2theta(x[1] * deta + eta_min) }; + const real_t z_ { z(r, theta) }; + const real_t dx_r { dchi * math::exp(x[0] * dchi + chi_min) }; + const real_t dr_Sigma { TWO * r * dx_r }; + + return math::exp(-chi) * dchi_inv * TWO * + (dx_r * Sigma(r, theta) - r * dr_Sigma) / + SQR(Sigma(r, theta) + TWO * r) - + dchi * math::exp(-chi) * dchi_inv * z_ / (ONE + z_); + } + + /** + * dr derivative of radial component of shift vector + * @param x coordinate array in code units + */ + Inline auto dt_beta1(const coord_t& x) const -> real_t { + const real_t chi { x[0] * dchi + chi_min }; + const real_t r { r0 + math::exp(chi) }; + const real_t eta { x[1] * deta + eta_min }; + const real_t theta { eta2theta(eta) }; + return -math::exp(-chi) * dchi_inv * TWO * r * dt_Sigma(eta) / + SQR(Sigma(r, theta) * (ONE + z(r, theta))); + } + + /** + * dr derivative of h^11 + * @param x coordinate array in code units + */ + Inline auto dr_h11(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t theta { eta2theta(x[1] * deta + eta_min) }; + + const real_t dx_r { dchi * math::exp(x[0] * dchi + chi_min) }; + const real_t dr_Sigma { TWO * r * dx_r }; + const real_t dr_Delta { TWO * dx_r * (r - ONE) }; + const real_t dr_A { FOUR * r * dx_r * (SQR(r) + SQR(a)) - + SQR(a) * SQR(math::sin(theta)) * dr_Delta }; + + return (math::exp(-TWO * (x[0] * dchi + chi_min)) / SQR(dchi) * + (Sigma(r, theta) * (Sigma(r, theta) + TWO * r) * dr_A - + TWO * A(r, theta) * + (r * dr_Sigma + Sigma(r, theta) * (dr_Sigma + dx_r))) / + (SQR(Sigma(r, theta) * (Sigma(r, theta) + TWO * r)))) - + TWO * dchi * math::exp(-TWO * (x[0] * dchi + chi_min)) / SQR(dchi) * + A(r, theta) / (Sigma(r, theta) * (Sigma(r, theta) + TWO * r)); + } + + /** + * dr derivative of h^22 + * @param x coordinate array in code units + */ + Inline auto dr_h22(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t theta { eta2theta(x[1] * deta + eta_min) }; + const real_t dx_r { dchi * math::exp(x[0] * dchi + chi_min) }; + const real_t dr_Sigma { TWO * r * dx_r }; + + return -dr_Sigma / SQR(Sigma(r, theta)) / SQR(deta); + } + + /** + * dr derivative of h^33 + * @param x coordinate array in code units + */ + Inline auto dr_h33(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t theta { eta2theta(x[1] * deta + eta_min) }; + const real_t dx_r { dchi * math::exp(x[0] * dchi + chi_min) }; + const real_t dr_Sigma { TWO * r * dx_r }; + + return -dr_Sigma / SQR(Sigma(r, theta)) / SQR(math::sin(theta)); + } + + /** + * dr derivative of h^13 + * @param x coordinate array in code units + */ + Inline auto dr_h13(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t theta { eta2theta(x[1] * deta + eta_min) }; + const real_t dx_r { dchi * math::exp(x[0] * dchi + chi_min) }; + const real_t dr_Sigma { TWO * r * dx_r }; + + return -a * dr_Sigma / SQR(Sigma(r, theta)) * + (math::exp(-(x[0] * dchi + chi_min)) * dchi_inv) - + dchi * (math::exp(-(x[0] * dchi + chi_min)) * dchi_inv) * a / + Sigma(r, theta); + } + + /** + * dtheta derivative of Sigma + * @param x coordinate array in code units + */ + Inline auto dt_Sigma(const real_t& eta) const -> real_t { + const real_t theta { eta2theta(eta) }; + const real_t dt_Sigma { -TWO * SQR(a) * math::sin(theta) * + math::cos(theta) * dx_dt(eta) }; + if (cmp::AlmostZero(dt_Sigma)) { + return ZERO; + } else { + return dt_Sigma; + } + } + + /** + * dtheta derivative of A + * @param x coordinate array in code units + */ + Inline auto dt_A(const real_t& r, const real_t& eta) const -> real_t { + const real_t theta { eta2theta(eta) }; + const real_t dt_A { -TWO * SQR(a) * math::sin(theta) * math::cos(theta) * + Delta(r) * dx_dt(eta) }; + if (cmp::AlmostZero(dt_A)) { + return ZERO; + } else { + return dt_A; + } + } + + /** + * dtheta derivative of h^11 + * @param x coordinate array in code units + */ + Inline auto dt_h11(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t eta { x[1] * deta + eta_min }; + const real_t theta { eta2theta(eta) }; + return math::exp(-TWO * (x[0] * dchi + chi_min)) / SQR(dchi) * + (Sigma(r, theta) * (Sigma(r, theta) + TWO * r) * dt_A(r, eta) - + TWO * A(r, theta) * dt_Sigma(eta) * (r + Sigma(r, theta))) / + (SQR(Sigma(r, theta) * (Sigma(r, theta) + TWO * r))); + } + + /** + * dtheta derivative of h^22 + * @param x coordinate array in code units + */ + Inline auto dt_h22(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t eta { x[1] * deta + eta_min }; + const real_t theta { eta2theta(eta) }; + return -dt_Sigma(eta) / SQR(Sigma(r, theta)) / SQR(deta); + } + + /** + * dtheta derivative of h^33 + * @param x coordinate array in code units + */ + Inline auto dt_h33(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t eta { x[1] * deta + eta_min }; + const real_t theta { eta2theta(eta) }; + return -(dt_Sigma(eta) + TWO * math::cos(theta) / math::sin(theta) * + Sigma(r, theta) * dx_dt(eta)) / + SQR(Sigma(r, theta) * math::sin(theta)); + } + + /** + * dtheta derivative of h^13 + * @param x coordinate array in code units + */ + Inline auto dt_h13(const coord_t& x) const -> real_t { + const real_t r { r0 + math::exp(x[0] * dchi + chi_min) }; + const real_t eta { x[1] * deta + eta_min }; + const real_t theta { eta2theta(eta) }; + return -a * dt_Sigma(eta) / SQR(Sigma(r, theta)) * + (math::exp(-(x[0] * dchi + chi_min)) * dchi_inv); + } + /** * sqrt(det(h_ij)) * @param x coordinate array in code units @@ -265,7 +483,7 @@ namespace metric { } /** - * sqrt(det(h_ij)) + * sqrt(det(h_ij)) divided by sin(theta). * @param x coordinate array in code units */ Inline auto sqrt_det_h_tilde(const coord_t& x) const -> real_t { @@ -287,12 +505,14 @@ namespace metric { * @param x1 radial coordinate along the axis (code units) */ Inline auto polar_area(const real_t& x1) const -> real_t { - return dchi * math::exp(x1 * dchi + chi_min) * - (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a)) * - math::sqrt( - ONE + TWO * (r0 + math::exp(x1 * dchi + chi_min)) / - (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a))) * - (ONE - math::cos(eta2theta(HALF * deta + eta_min))); + if constexpr (D != Dim::_1D) { + return dchi * math::exp(x1 * dchi + chi_min) * + (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a)) * + math::sqrt( + ONE + TWO * (r0 + math::exp(x1 * dchi + chi_min)) / + (SQR(r0 + math::exp(x1 * dchi + chi_min)) + SQR(a))) * + (ONE - math::cos(eta2theta(HALF * deta))); + } } /** @@ -456,6 +676,20 @@ namespace metric { } } + /** + * @brief quasi-spherical eta to spherical theta + */ + Inline auto dx_dt(const real_t& eta) const -> real_t { + if (cmp::AlmostZero(h0)) { + return deta; + } else { + return deta * + (ONE + TWO * h0 * constant::INV_PI_SQR * + (TWO * THREE * SQR(eta) - + TWO * THREE * constant::PI * eta + constant::PI_SQR)); + } + } + /** * @brief spherical theta to quasi-spherical eta */ diff --git a/src/metrics/qspherical.h b/src/metrics/qspherical.h index 8062f7589..4acfda442 100644 --- a/src/metrics/qspherical.h +++ b/src/metrics/qspherical.h @@ -38,6 +38,7 @@ namespace metric { const real_t r0, h, chi_min, eta_min, phi_min; const real_t dchi, deta, dphi; const real_t dchi_inv, deta_inv, dphi_inv; + const bool small_angle; public: static constexpr const char* Label { "qspherical" }; @@ -55,8 +56,8 @@ namespace metric { using MetricBase::nx3; using MetricBase::set_dxMin; - QSpherical(std::vector res, - boundaries_t ext, + QSpherical(const std::vector& res, + const boundaries_t& ext, const std::map& params) : MetricBase { res, ext } , r0 { params.at("r0") } @@ -69,7 +70,8 @@ namespace metric { , dphi { (x3_max - x3_min) / nx3 } , dchi_inv { ONE / dchi } , deta_inv { ONE / deta } - , dphi_inv { ONE / dphi } { + , dphi_inv { ONE / dphi } + , small_angle { eta2theta(HALF * deta) < constant::SMALL_ANGLE } { set_dxMin(find_dxMin()); } @@ -97,6 +99,20 @@ namespace metric { return min_dx; } + /** + * total volume of the region described by the metric (in physical units) + */ + [[nodiscard]] + auto totVolume() const -> real_t override { + if constexpr (D == Dim::_1D) { + raise::Error("1D spherical metric not applicable", HERE); + } else if constexpr (D == Dim::_2D) { + return (SQR(x1_max) - SQR(x1_min)) * (x2_max - x2_min); + } else { + return (SQR(x1_max) - SQR(x1_min)) * (x2_max - x2_min) * (x3_max - x3_min); + } + } + /** * metric component with lower indices: h_ij * @param x coordinate array in code units @@ -155,11 +171,11 @@ namespace metric { */ Inline auto sqrt_det_h(const coord_t& x) const -> real_t { if constexpr (D == Dim::_2D) { - real_t exp_chi { math::exp(x[0] * dchi + chi_min) }; + const real_t exp_chi { math::exp(x[0] * dchi + chi_min) }; return dchi * deta * exp_chi * dtheta_deta(x[1] * deta + eta_min) * SQR(r0 + exp_chi) * math::sin(eta2theta(x[1] * deta + eta_min)); } else if constexpr (D == Dim::_3D) { - real_t exp_chi { math::exp(x[0] * dchi + chi_min) }; + const real_t exp_chi { math::exp(x[0] * dchi + chi_min) }; return dchi * deta * dphi * exp_chi * dtheta_deta(x[1] * deta + eta_min) * SQR(r0 + exp_chi) * math::sin(eta2theta(x[1] * deta + eta_min)); } @@ -171,7 +187,7 @@ namespace metric { */ Inline auto sqrt_det_h_tilde(const coord_t& x) const -> real_t { if constexpr (D != Dim::_1D) { - real_t exp_chi { math::exp(x[0] * dchi + chi_min) }; + const real_t exp_chi { math::exp(x[0] * dchi + chi_min) }; return dchi * deta * exp_chi * dtheta_deta(x[1] * deta + eta_min) * SQR(r0 + exp_chi); } @@ -183,9 +199,16 @@ namespace metric { */ Inline auto polar_area(const real_t& x1) const -> real_t { if constexpr (D != Dim::_1D) { - real_t exp_chi { math::exp(x1 * dchi + chi_min) }; - return dchi * exp_chi * SQR(r0 + exp_chi) * - (ONE - math::cos(eta2theta(HALF * deta))); + const real_t exp_chi { math::exp(x1 * dchi + chi_min) }; + if (small_angle) { + const real_t dtheta = eta2theta(HALF * deta); + return dchi * exp_chi * SQR(r0 + exp_chi) * + (static_cast(48) - SQR(dtheta)) * SQR(dtheta) / + static_cast(384); + } else { + return dchi * exp_chi * SQR(r0 + exp_chi) * + (ONE - math::cos(eta2theta(HALF * deta))); + } } } @@ -284,8 +307,7 @@ namespace metric { * @note tetrad/sph <-> cntrv <-> cov */ template - Inline auto transform(const coord_t& xi, const real_t& v_in) const - -> real_t { + Inline auto transform(const coord_t& xi, const real_t& v_in) const -> real_t { static_assert(i > 0 && i <= 3, "Invalid index i"); static_assert(in != out, "Invalid vector transformation"); if constexpr ((in == Idx::T && out == Idx::Sph) || diff --git a/src/metrics/spherical.h b/src/metrics/spherical.h index f4bbe2eea..388d3710e 100644 --- a/src/metrics/spherical.h +++ b/src/metrics/spherical.h @@ -33,6 +33,7 @@ namespace metric { const real_t dr, dtheta, dphi; const real_t dr_inv, dtheta_inv, dphi_inv; + const bool small_angle; public: static constexpr const char* Label { "spherical" }; @@ -50,8 +51,8 @@ namespace metric { using MetricBase::nx3; using MetricBase::set_dxMin; - Spherical(std::vector res, - boundaries_t ext, + Spherical(const std::vector& res, + const boundaries_t& ext, const std::map& = {}) : MetricBase { res, ext } , dr((x1_max - x1_min) / nx1) @@ -59,7 +60,8 @@ namespace metric { , dphi((x3_max - x3_min) / nx3) , dr_inv { ONE / dr } , dtheta_inv { ONE / dtheta } - , dphi_inv { ONE / dphi } { + , dphi_inv { ONE / dphi } + , small_angle { HALF * dtheta < constant::SMALL_ANGLE } { set_dxMin(find_dxMin()); } @@ -76,6 +78,20 @@ namespace metric { return ONE / math::sqrt(ONE / SQR(dx1) + ONE / SQR(dx2)); } + /** + * total volume of the region described by the metric (in physical units) + */ + [[nodiscard]] + auto totVolume() const -> real_t override { + if constexpr (D == Dim::_1D) { + raise::Error("1D spherical metric not applicable", HERE); + } else if constexpr (D == Dim::_2D) { + return (SQR(x1_max) - SQR(x1_min)) * (x2_max - x2_min); + } else { + return (SQR(x1_max) - SQR(x1_min)) * (x2_max - x2_min) * (x3_max - x3_min); + } + } + /** * metric component with lower indices: h_ij * @param x coordinate array in code units @@ -152,9 +168,16 @@ namespace metric { /** * differential area at the pole (used in axisymmetric solvers) * @param x1 radial coordinate along the axis (code units) + * @note uses small-angle approximation when the resolution is too high */ Inline auto polar_area(const real_t& x1) const -> real_t { - return dr * SQR(x1 * dr + x1_min) * (ONE - math::cos(HALF * dtheta)); + if (small_angle) { + return dr * SQR(x1 * dr + x1_min) * + (static_cast(48) - SQR(dtheta)) * SQR(dtheta) / + static_cast(384); + } else { + return dr * SQR(x1 * dr + x1_min) * (ONE - math::cos(HALF * dtheta)); + } } /** @@ -252,8 +275,7 @@ namespace metric { * @note tetrad/sph <-> cntrv <-> cov */ template - Inline auto transform(const coord_t& xi, const real_t& v_in) const - -> real_t { + Inline auto transform(const coord_t& xi, const real_t& v_in) const -> real_t { static_assert(i > 0 && i <= 3, "Invalid index i"); static_assert(in != out, "Invalid vector transformation"); if constexpr ((in == Idx::T && out == Idx::Sph) || diff --git a/src/metrics/tests/CMakeLists.txt b/src/metrics/tests/CMakeLists.txt index 117cb3295..0d661c318 100644 --- a/src/metrics/tests/CMakeLists.txt +++ b/src/metrics/tests/CMakeLists.txt @@ -1,9 +1,12 @@ +# cmake-lint: disable=C0103,C0111 # ------------------------------ # @brief: Generates tests for the `ntt_metrics` module +# # @uses: -# - kokkos [required] -# - plog [required] -# - mpi [optional] +# +# * kokkos [required] +# * plog [required] +# * mpi [optional] # ------------------------------ set(SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../) @@ -13,7 +16,7 @@ function(gen_test title) set(src ${title}.cpp) add_executable(${exec} ${src}) - set (libs ntt_metrics) + set(libs ntt_metrics) add_dependencies(${exec} ${libs}) target_link_libraries(${exec} PRIVATE ${libs}) @@ -25,4 +28,4 @@ gen_test(vec_trans) gen_test(coord_trans) gen_test(sph-qsph) gen_test(ks-qks) -gen_test(sr-cart-sph) \ No newline at end of file +gen_test(sr-cart-sph) diff --git a/src/metrics/tests/coord_trans.cpp b/src/metrics/tests/coord_trans.cpp index f3779a852..f2bd7e464 100644 --- a/src/metrics/tests/coord_trans.cpp +++ b/src/metrics/tests/coord_trans.cpp @@ -31,9 +31,9 @@ Inline auto equal(const coord_t& a, const char* msg, real_t acc = ONE) -> bool { const auto eps = epsilon * acc; - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { if (not cmp::AlmostEqual(a[d], b[d], eps)) { - printf("%d : %.12e != %.12e %s\n", d, a[d], b[d], msg); + Kokkos::printf("%d : %.12e != %.12e %s\n", d, a[d], b[d], msg); return false; } } @@ -44,7 +44,7 @@ template Inline void unravel(std::size_t idx, tuple_t& ijk, const tuple_t& res) { - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { ijk[d] = idx % res[d]; idx /= res[d]; } @@ -82,7 +82,7 @@ void testMetric(const std::vector& res, coord_t x_Code_2 { ZERO }; coord_t x_Phys_1 { ZERO }; coord_t x_Sph_1 { ZERO }; - for (unsigned short d = 0; d < M::Dim; ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { x_Code_1[d] = (real_t)(idx[d]) + HALF; } metric.template convert(x_Code_1, x_Phys_1); @@ -127,24 +127,24 @@ auto main(int argc, char* argv[]) -> int { const auto res2d = std::vector { 64, 32 }; const auto res3d = std::vector { 64, 32, 16 }; const auto ext1dcart = boundaries_t { - {10.0, 20.0} + { 10.0, 20.0 } }; const auto ext2dcart = boundaries_t { - {0.0, 20.0}, - {0.0, 10.0} + { 0.0, 20.0 }, + { 0.0, 10.0 } }; const auto ext3dcart = boundaries_t { - {-2.0, 2.0}, - {-1.0, 1.0}, - {-0.5, 0.5} + { -2.0, 2.0 }, + { -1.0, 1.0 }, + { -0.5, 0.5 } }; const auto extsph = boundaries_t { - {1.0, 10.0}, - {0.0, constant::PI} + { 1.0, 10.0 }, + { 0.0, constant::PI } }; const auto params = std::map { - {"r0", -ONE}, - { "h", (real_t)0.25} + { "r0", -ONE }, + { "h", (real_t)0.25 } }; testMetric>({ 128 }, ext1dcart); @@ -155,30 +155,30 @@ auto main(int argc, char* argv[]) -> int { const auto resks = std::vector { 64, 54 }; const auto extsks = boundaries_t { - {0.8, 50.0}, - {0.0, constant::PI} + { 0.8, 50.0 }, + { 0.0, constant::PI } }; const auto paramsks = std::map { - {"a", (real_t)0.95} + { "a", (real_t)0.95 } }; testMetric>(resks, extsks, 150, paramsks); const auto resqks = std::vector { 64, 42 }; const auto extqks = boundaries_t { - {0.8, 10.0}, - {0.0, constant::PI} + { 0.8, 10.0 }, + { 0.0, constant::PI } }; const auto paramsqks = std::map { - {"r0", -TWO}, - { "h", ZERO}, - { "a", (real_t)0.8} + { "r0", -TWO }, + { "h", ZERO }, + { "a", (real_t)0.8 } }; testMetric>(resqks, extqks, 500, paramsqks); const auto resks0 = std::vector { 64, 54 }; const auto extks0 = boundaries_t { - {0.5, 20.0}, - {0.0, constant::PI} + { 0.5, 20.0 }, + { 0.0, constant::PI } }; testMetric>(resks0, extks0, 150); diff --git a/src/metrics/tests/ks-qks.cpp b/src/metrics/tests/ks-qks.cpp index bed051e16..ea05c0f92 100644 --- a/src/metrics/tests/ks-qks.cpp +++ b/src/metrics/tests/ks-qks.cpp @@ -25,8 +25,9 @@ Inline auto equal(const vec_t& a, const char* msg, real_t acc = ONE) -> bool { const auto eps = epsilon * acc; - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { if (not cmp::AlmostEqual(a[d], b[d], eps)) { + Kokkos::printf("%s: %.12e : %.12e\n", msg, a[d], b[d]); return false; } } @@ -37,7 +38,7 @@ template Inline void unravel(std::size_t idx, tuple_t& ijk, const tuple_t& res) { - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { ijk[d] = idx % res[d]; idx /= res[d]; } @@ -74,7 +75,7 @@ void testMetric(const std::vector& res, coord_t x_Code { ZERO }; coord_t x_Phys { ZERO }; - for (unsigned short d = 0; d < M::Dim; ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { x_Code[d] = (real_t)(idx[d]) + HALF; } @@ -137,31 +138,31 @@ void testMetric(const std::vector& res, vec_t h_ij_expect { h_11_expect, h_22_expect, h_33_expect }; if (not equal(h_ij_predict, h_ij_expect, "h_ij", acc)) { - printf("h_ij: %.12e %.12e %.12e : %.12e %.12e %.12e\n", - h_ij_predict[0], - h_ij_predict[1], - h_ij_predict[2], - h_ij_expect[0], - h_ij_expect[1], - h_ij_expect[2]); + Kokkos::printf("h_ij: %.12e %.12e %.12e : %.12e %.12e %.12e\n", + h_ij_predict[0], + h_ij_predict[1], + h_ij_predict[2], + h_ij_expect[0], + h_ij_expect[1], + h_ij_expect[2]); Kokkos::abort("h_ij"); } if (not equal(h_13_predict, { h_13_expect }, "h_13", acc)) { - printf("h_13: %.12e : %.12e\n", h_13_predict[0], h_13_expect); + Kokkos::printf("h_13: %.12e : %.12e\n", h_13_predict[0], h_13_expect); Kokkos::abort("h_13"); } if (not equal(hij_predict, hij_expect, "hij", acc)) { - printf("hij: %.12e %.12e %.12e : %.12e %.12e %.12e\n", - hij_predict[0], - hij_predict[1], - hij_predict[2], - hij_expect[0], - hij_expect[1], - hij_expect[2]); + Kokkos::printf("hij: %.12e %.12e %.12e : %.12e %.12e %.12e\n", + hij_predict[0], + hij_predict[1], + hij_predict[2], + hij_expect[0], + hij_expect[1], + hij_expect[2]); Kokkos::abort("hij"); } if (not equal(h13_predict, { h13_expect }, "h13", acc)) { - printf("h13: %.12e : %.12e\n", h13_predict[0], h13_expect); + Kokkos::printf("h13: %.12e : %.12e\n", h13_predict[0], h13_expect); Kokkos::abort("h13"); } }); diff --git a/src/metrics/tests/minkowski.cpp b/src/metrics/tests/minkowski.cpp index 1ef27b4fa..2386d65d2 100644 --- a/src/metrics/tests/minkowski.cpp +++ b/src/metrics/tests/minkowski.cpp @@ -19,11 +19,10 @@ void errorIf(bool condition, const std::string& message) { inline static constexpr auto epsilon = std::numeric_limits::epsilon(); template -Inline auto equal(const coord_t& a, const coord_t& b, real_t acc = ONE) - -> bool { - for (unsigned short d = 0; d < D; ++d) { +Inline auto equal(const coord_t& a, const coord_t& b, real_t acc = ONE) -> bool { + for (auto d { 0u }; d < D; ++d) { if (not cmp::AlmostEqual(a[d], b[d], epsilon * acc)) { - printf("%d : %.12f != %.12f\n", d, a[d], b[d]); + Kokkos::printf("%d : %.12f != %.12f\n", d, a[d], b[d]); return false; } } diff --git a/src/metrics/tests/sph-qsph.cpp b/src/metrics/tests/sph-qsph.cpp index 230a763e1..faac28562 100644 --- a/src/metrics/tests/sph-qsph.cpp +++ b/src/metrics/tests/sph-qsph.cpp @@ -25,9 +25,9 @@ Inline auto equal(const vec_t& a, const char* msg, real_t acc = ONE) -> bool { const auto eps = epsilon * acc; - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { if (not cmp::AlmostEqual(a[d], b[d], eps)) { - printf("%d : %.12e != %.12e %s\n", d, a[d], b[d], msg); + Kokkos::printf("%d : %.12e != %.12e %s\n", d, a[d], b[d], msg); return false; } } @@ -38,7 +38,7 @@ template Inline void unravel(std::size_t idx, tuple_t& ijk, const tuple_t& res) { - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { ijk[d] = idx % res[d]; idx /= res[d]; } @@ -74,7 +74,7 @@ void testMetric(const std::vector& res, coord_t x_Code { ZERO }; coord_t x_Phys { ZERO }; - for (unsigned short d = 0; d < M::Dim; ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { x_Code[d] = (real_t)(idx[d]) + HALF; } @@ -115,12 +115,12 @@ auto main(int argc, char* argv[]) -> int { using namespace metric; const auto res = std::vector { 64, 32 }; const auto ext = boundaries_t { - {1.0, 10.0}, - {0.0, constant::PI} + { 1.0, 10.0 }, + { 0.0, constant::PI } }; const auto params = std::map { - {"r0", -ONE}, - { "h", (real_t)0.25} + { "r0", -ONE }, + { "h", (real_t)0.25 } }; testMetric>(res, ext, 10); diff --git a/src/metrics/tests/sr-cart-sph.cpp b/src/metrics/tests/sr-cart-sph.cpp index ec2f6ddc0..6ca6a52d7 100644 --- a/src/metrics/tests/sr-cart-sph.cpp +++ b/src/metrics/tests/sr-cart-sph.cpp @@ -28,9 +28,9 @@ Inline auto equal(const coord_t& a, const char* msg, real_t acc = ONE) -> bool { const auto eps = epsilon * acc; - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { if (not cmp::AlmostEqual(a[d], b[d], eps)) { - printf("%d : %.12e != %.12e %s\n", d, a[d], b[d], msg); + Kokkos::printf("%d : %.12e != %.12e %s\n", d, a[d], b[d], msg); return false; } } @@ -41,7 +41,7 @@ template Inline void unravel(std::size_t idx, tuple_t& ijk, const tuple_t& res) { - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { ijk[d] = idx % res[d]; idx /= res[d]; } @@ -81,7 +81,7 @@ void testMetric(const std::vector& res, coord_t x_Code_2 { ZERO }; coord_t x_Cart { ZERO }; - for (unsigned short d = 0; d < M::Dim; ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { x_Code_1[d] = (real_t)(idx[d]) + HALF; } metric.template convert_xyz(x_Code_1, x_Cart); @@ -95,7 +95,7 @@ void testMetric(const std::vector& res, coord_t x_Code_r1 { ZERO }; coord_t x_Code_r2 { ZERO }; coord_t x_Sph { ZERO }; - for (unsigned short d = 0; d < M::Dim; ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { x_Code_r1[d] = x_Code_1[d]; } metric.template convert(x_Code_r1, x_Sph); @@ -123,30 +123,30 @@ auto main(int argc, char* argv[]) -> int { const auto res2d = std::vector { 64, 32 }; const auto res3d = std::vector { 64, 32, 16 }; const auto ext1dcart = boundaries_t { - {10.0, 20.0} + { 10.0, 20.0 } }; const auto ext2dcart = boundaries_t { - {0.0, 20.0}, - {0.0, 10.0} + { 0.0, 20.0 }, + { 0.0, 10.0 } }; const auto ext3dcart = boundaries_t { - {-2.0, 2.0}, - {-1.0, 1.0}, - {-0.5, 0.5} + { -2.0, 2.0 }, + { -1.0, 1.0 }, + { -0.5, 0.5 } }; const auto extsph = boundaries_t { - {1.0, 10.0}, - {0.0, constant::PI} + { 1.0, 10.0 }, + { 0.0, constant::PI } }; const auto params = std::map { - {"r0", -ONE}, - { "h", (real_t)0.25} + { "r0", -ONE }, + { "h", (real_t)0.25 } }; testMetric>({ 128 }, ext1dcart); testMetric>(res2d, ext2dcart, 200); testMetric>(res3d, ext3dcart, 500); - testMetric>(res2d, extsph, 10); + testMetric>(res2d, extsph, 100); testMetric>(res2d, extsph, 200, params); } catch (std::exception& e) { diff --git a/src/metrics/tests/vec_trans.cpp b/src/metrics/tests/vec_trans.cpp index 31015115c..e9a03aa50 100644 --- a/src/metrics/tests/vec_trans.cpp +++ b/src/metrics/tests/vec_trans.cpp @@ -31,9 +31,9 @@ Inline auto equal(const vec_t& a, const char* msg, real_t acc = ONE) -> bool { const auto eps = epsilon * acc; - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { if (not cmp::AlmostEqual(a[d], b[d], eps)) { - printf("%d : %.12e != %.12e %s\n", d, a[d], b[d], msg); + Kokkos::printf("%d : %.12e != %.12e %s\n", d, a[d], b[d], msg); return false; } } @@ -44,7 +44,7 @@ template Inline void unravel(std::size_t idx, tuple_t& ijk, const tuple_t& res) { - for (unsigned short d = 0; d < D; ++d) { + for (auto d { 0u }; d < D; ++d) { ijk[d] = idx % res[d]; idx /= res[d]; } @@ -79,7 +79,7 @@ void testMetric(const std::vector& res, tuple_t idx; unravel(n, idx, res_tup); coord_t x_Code { ZERO }; - for (unsigned short d = 0; d < M::Dim; ++d) { + for (auto d { 0u }; d < M::Dim; ++d) { x_Code[d] = (real_t)(idx[d]) + HALF; } vec_t v_Hat_1 { ZERO }; @@ -94,7 +94,7 @@ void testMetric(const std::vector& res, vec_t v_PhysCov_2 { ZERO }; // init - for (unsigned short d = 0; d < Dim::_3D; ++d) { + for (auto d { 0u }; d < 3u; ++d) { v_Hat_1[d] += ONE; v_PhysCntrv_1[d] += ONE; v_PhysCov_1[d] += ONE; @@ -102,12 +102,12 @@ void testMetric(const std::vector& res, // hat <-> cntrv metric.template transform(x_Code, v_Hat_1, v_Cntrv_1); - for (unsigned short d = 0; d < Dim::_3D; ++d) { + for (auto d { 0u }; d < 3u; ++d) { vec_t e_d { ZERO }; vec_t v_Cntrv_temp { ZERO }; e_d[d] = ONE; metric.template transform(x_Code, e_d, v_Cntrv_temp); - for (unsigned short d = 0; d < Dim::_3D; ++d) { + for (auto d { 0u }; d < 3u; ++d) { v_Cntrv_2[d] += v_Cntrv_temp[d]; } } @@ -123,15 +123,15 @@ void testMetric(const std::vector& res, v_Cov_2, "cntrv->cov is equal to hat->cov", acc); - for (unsigned short d = 0; d < Dim::_3D; ++d) { + for (auto d { 0u }; d < 3u; ++d) { v_Cov_2[d] = ZERO; } - for (unsigned short d = 0; d < Dim::_3D; ++d) { + for (auto d { 0u }; d < 3u; ++d) { vec_t e_d { ZERO }; vec_t v_Cov_temp { ZERO }; e_d[d] = ONE; metric.template transform(x_Code, e_d, v_Cov_temp); - for (unsigned short d = 0; d < Dim::_3D; ++d) { + for (auto d { 0u }; d < 3u; ++d) { v_Cov_2[d] += v_Cov_temp[d]; } } @@ -179,24 +179,24 @@ auto main(int argc, char* argv[]) -> int { const auto res2d = std::vector { 64, 32 }; const auto res3d = std::vector { 64, 32, 16 }; const auto ext1dcart = boundaries_t { - {10.0, 20.0} + { 10.0, 20.0 } }; const auto ext2dcart = boundaries_t { - {0.0, 20.0}, - {0.0, 10.0} + { 0.0, 20.0 }, + { 0.0, 10.0 } }; const auto ext3dcart = boundaries_t { - {-2.0, 2.0}, - {-1.0, 1.0}, - {-0.5, 0.5} + { -2.0, 2.0 }, + { -1.0, 1.0 }, + { -0.5, 0.5 } }; const auto extsph = boundaries_t { - {1.0, 10.0}, - {0.0, constant::PI} + { 1.0, 10.0 }, + { 0.0, constant::PI } }; const auto params = std::map { - {"r0", -ONE}, - { "h", (real_t)0.25} + { "r0", -ONE }, + { "h", (real_t)0.25 } }; // testMetric>({ 128 }, ext1dcart); @@ -219,13 +219,13 @@ auto main(int argc, char* argv[]) -> int { // const auto resqks = std::vector { 64, 42 }; const auto extqks = boundaries_t { - {0.8, 10.0}, - {0.0, constant::PI} + { 0.8, 10.0 }, + { 0.0, constant::PI } }; const auto paramsqks = std::map { - {"r0", -TWO}, - { "h", ZERO}, - { "a", (real_t)0.8} + { "r0", -TWO }, + { "h", ZERO }, + { "a", (real_t)0.8 } }; testMetric>(resqks, extqks, 500, paramsqks); // diff --git a/src/output/CMakeLists.txt b/src/output/CMakeLists.txt index b262d7771..1b132fb60 100644 --- a/src/output/CMakeLists.txt +++ b/src/output/CMakeLists.txt @@ -1,33 +1,44 @@ +# cmake-lint: disable=C0103 # ------------------------------ # @defines: ntt_output [STATIC/SHARED] +# # @sources: -# - writer.cpp -# - fields.cpp -# - utils/interpret_prompt.cpp +# +# * writer.cpp +# * fields.cpp +# * stats.cpp +# * utils/interpret_prompt.cpp +# # @includes: -# - ../ +# +# * ../ +# # @depends: -# - ntt_global [required] +# +# * ntt_global [required] +# # @uses: -# - kokkos [required] -# - ADIOS2 [required] -# - mpi [optional] +# +# * kokkos [required] +# * ADIOS2 [optional] +# * mpi [optional] # ------------------------------ set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(SOURCES - ${SRC_DIR}/writer.cpp - ${SRC_DIR}/write_attrs.cpp - ${SRC_DIR}/fields.cpp - ${SRC_DIR}/utils/interpret_prompt.cpp -) + +set(SOURCES ${SRC_DIR}/stats.cpp ${SRC_DIR}/fields.cpp + ${SRC_DIR}/utils/interpret_prompt.cpp) +if(${output}) + list(APPEND SOURCES ${SRC_DIR}/writer.cpp) +endif() add_library(ntt_output ${SOURCES}) set(libs ntt_global) add_dependencies(ntt_output ${libs}) target_link_libraries(ntt_output PUBLIC ${libs}) +target_link_libraries(ntt_output PRIVATE stdc++fs) -target_include_directories(ntt_output +target_include_directories( + ntt_output PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../ - INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../ -) + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../) diff --git a/src/output/fields.cpp b/src/output/fields.cpp index aa5a752d4..e6b86296f 100644 --- a/src/output/fields.cpp +++ b/src/output/fields.cpp @@ -23,20 +23,30 @@ namespace out { // determine the field ID const auto pos = name.find("_"); auto name_raw = (pos == std::string::npos) ? name : name.substr(0, pos); - name_raw = name_raw.substr(0, name_raw.find_first_of("0123ijxyzt")); + if ((fmt::toLower(name_raw) != "dive") and + (fmt::toLower(name_raw) != "divd")) { + name_raw = name_raw.substr(0, name_raw.find_first_of("0123ijxyzt")); + } if (FldsID::contains(fmt::toLower(name_raw).c_str())) { m_id = FldsID::pick(fmt::toLower(name_raw).c_str()); } else { m_id = FldsID::Custom; } + // check compatibility + raise::ErrorIf(id() == FldsID::A and S != SimEngine::GRPIC, + "Output of A_phi not supported for non-GRPIC", + HERE); + raise::ErrorIf(id() == FldsID::V and S == SimEngine::GRPIC, + "Output of bulk 3-vel not supported for GRPIC", + HERE); // determine the species and components to output if (is_moment()) { species = InterpretSpecies(name); } else { species = {}; } - if (is_field() || is_current()) { - // always write all the field/current components + if (is_field() || is_current() || id() == FldsID::V) { + // always write all the field/current/bulk vel components comp = { { 1 }, { 2 }, { 3 } }; } else if (id() == FldsID::A) { // only write A3 @@ -44,12 +54,15 @@ namespace out { } else if (id() == FldsID::T) { // energy-momentum tensor comp = InterpretComponents({ name.substr(1, 1), name.substr(2, 1) }); + } else if (id() == FldsID::V) { + // energy-momentum tensor + comp = InterpretComponents({ name.substr(1, 1) }); } else { // scalar (Rho, divE, Custom, etc.) comp = {}; } // data preparation flags - if (not is_moment() && not is_custom()) { + if (not(is_moment() or is_custom() or is_divergence())) { if (S == SimEngine::SRPIC) { prepare_flag = PrepareOutput::ConvertToHat; } else { diff --git a/src/output/fields.h b/src/output/fields.h index a520a246d..bc1271084 100644 --- a/src/output/fields.h +++ b/src/output/fields.h @@ -34,7 +34,7 @@ namespace out { PrepareOutputFlags interp_flag { PrepareOutput::None }; std::vector> comp {}; - std::vector species {}; + std::vector species {}; OutputField(const SimEngine& S, const std::string&); @@ -43,7 +43,7 @@ namespace out { [[nodiscard]] auto is_moment() const -> bool { return (id() == FldsID::T || id() == FldsID::Rho || id() == FldsID::Nppc || - id() == FldsID::N || id() == FldsID::Charge); + id() == FldsID::N || id() == FldsID::Charge || id() == FldsID::V); } [[nodiscard]] @@ -94,7 +94,7 @@ namespace out { tmp += m_name.substr(1, 2); } else if (id() == FldsID::A) { tmp += "3"; - } else if (is_field()) { + } else if (is_field() || id() == FldsID::V) { tmp += "i"; } if (species.size() > 0) { @@ -105,6 +105,10 @@ namespace out { } tmp.pop_back(); } + if (tmp == "dive" || tmp == "divd") { + // capitalize E/D + tmp[3] = std::toupper(tmp[3]); + } // capitalize the first letter tmp[0] = std::toupper(tmp[0]); } @@ -112,7 +116,7 @@ namespace out { } [[nodiscard]] - inline auto name(const std::size_t& ci) const -> std::string { + inline auto name(std::size_t ci) const -> std::string { raise::ErrorIf( comp.size() == 0, "OutputField::name(ci) called but no components were available", @@ -138,6 +142,10 @@ namespace out { } tmp.pop_back(); } + if (tmp == "dive" || tmp == "divd") { + // capitalize E/D + tmp[3] = std::toupper(tmp[3]); + } // capitalize the first letter tmp[0] = std::toupper(tmp[0]); return "f" + tmp; diff --git a/src/output/particles.h b/src/output/particles.h index fb05fec7d..0936e66f9 100644 --- a/src/output/particles.h +++ b/src/output/particles.h @@ -8,20 +8,22 @@ #ifndef OUTPUT_PARTICLES_H #define OUTPUT_PARTICLES_H +#include "global.h" + #include namespace out { class OutputSpecies { - const unsigned short m_sp; + const spidx_t m_sp; public: - OutputSpecies(unsigned short sp) : m_sp { sp } {} + OutputSpecies(spidx_t sp) : m_sp { sp } {} ~OutputSpecies() = default; [[nodiscard]] - auto species() const -> unsigned short { + auto species() const -> spidx_t { return m_sp; } diff --git a/src/output/spectra.h b/src/output/spectra.h index 119495cd3..c3e5d13d7 100644 --- a/src/output/spectra.h +++ b/src/output/spectra.h @@ -8,20 +8,22 @@ #ifndef OUTPUT_SPECTRA_H #define OUTPUT_SPECTRA_H +#include "global.h" + #include namespace out { class OutputSpectra { - const unsigned short m_sp; + const spidx_t m_sp; public: - OutputSpectra(unsigned short sp) : m_sp { sp } {} + OutputSpectra(spidx_t sp) : m_sp { sp } {} ~OutputSpectra() = default; [[nodiscard]] - auto species() const -> unsigned short { + auto species() const -> spidx_t { return m_sp; } diff --git a/src/output/stats.cpp b/src/output/stats.cpp new file mode 100644 index 000000000..6aa65067a --- /dev/null +++ b/src/output/stats.cpp @@ -0,0 +1,111 @@ +#include "output/stats.h" + +#include "enums.h" +#include "global.h" + +#include "arch/mpi_aliases.h" +#include "utils/error.h" +#include "utils/formatting.h" + +#include "output/utils/interpret_prompt.h" + +#include + +#include +#include +#include + +using namespace ntt; +using namespace out; + +namespace stats { + + OutputStats::OutputStats(const std::string& name, bool is_custom) + : m_name { name } { + if (is_custom) { + m_id = StatsID::Custom; + comp = {}; + species = {}; + return; + } + // determine the stats ID + const auto pos = name.find("_"); + auto name_raw = (pos == std::string::npos) ? name : name.substr(0, pos); + if ((name_raw[0] != 'E') and (name_raw[0] != 'B') and (name_raw[0] != 'J')) { + name_raw = name_raw.substr(0, name_raw.find_first_of("0123ijxyzt")); + } + if (StatsID::contains(fmt::toLower(name_raw).c_str())) { + m_id = StatsID::pick(fmt::toLower(name_raw).c_str()); + } else { + raise::Error("Unrecognized stats name: " + name, HERE); + } + // determine the species and components to output + if (is_moment()) { + species = InterpretSpecies(name); + } else { + species = {}; + } + if (is_vector()) { + // always write all the E^2, B^2, ExB components + comp = { { 1 }, { 2 }, { 3 } }; + } else if (id() == StatsID::T) { + // energy-momentum tensor + comp = InterpretComponents({ name.substr(1, 1), name.substr(2, 1) }); + } else { + // scalar (e.g., Rho, Npart, etc.) + comp = {}; + } + } + + void Writer::init(timestep_t interval, simtime_t interval_time) { + m_tracker = tools::Tracker("stats", interval, interval_time); + } + + auto Writer::shouldWrite(timestep_t step, simtime_t time) -> bool { + return m_tracker.shouldWrite(step, time); + } + + void Writer::defineStatsFilename(const std::string& filename) { + m_fname = filename; + } + + void Writer::defineStatsOutputs(const std::vector& stats_to_write, + bool is_custom) { + for (const auto& stat : stats_to_write) { + m_stat_writers.emplace_back(stat, is_custom); + } + } + + void Writer::writeHeader() { + CallOnce( + [](auto& fname, auto& stat_writers) { + std::fstream StatsOut(fname, std::fstream::out | std::fstream::app); + StatsOut << std::setw(14) << "step" << "," << std::setw(14) << "time" + << ","; + for (const auto& stat : stat_writers) { + if (stat.is_vector()) { + for (auto i { 0u }; i < stat.comp.size(); ++i) { + StatsOut << std::setw(14) << stat.name(i) << ","; + } + } else { + StatsOut << std::setw(14) << stat.name() << ","; + } + } + StatsOut << std::endl; + StatsOut.close(); + }, + m_fname, + m_stat_writers); + } + + void Writer::endWriting() { + CallOnce( + [](auto& fname) { + std::fstream StatsOut(fname, std::fstream::out | std::fstream::app); + StatsOut << std::endl; + StatsOut.close(); + }, + m_fname); + } + +} // namespace stats diff --git a/src/output/stats.h b/src/output/stats.h new file mode 100644 index 000000000..fc5bbf3a2 --- /dev/null +++ b/src/output/stats.h @@ -0,0 +1,203 @@ +/** + * @file output/stats.h + * @brief Class defining the metadata necessary to prepare the stats for output + * @implements + * - out::OutputStats + * - out::Writer + * @cpp: + * - stats.cpp + * @namespaces: + * - stats:: + */ + +#ifndef OUTPUT_STATS_H +#define OUTPUT_STATS_H + +#include "enums.h" +#include "global.h" + +#include "utils/error.h" +#include "utils/formatting.h" +#include "utils/tools.h" + +#if defined(MPI_ENABLED) + #include "arch/mpi_aliases.h" + + #include +#endif + +#include +#include +#include +#include + +using namespace ntt; + +namespace stats { + + class OutputStats { + const std::string m_name; + StatsID m_id { StatsID::INVALID }; + + public: + std::vector> comp {}; + std::vector species {}; + + OutputStats(const std::string&, bool); + + ~OutputStats() = default; + + [[nodiscard]] + auto is_moment() const -> bool { + return (id() == StatsID::T || id() == StatsID::Rho || id() == StatsID::Npart || + id() == StatsID::N || id() == StatsID::Charge); + } + + [[nodiscard]] + auto is_vector() const -> bool { + return id() == StatsID::ExB || id() == StatsID::E2 || id() == StatsID::B2; + } + + [[nodiscard]] + auto is_custom() const -> bool { + return id() == StatsID::Custom; + } + + [[nodiscard]] + inline auto name() const -> std::string { + if (id() == StatsID::Custom) { + return m_name; + } + // generate the name + std::string tmp = std::string(id().to_string()); + if (tmp == "exb") { + tmp = "ExB"; + } else if (tmp == "j.e") { + tmp = "J.E"; + } else { + // capitalize the first letter + tmp[0] = std::toupper(tmp[0]); + } + if (id() == StatsID::T) { + tmp += m_name.substr(1, 2); + } else if (is_vector()) { + if (id() == StatsID::E2 || id() == StatsID::B2) { + tmp = fmt::format("%ci^2", tmp[0]); + } else { + tmp += "i"; + } + } + if (species.size() > 0) { + tmp += "_"; + for (auto& s : species) { + tmp += std::to_string(s); + tmp += "_"; + } + tmp.pop_back(); + } + return tmp; + } + + [[nodiscard]] + inline auto name(std::size_t ci) const -> std::string { + raise::ErrorIf( + comp.size() == 0, + "OutputField::name(ci) called but no components were available", + HERE); + raise::ErrorIf( + ci >= comp.size(), + "OutputField::name(ci) called with an invalid component index", + HERE); + raise::ErrorIf( + comp[ci].size() == 0, + "OutputField::name(ci) called but no components were available", + HERE); + // generate the name + auto tmp = std::string(id().to_string()); + // capitalize the first letter + if (tmp == "exb") { + tmp = "ExB"; + } else { + // capitalize the first letter + tmp[0] = std::toupper(tmp[0]); + } + if (tmp == "E^2" or tmp == "B^2") { + tmp = fmt::format("%c%d^2", tmp[0], comp[ci][0]); + } else { + for (auto& c : comp[ci]) { + tmp += std::to_string(c); + } + if (species.size() > 0) { + tmp += "_"; + for (auto& s : species) { + tmp += std::to_string(s); + tmp += "_"; + } + tmp.pop_back(); + } + } + return tmp; + } + + [[nodiscard]] + auto id() const -> StatsID { + return m_id; + } + }; + + class Writer { + std::string m_fname; + std::vector m_stat_writers; + tools::Tracker m_tracker; + + public: + Writer() {} + + ~Writer() = default; + + Writer(Writer&&) = default; + + void init(timestep_t, simtime_t); + void defineStatsFilename(const std::string&); + void defineStatsOutputs(const std::vector&, bool); + + void writeHeader(); + + [[nodiscard]] + auto shouldWrite(timestep_t, simtime_t) -> bool; + + template + inline void write(const T& value) const { + auto tot_value { static_cast(0) }; +#if defined(MPI_ENABLED) + MPI_Reduce(&value, + &tot_value, + 1, + mpi::get_type(), + MPI_SUM, + MPI_ROOT_RANK, + MPI_COMM_WORLD); +#else + tot_value = value; +#endif + CallOnce( + [](auto&& fname, auto&& value) { + std::fstream StatsOut(fname, std::fstream::out | std::fstream::app); + StatsOut << std::setw(14) << value << ","; + StatsOut.close(); + }, + m_fname, + tot_value); + } + + void endWriting(); + + [[nodiscard]] + auto statsWriters() const -> const std::vector& { + return m_stat_writers; + } + }; + +} // namespace stats + +#endif // OUTPUT_STATS_H diff --git a/src/output/tests/CMakeLists.txt b/src/output/tests/CMakeLists.txt index d33cc6c54..f6f460ae9 100644 --- a/src/output/tests/CMakeLists.txt +++ b/src/output/tests/CMakeLists.txt @@ -1,30 +1,41 @@ +# cmake-lint: disable=C0103,C0111 # ------------------------------ # @brief: Generates tests for the `ntt_output` module +# # @uses: -# - kokkos [required] -# - mpi [optional] -# - adios2 [optional] -# !TODO: -# - add more proper write tests for ADIOS2 +# +# * kokkos [required] +# * mpi [optional] +# * adios2 [optional] # ------------------------------ set(SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../) -function(gen_test title) +function(gen_test title is_parallel) set(exec test-output-${title}.xc) set(src ${title}.cpp) add_executable(${exec} ${src}) - set (libs ntt_output ntt_global ntt_metrics ntt_framework) + set(libs ntt_output ntt_global ntt_metrics ntt_framework) add_dependencies(${exec} ${libs}) target_link_libraries(${exec} PRIVATE ${libs} stdc++fs) - add_test(NAME "OUTPUT::${title}" COMMAND "${exec}") + if(${is_parallel}) + add_test(NAME "OUTPUT::${title}" + COMMAND "${MPIEXEC_EXECUTABLE}" "${MPIEXEC_NUMPROC_FLAG}" "4" + "${exec}") + else() + add_test(NAME "OUTPUT::${title}" COMMAND "${exec}") + endif() endfunction() -if (NOT ${mpi}) - gen_test(fields) - gen_test(writer-nompi) -else() - gen_test(writer-mpi) -endif() \ No newline at end of file +gen_test(stats false) + +if(${output}) + if(NOT ${mpi}) + gen_test(fields false) + gen_test(writer-nompi false) + else() + gen_test(writer-mpi true) + endif() +endif() diff --git a/src/output/tests/fields.cpp b/src/output/tests/fields.cpp index e09bed142..de86af2f4 100644 --- a/src/output/tests/fields.cpp +++ b/src/output/tests/fields.cpp @@ -45,11 +45,30 @@ auto main() -> int { raise::ErrorIf(rho.interp_flag != PrepareOutput::None, "Rho should not have any interp flags", HERE); - raise::ErrorIf(not(rho.species == std::vector { 1, 3 }), + raise::ErrorIf(not(rho.species == std::vector { 1, 3 }), "Rho should have species 1 and 3", HERE); } + { + const auto dive = OutputField(SimEngine::SRPIC, "divE"); + raise::ErrorIf(dive.is_moment(), "divE should not be a moment", HERE); + raise::ErrorIf(dive.is_field(), "divE should not be a field", HERE); + raise::ErrorIf(not dive.is_divergence(), "divE should be a divergence", HERE); + raise::ErrorIf(dive.id() != FldsID::divE, + "divE should have ID FldsID::divE", + HERE); + raise::ErrorIf(dive.name() != "fDivE", "divE should have name `fDivE`", HERE); + raise::ErrorIf(dive.comp.size() != 0, "divE should have 0 components", HERE); + raise::ErrorIf(dive.species.size() != 0, "divE should have no species", HERE); + raise::ErrorIf(dive.prepare_flag != PrepareOutput::None, + "divE should not have any prepare flags", + HERE); + raise::ErrorIf(dive.interp_flag != PrepareOutput::None, + "divE should not have any interp flags", + HERE); + } + { const auto t = OutputField(SimEngine::GRPIC, "Tti_2_3"); raise::ErrorIf(not t.is_moment(), "T should be a moment", HERE); diff --git a/src/output/tests/stats.cpp b/src/output/tests/stats.cpp new file mode 100644 index 000000000..11e9961e3 --- /dev/null +++ b/src/output/tests/stats.cpp @@ -0,0 +1,114 @@ +#include "output/stats.h" + +#include "enums.h" + +#include "utils/error.h" + +#include +#include +#include + +auto main() -> int { + using namespace stats; + using namespace ntt; + try { + { + const auto e = OutputStats("E^2", false); + raise::ErrorIf(not e.is_vector(), "E^2 should be a vector quantity", HERE); + raise::ErrorIf(e.is_moment(), "E^2 should not be a moment", HERE); + raise::ErrorIf(e.id() != StatsID::E2, "E^2 should have ID StatsID::E2", HERE); + raise::ErrorIf(e.species.size() != 0, "E^2 should have no species", HERE); + raise::ErrorIf(e.comp.size() != 3, "E^2 should have 3 components", HERE); + raise::ErrorIf(e.name() != "Ei^2", "E^2 should have name `Ei^2`", HERE); + } + + { + const auto e = OutputStats("ExB", false); + raise::ErrorIf(not e.is_vector(), "ExB should be a vector quantity", HERE); + raise::ErrorIf(e.is_moment(), "ExB should not be a moment", HERE); + raise::ErrorIf(e.id() != StatsID::ExB, "ExB should have ID StatsID::ExB", HERE); + raise::ErrorIf(e.species.size() != 0, "ExB should have no species", HERE); + raise::ErrorIf(e.comp.size() != 3, "ExB should have 3 components", HERE); + raise::ErrorIf(e.name() != "ExBi", "ExB should have name `ExBi`", HERE); + } + + { + const auto e = OutputStats("J.E", false); + raise::ErrorIf(e.is_vector(), "J.E should not be a vector quantity", HERE); + raise::ErrorIf(e.is_moment(), "J.E should not be a moment", HERE); + raise::ErrorIf(e.id() != StatsID::JdotE, + "J.E should have ID StatsID::JdotE", + HERE); + raise::ErrorIf(e.species.size() != 0, "J.E should have no species", HERE); + raise::ErrorIf(e.comp.size() != 0, "J.E should have no components", HERE); + raise::ErrorIf(e.name() != "J.E", "J.E should have name `J.E`", HERE); + } + + { + const auto rho = OutputStats("Rho_1_3", false); + raise::ErrorIf(not rho.is_moment(), "Rho should be a moment", HERE); + raise::ErrorIf(rho.id() != StatsID::Rho, + "Rho should have ID StatsID::Rho", + HERE); + raise::ErrorIf(rho.name() != "Rho_1_3", "Rho should have name `Rho_1_3`", HERE); + raise::ErrorIf(rho.comp.size() != 0, "Rho should have 0 components", HERE); + raise::ErrorIf(not(rho.species == std::vector { 1, 3 }), + "Rho should have species 1 and 3", + HERE); + } + + { + const auto t = OutputStats("Tti_2_3", false); + raise::ErrorIf(not t.is_moment(), "T should be a moment", HERE); + raise::ErrorIf(t.is_vector(), "T should not be a vector quantity", HERE); + raise::ErrorIf(t.id() != StatsID::T, "T should have ID StatsID::T", HERE); + raise::ErrorIf(t.name() != "Tti_2_3", "T should have name `Tti_2_3`", HERE); + raise::ErrorIf(t.name(0) != "T01_2_3", "T should have name `T01_2_3`", HERE); + raise::ErrorIf(t.name(1) != "T02_2_3", "T should have name `T02_2_3`", HERE); + raise::ErrorIf(t.name(2) != "T03_2_3", "T should have name `T03_2_3`", HERE); + raise::ErrorIf(t.comp.size() != 3, "T should have 3 component", HERE); + raise::ErrorIf(t.comp[0].size() != 2, + "T.comp[0] should have 2 components", + HERE); + raise::ErrorIf(t.comp[1].size() != 2, + "T.comp[1] should have 2 components", + HERE); + raise::ErrorIf(t.comp[2].size() != 2, + "T.comp[2] should have 2 components", + HERE); + raise::ErrorIf(t.comp[0] != std::vector { 0, 1 }, + "T.comp[0] should be {0, 1}", + HERE); + raise::ErrorIf(t.comp[1] != std::vector { 0, 2 }, + "T.comp[1] should be {0, 2}", + HERE); + raise::ErrorIf(t.comp[2] != std::vector { 0, 3 }, + "T.comp[2] should be {0, 3}", + HERE); + raise::ErrorIf(t.species.size() != 2, "T should have 2 species", HERE); + raise::ErrorIf(t.species[0] != 2, "T should have specie 2", HERE); + raise::ErrorIf(t.species[1] != 3, "T should have specie 3", HERE); + } + + { + const auto t = OutputStats("Tij", false); + raise::ErrorIf(t.comp.size() != 6, "T should have 6 component", HERE); + } + + { + const auto custom = OutputStats("C2x_12", true); + raise::ErrorIf(custom.name() != "C2x_12", + "Custom should have name `C2x_12`", + HERE); + raise::ErrorIf(not custom.is_custom(), + "Custom should be... well... a custom", + HERE); + raise::ErrorIf(custom.is_moment(), "Custom should not be a moment", HERE); + raise::ErrorIf(custom.is_vector(), "Custom should not be a vector", HERE); + } + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return 1; + } + return 0; +} diff --git a/src/output/tests/writer-mpi.cpp b/src/output/tests/writer-mpi.cpp index 72cf46e35..bc95bbc81 100644 --- a/src/output/tests/writer-mpi.cpp +++ b/src/output/tests/writer-mpi.cpp @@ -1,9 +1,8 @@ #include "enums.h" #include "global.h" -#include "utils/formatting.h" +#include "arch/mpi_aliases.h" -#include "output/fields.h" #include "output/writer.h" #include @@ -13,59 +12,176 @@ #include #include -#include #include #include void cleanup() { namespace fs = std::filesystem; - // fs::path tempfile_path { "test.h5" }; - // fs::remove(tempfile_path); + fs::path tempfile_path { "test.h5" }; + fs::remove(tempfile_path); } +#define CEILDIV(a, b) \ + (static_cast( \ + math::ceil(static_cast(a) / static_cast(b)))) + auto main(int argc, char* argv[]) -> int { Kokkos::initialize(argc, argv); MPI_Init(&argc, &argv); - int rank, size; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &size); + int mpi_rank, mpi_size; + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank); + MPI_Comm_size(MPI_COMM_WORLD, &mpi_size); try { using namespace ntt; - auto writer = out::Writer("hdf5"); - writer.defineMeshLayout({ static_cast(size) * 10 }, - { static_cast(rank) * 10 }, - { 10 }, - false, - Coord::Cart); - writer.defineFieldOutputs(SimEngine::SRPIC, { "E" }); - - ndfield_t field { "fld", 10 + 2 * N_GHOSTS }; - Kokkos::parallel_for( - "fill", - CreateRangePolicy({ N_GHOSTS }, { 10 + N_GHOSTS }), - Lambda(index_t i1) { - field(i1, 0) = i1; - field(i1, 1) = -(real_t)(i1); - field(i1, 2) = i1 / 2; - }); - std::vector names; - std::vector addresses; - for (auto i = 0; i < 3; ++i) { - names.push_back(writer.fieldWriters()[0].name(i)); - addresses.push_back(i); + constexpr auto nx1 = 10; + constexpr auto nx1_gh = nx1 + 2 * N_GHOSTS; + constexpr auto i1min = N_GHOSTS; + constexpr auto i1max = nx1 + N_GHOSTS; + constexpr auto dwn1 = 3; + + ndfield_t field { "fld", nx1_gh }; + std::vector field_names; + + { + // fill data + Kokkos::parallel_for( + "fill", + CreateRangePolicy({ i1min }, { i1max }), + Lambda(index_t i1) { + const auto i1_ = static_cast(i1); + field(i1, 0) = i1_; + field(i1, 1) = -i1_; + field(i1, 2) = SQR(i1_); + }); } - writer.beginWriting("test", 0, 0.0); - writer.writeField(names, field, addresses); - writer.endWriting(); - writer.beginWriting("test", 1, 0.1); - writer.writeField(names, field, addresses); - writer.endWriting(); + adios2::ADIOS adios { MPI_COMM_WORLD }; + + { + // write + auto writer = out::Writer(); + writer.init(&adios, "hdf5", "test", false); + writer.defineMeshLayout({ static_cast(mpi_size) * nx1 }, + { static_cast(mpi_rank) * nx1 }, + { nx1 }, + { mpi_rank, mpi_size }, + { dwn1 }, + false, + Coord::Cart); + writer.defineFieldOutputs(SimEngine::SRPIC, { "E" }); + + std::vector addresses; + for (auto i = 0; i < 3; ++i) { + field_names.push_back(writer.fieldWriters()[0].name(i)); + addresses.push_back(i); + } + writer.beginWriting(WriteMode::Fields, 0, 0.0); + writer.writeField(field_names, field, addresses); + writer.endWriting(WriteMode::Fields); + + writer.beginWriting(WriteMode::Fields, 1, 0.1); + writer.writeField(field_names, field, addresses); + writer.endWriting(WriteMode::Fields); + adios.ExitComputationBlock(); + } + + adios.FlushAll(); + + { + // read + adios2::IO io = adios.DeclareIO("read-test"); + io.SetEngine("HDF5"); + adios2::Engine reader = io.Open("test.h5", adios2::Mode::Read); + raise::ErrorIf(io.InquireAttribute("NGhosts").Data()[0] != 0, + "NGhosts is not correct", + HERE); + raise::ErrorIf(io.InquireAttribute("Dimension").Data()[0] != 1, + "Dimension is not correct", + HERE); + for (std::size_t step = 0; reader.BeginStep() == adios2::StepStatus::OK; + ++step) { + timestep_t step_read; + simtime_t time_read; + + reader.Get(io.InquireVariable("Step"), + &step_read, + adios2::Mode::Sync); + reader.Get(io.InquireVariable("Time"), + &time_read, + adios2::Mode::Sync); + raise::ErrorIf(step_read != step, "Step is not correct", HERE); + raise::ErrorIf((float)time_read != (float)step * 0.1f, + "Time is not correct", + HERE); + + const auto l_size = nx1; + const auto l_offset = nx1 * mpi_rank; + + const double n = l_size; + const double d = dwn1; + const double l = l_offset; + const double f = math::ceil(l / d) * d - l; + + const auto first_cell = static_cast(f); + const auto l_size_dwn = static_cast(math::ceil((n - f) / d)); + const auto l_corner_dwn = static_cast(math::ceil(l / d)); + + array_t field_read {}; + int cntr = 0; + for (const auto& name : field_names) { + auto fieldVar = io.InquireVariable(name); + if (fieldVar) { + raise::ErrorIf(fieldVar.Shape().size() != 1, + fmt::format("%s is not 1D", name.c_str()), + HERE); + auto dims = fieldVar.Shape(); + std::size_t nx1_r = dims[0]; + raise::ErrorIf((nx1_r != CEILDIV(nx1 * mpi_size, dwn1)), + fmt::format("%s = %ld is not %d", + name.c_str(), + nx1_r, + CEILDIV(nx1 * mpi_size, dwn1)), + HERE); + + fieldVar.SetSelection( + adios2::Box({ l_corner_dwn }, { l_size_dwn })); + field_read = array_t(name, l_size_dwn); + auto field_read_h = Kokkos::create_mirror_view(field_read); + reader.Get(fieldVar, field_read_h.data(), adios2::Mode::Sync); + Kokkos::deep_copy(field_read, field_read_h); + + Kokkos::parallel_for( + "check", + CreateRangePolicy({ 0 }, { l_size_dwn }), + Lambda(index_t i1) { + if (not cmp::AlmostEqual( + field_read(i1), + field(i1 * dwn1 + first_cell + i1min, cntr))) { + Kokkos::printf("\n:::::::::::::::\nfield_read(%ld) = %f != " + "field(%ld, %d) = %f\n:::::::::::::::\n", + i1, + field_read(i1), + i1 * dwn1 + first_cell + i1min, + cntr, + field(i1 * dwn1 + first_cell + i1min, cntr)); + raise::KernelError(HERE, "Field is not read correctly"); + } + }); + } else { + raise::Error("Field not found", HERE); + } + ++cntr; + } + } + reader.Close(); + } } catch (std::exception& e) { std::cerr << e.what() << std::endl; - cleanup(); + CallOnce([]() { + cleanup(); + }); MPI_Finalize(); Kokkos::finalize(); return 1; @@ -75,3 +191,5 @@ auto main(int argc, char* argv[]) -> int { Kokkos::finalize(); return 0; } + +#undef CEILDIV diff --git a/src/output/tests/writer-nompi.cpp b/src/output/tests/writer-nompi.cpp index a2a116e65..66d834f43 100644 --- a/src/output/tests/writer-nompi.cpp +++ b/src/output/tests/writer-nompi.cpp @@ -3,7 +3,6 @@ #include "utils/formatting.h" -#include "output/fields.h" #include "output/writer.h" #include @@ -12,94 +11,193 @@ #include #include -#include #include #include +using namespace ntt; + void cleanup() { namespace fs = std::filesystem; fs::path tempfile_path { "test.h5" }; fs::remove(tempfile_path); } +#define CEILDIV(a, b) \ + (static_cast( \ + math::ceil(static_cast(a) / static_cast(b)))) + auto main(int argc, char* argv[]) -> int { Kokkos::initialize(argc, argv); try { + constexpr auto nx1 = 10; + constexpr auto nx1_gh = nx1 + 2 * N_GHOSTS; + constexpr auto nx2 = 14; + constexpr auto nx2_gh = nx2 + 2 * N_GHOSTS; + constexpr auto nx3 = 17; + constexpr auto nx3_gh = nx3 + 2 * N_GHOSTS; + constexpr auto i1min = N_GHOSTS; + constexpr auto i2min = N_GHOSTS; + constexpr auto i3min = N_GHOSTS; + constexpr auto i1max = nx1 + N_GHOSTS; + constexpr auto i2max = nx2 + N_GHOSTS; + constexpr auto i3max = nx3 + N_GHOSTS; + + constexpr auto dwn1 = 2; + constexpr auto dwn2 = 1; + constexpr auto dwn3 = 5; + + ndfield_t field { "fld", nx1_gh, nx2_gh, nx3_gh }; + std::vector field_names; - using namespace ntt; - auto writer = out::Writer("hdf5"); - writer.defineMeshLayout({ 10, 10, 10 }, { 0, 0, 0 }, { 10, 10, 10 }, false, Coord::Cart); - writer.defineFieldOutputs(SimEngine::SRPIC, { "E", "B", "Rho_1_3", "N_2" }); - - ndfield_t field { "fld", - 10 + 2 * N_GHOSTS, - 10 + 2 * N_GHOSTS, - 10 + 2 * N_GHOSTS }; - Kokkos::parallel_for( - "fill", - CreateRangePolicy({ N_GHOSTS, N_GHOSTS, N_GHOSTS }, - { 10 + N_GHOSTS, 10 + N_GHOSTS, 10 + N_GHOSTS }), - Lambda(index_t i1, index_t i2, index_t i3) { - field(i1, i2, i3, 0) = i1 + i2 + i3; - field(i1, i2, i3, 1) = i1 * i2 / i3; - field(i1, i2, i3, 2) = i1 / i2 * i3; - }); - std::vector names; - std::vector addresses; - for (auto i = 0; i < 3; ++i) { - names.push_back(writer.fieldWriters()[0].name(i)); - addresses.push_back(i); + { + // fill data + Kokkos::parallel_for( + "fill", + CreateRangePolicy({ i1min, i2min, i3min }, + { i1max, i2max, i3max }), + Lambda(index_t i1, index_t i2, index_t i3) { + const auto i1_ = static_cast(i1); + const auto i2_ = static_cast(i2); + const auto i3_ = static_cast(i3); + field(i1, i2, i3, 0) = i1_; + field(i1, i2, i3, 1) = i2_; + field(i1, i2, i3, 2) = i3_; + }); } - writer.beginWriting("test", 0, 0.0); - writer.writeField(names, field, addresses); - writer.endWriting(); - writer.beginWriting("test", 1, 0.1); - writer.writeField(names, field, addresses); - writer.endWriting(); + adios2::ADIOS adios; + + { + // write + auto writer = out::Writer(); + writer.init(&adios, "hdf5", "test", false); + writer.defineMeshLayout({ nx1, nx2, nx3 }, + { 0, 0, 0 }, + { nx1, nx2, nx3 }, + { 0, 1 }, + { dwn1, dwn2, dwn3 }, + false, + Coord::Cart); + writer.defineFieldOutputs(SimEngine::SRPIC, { "E", "B", "Rho_1_3", "N_2" }); + + std::vector addresses; + for (auto i = 0; i < 3; ++i) { + field_names.push_back(writer.fieldWriters()[0].name(i)); + addresses.push_back(i); + } + writer.beginWriting(WriteMode::Fields, 10, 123.0); + writer.writeField(field_names, field, addresses); + writer.endWriting(WriteMode::Fields); + + writer.beginWriting(WriteMode::Fields, 20, 123.4); + writer.writeField(field_names, field, addresses); + writer.endWriting(WriteMode::Fields); + } + + adios.FlushAll(); { // read - adios2::ADIOS adios; - adios2::IO io = adios.DeclareIO("read-test"); + adios2::IO io = adios.DeclareIO("read-test"); io.SetEngine("hdf5"); adios2::Engine reader = io.Open("test.h5", adios2::Mode::Read); - - std::size_t step { 0 }; - long double time { 0.0 }; - reader.Get(io.InquireVariable("Step"), step); - reader.Get(io.InquireVariable("Time"), time); - raise::ErrorIf(step != 0, "Step is not 0", HERE); - raise::ErrorIf(time != 0.0, "Time is not 0.0", HERE); - - for (std::size_t step = 0; reader.BeginStep() == adios2::StepStatus::OK; - ++step) { - std::size_t step_read; - adios2::Variable stepVar = io.InquireVariable( - "Step"); - reader.Get(stepVar, step_read); - - long double time_read; - reader.Get(io.InquireVariable("Time"), time_read); - raise::ErrorIf(step_read != step, "Step is not correct", HERE); - raise::ErrorIf((float)time_read != (float)step / 10.0f, + const auto layoutRight = io.InquireAttribute("LayoutRight").Data()[0] == + 1; + + raise::ErrorIf(io.InquireAttribute("NGhosts").Data()[0] != 0, + "NGhosts is not correct", + HERE); + raise::ErrorIf(io.InquireAttribute("Dimension").Data()[0] != 3, + "Dimension is not correct", + HERE); + + for (auto step = 0u; reader.BeginStep() == adios2::StepStatus::OK; ++step) { + timestep_t step_read; + simtime_t time_read; + + reader.Get(io.InquireVariable("Step"), + &step_read, + adios2::Mode::Sync); + reader.Get(io.InquireVariable("Time"), + &time_read, + adios2::Mode::Sync); + raise::ErrorIf(step_read != (step + 1) * 10, "Step is not correct", HERE); + raise::ErrorIf((float)time_read != 123 + (float)step * 0.4f, "Time is not correct", HERE); - for (const auto& name : names) { - auto data = io.InquireVariable(name); - raise::ErrorIf(data.Shape().size() != 3, - fmt::format("%s is not 3D", name.c_str()), - HERE); - - auto dims = data.Shape(); - std::size_t nx1 = dims[0]; - std::size_t nx2 = dims[1]; - std::size_t nx3 = dims[2]; - raise::ErrorIf((nx1 != 10) || (nx2 != 10) || (nx3 != 10), - fmt::format("%s is not 10x10x10", name.c_str()), - HERE); + array_t field_read {}; + + int cntr = 0; + for (const auto& name : field_names) { + auto fieldVar = io.InquireVariable(name); + if (fieldVar) { + raise::ErrorIf(fieldVar.Shape().size() != 3, + fmt::format("%s is not 3D", name.c_str()), + HERE); + + auto dims = fieldVar.Shape(); + ncells_t nx1_r = dims[0]; + ncells_t nx2_r = dims[1]; + ncells_t nx3_r = dims[2]; + if (!layoutRight) { + std::swap(nx1_r, nx3_r); + } + raise::ErrorIf((nx1_r != CEILDIV(nx1, dwn1)) || + (nx2_r != CEILDIV(nx2, dwn2)) || + (nx3_r != CEILDIV(nx3, dwn3)), + fmt::format("%s = %ldx%ldx%ld is not %dx%dx%d", + name.c_str(), + nx1_r, + nx2_r, + nx3_r, + CEILDIV(nx1, dwn1), + CEILDIV(nx2, dwn2), + CEILDIV(nx3, dwn3)), + HERE); + + if (!layoutRight) { + std::swap(nx1_r, nx3_r); + } + fieldVar.SetSelection( + adios2::Box({ 0, 0, 0 }, { nx1_r, nx2_r, nx3_r })); + if (!layoutRight) { + std::swap(nx1_r, nx3_r); + } + field_read = array_t(name, nx1_r, nx2_r, nx3_r); + auto field_read_h = Kokkos::create_mirror_view(field_read); + reader.Get(fieldVar, field_read_h.data(), adios2::Mode::Sync); + Kokkos::deep_copy(field_read, field_read_h); + + Kokkos::parallel_for( + "check", + CreateRangePolicy({ 0, 0, 0 }, { nx1_r, nx2_r, nx3_r }), + Lambda(index_t i1, index_t i2, index_t i3) { + if (not cmp::AlmostEqual(field_read(i1, i2, i3), + field(i1 * dwn1 + i1min, + i2 * dwn2 + i2min, + i3 * dwn3 + i3min, + cntr))) { + Kokkos::printf( + "\n:::::::::::::::\nfield_read(%ld, %ld, %ld) = %f != " + "field(%ld, %ld, %ld, %d) = %f\n:::::::::::::::\n", + i1, + i2, + i3, + field_read(i1, i2, i3), + i1 * dwn1 + i1min, + i2 * dwn2 + i2min, + i3 * dwn3 + i3min, + cntr, + field(i1 * dwn1 + i1min, i2 * dwn2 + i2min, i3 * dwn3 + i3min, cntr)); + raise::KernelError(HERE, "Field is not read correctly"); + } + }); + } else { + raise::Error("Field not found", HERE); + } + ++cntr; } reader.EndStep(); } @@ -115,3 +213,5 @@ auto main(int argc, char* argv[]) -> int { Kokkos::finalize(); return 0; } + +#undef CEILDIV diff --git a/src/output/utils/attr_writer.h b/src/output/utils/attr_writer.h index 47f269d55..c8b21e4c2 100644 --- a/src/output/utils/attr_writer.h +++ b/src/output/utils/attr_writer.h @@ -40,7 +40,10 @@ namespace out { {typeid(int), defineAttribute}, {typeid(short), defineAttribute}, {typeid(unsigned int), defineAttribute}, - {typeid(std::size_t), defineAttribute}, + {typeid(long int), defineAttribute}, + {typeid(unsigned long int), defineAttribute}, + {typeid(long long int), defineAttribute}, + {typeid(unsigned long long int), defineAttribute}, {typeid(unsigned short), defineAttribute}, {typeid(float), defineAttribute}, {typeid(double), defineAttribute}, @@ -49,7 +52,10 @@ namespace out { {typeid(std::vector), defineAttribute>}, {typeid(std::vector), defineAttribute>}, {typeid(std::vector), defineAttribute>}, - {typeid(std::vector), defineAttribute>}, + {typeid(std::vector), defineAttribute>}, + {typeid(std::vector), defineAttribute>}, + {typeid(std::vector), defineAttribute>}, + {typeid(std::vector), defineAttribute>}, {typeid(std::vector), defineAttribute>}, {typeid(std::vector), defineAttribute>}, {typeid(std::vector), defineAttribute>}, diff --git a/src/output/utils/interpret_prompt.cpp b/src/output/utils/interpret_prompt.cpp index 7e6d92971..8506b29ff 100644 --- a/src/output/utils/interpret_prompt.cpp +++ b/src/output/utils/interpret_prompt.cpp @@ -10,12 +10,12 @@ namespace out { - auto InterpretSpecies(const std::string& in) -> std::vector { - std::vector species; + auto InterpretSpecies(const std::string& in) -> std::vector { + std::vector species; if (in.find("_") < in.size()) { auto species_str = fmt::splitString(in.substr(in.find("_") + 1), "_"); for (const auto& specie : species_str) { - species.push_back((unsigned short)(std::stoi(specie))); + species.push_back((spidx_t)(std::stoi(specie))); } } return species; diff --git a/src/output/utils/interpret_prompt.h b/src/output/utils/interpret_prompt.h index ebacaa980..032482cf8 100644 --- a/src/output/utils/interpret_prompt.h +++ b/src/output/utils/interpret_prompt.h @@ -4,7 +4,7 @@ * Defines the function that interprets ... * ... the user-defined species, e.g. when computing moments * @implements - * - out::InterpretSpecies -> std::vector + * - out::InterpretSpecies -> std::vector * - out::InterpretComponents -> std::vector> * @cpp: * - interpret_prompt.cpp @@ -17,15 +17,17 @@ #ifndef OUTPUT_UTILS_INTERPRET_PROMPT_H #define OUTPUT_UTILS_INTERPRET_PROMPT_H +#include "global.h" + #include #include namespace out { - auto InterpretSpecies(const std::string&) -> std::vector; + auto InterpretSpecies(const std::string&) -> std::vector; - auto InterpretComponents(const std::vector&) - -> std::vector>; + auto InterpretComponents( + const std::vector&) -> std::vector>; } // namespace out diff --git a/src/output/write_attrs.cpp b/src/output/write_attrs.cpp deleted file mode 100644 index e1e70f467..000000000 --- a/src/output/write_attrs.cpp +++ /dev/null @@ -1,136 +0,0 @@ -#include "enums.h" -#include "global.h" - -#include "output/writer.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace out { - - template - struct has_to_string : std::false_type {}; - - template - struct has_to_string().to_string())>> - : std::true_type {}; - - template - auto write(adios2::IO& io, const std::string& name, T var) -> - typename std::enable_if::value, void>::type { - io.DefineAttribute(name, std::string(var.to_string())); - } - - template - auto write(adios2::IO& io, const std::string& name, T var) - -> decltype(void(T()), void()) { - io.DefineAttribute(name, var); - } - - template <> - void write(adios2::IO& io, const std::string& name, bool var) { - io.DefineAttribute(name, var ? 1 : 0); - } - - template <> - void write(adios2::IO& io, const std::string& name, Dimension var) { - io.DefineAttribute(name, (unsigned short)var); - } - - template - auto write_vec(adios2::IO& io, const std::string& name, std::vector var) -> - typename std::enable_if::value, void>::type { - std::vector var_str; - for (const auto& v : var) { - var_str.push_back(v.to_string()); - } - io.DefineAttribute(name, var_str.data(), var_str.size()); - } - - template - auto write_vec(adios2::IO& io, const std::string& name, std::vector var) - -> decltype(void(T()), void()) { - io.DefineAttribute(name, var.data(), var.size()); - } - - std::map> - write_functions; - - template - void register_write_function() { - write_functions[std::type_index(typeid(T))] = - [](adios2::IO& io, const std::string& name, std::any a) { - write(io, name, std::any_cast(a)); - }; - } - - template - void register_write_function_for_vector() { - write_functions[std::type_index(typeid(std::vector))] = - [](adios2::IO& io, const std::string& name, std::any a) { - write_vec(io, name, std::any_cast>(a)); - }; - } - - void write_any(adios2::IO& io, const std::string& name, std::any a) { - auto it = write_functions.find(a.type()); - if (it != write_functions.end()) { - it->second(io, name, a); - } else { - throw std::runtime_error("No write function registered for this type"); - } - } - - void Writer::writeAttrs(const prm::Parameters& params) { - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - register_write_function_for_vector(); - - for (auto& [key, value] : params.allVars()) { - try { - write_any(m_io, key, value); - } catch (const std::exception& e) { - continue; - } - } - } -} // namespace out \ No newline at end of file diff --git a/src/output/writer.cpp b/src/output/writer.cpp index 91bf596a8..cc9ec0eb8 100644 --- a/src/output/writer.cpp +++ b/src/output/writer.cpp @@ -4,12 +4,13 @@ #include "arch/kokkos_aliases.h" #include "utils/error.h" +#include "utils/formatting.h" #include "utils/param_container.h" +#include "utils/tools.h" #include - -#include -#include +#include +#include #if defined(MPI_ENABLED) #include "arch/mpi_aliases.h" @@ -17,77 +18,115 @@ #include #endif +#include +#include +#include + namespace out { - Writer::Writer(const std::string& engine) : m_engine { engine } { - m_io = m_adios.DeclareIO("Entity::ADIOS2"); + void Writer::init(adios2::ADIOS* ptr_adios, + const std::string& engine, + const std::string& title, + bool use_separate_files) { + m_separate_files = use_separate_files; + m_engine = fmt::toLower(engine); + p_adios = ptr_adios; + + raise::ErrorIf(p_adios == nullptr, "ADIOS pointer is null", HERE); + + m_io = p_adios->DeclareIO("Entity::Output"); m_io.SetEngine(engine); - m_io.DefineVariable("Step"); - m_io.DefineVariable("Time"); + m_io.DefineVariable("Step"); + m_io.DefineVariable("Time"); + m_root = path_t(title); } void Writer::addTracker(const std::string& type, - std::size_t interval, - long double interval_time) { - m_trackers.insert(std::pair( - { type, Tracker(type, interval, interval_time) })); + timestep_t interval, + simtime_t interval_time) { + m_trackers.insert({ type, tools::Tracker(type, interval, interval_time) }); } auto Writer::shouldWrite(const std::string& type, - std::size_t step, - long double time) -> bool { + timestep_t step, + simtime_t time) -> bool { if (m_trackers.find(type) != m_trackers.end()) { return m_trackers.at(type).shouldWrite(step, time); } else { - raise::Error("Tracker type not found", HERE); + raise::Error(fmt::format("Tracker type %s not found", type.c_str()), HERE); return false; } } - void Writer::defineMeshLayout(const std::vector& glob_shape, - const std::vector& loc_corner, - const std::vector& loc_shape, - bool incl_ghosts, - Coord coords) { - m_flds_ghosts = incl_ghosts; + void Writer::setMode(adios2::Mode mode) { + m_mode = mode; + } + + void Writer::defineMeshLayout( + const std::vector& glob_shape, + const std::vector& loc_corner, + const std::vector& loc_shape, + const std::pair& domain_idx, + const std::vector& dwn, + bool incl_ghosts, + Coord coords) { + m_flds_ghosts = incl_ghosts; + m_dwn = dwn; + m_flds_g_shape = glob_shape; m_flds_l_corner = loc_corner; m_flds_l_shape = loc_shape; + for (auto i { 0u }; i < glob_shape.size(); ++i) { + raise::ErrorIf(dwn[i] != 1 && incl_ghosts, + "Downsampling with ghosts not supported", + HERE); + + const double g = glob_shape[i]; + const double d = m_dwn[i]; + const double l = loc_corner[i]; + const double n = loc_shape[i]; + const double f = math::ceil(l / d) * d - l; + m_flds_g_shape_dwn.push_back(static_cast(math::ceil(g / d))); + m_flds_l_corner_dwn.push_back(static_cast(math::ceil(l / d))); + m_flds_l_first.push_back(static_cast(f)); + m_flds_l_shape_dwn.push_back(static_cast(math::ceil((n - f) / d))); + } + m_io.DefineAttribute("NGhosts", incl_ghosts ? N_GHOSTS : 0); m_io.DefineAttribute("Dimension", m_flds_g_shape.size()); m_io.DefineAttribute("Coordinates", std::string(coords.to_string())); - for (std::size_t i { 0 }; i < m_flds_g_shape.size(); ++i) { + for (auto i { 0u }; i < m_flds_g_shape.size(); ++i) { // cell-centers - adios2::Dims g_shape = { m_flds_g_shape[i] }; - adios2::Dims l_corner = { m_flds_l_corner[i] }; - adios2::Dims l_shape = { m_flds_l_shape[i] }; m_io.DefineVariable("X" + std::to_string(i + 1), - g_shape, - l_corner, - l_shape, + { m_flds_g_shape_dwn[i] }, + { m_flds_l_corner_dwn[i] }, + { m_flds_l_shape_dwn[i] }, adios2::ConstantDims); // cell-edges - const auto is_last = (m_flds_l_corner[i] + m_flds_l_shape[i] == + const auto is_last = (m_flds_l_corner[i] + m_flds_l_shape[i] == m_flds_g_shape[i]); - adios2::Dims g_shape1 = { m_flds_g_shape[i] + 1 }; - adios2::Dims l_shape1 = { m_flds_l_shape[i] + (is_last ? 1 : 0) }; m_io.DefineVariable("X" + std::to_string(i + 1) + "e", - g_shape1, - l_corner, - l_shape1, + { m_flds_g_shape_dwn[i] + 1 }, + { m_flds_l_corner_dwn[i] }, + { m_flds_l_shape_dwn[i] + (is_last ? 1 : 0) }, adios2::ConstantDims); + m_io.DefineVariable("N" + std::to_string(i + 1) + "l", + { 2 * domain_idx.second }, + { 2 * domain_idx.first }, + { 2 }, + adios2::ConstantDims); } if constexpr (std::is_same::array_layout, Kokkos::LayoutRight>::value) { m_io.DefineAttribute("LayoutRight", 1); } else { - std::reverse(m_flds_g_shape.begin(), m_flds_g_shape.end()); - std::reverse(m_flds_l_corner.begin(), m_flds_l_corner.end()); - std::reverse(m_flds_l_shape.begin(), m_flds_l_shape.end()); + std::reverse(m_flds_g_shape_dwn.begin(), m_flds_g_shape_dwn.end()); + std::reverse(m_flds_l_corner_dwn.begin(), m_flds_l_corner_dwn.end()); + std::reverse(m_flds_l_shape_dwn.begin(), m_flds_l_shape_dwn.end()); m_io.DefineAttribute("LayoutRight", 0); } } @@ -95,8 +134,9 @@ namespace out { void Writer::defineFieldOutputs(const SimEngine& S, const std::vector& flds_out) { m_flds_writers.clear(); - raise::ErrorIf((m_flds_g_shape.size() == 0) || (m_flds_l_corner.size() == 0) || - (m_flds_l_shape.size() == 0), + raise::ErrorIf((m_flds_g_shape_dwn.size() == 0) || + (m_flds_l_corner_dwn.size() == 0) || + (m_flds_l_shape_dwn.size() == 0), "Mesh layout must be defined before field output", HERE); for (const auto& fld : flds_out) { @@ -104,25 +144,27 @@ namespace out { } for (const auto& fld : m_flds_writers) { if (fld.comp.size() == 0) { + // scalar m_io.DefineVariable(fld.name(), - m_flds_g_shape, - m_flds_l_corner, - m_flds_l_shape, + m_flds_g_shape_dwn, + m_flds_l_corner_dwn, + m_flds_l_shape_dwn, adios2::ConstantDims); } else { - for (std::size_t i { 0 }; i < fld.comp.size(); ++i) { + // vector or tensor + for (auto i { 0u }; i < fld.comp.size(); ++i) { m_io.DefineVariable(fld.name(i), - m_flds_g_shape, - m_flds_l_corner, - m_flds_l_shape, + m_flds_g_shape_dwn, + m_flds_l_corner_dwn, + m_flds_l_shape_dwn, adios2::ConstantDims); } } } } - void Writer::defineParticleOutputs(Dimension dim, - const std::vector& specs) { + void Writer::defineParticleOutputs(Dimension dim, + const std::vector& specs) { m_prtl_writers.clear(); for (const auto& s : specs) { m_prtl_writers.emplace_back(s); @@ -147,7 +189,7 @@ namespace out { } } - void Writer::defineSpectraOutputs(const std::vector& specs) { + void Writer::defineSpectraOutputs(const std::vector& specs) { m_spectra_writers.clear(); for (const auto& s : specs) { m_spectra_writers.emplace_back(s); @@ -158,49 +200,128 @@ namespace out { } } + void Writer::writeAttrs(const prm::Parameters& params) { + params.write(m_io); + } + template - void WriteField(adios2::IO& io, - adios2::Engine& writer, - const std::string& varname, - const ndfield_t& field, - std::size_t comp, - bool ghosts) { - auto var = io.InquireVariable(varname); - const auto gh_zones = ghosts ? 0 : N_GHOSTS; + void WriteField(adios2::IO& io, + adios2::Engine& writer, + const std::string& varname, + const ndfield_t& field, + std::size_t comp, + std::vector dwn, + std::vector first_cell, + bool ghosts) { + // when dwn != 1 in any direction, it is assumed that ghosts == false + auto var = io.InquireVariable(varname); + const auto gh_zones = ghosts ? 0 : N_GHOSTS; + ndarray_t output_field {}; if constexpr (D == Dim::_1D) { - auto slice_i1 = range_tuple_t(gh_zones, field.extent(0) - gh_zones); - auto slice = Kokkos::subview(field, slice_i1, comp); - auto output_field = array_t("output_field", slice.extent(0)); - Kokkos::deep_copy(output_field, slice); - auto output_field_host = Kokkos::create_mirror_view(output_field); - Kokkos::deep_copy(output_field_host, output_field); - writer.Put(var, output_field_host); + if (ghosts || dwn[0] == 1) { + auto slice_i1 = range_tuple_t(gh_zones, field.extent(0) - gh_zones); + auto slice = Kokkos::subview(field, slice_i1, comp); + output_field = array_t { "output_field", slice.extent(0) }; + Kokkos::deep_copy(output_field, slice); + } else { + + const auto dwn1 = dwn[0]; + const double first_cell1_d = first_cell[0]; + const double nx1_full = field.extent(0) - 2 * N_GHOSTS; + const auto first_cell1 = first_cell[0]; + + const auto nx1_dwn = static_cast( + math::ceil((nx1_full - first_cell1_d) / dwn1)); + + output_field = array_t { "output_field", nx1_dwn }; + Kokkos::parallel_for( + "outputField", + nx1_dwn, + Lambda(index_t i1) { + output_field(i1) = field(first_cell1 + i1 * dwn1 + N_GHOSTS, comp); + }); + } } else if constexpr (D == Dim::_2D) { - auto slice_i1 = range_tuple_t(gh_zones, field.extent(0) - gh_zones); - auto slice_i2 = range_tuple_t(gh_zones, field.extent(1) - gh_zones); - auto slice = Kokkos::subview(field, slice_i1, slice_i2, comp); - auto output_field = array_t("output_field", + if (ghosts || (dwn[0] == 1 && dwn[1] == 1)) { + auto slice_i1 = range_tuple_t(gh_zones, field.extent(0) - gh_zones); + auto slice_i2 = range_tuple_t(gh_zones, field.extent(1) - gh_zones); + auto slice = Kokkos::subview(field, slice_i1, slice_i2, comp); + output_field = array_t { "output_field", slice.extent(0), - slice.extent(1)); - Kokkos::deep_copy(output_field, slice); - auto output_field_host = Kokkos::create_mirror_view(output_field); - Kokkos::deep_copy(output_field_host, output_field); - writer.Put(var, output_field_host); + slice.extent(1) }; + Kokkos::deep_copy(output_field, slice); + } else { + const auto dwn1 = dwn[0]; + const auto dwn2 = dwn[1]; + const double first_cell1_d = first_cell[0]; + const double first_cell2_d = first_cell[1]; + const double nx1_full = field.extent(0) - 2 * N_GHOSTS; + const double nx2_full = field.extent(1) - 2 * N_GHOSTS; + const auto first_cell1 = first_cell[0]; + const auto first_cell2 = first_cell[1]; + + const auto nx1_dwn = static_cast( + math::ceil((nx1_full - first_cell1_d) / dwn1)); + const auto nx2_dwn = static_cast( + math::ceil((nx2_full - first_cell2_d) / dwn2)); + output_field = array_t { "output_field", nx1_dwn, nx2_dwn }; + Kokkos::parallel_for( + "outputField", + CreateRangePolicy({ 0, 0 }, { nx1_dwn, nx2_dwn }), + Lambda(index_t i1, index_t i2) { + output_field(i1, i2) = field(first_cell1 + i1 * dwn1 + N_GHOSTS, + first_cell2 + i2 * dwn2 + N_GHOSTS, + comp); + }); + } } else if constexpr (D == Dim::_3D) { - auto slice_i1 = range_tuple_t(gh_zones, field.extent(0) - gh_zones); - auto slice_i2 = range_tuple_t(gh_zones, field.extent(1) - gh_zones); - auto slice_i3 = range_tuple_t(gh_zones, field.extent(2) - gh_zones); - auto slice = Kokkos::subview(field, slice_i1, slice_i2, slice_i3, comp); - auto output_field = array_t("output_field", - slice.extent(0), - slice.extent(1), - slice.extent(2)); - Kokkos::deep_copy(output_field, slice); - auto output_field_host = Kokkos::create_mirror_view(output_field); - Kokkos::deep_copy(output_field_host, output_field); - writer.Put(var, output_field_host); + if (ghosts || (dwn[0] == 1 && dwn[1] == 1 && dwn[2] == 1)) { + auto slice_i1 = range_tuple_t(gh_zones, field.extent(0) - gh_zones); + auto slice_i2 = range_tuple_t(gh_zones, field.extent(1) - gh_zones); + auto slice_i3 = range_tuple_t(gh_zones, field.extent(2) - gh_zones); + auto slice = Kokkos::subview(field, slice_i1, slice_i2, slice_i3, comp); + output_field = array_t { "output_field", + slice.extent(0), + slice.extent(1), + slice.extent(2) }; + Kokkos::deep_copy(output_field, slice); + } else { + const auto dwn1 = dwn[0]; + const auto dwn2 = dwn[1]; + const auto dwn3 = dwn[2]; + const double first_cell1_d = first_cell[0]; + const double first_cell2_d = first_cell[1]; + const double first_cell3_d = first_cell[2]; + const double nx1_full = field.extent(0) - 2 * N_GHOSTS; + const double nx2_full = field.extent(1) - 2 * N_GHOSTS; + const double nx3_full = field.extent(2) - 2 * N_GHOSTS; + const auto first_cell1 = first_cell[0]; + const auto first_cell2 = first_cell[1]; + const auto first_cell3 = first_cell[2]; + + const auto nx1_dwn = static_cast( + math::ceil((nx1_full - first_cell1_d) / dwn1)); + const auto nx2_dwn = static_cast( + math::ceil((nx2_full - first_cell2_d) / dwn2)); + const auto nx3_dwn = static_cast( + math::ceil((nx3_full - first_cell3_d) / dwn3)); + + output_field = array_t { "output_field", nx1_dwn, nx2_dwn, nx3_dwn }; + Kokkos::parallel_for( + "outputField", + CreateRangePolicy({ 0, 0, 0 }, { nx1_dwn, nx2_dwn, nx3_dwn }), + Lambda(index_t i1, index_t i2, index_t i3) { + output_field(i1, i2, i3) = field(first_cell1 + i1 * dwn1 + N_GHOSTS, + first_cell2 + i2 * dwn2 + N_GHOSTS, + first_cell3 + i3 * dwn3 + N_GHOSTS, + comp); + }); + } } + auto output_field_h = Kokkos::create_mirror_view(output_field); + Kokkos::deep_copy(output_field_h, output_field); + writer.Put(var, output_field_h, adios2::Mode::Sync); } template @@ -213,14 +334,21 @@ namespace out { raise::ErrorIf(names.size() != addresses.size(), "# of names != # of addresses ", HERE); - for (std::size_t i { 0 }; i < addresses.size(); ++i) { - WriteField(m_io, m_writer, names[i], fld, addresses[i], m_flds_ghosts); + for (auto i { 0u }; i < addresses.size(); ++i) { + WriteField(m_io, + m_writer, + names[i], + fld, + addresses[i], + m_dwn, + m_flds_l_first, + m_flds_ghosts); } } void Writer::writeParticleQuantity(const array_t& array, - std::size_t glob_total, - std::size_t loc_offset, + npart_t glob_total, + npart_t loc_offset, const std::string& varname) { auto var = m_io.InquireVariable(varname); var.SetShape({ glob_total }); @@ -228,11 +356,12 @@ namespace out { adios2::Box({ loc_offset }, { array.extent(0) })); auto array_h = Kokkos::create_mirror_view(array); Kokkos::deep_copy(array_h, array); - m_writer.Put(var, array_h); + m_writer.Put(var, array_h, adios2::Mode::Sync); } void Writer::writeSpectrum(const array_t& counts, const std::string& varname) { + auto var = m_io.InquireVariable(varname); auto counts_h = Kokkos::create_mirror_view(counts); Kokkos::deep_copy(counts_h, counts); #if defined(MPI_ENABLED) @@ -248,121 +377,146 @@ namespace out { MPI_ROOT_RANK, MPI_COMM_WORLD); if (rank == MPI_ROOT_RANK) { - auto var = m_io.InquireVariable(varname); - var.SetSelection(adios2::Box({}, { counts.extent(0) })); - m_writer.Put(var, counts_h_all); + var.SetSelection( + adios2::Box({ 0u }, { counts_h_all.extent(0) })); + m_writer.Put(var, counts_h_all, adios2::Mode::Sync); + } else { + var.SetSelection(adios2::Box({ 0u }, { 0u })); + m_writer.Put(var, nullptr); } #else - auto var = m_io.InquireVariable(varname); var.SetSelection(adios2::Box({}, { counts.extent(0) })); - m_writer.Put(var, counts_h); + m_writer.Put(var, counts_h, adios2::Mode::Sync); #endif } void Writer::writeSpectrumBins(const array_t& e_bins, const std::string& varname) { + auto var = m_io.InquireVariable(varname); + auto e_bins_h = Kokkos::create_mirror_view(e_bins); + Kokkos::deep_copy(e_bins_h, e_bins); #if defined(MPI_ENABLED) int rank; MPI_Comm_rank(MPI_COMM_WORLD, &rank); - if (rank != MPI_ROOT_RANK) { - return; + if (rank == MPI_ROOT_RANK) { + var.SetSelection(adios2::Box({ 0u }, { e_bins_h.extent(0) })); + m_writer.Put(var, e_bins_h.data(), adios2::Mode::Sync); + } else { + var.SetSelection(adios2::Box({ 0u }, { 0u })); + m_writer.Put(var, nullptr, adios2::Mode::Sync); } +#else + var.SetSelection(adios2::Box({}, { e_bins_h.extent(0) })); + m_writer.Put(var, e_bins_h, adios2::Mode::Sync); #endif - auto var = m_io.InquireVariable(varname); - var.SetSelection(adios2::Box({}, { e_bins.extent(0) })); - auto e_bins_h = Kokkos::create_mirror_view(e_bins); - Kokkos::deep_copy(e_bins_h, e_bins); - m_writer.Put(var, e_bins_h); } - void Writer::writeMesh(unsigned short dim, - const array_t& xc, - const array_t& xe) { + void Writer::writeMesh(unsigned short dim, + const array_t& xc, + const array_t& xe, + const std::vector& loc_off_sz) { auto varc = m_io.InquireVariable("X" + std::to_string(dim + 1)); auto vare = m_io.InquireVariable("X" + std::to_string(dim + 1) + "e"); auto xc_h = Kokkos::create_mirror_view(xc); auto xe_h = Kokkos::create_mirror_view(xe); Kokkos::deep_copy(xc_h, xc); Kokkos::deep_copy(xe_h, xe); - m_writer.Put(varc, xc_h); - m_writer.Put(vare, xe_h); + m_writer.Put(varc, xc_h, adios2::Mode::Sync); + m_writer.Put(vare, xe_h, adios2::Mode::Sync); + auto vard = m_io.InquireVariable( + "N" + std::to_string(dim + 1) + "l"); + m_writer.Put(vard, loc_off_sz.data(), adios2::Mode::Sync); } - void Writer::beginWriting(const std::string& fname, - std::size_t tstep, - long double time) { - m_adios.ExitComputationBlock(); + void Writer::beginWriting(WriteModeTags write_mode, + timestep_t tstep, + simtime_t time) { + raise::ErrorIf(write_mode == WriteMode::None, "None is not a valid mode", HERE); + raise::ErrorIf(p_adios == nullptr, "ADIOS pointer is null", HERE); + if (m_active_mode != WriteMode::None) { + raise::Fatal("Already writing", HERE); + } try { - m_writer = m_io.Open(fname + (m_engine == "hdf5" ? ".h5" : ".bp"), m_mode); + path_t filename; + + const std::string ext = (m_engine == "hdf5") ? "h5" : "bp"; + if (m_separate_files) { + std::string mode_str; + if (write_mode == WriteMode::Fields) { + mode_str = "fields"; + } else if (write_mode == WriteMode::Particles) { + mode_str = "particles"; + } else if (write_mode == WriteMode::Spectra) { + mode_str = "spectra"; + } else { + raise::Fatal("Unknown write mode", HERE); + } + CallOnce( + [](auto&& main_path, auto&& mode_path) { + const path_t main { main_path }; + const path_t mode { mode_path }; + if (!std::filesystem::exists(main_path)) { + std::filesystem::create_directory(main_path); + } + if (!std::filesystem::exists(main_path / mode_path)) { + std::filesystem::create_directory(main_path / mode_path); + } + }, + m_root, + mode_str); +#if defined(MPI_ENABLED) + MPI_Barrier(MPI_COMM_WORLD); +#endif + filename = m_root / path_t(mode_str) / + fmt::format("%s.%08lu.%s", mode_str.c_str(), tstep, ext.c_str()); + m_mode = adios2::Mode::Write; + } else { + filename = fmt::format("%s.%s", m_root.c_str(), ext.c_str()); + m_mode = std::filesystem::exists(filename) ? adios2::Mode::Append + : adios2::Mode::Write; + } + m_writer = m_io.Open(filename, m_mode); + m_writer.BeginStep(); + m_writer.Put(m_io.InquireVariable("Step"), &tstep); + m_writer.Put(m_io.InquireVariable("Time"), &time); + m_active_mode = write_mode; } catch (std::exception& e) { raise::Fatal(e.what(), HERE); } - m_mode = adios2::Mode::Append; - m_writer.BeginStep(); - m_writer.Put(m_io.InquireVariable("Step"), &tstep); - m_writer.Put(m_io.InquireVariable("Time"), &time); } - void Writer::endWriting() { + void Writer::endWriting(WriteModeTags write_mode) { + raise::ErrorIf(write_mode == WriteMode::None, "None is not a valid mode", HERE); + raise::ErrorIf(p_adios == nullptr, "ADIOS pointer is null", HERE); + if (m_active_mode == WriteMode::None) { + raise::Fatal("Not writing", HERE); + } + if (m_active_mode != write_mode) { + raise::Fatal("Writing mode mismatch", HERE); + } + m_active_mode = WriteMode::None; m_writer.EndStep(); m_writer.Close(); - m_adios.EnterComputationBlock(); } - template void Writer::writeField(const std::vector&, - const ndfield_t&, - const std::vector&); - template void Writer::writeField(const std::vector&, - const ndfield_t&, - const std::vector&); - template void Writer::writeField(const std::vector&, - const ndfield_t&, - const std::vector&); - template void Writer::writeField(const std::vector&, - const ndfield_t&, - const std::vector&); - template void Writer::writeField(const std::vector&, - const ndfield_t&, - const std::vector&); - template void Writer::writeField(const std::vector&, - const ndfield_t&, - const std::vector&); - - template void WriteField(adios2::IO&, - adios2::Engine&, - const std::string&, - const ndfield_t&, - std::size_t, - bool); - template void WriteField(adios2::IO&, - adios2::Engine&, - const std::string&, - const ndfield_t&, - std::size_t, - bool); - template void WriteField(adios2::IO&, - adios2::Engine&, - const std::string&, - const ndfield_t&, - std::size_t, - bool); - template void WriteField(adios2::IO&, - adios2::Engine&, - const std::string&, - const ndfield_t&, - std::size_t, - bool); - template void WriteField(adios2::IO&, - adios2::Engine&, - const std::string&, - const ndfield_t&, - std::size_t, - bool); - template void WriteField(adios2::IO&, - adios2::Engine&, - const std::string&, - const ndfield_t&, - std::size_t, - bool); +#define WRITE_FIELD(D, N) \ + template void Writer::writeField(const std::vector&, \ + const ndfield_t&, \ + const std::vector&); \ + template void WriteField(adios2::IO&, \ + adios2::Engine&, \ + const std::string&, \ + const ndfield_t&, \ + std::size_t, \ + std::vector, \ + std::vector, \ + bool); + WRITE_FIELD(Dim::_1D, 3) + WRITE_FIELD(Dim::_1D, 6) + WRITE_FIELD(Dim::_2D, 3) + WRITE_FIELD(Dim::_2D, 6) + WRITE_FIELD(Dim::_3D, 3) + WRITE_FIELD(Dim::_3D, 6) +#undef WRITE_FIELD } // namespace out diff --git a/src/output/writer.h b/src/output/writer.h index 517a1655d..cc3edc733 100644 --- a/src/output/writer.h +++ b/src/output/writer.h @@ -11,6 +11,7 @@ #include "arch/kokkos_aliases.h" #include "utils/param_container.h" +#include "utils/tools.h" #include "output/fields.h" #include "output/particles.h" @@ -28,86 +29,76 @@ namespace out { - class Tracker { - const std::string m_type; - const std::size_t m_interval; - const long double m_interval_time; - const bool m_use_time; - - long double m_last_output_time { -1.0 }; - - public: - Tracker(const std::string& type, std::size_t interval, long double interval_time) - : m_type { type } - , m_interval { interval } - , m_interval_time { interval_time } - , m_use_time { interval_time > 0.0 } {} - - ~Tracker() = default; - - auto shouldWrite(std::size_t step, long double time) -> bool { - if (m_use_time) { - if (time - m_last_output_time >= m_interval_time) { - m_last_output_time = time; - return true; - } else { - return false; - } - } else { - return step % m_interval == 0; - } - } - }; - class Writer { -#if !defined(MPI_ENABLED) - adios2::ADIOS m_adios; -#else // MPI_ENABLED - adios2::ADIOS m_adios { MPI_COMM_WORLD }; -#endif + adios2::ADIOS* p_adios { nullptr }; + adios2::IO m_io; adios2::Engine m_writer; adios2::Mode m_mode { adios2::Mode::Write }; + bool m_separate_files; + // global shape of the fields array to output - adios2::Dims m_flds_g_shape; + std::vector m_flds_g_shape; // local corner of the fields array to output - adios2::Dims m_flds_l_corner; + std::vector m_flds_l_corner; // local shape of the fields array to output - adios2::Dims m_flds_l_shape; - bool m_flds_ghosts; - const std::string m_engine; + std::vector m_flds_l_shape; - std::map m_trackers; + // downsampling factors for each dimension + std::vector m_dwn; + // starting cell in each dimension (not including ghosts) + std::vector m_flds_l_first; + + // same but downsampled + adios2::Dims m_flds_g_shape_dwn; + adios2::Dims m_flds_l_corner_dwn; + adios2::Dims m_flds_l_shape_dwn; + + bool m_flds_ghosts; + std::string m_engine; + path_t m_root; + + std::map m_trackers; std::vector m_flds_writers; std::vector m_prtl_writers; std::vector m_spectra_writers; + WriteModeTags m_active_mode { WriteMode::None }; + public: - Writer() : m_engine { "disabled" } {} + Writer() {} - Writer(const std::string& engine); ~Writer() = default; Writer(Writer&&) = default; - void addTracker(const std::string&, std::size_t, long double); - auto shouldWrite(const std::string&, std::size_t, long double) -> bool; + void init(adios2::ADIOS*, const std::string&, const std::string&, bool); - void writeAttrs(const prm::Parameters& params); + void setMode(adios2::Mode); + + void addTracker(const std::string&, timestep_t, simtime_t); + auto shouldWrite(const std::string&, timestep_t, simtime_t) -> bool; + + void writeAttrs(const prm::Parameters&); void defineMeshLayout(const std::vector&, const std::vector&, const std::vector&, - bool incl_ghosts, + const std::pair&, + const std::vector&, + bool, Coord); void defineFieldOutputs(const SimEngine&, const std::vector&); - void defineParticleOutputs(Dimension, const std::vector&); - void defineSpectraOutputs(const std::vector&); + void defineParticleOutputs(Dimension, const std::vector&); + void defineSpectraOutputs(const std::vector&); - void writeMesh(unsigned short, const array_t&, const array_t&); + void writeMesh(unsigned short, + const array_t&, + const array_t&, + const std::vector&); template void writeField(const std::vector&, @@ -115,16 +106,20 @@ namespace out { const std::vector&); void writeParticleQuantity(const array_t&, - std::size_t, - std::size_t, + npart_t, + npart_t, const std::string&); void writeSpectrum(const array_t&, const std::string&); void writeSpectrumBins(const array_t&, const std::string&); - void beginWriting(const std::string&, std::size_t, long double); - void endWriting(); + void beginWriting(WriteModeTags, timestep_t, simtime_t); + void endWriting(WriteModeTags); /* getters -------------------------------------------------------------- */ + auto root() const -> const path_t& { + return m_root; + } + auto fieldWriters() const -> const std::vector& { return m_flds_writers; }