diff --git a/python/csrc/pyintegrator.cpp b/python/csrc/pyintegrator.cpp index 5ad2cf2..6e36f00 100644 --- a/python/csrc/pyintegrator.cpp +++ b/python/csrc/pyintegrator.cpp @@ -31,8 +31,10 @@ void bind_integrator(py::module &m) { m, "IntegratorOptions"); pyIntegratorOptions.def(py::init<>()) - .def_static("from_yaml", &harp::IntegratorOptionsImpl::from_yaml, - py::arg("filename")) + .def_static("from_yaml", + py::overload_cast( + &harp::IntegratorOptionsImpl::from_yaml), + py::arg("filename"), py::arg("verbose") = false) .def("__repr__", [](const harp::IntegratorOptions &a) { std::stringstream ss; diff --git a/src/integrator/integrator.cpp b/src/integrator/integrator.cpp index faba0b3..bfec375 100644 --- a/src/integrator/integrator.cpp +++ b/src/integrator/integrator.cpp @@ -7,20 +7,27 @@ namespace harp { -IntegratorOptions IntegratorOptionsImpl::from_yaml( - std::string const& filename) { +IntegratorOptions IntegratorOptionsImpl::from_yaml(std::string const& filename, + bool verbose) { auto op = IntegratorOptionsImpl::create(); + op->verbose() = verbose; auto config = YAML::LoadFile(filename); if (!config["integration"]) return op; + return from_yaml(config["integration"], verbose); +} - op->type() = config["integration"]["type"].as("rk3"); - op->cfl() = config["integration"]["cfl"].as(0.9); - op->tlim() = config["integration"]["tlim"].as(1.e9); - op->nlim() = config["integration"]["nlim"].as(-1); - op->ncycle_out() = config["integration"]["ncycle_out"].as(1); - op->restart() = config["integration"]["restart"].as(""); +IntegratorOptions IntegratorOptionsImpl::from_yaml(YAML::Node const& node, + bool verbose) { + auto op = IntegratorOptionsImpl::create(); + op->type() = node["type"].as("rk3"); + op->cfl() = node["cfl"].as(0.9); + op->tlim() = node["tlim"].as(1.e9); + op->nlim() = node["nlim"].as(-1); + op->ncycle_out() = node["ncycle_out"].as(1); + op->restart() = node["restart"].as(""); + op->verbose() = node["verbose"].as(verbose); return op; } @@ -100,7 +107,7 @@ torch::Tensor IntegratorImpl::forward(int s, torch::Tensor u0, torch::Tensor u1, std::to_string(stages.size()) + ")"); } - auto out = torch::empty_like(u0); + /*auto out = torch::empty_like(u0); auto iter = at::TensorIteratorConfig() .add_output(out) .add_input(u0) @@ -109,8 +116,18 @@ torch::Tensor IntegratorImpl::forward(int s, torch::Tensor u0, torch::Tensor u1, .build(); at::native::call_average3(out.device().type(), iter, stages[s].wght0(), - stages[s].wght1(), stages[s].wght2()); - + stages[s].wght1(), stages[s].wght2());*/ + auto out = + stages[s].wght0() * u0 + stages[s].wght1() * u1 + stages[s].wght2() * u2; return out; } + +std::shared_ptr IntegratorImpl::create( + IntegratorOptions const& opts, torch::nn::Module* p, + std::string const& name) { + TORCH_CHECK(p, "[Integrator] Parent module is null"); + TORCH_CHECK(opts, "[Integrator] Options pointer is null"); + return p->register_module(name, Integrator(opts)); +} + } // namespace harp diff --git a/src/integrator/integrator.hpp b/src/integrator/integrator.hpp index a38040a..0c99497 100644 --- a/src/integrator/integrator.hpp +++ b/src/integrator/integrator.hpp @@ -12,6 +12,10 @@ // according to: // https://gkeyll.readthedocs.io/en/latest/dev/ssp-rk.html +namespace YAML { +class Node; +} // namespace YAML + namespace harp { struct IntegratorWeight { void report(std::ostream& os) const { @@ -29,7 +33,9 @@ struct IntegratorOptionsImpl { return std::make_shared(); } static std::shared_ptr from_yaml( - std::string const& filename); + std::string const& filename, bool verbose = false); + static std::shared_ptr from_yaml( + YAML::Node const& node, bool verbose = false); void report(std::ostream& os) const { os << "* type = " << type() << "\n" @@ -38,6 +44,7 @@ struct IntegratorOptionsImpl { << "* nlim = " << nlim() << "\n" << "* ncycle_out = " << ncycle_out() << "\n" << "* max_redo = " << max_redo() << "\n" + << "* verbose = " << verbose() << "\n" << "* restart = " << restart() << "\n"; } @@ -47,12 +54,27 @@ struct IntegratorOptionsImpl { ADD_ARG(int, nlim) = -1; ADD_ARG(int, ncycle_out) = 1; ADD_ARG(int, max_redo) = 5; + ADD_ARG(bool, verbose) = false; ADD_ARG(std::string, restart) = ""; }; using IntegratorOptions = std::shared_ptr; class IntegratorImpl : public torch::nn::Cloneable { public: + //! Create and register an `Integrator` module + /*! + * This function registers the created module as a submodule + * of the given parent module `p`. + * + * \param[in] opts options for creating the `Integrator` module + * \param[in] p parent module for registering the created module + * \param[in] name name for registering the created module + * \return created `Integrator` module + */ + static std::shared_ptr create( + IntegratorOptions const& opts, torch::nn::Module* p, + std::string const& name = "intg"); + int current_redo = 0; //! options with which this `Integrator` was constructed