From 35d2974862753dedd3e00d823f17c07e2181fa8a Mon Sep 17 00:00:00 2001 From: mac/cli Date: Sat, 17 Jan 2026 20:53:21 -0500 Subject: [PATCH 1/9] wip --- src/rtsolver/dtridgl_impl.h | 39 ++++++++++++++++++++++ src/rtsolver/toon_mckay89.hpp | 4 +-- src/rtsolver/toon_mckay89_longwave_impl.h | 26 +++++++++++++-- src/rtsolver/toon_mckay89_shortwave_impl.h | 14 +++++++- 4 files changed, 77 insertions(+), 6 deletions(-) create mode 100644 src/rtsolver/dtridgl_impl.h diff --git a/src/rtsolver/dtridgl_impl.h b/src/rtsolver/dtridgl_impl.h new file mode 100644 index 0000000..09ef6ae --- /dev/null +++ b/src/rtsolver/dtridgl_impl.h @@ -0,0 +1,39 @@ +#pragma once + +// C/C++ +#include + +namespace harp { + +// Solves a tridiagonal system using the Thomas algorithm (TDMA) +template +void dtridgl(int n, const T *a, const T *b, const T *c, const T *d, T *x, + char *mem, int &offset) { + T *cp = (T *)get_mem(n, sizeof(T), mem, &offset); + T *dp = (T *)get_mem(n, sizeof(T), mem, &offset); + + if (!cp || !dp) { + // Handle memory allocation failure + exit(EXIT_FAILURE); + } + + // First row + cp[0] = c[0] / b[0]; + dp[0] = d[0] / b[0]; + + // Forward sweep + for (int i = 1; i < n; ++i) { + T denom = b[i] - a[i] * cp[i - 1]; + if (denom == 0.0) denom = 1e-12; // Avoid division by zero + cp[i] = (i < n - 1) ? c[i] / denom : 0.0; + dp[i] = (d[i] - a[i] * dp[i - 1]) / denom; + } + + // Back substitution + x[n - 1] = dp[n - 1]; + for (int i = n - 2; i >= 0; --i) { + x[i] = dp[i] - cp[i] * x[i + 1]; + } +} + +} // namespace harp diff --git a/src/rtsolver/toon_mckay89.hpp b/src/rtsolver/toon_mckay89.hpp index 53fd5f4..c033061 100644 --- a/src/rtsolver/toon_mckay89.hpp +++ b/src/rtsolver/toon_mckay89.hpp @@ -56,7 +56,7 @@ class ToonMcKay89Impl : public torch::nn::Cloneable { * Based on Elsie Lee's implementation in Exo-FMS_column_ck, which was * based on CHIMERA code by Mike Line. * Ported by Xi Zhang to Eigen - * Proted by Cheng Li to torch + * Ported by Cheng Li to torch * Reference: Toon, O.B., 1989, JGR, 94,16287-16301. */ torch::Tensor shortwave_solver(torch::Tensor Finc, torch::Tensor mu0, @@ -68,7 +68,7 @@ class ToonMcKay89Impl : public torch::nn::Cloneable { * Based on Elsie Lee's implementation in Exo-FMS_column_ck, which was * based on CHIMERA code by Mike Line. * Ported by Xi Zhang to Eigen - * Proted by Cheng Li to torch + * Ported by Cheng Li to torch * Reference: Toon, O.B., 1989, JGR, 94, 16287-16301. */ torch::Tensor longwave_solver(torch::Tensor be, torch::Tensor dtau, diff --git a/src/rtsolver/toon_mckay89_longwave_impl.h b/src/rtsolver/toon_mckay89_longwave_impl.h index 7003e34..eb0063e 100644 --- a/src/rtsolver/toon_mckay89_longwave_impl.h +++ b/src/rtsolver/toon_mckay89_longwave_impl.h @@ -1,9 +1,16 @@ +#pragma once + // C/C++ #include #include #include #include +// harp +#include "dtridgl_impl.h" + +namespace harp { + template void toon_mckay89_longwave(int nlay, int nlev, const T *be, const T *tau_in, const T *w_in, const T *g_in, T a_surf_in, T *flx_up, @@ -76,6 +83,12 @@ void toon_mckay89_longwave(int nlay, int nlev, const T *be, const T *tau_in, T *lw_up_g = (T *)get_mem(nlev, sizeof(T), mem, &offset); T *lw_down_g = (T *)get_mem(nlev, sizeof(T), mem, &offset); + if (offset > memsize) { + fprintf(stderr, + "Error: Memory allocation failed in toon_mckay89_shortwave\n"); + exit(EXIT_FAILURE); + } + // === Precomputations === for (int k = 0; k < nlay; ++k) dtau_in[k] = tau_in[k + 1] - tau_in[k]; @@ -118,7 +131,12 @@ void toon_mckay89_longwave(int nlay, int nlev, const T *be, const T *tau_in, T bottom = Bsurf + B1[nlay - 1] * ubari; // === Solve tridiagonal system (not shown again for brevity) === - dtridgl(l, Af, Bf, Cf, Df, xk); + dtridgl(l, Af, Bf, Cf, Df, xk, mem, offset); + if (offset > memsize) { + fprintf(stderr, + "Error: Memory allocation failed in toon_mckay89_shortwave\n"); + exit(EXIT_FAILURE); + } // === Calculate xk1, xk2 from xkk === for (int n = 0; n < nlay; ++n) { @@ -150,8 +168,8 @@ void toon_mckay89_longwave(int nlay, int nlev, const T *be, const T *tau_in, } // === Gaussian quadrature integration === - memset(flx_up, 0, nlev * sizeof(T), mem, &offset); - memset(flx_down, 0, nlev * sizeof(T), mem, &offset); + memset(flx_up, 0, nlev * sizeof(T)); + memset(flx_down, 0, nlev * sizeof(T)); for (int m = 0; m < nmu; ++m) { for (int k = 0; k < nlay; ++k) { @@ -184,3 +202,5 @@ void toon_mckay89_longwave(int nlay, int nlev, const T *be, const T *tau_in, } } } + +} // namespace harp diff --git a/src/rtsolver/toon_mckay89_shortwave_impl.h b/src/rtsolver/toon_mckay89_shortwave_impl.h index c1a64ff..8e0ae4f 100644 --- a/src/rtsolver/toon_mckay89_shortwave_impl.h +++ b/src/rtsolver/toon_mckay89_shortwave_impl.h @@ -4,6 +4,11 @@ #include #include +// harp +#include "dtridgl_impl.h" + +namespace harp { + template void toon_mckay89_shortwave(int nlay, int nlev, T F0_in, const T *mu_in, const T *tau_in, const T *w_in, const T *g_in, @@ -183,7 +188,12 @@ void toon_mckay89_shortwave(int nlay, int nlev, T F0_in, const T *mu_in, Cf[l - 1] = 0.0; Df[l - 1] = bsurf - Cp[nlay - 1] + w_surf_in * Cm[nlay - 1]; - dtridgl(l, Af, Bf, Cf, Df, xk); + dtridgl(l, Af, Bf, Cf, Df, xk, mem, offset); + if (offset > memsize) { + fprintf(stderr, + "Error: Memory allocation failed in toon_mckay89_shortwave\n"); + exit(EXIT_FAILURE); + } for (int n = 0; n < nlay; ++n) { xk1[n] = xk[2 * n] + xk[2 * n + 1]; @@ -204,3 +214,5 @@ void toon_mckay89_shortwave(int nlay, int nlev, T F0_in, const T *mu_in, for (int n = 0; n < nlev; ++n) flx_down[n] += dir[n]; } + +} // namespace harp From 6139b72c906d8d8820cdd3bbd26afa714c2915fb Mon Sep 17 00:00:00 2001 From: mac/cli Date: Sun, 18 Jan 2026 12:52:13 -0500 Subject: [PATCH 2/9] wip --- src/CMakeLists.txt | 2 +- src/rtsolver/dtridgl_impl.h | 28 +- src/rtsolver/rtsolver_dispatch.cpp | 59 +++- src/rtsolver/rtsolver_dispatch.cu | 51 +++ src/rtsolver/rtsolver_dispatch.hpp | 31 ++ src/rtsolver/toon_mckay89.cpp | 93 +++++- src/rtsolver/toon_mckay89.hpp | 48 ++- src/rtsolver/toon_mckay89_longwave.cpp | 152 --------- src/rtsolver/toon_mckay89_longwave_impl.h | 315 ++++++++++-------- src/rtsolver/toon_mckay89_shortwave.cpp | 221 ------------- src/rtsolver/toon_mckay89_shortwave_impl.h | 367 ++++++++++----------- 11 files changed, 605 insertions(+), 762 deletions(-) create mode 100644 src/rtsolver/rtsolver_dispatch.cu create mode 100644 src/rtsolver/rtsolver_dispatch.hpp delete mode 100644 src/rtsolver/toon_mckay89_longwave.cpp delete mode 100644 src/rtsolver/toon_mckay89_shortwave.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5a3d0eb..e2927be 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -34,7 +34,7 @@ file(GLOB src_files utils/*.cpp opacity/*.cpp radiation/*.cpp - #rtsolver/*.cpp + rtsolver/*.cpp integrator/*.cpp ) diff --git a/src/rtsolver/dtridgl_impl.h b/src/rtsolver/dtridgl_impl.h index 09ef6ae..55093fa 100644 --- a/src/rtsolver/dtridgl_impl.h +++ b/src/rtsolver/dtridgl_impl.h @@ -3,36 +3,30 @@ // C/C++ #include +// base +#include + namespace harp { // Solves a tridiagonal system using the Thomas algorithm (TDMA) template -void dtridgl(int n, const T *a, const T *b, const T *c, const T *d, T *x, - char *mem, int &offset) { - T *cp = (T *)get_mem(n, sizeof(T), mem, &offset); - T *dp = (T *)get_mem(n, sizeof(T), mem, &offset); - - if (!cp || !dp) { - // Handle memory allocation failure - exit(EXIT_FAILURE); - } - +DISPATCH_MACRO void dtridgl(int n, const T *a, const T *b, T *c, T *d, T *x) { // First row - cp[0] = c[0] / b[0]; - dp[0] = d[0] / b[0]; + c[0] = c[0] / b[0]; + d[0] = d[0] / b[0]; // Forward sweep for (int i = 1; i < n; ++i) { - T denom = b[i] - a[i] * cp[i - 1]; + T denom = b[i] - a[i] * c[i - 1]; if (denom == 0.0) denom = 1e-12; // Avoid division by zero - cp[i] = (i < n - 1) ? c[i] / denom : 0.0; - dp[i] = (d[i] - a[i] * dp[i - 1]) / denom; + c[i] = (i < n - 1) ? c[i] / denom : 0.0; + d[i] = (d[i] - a[i] * d[i - 1]) / denom; } // Back substitution - x[n - 1] = dp[n - 1]; + x[n - 1] = d[n - 1]; for (int i = n - 2; i >= 0; --i) { - x[i] = dp[i] - cp[i] * x[i + 1]; + x[i] = d[i] - c[i] * x[i + 1]; } } diff --git a/src/rtsolver/rtsolver_dispatch.cpp b/src/rtsolver/rtsolver_dispatch.cpp index f6143dc..b1ba855 100644 --- a/src/rtsolver/rtsolver_dispatch.cpp +++ b/src/rtsolver/rtsolver_dispatch.cpp @@ -5,4 +5,61 @@ #include #include -namespace harp {} // namespace harp +// harp +#include "toon_mckay89_longwave_impl.h" +#include "toon_mckay89_shortwave_impl.h" + +namespace harp { + +void call_toon89_sw_cpu(at::TensorIterator &iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_sw_cpu", [&] { + int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); + int grain_size = iter.numel() / at::get_num_threads(); + + iter.for_each( + [&](char **data, const int64_t *strides, int64_t n) { + for (int i = 0; i < n; i++) { + auto out = reinterpret_cast(data[0] + i * strides[0]); + auto prop = reinterpret_cast(data[1] + i * strides[1]); + auto umu0 = reinterpret_cast(data[2] + i * strides[2]); + auto fbeam = reinterpret_cast(data[3] + i * strides[3]); + auto albedo = + reinterpret_cast(data[4] + i * strides[4]); + toon_mckay89_shortwave(nlay, *fbeam, umu0, prop, *albedo, out, work) + } + }, + grain_size); + }); +} + +void call_toon89_lw_cpu(at::TensorIterator &iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_lw_cpu", [&] { + int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); + int grain_size = iter.numel() / at::get_num_threads(); + + iter.for_each( + [&](char **data, const int64_t *strides, int64_t n) { + for (int i = 0; i < n; i++) { + auto out = reinterpret_cast(data[0] + i * strides[0]); + auto prop = reinterpret_cast(data[1] + i * strides[1]); + auto albedo = + reinterpret_cast(data[4] + i * strides[4]); + auto be = reinterpret_cast(data[5] + i * strides[5]); + toon_mckay89_longwave(nlay, be, prop, *albedo, out, work); + } + }, + grain_size); + }); +} + +} // namespace harp + +namespace at::native { + +DEFINE_DISPATCH(call_toon89_lw); +DEFINE_DISPATCH(call_toon89_sw); + +REGISTER_ALL_CPU_DISPATCH(call_toon89_lw, &harp::call_toon89_lw_cpu); +REGISTER_ALL_CPU_DISPATCH(call_toon89_sw, &harp::call_toon89_sw_cpu); + +} // namespace at::native diff --git a/src/rtsolver/rtsolver_dispatch.cu b/src/rtsolver/rtsolver_dispatch.cu new file mode 100644 index 0000000..ecf7d64 --- /dev/null +++ b/src/rtsolver/rtsolver_dispatch.cu @@ -0,0 +1,51 @@ +// torch +#include +#include +#include +#include +#include + +// harp +#include +#include "disort_dispatch.hpp" +#include "disort_impl.h" + +namespace disort { + +void call_toon89_lw_cuda(at::TensorIterator& iter, int rank_in_column, + disort_state *ds, disort_output *ds_out) { + at::cuda::CUDAGuard device_guard(iter.device()); + + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_disort_cuda", [&] { + auto nprop = at::native::ensure_nonempty_size(iter.output(), -1); + + native::gpu_kernel<12>( + iter, [=] GPU_LAMBDA(char* const data[12], unsigned int strides[12]) { + auto out = reinterpret_cast(data[0] + strides[0]); + auto prop = reinterpret_cast(data[1] + strides[1]); + auto umu0 = reinterpret_cast(data[2] + strides[2]); + auto phi0 = reinterpret_cast(data[3] + strides[3]); + auto fbeam = reinterpret_cast(data[4] + strides[4]); + auto albedo = reinterpret_cast(data[5] + strides[5]); + auto fluor = reinterpret_cast(data[6] + strides[6]); + auto fisot = reinterpret_cast(data[7] + strides[7]); + auto temis = reinterpret_cast(data[8] + strides[8]); + auto btemp = reinterpret_cast(data[9] + strides[9]); + auto ttemp = reinterpret_cast(data[10] + strides[10]); + auto temf = reinterpret_cast(data[11] + strides[11]); + auto idxf = reinterpret_cast(data[12] + strides[12]); + int idx = static_cast(*idxf); + // disort_impl(out, prop, ftoa, temf, rank_in_column, ds[*idx], + // ds_out[*idx], nprop); + }); + }); +} + +} // namespace disort + +namespace at::native { + +REGISTER_CUDA_DISPATCH(call_toon89_lw, &disort::call_toon89_lw_cuda); +REGISTER_CUDA_DISPATCH(call_toon89_sw, &disort::call_toon89_sw_cuda); + +} // namespace at::native diff --git a/src/rtsolver/rtsolver_dispatch.hpp b/src/rtsolver/rtsolver_dispatch.hpp new file mode 100644 index 0000000..2b949ab --- /dev/null +++ b/src/rtsolver/rtsolver_dispatch.hpp @@ -0,0 +1,31 @@ +#pragma once + +// torch +#include +#include + +namespace at::native { + +using toon89_fn = void (*)(at::TensorIterator &iter); + +//! \brief Toon 1989 longwave solver +/*! + * Based on Elsie Lee's implementation in Exo-FMS_column_ck, which was + * based on CHIMERA code by Mike Line. + * Ported by Xi Zhang to Eigen + * Ported by Cheng Li to torch + * Reference: Toon, O.B., 1989, JGR, 94, 16287-16301. + */ +DECLARE_DISPATCH(toon89_fn, call_toon89_lw); + +//! \brief Toon 1989 shortwave solver +/*! + * Based on Elsie Lee's implementation in Exo-FMS_column_ck, which was + * based on CHIMERA code by Mike Line. + * Ported by Xi Zhang to Eigen + * Ported by Cheng Li to torch + * Reference: Toon, O.B., 1989, JGR, 94,16287-16301. + */ +DECLARE_DISPATCH(toon89_fn, call_toon89_sw); + +} // namespace at::native diff --git a/src/rtsolver/toon_mckay89.cpp b/src/rtsolver/toon_mckay89.cpp index e5fe162..368d378 100644 --- a/src/rtsolver/toon_mckay89.cpp +++ b/src/rtsolver/toon_mckay89.cpp @@ -3,6 +3,8 @@ #include +#include "rtsolver_dispatch.hpp" + namespace harp { ToonMcKay89Impl::ToonMcKay89Impl(ToonMcKay89Options const& options) @@ -17,36 +19,81 @@ void ToonMcKay89Impl::reset() { torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, std::map* bc, torch::optional temf) { - int nlay = prop.size(-1); - int ncol = prop.size(1); + // check dimensions + TORCH_CHECK(prop.dim() == 4, "ToonMcKay89::forward: prop.dim() != 4"); - auto prop1 = prop.flip(-1); // from top to bottom + int nwave = prop.size(0); + int ncol = prop.size(1); + int nlyr = prop.size(2); // optical thickness - auto tau = prop.select(-1, 0); + auto tau = prop.select(-1, 0).flip(-1); // single scattering albedo - auto w0 = prop.select(-1, 1); + auto w0 = prop.select(-1, 1).flip(-1); // scattering asymmetry parameter - auto g = prop.select(-1, 2); + auto g = prop.select(-1, 2).flip(-1); // add slash if (bname.size() > 0 && bname.back() != '/') { bname += "/"; } - TORCH_CHECK(bc->count(bname + "albedo") > 0, - "Boundary condition for surface albedo not found."); + // check bc + if (bc->find(bname + "umu0") != bc->end()) { + TORCH_CHECK(bc->at(bname + "umu0").dim() == 1, + "DisortImpl::forward: bc->umu0.dim() != 1"); + TORCH_CHECK(bc->at(bname + "umu0").size(0) == ncol, + "DisortImpl::forward: bc->umu0.size(0) != ncol"); + (*bc)["umu0"] = bc->at(bname + "umu0"); + } else { + (*bc)["umu0"] = torch::ones({1, ncol}, prop.options()); + } + + if (bc->find(bname + "fbeam") != bc->end()) { + TORCH_CHECK(bc->at(bname + "fbeam").dim() == 2, + "DisortImpl::forward: bc->fbeam.dim() != 2"); + TORCH_CHECK(bc->at(bname + "fbeam").size(0) == nwave, + "DisortImpl::forward: bc->fbeam.size(0) != nwave"); + TORCH_CHECK(bc->at(bname + "fbeam").size(1) == ncol, + "DisortImpl::forward: bc->fbeam.size(1) != ncol"); + (*bc)["fbeam"] = bc->at(bname + "fbeam"); + } else { + (*bc)["fbeam"] = torch::zeros({nwave, ncol}, prop.options()); + } + + if (bc->find(bname + "albedo") != bc->end()) { + TORCH_CHECK(bc->at(bname + "albedo").dim() == 2, + "DisortImpl::forward: bc->albedo.dim() != 2"); + TORCH_CHECK(bc->at(bname + "albedo").size(0) == nwave, + "DisortImpl::forward: bc->albedo.size(0) != nwave"); + TORCH_CHECK(bc->at(bname + "albedo").size(1) == ncol, + "DisortImpl::forward: bc->albedo.size(1) != ncol"); + (*bc)["albedo"] = bc->at(bname + "albedo"); + } else { + (*bc)["albedo"] = torch::zeros({nwave, ncol}, prop.options()); + } + + auto flx = torch::zeros({nwave, ncol, nlyr + 1, 2}, prop.options()); if (!temf.has_value()) { // shortwave - TORCH_CHECK(bc->count(bname + "fbeam") > 0, - "Boundary condition for incoming flux not found."); - TORCH_CHECK(bc->count("umu0") > 0, "Boundary condition for mu0 not found."); - return shortwave_solver(bc->at(bname + "fbeam"), bc->at("umu0"), tau, w0, g, - bc->at(bname + "albedo")) - .flip(-2); + auto iter = at::TensorIteratorConfig() + .resize_outputs(false) + .check_all_same_dtype(true) + .declare_static_shape({nwave, ncol, nlyr + 1, 2}, + /*squash_dims=*/{2, 3}) + .add_output(flx) + .add_input(prop) + .add_owned_input(bc->at("umu0") + .view({1, ncol, 1, 1}) + .expand({nwave, ncol, nlyr, 1})) + .add_owned_input(bc->at("fbeam").view({nwave, ncol, 1, 1})) + .add_owned_input(bc->at("albedo").view({nwave, ncol, 1, 1})) + .build(); + at::native::call_toon89_sw(flx.device().type(), iter); + return flx; } else { // longwave /*Eigen::VectorXd temp(nlay + 1); Eigen::VectorXd be(nlay + 1); @@ -54,8 +101,22 @@ torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, temp(i) = ds_.temper[i]; be(i) = BB_integrate(ds_.temper[i], spec.wav1, spec.wav2); }*/ - auto be = bbflux_wavenumber(wave, temp); - return longwave_solver(be, tau, w0, g, bc->at(bname + "albedo")).flip(-2); + auto be = bbflux_wavenumber(wave, tempf.value()); + auto iter = at::TensorIteratorConfig() + .resize_outputs(false) + .check_all_same_dtype(true) + .declare_static_shape({nwave, ncol, nlyr + 1, 2}, + /*squash_dims=*/{2, 3}) + .add_output(flx) + .add_input(prop) + .add_owned_input(bc->at("fbeam").view({nwave, ncol, 1, 1})) + .add_owned_input(bc->at("albedo").view({nwave, ncol, 1, 1})) + .add_owned_input(be.view({1, ncol, nlyr + 1, 1}) + .expand({nwave, ncol, nlyr + 1, 1})) + .build(); + + at::native::call_toon89_lw(flx.device().type(), iter); + return flx; } } diff --git a/src/rtsolver/toon_mckay89.hpp b/src/rtsolver/toon_mckay89.hpp index c033061..cd13c3e 100644 --- a/src/rtsolver/toon_mckay89.hpp +++ b/src/rtsolver/toon_mckay89.hpp @@ -12,8 +12,23 @@ namespace harp { -struct ToonMcKay89Options { - ToonMcKay89Options() = default; +struct ToonMcKay89OptionsImpl { + ToonMcKay89Options(); + static std::shared_ptr create() { + return std::make_shared(); + } + + void report(std::ostream& os) const { + os << "* zenith_correction = " << zenith_correction() << "\n"; + + os << "* wave_lower = "; + for (auto const& v : wave_lower()) os << v << ", "; + os << "\n"; + + os << "* wave_upper = "; + for (auto const& v : wave_upper()) os << v << ", "; + os << "\n"; + } //! set lower wavenumber(length) at each bin ADD_ARG(std::vector, wave_lower) = {}; @@ -25,13 +40,15 @@ struct ToonMcKay89Options { ADD_ARG(bool, zenith_correction) = false; }; +using ToonMcKay89Options = std::shared_ptr; + class ToonMcKay89Impl : public torch::nn::Cloneable { public: //! options with which this `ToonMcKay89Impl` was constructed ToonMcKay89Options options; //! Constructor to initialize the layers - ToonMcKay89Impl() = default; + ToonMcKay89Impl() : options(ToonMcKay89OptionsImpl::create()) {} explicit ToonMcKay89Impl(ToonMcKay89Options const& options); void reset() override; @@ -49,31 +66,6 @@ class ToonMcKay89Impl : public torch::nn::Cloneable { std::map* bc, std::string bname = "", torch::optional temf = torch::nullopt); - - private: - //! \brief Toon 1989 shortwave solver - /*! - * Based on Elsie Lee's implementation in Exo-FMS_column_ck, which was - * based on CHIMERA code by Mike Line. - * Ported by Xi Zhang to Eigen - * Ported by Cheng Li to torch - * Reference: Toon, O.B., 1989, JGR, 94,16287-16301. - */ - torch::Tensor shortwave_solver(torch::Tensor Finc, torch::Tensor mu0, - torch::Tensor dtau, torch::Tensor w0, - torch::Tensor g, torch::Tensor albedo); - - //! \brief Toon 1989 longwave solver - /*! - * Based on Elsie Lee's implementation in Exo-FMS_column_ck, which was - * based on CHIMERA code by Mike Line. - * Ported by Xi Zhang to Eigen - * Ported by Cheng Li to torch - * Reference: Toon, O.B., 1989, JGR, 94, 16287-16301. - */ - torch::Tensor longwave_solver(torch::Tensor be, torch::Tensor dtau, - torch::Tensor w0, torch::Tensor g, - torch::Tensor albedo); }; } // namespace harp diff --git a/src/rtsolver/toon_mckay89_longwave.cpp b/src/rtsolver/toon_mckay89_longwave.cpp deleted file mode 100644 index 287276b..0000000 --- a/src/rtsolver/toon_mckay89_longwave.cpp +++ /dev/null @@ -1,152 +0,0 @@ -// C/C++ -#include - -// harp -#include - -#include "toon_mckay89.hpp" - -torch::Tensor ToonMcKay89Impl::longwave_solver(torch::Tensor be, - torch::Tensor tau_in, - torch::Tensor w_in, - torch::Tensor g_in, - torch::Tensor w_surf_in) { - const int nmu = 2; - const auto dtype = be.scalar_type(); - const auto device = be.device(); - const auto uarr = - torch::tensor({0.21132487, 0.78867513}, - torch::TensorOptions().dtype(dtype).device(device)); - const auto w = torch::tensor( - {0.5, 0.5}, torch::TensorOptions().dtype(dtype).device(device)); - const auto wuarr = uarr * w; - const double ubari = 0.5; - const double twopi = 6.283185307179586; - - int nlev = nlay + 1; - - auto out = torch::zeros({ncol, nlev, 2}, tau_cum.options()); - flx_down = out.select(-1, 0); - flx_up = out.select(-1, 1); - - // dtau = (1 - w * g^2) * (tau_in[1:] - tau_in[:-1]) - auto dtau_in = tau_in.narrow(-1, 1, nlay) - tau_in.narrow(-1, 0, nlay); - auto g2 = g_in * g_in; - auto dtau = (1.0 - w_in * g2) * dtau_in; - - auto w0 = ((1.0 - g2) * w_in) / (1.0 - w_in * g2); - auto hg = g_in / (1.0 + g_in); - - auto tau0 = torch::zeros_like(tau_in.select(-1, 0).unsqueeze(-1)); - auto tau = torch::cat({tau0, dtau.cumsum(-1)}, -1); - - auto denom = 1.0 - w0 * hg; - auto alp = ((1.0 - w0) / denom).sqrt(); - auto lam = alp * denom / ubari; - auto gam = (1.0 - alp) / (1.0 + alp); - auto term = ubari / denom; - - auto B0 = torch::empty_like(w0); - auto B1 = torch::empty_like(w0); - auto small_dtau_mask = dtau <= 1.0e-6; - - auto be_k = be.narrow(-1, 0, nlay); - auto be_k1 = be.narrow(-1, 1, nlay); - - B1.masked_fill_(small_dtau_mask, 0.0); - B0.masked_scatter_(small_dtau_mask, - 0.5 * (be_k + be_k1).masked_select(small_dtau_mask)); - - auto B1_alt = (be_k1 - be_k) / dtau; - B1.masked_scatter_(~small_dtau_mask, B1_alt.masked_select(~small_dtau_mask)); - B0.masked_scatter_(~small_dtau_mask, be_k.masked_select(~small_dtau_mask)); - - auto term_B1 = B1 * term; - auto Cpm1 = B0 + term_B1; - auto Cmm1 = B0 - term_B1; - auto dtau_B1 = B1 * dtau; - auto Cp = B0 + dtau_B1 + term_B1; - auto Cm = B0 + dtau_B1 - term_B1; - - auto tautop = dtau.select(-1, 0) * std::exp(-1.0); - auto Btop = (1.0 - (tautop / ubari).neg().exp()) * be.select(-1, 0); - auto Bsurf = be.select(-1, nlev - 1); - auto bottom = Bsurf + B1.select(-1, nlay - 1) * ubari; - - auto exptrm = torch::min(lam * dtau, torch::tensor(35.0, lam.options())); - auto Ep = exptrm.exp(); - auto Em = 1.0 / Ep; - - auto E1 = Ep + gam * Em; - auto E2 = Ep - gam * Em; - auto E3 = gam * Ep + Em; - auto E4 = gam * Ep - Em; - - // ========================== Fill Af, Bf, Cf, Df ========================== - int l = 2 * nlay; - torch::Tensor Af_vec = torch::zeros_like(torch::empty({l}, dtau.options())) - .expand_as(B0) - .unsqueeze(-1) - .repeat({1, 1, l}); - torch::Tensor Bf_vec = Af_vec.clone(); - torch::Tensor Cf_vec = Af_vec.clone(); - torch::Tensor Df_vec = Af_vec.clone(); - - Bf_vec.select(-1, 0).copy_(gam.select(-1, 0) + 1.0); - Cf_vec.select(-1, 0).copy_(gam.select(-1, 0) - 1.0); - Df_vec.select(-1, 0).copy_(Btop.unsqueeze(-1) - Cmm1.select(-1, 0)); - - for (int i = 1, n = 1; i < l - 1; i += 2, ++n) { - auto gam_n = gam.select(-1, n); - auto gam_nm1 = gam.select(-1, n - 1); - - auto E1_n = E1.select(-1, n - 1); - auto E2_n = E2.select(-1, n - 1); - auto E3_n = E3.select(-1, n - 1); - auto E4_n = E4.select(-1, n - 1); - - auto Cp_nm1 = Cp.select(-1, n - 1); - auto Cpm1_n = Cpm1.select(-1, n); - auto Cm_nm1 = Cm.select(-1, n - 1); - auto Cmm1_n = Cmm1.select(-1, n); - - Af_vec.select(-1, i).copy_((E1_n + E3_n) * (gam_n - 1.0)); - Bf_vec.select(-1, i).copy_((E2_n + E4_n) * (gam_n - 1.0)); - Cf_vec.select(-1, i).copy_(2.0 * (1.0 - gam_n * gam_n)); - Df_vec.select(-1, i).copy_((gam_n - 1.0) * (Cpm1_n - Cp_nm1) + - (1.0 - gam_n) * (Cm_nm1 - Cmm1_n)); - } - - for (int i = 2, n = 1; i < l - 1; i += 2, ++n) { - auto gam_n = gam.select(-1, n); - auto gam_nm1 = gam.select(-1, n - 1); - - auto E1_n = E1.select(-1, n - 1); - auto E3_n = E3.select(-1, n - 1); - - auto Cp_nm1 = Cp.select(-1, n - 1); - auto Cpm1_n = Cpm1.select(-1, n); - auto Cm_nm1 = Cm.select(-1, n - 1); - auto Cmm1_n = Cmm1.select(-1, n); - - Af_vec.select(-1, i).copy_(2.0 * (1.0 - gam_nm1 * gam_nm1)); - Bf_vec.select(-1, i).copy_((E1_n - E3_n) * (1.0 + gam_n)); - Cf_vec.select(-1, i).copy_((E1_n + E3_n) * (gam_n - 1.0)); - Df_vec.select(-1, i).copy_(E3_n * (Cpm1_n - Cp_nm1) + - E1_n * (Cm_nm1 - Cmm1_n)); - } - - Af_vec.select(-1, l - 1).copy_(E1.select(-1, nlay - 1) - - a_surf_in * E3.select(-1, nlay - 1)); - Bf_vec.select(-1, l - 1).copy_(E2.select(-1, nlay - 1) - - a_surf_in * E4.select(-1, nlay - 1)); - Cf_vec.select(-1, l - 1).fill_(0.0); - Df_vec.select(-1, l - 1).copy_(Bsurf - Cp.select(-1, nlay - 1) + - a_surf_in * Cm.select(-1, nlay - 1)); - - // Fill output fluxes - flx_up = torch::zeros_like(be); - flx_down = torch::zeros_like(be); - - return out; -} diff --git a/src/rtsolver/toon_mckay89_longwave_impl.h b/src/rtsolver/toon_mckay89_longwave_impl.h index eb0063e..bfc786d 100644 --- a/src/rtsolver/toon_mckay89_longwave_impl.h +++ b/src/rtsolver/toon_mckay89_longwave_impl.h @@ -1,23 +1,31 @@ #pragma once // C/C++ -#include -#include -#include -#include +#include +#include +#include + +// base +#include // harp #include "dtridgl_impl.h" +#define DTAU_IN(i) prop[(nlay - i - 1) * 3] +#define W_IN(i) prop[(nlay - i - 1) * 3 + 1] +#define G_IN(i) prop[(nlay - i - 1) * 3 + 2] +#define FLX_UP(i) flx[2 * (nlev - i - 1)] +#define FLX_DN(i) flx[2 * (nlev - i - 1) + 1] + namespace harp { template -void toon_mckay89_longwave(int nlay, int nlev, const T *be, const T *tau_in, - const T *w_in, const T *g_in, T a_surf_in, T *flx_up, - T *flx_down, char *mem, int memsize) { +DISPATCH_MACRO void toon_mckay89_longwave(int nlay, const T *be, const T *prop, + T a_surf_in, T *flx, char *work) { + int nlev = nlay + 1; int l = 2 * nlay; - int lm1 = l - 1; int lm2 = l - 2; + int lm1 = l - 1; // Constants const int nmu = 5; @@ -29,88 +37,74 @@ void toon_mckay89_longwave(int nlay, int nlev, const T *be, const T *tau_in, const T *wuarr = {0.0157479145, 0.0739088701, 0.1463869871, 0.1671746381, 0.0967815902}; - // Scratch arrays - T *dtau_in = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *dtau = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *tau = (T *)get_mem(nlev, sizeof(T), mem, &offset); - T *w0 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *hg = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *B0 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *B1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *lam = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *gam = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *alp = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *term = (T *)get_mem(nlay, sizeof(T), mem, &offset); - - T *Cpm1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Cmm1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Cp = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Cm = (T *)get_mem(nlay, sizeof(T), mem, &offset); - - T *exptrm = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Ep = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Em = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *E1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *E2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *E3 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *E4 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - - T *Af = (T *)get_mem(l, sizeof(T), mem, &offset); - T *Bf = (T *)get_mem(l, sizeof(T), mem, &offset); - T *Cf = (T *)get_mem(l, sizeof(T), mem, &offset); - T *Df = (T *)get_mem(l, sizeof(T), mem, &offset); - T *xkk = (T *)get_mem(l, sizeof(T), mem, &offset); - T *xk1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *xk2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - - T *g = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *h = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *xj = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *xk = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *alpha1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *alpha2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *sigma1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *sigma2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - - T *em1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *obj = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *epp = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *obj2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *epp2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *em2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *em3 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - - T *lw_up_g = (T *)get_mem(nlev, sizeof(T), mem, &offset); - T *lw_down_g = (T *)get_mem(nlev, sizeof(T), mem, &offset); - - if (offset > memsize) { - fprintf(stderr, - "Error: Memory allocation failed in toon_mckay89_shortwave\n"); - exit(EXIT_FAILURE); - } - - // === Precomputations === - - for (int k = 0; k < nlay; ++k) dtau_in[k] = tau_in[k + 1] - tau_in[k]; - - for (int k = 0; k < nlay; ++k) { - T g2 = g_in[k] * g_in[k]; - T denom = 1.0 - w_in[k] * g2; - w0[k] = (1.0 - g2) * w_in[k] / denom; - dtau[k] = denom * dtau_in[k]; - hg[k] = g_in[k] / (1.0 + g_in[k]); + // --- Work Variables Allocation --- + dtau = alloc_from(work, nlay); + tau = alloc_from(work, nlev); + T *w0 = alloc_from(work, nlay); + T *hg = alloc_from(work, nlay); + T *B0 = alloc_from(work, nlay); + T *B1 = alloc_from(work, nlay); + T *lam = alloc_from(work, nlay); + T *gam = alloc_from(work, nlay); + T *alp = alloc_from(work, nlay); + T *term = alloc_from(work, nlay); + + T *Cpm1 = alloc_from(work, nlay); + T *Cmm1 = alloc_from(work, nlay); + T *Cp = alloc_from(work, nlay); + T *Cm = alloc_from(work, nlay); + + T *exptrm = alloc_from(work, nlay); + T *Ep = alloc_from(work, nlay); + T *Em = alloc_from(work, nlay); + T *E1 = alloc_from(work, nlay); + T *E2 = alloc_from(work, nlay); + T *E3 = alloc_from(work, nlay); + T *E4 = alloc_from(work, nlay); + + T *Af = alloc_from(work, l); + T *Bf = alloc_from(work, l); + T *Cf = alloc_from(work, l); + T *Df = alloc_from(work, l); + T *xkk = alloc_from(work, l); + T *xk1 = alloc_from(work, nlay); + T *xk2 = alloc_from(work, nlay); + + T *g = alloc_from(work, nlay); + T *h = alloc_from(work, nlay); + T *xj = alloc_from(work, nlay); + T *xk = alloc_from(work, nlay); + T *alpha1 = alloc_from(work, nlay); + T *alpha2 = alloc_from(work, nlay); + T *sigma1 = alloc_from(work, nlay); + T *sigma2 = alloc_from(work, nlay); + + T *em1 = alloc_from(work, nlay); + T *em2 = alloc_from(work, nlay); + T *em3 = alloc_from(work, nlay); + T *lw_up_g = alloc_from(work, nlev); + T *lw_down_g = alloc_from(work, nlev); + + // Delta-Eddington Scaling + for (int i = 0; i < nlay; i++) { + T gsq = G_IN(i) * G_IN(i); + w0[i] = (1.0 - gsq) * W_IN(i) / (1.0 - W_IN(i) * gsq); + dtau[i] = (1.0 - W_IN(i) * gsq) * DTAU_IN(i); + hg[i] = G_IN(i) / (1.0 + G_IN(i)); } tau[0] = 0.0; - for (int k = 0; k < nlay; ++k) tau[k + 1] = tau[k] + dtau[k]; + for (int k = 0; k < nlay; k++) { + tau[k + 1] = tau[k] + dtau[k]; + } - for (int k = 0; k < nlay; ++k) { + for (int k = 0; k < nlay; k++) { alp[k] = sqrt((1.0 - w0[k]) / (1.0 - w0[k] * hg[k])); lam[k] = alp[k] * (1.0 - w0[k] * hg[k]) / ubari; gam[k] = (1.0 - alp[k]) / (1.0 + alp[k]); term[k] = ubari / (1.0 - w0[k] * hg[k]); - if (dtau[k] <= 1e-6) { + if (dtau[k] <= 1.0e-6) { B1[k] = 0.0; B0[k] = 0.5 * (be[k + 1] + be[k]); } else { @@ -122,85 +116,130 @@ void toon_mckay89_longwave(int nlay, int nlev, const T *be, const T *tau_in, Cmm1[k] = B0[k] - B1[k] * term[k]; Cp[k] = B0[k] + B1[k] * dtau[k] + B1[k] * term[k]; Cm[k] = B0[k] + B1[k] * dtau[k] - B1[k] * term[k]; + + exptrm[k] = fmin(lam[k] * dtau[k], 35.0); + Ep[k] = exp(exptrm[k]); + Em[k] = 1.0 / Ep[k]; + E1[k] = Ep[k] + gam[k] * Em[k]; + E2[k] = Ep[k] - gam[k] * Em[k]; + E3[k] = gam[k] * Ep[k] + Em[k]; + E4[k] = gam[k] * Ep[k] - Em[k]; } T tautop = dtau[0] * exp(-1.0); T Btop = (1.0 - exp(-tautop / ubari)) * be[0]; T Bsurf = be[nlev - 1]; - - T bottom = Bsurf + B1[nlay - 1] * ubari; - - // === Solve tridiagonal system (not shown again for brevity) === - dtridgl(l, Af, Bf, Cf, Df, xk, mem, offset); - if (offset > memsize) { - fprintf(stderr, - "Error: Memory allocation failed in toon_mckay89_shortwave\n"); - exit(EXIT_FAILURE); + T bsurf_flux = Bsurf; // Bsurf is local variable + + // --- Matrix Construction (1-based indices for solver) --- + Af[1] = 0.0; + Bf[1] = gam[0] + 1.0; + Cf[1] = gam[0] - 1.0; + Df[1] = Btop - Cmm1[0]; + + int n_idx = 0; + for (int i = 2; i <= lm2; i += 2) { + Af[i] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); + Bf[i] = (E2[n_idx] + E4[n_idx]) * (gam[n_idx + 1] - 1.0); + Cf[i] = 2.0 * (1.0 - gam[n_idx + 1] * gam[n_idx + 1]); + Df[i] = (gam[n_idx + 1] - 1.0) * (Cpm1[n_idx + 1] - Cp[n_idx]) + + (1.0 - gam[n_idx + 1]) * (Cm[n_idx] - Cmm1[n_idx + 1]); + n_idx++; } - // === Calculate xk1, xk2 from xkk === - for (int n = 0; n < nlay; ++n) { - xk1[n] = xkk[2 * n] + xkk[2 * n + 1]; - xk2[n] = xkk[2 * n] - xkk[2 * n + 1]; - if (fabs(xk2[n] / xkk[2 * n]) < 1e-30) xk2[n] = 0.0; + n_idx = 0; + for (int i = 3; i <= lm1; i += 2) { + Af[i] = 2.0 * (1.0 - gam[n_idx] * gam[n_idx]); + Bf[i] = (E1[n_idx] - E3[n_idx]) * (1.0 + gam[n_idx + 1]); + Cf[i] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); + Df[i] = E3[n_idx] * (Cpm1[n_idx + 1] - Cp[n_idx]) + + E1[n_idx] * (Cm[n_idx] - Cmm1[n_idx + 1]); + n_idx++; } - // === Conditional computation for g, h, xj, xk etc. === - for (int k = 0; k < nlay; ++k) { - if (w0[k] <= 1e-4) { - g[k] = h[k] = xj[k] = xk[k] = 0.0; - alpha1[k] = sigma1[k] = twopi * B0[k]; - alpha2[k] = sigma2[k] = twopi * B1[k]; + Af[l] = E1[nlay - 1] - a_surf_in * E3[nlay - 1]; + Bf[l] = E2[nlay - 1] - a_surf_in * E4[nlay - 1]; + Cf[l] = 0.0; + Df[l] = bsurf_flux - Cp[nlay - 1] + a_surf_in * Cm[nlay - 1]; + + dtridgl(l, Af, Bf, Cf, Df, xkk); + + for (int n = 0; n < nlay; n++) { + xk1[n] = xkk[2 * n + 1] + xkk[2 * n + 2]; + xk2[n] = xkk[2 * n + 1] - xkk[2 * n + 2]; + if (fabs(xk2[n]) < 1e-30 * fabs(xkk[2 * n + 1])) xk2[n] = 0.0; + + if (w0[n] <= 1e-4) { + g[n] = 0.0; + h[n] = 0.0; + xj[n] = 0.0; + xk_ptr[n] = 0.0; // using xk_ptr because xk name conflict + alpha1[n] = twopi * B0[n]; + alpha2[n] = twopi * B1[n]; + sigma1[n] = alpha1[n]; + sigma2[n] = alpha2[n]; } else { - T f1 = (1.0 + hg[k] * alp[k]) / (1.0 + alp[k]); - T f2 = (1.0 - hg[k] * alp[k]) / (1.0 + alp[k]); - - g[k] = twopi * w0[k] * xk1[k] * f1; - h[k] = twopi * w0[k] * xk2[k] * f2; - xj[k] = twopi * w0[k] * xk1[k] * f2; - xk[k] = twopi * w0[k] * xk2[k] * f1; - - T fact = ubari * w0[k] * hg[k] / (1.0 - w0[k] * hg[k]); - alpha1[k] = twopi * (B0[k] + B1[k] * fact); - sigma1[k] = twopi * (B0[k] - B1[k] * fact); - alpha2[k] = sigma2[k] = twopi * B1[k]; + T common_den = 1.0 + alp[n]; + g[n] = twopi * w0[n] * xk1[n] * (1.0 + hg[n] * alp[n]) / common_den; + h[n] = twopi * w0[n] * xk2[n] * (1.0 - hg[n] * alp[n]) / common_den; + xj[n] = twopi * w0[n] * xk1[n] * (1.0 - hg[n] * alp[n]) / common_den; + xk[n] = twopi * w0[n] * xk2[n] * (1.0 + hg[n] * alp[n]) / common_den; + T term_val = ubari * w0[n] * hg[n] / (1.0 - w0[n] * hg[n]); + alpha1[n] = twopi * (B0[n] + B1[n] * term_val); + alpha2[n] = twopi * B1[n]; + sigma1[n] = twopi * (B0[n] - B1[n] * term_val); + sigma2[n] = alpha2[n]; } + em1[n] = 1.0 / exp(fmin(lam[n] * dtau[n], 35.0)); } - // === Gaussian quadrature integration === - memset(flx_up, 0, nlev * sizeof(T)); - memset(flx_down, 0, nlev * sizeof(T)); + for (int k = 0; k < nlev; k++) { + FLX_UP(k) = 0.0; + FLX_DN(k) = 0.0; + } - for (int m = 0; m < nmu; ++m) { - for (int k = 0; k < nlay; ++k) { - em2[k] = exp(-dtau[k] / uarr[m]); - em3[k] = em1[k] * em2[k]; - } + // --- Gaussian Quadrature Mu Loop --- + for (int m = 0; m < nmu; m++) { + T u = uarr[m]; + + // Downward loop + lw_down_g[0] = twopi * (1.0 - exp(-tautop / u)) * be[0]; + for (int k = 0; k < nlay; k++) { + em2[k] = exp(-dtau[k] / u); + T l_u_p1 = lam[k] * u + 1.0; + T l_u_m1 = lam[k] * u - 1.0; - lw_down_g[0] = twopi * (1.0 - exp(-tautop / uarr[m])) * be[0]; - for (int k = 0; k < nlay; ++k) { lw_down_g[k + 1] = - lw_down_g[k] * em2[k] + - (xj[k] / (lam[k] * uarr[m] + 1.0)) * (epp[k] - em2[k]) + - (xk[k] / (lam[k] * uarr[m] - 1.0)) * (em2[k] - em[k]) + - sigma1[k] * (1.0 - em2[k]) + - sigma2[k] * (uarr[m] * em2[k] + dtau[k] - uarr[m]); + lw_down_g[k] * em2[k] + (xj[k] / l_u_p1) * (Ep[k] - em2[k]) + + (xk[k] / l_u_m1) * (em2[k] - em1[k]) + sigma1[k] * (1.0 - em2[k]) + + sigma2[k] * (u * em2[k] + dtau[k] - u); } - lw_up_g[nlev - 1] = twopi * (Bsurf + B1[nlay - 1] * uarr[m]); - for (int k = nlay - 1; k >= 0; --k) { - lw_up_g[k] = lw_up_g[k + 1] * em2[k] + - (g[k] / (lam[k] * uarr[m] - 1.0)) * (epp[k] * em2[k] - 1.0) + - (h[k] / (lam[k] * uarr[m] + 1.0)) * (1.0 - em3[k]) + - alpha1[k] * (1.0 - em2[k]) + - alpha2[k] * (uarr[m] - (dtau[k] + uarr[m]) * em2[k]); + // Upward loop + lw_up_g[nlev - 1] = twopi * (Bsurf + B1[nlay - 1] * u); + for (int k = nlay - 1; k >= 0; k--) { + em2[k] = exp(-dtau[k] / u); + T em3_val = em1[k] * em2[k]; + T l_u_m1 = lam[k] * u - 1.0; + T l_u_p1 = lam[k] * u + 1.0; + + lw_up_g[k] = + lw_up_g[k + 1] * em2[k] + (g[k] / l_u_m1) * (Ep[k] * em2[k] - 1.0) + + (h[k] / l_u_p1) * (1.0 - em3_val) + alpha1[k] * (1.0 - em2[k]) + + alpha2[k] * (u - (dtau[k] + u) * em2[k]); } - for (int k = 0; k < nlev; ++k) { - flx_down[k] += lw_down_g[k] * wuarr[m]; - flx_up[k] += lw_up_g[k] * wuarr[m]; + for (int k = 0; k < nlev; k++) { + FLX_DN(k) += lw_down_g[k] * wuarr[m]; + FLX_UP(k) += lw_up_g[k] * wuarr[m]; } } } } // namespace harp + +#undef DTAU_IN +#undef W_IN +#undef G_IN +#undef FLX_UP +#undef FLX_DN diff --git a/src/rtsolver/toon_mckay89_shortwave.cpp b/src/rtsolver/toon_mckay89_shortwave.cpp deleted file mode 100644 index 60bc783..0000000 --- a/src/rtsolver/toon_mckay89_shortwave.cpp +++ /dev/null @@ -1,221 +0,0 @@ -// C/C++ -#include - -// harp -#include - -#include "toon_mckay89.hpp" - -torch::Tensor ToonMcKay89Impl::shortwave_solver( - torch::Tensor F0_in, torch::Tensor mu_in, torch::Tensor tau_in, - torch::Tensor w_in, torch::Tensor g_in, torch::Tensor w_surf_in) { - int nwave = tau_in.size(0); - int ncol = tau_in.size(1); - int nlay = tau_in.size(2); - - // Input validation - if (mu_in.size(0) != ncol || w_in.size(-1) != nlay || - g_in.size(-1) != nlay { - throw std::invalid_argument("Input vectors have incorrect sizes."); - } - - // increase the last dimension by 1 (lyr -> lvl) - auto shape = tau_in.sizes().vec(); - shape.back() += 1; - torch::Tensor tau_cum = torch::zeros(shape, tau_in.options()); - tau_cum.narrow(-1, 1, nlay) = tau_in.cumsum(-1); - - int nlev = tau_cum.size(-1); - - // Initialize output flux arrays - auto out = torch::zeros({ncol, nlev, 2}, tau_cum.options()); - flx_down = out.select(-1, 0); - flx_up = out.select(-1, 1); - - // Constants - const double sqrt3 = std::sqrt(3.0); - const double sqrt3d2 = sqrt3 / 2.0; - const double bsurf = 0.0; - const double btop = 0.0; - - // Check if all single scattering albedos are effectively zero - bool all_w0_zero = (w_in <= 1.0e-12).all().item(); - - if (all_w0_zero) { // no scattering - // Direct beam only - // No zenith correction, use regular method - if (!options.zenith_correction) { - flx_down = F0_in.unsqueeze(-1) * mu_in.unsqueeze(-1) * - (-tau_cum / mu_in.unsqueeze(-1)).exp(); - } else { - // Zenith angle correction using cumulative transmission - TORCH_CHECK(mu_in.size(-1) == nlay, - "The last dimension of mu_in should have layers"); - auto trans_cum = torch::zeros_like(tau_cum); - trans_cum.narrow(-1, 1, nlay) = tau_in / mu_in; - trans_cum.narrow(-1, 1, nlay) = torch::cumsum(trans_cum, -1); - - flx_down = F0_in.unsqueeze(-1) * mu_in * torch::exp(-trans_cum); - } - - // Adjust the downward flux at the surface layer for surface albedo - flx_down.select(-1, nlev - 1) *= 1.0 - w_surf_in; - - // Upward flux remains zero - return out; - } - - // Delta Eddington scaling - auto w0 = ((1.0 - g_in * g_in) * w_in) / (1.0 - w_in * g_in * g_in); - auto dtau = (1.0 - w_in * g_in * g_in) * tau_in; - auto hg = g_in / (1.0 + g_in); - - // Initialize tau_total - torch::Tensor tau_total = torch::zeros_like(tau_cum); - tau_total.narrow(-1, 1, nlay) = dtau.cumsum(-1); - - // Compute g1, g2, g3, g4 - auto g1 = sqrt3d2 * (2.0 - w0 * (1.0 + hg)); - auto g2 = sqrt3d2 * w0 * (1.0 - hg); - // Prevent division by zero - g2.clamp_(1.0e-10); - - // Compute mu_zm at midpoints - auto mu_zm = (mu_in.narrow(-1, 0, nlay) + mu_in.narrow(-1, 1, nlay)) / 2.0; - auto g3 = (1.0 - sqrt3 * hg * mu_zm) / 2.0; - auto g4 = 1.0 - g3; - - // Compute lam and gam - auto lam = (g1 * g1 - g2 * g2).square(); - auto gam = (g1 - lam) / g2; - - // Compute denom and handle denom == 0 - auto denom = lam * lam - (1.0 / mu_in.select(-1, nlev - 1).square()); - denom.clamp_(1.0e-10); - - // Compute Am and Ap - auto Am = F0_in * w0 * (g4 * (g1 + 1.0 / mu_in.select(-1, nlev - 1)) + g2 * g3) / denom; - auto Ap = F0_in * w0 * (g3 * (g1 - 1.0 / mu_in.select(-1, nlev - 1)) + g2 * g4) / denom; - - // Compute Cpm1 and Cmm1 at the top of the layer - auto Cpm1 = Ap * (-tau_total.narrow(-1, 0, nlay) / mu_in.select(-1, nlev - 1)).exp(); - auto Cmm1 = Am * (-tau_total.narrow(-1, 0, nlay) / mu_in.select(-1, nlev - 1)).exp(); - - // Compute Cp and Cm at the bottom of the layer - auto Cp = Ap * (-tau_total.narrow(-1, 1, nlay) / mu_in.select(-1, nlev - 1)).exp(); - auto Cm = Am * (-tau_total.narrow(-1, 1, nlay) / mu_in.select(-1, nlev - 1)).exp(); - - // Compute exponential terms, clamped to prevent overflow - auto exptrm = (lam * dtau).clamp_(35.0); - auto Ep = exptrm.exp(); - auto Em = 1.0 / Ep; - auto E1 = Ep + gam * Em; - auto E2 = Ep - gam * Em; - auto E3 = gam * Ep + Em; - auto E4 = gam * Ep - Em; - - // Initialize Af, Bf, Cf, Df - int l = 2 * nlay; - auto Af = torch::zeros({nwave, ncol, l}, tau_in.options()); - auto Bf = torch::zeros({nwave, ncol, l}, tau_in.options()); - auto Cf = torch::zeros({nwave, ncol, l}, tau_in.options()); - auto Df = torch::zeros({nwave, ncol, l}, tau_in.options()); - - // Boundary conditions at the top - Af.select(-1, 0) = 0.0; - Bf.select(-1, 0) = gam.select(-1, 0) + 1.0; - Cf.select(-1, 0) = gam.select(-1, 0) - 1.0; - Df.select(-1, 0) = btop - Cmm1.select(-1, 0); - for (int i = 1, n = 1; i < l - 1; i += 2, ++n) { - TORCK_CHECK(n < nlay, - "Index out of range in sw_Toon89 Af, Bf, Cf, Df population."); - - Af.select(-1, i) = (E1.select(-1, n - 1) + E3.select(-1, n - 1)) * - (gam.select(-1, n) - 1.0); - Bf.select(-1, i) = (E2.select(-1, n - 1) + E4.select(-1, n - 1)) * - (gam.select(-1, n) - 1.0); - Cf.select(-1, i) = 2.0 * (1.0 - gam.select(-1, n).square()); - Df.select(-1, i) = - (gam.select(-1, n) - 1.0) * - (Cpm1.select(-1, n) - Cp.select(-1, n - 1)) + - (1.0 - gam.select(-1, n)) * (Cm.select(-1, n - 1) - Cmm1.select(-1, n)); - } - - // Populate Af, Bf, Cf, Df for even indices - // Start from n=1 to avoid negative indexing (Cp(n-1) when n=0) - for (int i = 2, n = 1; i < l - 1; i += 2, ++n) { - TORCH_CHECK(n < nlay, - "Index out of range in sw_Toon89 Af, Bf, Cf, Df population."); - - Af.select(-1, i) = 2.0 * (1.0 - gam.select(-1, n).square()); - Bf.select(-1, i) = (E1.select(-1, n - 1) - E3.select(-1, n - 1)) * - (1.0 + gam.select(-1, n)); - Cf.select(-1, i) = (E1.select(-1, n - 1) + E3.select(-1, n - 1)) * - (gam.select(-1, n) - 1.0); - Df.select(-1, i) = - E3.select(-1, n - 1) * (Cpm1.select(-1, n) - Cp.select(-1, n - 1)) + - E1.select(-1, n - 1) * (Cm.select(-1, n - 1) - Cmm1.select(-1, n)); - } - - // Boundary conditions at l (last index) - Af.select(-1, l - 1) = E1.select(-1, nlay - 1) - w_surf_in * E3.select(-1, nlay - 1); - Bf.select(-1, l - 1) = E2.select(-1, nlay - 1) - w_surf_in * E4.select(-1, nlay - 1); - Cf.select(-1, l - 1) = 0.0; - Df.select(-1, l - 1) = bsurf - Cp.select(-1, nlay - 1) + w_surf_in * Cm.select(-1, nlay - 1); - - // Solve the tridiagonal system - tridiag_lu(Af, Bf, Cf); - tridiag_solve(Df, Af, Bf, Cf); - - // Compute xk1 and xk2 from xk - // select even and odd indices - auto xk_2n = Df.index_select(-1, torch::arange(0, tensor.size(0), 2)); - auto xk_2np1 = Df.index_select(-1, torch::arange(1, tensor.size(0), 2)); - - auto xk1 = xk_2n + xk_2np1; - auto xk2 = xk_2n - xk_2np1; - - xk2 = torch::where(torch::abs(xk2 / xk_2n) < 1e-30, torch::zeros_like(xk2), xk2); - - // Populate flx_up and flx_down for layers 1 to nlay - flx_up.select(-1, 0, nlay) = xk1 + g * xk2 + Cpm1; - flx_down.select(-1, 0, nlay) = xk1 * gam + xk2 + Cmm1; - - // Compute flx_up and flx_down at level nlev - flx_up.select(-1, 0, nlev - 1) = xk1.select(-1, nlay - 1) * std::exp(1.0) - + gam.select(-1, nlay - 1) * xk2.select(-1, nlay - 1) * std::exp(-1.0) - + Cp.select(-1, nlay - 1); - flx_down.select(-1, 0, nlev - 1) = xk1.select(-1, nlay - 1) * std::exp(1.0) - * gam.select(-1, nlay - 1) + xk2.select(-1, nlay - 1) * std::exp(-1.0) - + Cm.select(-1, nlay - 1); - - // Compute dir flux - Torch::Tensor dir; - if (!options.zenith_correction) { - // No zenith correction - dir = F0_in.unsqueeze(-1) * mu_in.unsqueeze(-1) * - (-tau_cum / mu_in.unsqueeze(-1)).exp(); - } else { - // Zenith angle correction - TORCH_CHECK(mu_in.size(-1) == nlay, - "The last dimension of mu_in should have layers"); - auto trans_cum = torch::zeros_like(tau_cum); - trans_cum.narrow(-1, 1, nlay) = tau_in / mu_in; - trans_cum.narrow(-1, 1, nlay) = torch::cumsum(trans_cum, -1); - - dir = F0_in * mu_in * torch::exp(-trans_cum); - } - - // Adjust the downward flux at the surface layer for surface albedo - dir.select(-1, nlev - 1) *= 1.0 - w_surf_in; - - // for(int i=0; i -#include -#include -#include +#include +#include +#include + +// base +#include // harp #include "dtridgl_impl.h" +#define DTAU_IN(i) prop[(nlay - i - 1) * 3] +#define W_IN(i) prop[(nlay - i - 1) * 3 + 1] +#define G_IN(i) prop[(nlay - i - 1) * 3 + 2] +#define FLX_UP(i) flx[2 * (nlev - i - 1)] +#define FLX_DN(i) flx[2 * (nlev - i - 1) + 1] + namespace harp { template -void toon_mckay89_shortwave(int nlay, int nlev, T F0_in, const T *mu_in, - const T *tau_in, const T *w_in, const T *g_in, - T w_surf_in, T *flx_down, T *flx_up, char *mem, - int memsize) { +DISPATCH_MACRO void toon_mckay89_shortwave(int nlay, T F0_in, T const *mu_in, + T const *prop, T w_surf_in, T *flx, + char *work) { + int nlev = nlay + 1; int l = 2 * nlay; - int lm1 = l - 1; int lm2 = l - 2; + int lm1 = l - 1; - // Constants - const T sqrt3 = sqrt(3.0); - const T sqrt3d2 = sqrt3 / 2.0; - const T bsurf = 0.0, btop = 0.0; - - // Scratch arrays - T *dir = (T *)get_mem(nlev, sizeof(T), mem, &offset); - T *tau = (T *)get_mem(nlev, sizeof(T), mem, &offset); - T *cum_trans = (T *)get_mem(nlev, sizeof(T), mem, &offset); - T *dtau_in = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *dtau = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *mu_zm = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *w0 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *hg = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *g1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *g2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *g3 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *g4 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *lam = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *gam = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *denom = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Am = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Ap = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Cpm1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Cmm1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Cp = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Cm = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *exptrm = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Ep = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Em = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *E1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *E2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *E3 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *E4 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *Af = (T *)get_mem(l, sizeof(T), mem, &offset); - T *Bf = (T *)get_mem(l, sizeof(T), mem, &offset); - T *Cf = (T *)get_mem(l, sizeof(T), mem, &offset); - T *Df = (T *)get_mem(l, sizeof(T), mem, &offset); - T *xk = (T *)get_mem(l, sizeof(T), mem, &offset); - T *xk1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *xk2 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - T *opt1 = (T *)get_mem(nlay, sizeof(T), mem, &offset); - - if (offset > memsize) { - fprintf(stderr, - "Error: Memory allocation failed in toon_mckay89_shortwave\n"); - exit(EXIT_FAILURE); - } - - // Early exit if all single scattering albedos are ~0 - int all_zero = 1; - for (int k = 0; k < nlay; ++k) { - if (w_in[k] > 1.0e-12) { - all_zero = 0; + // --- Memory Allocation --- + T *dir = alloc_from(work, nlev); + T *tau = alloc_from(work, nlev); + T *cum_trans = alloc_from(work, nlev); + T *tau_in = alloc_from(work, nlev); + T *dtau = alloc_from(work, nlay); + T *mu_zm = alloc_from(work, nlay); + T *w0 = alloc_from(work, nlay); + T *hg = alloc_from(work, nlay); + T *g1 = alloc_from(work, nlay); + T *g2 = alloc_from(work, nlay); + T *g3 = alloc_from(work, nlay); + T *g4 = alloc_from(work, nlay); + T *lam = alloc_from(work, nlay); + T *gam = alloc_from(work, nlay); + T *denom = alloc_from(work, nlay); + T *Am = alloc_from(work, nlay); + T *Ap = alloc_from(work, nlay); + T *Cpm1 = alloc_from(work, nlay); + T *Cmm1 = alloc_from(work, nlay); + T *Cp = alloc_from(work, nlay); + T *Cm = alloc_from(work, nlay); + T *exptrm = alloc_from(work, nlay); + T *Ep = alloc_from(work, nlay); + T *Em = alloc_from(work, nlay); + T *E1 = alloc_from(work, nlay); + T *E2 = alloc_from(work, nlay); + T *E3 = alloc_from(work, nlay); + T *E4 = alloc_from(work, nlay); + T *Af = alloc_from(work, l); + T *Bf = alloc_from(work, l); + T *Cf = alloc_from(work, l); + T *Df = alloc_from(work, l); + T *xk = alloc_from(work, l); + T *xk1 = alloc_from(work, nlay); + T *xk2 = alloc_from(work, nlay); + + const dp sqrt3 = sqrt(3.0); + const dp sqrt3d2 = sqrt3 / 2.0; + const dp btop = 0.0, bsurf = 0.0; + + // Check for zero albedo + bool all_zero_w = true; + for (int i = 0; i < nlay; i++) { + if (w_in[i] > 1.0e-12) { + all_zero_w = false; break; } } - if (all_zero) { + // compute integrated optical depth + tau_in[0] = 0.0; + for (int i = 0; i < nlay; i++) { + tau_in[i + 1] = tau_in[i] + DTAU_IN(i); + } + + if (all_zero_w) { + // --- Special Case: Direct Beam Only --- if (mu_in[nlev - 1] == mu_in[0]) { - for (int k = 0; k < nlev; ++k) - flx_down[k] = - F0_in * mu_in[nlev - 1] * exp(-tau_in[k] / mu_in[nlev - 1]); + for (int k = 0; k < nlev; k++) { + FLX_DN(k) = F0_in * mu_in[nlev - 1] * exp(-tau_in[k] / mu_in[nlev - 1]); + } } else { cum_trans[0] = tau_in[0] / mu_in[0]; - for (int k = 1; k < nlev; ++k) - cum_trans[k] = - cum_trans[k - 1] + (tau_in[k] - tau_in[k - 1]) / mu_in[k]; - for (int k = 0; k < nlev; ++k) - flx_down[k] = F0_in * mu_in[nlev - 1] * exp(-cum_trans[k]); + for (int k = 0; k < nlev - 1; k++) { + cum_trans[k + 1] = cum_trans[k] + DTAU_IN(k) / mu_in[k + 1]; + } + for (int k = 0; k < nlev; k++) { + FLX_DN(k) = F0_in * mu_in[nlev - 1] * exp(-cum_trans[k]); + } } - flx_down[nlev - 1] *= (1.0 - w_surf_in); - for (int k = 0; k < nlev; ++k) flx_up[k] = 0.0; - } + FLX_DN(nlev - 1) *= (1.0 - w_surf_in); + for (int k = 0; k < nlev; k++) FLX_UP(k) = 0.0; - // Continue with rest of code - for (int k = 0; k < nlay; ++k) dtau_in[k] = tau_in[k + 1] - tau_in[k]; - - for (int k = 0; k < nlay; ++k) { - T g2_val = g_in[k] * g_in[k]; - T denom_val = 1.0 - w_in[k] * g2_val; - w0[k] = ((1.0 - g2_val) * w_in[k]) / denom_val; - dtau[k] = denom_val * dtau_in[k]; - hg[k] = g_in[k] / (1.0 + g_in[k]); - } - - tau[0] = 0.0; - for (int k = 0; k < nlay; ++k) tau[k + 1] = tau[k] + dtau[k]; - - if (mu_in[nlev - 1] == mu_in[0]) { - for (int k = 0; k < nlev; ++k) - dir[k] = F0_in * mu_in[nlev - 1] * exp(-tau[k] / mu_in[nlev - 1]); - for (int k = 0; k < nlay; ++k) mu_zm[k] = mu_in[nlev - 1]; } else { - cum_trans[0] = tau[0] / mu_in[0]; - for (int k = 1; k < nlev; ++k) - cum_trans[k] = cum_trans[k - 1] + (tau[k] - tau[k - 1]) / mu_in[k]; - for (int k = 0; k < nlev; ++k) - dir[k] = F0_in * mu_in[nlev - 1] * exp(-cum_trans[k]); - for (int k = 0; k < nlay; ++k) mu_zm[k] = 0.5 * (mu_in[k] + mu_in[k + 1]); - } - - for (int k = 0; k < nlay; ++k) { - g1[k] = sqrt3d2 * (2.0 - w0[k] * (1.0 + hg[k])); - g2[k] = sqrt3d2 * w0[k] * (1.0 - hg[k]); - if (g2[k] == 0.0) g2[k] = 1e-10; - g3[k] = 0.5 * (1.0 - sqrt3 * hg[k] / mu_zm[k]); - g4[k] = 1.0 - g3[k]; - - lam[k] = sqrt(g1[k] * g1[k] - g2[k] * g2[k]); - gam[k] = (g1[k] - lam[k]) / g2[k]; - - denom[k] = lam[k] * lam[k] - 1.0 / (mu_zm[k] * mu_zm[k]); - if (denom[k] == 0.0) denom[k] = 1e-10; - - Am[k] = F0_in * w0[k] * (g4[k] * (g1[k] + 1.0 / mu_zm[k]) + g2[k] * g3[k]) / - denom[k]; - Ap[k] = F0_in * w0[k] * (g3[k] * (g1[k] - 1.0 / mu_zm[k]) + g2[k] * g4[k]) / - denom[k]; - } - - for (int k = 0; k < nlay; ++k) { - opt1[k] = exp(-tau[k] / mu_zm[k]); - Cpm1[k] = Ap[k] * opt1[k]; - Cmm1[k] = Am[k] * opt1[k]; - - opt1[k] = exp(-tau[k + 1] / mu_zm[k]); - Cp[k] = Ap[k] * opt1[k]; - Cm[k] = Am[k] * opt1[k]; + // --- General Case: Toon et al. 1989 Solver --- - exptrm[k] = fmin(lam[k] * dtau[k], 35.0); - Ep[k] = exp(exptrm[k]); - Em[k] = 1.0 / Ep[k]; - - E1[k] = Ep[k] + gam[k] * Em[k]; - E2[k] = Ep[k] - gam[k] * Em[k]; - E3[k] = gam[k] * Ep[k] + Em[k]; - E4[k] = gam[k] * Ep[k] - Em[k]; - } - - // System assembly - Af[0] = 0.0; - Bf[0] = gam[0] + 1.0; - Cf[0] = gam[0] - 1.0; - Df[0] = btop - Cmm1[0]; - - n = 0; - for (int i = 1; i < lm2; i += 2) { - Af[i] = (E1[n] + E3[n]) * (gam[n + 1] - 1.0); - Bf[i] = (E2[n] + E4[n]) * (gam[n + 1] - 1.0); - Cf[i] = 2.0 * (1.0 - gam[n + 1] * gam[n + 1]); - Df[i] = (gam[n + 1] - 1.0) * (Cpm1[n + 1] - Cp[n]) + - (1.0 - gam[n + 1]) * (Cm[n] - Cmm1[n + 1]); - ++n; - } - - n = 0; - for (int i = 2; i < lm1; i += 2) { - Af[i] = 2.0 * (1.0 - gam[n] * gam[n]); - Bf[i] = (E1[n] - E3[n]) * (1.0 + gam[n + 1]); - Cf[i] = (E1[n] + E3[n]) * (gam[n + 1] - 1.0); - Df[i] = E3[n] * (Cpm1[n + 1] - Cp[n]) + E1[n] * (Cm[n] - Cmm1[n + 1]); - ++n; - } - - Af[l - 1] = E1[nlay - 1] - w_surf_in * E3[nlay - 1]; - Bf[l - 1] = E2[nlay - 1] - w_surf_in * E4[nlay - 1]; - Cf[l - 1] = 0.0; - Df[l - 1] = bsurf - Cp[nlay - 1] + w_surf_in * Cm[nlay - 1]; + for (int i = 0; i < nlay; i++) { + dp g_sq = g_in[i] * g_in[i]; + w0[i] = ((1.0 - g_sq) * w_in[i]) / (1.0 - w_in[i] * g_sq); + dtau[i] = (1.0 - w_in[i] * g_sq) * DTAU_IN(i); + hg[i] = g_in[i] / (1.0 + g_in[i]); + } - dtridgl(l, Af, Bf, Cf, Df, xk, mem, offset); - if (offset > memsize) { - fprintf(stderr, - "Error: Memory allocation failed in toon_mckay89_shortwave\n"); - exit(EXIT_FAILURE); - } + tau[0] = 0.0; + for (int k = 0; k < nlay; k++) tau[k + 1] = tau[k] + dtau[k]; - for (int n = 0; n < nlay; ++n) { - xk1[n] = xk[2 * n] + xk[2 * n + 1]; - xk2[n] = xk[2 * n] - xk[2 * n + 1]; - if (fabs(xk2[n] / xk[2 * n]) < 1e-30) xk2[n] = 0.0; - } + if (mu_in[nlev - 1] == mu_in[0]) { + dp mu_val = mu_in[nlev - 1]; + for (int k = 0; k < nlev; k++) + dir[k] = F0_in * mu_val * exp(-tau[k] / mu_val); + for (int i = 0; i < nlay; i++) mu_zm[i] = mu_val; + } else { + cum_trans[0] = tau[0] / mu_in[0]; + for (int k = 0; k < nlev - 1; k++) + cum_trans[k + 1] = cum_trans[k] + (tau[k + 1] - tau[k]) / mu_in[k + 1]; + for (int k = 0; k < nlev; k++) + dir[k] = F0_in * mu_in[nlev - 1] * exp(-cum_trans[k]); + for (int i = 0; i < nlay; i++) mu_zm[i] = (mu_in[i] + mu_in[i + 1]) / 2.0; + } - for (int n = 0; n < nlay; ++n) { - flx_up[n] = xk1[n] + gam[n] * xk2[n] + Cpm1[n]; - flx_down[n] = xk1[n] * gam[n] + xk2[n] + Cmm1[n]; - } + for (int i = 0; i < nlay; i++) { + g1[i] = sqrt3d2 * (2.0 - w0[i] * (1.0 + hg[i])); + g2[i] = (sqrt3d2 * w0[i]) * (1.0 - hg[i]); + if (g2[i] == 0.0) g2[i] = 1.0e-10; + g3[i] = (1.0 - sqrt3 * hg[i] * mu_zm[i]) / 2.0; + g4[i] = 1.0 - g3[i]; + lam[i] = sqrt(g1[i] * g1[i] - g2[i] * g2[i]); + gam[i] = (g1[i] - lam[i]) / g2[i]; + denom[i] = (lam[i] * lam[i]) - 1.0 / (mu_zm[i] * mu_zm[i]); + if (denom[i] == 0.0) denom[i] = 1.0e-10; + Ap[i] = F0_in * w0[i] * + (g3[i] * (g1[i] - 1.0 / mu_zm[i]) + g2[i] * g4[i]) / denom[i]; + Am[i] = F0_in * w0[i] * + (g4[i] * (g1[i] + 1.0 / mu_zm[i]) + g2[i] * g3[i]) / denom[i]; + Cpm1[i] = Ap[i] * exp(-tau[i] / mu_zm[i]); + Cmm1[i] = Am[i] * exp(-tau[i] / mu_zm[i]); + Cp[i] = Ap[i] * exp(-tau[i + 1] / mu_zm[i]); + Cm[i] = Am[i] * exp(-tau[i + 1] / mu_zm[i]); + exptrm[i] = fmin(lam[i] * dtau[i], 35.0); + Ep[i] = exp(exptrm[i]); + Em[i] = 1.0 / Ep[i]; + E1[i] = Ep[i] + gam[i] * Em[i]; + E2[i] = Ep[i] - gam[i] * Em[i]; + E3[i] = gam[i] * Ep[i] + Em[i]; + E4[i] = gam[i] * Ep[i] - Em[i]; + } - flx_up[nlev - 1] = xk1[nlay - 1] * Ep[nlay - 1] + - gam[nlay - 1] * xk2[nlay - 1] * Em[nlay - 1] + - Cp[nlay - 1]; - flx_down[nlev - 1] = xk1[nlay - 1] * Ep[nlay - 1] * gam[nlay - 1] + + // Matrix Setup + Af[1] = 0.0; + Bf[1] = gam[0] + 1.0; + Cf[1] = gam[0] - 1.0; + Df[1] = btop - Cmm1[0]; + int n_idx = 0; + for (int i = 2; i <= lm2; i += 2) { + Af[i] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); + Bf[i] = (E2[n_idx] + E4[n_idx]) * (gam[n_idx + 1] - 1.0); + Cf[i] = 2.0 * (1.0 - gam[n_idx + 1] * gam[n_idx + 1]); + Df[i] = (gam[n_idx + 1] - 1.0) * (Cpm1[n_idx + 1] - Cp[n_idx]) + + (1.0 - gam[n_idx + 1]) * (Cm[n_idx] - Cmm1[n_idx + 1]); + n_idx++; + } + n_idx = 0; + for (int i = 3; i <= lm1; i += 2) { + Af[i] = 2.0 * (1.0 - gam[n_idx] * gam[n_idx]); + Bf[i] = (E1[n_idx] - E3[n_idx]) * (1.0 + gam[n_idx + 1]); + Cf[i] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); + Df[i] = E3[n_idx] * (Cpm1[n_idx + 1] - Cp[n_idx]) + + E1[n_idx] * (Cm[n_idx] - Cmm1[n_idx + 1]); + n_idx++; + } + Af[l] = E1[nlay - 1] - w_surf_in * E3[nlay - 1]; + Bf[l] = E2[nlay - 1] - w_surf_in * E4[nlay - 1]; + Cf[l] = 0.0; + Df[l] = bsurf - Cp[nlay - 1] + w_surf_in * Cm[nlay - 1]; + + dtridgl(l, Af, Bf, Cf, Df, xk); + + for (int n = 0; n < nlay; n++) { + xk1[n] = xk[2 * n + 1] + xk[2 * n + 2]; + xk2[n] = xk[2 * n + 1] - xk[2 * n + 2]; + if (fabs(xk2[n]) < 1e-30 * fabs(xk[2 * n + 1])) xk2[n] = 0.0; + FLX_UP(n) = xk1[n] + gam[n] * xk2[n] + Cpm1[n]; + FLX_DN(n) = xk1[n] * gam[n] + xk2[n] + Cmm1[n]; + } + FLX_UP(nlev - 1) = xk1[nlay - 1] * Ep[nlay - 1] + + gam[nlay - 1] * xk2[nlay - 1] * Em[nlay - 1] + + Cp[nlay - 1]; + FLX_DN(nlev - 1) = xk1[nlay - 1] * Ep[nlay - 1] * gam[nlay - 1] + xk2[nlay - 1] * Em[nlay - 1] + Cm[nlay - 1]; - - for (int n = 0; n < nlev; ++n) flx_down[n] += dir[n]; + for (int k = 0; k < nlev; k++) FLX_DN(k) += dir[k]; + } } } // namespace harp + +#undef DTAU_IN +#undef W_IN +#undef G_IN +#undef FLX_UP +#undef FLX_DN From bd43fd84f264fd1951619cb514291cb925e7bd7b Mon Sep 17 00:00:00 2001 From: mac/cli Date: Sun, 18 Jan 2026 13:34:44 -0500 Subject: [PATCH 3/9] compiles --- src/radiation/bbflux.cpp | 70 +++++ src/radiation/bbflux.hpp | 4 + src/rtsolver/rt_solver_toon.cpp | 340 --------------------- src/rtsolver/rtsolver_dispatch.cpp | 12 +- src/rtsolver/toon_mckay89.cpp | 7 +- src/rtsolver/toon_mckay89.hpp | 4 +- src/rtsolver/toon_mckay89_longwave_impl.h | 16 +- src/rtsolver/toon_mckay89_shortwave_impl.h | 20 +- src/utils/alloc.h | 126 ++++++++ 9 files changed, 238 insertions(+), 361 deletions(-) delete mode 100644 src/rtsolver/rt_solver_toon.cpp create mode 100644 src/utils/alloc.h diff --git a/src/radiation/bbflux.cpp b/src/radiation/bbflux.cpp index eeaa9b1..8f5983a 100644 --- a/src/radiation/bbflux.cpp +++ b/src/radiation/bbflux.cpp @@ -89,6 +89,76 @@ torch::Tensor bbflux_wavenumber(double wn1, double wn2, torch::Tensor temp) { return ans * sigdpi * torch::pow(temp, 4); } +torch::Tensor bbflux_wavenumber(torch::Tensor wn1, torch::Tensor wn2, + torch::Tensor temp) { + if ((wn2 < wn1).any().item() || (wn1 < 0.0).any().item()) { + TORCH_CHECK(false, "bbflux_wavenumber: Invalid wavenumbers"); + } + + TORCH_CHECK(temp.min().item() > 0.0, + "bbflux_wavenumber: Temperature must be positive"); + + const double C2 = 1.438786; // h * c / k in units cm * K + const double SIGMA = 5.67032e-8; // Stefan-Boltzmann constant in W/m²K⁴ + const double VCUT = 1.5; + const double sigdpi = SIGMA / M_PI; + const double vmax = std::log(DBL_MAX); + const double conc = 15.0 / std::pow(M_PI, 4); // Now computed at runtime + const double c1 = 1.1911e-18; // h * c^2, in units W/(m² * sr * cm⁻⁴) + const double A1 = 1.0 / 3.0; + const double A2 = -1.0 / 8.0; + const double A3 = 1.0 / 60.0; + const double A4 = -1.0 / 5040.0; + const double A5 = 1.0 / 272160.0; + const double A6 = -1.0 / 13305600.0; + + // Handle the case where wn1 == wn2 + if ((wn1 == wn2).all().item()) { + auto wn = wn1; + auto arg = torch::exp(-C2 * wn / temp); + return c1 * wn.pow(3) * arg / (1.0 - arg); + } + + torch::Tensor v[2] = {C2 * wn1 / temp, C2 * wn2 / temp}; + torch::Tensor smallv = torch::zeros_like(temp); + torch::Tensor p[2]; + torch::Tensor d[2]; + + // Handle different cases for wavenumbers + for (int i = 0; i <= 1; ++i) { + smallv += torch::where(v[i] < VCUT, torch::ones_like(temp), + torch::zeros_like(temp)); + + auto vsq = v[i] * v[i]; + p[i] = + conc * vsq * v[i] * + (A1 + v[i] * (A2 + v[i] * (A3 + vsq * (A4 + vsq * (A5 + vsq * A6))))); + p[i] = torch::where(v[i] < VCUT, p[i], torch::zeros_like(temp)); + + // Use exponential series expansion + const double vcp[7] = {10.25, 5.7, 3.9, 2.9, 2.3, 1.9, 0.0}; + + auto ex = torch::exp(-v[i]); + auto exm = torch::ones_like(temp); + d[i] = torch::zeros_like(temp); + + for (int m = 1; m <= 6; ++m) { + auto mv = static_cast(m) * v[i]; + exm *= ex; + d[i] += exm * (6.0 + mv * (6.0 + mv * (3.0 + mv))) / (m * m); + } + d[i] *= conc; + + d[i] = torch::where(v[i] > VCUT, d[i], torch::zeros_like(temp)); + } + + auto ans = + torch::where(smallv == 2, p[1] - p[0], + torch::where(smallv == 1, 1.0 - p[0] - d[1], d[0] - d[1])); + + return ans * sigdpi * torch::pow(temp, 4); +} + torch::Tensor bbflux_wavelength(torch::Tensor wave, double temp, int ncol) { // Check if wave is a 1D tensor TORCH_CHECK(wave.dim() == 1, "wavelength must be a 1D tensor"); diff --git a/src/radiation/bbflux.hpp b/src/radiation/bbflux.hpp index fce2a8d..b15cbcb 100644 --- a/src/radiation/bbflux.hpp +++ b/src/radiation/bbflux.hpp @@ -34,6 +34,10 @@ torch::Tensor bbflux_wavenumber(torch::Tensor wave, double temp, int ncol = 1); */ torch::Tensor bbflux_wavenumber(double wn1, double wn2, torch::Tensor temp); +//! \brief calculate integrated blackbody flux using wavenumber +torch::Tensor bbflux_wavenumber(torch::Tensor wn1, torch::Tensor wn2, + torch::Tensor temp); + //! \brief calculate blackbody flux using wavelength /*! * Formula: diff --git a/src/rtsolver/rt_solver_toon.cpp b/src/rtsolver/rt_solver_toon.cpp deleted file mode 100644 index da6c408..0000000 --- a/src/rtsolver/rt_solver_toon.cpp +++ /dev/null @@ -1,340 +0,0 @@ -// RT solvers based on Toon 1989 method by Xi Zhang -// Reference: Toon, O.B., 1989, JGR, 94, 16287-16301. - -#include -#include -#include -#include -#include - -// external -#include - -// athena -#include -#include - -// application -#include -#include - -// climath -#include - -// canoe -#include -#include - -// astro -#include - -// exo3 -#include -#include - -// harp -#include "radiation.hpp" -#include "rt_solvers.hpp" - -#ifdef RT_DISORT - -RadiationBand::RTSolverToon::RTSolverToon(RadiationBand *pmy_band, - YAML::Node const &rad) - : RTSolver(pmy_band, "Toon") { - Application::Logger app("harp"); - app->Log("Toon solver initialized for band " + pmy_band_->GetName()); -} - -//! \todo update based on band outdir -void RadiationBand::RTSolverToon::Resize(int nlyr, int nstr) { - RadiationBand::RTSolver::Resize(nlyr, nstr); - Unseal(); - SetAtmosphereDimension(nlyr, nstr, nstr); - Seal(); -} - -void RadiationBand::RTSolverToon::Prepare(MeshBlock const *pmb, int k, int j) { - auto &wmin = pmy_band_->wrange_.first; - auto &wmax = pmy_band_->wrange_.second; - - Real dist_au = 1.0; - Direction ray = pmb->pimpl->prad->GetRayInput(0); - auto planet = pmb->pimpl->planet; - - if (planet && pmy_band_->TestFlag(RadiationFlags::TimeDependent)) { - Real time = pmb->pmy_mesh->time; - Real lat, lon; - - CubedSphereUtility::get_latlon_on_sphere(&lat, &lon, pmb, k, j, pmb->is); - - ray = planet->ParentZenithAngle(time, lat, lon); - dist_au = planet->ParentDistanceInAu(time); - } else { - if (pmy_band_->HasPar("umu0")) { - ray.mu = pmy_band_->GetPar("umu0"); - } - - if (pmy_band_->HasPar("phi0")) { - ray.phi = pmy_band_->GetPar("phi0"); - } - - if (pmy_band_->HasPar("dist_au")) { - dist_au = pmy_band_->GetPar("dist_au"); - } - } - - // pack temperature - if (pmy_band_->TestFlag(RadiationFlags::ThermalEmission)) { - pmy_band_->packTemperature(); - } - - // pack spectral properties - pmy_band_->packSpectralProperties(); - ds_.bc.umu0 = ray.mu > 1.E-3 ? ray.mu : 1.E-3; - - if (pmy_band_->TestFlag(RadiationFlags::BroadBand)) { - // stellar source function overrides fbeam - if (pmy_band_->HasPar("S0")) { - ds_.bc.fbeam = pmy_band_->GetPar("S0"); - } else if (pmy_band_->HasPar("temp0")) { - Real temp0 = pmy_band_->GetPar("temp0"); - ds_.bc.fbeam = Constants::stefanBoltzmann * pow(temp0, 4); - } else if (planet && planet->HasParentFlux()) { - ds_.bc.fbeam = planet->ParentInsolationFlux(wmin, wmax, 1.); - } else { - ds_.bc.fbeam = 0.; - } - ds_.bc.fbeam /= dist_au * dist_au; - } - - pmb->pcoord->Face1Area(k, j, pmb->is, pmb->ie + 1, farea_); - pmb->pcoord->CellVolume(k, j, pmb->is, pmb->ie, vol_); -} - -void RadiationBand::RTSolverToon::CalBandFlux(MeshBlock const *pmb, int k, - int j) { - Real dist_au = 1.0; - auto planet = pmb->pimpl->planet; - - if (planet && pmy_band_->TestFlag(RadiationFlags::TimeDependent)) { - dist_au = planet->ParentDistanceInAu(pmb->pmy_mesh->time); - } else if (pmy_band_->HasPar("dist_au")) { - dist_au = pmy_band_->GetPar("dist_au"); - } - - // loop over spectral grids in the band - bool override_with_stellar_spectra = false; - if (!pmy_band_->TestFlag(RadiationFlags::BroadBand) && - !pmy_band_->HasPar("S0") && !pmy_band_->HasPar("temp0") && planet && - planet->HasParentFlux()) { - override_with_stellar_spectra = true; - } - - pmy_band_->pexv->GatherAll(pmb); - if (pmy_band_->TestFlag(RadiationFlags::ThermalEmission)) { - pmy_band_->unpackTemperature(&ds_); - } - - int b = 0; - for (auto &spec : pmy_band_->pgrid_->spec) { - if (override_with_stellar_spectra) { - // stellar source function - ds_.bc.fbeam = - planet->ParentInsolationFlux(spec.wav1, spec.wav2, dist_au); - } - - // Transfer spectral grid data - pmy_band_->unpackSpectralProperties(b, &ds_); - - // add spectral bin flux - addToonFlux(pmb->pcoord, b++, k, j, pmb->is, pmb->ie + 1, flux_up, - flux_down); - } -} - -void RadiationBand::RTSolverToon::addToonFlux( - Coordinates const *pcoord, int b, int k, int j, int il, int iu, - const Eigen::VectorXd &flux_up, const Eigen::VectorXd &flux_down) { - auto &bflxup = pmy_band_->bflxup; - auto &bflxdn = pmy_band_->bflxdn; - - auto &flxup = pmy_band_->flxup_; - auto &flxdn = pmy_band_->flxdn_; - auto const &spec = pmy_band_->pgrid_->spec; - - int rank_in_column = pmy_band_->pexv->GetRankInGroup(); - - // Accumulate flux from spectral bins - for (int i = il; i <= iu; ++i) { - int m = ds_.nlyr - (rank_in_column * (iu - il) + i - il); - // Flux up - flxup(b, k, j, i) = flux_up(m); - // Flux down - flxdn(b, k, j, i) = flux_down(m); - - bflxup(k, j, i) += spec[b].wght * flxup(b, k, j, i); - bflxdn(k, j, i) += spec[b].wght * flxdn(b, k, j, i); - } - - // Spherical correction - Real volh; - Real bflxup_iu = bflxup(k, j, iu); - Real bflxdn_iu = bflxdn(k, j, iu); - - for (int i = iu - 1; i >= il; --i) { - // Upward - volh = (bflxup_iu - bflxup(k, j, i)) / pcoord->dx1f(i) * vol_(i); - bflxup_iu = bflxup(k, j, i); - bflxup(k, j, i) = (bflxup(k, j, i + 1) * farea_(i + 1) - volh) / farea_(i); - - // Downward - volh = (bflxdn_iu - bflxdn(k, j, i)) / pcoord->dx1f(i) * vol_(i); - bflxdn_iu = bflxdn(k, j, i); - bflxdn(k, j, i) = (bflxdn(k, j, i + 1) * farea_(i + 1) - volh) / farea_(i); - } - - /* - for (int i = iu; i >= il; --i) { - std::cout << "i: " << iu-i+1 <<" flxup: " << bflxup(k, j, i) << " flxdn: " - << bflxdn(k, j, i) << " fluxdiff: " << bflxup(k, j, i) - bflxdn(k, j, i) << - std::endl; - } -*/ -} - -// Inegrate Planck function over a band, based on cdisort -double RadiationBand::RTSolverToon::BB_integrate(double T, double wn1, - double wn2) { - if (T < 1e-4 || wn2 < wn1 || wn1 < 0.0) { - throw std::invalid_argument( - "BB_integrate: Invalid temperature or wavenumbers"); - } - - constexpr double C2 = 1.438786; // h * c / k in units cm * K - constexpr double SIGMA = 5.67032e-8; // Stefan-Boltzmann constant in W/m²K⁴ - constexpr double VCUT = 1.5; - constexpr double sigdpi = SIGMA / M_PI; - const double vmax = std::log(DBL_MAX); - const double conc = 15.0 / std::pow(M_PI, 4); // Now computed at runtime - constexpr double c1 = 1.1911e-18; // h * c^2, in units W/(m² * sr * cm⁻⁴) - constexpr double A1 = 1.0 / 3.0; - constexpr double A2 = -1.0 / 8.0; - constexpr double A3 = 1.0 / 60.0; - constexpr double A4 = -1.0 / 5040.0; - constexpr double A5 = 1.0 / 272160.0; - constexpr double A6 = -1.0 / 13305600.0; - // Helper function to compute Planck integrand value - auto planck_function = [](double v) { - return std::pow(v, 3) / (std::exp(v) - 1.0); - }; - - // Handle the case where wn1 == wn2 - if (wn1 == wn2) { - double wn = wn1; - double arg = std::exp(-C2 * wn / T); - return c1 * std::pow(wn, 3) * arg / (1.0 - arg); - } - - double v[2] = {C2 * wn1 / T, C2 * wn2 / T}; - double p[2] = {0.0, 0.0}, d[2] = {0.0, 0.0}; - int smallv = 0; - - // Handle different cases for wavenumbers - for (int i = 0; i <= 1; ++i) { - if (v[i] < VCUT) { - // Use power series expansion - smallv++; - double vsq = v[i] * v[i]; - p[i] = - conc * vsq * v[i] * - (A1 + v[i] * (A2 + v[i] * (A3 + vsq * (A4 + vsq * (A5 + vsq * A6))))); - } else { - // Use exponential series expansion - int mmax = 1; - static const double vcp[7] = {10.25, 5.7, 3.9, 2.9, 2.3, 1.9, 0.0}; - while (v[i] < vcp[mmax - 1] && mmax < 7) { - ++mmax; - } - - double ex = std::exp(-v[i]); - double exm = 1.0; - d[i] = 0.0; - - for (int m = 1; m <= mmax; ++m) { - double mv = static_cast(m) * v[i]; - exm *= ex; - d[i] += exm * (6.0 + mv * (6.0 + mv * (3.0 + mv))) / (m * m); - } - d[i] *= conc; - } - } - - double ans; - if (smallv == 2) { - // Both wavenumbers are small - ans = p[1] - p[0]; - } else if (smallv == 1) { - // One wavenumber is small, the other is large - ans = 1.0 - p[0] - d[1]; - } else { - // Both wavenumbers are large - ans = d[0] - d[1]; - } - - ans *= sigdpi * T * T * T * T; - - if (ans == 0.0) { - std::cerr << "BB_integrate: Warning - result is zero; possible underflow" - << std::endl; - } - - return ans; -} - -// Tridiagonal Solver using the Thomas Algorithm -inline Eigen::VectorXd RadiationBand::RTSolverToon::tridiagonal_solver( - const Eigen::VectorXd &a, const Eigen::VectorXd &b, - const Eigen::VectorXd &c, const Eigen::VectorXd &d) { - int l = b.size(); - if (a.size() != static_cast(l - 1) || - c.size() != static_cast(l - 1) || - d.size() != static_cast(l)) { - throw std::invalid_argument( - "Incorrect vector sizes for tridiagonal_solver."); - } - - Eigen::VectorXd c_prime(l - 1); - Eigen::VectorXd d_prime(l); - - // Forward sweep - c_prime(0) = c(0) / b(0); - d_prime(0) = d(0) / b(0); - for (int i = 1; i < l - 1; ++i) { - double denom = b(i) - a(i - 1) * c_prime(i - 1); - if (std::abs(denom) < 1e-12) { - throw std::runtime_error( - "Tridiagonal solver failed: near-zero denominator."); - } - c_prime(i) = c(i) / denom; - d_prime(i) = (d(i) - a(i - 1) * d_prime(i - 1)) / denom; - } - - // Last equation - double denom_last = b(l - 1) - a(l - 2) * c_prime(l - 2); - if (std::abs(denom_last) < 1e-12) { - throw std::runtime_error( - "Tridiagonal solver failed: near-zero denominator at last equation."); - } - d_prime(l - 1) = (d(l - 1) - a(l - 2) * d_prime(l - 2)) / denom_last; - - // Back substitution - Eigen::VectorXd x(l); - x(l - 1) = d_prime(l - 1); - for (int i = l - 2; i >= 0; --i) { - x(i) = d_prime(i) - c_prime(i) * x(i + 1); - } - - return x; -} - -#endif diff --git a/src/rtsolver/rtsolver_dispatch.cpp b/src/rtsolver/rtsolver_dispatch.cpp index b1ba855..7ce13bd 100644 --- a/src/rtsolver/rtsolver_dispatch.cpp +++ b/src/rtsolver/rtsolver_dispatch.cpp @@ -6,6 +6,7 @@ #include // harp +#include "rtsolver_dispatch.hpp" #include "toon_mckay89_longwave_impl.h" #include "toon_mckay89_shortwave_impl.h" @@ -15,6 +16,8 @@ void call_toon89_sw_cpu(at::TensorIterator &iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_sw_cpu", [&] { int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); int grain_size = iter.numel() / at::get_num_threads(); + int mem_size = toon89_sw_space(nlay); + char *work = new char[mem_size]; iter.for_each( [&](char **data, const int64_t *strides, int64_t n) { @@ -25,10 +28,13 @@ void call_toon89_sw_cpu(at::TensorIterator &iter) { auto fbeam = reinterpret_cast(data[3] + i * strides[3]); auto albedo = reinterpret_cast(data[4] + i * strides[4]); - toon_mckay89_shortwave(nlay, *fbeam, umu0, prop, *albedo, out, work) + toon_mckay89_shortwave(nlay, *fbeam, umu0, prop, *albedo, out, + work); } }, grain_size); + + delete[] work; }); } @@ -36,6 +42,8 @@ void call_toon89_lw_cpu(at::TensorIterator &iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_lw_cpu", [&] { int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); int grain_size = iter.numel() / at::get_num_threads(); + int mem_size = toon89_sw_space(nlay); + char *work = new char[mem_size]; iter.for_each( [&](char **data, const int64_t *strides, int64_t n) { @@ -49,6 +57,8 @@ void call_toon89_lw_cpu(at::TensorIterator &iter) { } }, grain_size); + + delete[] work; }); } diff --git a/src/rtsolver/toon_mckay89.cpp b/src/rtsolver/toon_mckay89.cpp index 368d378..c9e03f6 100644 --- a/src/rtsolver/toon_mckay89.cpp +++ b/src/rtsolver/toon_mckay89.cpp @@ -1,7 +1,7 @@ // harp #include "toon_mckay89.hpp" -#include +#include #include "rtsolver_dispatch.hpp" @@ -18,6 +18,7 @@ void ToonMcKay89Impl::reset() { torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, std::map* bc, + std::string bname, torch::optional temf) { // check dimensions TORCH_CHECK(prop.dim() == 4, "ToonMcKay89::forward: prop.dim() != 4"); @@ -101,7 +102,9 @@ torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, temp(i) = ds_.temper[i]; be(i) = BB_integrate(ds_.temper[i], spec.wav1, spec.wav2); }*/ - auto be = bbflux_wavenumber(wave, tempf.value()); + auto wave_lo = torch::tensor(options->wave_lower(), prop.options()); + auto wave_hi = torch::tensor(options->wave_upper(), prop.options()); + auto be = bbflux_wavenumber(wave_lo, wave_hi, temf.value()); auto iter = at::TensorIteratorConfig() .resize_outputs(false) .check_all_same_dtype(true) diff --git a/src/rtsolver/toon_mckay89.hpp b/src/rtsolver/toon_mckay89.hpp index cd13c3e..8b9756b 100644 --- a/src/rtsolver/toon_mckay89.hpp +++ b/src/rtsolver/toon_mckay89.hpp @@ -13,7 +13,7 @@ namespace harp { struct ToonMcKay89OptionsImpl { - ToonMcKay89Options(); + ToonMcKay89OptionsImpl() {} static std::shared_ptr create() { return std::make_shared(); } @@ -40,7 +40,7 @@ struct ToonMcKay89OptionsImpl { ADD_ARG(bool, zenith_correction) = false; }; -using ToonMcKay89Options = std::shared_ptr; +using ToonMcKay89Options = std::shared_ptr; class ToonMcKay89Impl : public torch::nn::Cloneable { public: diff --git a/src/rtsolver/toon_mckay89_longwave_impl.h b/src/rtsolver/toon_mckay89_longwave_impl.h index bfc786d..77f4316 100644 --- a/src/rtsolver/toon_mckay89_longwave_impl.h +++ b/src/rtsolver/toon_mckay89_longwave_impl.h @@ -9,6 +9,8 @@ #include // harp +#include + #include "dtridgl_impl.h" #define DTAU_IN(i) prop[(nlay - i - 1) * 3] @@ -32,14 +34,14 @@ DISPATCH_MACRO void toon_mckay89_longwave(int nlay, const T *be, const T *prop, const T twopi = 2.0 * M_PI; const T ubari = 0.5; - const T *uarr = {0.0985350858, 0.3045357266, 0.5620251898, 0.8019865821, - 0.9601901429}; - const T *wuarr = {0.0157479145, 0.0739088701, 0.1463869871, 0.1671746381, - 0.0967815902}; + const T uarr[] = {0.0985350858, 0.3045357266, 0.5620251898, 0.8019865821, + 0.9601901429}; + const T wuarr[] = {0.0157479145, 0.0739088701, 0.1463869871, 0.1671746381, + 0.0967815902}; // --- Work Variables Allocation --- - dtau = alloc_from(work, nlay); - tau = alloc_from(work, nlev); + T *dtau = alloc_from(work, nlay); + T *tau = alloc_from(work, nlev); T *w0 = alloc_from(work, nlay); T *hg = alloc_from(work, nlay); T *B0 = alloc_from(work, nlay); @@ -173,7 +175,7 @@ DISPATCH_MACRO void toon_mckay89_longwave(int nlay, const T *be, const T *prop, g[n] = 0.0; h[n] = 0.0; xj[n] = 0.0; - xk_ptr[n] = 0.0; // using xk_ptr because xk name conflict + xk[n] = 0.0; alpha1[n] = twopi * B0[n]; alpha2[n] = twopi * B1[n]; sigma1[n] = alpha1[n]; diff --git a/src/rtsolver/toon_mckay89_shortwave_impl.h b/src/rtsolver/toon_mckay89_shortwave_impl.h index fca8897..58048b9 100644 --- a/src/rtsolver/toon_mckay89_shortwave_impl.h +++ b/src/rtsolver/toon_mckay89_shortwave_impl.h @@ -7,6 +7,8 @@ #include // harp +#include + #include "dtridgl_impl.h" #define DTAU_IN(i) prop[(nlay - i - 1) * 3] @@ -63,14 +65,14 @@ DISPATCH_MACRO void toon_mckay89_shortwave(int nlay, T F0_in, T const *mu_in, T *xk1 = alloc_from(work, nlay); T *xk2 = alloc_from(work, nlay); - const dp sqrt3 = sqrt(3.0); - const dp sqrt3d2 = sqrt3 / 2.0; - const dp btop = 0.0, bsurf = 0.0; + const T sqrt3 = sqrt(3.0); + const T sqrt3d2 = sqrt3 / 2.0; + const T btop = 0.0, bsurf = 0.0; // Check for zero albedo bool all_zero_w = true; for (int i = 0; i < nlay; i++) { - if (w_in[i] > 1.0e-12) { + if (W_IN(i) > 1.0e-12) { all_zero_w = false; break; } @@ -104,17 +106,17 @@ DISPATCH_MACRO void toon_mckay89_shortwave(int nlay, T F0_in, T const *mu_in, // --- General Case: Toon et al. 1989 Solver --- for (int i = 0; i < nlay; i++) { - dp g_sq = g_in[i] * g_in[i]; - w0[i] = ((1.0 - g_sq) * w_in[i]) / (1.0 - w_in[i] * g_sq); - dtau[i] = (1.0 - w_in[i] * g_sq) * DTAU_IN(i); - hg[i] = g_in[i] / (1.0 + g_in[i]); + T g_sq = G_IN(i) * G_IN(i); + w0[i] = ((1.0 - g_sq) * W_IN(i)) / (1.0 - W_IN(i) * g_sq); + dtau[i] = (1.0 - W_IN(i) * g_sq) * DTAU_IN(i); + hg[i] = G_IN(i) / (1.0 + G_IN(i)); } tau[0] = 0.0; for (int k = 0; k < nlay; k++) tau[k + 1] = tau[k] + dtau[k]; if (mu_in[nlev - 1] == mu_in[0]) { - dp mu_val = mu_in[nlev - 1]; + T mu_val = mu_in[nlev - 1]; for (int k = 0; k < nlev; k++) dir[k] = F0_in * mu_val * exp(-tau[k] / mu_val); for (int i = 0; i < nlay; i++) mu_zm[i] = mu_val; diff --git a/src/utils/alloc.h b/src/utils/alloc.h new file mode 100644 index 0000000..9f74482 --- /dev/null +++ b/src/utils/alloc.h @@ -0,0 +1,126 @@ +#pragma once + +// C/C++ +#include +#include +#include + +// base +#include + +namespace harp { + +DISPATCH_MACRO inline uintptr_t align_up(uintptr_t p, size_t a) { + // a must be power of two; works for 4, 8, 16, ... + return (p + (a - 1)) & ~(a - 1); +} + +template +DISPATCH_MACRO inline U* alloc_from(char*& cursor, size_t count) { + uintptr_t p = reinterpret_cast(cursor); + p = align_up(p, alignof(U)); + U* out = reinterpret_cast(p); + cursor = reinterpret_cast(p + count * sizeof(U)); + return out; +} + +template +size_t toon89_sw_space(int nlay) { + size_t bytes = 0; + auto bump = [&](size_t align, size_t nbytes) { + bytes = static_cast(align_up(bytes, align)) + nbytes; + }; + + int nlev = nlay + 1; + bump(alignof(T), nlev * sizeof(T)); // dir + bump(alignof(T), nlev * sizeof(T)); // tau + bump(alignof(T), nlev * sizeof(T)); // cum_trans + bump(alignof(T), nlev * sizeof(T)); // tau_in + bump(alignof(T), nlay * sizeof(T)); // dtau + bump(alignof(T), nlay * sizeof(T)); // mu_zm + bump(alignof(T), nlay * sizeof(T)); // w0 + bump(alignof(T), nlay * sizeof(T)); // hg + bump(alignof(T), nlay * sizeof(T)); // g1 + bump(alignof(T), nlay * sizeof(T)); // g2 + bump(alignof(T), nlay * sizeof(T)); // g3 + bump(alignof(T), nlay * sizeof(T)); // g4 + bump(alignof(T), nlay * sizeof(T)); // lam + bump(alignof(T), nlay * sizeof(T)); // gam + bump(alignof(T), nlay * sizeof(T)); // denom + bump(alignof(T), nlay * sizeof(T)); // Am + bump(alignof(T), nlay * sizeof(T)); // Ap + bump(alignof(T), nlay * sizeof(T)); // Cpm1 + bump(alignof(T), nlay * sizeof(T)); // Cmm1 + bump(alignof(T), nlay * sizeof(T)); // Cp + bump(alignof(T), nlay * sizeof(T)); // Cm + bump(alignof(T), nlay * sizeof(T)); // exptrm + bump(alignof(T), nlay * sizeof(T)); // Ep + bump(alignof(T), nlay * sizeof(T)); // Em + bump(alignof(T), nlay * sizeof(T)); // E1 + bump(alignof(T), nlay * sizeof(T)); // E2 + bump(alignof(T), nlay * sizeof(T)); // E3 + bump(alignof(T), nlay * sizeof(T)); // E4 + bump(alignof(T), (2 * nlay) * sizeof(T)); // Af + bump(alignof(T), (2 * nlay) * sizeof(T)); // Bf + bump(alignof(T), (2 * nlay) * sizeof(T)); // Cf + bump(alignof(T), (2 * nlay) * sizeof(T)); // Df + bump(alignof(T), (2 * nlay) * sizeof(T)); // xk + bump(alignof(T), nlay * sizeof(T)); // xk1 + bump(alignof(T), nlay * sizeof(T)); // xk2 + + return bytes; +} + +template +size_t toon89_lw_space(int nlay) { + size_t bytes = 0; + auto bump = [&](size_t align, size_t nbytes) { + bytes = static_cast(align_up(bytes, align)) + nbytes; + }; + int nlev = nlay + 1; + + bump(alignof(T), nlay * sizeof(T)); // dtau + bump(alignof(T), nlev * sizeof(T)); // tau + bump(alignof(T), nlay * sizeof(T)); // w0 + bump(alignof(T), nlay * sizeof(T)); // hg + bump(alignof(T), nlay * sizeof(T)); // B0 + bump(alignof(T), nlay * sizeof(T)); // B1 + bump(alignof(T), nlay * sizeof(T)); // lam + bump(alignof(T), nlay * sizeof(T)); // gam + bump(alignof(T), nlay * sizeof(T)); // alp + bump(alignof(T), nlay * sizeof(T)); // term + bump(alignof(T), nlay * sizeof(T)); // Cpm1 + bump(alignof(T), nlay * sizeof(T)); // Cmm1 + bump(alignof(T), nlay * sizeof(T)); // Cp + bump(alignof(T), nlay * sizeof(T)); // Cm + bump(alignof(T), nlay * sizeof(T)); // exptrm + bump(alignof(T), nlay * sizeof(T)); // Ep + bump(alignof(T), nlay * sizeof(T)); // Em + bump(alignof(T), nlay * sizeof(T)); // E1 + bump(alignof(T), nlay * sizeof(T)); // E2 + bump(alignof(T), nlay * sizeof(T)); // E3 + bump(alignof(T), nlay * sizeof(T)); // E4 + bump(alignof(T), (2 * nlay) * sizeof(T)); // Af + bump(alignof(T), (2 * nlay) * sizeof(T)); // Bf + bump(alignof(T), (2 * nlay) * sizeof(T)); // Cf + bump(alignof(T), (2 * nlay) * sizeof(T)); // Df + bump(alignof(T), (2 * nlay) * sizeof(T)); // xkk + bump(alignof(T), nlay * sizeof(T)); // xk1 + bump(alignof(T), nlay * sizeof(T)); // xk2 + bump(alignof(T), nlay * sizeof(T)); // g + bump(alignof(T), nlay * sizeof(T)); // h + bump(alignof(T), nlay * sizeof(T)); // xj + bump(alignof(T), nlay * sizeof(T)); // xk + bump(alignof(T), nlay * sizeof(T)); // alpha1 + bump(alignof(T), nlay * sizeof(T)); // alpha2 + bump(alignof(T), nlay * sizeof(T)); // sigma1 + bump(alignof(T), nlay * sizeof(T)); // sigma2 + bump(alignof(T), nlay * sizeof(T)); // em1 + bump(alignof(T), nlay * sizeof(T)); // em2 + bump(alignof(T), nlay * sizeof(T)); // em3 + bump(alignof(T), nlev * sizeof(T)); // lw_up_g + bump(alignof(T), nlev * sizeof(T)); // lw_down_g + + return bytes; +} +} // namespace harp From dbd2f2e77c1bfd95f2f74519729b3bb5c106dca2 Mon Sep 17 00:00:00 2001 From: mac/cli Date: Sun, 18 Jan 2026 13:46:33 -0500 Subject: [PATCH 4/9] wip --- src/loops.cuh | 101 ++++++++++++++++++++++++++++++ src/rtsolver/rtsolver_dispatch.cu | 63 +++++++++++-------- 2 files changed, 138 insertions(+), 26 deletions(-) create mode 100644 src/loops.cuh diff --git a/src/loops.cuh b/src/loops.cuh new file mode 100644 index 0000000..81a34ba --- /dev/null +++ b/src/loops.cuh @@ -0,0 +1,101 @@ +#pragma once + +// torch +#include +#include + +namespace harp { +namespace native { + +template +__global__ void element_kernel(int64_t numel, func_t f) { + int tid = threadIdx.x; + int idx = blockIdx.x * blockDim.x + tid; + + // Shared memory allocation + extern __shared__ unsigned char memory[]; + char* smem = reinterpret_cast(memory); + + if (idx < numel) { + f(idx, smem); + } +} + +template +void gpu_kernel(at::TensorIterator& iter, const func_t& f) { + TORCH_CHECK(iter.ninputs() + iter.noutputs() == Arity); + + std::array data; + for (int i = 0; i < Arity; i++) { + data[i] = reinterpret_cast(iter.data_ptr(i)); + } + + auto offset_calc = ::make_offset_calculator(iter); + int64_t numel = iter.numel(); + + at::native::launch_legacy_kernel<128, 1>(numel, + [=] __device__(int idx) { + auto offsets = offset_calc.get(idx); + f(data.data(), offsets.data()); + }); +} + +template +void gpu_mem_kernel(at::TensorIterator& iter, int work_size, const func_t& f) { + TORCH_CHECK(iter.ninputs() + iter.noutputs() == Arity); + + std::array data; + for (int i = 0; i < Arity; i++) { + data[i] = reinterpret_cast(iter.data_ptr(i)); + } + + auto offset_calc = ::make_offset_calculator(iter); + int64_t numel = iter.numel(); + + dim3 block(Threads); + dim3 grid((numel + block.x - 1) / block.x); + auto stream = at::cuda::getCurrentCUDAStream(); + size_t shared = block.x * work_size; + + // set attribute to allow max dynamic shared memory + int device; + cudaGetDevice(&device); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device); + + // query max allowed per-block shared memory + int max_dynamic_smem = prop.sharedMemPerBlockOptin; + //printf("max_dynamic_smem = %d\n", max_dynamic_smem); + + auto device_lambda = [=] __device__(int idx, char* smem) { + auto offsets = offset_calc.get(idx); + int tid = threadIdx.x; + f(data.data(), offsets.data(), smem + tid * work_size); + }; + + // request the full size + auto kernelPtr = element_kernel; + cudaFuncSetAttribute( + kernelPtr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + max_dynamic_smem); + + if (shared > (size_t)max_dynamic_smem) { + TORCH_CHECK(false, "Requested shared memory (", shared, + " bytes) exceeds device maximum (", + max_dynamic_smem, " bytes)."); + } + + /*std::cout << "block = " << block.x + << ", grid = " << grid.x + << ", shared = " << shared + << ", work_size = " << work_size + << std::endl;*/ + + element_kernel<<>>(numel, device_lambda); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace native +} // namespace harp diff --git a/src/rtsolver/rtsolver_dispatch.cu b/src/rtsolver/rtsolver_dispatch.cu index ecf7d64..cd5b0b1 100644 --- a/src/rtsolver/rtsolver_dispatch.cu +++ b/src/rtsolver/rtsolver_dispatch.cu @@ -6,46 +6,57 @@ #include // harp -#include -#include "disort_dispatch.hpp" -#include "disort_impl.h" +#include +#include "rtsolver_dispatch.hpp" +#include "toon_mckay89_longwave_impl.h" +#include "toon_mckay89_shortwave_impl.h" -namespace disort { +namespace harp { -void call_toon89_lw_cuda(at::TensorIterator& iter, int rank_in_column, - disort_state *ds, disort_output *ds_out) { +void call_toon89_sw_cuda(at::TensorIterator& iter) { at::cuda::CUDAGuard device_guard(iter.device()); - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_disort_cuda", [&] { - auto nprop = at::native::ensure_nonempty_size(iter.output(), -1); + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_sw_cuda", [&] { + int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); + int mem_size = toon89_sw_space(nlay); - native::gpu_kernel<12>( - iter, [=] GPU_LAMBDA(char* const data[12], unsigned int strides[12]) { + native::gpu_mem_kernel<32, 5>( + iter, [=] GPU_LAMBDA( + char* const data[5], unsigned int strides[5], char *work) { auto out = reinterpret_cast(data[0] + strides[0]); auto prop = reinterpret_cast(data[1] + strides[1]); auto umu0 = reinterpret_cast(data[2] + strides[2]); - auto phi0 = reinterpret_cast(data[3] + strides[3]); - auto fbeam = reinterpret_cast(data[4] + strides[4]); - auto albedo = reinterpret_cast(data[5] + strides[5]); - auto fluor = reinterpret_cast(data[6] + strides[6]); - auto fisot = reinterpret_cast(data[7] + strides[7]); - auto temis = reinterpret_cast(data[8] + strides[8]); - auto btemp = reinterpret_cast(data[9] + strides[9]); - auto ttemp = reinterpret_cast(data[10] + strides[10]); - auto temf = reinterpret_cast(data[11] + strides[11]); - auto idxf = reinterpret_cast(data[12] + strides[12]); - int idx = static_cast(*idxf); - // disort_impl(out, prop, ftoa, temf, rank_in_column, ds[*idx], - // ds_out[*idx], nprop); + auto fbeam = reinterpret_cast(data[3] + strides[3]); + auto albedo = reinterpret_cast(data[4] + strides[4]); + toon_mckay89_shortwave(nlay, *fbeam, umu0, prop, *albedo, out, work); }); }); } -} // namespace disort +void call_toon89_lw_cuda(at::TensorIterator& iter) { + at::cuda::CUDAGuard device_guard(iter.device()); + + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_lw_cuda", [&] { + int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); + int mem_size = toon89_sw_space(nlay); + + native::gpu_mem_kernel<32, 4>( + iter, [=] GPU_LAMBDA( + char* const data[4], unsigned int strides[4], char *work) { + auto out = reinterpret_cast(data[0] + strides[0]); + auto prop = reinterpret_cast(data[1] + strides[1]); + auto albedo = reinterpret_cast(data[2] + strides[2]); + auto be = reinterpret_cast(data[3] + strides[3]); + toon_mckay89_longwave(nlay, be, prop, *albedo, out, work); + }); + }); +} + +} // namespace harp namespace at::native { -REGISTER_CUDA_DISPATCH(call_toon89_lw, &disort::call_toon89_lw_cuda); -REGISTER_CUDA_DISPATCH(call_toon89_sw, &disort::call_toon89_sw_cuda); +REGISTER_CUDA_DISPATCH(call_toon89_lw, &harp::call_toon89_lw_cuda); +REGISTER_CUDA_DISPATCH(call_toon89_sw, &harp::call_toon89_sw_cuda); } // namespace at::native From aa840a97d57c1f62eb4c344c8988032ea4486b46 Mon Sep 17 00:00:00 2001 From: mac/cli Date: Sun, 18 Jan 2026 15:50:11 -0500 Subject: [PATCH 5/9] wip --- src/radiation/radiation_band.cpp | 6 ++++++ src/radiation/radiation_band.hpp | 2 ++ src/rtsolver/toon_mckay89.hpp | 1 + 3 files changed, 9 insertions(+) diff --git a/src/radiation/radiation_band.cpp b/src/radiation/radiation_band.cpp index 0bf98c5..c9a01bf 100644 --- a/src/radiation/radiation_band.cpp +++ b/src/radiation/radiation_band.cpp @@ -107,6 +107,10 @@ RadiationBandOptions RadiationBandOptionsImpl::from_yaml( if (op->verbose()) { std::cout << " Solver flags: " << op->disort()->flags() << std::endl; } + } else if (op->solver_name() == "toon") { + op->toon() = ToonMcKay89OptionsImpl::create(); + op->toon()->wave_lower(std::vector(op->nwave(), wmin)); + op->toon()->wave_upper(std::vector(op->nwave(), wmax)); } else if (op->solver_name() == "twostr") { TORCH_CHECK(false, "twostr solver not implemented"); } else { @@ -173,6 +177,8 @@ void RadiationBandImpl::reset() { } else if (options->solver_name() == "disort") { rtsolver = torch::nn::AnyModule(disort::Disort(options->disort())); register_module("solver", rtsolver.ptr()); + } else if (options->solver_name() == "toon") { + rtsolver = torch::nn::AnyModule(ToonMcKay89(options->toon())); } else { TORCH_CHECK(false, "Unknown solver: ", options->solver_name()); } diff --git a/src/radiation/radiation_band.hpp b/src/radiation/radiation_band.hpp index a07da23..e219174 100644 --- a/src/radiation/radiation_band.hpp +++ b/src/radiation/radiation_band.hpp @@ -11,6 +11,7 @@ // harp #include +#include // arg #include @@ -103,6 +104,7 @@ struct RadiationBandOptionsImpl { ADD_ARG(OpacityDict, opacities) = {}; ADD_ARG(disort::DisortOptions, disort); + ADD_ARG(ToonMcKay89Options, toon); ADD_ARG(std::vector, wavenumber); ADD_ARG(std::vector, weight); diff --git a/src/rtsolver/toon_mckay89.hpp b/src/rtsolver/toon_mckay89.hpp index 8b9756b..e128b5b 100644 --- a/src/rtsolver/toon_mckay89.hpp +++ b/src/rtsolver/toon_mckay89.hpp @@ -67,6 +67,7 @@ class ToonMcKay89Impl : public torch::nn::Cloneable { std::string bname = "", torch::optional temf = torch::nullopt); }; +TORCH_MODULE(ToonMcKay89); } // namespace harp From f591293b2db0172c43c9f9f46e0bf02f0e250fa8 Mon Sep 17 00:00:00 2001 From: mac/cli Date: Sun, 18 Jan 2026 16:31:42 -0500 Subject: [PATCH 6/9] wip --- python/csrc/pyharp.cpp | 2 ++ python/csrc/pyradiation.cpp | 2 ++ python/csrc/pyrtsolver.cpp | 65 +++++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+) create mode 100644 python/csrc/pyrtsolver.cpp diff --git a/python/csrc/pyharp.cpp b/python/csrc/pyharp.cpp index 9f2258a..8cae320 100644 --- a/python/csrc/pyharp.cpp +++ b/python/csrc/pyharp.cpp @@ -15,6 +15,7 @@ void bind_opacity(py::module &m); void bind_math(py::module &m); void bind_constants(py::module &m); void bind_integrator(py::module &); +void bind_rtsolver(py::module &); PYBIND11_MODULE(pyharp, m) { m.attr("__name__") = "pyharp"; @@ -31,6 +32,7 @@ PYBIND11_MODULE(pyharp, m) { bind_math(m); bind_constants(m); bind_integrator(m); + bind_rtsolver(m); m.def( "species_names", diff --git a/python/csrc/pyradiation.cpp b/python/csrc/pyradiation.cpp index 615b206..38da0c3 100644 --- a/python/csrc/pyradiation.cpp +++ b/python/csrc/pyradiation.cpp @@ -43,6 +43,8 @@ void bind_radiation(py::module &m) { .ADD_OPTION(std::string, harp::RadiationBandOptionsImpl, outdirs) .ADD_OPTION(std::string, harp::RadiationBandOptionsImpl, solver_name) .ADD_OPTION(disort::DisortOptions, harp::RadiationBandOptionsImpl, disort) + .ADD_OPTION(harp::ToonMcKay89Options, harp::RadiationBandOptionsImpl, + toon) .ADD_OPTION(std::vector, harp::RadiationBandOptionsImpl, wavenumber) .ADD_OPTION(std::vector, harp::RadiationBandOptionsImpl, weight) diff --git a/python/csrc/pyrtsolver.cpp b/python/csrc/pyrtsolver.cpp new file mode 100644 index 0000000..9f0a933 --- /dev/null +++ b/python/csrc/pyrtsolver.cpp @@ -0,0 +1,65 @@ +// torch +#include + +// fmt +#include + +// harp +#include + +// python +#include "pyoptions.hpp" + +namespace py = pybind11; + +void bind_rtsolver(py::module &m) { + auto pyToonMcKay89Options = + py::class_( + m, "ToonMcKay89Options"); + + pyToonMcKay89Options.def(py::init<>()) + .def("__repr__", + [](const harp::ToonMcKay89Options &a) { + std::stringstream ss; + a->report(ss); + return fmt::format("ToonMcKay89Options(\n{})", ss.str()); + }) + .ADD_OPTION(std::vector, harp::ToonMcKay89OptionsImpl, wave_lower) + .ADD_OPTION(std::vector, harp::ToonMcKay89OptionsImpl, wave_upper) + .ADD_OPTION(bool, harp::ToonMcKay89OptionsImpl, zenith_correction); + + torch::python::bind_module(m, "ToonMcKay89") + .def(py::init<>()) + .def(py::init(), py::arg("options")) + .def_readonly("options", &harp::ToonMcKay89Impl::options) + .def( + "forward", + [](harp::ToonMcKay89Impl &self, torch::Tensor prop, std::string bname, + torch::optional temf, const py::kwargs &kwargs) { + // get bc from kwargs + std::map bc; + for (auto item : kwargs) { + auto key = py::cast(item.first); + auto value = py::cast(item.second); + bc.emplace(std::move(key), std::move(value)); + } + + for (auto &[key, value] : bc) { + std::vector items = {"fbeam", "albedo", "umu0"}; + // broadcast dimensions to (nwave, ncol) + if (std::find(items.begin(), items.end(), key) != items.end()) { + while (value.dim() < 2) { + value = value.unsqueeze(0); + } + } + } + + // broadcast dimensions to (nwave, ncol, nlyr, nprop) + while (prop.dim() < 4) { + prop = prop.unsqueeze(0); + } + + return self.forward(prop, &bc, bname, temf); + }, + py::arg("prop"), py::arg("bname") = "", py::arg("temf") = py::none()); +}; From f1ebec84223869dcf56bcac1f630e9c445fc6511 Mon Sep 17 00:00:00 2001 From: mac/cli Date: Sun, 18 Jan 2026 17:40:30 -0500 Subject: [PATCH 7/9] wip --- src/radiation/bbflux.cpp | 8 ++--- src/rtsolver/rtsolver_dispatch.cpp | 8 ++--- src/rtsolver/rtsolver_dispatch.cu | 8 ++--- src/rtsolver/toon_mckay89.cpp | 17 ++++++---- src/rtsolver/toon_mckay89_longwave_impl.h | 10 +++--- src/rtsolver/toon_mckay89_shortwave_impl.h | 34 ++++++++++---------- tests/CMakeLists.txt | 2 +- tests/test_toon.cpp | 36 ++++++++++++++++++++++ 8 files changed, 83 insertions(+), 40 deletions(-) create mode 100644 tests/test_toon.cpp diff --git a/src/radiation/bbflux.cpp b/src/radiation/bbflux.cpp index 8f5983a..d421268 100644 --- a/src/radiation/bbflux.cpp +++ b/src/radiation/bbflux.cpp @@ -126,8 +126,8 @@ torch::Tensor bbflux_wavenumber(torch::Tensor wn1, torch::Tensor wn2, // Handle different cases for wavenumbers for (int i = 0; i <= 1; ++i) { - smallv += torch::where(v[i] < VCUT, torch::ones_like(temp), - torch::zeros_like(temp)); + smallv = smallv + torch::where(v[i] < VCUT, torch::ones_like(temp), + torch::zeros_like(temp)); auto vsq = v[i] * v[i]; p[i] = @@ -144,8 +144,8 @@ torch::Tensor bbflux_wavenumber(torch::Tensor wn1, torch::Tensor wn2, for (int m = 1; m <= 6; ++m) { auto mv = static_cast(m) * v[i]; - exm *= ex; - d[i] += exm * (6.0 + mv * (6.0 + mv * (3.0 + mv))) / (m * m); + exm = exm * ex; + d[i] = d[i] + exm * (6.0 + mv * (6.0 + mv * (3.0 + mv))) / (m * m); } d[i] *= conc; diff --git a/src/rtsolver/rtsolver_dispatch.cpp b/src/rtsolver/rtsolver_dispatch.cpp index 7ce13bd..d120a9d 100644 --- a/src/rtsolver/rtsolver_dispatch.cpp +++ b/src/rtsolver/rtsolver_dispatch.cpp @@ -14,7 +14,7 @@ namespace harp { void call_toon89_sw_cpu(at::TensorIterator &iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_sw_cpu", [&] { - int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); + int nlay = at::native::ensure_nonempty_size(iter.input(0), -2); int grain_size = iter.numel() / at::get_num_threads(); int mem_size = toon89_sw_space(nlay); char *work = new char[mem_size]; @@ -40,7 +40,7 @@ void call_toon89_sw_cpu(at::TensorIterator &iter) { void call_toon89_lw_cpu(at::TensorIterator &iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_lw_cpu", [&] { - int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); + int nlay = at::native::ensure_nonempty_size(iter.input(0), -2); int grain_size = iter.numel() / at::get_num_threads(); int mem_size = toon89_sw_space(nlay); char *work = new char[mem_size]; @@ -50,9 +50,9 @@ void call_toon89_lw_cpu(at::TensorIterator &iter) { for (int i = 0; i < n; i++) { auto out = reinterpret_cast(data[0] + i * strides[0]); auto prop = reinterpret_cast(data[1] + i * strides[1]); + auto be = reinterpret_cast(data[2] + i * strides[2]); auto albedo = - reinterpret_cast(data[4] + i * strides[4]); - auto be = reinterpret_cast(data[5] + i * strides[5]); + reinterpret_cast(data[3] + i * strides[3]); toon_mckay89_longwave(nlay, be, prop, *albedo, out, work); } }, diff --git a/src/rtsolver/rtsolver_dispatch.cu b/src/rtsolver/rtsolver_dispatch.cu index cd5b0b1..7f72991 100644 --- a/src/rtsolver/rtsolver_dispatch.cu +++ b/src/rtsolver/rtsolver_dispatch.cu @@ -20,7 +20,7 @@ void call_toon89_sw_cuda(at::TensorIterator& iter) { int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); int mem_size = toon89_sw_space(nlay); - native::gpu_mem_kernel<32, 5>( + native::gpu_mem_kernel<128, 5>( iter, [=] GPU_LAMBDA( char* const data[5], unsigned int strides[5], char *work) { auto out = reinterpret_cast(data[0] + strides[0]); @@ -40,13 +40,13 @@ void call_toon89_lw_cuda(at::TensorIterator& iter) { int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); int mem_size = toon89_sw_space(nlay); - native::gpu_mem_kernel<32, 4>( + native::gpu_mem_kernel<128, 4>( iter, [=] GPU_LAMBDA( char* const data[4], unsigned int strides[4], char *work) { auto out = reinterpret_cast(data[0] + strides[0]); auto prop = reinterpret_cast(data[1] + strides[1]); - auto albedo = reinterpret_cast(data[2] + strides[2]); - auto be = reinterpret_cast(data[3] + strides[3]); + auto be = reinterpret_cast(data[2] + strides[2]); + auto albedo = reinterpret_cast(data[3] + strides[3]); toon_mckay89_longwave(nlay, be, prop, *albedo, out, work); }); }); diff --git a/src/rtsolver/toon_mckay89.cpp b/src/rtsolver/toon_mckay89.cpp index c9e03f6..6d79f81 100644 --- a/src/rtsolver/toon_mckay89.cpp +++ b/src/rtsolver/toon_mckay89.cpp @@ -88,7 +88,8 @@ torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, .add_input(prop) .add_owned_input(bc->at("umu0") .view({1, ncol, 1, 1}) - .expand({nwave, ncol, nlyr, 1})) + .expand({nwave, ncol, nlyr + 1, 1}) + .contiguous()) .add_owned_input(bc->at("fbeam").view({nwave, ncol, 1, 1})) .add_owned_input(bc->at("albedo").view({nwave, ncol, 1, 1})) .build(); @@ -102,9 +103,15 @@ torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, temp(i) = ds_.temper[i]; be(i) = BB_integrate(ds_.temper[i], spec.wav1, spec.wav2); }*/ - auto wave_lo = torch::tensor(options->wave_lower(), prop.options()); - auto wave_hi = torch::tensor(options->wave_upper(), prop.options()); + auto wave_lo = torch::tensor(options->wave_lower(), prop.options()) + .unsqueeze(-1) + .unsqueeze(-1); + auto wave_hi = torch::tensor(options->wave_upper(), prop.options()) + .unsqueeze(-1) + .unsqueeze(-1); + auto be = bbflux_wavenumber(wave_lo, wave_hi, temf.value()); + auto iter = at::TensorIteratorConfig() .resize_outputs(false) .check_all_same_dtype(true) @@ -112,10 +119,8 @@ torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, /*squash_dims=*/{2, 3}) .add_output(flx) .add_input(prop) - .add_owned_input(bc->at("fbeam").view({nwave, ncol, 1, 1})) + .add_input(be) .add_owned_input(bc->at("albedo").view({nwave, ncol, 1, 1})) - .add_owned_input(be.view({1, ncol, nlyr + 1, 1}) - .expand({nwave, ncol, nlyr + 1, 1})) .build(); at::native::call_toon89_lw(flx.device().type(), iter); diff --git a/src/rtsolver/toon_mckay89_longwave_impl.h b/src/rtsolver/toon_mckay89_longwave_impl.h index 77f4316..7b89b82 100644 --- a/src/rtsolver/toon_mckay89_longwave_impl.h +++ b/src/rtsolver/toon_mckay89_longwave_impl.h @@ -13,11 +13,11 @@ #include "dtridgl_impl.h" -#define DTAU_IN(i) prop[(nlay - i - 1) * 3] -#define W_IN(i) prop[(nlay - i - 1) * 3 + 1] -#define G_IN(i) prop[(nlay - i - 1) * 3 + 2] -#define FLX_UP(i) flx[2 * (nlev - i - 1)] -#define FLX_DN(i) flx[2 * (nlev - i - 1) + 1] +#define DTAU_IN(i) prop[(nlay - (i) - 1) * 3] +#define W_IN(i) prop[(nlay - (i) - 1) * 3 + 1] +#define G_IN(i) prop[(nlay - (i) - 1) * 3 + 2] +#define FLX_UP(i) flx[2 * (nlev - (i) - 1)] +#define FLX_DN(i) flx[2 * (nlev - (i) - 1) + 1] namespace harp { diff --git a/src/rtsolver/toon_mckay89_shortwave_impl.h b/src/rtsolver/toon_mckay89_shortwave_impl.h index 58048b9..82994bc 100644 --- a/src/rtsolver/toon_mckay89_shortwave_impl.h +++ b/src/rtsolver/toon_mckay89_shortwave_impl.h @@ -11,11 +11,12 @@ #include "dtridgl_impl.h" -#define DTAU_IN(i) prop[(nlay - i - 1) * 3] -#define W_IN(i) prop[(nlay - i - 1) * 3 + 1] -#define G_IN(i) prop[(nlay - i - 1) * 3 + 2] -#define FLX_UP(i) flx[2 * (nlev - i - 1)] -#define FLX_DN(i) flx[2 * (nlev - i - 1) + 1] +#define DTAU_IN(i) prop[(nlay - (i) - 1) * 3] +#define W_IN(i) prop[(nlay - (i) - 1) * 3 + 1] +#define G_IN(i) prop[(nlay - (i) - 1) * 3 + 2] +#define FLX_UP(i) flx[2 * (nlev - (i) - 1)] +#define FLX_DN(i) flx[2 * (nlev - (i) - 1) + 1] +#define MU_IN(i) mu_in[nlev - (i) - 1] namespace harp { @@ -86,17 +87,17 @@ DISPATCH_MACRO void toon_mckay89_shortwave(int nlay, T F0_in, T const *mu_in, if (all_zero_w) { // --- Special Case: Direct Beam Only --- - if (mu_in[nlev - 1] == mu_in[0]) { + if (MU_IN(nlev - 1) == MU_IN(0)) { for (int k = 0; k < nlev; k++) { - FLX_DN(k) = F0_in * mu_in[nlev - 1] * exp(-tau_in[k] / mu_in[nlev - 1]); + FLX_DN(k) = F0_in * MU_IN(nlev - 1) * exp(-tau_in[k] / MU_IN(nlev - 1)); } } else { - cum_trans[0] = tau_in[0] / mu_in[0]; + cum_trans[0] = tau_in[0] / MU_IN(0); for (int k = 0; k < nlev - 1; k++) { - cum_trans[k + 1] = cum_trans[k] + DTAU_IN(k) / mu_in[k + 1]; + cum_trans[k + 1] = cum_trans[k] + DTAU_IN(k) / MU_IN(k + 1); } for (int k = 0; k < nlev; k++) { - FLX_DN(k) = F0_in * mu_in[nlev - 1] * exp(-cum_trans[k]); + FLX_DN(k) = F0_in * MU_IN(nlev - 1) * exp(-cum_trans[k]); } } FLX_DN(nlev - 1) *= (1.0 - w_surf_in); @@ -115,18 +116,18 @@ DISPATCH_MACRO void toon_mckay89_shortwave(int nlay, T F0_in, T const *mu_in, tau[0] = 0.0; for (int k = 0; k < nlay; k++) tau[k + 1] = tau[k] + dtau[k]; - if (mu_in[nlev - 1] == mu_in[0]) { - T mu_val = mu_in[nlev - 1]; + if (MU_IN(nlev - 1) == MU_IN(0)) { + T mu_val = MU_IN(nlev - 1); for (int k = 0; k < nlev; k++) dir[k] = F0_in * mu_val * exp(-tau[k] / mu_val); for (int i = 0; i < nlay; i++) mu_zm[i] = mu_val; } else { - cum_trans[0] = tau[0] / mu_in[0]; + cum_trans[0] = tau[0] / MU_IN(0); for (int k = 0; k < nlev - 1; k++) - cum_trans[k + 1] = cum_trans[k] + (tau[k + 1] - tau[k]) / mu_in[k + 1]; + cum_trans[k + 1] = cum_trans[k] + (tau[k + 1] - tau[k]) / MU_IN(k + 1); for (int k = 0; k < nlev; k++) - dir[k] = F0_in * mu_in[nlev - 1] * exp(-cum_trans[k]); - for (int i = 0; i < nlay; i++) mu_zm[i] = (mu_in[i] + mu_in[i + 1]) / 2.0; + dir[k] = F0_in * MU_IN(nlev - 1) * exp(-cum_trans[k]); + for (int i = 0; i < nlay; i++) mu_zm[i] = (MU_IN(i) + MU_IN(i + 1)) / 2.0; } for (int i = 0; i < nlay; i++) { @@ -209,3 +210,4 @@ DISPATCH_MACRO void toon_mckay89_shortwave(int nlay, T F0_in, T const *mu_in, #undef G_IN #undef FLX_UP #undef FLX_DN +#undef MU_IN diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 408657d..319b0dc 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -12,7 +12,7 @@ setup_test(test_bbflux) setup_test(test_flux_utils) #setup_test(test_yaml_input) setup_test(test_tridiag) -setup_test(test_composition) +setup_test(test_toon) # Python tests diff --git a/tests/test_toon.cpp b/tests/test_toon.cpp new file mode 100644 index 0000000..ba00569 --- /dev/null +++ b/tests/test_toon.cpp @@ -0,0 +1,36 @@ +// harp +#include + +int main(int argc, char** argv) { + auto op = harp::ToonMcKay89OptionsImpl::create(); + op->wave_lower({200., 500., 1000.}); + op->wave_upper({500., 1000., 2000.}); + + op->report(std::cout); + harp::ToonMcKay89 toon(op); + + int nwave = op->wave_lower().size(); + int nlyr = 10; + int ncol = 2; + int nprop = 3; + + auto prop = 0.5 * torch::ones({nwave, ncol, nlyr, nprop}, torch::kFloat64); + prop.select(-1, 0) = 0.1; + prop.select(-1, 1) = 0.2; + prop.select(-1, 2) = 0.3; + std::map bc; + bc["fbeam"] = torch::ones({nwave, ncol}, torch::kFloat64); + bc["umu0"] = torch::ones({ncol}, torch::kFloat64) * 0.2; + bc["albedo"] = torch::ones({nwave, ncol}, torch::kFloat64) * 0.3; + + auto sw_flx = toon(prop, &bc); + + std::cout << "sw_flx_up = " << sw_flx.select(-1, 0) << "\n"; + std::cout << "sw_flx_dn = " << sw_flx.select(-1, 1) << "\n"; + + auto temf = torch::ones({ncol, nlyr + 1}, torch::kFloat64) * 300.0; + auto lw_flx = toon(prop, &bc, "", temf); + + std::cout << "lw_flx_up = " << lw_flx.select(-1, 0) << "\n"; + std::cout << "lw_flx_dn = " << lw_flx.select(-1, 1) << "\n"; +} From 175b2df3e3696f4f6e9f26c52aa1ce886dd3ca11 Mon Sep 17 00:00:00 2001 From: stormy/cli Date: Mon, 19 Jan 2026 15:12:32 -0500 Subject: [PATCH 8/9] add toon testing --- cmake/macros/macro_setup_test.cmake | 4 +- src/CMakeLists.txt | 1 + src/loops.cuh | 88 +++++++++++----------- src/rtsolver/rtsolver_dispatch.cpp | 2 +- src/rtsolver/rtsolver_dispatch.cu | 14 ++-- src/rtsolver/toon_mckay89_longwave_impl.h | 5 +- src/rtsolver/toon_mckay89_shortwave_impl.h | 35 ++++----- src/utils/alloc.h | 8 -- tests/device_testing.hpp | 55 ++++++++++++++ tests/test_toon.cpp | 27 +++++-- 10 files changed, 145 insertions(+), 94 deletions(-) create mode 100644 tests/device_testing.hpp diff --git a/cmake/macros/macro_setup_test.cmake b/cmake/macros/macro_setup_test.cmake index 876e674..d05396a 100644 --- a/cmake/macros/macro_setup_test.cmake +++ b/cmake/macros/macro_setup_test.cmake @@ -22,7 +22,9 @@ macro(setup_test namel) ${TORCH_API_INCLUDE_DIR}) target_link_libraries(${namel}.${buildl} - PRIVATE pyharp::harp gtest_main) + PRIVATE pyharp::harp + $,pyharp::harp_cu,> + gtest_main) add_test(NAME ${namel}.${buildl} COMMAND ${namel}.${buildl}) endmacro() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e2927be..a3ae43d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -65,6 +65,7 @@ add_library(pyharp::harp ALIAS ${namel}_${buildl}) if (CUDAToolkit_FOUND) file(GLOB cu_src_files integrator/*.cu + rtsolver/*.cu ) add_library(${namel}_cuda_${buildl} diff --git a/src/loops.cuh b/src/loops.cuh index 81a34ba..4bba488 100644 --- a/src/loops.cuh +++ b/src/loops.cuh @@ -8,16 +8,11 @@ namespace harp { namespace native { template -__global__ void element_kernel(int64_t numel, func_t f) { +__global__ void element_kernel(int64_t numel, func_t f, char *work) { int tid = threadIdx.x; int idx = blockIdx.x * blockDim.x + tid; - - // Shared memory allocation - extern __shared__ unsigned char memory[]; - char* smem = reinterpret_cast(memory); - if (idx < numel) { - f(idx, smem); + f(idx, work); } } @@ -40,8 +35,8 @@ void gpu_kernel(at::TensorIterator& iter, const func_t& f) { }); } -template -void gpu_mem_kernel(at::TensorIterator& iter, int work_size, const func_t& f) { +template +void gpu_chunk_kernel(at::TensorIterator& iter, int work_size, const func_t& f) { TORCH_CHECK(iter.ninputs() + iter.noutputs() == Arity); std::array data; @@ -52,49 +47,50 @@ void gpu_mem_kernel(at::TensorIterator& iter, int work_size, const func_t& f) { auto offset_calc = ::make_offset_calculator(iter); int64_t numel = iter.numel(); - dim3 block(Threads); - dim3 grid((numel + block.x - 1) / block.x); - auto stream = at::cuda::getCurrentCUDAStream(); - size_t shared = block.x * work_size; + // devide numel into Chunk parts to reduce memory usage + // allocate working memory pool + char* d_workspace = nullptr; - // set attribute to allow max dynamic shared memory - int device; - cudaGetDevice(&device); - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, device); + // workspace size per chunk + int chunks = Chunks > numel ? numel : Chunks; + int base = numel / chunks; + int rem = numel % chunks; - // query max allowed per-block shared memory - int max_dynamic_smem = prop.sharedMemPerBlockOptin; - //printf("max_dynamic_smem = %d\n", max_dynamic_smem); + size_t workspace_bytes = work_size * (base + (rem > 0 ? 1 : 0)); + cudaMalloc(&d_workspace, workspace_bytes); - auto device_lambda = [=] __device__(int idx, char* smem) { - auto offsets = offset_calc.get(idx); - int tid = threadIdx.x; - f(data.data(), offsets.data(), smem + tid * work_size); - }; - - // request the full size - auto kernelPtr = element_kernel; - cudaFuncSetAttribute( - kernelPtr, - cudaFuncAttributeMaxDynamicSharedMemorySize, - max_dynamic_smem); - - if (shared > (size_t)max_dynamic_smem) { - TORCH_CHECK(false, "Requested shared memory (", shared, - " bytes) exceeds device maximum (", - max_dynamic_smem, " bytes)."); - } + int chunk_start = 0; + + for (int n = 0; n < chunks; n++) { + int64_t chunk_numel = base + (n < rem ? 1 : 0); + int64_t chunk_end = chunk_start + chunk_numel; // exclusive - /*std::cout << "block = " << block.x - << ", grid = " << grid.x - << ", shared = " << shared - << ", work_size = " << work_size - << std::endl;*/ + dim3 block(64); + dim3 grid((chunk_numel + block.x - 1) / block.x); - element_kernel<<>>(numel, device_lambda); + auto device_lambda = [=] __device__(int idx, char* work) { + auto offsets = offset_calc.get(idx + chunk_start); + f(data.data(), offsets.data(), work + idx * work_size); + }; + + /*std::cout << "chunk = " << n + << ", chunk_start = " << chunk_start + << ", chunk_end = " << chunk_end + << ", chunk_numel = " << chunk_numel + << ", block = " << block.x + << ", grid = " << grid.x + << ", work_size = " << work_size + << std::endl;*/ + + element_kernel<<>>(chunk_numel, device_lambda, d_workspace); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + cudaDeviceSynchronize(); + + chunk_start = chunk_end; + } - C10_CUDA_KERNEL_LAUNCH_CHECK(); + // free working memory pool + cudaFree(d_workspace); } } // namespace native diff --git a/src/rtsolver/rtsolver_dispatch.cpp b/src/rtsolver/rtsolver_dispatch.cpp index d120a9d..dc0a48c 100644 --- a/src/rtsolver/rtsolver_dispatch.cpp +++ b/src/rtsolver/rtsolver_dispatch.cpp @@ -42,7 +42,7 @@ void call_toon89_lw_cpu(at::TensorIterator &iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_lw_cpu", [&] { int nlay = at::native::ensure_nonempty_size(iter.input(0), -2); int grain_size = iter.numel() / at::get_num_threads(); - int mem_size = toon89_sw_space(nlay); + int mem_size = toon89_lw_space(nlay); char *work = new char[mem_size]; iter.for_each( diff --git a/src/rtsolver/rtsolver_dispatch.cu b/src/rtsolver/rtsolver_dispatch.cu index 7f72991..af75a88 100644 --- a/src/rtsolver/rtsolver_dispatch.cu +++ b/src/rtsolver/rtsolver_dispatch.cu @@ -17,11 +17,11 @@ void call_toon89_sw_cuda(at::TensorIterator& iter) { at::cuda::CUDAGuard device_guard(iter.device()); AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_sw_cuda", [&] { - int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); + int nlay = at::native::ensure_nonempty_size(iter.input(0), -2); int mem_size = toon89_sw_space(nlay); - native::gpu_mem_kernel<128, 5>( - iter, [=] GPU_LAMBDA( + native::gpu_chunk_kernel<8, 5>( + iter, mem_size, [=] GPU_LAMBDA( char* const data[5], unsigned int strides[5], char *work) { auto out = reinterpret_cast(data[0] + strides[0]); auto prop = reinterpret_cast(data[1] + strides[1]); @@ -37,11 +37,11 @@ void call_toon89_lw_cuda(at::TensorIterator& iter) { at::cuda::CUDAGuard device_guard(iter.device()); AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "call_toon89_lw_cuda", [&] { - int nlay = at::native::ensure_nonempty_size(iter.input(1), -2); - int mem_size = toon89_sw_space(nlay); + int nlay = at::native::ensure_nonempty_size(iter.input(0), -2); + int mem_size = toon89_lw_space(nlay); - native::gpu_mem_kernel<128, 4>( - iter, [=] GPU_LAMBDA( + native::gpu_chunk_kernel<8, 4>( + iter, mem_size, [=] GPU_LAMBDA( char* const data[4], unsigned int strides[4], char *work) { auto out = reinterpret_cast(data[0] + strides[0]); auto prop = reinterpret_cast(data[1] + strides[1]); diff --git a/src/rtsolver/toon_mckay89_longwave_impl.h b/src/rtsolver/toon_mckay89_longwave_impl.h index 7b89b82..99ff86b 100644 --- a/src/rtsolver/toon_mckay89_longwave_impl.h +++ b/src/rtsolver/toon_mckay89_longwave_impl.h @@ -56,7 +56,6 @@ DISPATCH_MACRO void toon_mckay89_longwave(int nlay, const T *be, const T *prop, T *Cp = alloc_from(work, nlay); T *Cm = alloc_from(work, nlay); - T *exptrm = alloc_from(work, nlay); T *Ep = alloc_from(work, nlay); T *Em = alloc_from(work, nlay); T *E1 = alloc_from(work, nlay); @@ -119,8 +118,8 @@ DISPATCH_MACRO void toon_mckay89_longwave(int nlay, const T *be, const T *prop, Cp[k] = B0[k] + B1[k] * dtau[k] + B1[k] * term[k]; Cm[k] = B0[k] + B1[k] * dtau[k] - B1[k] * term[k]; - exptrm[k] = fmin(lam[k] * dtau[k], 35.0); - Ep[k] = exp(exptrm[k]); + T exptrm = fmin(lam[k] * dtau[k], 35.0); + Ep[k] = exp(exptrm); Em[k] = 1.0 / Ep[k]; E1[k] = Ep[k] + gam[k] * Em[k]; E2[k] = Ep[k] - gam[k] * Em[k]; diff --git a/src/rtsolver/toon_mckay89_shortwave_impl.h b/src/rtsolver/toon_mckay89_shortwave_impl.h index 82994bc..9f9498d 100644 --- a/src/rtsolver/toon_mckay89_shortwave_impl.h +++ b/src/rtsolver/toon_mckay89_shortwave_impl.h @@ -38,20 +38,13 @@ DISPATCH_MACRO void toon_mckay89_shortwave(int nlay, T F0_in, T const *mu_in, T *mu_zm = alloc_from(work, nlay); T *w0 = alloc_from(work, nlay); T *hg = alloc_from(work, nlay); - T *g1 = alloc_from(work, nlay); - T *g2 = alloc_from(work, nlay); - T *g3 = alloc_from(work, nlay); - T *g4 = alloc_from(work, nlay); - T *lam = alloc_from(work, nlay); T *gam = alloc_from(work, nlay); - T *denom = alloc_from(work, nlay); T *Am = alloc_from(work, nlay); T *Ap = alloc_from(work, nlay); T *Cpm1 = alloc_from(work, nlay); T *Cmm1 = alloc_from(work, nlay); T *Cp = alloc_from(work, nlay); T *Cm = alloc_from(work, nlay); - T *exptrm = alloc_from(work, nlay); T *Ep = alloc_from(work, nlay); T *Em = alloc_from(work, nlay); T *E1 = alloc_from(work, nlay); @@ -131,25 +124,23 @@ DISPATCH_MACRO void toon_mckay89_shortwave(int nlay, T F0_in, T const *mu_in, } for (int i = 0; i < nlay; i++) { - g1[i] = sqrt3d2 * (2.0 - w0[i] * (1.0 + hg[i])); - g2[i] = (sqrt3d2 * w0[i]) * (1.0 - hg[i]); - if (g2[i] == 0.0) g2[i] = 1.0e-10; - g3[i] = (1.0 - sqrt3 * hg[i] * mu_zm[i]) / 2.0; - g4[i] = 1.0 - g3[i]; - lam[i] = sqrt(g1[i] * g1[i] - g2[i] * g2[i]); - gam[i] = (g1[i] - lam[i]) / g2[i]; - denom[i] = (lam[i] * lam[i]) - 1.0 / (mu_zm[i] * mu_zm[i]); - if (denom[i] == 0.0) denom[i] = 1.0e-10; - Ap[i] = F0_in * w0[i] * - (g3[i] * (g1[i] - 1.0 / mu_zm[i]) + g2[i] * g4[i]) / denom[i]; - Am[i] = F0_in * w0[i] * - (g4[i] * (g1[i] + 1.0 / mu_zm[i]) + g2[i] * g3[i]) / denom[i]; + T g1 = sqrt3d2 * (2.0 - w0[i] * (1.0 + hg[i])); + T g2 = (sqrt3d2 * w0[i]) * (1.0 - hg[i]); + if (g2 == 0.0) g2 = 1.0e-10; + T g3 = (1.0 - sqrt3 * hg[i] * mu_zm[i]) / 2.0; + T g4 = 1.0 - g3; + T lam = sqrt(g1 * g1 - g2 * g2); + gam[i] = (g1 - lam) / g2; + T denom = (lam * lam) - 1.0 / (mu_zm[i] * mu_zm[i]); + if (denom == 0.0) denom = 1.0e-10; + Ap[i] = F0_in * w0[i] * (g3 * (g1 - 1.0 / mu_zm[i]) + g2 * g4) / denom; + Am[i] = F0_in * w0[i] * (g4 * (g1 + 1.0 / mu_zm[i]) + g2 * g3) / denom; Cpm1[i] = Ap[i] * exp(-tau[i] / mu_zm[i]); Cmm1[i] = Am[i] * exp(-tau[i] / mu_zm[i]); Cp[i] = Ap[i] * exp(-tau[i + 1] / mu_zm[i]); Cm[i] = Am[i] * exp(-tau[i + 1] / mu_zm[i]); - exptrm[i] = fmin(lam[i] * dtau[i], 35.0); - Ep[i] = exp(exptrm[i]); + T exptrm = fmin(lam * dtau[i], 35.0); + Ep[i] = exp(exptrm); Em[i] = 1.0 / Ep[i]; E1[i] = Ep[i] + gam[i] * Em[i]; E2[i] = Ep[i] - gam[i] * Em[i]; diff --git a/src/utils/alloc.h b/src/utils/alloc.h index 9f74482..4b502a2 100644 --- a/src/utils/alloc.h +++ b/src/utils/alloc.h @@ -40,20 +40,13 @@ size_t toon89_sw_space(int nlay) { bump(alignof(T), nlay * sizeof(T)); // mu_zm bump(alignof(T), nlay * sizeof(T)); // w0 bump(alignof(T), nlay * sizeof(T)); // hg - bump(alignof(T), nlay * sizeof(T)); // g1 - bump(alignof(T), nlay * sizeof(T)); // g2 - bump(alignof(T), nlay * sizeof(T)); // g3 - bump(alignof(T), nlay * sizeof(T)); // g4 - bump(alignof(T), nlay * sizeof(T)); // lam bump(alignof(T), nlay * sizeof(T)); // gam - bump(alignof(T), nlay * sizeof(T)); // denom bump(alignof(T), nlay * sizeof(T)); // Am bump(alignof(T), nlay * sizeof(T)); // Ap bump(alignof(T), nlay * sizeof(T)); // Cpm1 bump(alignof(T), nlay * sizeof(T)); // Cmm1 bump(alignof(T), nlay * sizeof(T)); // Cp bump(alignof(T), nlay * sizeof(T)); // Cm - bump(alignof(T), nlay * sizeof(T)); // exptrm bump(alignof(T), nlay * sizeof(T)); // Ep bump(alignof(T), nlay * sizeof(T)); // Em bump(alignof(T), nlay * sizeof(T)); // E1 @@ -93,7 +86,6 @@ size_t toon89_lw_space(int nlay) { bump(alignof(T), nlay * sizeof(T)); // Cmm1 bump(alignof(T), nlay * sizeof(T)); // Cp bump(alignof(T), nlay * sizeof(T)); // Cm - bump(alignof(T), nlay * sizeof(T)); // exptrm bump(alignof(T), nlay * sizeof(T)); // Ep bump(alignof(T), nlay * sizeof(T)); // Em bump(alignof(T), nlay * sizeof(T)); // E1 diff --git a/tests/device_testing.hpp b/tests/device_testing.hpp new file mode 100644 index 0000000..1109bd3 --- /dev/null +++ b/tests/device_testing.hpp @@ -0,0 +1,55 @@ +#pragma once + +// external +#include + +// torch +#include + +struct Parameters { + torch::DeviceType device_type; + torch::Dtype dtype; +}; + +inline void PrintTo(const Parameters& param, std::ostream* os) { + std::string device_str = torch::Device(param.device_type).str(); + std::string dtype_str = torch::toString(param.dtype); + *os << "Device: " << device_str << ", Dtype: " << dtype_str; +} + +class DeviceTest : public testing::TestWithParam { + protected: + torch::Device device = torch::kCPU; + torch::Dtype dtype = torch::kFloat32; + + void SetUp() override { + // Get the current parameters + auto param = GetParam(); + device = torch::Device(param.device_type); + dtype = param.dtype; + + // Check if the device is available, and skip the test if not + if (device.type() == torch::kCUDA && !torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA is not available, skipping test."; + } + + if (device.type() == torch::kMPS && !torch::hasMPS()) { + GTEST_SKIP() << "MPS is not available, skipping test."; + } + } +}; + +INSTANTIATE_TEST_SUITE_P( + DeviceAndDtype, DeviceTest, + testing::Values(Parameters{torch::kCPU, torch::kFloat32}, + Parameters{torch::kCPU, torch::kFloat64}, + // Parameters{torch::kMPS, torch::kFloat32}, + Parameters{torch::kCUDA, torch::kFloat32}, + Parameters{torch::kCUDA, torch::kFloat64}), + [](const testing::TestParamInfo& info) { + std::string name = torch::Device(info.param.device_type).str(); + name += "_"; + name += torch::toString(info.param.dtype); + std::replace(name.begin(), name.end(), '.', '_'); + return name; + }); diff --git a/tests/test_toon.cpp b/tests/test_toon.cpp index ba00569..a4f7494 100644 --- a/tests/test_toon.cpp +++ b/tests/test_toon.cpp @@ -1,36 +1,51 @@ +// external +#include + // harp #include -int main(int argc, char** argv) { +// tests +#include "device_testing.hpp" + +using namespace harp; + +TEST_P(DeviceTest, simple_toon_mckay89) { auto op = harp::ToonMcKay89OptionsImpl::create(); op->wave_lower({200., 500., 1000.}); op->wave_upper({500., 1000., 2000.}); op->report(std::cout); harp::ToonMcKay89 toon(op); + toon->to(device, dtype); int nwave = op->wave_lower().size(); int nlyr = 10; int ncol = 2; int nprop = 3; - auto prop = 0.5 * torch::ones({nwave, ncol, nlyr, nprop}, torch::kFloat64); + auto prop = 0.5 * torch::ones({nwave, ncol, nlyr, nprop}, + torch::device(device).dtype(dtype)); prop.select(-1, 0) = 0.1; prop.select(-1, 1) = 0.2; prop.select(-1, 2) = 0.3; std::map bc; - bc["fbeam"] = torch::ones({nwave, ncol}, torch::kFloat64); - bc["umu0"] = torch::ones({ncol}, torch::kFloat64) * 0.2; - bc["albedo"] = torch::ones({nwave, ncol}, torch::kFloat64) * 0.3; + bc["fbeam"] = torch::ones({nwave, ncol}, prop.options()); + bc["umu0"] = torch::ones({ncol}, prop.options()) * 0.2; + bc["albedo"] = torch::ones({nwave, ncol}, prop.options()) * 0.3; auto sw_flx = toon(prop, &bc); std::cout << "sw_flx_up = " << sw_flx.select(-1, 0) << "\n"; std::cout << "sw_flx_dn = " << sw_flx.select(-1, 1) << "\n"; - auto temf = torch::ones({ncol, nlyr + 1}, torch::kFloat64) * 300.0; + auto temf = torch::ones({ncol, nlyr + 1}, prop.options()) * 300.0; auto lw_flx = toon(prop, &bc, "", temf); std::cout << "lw_flx_up = " << lw_flx.select(-1, 0) << "\n"; std::cout << "lw_flx_dn = " << lw_flx.select(-1, 1) << "\n"; } + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From dbe23dcaba0c368f12d186bdc08892bc7edc1277 Mon Sep 17 00:00:00 2001 From: mac/cli Date: Mon, 19 Jan 2026 15:38:29 -0500 Subject: [PATCH 9/9] wip --- src/rtsolver/toon_mckay89.cpp | 16 ++++----- src/rtsolver/toon_mckay89_longwave_impl.h | 40 +++++++++++----------- src/rtsolver/toon_mckay89_shortwave_impl.h | 40 +++++++++++----------- 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/rtsolver/toon_mckay89.cpp b/src/rtsolver/toon_mckay89.cpp index 6d79f81..eedbcd6 100644 --- a/src/rtsolver/toon_mckay89.cpp +++ b/src/rtsolver/toon_mckay89.cpp @@ -44,9 +44,9 @@ torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, // check bc if (bc->find(bname + "umu0") != bc->end()) { TORCH_CHECK(bc->at(bname + "umu0").dim() == 1, - "DisortImpl::forward: bc->umu0.dim() != 1"); + "ToonMcKay89::forward: bc->umu0.dim() != 1"); TORCH_CHECK(bc->at(bname + "umu0").size(0) == ncol, - "DisortImpl::forward: bc->umu0.size(0) != ncol"); + "ToonMcKay89::forward: bc->umu0.size(0) != ncol"); (*bc)["umu0"] = bc->at(bname + "umu0"); } else { (*bc)["umu0"] = torch::ones({1, ncol}, prop.options()); @@ -54,11 +54,11 @@ torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, if (bc->find(bname + "fbeam") != bc->end()) { TORCH_CHECK(bc->at(bname + "fbeam").dim() == 2, - "DisortImpl::forward: bc->fbeam.dim() != 2"); + "ToonMcKay89::forward: bc->fbeam.dim() != 2"); TORCH_CHECK(bc->at(bname + "fbeam").size(0) == nwave, - "DisortImpl::forward: bc->fbeam.size(0) != nwave"); + "ToonMcKay89::forward: bc->fbeam.size(0) != nwave"); TORCH_CHECK(bc->at(bname + "fbeam").size(1) == ncol, - "DisortImpl::forward: bc->fbeam.size(1) != ncol"); + "ToonMcKay89::forward: bc->fbeam.size(1) != ncol"); (*bc)["fbeam"] = bc->at(bname + "fbeam"); } else { (*bc)["fbeam"] = torch::zeros({nwave, ncol}, prop.options()); @@ -66,11 +66,11 @@ torch::Tensor ToonMcKay89Impl::forward(torch::Tensor prop, if (bc->find(bname + "albedo") != bc->end()) { TORCH_CHECK(bc->at(bname + "albedo").dim() == 2, - "DisortImpl::forward: bc->albedo.dim() != 2"); + "ToonMcKay89::forward: bc->albedo.dim() != 2"); TORCH_CHECK(bc->at(bname + "albedo").size(0) == nwave, - "DisortImpl::forward: bc->albedo.size(0) != nwave"); + "ToonMcKay89::forward: bc->albedo.size(0) != nwave"); TORCH_CHECK(bc->at(bname + "albedo").size(1) == ncol, - "DisortImpl::forward: bc->albedo.size(1) != ncol"); + "ToonMcKay89::forward: bc->albedo.size(1) != ncol"); (*bc)["albedo"] = bc->at(bname + "albedo"); } else { (*bc)["albedo"] = torch::zeros({nwave, ncol}, prop.options()); diff --git a/src/rtsolver/toon_mckay89_longwave_impl.h b/src/rtsolver/toon_mckay89_longwave_impl.h index 99ff86b..769ae5d 100644 --- a/src/rtsolver/toon_mckay89_longwave_impl.h +++ b/src/rtsolver/toon_mckay89_longwave_impl.h @@ -133,41 +133,41 @@ DISPATCH_MACRO void toon_mckay89_longwave(int nlay, const T *be, const T *prop, T bsurf_flux = Bsurf; // Bsurf is local variable // --- Matrix Construction (1-based indices for solver) --- - Af[1] = 0.0; - Bf[1] = gam[0] + 1.0; - Cf[1] = gam[0] - 1.0; - Df[1] = Btop - Cmm1[0]; + Af[0] = 0.0; + Bf[0] = gam[0] + 1.0; + Cf[0] = gam[0] - 1.0; + Df[0] = Btop - Cmm1[0]; int n_idx = 0; for (int i = 2; i <= lm2; i += 2) { - Af[i] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); - Bf[i] = (E2[n_idx] + E4[n_idx]) * (gam[n_idx + 1] - 1.0); - Cf[i] = 2.0 * (1.0 - gam[n_idx + 1] * gam[n_idx + 1]); - Df[i] = (gam[n_idx + 1] - 1.0) * (Cpm1[n_idx + 1] - Cp[n_idx]) + - (1.0 - gam[n_idx + 1]) * (Cm[n_idx] - Cmm1[n_idx + 1]); + Af[i - 1] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); + Bf[i - 1] = (E2[n_idx] + E4[n_idx]) * (gam[n_idx + 1] - 1.0); + Cf[i - 1] = 2.0 * (1.0 - gam[n_idx + 1] * gam[n_idx + 1]); + Df[i - 1] = (gam[n_idx + 1] - 1.0) * (Cpm1[n_idx + 1] - Cp[n_idx]) + + (1.0 - gam[n_idx + 1]) * (Cm[n_idx] - Cmm1[n_idx + 1]); n_idx++; } n_idx = 0; for (int i = 3; i <= lm1; i += 2) { - Af[i] = 2.0 * (1.0 - gam[n_idx] * gam[n_idx]); - Bf[i] = (E1[n_idx] - E3[n_idx]) * (1.0 + gam[n_idx + 1]); - Cf[i] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); - Df[i] = E3[n_idx] * (Cpm1[n_idx + 1] - Cp[n_idx]) + - E1[n_idx] * (Cm[n_idx] - Cmm1[n_idx + 1]); + Af[i - 1] = 2.0 * (1.0 - gam[n_idx] * gam[n_idx]); + Bf[i - 1] = (E1[n_idx] - E3[n_idx]) * (1.0 + gam[n_idx + 1]); + Cf[i - 1] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); + Df[i - 1] = E3[n_idx] * (Cpm1[n_idx + 1] - Cp[n_idx]) + + E1[n_idx] * (Cm[n_idx] - Cmm1[n_idx + 1]); n_idx++; } - Af[l] = E1[nlay - 1] - a_surf_in * E3[nlay - 1]; - Bf[l] = E2[nlay - 1] - a_surf_in * E4[nlay - 1]; - Cf[l] = 0.0; - Df[l] = bsurf_flux - Cp[nlay - 1] + a_surf_in * Cm[nlay - 1]; + Af[l - 1] = E1[nlay - 1] - a_surf_in * E3[nlay - 1]; + Bf[l - 1] = E2[nlay - 1] - a_surf_in * E4[nlay - 1]; + Cf[l - 1] = 0.0; + Df[l - 1] = bsurf_flux - Cp[nlay - 1] + a_surf_in * Cm[nlay - 1]; dtridgl(l, Af, Bf, Cf, Df, xkk); for (int n = 0; n < nlay; n++) { - xk1[n] = xkk[2 * n + 1] + xkk[2 * n + 2]; - xk2[n] = xkk[2 * n + 1] - xkk[2 * n + 2]; + xk1[n] = xkk[2 * n] + xkk[2 * n + 1]; + xk2[n] = xkk[2 * n] - xkk[2 * n + 1]; if (fabs(xk2[n]) < 1e-30 * fabs(xkk[2 * n + 1])) xk2[n] = 0.0; if (w0[n] <= 1e-4) { diff --git a/src/rtsolver/toon_mckay89_shortwave_impl.h b/src/rtsolver/toon_mckay89_shortwave_impl.h index 9f9498d..5338b10 100644 --- a/src/rtsolver/toon_mckay89_shortwave_impl.h +++ b/src/rtsolver/toon_mckay89_shortwave_impl.h @@ -149,38 +149,38 @@ DISPATCH_MACRO void toon_mckay89_shortwave(int nlay, T F0_in, T const *mu_in, } // Matrix Setup - Af[1] = 0.0; - Bf[1] = gam[0] + 1.0; - Cf[1] = gam[0] - 1.0; - Df[1] = btop - Cmm1[0]; + Af[0] = 0.0; + Bf[0] = gam[0] + 1.0; + Cf[0] = gam[0] - 1.0; + Df[0] = btop - Cmm1[0]; int n_idx = 0; for (int i = 2; i <= lm2; i += 2) { - Af[i] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); - Bf[i] = (E2[n_idx] + E4[n_idx]) * (gam[n_idx + 1] - 1.0); - Cf[i] = 2.0 * (1.0 - gam[n_idx + 1] * gam[n_idx + 1]); - Df[i] = (gam[n_idx + 1] - 1.0) * (Cpm1[n_idx + 1] - Cp[n_idx]) + - (1.0 - gam[n_idx + 1]) * (Cm[n_idx] - Cmm1[n_idx + 1]); + Af[i - 1] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); + Bf[i - 1] = (E2[n_idx] + E4[n_idx]) * (gam[n_idx + 1] - 1.0); + Cf[i - 1] = 2.0 * (1.0 - gam[n_idx + 1] * gam[n_idx + 1]); + Df[i - 1] = (gam[n_idx + 1] - 1.0) * (Cpm1[n_idx + 1] - Cp[n_idx]) + + (1.0 - gam[n_idx + 1]) * (Cm[n_idx] - Cmm1[n_idx + 1]); n_idx++; } n_idx = 0; for (int i = 3; i <= lm1; i += 2) { - Af[i] = 2.0 * (1.0 - gam[n_idx] * gam[n_idx]); - Bf[i] = (E1[n_idx] - E3[n_idx]) * (1.0 + gam[n_idx + 1]); - Cf[i] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); - Df[i] = E3[n_idx] * (Cpm1[n_idx + 1] - Cp[n_idx]) + - E1[n_idx] * (Cm[n_idx] - Cmm1[n_idx + 1]); + Af[i - 1] = 2.0 * (1.0 - gam[n_idx] * gam[n_idx]); + Bf[i - 1] = (E1[n_idx] - E3[n_idx]) * (1.0 + gam[n_idx + 1]); + Cf[i - 1] = (E1[n_idx] + E3[n_idx]) * (gam[n_idx + 1] - 1.0); + Df[i - 1] = E3[n_idx] * (Cpm1[n_idx + 1] - Cp[n_idx]) + + E1[n_idx] * (Cm[n_idx] - Cmm1[n_idx + 1]); n_idx++; } - Af[l] = E1[nlay - 1] - w_surf_in * E3[nlay - 1]; - Bf[l] = E2[nlay - 1] - w_surf_in * E4[nlay - 1]; - Cf[l] = 0.0; - Df[l] = bsurf - Cp[nlay - 1] + w_surf_in * Cm[nlay - 1]; + Af[l - 1] = E1[nlay - 1] - w_surf_in * E3[nlay - 1]; + Bf[l - 1] = E2[nlay - 1] - w_surf_in * E4[nlay - 1]; + Cf[l - 1] = 0.0; + Df[l - 1] = bsurf - Cp[nlay - 1] + w_surf_in * Cm[nlay - 1]; dtridgl(l, Af, Bf, Cf, Df, xk); for (int n = 0; n < nlay; n++) { - xk1[n] = xk[2 * n + 1] + xk[2 * n + 2]; - xk2[n] = xk[2 * n + 1] - xk[2 * n + 2]; + xk1[n] = xk[2 * n] + xk[2 * n + 1]; + xk2[n] = xk[2 * n] - xk[2 * n + 1]; if (fabs(xk2[n]) < 1e-30 * fabs(xk[2 * n + 1])) xk2[n] = 0.0; FLX_UP(n) = xk1[n] + gam[n] * xk2[n] + Cpm1[n]; FLX_DN(n) = xk1[n] * gam[n] + xk2[n] + Cmm1[n];