diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index ccc720ed14..98b7a003ce 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,6 +1,7 @@ #include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" #include "op-attrs/computation_graph_op_attrs.h" @@ -59,6 +60,9 @@ tl::expected get_model_computation_graph(std::string const &model_name) { if (model_name == "transformer") { return get_default_transformer_computation_graph(); + } else if (model_name == "inception_v3") { + return get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); } else if (model_name == "split_test") { int batch_size = 8; return get_split_test_computation_graph(batch_size); @@ -132,7 +136,7 @@ int main(int argc, char **argv) { "for preprocessed to help check series-parallel structure"}); std::vector model_options = { - "transformer", "split_test", "single_operator"}; + "transformer", "inception_v3", "split_test", "single_operator"}; CLIArgumentKey key_model_name = cli_add_positional_argument( cli, CLIPositionalArgumentSpec{ diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc index ab537e73de..c9d84a8948 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,5 @@ #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" #include "pcg/computation_graph.h" @@ -291,6 +292,16 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(sp_decomposition.has_value()); } + + SUBCASE("inception_v3") { + ComputationGraph cg = get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } } } diff --git a/lib/local-execution/src/ops/concat.cc b/lib/local-execution/src/ops/concat.cc index 35f663b1cd..4c3462e694 100644 --- a/lib/local-execution/src/ops/concat.cc +++ b/lib/local-execution/src/ops/concat.cc @@ -50,7 +50,7 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto inputs = acc.get_variadic_tensor(INPUTS); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(inputs.size() <= MAX_NUM_INPUTS); return profile(forward_kernel, profiling, @@ -68,7 +68,7 @@ static std::optional auto input_grads = acc.get_variadic_tensor_grad(INPUTS); auto output_grad = acc.get_tensor_grad(OUTPUT); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(input_grads.size() <= MAX_NUM_INPUTS); return profile(backward_kernel, profiling, diff --git a/lib/models/include/models/inception_v3/inception_v3.h b/lib/models/include/models/inception_v3/inception_v3.h new file mode 100644 index 0000000000..5c4754e441 --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 + +#include "models/inception_v3/inception_v3_config.dtg.h" +#include "pcg/computation_graph.dtg.h" + +namespace FlexFlow { + +/** + * @brief Get the default training config from https://arxiv.org/abs/1512.00567. + */ +InceptionV3Config get_default_inception_v3_training_config(); + +/** + * @brief Get a computation graph for Inception-v3 as described in + * https://arxiv.org/abs/1512.00567. + */ +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/inception_v3/inception_v3_config.struct.toml b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml new file mode 100644 index 0000000000..a2a75c83bb --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "InceptionV3Config" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "num_classes" +type = "int" + +[[fields]] +name = "batch_size" +type = "int" + +[[fields]] +name = "aux_logits" +type = "bool" diff --git a/lib/models/include/models/inception_v3/inception_v3_output.struct.toml b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml new file mode 100644 index 0000000000..066e6df02b --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "InceptionV3Output" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "standard_logits" +type = "::FlexFlow::tensor_guid_t" + +[[fields]] +name = "aux_logits" +type = "std::optional<::FlexFlow::tensor_guid_t>" diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc new file mode 100644 index 0000000000..f540eae629 --- /dev/null +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -0,0 +1,750 @@ +#include "models/inception_v3/inception_v3.h" +#include "models/inception_v3/inception_v3_output.dtg.h" +#include "op-attrs/tensor_shape.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +struct CheckShape { + CheckShape(ComputationGraphBuilder const &cgb, + InceptionV3Config const &config) + : cgb(cgb), config(config) {} + + ComputationGraphBuilder const &cgb; + InceptionV3Config const &config; + + void operator()(tensor_guid_t t, int c, int h, int w) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + size_t_from_int(h), + size_t_from_int(w), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); + } + } + + void operator()(tensor_guid_t t, int c) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); + } + } +}; + +InceptionV3Config get_default_inception_v3_training_config() { + return InceptionV3Config{ + /*num_classes=*/1000, + + // see section 8 of https://arxiv.org/abs/1512.00567 for the source of the + // batch size + /*batch_size=*/32, + + // see section 4 of https://arxiv.org/abs/1512.00567 for a discussion of + // auxiliary logits. they are used by default in training + /*aux_logits=*/true, + }; +} + +static tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int filters, + int kernel_size_h, + int kernel_size_w, + int stride_h = 1, + int stride_w = 1, + int padding_h = 0, + int padding_w = 0, + bool use_bias = false) { + tensor_guid_t conv = cgb.conv2d(input, + /*outChannels=*/filters, + /*kernelH=*/kernel_size_h, + /*kernelW=*/kernel_size_w, + /*strideH=*/stride_h, + /*strideW=*/stride_w, + /*paddingH=*/padding_h, + /*paddingW=*/padding_w, + /*activation=*/std::nullopt, + /*groups=*/1, + /*use_bias=*/use_bias); + return cgb.batch_norm(conv, + /*affine=*/true, + /*activation=*/Activation::RELU, + /*eps=*/1e-5, + /*momentum=*/0.1); +} + +static tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int pool_features) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch5x5 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/48, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/2, + /*padding_w=*/2); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/pool_features, + /*kernel_stride_h=*/1, + /*kernel_stride_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = create_conv_block(cgb, + input, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_stride_h=*/3, + /*kernel_stride_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch3x3dbl, branch_pool}, /*axis=*/1); +} + +static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + int channels_7x7) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(branch1x1, 192, 17, 17); + + tensor_guid_t branch7x7 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + return t; + }(); + check_shape(branch7x7, 192, 17, 17); + + tensor_guid_t branch7x7dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + return t; + }(); + check_shape(branch7x7dbl, 192, 17, 17); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + check_shape(branch_pool, 192, 17, 17); + + return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, t, 320, 3, 3, 2, 2); + return t; + }(); + + tensor_guid_t branch7x7x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch7x7x3, branch_pool}, /*axis=*/1); +} + +static tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/320, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/448, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input) { + tensor_guid_t t = input; + + check_shape(t, 3, 299, 299); + + // Conv2d_1a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + check_shape(t, 32, 149, 149); + + // Conv2d_2a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 32, 147, 147); + + // Conv2d_2b_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + check_shape(t, 64, 147, 147); + + // maxpool1 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 64, 73, 73); + + // Conv2d_3b_1x1 + t = create_conv_block(cgb, + t, + /*filters=*/80, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(t, 80, 73, 73); + + // Conv2d_4a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 192, 71, 71); + + // maxpool2 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 192, 35, 35); + + return t; +} + +static tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + // avgpool + tensor_guid_t x = cgb.pool2d(input, + /*kernelH=*/8, + /*kernelW=*/8, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 2048, 1, 1); + + // dropout + x = cgb.dropout(x, + /*rate=*/0.5); + check_shape(x, 2048, 1, 1); + + x = cgb.flat(x, + /*start_dim=*/1); + check_shape(x, 2048); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); + + // softmax (not in pytorch model, but shown in Table 1 on p6 of + // https://arxiv.org/abs/1512.00567) + x = cgb.softmax(x); + check_shape(x, num_classes); + + return x; +} + +static tensor_guid_t create_inception_aux(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + tensor_guid_t x = input; + check_shape(x, 768, 17, 17); + + x = cgb.pool2d(x, + /*kernelH=*/5, + /*kernelW=*/5, + /*strideH=*/3, + /*strideW=*/3, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 768, 5, 5); + + // conv0 + x = create_conv_block(cgb, + x, + /*filters=*/128, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(x, 128, 5, 5); + + // conv1 + x = create_conv_block(cgb, + x, + /*filters=*/768, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5); + check_shape(x, 768, 1, 1); + + x = cgb.adaptive_pool2d(x, + /*output_h=*/1, + /*output_w=*/1); + check_shape(x, 768, 1, 1); + + x = cgb.flat(x, + /*start_dim=*/1); + check_shape(x, 768); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); + + return x; +} + +static InceptionV3Output create_inception_v3(ComputationGraphBuilder &cgb, + InceptionV3Config const &config, + tensor_guid_t const &input) { + // NOTE: the shapes for check_shape (as well as the layer names in comments) + // are pulled from + // https://github.com/pytorch/vision/blob/6d7851bd5e2bedc294e40e90532f0e375fcfee04/torchvision/models/inception.py#L103-L155 + CheckShape check_shape = CheckShape{ + /*cgb=*/cgb, + /*config=*/config, + }; + + tensor_guid_t x = create_initial_layers(cgb, check_shape, input); + check_shape(x, 192, 35, 35); + + // Mixed_5b + x = create_inception_module_a(cgb, x, 32); + check_shape(x, 256, 35, 35); + + // Mixed_5c + x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + + // Mixed_5d + x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + + // Mixed_6a + x = create_inception_module_b(cgb, x); + check_shape(x, 768, 17, 17); + + // Mixed_6b + x = create_inception_module_c(cgb, check_shape, x, 128); + check_shape(x, 768, 17, 17); + + // Mixed_6c + x = create_inception_module_c(cgb, check_shape, x, 160); + check_shape(x, 768, 17, 17); + + // Mixed_6d + x = create_inception_module_c(cgb, check_shape, x, 160); + check_shape(x, 768, 17, 17); + + // Mixed_6e + x = create_inception_module_c(cgb, check_shape, x, 192); + check_shape(x, 768, 17, 17); + + std::optional aux; + if (config.aux_logits) { + aux = create_inception_aux(cgb, check_shape, x, config.num_classes); + check_shape(aux.value(), config.num_classes); + } + + // Mixed_7a + x = create_inception_module_d(cgb, x); + check_shape(x, 1280, 8, 8); + + // Mixed_7b + x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); + + // Mixed_7c + x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); + + x = create_final_layers(cgb, check_shape, x, config.num_classes); + check_shape(x, config.num_classes); + + return InceptionV3Output{ + x, + aux, + }; +} + +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config) { + ComputationGraphBuilder cgb; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + 3, + 299, + 299, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES); + InceptionV3Output output = create_inception_v3(cgb, config, input); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/test/src/models/inception_v3/inception_v3.cc b/lib/models/test/src/models/inception_v3/inception_v3.cc new file mode 100644 index 0000000000..2b0fe82fd6 --- /dev/null +++ b/lib/models/test/src/models/inception_v3/inception_v3.cc @@ -0,0 +1,19 @@ +#include "models/inception_v3/inception_v3.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_inception_v3_computation_graph") { + InceptionV3Config config = get_default_inception_v3_training_config(); + + ComputationGraph result = get_inception_v3_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 522; + CHECK(result_num_layers == correct_num_layers); + } + } +} diff --git a/lib/models/test/src/models/transformer.cc b/lib/models/test/src/models/transformer/transformer.cc similarity index 100% rename from lib/models/test/src/models/transformer.cc rename to lib/models/test/src/models/transformer/transformer.cc diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/concat.h b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h new file mode 100644 index 0000000000..9b9eaf9b93 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H + +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template +FFOrdered concat(FFOrdered const &l, FFOrdered const &r) { + std::vector l_vec = std::vector(l.cbegin(), l.cend()); + std::vector r_vec = std::vector(r.cbegin(), r.cend()); + + std::vector raw_result = concat_vectors(l_vec, r_vec); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +template +FFOrdered concat(std::vector> const &inputs) { + std::vector> vec_inputs = + transform(inputs, [](FFOrdered const &input) { + return std::vector(input.cbegin(), input.cend()); + }); + + std::vector raw_result = concat_vectors(vec_inputs); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h index 34d186e74e..6aa23d40fc 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h @@ -10,7 +10,7 @@ namespace FlexFlow { template struct DimOrdered { - DimOrdered() = delete; + DimOrdered() {} DimOrdered(std::initializer_list const &l) : contents(l.begin(), l.end()) {} @@ -138,6 +138,10 @@ struct DimOrdered { return this->contents.size(); } + size_t empty() const { + return this->contents.empty(); + } + size_t num_dims() const { return this->size(); } diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h new file mode 100644 index 0000000000..79d4929797 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H + +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "op-attrs/dim_ordered/ff_ordered_of.h" + +namespace FlexFlow { + +template +FFOrdered ff_ordered_from_map(std::map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +template +FFOrdered ff_ordered_from_map(std::unordered_map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index 23b971da6b..e4c0e8e275 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -21,6 +21,13 @@ DimOrdered nonoverloaded_slice(DimOrdered const &d, subvec(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; } +template +FFOrdered slice(FFOrdered const &d, + std::optional const &start, + std::optional const &end) { + return nonoverloaded_slice(d, start, end); +} + template DimOrdered slice(DimOrdered const &d, std::optional const &start, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 8afcbb06b1..f2e95690d1 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -1,15 +1,42 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &); +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &); + +tl::expected get_output_shape(BatchNormAttrs const &, + TensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, TensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, TensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_gamma_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); + +tl::expected + get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, + ParallelTensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml index bc82f3c743..fdc3bce1fe 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml @@ -10,6 +10,28 @@ features = [ "fmt", ] +includes = [ + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + [[fields]] name = "relu" type = "bool" + +[[fields]] +name = "affine" +type = "bool" + +[[fields]] +name = "eps" +type = "float" + +[[fields]] +name = "momentum" +type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index f3ac8494c0..f07f06df85 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -10,10 +10,11 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ConcatAttrs); -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml index 4faa870bc4..fab8132993 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -17,7 +17,3 @@ includes = [ [[fields]] name = "axis" type = "::FlexFlow::ff_dim_t" - -[[fields]] -name = "num_inputs" -type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 676d21c59b..710cbdb44b 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -3,6 +3,7 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -11,8 +12,11 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(FlatAttrs); TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &); +tl::expected + get_output_parallel_dim_degrees(FlatAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_output_shape(FlatAttrs const &, ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml index e445535e29..7349e2a8c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -8,4 +8,23 @@ features = [ "rapidcheck", "fmt", ] -fields = [] + +includes = [ + "", + "op-attrs/ff_dim.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "start_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "end_dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 505fdd9f8c..1af22ad022 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -3,6 +3,7 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -10,9 +11,22 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &); +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation); + +tl::expected get_output_shape(Pool2DAttrs const &, + TensorShape const &); + +tl::expected + get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(Pool2DAttrs const &, + ParallelTensorDimDegrees const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml index 56bf682f50..20ca7deabc 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -12,6 +12,13 @@ features = [ includes = [ "op-attrs/pool_op.dtg.h", "op-attrs/activation.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] @@ -44,4 +51,4 @@ type = "::FlexFlow::PoolOp" [[fields]] name = "activation" -type = "::FlexFlow::Activation" +type = "std::optional<::FlexFlow::Activation>" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml new file mode 100644 index 0000000000..974b27d2a7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ParallelTensorDimDegrees" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/dim_ordered/dim_ordered.h", +] + +[[fields]] +name = "sum_degree" +type = "::FlexFlow::SumDegree" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "shard_degrees" +type = "::FlexFlow::FFOrdered" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 8e02e3607b..7a89b4bd78 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_dims.dtg.h" #include "op-attrs/tensor_dims.dtg.h" @@ -14,6 +15,8 @@ std::unordered_set replica_dims(ParallelTensorDims const &); /* size_t get_volume(ParallelTensorDims const &); */ size_t num_shard_dims(ParallelTensorDims const &); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &); + int total_replica_degree(ParallelTensorDims const &); int total_shard_degree(ParallelTensorDims const &); int total_parallel_degree(ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 76356b39d4..806a5f0de7 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,6 +1,7 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.h" @@ -17,12 +18,17 @@ FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &); + ParallelTensorShape lift_to_parallel(TensorShape const &); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &, + ParallelTensorDimDegrees const &); std::unordered_set replica_dims(ParallelTensorShape const &); diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..cb29f708a3 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/concat.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..2de88f38c8 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc index c7febde1d6..21efc26466 100644 --- a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -1,5 +1,6 @@ #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" @@ -22,8 +23,8 @@ std::vector return std::vector{IncomingTensorRole::INPUT, IncomingTensorRole::INPUT}; }, - [](BatchNormAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; + [](BatchNormAttrs const &attrs) { + return get_batch_norm_incoming_tensor_roles(attrs); }, [](BroadcastAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc index d91d1a1eca..0058ee35a2 100644 --- a/lib/op-attrs/src/op-attrs/get_output_shapes.cc +++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc @@ -14,6 +14,7 @@ #include "op-attrs/ops/input.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/replicate.h" #include "op-attrs/ops/weight.h" #include "utils/overload.h" @@ -29,7 +30,7 @@ std::vector get_output_shape(attrs, inputs.at(0), inputs.at(1)))}; }, [&](BatchNormAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](CastAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; @@ -38,7 +39,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](ConcatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs)}; + return {throw_if_unexpected(get_output_shape(attrs, inputs))}; }, [&](Conv2DAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; @@ -57,7 +58,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](FlatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](GatherAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0), inputs.at(1))}; @@ -71,6 +72,9 @@ std::vector [&](LinearAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, + [&](Pool2DAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, [&](ReplicateAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; }, diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index b75c3521c6..f394bb8473 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -1,15 +1,260 @@ #include "op-attrs/ops/batch_norm.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/any_of.h" +#include "utils/containers/extend.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, - TensorShape const &input_shape) { +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &attrs) { + std::vector result = {IncomingTensorRole::INPUT}; + + if (attrs.affine) { + extend(result, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return result; +} + +static std::optional + check_input_shape(BatchNormAttrs const &, TensorShape const &input_shape) { + if (num_dims(input_shape) < 2) { + return fmt::format( + "BatchNormAttrs expected input dims >= 2, but received input shape {}", + input_shape); + } + + if (input_shape.data_type != DataType::FLOAT) { + return fmt::format("BatchNormAttrs currently only supports data_type = " + "FLOAT, but received input data_type {}. " + "If you need this feature, please create an issue.", + input_shape.data_type); + } + + return std::nullopt; +} + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + return input_shape; } -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No gamma weights exist for attrs.affine = false"); + } + + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + + return TensorShape{ + TensorDims{FFOrdered{ + num_channels, + }}, + DataType::FLOAT, + }; +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + + if (!attrs.affine) { + return tl::unexpected("No beta weights exist for attrs.affine = false"); + } + + return get_gamma_weights_shape(attrs, input_shape); +} + +static std::optional + check_input_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.shard_degrees.size() < 2) { + return fmt::format("BatchNormAttrs expected input dims >= 2, but received " + "input degrees {}", + input_degrees); + } + + if (input_degrees.sum_degree != SumDegree{1}) { + return fmt::format("Expected sum degree 1, but receieved sum degree {}", + input_degrees.sum_degree); + } + + if (input_degrees.discard_copy_degree != DiscardCopyDegree{1}) { + return fmt::format( + "Expected discard copy degree 1, but receieved discard copy degree {}", + input_degrees.discard_copy_degree); + } + + FFOrdered non_channel_degrees = + concat(slice(input_degrees.shard_degrees, ff_dim_t{0}, ff_dim_t{1}), + slice(input_degrees.shard_degrees, ff_dim_t{2}, std::nullopt)); + + if (any_of(non_channel_degrees, [](int degree) { return degree != 1; })) { + return fmt::format("Expected parallel degree of all non-channel dimensions " + "to be 1, but received input with degrees {}", + input_degrees); + } + + return std::nullopt; +} + +tl::expected + get_output_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + return input_degrees; +} + +tl::expected + get_gamma_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No gamma weights exist for attrs.affine = false"); + } + + ff_dim_t channel_dim = ff_dim_t{1}; + + return ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{input_degrees.shard_degrees.at(channel_dim)}, + }; +} + +tl::expected + get_beta_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No beta weights exist for attrs.affine = false"); + } + + return get_gamma_weights_parallel_dim_degrees(attrs, input_degrees); +} + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected returned = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = + get_gamma_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_gamma_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = + get_beta_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_beta_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 02fee70bea..74295f279e 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -1,24 +1,129 @@ #include "op-attrs/ops/concat.h" +#include "op-attrs/dim_ordered/enumerate.h" +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/all_of.h" +#include "utils/containers/are_all_same.h" +#include "utils/containers/require_all_same1.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/fmt/map.h" namespace FlexFlow { -/* bool ConcatAttrs::is_valid( */ -/* std::vector const &input) const { */ -/* bool valid = true; */ -/* for (auto p : input) { */ -/* valid &= p.is_valid(); */ -/* } */ -/* return valid; */ -/* } */ - -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + auto get_non_axis_dims = [&](TensorShape const &s) { + std::map dim_sizes = enumerate(ff_ordered(s.dims)); + dim_sizes.erase(attrs.axis); + return dim_sizes; + }; + + if (inputs.size() <= 1) { + return tl::unexpected(fmt::format("get_output_shape for Concat expected 2 " + "or more input, but receieved {}", + inputs)); + } + + if (attrs.axis.value < 0) { + return tl::unexpected(fmt::format("ConcatAttrs requires axis >= 0")); + } + + if (!are_all_same(transform( + inputs, [](TensorShape const &s) { return num_dims(s); }))) { + return tl::unexpected( + fmt::format("get_output_shape for Concat expected all inputs to have " + "the same number of dimensions, but receieved {}", + inputs)); + } + + std::map non_axis_dims = ({ + tl::expected, std::string> returned = + require_all_same1(transform(inputs, get_non_axis_dims)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + std::vector axis_dim_sizes = transform( + inputs, [&](TensorShape const &s) { return dim_at_idx(s, attrs.axis); }); + + size_t output_axis_dim_size = sum(axis_dim_sizes); + + non_axis_dims.insert({attrs.axis, output_axis_dim_size}); + + DataType datatype = ({ + tl::expected returned = require_all_same1( + transform(inputs, [](TensorShape const &s) { return s.data_type; })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return TensorShape{ + TensorDims{ + ff_ordered_from_map(non_axis_dims), + }, + datatype, + }; } -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + TensorShape unpar = ({ + tl::expected returned = + get_output_shape(attrs, transform(inputs, get_reduced_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + SumDegree sum_degree = ({ + tl::expected returned = + require_all_same1(transform(inputs, get_sum_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + SumDegree{returned.value()}; + }); + + DiscardCopyDegree discard_copy_degree = ({ + tl::expected returned = + require_all_same1(transform(inputs, get_discard_copy_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + DiscardCopyDegree{returned.value()}; + }); + + if (!all_of(inputs, [&](ParallelTensorShape const &s) { + return shard_dim_at_idx(s, attrs.axis).degree == 1; + })) { + return tl::unexpected(fmt::format( + "get_output_shape for Concat expected input tensors to have parallel " + "degree 1 in the concat axis dimension, but received {}", + inputs)); + } + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + require_all_same1(transform(inputs, [](ParallelTensorShape const &s) { + return get_parallel_degrees(s); + })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index f77daf451f..eac756cc15 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -54,11 +54,11 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, Conv2DInputShape input = parse_input_shape(raw_input_shape); size_t out_height = - (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / - attrs.stride_h; + (input.height + (2 * attrs.padding_h) - attrs.kernel_h) / attrs.stride_h + + 1; size_t out_width = - (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / - attrs.stride_w; + (input.width + (2 * attrs.padding_w) - attrs.kernel_w) / attrs.stride_w + + 1; assert(attrs.out_channels > 0); diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index 5d318207ee..e9833d5e3f 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -1,57 +1,85 @@ #include "op-attrs/ops/flat.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "utils/containers/any_of.h" +#include "utils/containers/product.h" #include namespace FlexFlow { -TensorShape get_output_shape(FlatAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); +TensorShape get_output_shape(FlatAttrs const &attrs, + TensorShape const &input_shape) { + FFOrdered leading_dims = + slice(ff_ordered(input_shape.dims), ff_dim_t{0}, attrs.start_dim); + FFOrdered flattened_dims = + slice(ff_ordered(input_shape.dims), attrs.start_dim, attrs.end_dim); + FFOrdered trailing_dims = + slice(ff_ordered(input_shape.dims), attrs.end_dim, std::nullopt); + + if (flattened_dims.empty()) { + return input_shape; + } + + return TensorShape{ + TensorDims{ + concat(std::vector{ + leading_dims, + {product(flattened_dims)}, + trailing_dims, + }), + }, + input_shape.data_type, + }; } -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_parallel_dim_degrees( + FlatAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + FFOrdered flattened_dim_degrees = + slice(input_degrees.shard_degrees, attrs.start_dim, attrs.end_dim); + + if (flattened_dim_degrees.empty()) { + return input_degrees; + } + + if (any_of(flattened_dim_degrees, [](int degree) { return degree != 1; })) { + return tl::unexpected( + fmt::format("get_output_parallel_dim_degrees for {} expected all shard " + "degrees of flattened dimensions to be 1, but received {}", + attrs, + input_degrees)); + } + + return ParallelTensorDimDegrees{ + /*sum_degree=*/input_degrees.sum_degree, + /*discard_copy_degree=*/input_degrees.discard_copy_degree, + /*shard_degrees=*/ + concat(std::vector{ + slice(input_degrees.shard_degrees, ff_dim_t{0}, attrs.start_dim), + {product(flattened_dim_degrees)}, + slice(input_degrees.shard_degrees, attrs.end_dim, std::nullopt), + }), + }; } -// namespace Input { -// constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, -// REPLICA = 4; -// } -// -// namespace Output { -// constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; -// } -// -/* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ - -/* bool is_valid = true; */ -/* is_valid &= input.is_valid(); */ -/* is_valid &= output_shape.is_valid(); */ -/* is_valid &= (input.at(Input::WIDTH).degree == 1); */ - -/* return is_valid; */ -/* } */ - -/* ParallelTensorShape FlatAttrs::calculate_output_shape(ParallelTensorShape - * const &input) const { */ -/* assert (input.num_dims() == Input::NUMDIM); */ -/* ParallelTensorShape output_dims; */ -/* output_dims.data_type = input.data_type; */ - -/* output_dims.at(Output::REPLICA) = input.at(Input::REPLICA); */ -/* output_dims.at(Output::SAMPLE) = input.at(Input::SAMPLE); */ - -/* output_dims.at(Output::CHANNEL).degree = input.at(Input::CHANNEL).degree; - */ -/* assert (input.at(Input::HEIGHT).degree == 1); */ -/* assert (input.at(Input::WIDTH).degree == 1); */ - -/* output_dims.at(Output::CHANNEL).size = input.at(Input::CHANNEL).size * - * input.at(Input::HEIGHT).size * input.at(Input::WIDTH).size; */ -/* output_dims.at(Output::CHANNEL).parallel_idx = - * input.at(Input::CHANNEL).parallel_idx; */ - -/* return output_dims; */ -/* } */ +tl::expected + get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = get_output_shape(attrs, get_reduced_shape(input_shape)); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index d3c00efbb9..0dd9ac7a17 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -112,7 +112,7 @@ static std::optional if (get_discard_copy_degree(input_shape) != 1) { return fmt::format( - "Expected discard copy degree 1, but received discartd copy degree {}", + "Expected discard copy degree 1, but received discard copy degree {}", get_discard_copy_degree(input_shape)); } diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index e1917efd89..95bcd8b336 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,62 +1,184 @@ #include "op-attrs/ops/pool_2d.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation) { + // AdaptivePool2D semantics pulled from + // https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993 -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} + if (num_dims(input_dims) != 4) { + return tl::unexpected( + fmt::format("make_adaptive_pool2d_attrs expected input tensor to " + "have 4 dims, but received dims {}", + input_dims)); + } -} // namespace FlexFlow + size_t num_samples = dim_at_idx(input_dims, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_dims, ff_dim_t{1}); + size_t input_h = dim_at_idx(input_dims, ff_dim_t{2}); + size_t input_w = dim_at_idx(input_dims, ff_dim_t{3}); -/* -#include "op-attrs/ops/pool_2d.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" + if (input_h % output_h != 0) { + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_h % output_h " + "== 0, but received input_h={} and output_h={} (input_dims={}). If you " + "need input_h % output_h != 0 supported, please create an issue.", + input_h, + output_h, + input_dims)); + } -namespace FlexFlow { + if (input_w % output_w != 0) { + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_w % output_w " + "== 0, but received input_w={} and output_w={} (input_dims={}). If you " + "need input_w % output_w != 0 supported, please create an issue.", + input_w, + output_w, + input_dims)); + } -namespace Input { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; + // Note that for some reason the stack overflow post linked above states that + // `kernel_size = ind - (outd-1)*stride`, but some simplification yields + // `kernel_size` = `ind - (outd - 1)*stride` + // = `ind - (outd - 1) * (ind / outd)` + // = `ind - ind + (ind /outd)` + // = `ind / outd` + // = `stride` -namespace Output { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; + int kernel_h = input_h / output_h; + int kernel_w = input_w / output_w; -bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { - ParallelTensorShape output_shape = this->calculate_output_shape(input); + int stride_h = kernel_h; + int stride_w = kernel_w; - return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); -} + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + + TensorShape expected_ouput_shape = TensorShape{ + TensorDims{FFOrdered{ + num_samples, + num_channels, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; -static std::vector - construct_mappings(ParallelTensorShape const &input_shape) { - auto const outputMappings = construct_output_parallel_dims({ - {Input::REPLICA, MappingOperation::PARTITION, Output::REPLICA}, - {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, - {Input::CHANNEL, MappingOperation::PARTITION, Output::CHANNEL}, - {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, - {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}, + TensorShape output_shape = ({ + tl::expected result = + get_output_shape(attrs, TensorShape{input_dims, DataType::FLOAT}); + if (!result.has_value()) { + return tl::unexpected(result.error()); + } + result.value(); }); - return outputMappings; + if (output_shape != expected_ouput_shape) { + return tl::unexpected( + fmt::format("Result of make_adaptive_pool_2d (i.e., {}) should produce " + "expected output shape {}, but produced {}. This is a bug " + "in FlexFlow, Please create an issue.", + attrs, + expected_ouput_shape, + output_shape)); + } + + return attrs; } -static ParallelDimMappingSolution - solve_mappings(ParallelTensorShape const &input) { - return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); +tl::expected + get_output_shape(Pool2DAttrs const &attrs, TensorShape const &input_shape) { + if (num_dims(input_shape) != 4) { + return tl::unexpected( + fmt::format("get_output_shape for Pool2DAttrs expected input tensor to " + "have 4 dims, but received shape {}", + input_shape)); + } + + size_t num_samples = dim_at_idx(input_shape, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + size_t input_height = dim_at_idx(input_shape, ff_dim_t{2}); + size_t input_width = dim_at_idx(input_shape, ff_dim_t{3}); + + size_t output_height = + (input_height + 2 * attrs.padding_h - attrs.kernel_h) / attrs.stride_h + + 1; + + size_t output_width = + (input_width + 2 * attrs.padding_w - attrs.kernel_w) / attrs.stride_w + 1; + + return TensorShape{TensorDims{FFOrdered{ + num_samples, + num_channels, + output_height, + output_width, + }}, + input_shape.data_type}; } -ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape -const &input) const { return solve_mappings(input).output_shapes.at(0); +tl::expected + get_output_shape(Pool2DAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected result_degrees = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!result_degrees.has_value()) { + return tl::unexpected(result_degrees.error()); + } + result_degrees.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_output_parallel_dim_degrees( + Pool2DAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.sum_degree.value > 1) { + if (attrs.pool_type == PoolOp::MAX) { + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with PoolOp::MAX " + "expected input sum degree == 1, but received {}", + input_degrees)); + } else if (attrs.activation.has_value()) { + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with activation={} " + "expected input sum degree == 1, but received {}", + attrs.activation.value(), + input_degrees)); + } + } + + return input_degrees; } } // namespace FlexFlow -*/ diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 4bce5449f4..61062b84b0 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -29,6 +29,14 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { return dims.shard_dims.size(); } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { + return ParallelTensorDimDegrees{ + d.replica_dims.sum_degree, + d.replica_dims.discard_copy_degree, + ff_ordered_shard_degrees(d), + }; +} + int total_replica_degree(ParallelTensorDims const &dims) { return dims.replica_dims.discard_copy_degree.value * dims.replica_dims.sum_degree.value; diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 10bf5027a4..3cd0f47a5d 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -59,6 +59,10 @@ std::optional } } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &s) { + return get_parallel_degrees(s.dims); +} + ParallelTensorShape lift_to_parallel(TensorShape const &s) { return ParallelTensorShape{lift_to_parallel(s.dims), s.data_type}; } @@ -75,6 +79,15 @@ ParallelTensorShape }; } +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &s, + ParallelTensorDimDegrees const °rees) { + return lift_to_parallel_with_degrees(s, + degrees.sum_degree, + degrees.discard_copy_degree, + degrees.shard_degrees); +} + TensorShape require_not_parallel(ParallelTensorShape const &s) { int total_degree = get_total_parallel_degree(s); if (total_degree != 1) { diff --git a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc index 42ea07e6b5..84f1861f0b 100644 --- a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc @@ -5,8 +5,12 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ComputationGraphOpAttrs to/from json") { - ComputationGraphOpAttrs correct = - ComputationGraphOpAttrs{BatchNormAttrs{true}}; + ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, + }}; nlohmann::json j = correct; auto result = j.get(); diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..2ac641cfc2 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1,66 @@ +#include "op-attrs/dim_ordered/concat.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("concat(FFOrdered, FFOrdered)") { + SUBCASE("inputs have elements") { + FFOrdered l_input = FFOrdered{1, 3, 1}; + FFOrdered r_input = FFOrdered{2, 1}; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = {1, 3, 1, 2, 1}; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + FFOrdered l_input = FFOrdered{}; + FFOrdered r_input = FFOrdered{}; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } + + TEST_CASE("concat(std::vector>)") { + SUBCASE("inputs have elements") { + std::vector> input = { + {1}, + {2, 1}, + {1}, + }; + + FFOrdered result = concat(input); + FFOrdered correct = { + 1, + 2, + 1, + 1, + }; + + CHECK(result == correct); + } + + SUBCASE("no inputs") { + std::vector> input = {}; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + std::vector> input = {{}, {}, {}}; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..7bc1695e5c --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1,66 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("ff_ordered_from_map", + T, + std::map, + std::unordered_map) { + SUBCASE("input is empty") { + T m = {}; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is missing keys") { + SUBCASE("missing key is in middle") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 2}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("missing key is 0 idx") { + T m = { + {ff_dim_t{1}, 2}, + {ff_dim_t{2}, 7}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + } + + SUBCASE("input has negative keys") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{-1}, 2}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("input is valid") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{2}, 2}, + {ff_dim_t{3}, 7}, + }; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {4, 5, 2, 7}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc index 60dedfe70a..33cc00c6a1 100644 --- a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc @@ -9,7 +9,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Concat") { int num_incoming = 4; ComputationGraphOpAttrs attrs = - ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}, num_incoming}}; + ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}}}; std::vector result = get_incoming_tensor_roles(attrs, num_incoming); diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc new file mode 100644 index 0000000000..4196394d00 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc @@ -0,0 +1,404 @@ +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_batch_norm_incoming_tensor_roles(BatchNormAttrs)") { + auto make_attrs = [](bool affine) { + return BatchNormAttrs{ + /*relu=*/false, + /*affine=*/affine, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + }; + + SUBCASE("affine = true") { + BatchNormAttrs attrs = make_attrs(/*affine=*/true); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + BatchNormAttrs attrs = make_attrs(/*affine=*/false); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("shape inference (BatchNorm)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + 18, + }}, + DataType::FLOAT, + }; + + TensorShape output = input; + + TensorShape gamma = TensorShape{ + TensorDims{FFOrdered{ + 14, + }}, + DataType::FLOAT, + }; + + TensorShape beta = gamma; + + SUBCASE("get_output_shape(BatchNormAttrs, TensorShape)") { + tl::expected result = + get_output_shape(attrs_affine_true, input); + tl::expected correct = output; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_shape(attrs_affine_true, input); + tl::expected correct = gamma; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_shape(attrs_affine_true, input); + tl::expected correct = beta; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel dim degree inference (BatchNormAttrs)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + SUBCASE("partition parallelism (in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + degree, + 1, + 1, + }, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + tl::expected result = + get_output_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = + optional_from_expected(get_gamma_weights_parallel_dim_degrees( + attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = + optional_from_expected(get_beta_weights_parallel_dim_degrees( + attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + SUBCASE("partition parallelism (not in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, degree, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("sum parallelism") { + SumDegree sum_degree = SumDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + sum_degree, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("discard copy parallelism") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + discard_copy_degree, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel shape inference (BatchNormAttrs)") { + // since most of the edge cases are already tested in the above test cases + // (i.e., shape inference and parallel degree inference) + // here we just do a basic check that they compose + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/true, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{18, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("get_output_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_gamma_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_beta_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc index df436da66c..3d86576279 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc @@ -5,7 +5,12 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { - BatchNormAttrs correct = BatchNormAttrs{true}; + BatchNormAttrs correct = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, + }; nlohmann::json j = correct; BatchNormAttrs result = j.get(); diff --git a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc index 152df09eca..7abb98f3e3 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc @@ -73,8 +73,8 @@ TEST_SUITE(FF_TEST_SUITE) { }; size_t num_samples = 7; - size_t input_channels = 6; - size_t input_height = 10; + size_t input_channels = 4; + size_t input_height = 11; size_t input_width = 15; TensorShape input = TensorShape{ @@ -87,8 +87,8 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - size_t output_height = 3; - size_t output_width = 6; + size_t output_height = 6; + size_t output_width = 8; TensorShape output = TensorShape{ TensorDims{FFOrdered{ diff --git a/lib/op-attrs/test/src/op-attrs/ops/flat.cc b/lib/op-attrs/test/src/op-attrs/ops/flat.cc new file mode 100644 index 0000000000..d81ab95c35 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/flat.cc @@ -0,0 +1,244 @@ +#include "op-attrs/ops/flat.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(FlatAttrs, TensorShape)") { + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + SUBCASE("flatten all dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4 * 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten trailing dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten leading dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten middle dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4 * 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim == end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim < end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{1}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + } + + TEST_CASE( + "get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees)") { + FlatAttrs attrs = FlatAttrs{/*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}}; + + SUBCASE("allows shard parallelism in non-flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 1, 3}, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 3}, + }; + + CHECK(result == correct); + } + + SUBCASE("does not allow shard parallelism in flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 2, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("allows sum parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_shape(FlatAttrs, ParallelTensorShape)") { + // since most of the edge cases are already tested in + // get_output_shape(FlatAttrs, TensorShape) and + // get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees), + // here we just do a basic check that they compose + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8, 1}, + ShardParallelDim{6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + tl::expected result = + get_output_shape(attrs, input_shape); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8 * 6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc new file mode 100644 index 0000000000..0c14c0fc2a --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc @@ -0,0 +1,400 @@ +#include "op-attrs/ops/pool_2d.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("make_adaptive_pool2d") { + size_t input_n = 10; + size_t input_c = 11; + size_t input_h = 15; + size_t input_w = 20; + Activation activation = Activation::RELU; + PoolOp op = PoolOp::AVG; + + TensorDims input_dims = + TensorDims{FFOrdered{input_n, input_c, input_h, input_w}}; + + SUBCASE("input_h divisible by output_h && input_w divisible by output_w") { + int output_h = 5; + int output_w = 2; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/10, + /*stride_h=*/3, + /*stride_w=*/10, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = + get_output_shape(correct_attrs, input_shape); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + input_n, + input_c, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + SUBCASE("input_h not divisible by output_h") { + int output_h = 6; + int output_w = 2; + + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_w not divisible by output_w") { + int output_h = 5; + int output_w = 3; + + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_h == output_h and input_w == output_w") { + int output_h = input_h; + int output_w = input_w; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/1, + /*kernel_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = + get_output_shape(correct_attrs, input_shape); + tl::expected correct = input_shape; + + CHECK(result == correct); + } + } + } + + TEST_CASE("get_output_shape(Pool2DAttrs, TensorShape)") { + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("fails on non-4d inputs") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + 14, + }}, + DataType::FLOAT, + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("4d input") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{11, 13, 12, 6}}, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{11, 13, 6, 4}}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_parallel_dim_degrees(Pool2DAttrs, " + "ParallelTensorDimDegrees)") { + auto make_attrs = [](PoolOp pool_type, + std::optional const &activation) { + return Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + }; + + SUBCASE("allows data parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows arbitrary input sharding parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 2, + 5, + 6, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{3}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("sum parallelism") { + SUBCASE("without activation") { + SUBCASE("PoolOp::MAX does not allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + std::optional result = + optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("PoolOp::AVG does allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + } + + SUBCASE("with activation does not allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/Activation::RELU); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("get_output_shape(Pool2DAttrs, ParallelTensorShape)") { + // this function is mostly covered by the tests above, so we + // just do a single test to make sure it works/exists + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("valid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{12, 3}, + ShardParallelDim{6, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{6, 3}, + ShardParallelDim{4, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + } + + SUBCASE("invalid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 1}, + ShardParallelDim{16, 1}, + ShardParallelDim{12, 1}, + ShardParallelDim{6, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 11e591545d..45cde0de57 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -137,6 +137,13 @@ struct ComputationGraphBuilder { PoolOp type = PoolOp::MAX, std::optional const &activation = std::nullopt, std::optional const &name = std::nullopt); + tensor_guid_t adaptive_pool2d( + tensor_guid_t const &input, + int output_h, + int output_w, + PoolOp type = PoolOp::MAX, + std::optional const &activation = std::nullopt, + std::optional const &name = std::nullopt); tensor_guid_t layer_norm(tensor_guid_t const &input, std::vector const &axes, @@ -145,7 +152,10 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); tensor_guid_t batch_norm(tensor_guid_t const &input, - bool relu = true, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); tensor_guid_t batch_matmul(tensor_guid_t const &A, @@ -170,11 +180,9 @@ struct ComputationGraphBuilder { DataType dtype, std::optional const &name = std::nullopt); // Add a concat layer - tensor_guid_t - concat(int n, - std::vector const &tensors, - int axis, - std::optional const &maybe_name = std::nullopt); + tensor_guid_t concat(std::vector const &tensors, + int axis, + std::optional const &name = std::nullopt); // Add a mean layer tensor_guid_t mean(tensor_guid_t const &input, std::vector const &dims, @@ -188,6 +196,8 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); // Add a flat layer tensor_guid_t flat(tensor_guid_t const &input, + int start_dim = 0, + std::optional const &end_dim = std::nullopt, std::optional const &name = std::nullopt); // Add a softmax layer tensor_guid_t softmax(tensor_guid_t const &input, @@ -252,9 +262,9 @@ struct ComputationGraphBuilder { std::vector const &weights, std::vector const &outputs); -private: TensorShape get_shape(tensor_guid_t const &) const; +private: tensor_guid_t broadcast(tensor_guid_t const &, TensorDims const &, std::string const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 3a7f67dcf0..019b120936 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -87,7 +87,10 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t batch_norm(parallel_tensor_guid_t const &input, - bool relu = true, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); parallel_tensor_guid_t diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index a4f61cff98..4a565476bd 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -6,14 +6,17 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/concat.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/dropout.h" #include "op-attrs/ops/element_binary.h" #include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/flat.h" #include "op-attrs/ops/gather.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/tensor_dims.h" @@ -498,21 +501,130 @@ tensor_guid_t ComputationGraphBuilder::gather( return get_only( this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } +tensor_guid_t ComputationGraphBuilder::pool2d( + tensor_guid_t const &x, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/kernelH, + /*kernel_w=*/kernelW, + /*stride_h=*/strideH, + /*stride_w=*/strideW, + /*padding_h=*/paddingH, + /*padding_w=*/paddingW, + /*pool_type=*/type, + /*activation=*/activation, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} + +tensor_guid_t ComputationGraphBuilder::adaptive_pool2d( + tensor_guid_t const &uncasted_input, + int output_h, + int output_w, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { + + TensorDims input_dims = this->get_shape(uncasted_input).dims; + + Pool2DAttrs attrs = throw_if_unexpected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, type, activation)); + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t casted_input = + this->as_type(uncasted_input, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, this->get_shape(casted_input))); + + return get_only(this->add_layer( + layer, {casted_input}, {}, {make_output_attrs(output_shape)})); +} tensor_guid_t ComputationGraphBuilder::batch_norm( tensor_guid_t const &input, - bool relu, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); + } + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, + }; + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape input_shape = this->get_shape(input); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); - return get_only( - this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); + std::vector weights; + + if (affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + TensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + TensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } + + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -674,6 +786,50 @@ tensor_guid_t ComputationGraphBuilder::dense( layer, {input}, weights, {make_output_attrs(output_shape)})); } +tensor_guid_t ComputationGraphBuilder::concat( + std::vector const &inputs, + int axis, + std::optional const &maybe_name) { + + ConcatAttrs attrs = ConcatAttrs{ff_dim_t{axis}}; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + std::vector input_shapes = transform( + inputs, [&](tensor_guid_t const &i) { return this->get_shape(i); }); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shapes)); + + return get_only( + this->add_layer(layer, inputs, {}, {make_output_attrs(output_shape)})); +} + +tensor_guid_t ComputationGraphBuilder::flat( + tensor_guid_t const &input, + int start_dim, + std::optional const &end_dim, + std::optional const &maybe_name) { + int input_num_dims = num_dims(this->get_shape(input)); + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{start_dim}, + /*end_dim=*/ff_dim_t{end_dim.value_or(input_num_dims)}, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} + tensor_guid_t ComputationGraphBuilder::layer_norm( tensor_guid_t const &input, std::vector const &axes, diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 620dc035fc..ce00ea62f4 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -331,18 +331,56 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( parallel_tensor_guid_t const &input, - bool relu, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); + } + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, + }; std::string name = maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + ParallelTensorShape input_shape = this->get_shape(input); + ParallelTensorShape output_shape = - get_output_shape(attrs, this->get_shape(input)); + throw_if_unexpected(get_output_shape(attrs, input_shape)); + + std::vector weights; + + if (attrs.affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + ParallelTensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + ParallelTensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } return this->add_layer(layer, {input}, {}, {output_shape}); } diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml index 59e913750e..eb758ea4fc 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml @@ -55,7 +55,8 @@ values = [ { name = "SHOULD_BROADCAST_LHS" }, { name = "SHOULD_BROADCAST_RHS" }, { name = "DIM" }, - { name = "ELEMENTWISE_AFFINE" }, + { name = "AFFINE" }, + { name = "MOMENTUM" }, { name = "REGULARIZER" }, { name = "SHAPE" }, { name = "SPLITS" }, diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index d5d735ef59..442d3345a1 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -19,8 +19,12 @@ std::optional get_attribute(BatchNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); - case OperatorAttributeKey::RELU: - return p.relu; + case OperatorAttributeKey::EPSILON: + return p.eps; + case OperatorAttributeKey::AFFINE: + return p.affine; + case OperatorAttributeKey::MOMENTUM: + return p.momentum; default: return std::nullopt; } @@ -189,6 +193,10 @@ std::optional get_attribute(LayerNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); + case OperatorAttributeKey::AFFINE: + return p.elementwise_affine; + case OperatorAttributeKey::AXES: + return vector_of(p.axes); default: return std::nullopt; } diff --git a/lib/substitutions/test/src/substitutions/substitution.cc b/lib/substitutions/test/src/substitutions/substitution.cc index 87ffc01f0b..1718b03b5c 100644 --- a/lib/substitutions/test/src/substitutions/substitution.cc +++ b/lib/substitutions/test/src/substitutions/substitution.cc @@ -21,7 +21,7 @@ TEST_SUITE(FF_TEST_SUITE) { // } TEST_CASE("evaluate_substitution_output(SubParallelComputationGraph, " - "Substituion, PCGPatternMatch)") { + "Substitution, PCGPatternMatch)") { // Currently Substitution creation is very verbose. // This is being addressed in // https://github.com/flexflow/FlexFlow/issues/1473. diff --git a/lib/utils/include/utils/containers/are_all_same.h b/lib/utils/include/utils/containers/are_all_same.h new file mode 100644 index 0000000000..37b1838146 --- /dev/null +++ b/lib/utils/include/utils/containers/are_all_same.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H + +namespace FlexFlow { + +template +bool are_all_same(C const &c) { + if (c.empty()) { + return true; + } + + auto const &first = *c.cbegin(); + for (auto const &v : c) { + if (v != first) { + return false; + } + } + return true; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_all_same1.h b/lib/utils/include/utils/containers/require_all_same1.h new file mode 100644 index 0000000000..2f42243857 --- /dev/null +++ b/lib/utils/include/utils/containers/require_all_same1.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H + +#include +#include + +namespace FlexFlow { + +template +tl::expected require_all_same1(C const &c) { + if (c.empty()) { + return tl::unexpected(fmt::format( + "require_all_same1 expected non-empty container, but received {}", c)); + } + + T const &first = *c.cbegin(); + for (T const &v : c) { + if (v != first) { + return tl::unexpected(fmt::format("require_all_same1 found non-same " + "elements {} and {} in containers {}", + first, + v, + c)); + } + } + return first; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 52368f94ad..5ae90ec5ba 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -25,10 +25,15 @@ std::vector subvec(std::vector const &v, if (maybe_start.has_value()) { begin_iter += resolve_loc(maybe_start.value()); } + if (maybe_end.has_value()) { end_iter = v.cbegin() + resolve_loc(maybe_end.value()); } + if (end_iter < begin_iter) { + end_iter = begin_iter; + } + std::vector output(begin_iter, end_iter); return output; } diff --git a/lib/utils/include/utils/containers/sum.h b/lib/utils/include/utils/containers/sum.h new file mode 100644 index 0000000000..5dbd620781 --- /dev/null +++ b/lib/utils/include/utils/containers/sum.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H + +namespace FlexFlow { + +template +T sum(C const &c) { + T result = 0; + for (T const &t : c) { + result += t; + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 3ec165d595..377561d70c 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H #include "utils/exception.h" #include "utils/fmt/optional.h" diff --git a/lib/utils/src/utils/containers/are_all_same.cc b/lib/utils/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..c515bceee2 --- /dev/null +++ b/lib/utils/src/utils/containers/are_all_same.cc @@ -0,0 +1 @@ +#include "utils/containers/are_all_same.h" diff --git a/lib/utils/src/utils/containers/require_all_same1.cc b/lib/utils/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..295339a91d --- /dev/null +++ b/lib/utils/src/utils/containers/require_all_same1.cc @@ -0,0 +1 @@ +#include "utils/containers/require_all_same1.h" diff --git a/lib/utils/src/utils/containers/sum.cc b/lib/utils/src/utils/containers/sum.cc new file mode 100644 index 0000000000..088b5f1983 --- /dev/null +++ b/lib/utils/src/utils/containers/sum.cc @@ -0,0 +1 @@ +#include "utils/containers/sum.h" diff --git a/lib/utils/test/src/utils/containers/are_all_same.cc b/lib/utils/test/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..fd8b321439 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_all_same.cc @@ -0,0 +1,36 @@ +#include "utils/containers/are_all_same.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_all_same(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("input elements are all same") { + std::vector input = {1, 1, 1}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("input elements are not all same") { + std::vector input = {1, 1, 2, 1}; + + bool result = are_all_same(input); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/require_all_same1.cc b/lib/utils/test/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..48c1ab0b99 --- /dev/null +++ b/lib/utils/test/src/utils/containers/require_all_same1.cc @@ -0,0 +1,54 @@ +#include "utils/containers/require_all_same1.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "utils/expected.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("require_all_same1(T)", + T, + std::vector, + std::unordered_set, + std::unordered_multiset, + std::set, + std::multiset) { + SUBCASE("input is empty") { + T input = {}; + + std::optional result = + optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input elements are all the same") { + T input = {1, 1, 1}; + + tl::expected result = require_all_same1(input); + tl::expected correct = 1; + + CHECK(result == correct); + } + + SUBCASE("input elements are not all the same") { + T input = {1, 1, 2, 1}; + + std::optional result = + optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/sum.cc b/lib/utils/test/src/utils/containers/sum.cc new file mode 100644 index 0000000000..32d8cd32a3 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sum.cc @@ -0,0 +1,27 @@ +#include "utils/containers/sum.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("sum(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + + int result = sum(input); + int correct = 0; + + CHECK(result == correct); + } + + SUBCASE("input is not empty") { + std::vector input = {1, 3, 2}; + + int result = sum(input); + int correct = 6; + + CHECK(result == correct); + } + } +}