-
Notifications
You must be signed in to change notification settings - Fork 1
Add toon test #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add toon test #80
Changes from all commits
35d2974
6139b72
bd43fd8
dbd2f2e
aa840a9
f591293
f1ebec8
175b2df
dbe23dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| // torch | ||
| #include <torch/extension.h> | ||
|
|
||
| // fmt | ||
| #include <fmt/format.h> | ||
|
|
||
| // harp | ||
| #include <harp/rtsolver/toon_mckay89.hpp> | ||
|
|
||
| // python | ||
| #include "pyoptions.hpp" | ||
|
|
||
| namespace py = pybind11; | ||
|
|
||
| void bind_rtsolver(py::module &m) { | ||
| auto pyToonMcKay89Options = | ||
| py::class_<harp::ToonMcKay89OptionsImpl, harp::ToonMcKay89Options>( | ||
| m, "ToonMcKay89Options"); | ||
|
|
||
| pyToonMcKay89Options.def(py::init<>()) | ||
| .def("__repr__", | ||
| [](const harp::ToonMcKay89Options &a) { | ||
| std::stringstream ss; | ||
| a->report(ss); | ||
| return fmt::format("ToonMcKay89Options(\n{})", ss.str()); | ||
| }) | ||
| .ADD_OPTION(std::vector<double>, harp::ToonMcKay89OptionsImpl, wave_lower) | ||
| .ADD_OPTION(std::vector<double>, harp::ToonMcKay89OptionsImpl, wave_upper) | ||
| .ADD_OPTION(bool, harp::ToonMcKay89OptionsImpl, zenith_correction); | ||
|
|
||
| torch::python::bind_module<harp::ToonMcKay89Impl>(m, "ToonMcKay89") | ||
| .def(py::init<>()) | ||
| .def(py::init<harp::ToonMcKay89Options>(), py::arg("options")) | ||
| .def_readonly("options", &harp::ToonMcKay89Impl::options) | ||
| .def( | ||
| "forward", | ||
| [](harp::ToonMcKay89Impl &self, torch::Tensor prop, std::string bname, | ||
| torch::optional<torch::Tensor> temf, const py::kwargs &kwargs) { | ||
| // get bc from kwargs | ||
| std::map<std::string, torch::Tensor> bc; | ||
| for (auto item : kwargs) { | ||
| auto key = py::cast<std::string>(item.first); | ||
| auto value = py::cast<torch::Tensor>(item.second); | ||
| bc.emplace(std::move(key), std::move(value)); | ||
| } | ||
|
|
||
| for (auto &[key, value] : bc) { | ||
| std::vector<std::string> items = {"fbeam", "albedo", "umu0"}; | ||
| // broadcast dimensions to (nwave, ncol) | ||
| if (std::find(items.begin(), items.end(), key) != items.end()) { | ||
| while (value.dim() < 2) { | ||
| value = value.unsqueeze(0); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // broadcast dimensions to (nwave, ncol, nlyr, nprop) | ||
| while (prop.dim() < 4) { | ||
| prop = prop.unsqueeze(0); | ||
| } | ||
|
|
||
| return self.forward(prop, &bc, bname, temf); | ||
| }, | ||
| py::arg("prop"), py::arg("bname") = "", py::arg("temf") = py::none()); | ||
| }; |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| #pragma once | ||
|
|
||
| // torch | ||
| #include <ATen/TensorIterator.h> | ||
| #include <ATen/native/cuda/Loops.cuh> | ||
|
|
||
| namespace harp { | ||
| namespace native { | ||
|
|
||
| template <typename func_t> | ||
| __global__ void element_kernel(int64_t numel, func_t f, char *work) { | ||
| int tid = threadIdx.x; | ||
| int idx = blockIdx.x * blockDim.x + tid; | ||
| if (idx < numel) { | ||
| f(idx, work); | ||
| } | ||
| } | ||
|
|
||
| template <int Arity, typename func_t> | ||
| void gpu_kernel(at::TensorIterator& iter, const func_t& f) { | ||
| TORCH_CHECK(iter.ninputs() + iter.noutputs() == Arity); | ||
|
|
||
| std::array<char*, Arity> data; | ||
| for (int i = 0; i < Arity; i++) { | ||
| data[i] = reinterpret_cast<char*>(iter.data_ptr(i)); | ||
| } | ||
|
|
||
| auto offset_calc = ::make_offset_calculator<Arity>(iter); | ||
| int64_t numel = iter.numel(); | ||
|
|
||
| at::native::launch_legacy_kernel<128, 1>(numel, | ||
| [=] __device__(int idx) { | ||
| auto offsets = offset_calc.get(idx); | ||
| f(data.data(), offsets.data()); | ||
| }); | ||
| } | ||
|
|
||
| template <int Chunks, int Arity, typename func_t> | ||
| void gpu_chunk_kernel(at::TensorIterator& iter, int work_size, const func_t& f) { | ||
| TORCH_CHECK(iter.ninputs() + iter.noutputs() == Arity); | ||
|
|
||
| std::array<char*, Arity> data; | ||
| for (int i = 0; i < Arity; i++) { | ||
| data[i] = reinterpret_cast<char*>(iter.data_ptr(i)); | ||
| } | ||
|
|
||
| auto offset_calc = ::make_offset_calculator<Arity>(iter); | ||
| int64_t numel = iter.numel(); | ||
|
|
||
| // devide numel into Chunk parts to reduce memory usage | ||
| // allocate working memory pool | ||
| char* d_workspace = nullptr; | ||
|
|
||
| // workspace size per chunk | ||
| int chunks = Chunks > numel ? numel : Chunks; | ||
| int base = numel / chunks; | ||
| int rem = numel % chunks; | ||
|
|
||
| size_t workspace_bytes = work_size * (base + (rem > 0 ? 1 : 0)); | ||
| cudaMalloc(&d_workspace, workspace_bytes); | ||
|
|
||
| int chunk_start = 0; | ||
|
|
||
| for (int n = 0; n < chunks; n++) { | ||
| int64_t chunk_numel = base + (n < rem ? 1 : 0); | ||
| int64_t chunk_end = chunk_start + chunk_numel; // exclusive | ||
|
|
||
| dim3 block(64); | ||
| dim3 grid((chunk_numel + block.x - 1) / block.x); | ||
|
|
||
| auto device_lambda = [=] __device__(int idx, char* work) { | ||
| auto offsets = offset_calc.get(idx + chunk_start); | ||
| f(data.data(), offsets.data(), work + idx * work_size); | ||
| }; | ||
|
|
||
| /*std::cout << "chunk = " << n | ||
| << ", chunk_start = " << chunk_start | ||
| << ", chunk_end = " << chunk_end | ||
| << ", chunk_numel = " << chunk_numel | ||
| << ", block = " << block.x | ||
| << ", grid = " << grid.x | ||
| << ", work_size = " << work_size | ||
| << std::endl;*/ | ||
|
|
||
| element_kernel<<<grid, block>>>(chunk_numel, device_lambda, d_workspace); | ||
| C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
| cudaDeviceSynchronize(); | ||
|
|
||
| chunk_start = chunk_end; | ||
| } | ||
|
|
||
| // free working memory pool | ||
| cudaFree(d_workspace); | ||
| } | ||
|
|
||
| } // namespace native | ||
| } // namespace harp | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,33 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma once | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // C/C++ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <cstdlib> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // base | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <configure.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace harp { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Solves a tridiagonal system using the Thomas algorithm (TDMA) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DISPATCH_MACRO void dtridgl(int n, const T *a, const T *b, T *c, T *d, T *x) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // First row | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| c[0] = c[0] / b[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| d[0] = d[0] / b[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Forward sweep | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = 1; i < n; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T denom = b[i] - a[i] * c[i - 1]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (denom == 0.0) denom = 1e-12; // Avoid division by zero | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| c[i] = (i < n - 1) ? c[i] / denom : 0.0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| d[i] = (d[i] - a[i] * d[i - 1]) / denom; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Back substitution | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x[n - 1] = d[n - 1]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = n - 2; i >= 0; --i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+14
to
+28
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // First row | |
| c[0] = c[0] / b[0]; | |
| d[0] = d[0] / b[0]; | |
| // Forward sweep | |
| for (int i = 1; i < n; ++i) { | |
| T denom = b[i] - a[i] * c[i - 1]; | |
| if (denom == 0.0) denom = 1e-12; // Avoid division by zero | |
| c[i] = (i < n - 1) ? c[i] / denom : 0.0; | |
| d[i] = (d[i] - a[i] * d[i - 1]) / denom; | |
| } | |
| // Back substitution | |
| x[n - 1] = d[n - 1]; | |
| for (int i = n - 2; i >= 0; --i) { | |
| // First row (1-based indexing: use elements 1..n) | |
| c[1] = c[1] / b[1]; | |
| d[1] = d[1] / b[1]; | |
| // Forward sweep | |
| for (int i = 2; i <= n; ++i) { | |
| T denom = b[i] - a[i] * c[i - 1]; | |
| if (denom == 0.0) denom = 1e-12; // Avoid division by zero | |
| c[i] = (i < n) ? c[i] / denom : static_cast<T>(0); | |
| d[i] = (d[i] - a[i] * d[i - 1]) / denom; | |
| } | |
| // Back substitution | |
| x[n] = d[n]; | |
| for (int i = n - 1; i >= 1; --i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in comment. "devide" should be "divide".