Skip to content
Merged
6 changes: 5 additions & 1 deletion bin/export-model-arch/src/export_model_arch.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h"
#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h"
#include "export_model_arch/json_sp_model_export.dtg.h"
#include "models/inception_v3/inception_v3.h"
#include "models/split_test/split_test.h"
#include "models/transformer/transformer.h"
#include "op-attrs/computation_graph_op_attrs.h"
Expand Down Expand Up @@ -59,6 +60,9 @@ tl::expected<ComputationGraph, std::string>
get_model_computation_graph(std::string const &model_name) {
if (model_name == "transformer") {
return get_default_transformer_computation_graph();
} else if (model_name == "inception_v3") {
return get_inception_v3_computation_graph(
get_default_inception_v3_training_config());
} else if (model_name == "split_test") {
int batch_size = 8;
return get_split_test_computation_graph(batch_size);
Expand Down Expand Up @@ -132,7 +136,7 @@ int main(int argc, char **argv) {
"for preprocessed to help check series-parallel structure"});

std::vector<std::string> model_options = {
"transformer", "split_test", "single_operator"};
"transformer", "inception_v3", "split_test", "single_operator"};
CLIArgumentKey key_model_name = cli_add_positional_argument(
cli,
CLIPositionalArgumentSpec{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h"
#include "models/inception_v3/inception_v3.h"
#include "models/split_test/split_test.h"
#include "models/transformer/transformer.h"
#include "pcg/computation_graph.h"
Expand Down Expand Up @@ -291,6 +292,16 @@ TEST_SUITE(FF_TEST_SUITE) {

CHECK(sp_decomposition.has_value());
}

SUBCASE("inception_v3") {
ComputationGraph cg = get_inception_v3_computation_graph(
get_default_inception_v3_training_config());

std::optional<SeriesParallelDecomposition> sp_decomposition =
get_computation_graph_series_parallel_decomposition(cg);

CHECK(sp_decomposition.has_value());
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions lib/local-execution/src/ops/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static std::optional<float> forward_task_impl(TaskArgumentAccessor const &acc) {
auto output = acc.get_tensor<Permissions::WO>(OUTPUT);
auto inputs = acc.get_variadic_tensor<Permissions::RO>(INPUTS);

assert(attrs.num_inputs <= MAX_NUM_INPUTS);
assert(inputs.size() <= MAX_NUM_INPUTS);

return profile(forward_kernel,
profiling,
Expand All @@ -68,7 +68,7 @@ static std::optional<float>
auto input_grads = acc.get_variadic_tensor_grad<Permissions::RW>(INPUTS);
auto output_grad = acc.get_tensor_grad<Permissions::RO>(OUTPUT);

assert(attrs.num_inputs <= MAX_NUM_INPUTS);
assert(input_grads.size() <= MAX_NUM_INPUTS);

return profile(backward_kernel,
profiling,
Expand Down
23 changes: 23 additions & 0 deletions lib/models/include/models/inception_v3/inception_v3.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3
#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3

#include "models/inception_v3/inception_v3_config.dtg.h"
#include "pcg/computation_graph.dtg.h"

namespace FlexFlow {

/**
* @brief Get the default training config from https://arxiv.org/abs/1512.00567.
*/
InceptionV3Config get_default_inception_v3_training_config();

/**
* @brief Get a computation graph for Inception-v3 as described in
* https://arxiv.org/abs/1512.00567.
*/
ComputationGraph
get_inception_v3_computation_graph(InceptionV3Config const &config);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace = "FlexFlow"
name = "InceptionV3Config"

features = [
"eq",
"ord",
"hash",
"json",
"rapidcheck",
"fmt",
]

[[fields]]
name = "num_classes"
type = "int"

[[fields]]
name = "batch_size"
type = "int"

[[fields]]
name = "aux_logits"
type = "bool"
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
namespace = "FlexFlow"
name = "InceptionV3Output"
features = [
"eq",
"ord",
"hash",
"fmt",
]

includes = [
"pcg/tensor_guid_t.dtg.h",
"<optional>",
]

src_includes = [
"utils/fmt/optional.h",
]

[[fields]]
name = "standard_logits"
type = "::FlexFlow::tensor_guid_t"

[[fields]]
name = "aux_logits"
type = "std::optional<::FlexFlow::tensor_guid_t>"
Loading