Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/csrc/pyintegrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ void bind_integrator(py::module &m) {
m, "IntegratorOptions");

pyIntegratorOptions.def(py::init<>())
.def_static("from_yaml", &harp::IntegratorOptionsImpl::from_yaml,
py::arg("filename"))
.def_static("from_yaml",
py::overload_cast<std::string const &, bool>(
&harp::IntegratorOptionsImpl::from_yaml),
py::arg("filename"), py::arg("verbose") = false)
.def("__repr__",
[](const harp::IntegratorOptions &a) {
std::stringstream ss;
Expand Down
39 changes: 28 additions & 11 deletions src/integrator/integrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,27 @@

namespace harp {

IntegratorOptions IntegratorOptionsImpl::from_yaml(
std::string const& filename) {
IntegratorOptions IntegratorOptionsImpl::from_yaml(std::string const& filename,
bool verbose) {
auto op = IntegratorOptionsImpl::create();
op->verbose() = verbose;

auto config = YAML::LoadFile(filename);
if (!config["integration"]) return op;
return from_yaml(config["integration"], verbose);
}

op->type() = config["integration"]["type"].as<std::string>("rk3");
op->cfl() = config["integration"]["cfl"].as<double>(0.9);
op->tlim() = config["integration"]["tlim"].as<double>(1.e9);
op->nlim() = config["integration"]["nlim"].as<int>(-1);
op->ncycle_out() = config["integration"]["ncycle_out"].as<int>(1);
op->restart() = config["integration"]["restart"].as<std::string>("");
IntegratorOptions IntegratorOptionsImpl::from_yaml(YAML::Node const& node,
bool verbose) {
auto op = IntegratorOptionsImpl::create();

op->type() = node["type"].as<std::string>("rk3");
op->cfl() = node["cfl"].as<double>(0.9);
op->tlim() = node["tlim"].as<double>(1.e9);
op->nlim() = node["nlim"].as<int>(-1);
op->ncycle_out() = node["ncycle_out"].as<int>(1);
op->restart() = node["restart"].as<std::string>("");
op->verbose() = node["verbose"].as<bool>(verbose);
return op;
}

Expand Down Expand Up @@ -100,7 +107,7 @@ torch::Tensor IntegratorImpl::forward(int s, torch::Tensor u0, torch::Tensor u1,
std::to_string(stages.size()) + ")");
}

auto out = torch::empty_like(u0);
/*auto out = torch::empty_like(u0);
auto iter = at::TensorIteratorConfig()
.add_output(out)
.add_input(u0)
Expand All @@ -109,8 +116,18 @@ torch::Tensor IntegratorImpl::forward(int s, torch::Tensor u0, torch::Tensor u1,
.build();

at::native::call_average3(out.device().type(), iter, stages[s].wght0(),
stages[s].wght1(), stages[s].wght2());

stages[s].wght1(), stages[s].wght2());*/
auto out =
stages[s].wght0() * u0 + stages[s].wght1() * u1 + stages[s].wght2() * u2;
return out;
}

std::shared_ptr<IntegratorImpl> IntegratorImpl::create(
IntegratorOptions const& opts, torch::nn::Module* p,
std::string const& name) {
TORCH_CHECK(p, "[Integrator] Parent module is null");
TORCH_CHECK(opts, "[Integrator] Options pointer is null");
return p->register_module(name, Integrator(opts));
}

} // namespace harp
24 changes: 23 additions & 1 deletion src/integrator/integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
// according to:
// https://gkeyll.readthedocs.io/en/latest/dev/ssp-rk.html

namespace YAML {
class Node;
} // namespace YAML

namespace harp {
struct IntegratorWeight {
void report(std::ostream& os) const {
Expand All @@ -29,7 +33,9 @@ struct IntegratorOptionsImpl {
return std::make_shared<IntegratorOptionsImpl>();
}
static std::shared_ptr<IntegratorOptionsImpl> from_yaml(
std::string const& filename);
std::string const& filename, bool verbose = false);
static std::shared_ptr<IntegratorOptionsImpl> from_yaml(
YAML::Node const& node, bool verbose = false);

void report(std::ostream& os) const {
os << "* type = " << type() << "\n"
Expand All @@ -38,6 +44,7 @@ struct IntegratorOptionsImpl {
<< "* nlim = " << nlim() << "\n"
<< "* ncycle_out = " << ncycle_out() << "\n"
<< "* max_redo = " << max_redo() << "\n"
<< "* verbose = " << verbose() << "\n"
<< "* restart = " << restart() << "\n";
}

Expand All @@ -47,12 +54,27 @@ struct IntegratorOptionsImpl {
ADD_ARG(int, nlim) = -1;
ADD_ARG(int, ncycle_out) = 1;
ADD_ARG(int, max_redo) = 5;
ADD_ARG(bool, verbose) = false;
ADD_ARG(std::string, restart) = "";
};
using IntegratorOptions = std::shared_ptr<IntegratorOptionsImpl>;

class IntegratorImpl : public torch::nn::Cloneable<IntegratorImpl> {
public:
//! Create and register an `Integrator` module
/*!
* This function registers the created module as a submodule
* of the given parent module `p`.
*
* \param[in] opts options for creating the `Integrator` module
* \param[in] p parent module for registering the created module
* \param[in] name name for registering the created module
* \return created `Integrator` module
*/
static std::shared_ptr<IntegratorImpl> create(
IntegratorOptions const& opts, torch::nn::Module* p,
std::string const& name = "intg");

int current_redo = 0;

//! options with which this `Integrator` was constructed
Expand Down
Loading