Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/models/include/models/candle_uno/candle_uno.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ CandleUnoConfig get_default_candle_uno_config();
* this model.
*
* @param CandleUnoConfig The config of the Candle Uno model.
* @return ComputationGraph The PCG of a Transformer model.
* @return ComputationGraph The computation graph of a Candle Uno model.
*/
ComputationGraph get_candle_uno_computation_graph(CandleUnoConfig const &);

Expand Down
92 changes: 92 additions & 0 deletions lib/models/include/models/simvp/simvp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SIMVP_H
#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SIMVP_H

#include "pcg/computation_graph_builder.h"
#include "simvp_config.dtg.h"

namespace FlexFlow {

// Helper functions to construct the SimVP model

/**
* @brief Get the default configs of SimVP model.
*/
SimVPConfig get_default_simvp_config();

// Refer to
// https://github.com/chengtan9907/OpenSTL/blob/b658dab3da427c8750c8595316e7ae9d70b818df/openstl/models/simvp_model.py#L51
std::vector<bool> create_simvp_samplings(size_t N_S, bool reverse = false);

// Refer to
// https://github.com/chengtan9907/OpenSTL/blob/b658dab3da427c8750c8595316e7ae9d70b818df/openstl/modules/simvp_modules.py#L57
tensor_guid_t create_simvp_convsc(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &input,
size_t in_dim,
size_t out_dim,
int kernel_size = 3,
bool downsampling = false,
bool upsampling = false);

// Refer to
// https://github.com/chengtan9907/OpenSTL/blob/b658dab3da427c8750c8595316e7ae9d70b818df/openstl/models/simvp_model.py#L150
tensor_guid_t create_simvp_gsta_meta_block(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &input,
int in_channels,
int out_channels,
float mlp_ratio = 8.0,
float drop = 0.0,
float drop_path = 0.0);

// Refer to
// https://github.com/chengtan9907/OpenSTL/blob/b658dab3da427c8750c8595316e7ae9d70b818df/openstl/modules/simvp_modules.py#L181
tensor_guid_t create_simvp_ga_sub_block(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &input,
int dim,
int kernel_size = 21,
float mlp_ratio = 4.0,
float drop = 0.0,
float drop_path = 0.1,
float init_value = 1e-2);

// Refer to
// https://github.com/chengtan9907/OpenSTL/blob/b658dab3da427c8750c8595316e7ae9d70b818df/openstl/models/simvp_model.py#L57
std::pair<tensor_guid_t, tensor_guid_t>
create_simvp_encoder(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &input);

// Refer to
// https://github.com/chengtan9907/OpenSTL/blob/b658dab3da427c8750c8595316e7ae9d70b818df/openstl/models/simvp_model.py#L100
tensor_guid_t create_simvp_middle_net(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &embed,
int channel_in,
int channel_hid,
float mlp_ratio = 4.0,
float drop = 0.0,
float drop_path = 0.1);

// Refer to
// https://github.com/chengtan9907/OpenSTL/blob/b658dab3da427c8750c8595316e7ae9d70b818df/openstl/models/simvp_model.py#L78
tensor_guid_t create_simvp_decoder(ComputationGraphBuilder &cgb,
SimVPConfig const &config,
tensor_guid_t const &hid,
tensor_guid_t const &skip);

/**
* @brief Get the SimVP computation graph.
*
* @details Refered OpenSTL implementation at
* https://github.com/chengtan9907/OpenSTL/blob/b658dab3da427c8750c8595316e7ae9d70b818df/openstl/models/simvp_model.py#L9
*
* @param SimVPConfig The config of the SimVP model.
* @return ComputationGraph The computation graph of a SimVP model.
*/
ComputationGraph get_simvp_computation_graph(SimVPConfig const &config);

} // namespace FlexFlow

#endif
74 changes: 74 additions & 0 deletions lib/models/include/models/simvp/simvp_config.struct.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
namespace = "FlexFlow"
name = "SimVPConfig"

features = [
"eq",
"ord",
"hash",
"json",
"rapidcheck",
"fmt",
]

includes = [
"<vector>",
"<map>",
"<string>",
"models/simvp/simvp_model_type.dtg.h",
"utils/nonnegative_int/nonnegative_int.h",
]

src_includes = [
"utils/fmt/vector.h",
"utils/fmt/map.h",
"utils/hash/vector.h",
"utils/hash/map.h",
]

[[fields]]
name = "batch_size"
type = "::FlexFlow::nonnegative_int"

[[fields]]
name = "hid_S"
type = "::FlexFlow::nonnegative_int"

[[fields]]
name = "hid_T"
type = "::FlexFlow::nonnegative_int"

[[fields]]
name = "N_S"
type = "::FlexFlow::nonnegative_int"

[[fields]]
name = "N_T"
type = "::FlexFlow::nonnegative_int"

[[fields]]
name = "model_type"
type = "FlexFlow::SimVPModelType"

[[fields]]
name = "mlp_ratio"
type = "float"

[[fields]]
name = "drop"
type = "float"

[[fields]]
name = "drop_path"
type = "float"

[[fields]]
name = "spatio_kernel_enc"
type = "::FlexFlow::nonnegative_int"

[[fields]]
name = "spatio_kernel_dec"
type = "::FlexFlow::nonnegative_int"

[[fields]]
name = "in_shape"
type = "std::vector<::FlexFlow::nonnegative_int>"
7 changes: 7 additions & 0 deletions lib/models/include/models/simvp/simvp_model_type.enum.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace = "FlexFlow"
name = "SimVPModelType"

features = ["hash", "json", "rapidcheck", "fmt"]

[[values]]
name = "gSTA"
Loading
Loading