From ede66f7faa9fc014db1c3ab29dfa4935ab2e7108 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Tue, 10 Sep 2024 16:40:50 -0700 Subject: [PATCH 01/10] inception v3 initial implementation --- lib/models/include/models/inceptionv3.h | 17 ++ .../models/inceptionv3_config.struct.toml | 31 +++ lib/models/src/models/inceptionv3.cc | 203 ++++++++++++++++++ lib/models/test/src/models/inceptionv3.cc | 19 ++ lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 25 ++- .../include/pcg/computation_graph_builder.h | 4 +- lib/pcg/src/pcg/computation_graph_builder.cc | 54 +++++ 7 files changed, 349 insertions(+), 4 deletions(-) create mode 100644 lib/models/include/models/inceptionv3.h create mode 100644 lib/models/include/models/inceptionv3_config.struct.toml create mode 100644 lib/models/src/models/inceptionv3.cc create mode 100644 lib/models/test/src/models/inceptionv3.cc diff --git a/lib/models/include/models/inceptionv3.h b/lib/models/include/models/inceptionv3.h new file mode 100644 index 0000000000..e124334171 --- /dev/null +++ b/lib/models/include/models/inceptionv3.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 + +#include "models/inceptionv3_config.dtg.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" + +namespace FlexFlow { + +InceptionV3Config get_default_inception_v3_config(); + +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/inceptionv3_config.struct.toml b/lib/models/include/models/inceptionv3_config.struct.toml new file mode 100644 index 0000000000..cae54c892e --- /dev/null +++ b/lib/models/include/models/inceptionv3_config.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "InceptionV3Config" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "input_height" +type = "size_t" + +[[fields]] +name = "input_width" +type = "size_t" + +[[fields]] +name = "input_num_channels" +type = "size_t" + +[[fields]] +name = "num_classes" +type = "size_t" + +[[fields]] +name = "batch_size" +type = "size_t" diff --git a/lib/models/src/models/inceptionv3.cc b/lib/models/src/models/inceptionv3.cc new file mode 100644 index 0000000000..25aac3a246 --- /dev/null +++ b/lib/models/src/models/inceptionv3.cc @@ -0,0 +1,203 @@ +#include "models/inceptionv3.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" + +namespace FlexFlow { + +InceptionV3Config get_default_inception_v3_config() { + return InceptionV3Config{/*input_height=*/299, + /*input_width=*/299, + /*input_num_channels=*/3, + /*num_classes=*/1000, + /*batch_size=*/32}; +} + +tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int filters, + int kernel_size_h, + int kernel_size_w, + int stride_h = 1, + int stride_w = 1, + int padding_h = 0, + int padding_w = 0, + bool use_bias = false) { + tensor_guid_t conv = cgb.conv2d(input, + filters, + kernel_size_h, + kernel_size_w, + stride_h, + stride_w, + padding_h, + padding_w, + std::nullopt, + 1, + use_bias); + return cgb.batch_norm(conv); +} + +tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int pool_features) { + tensor_guid_t branch1x1 = create_conv_block(cgb, input, 64, 1, 1); + + tensor_guid_t branch5x5 = create_conv_block(cgb, input, 48, 1, 1); + branch5x5 = create_conv_block(cgb, branch5x5, 64, 5, 5, 1, 1, 2, 2); + + tensor_guid_t branch3x3dbl = create_conv_block(cgb, input, 64, 1, 1); + branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 96, 3, 3, 1, 1, 1, 1); + branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 96, 3, 3, 1, 1, 1, 1); + + tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); + branch_pool = create_conv_block(cgb, branch_pool, pool_features, 1, 1); + + return cgb.concat(4, {branch1x1, branch5x5, branch3x3dbl, branch_pool}, 3); +} + +tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = create_conv_block(cgb, input, 384, 3, 3, 2, 2); + + tensor_guid_t branch3x3dbl = create_conv_block(cgb, input, 64, 1, 1); + branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 96, 3, 3, 1, 1, 1, 1); + branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 96, 3, 3, 2, 2); + + tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 2, 2, 0, 0, PoolOp::MAX); + + return cgb.concat(3, {branch3x3, branch3x3dbl, branch_pool}, 3); +} + +tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int channels_7x7) { + tensor_guid_t branch1x1 = create_conv_block(cgb, input, 192, 1, 1); + + tensor_guid_t branch7x7 = create_conv_block(cgb, input, channels_7x7, 1, 1); + branch7x7 = create_conv_block(cgb, branch7x7, channels_7x7, 1, 7, 1, 1, 0, 3); + branch7x7 = create_conv_block(cgb, branch7x7, 192, 7, 1, 1, 1, 3, 0); + + tensor_guid_t branch7x7dbl = + create_conv_block(cgb, input, channels_7x7, 1, 1); + branch7x7dbl = + create_conv_block(cgb, branch7x7dbl, channels_7x7, 7, 1, 1, 1, 3, 0); + branch7x7dbl = + create_conv_block(cgb, branch7x7dbl, channels_7x7, 1, 7, 1, 1, 0, 3); + branch7x7dbl = + create_conv_block(cgb, branch7x7dbl, channels_7x7, 7, 1, 1, 1, 3, 0); + branch7x7dbl = + create_conv_block(cgb, branch7x7dbl, channels_7x7, 1, 7, 1, 1, 0, 3); + + tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); + branch_pool = create_conv_block(cgb, branch_pool, 192, 1, 1); + + return cgb.concat(4, {branch1x1, branch7x7, branch7x7dbl, branch_pool}, 3); +} + +tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = create_conv_block(cgb, input, 192, 1, 1); + branch3x3 = create_conv_block(cgb, branch3x3, 320, 3, 3, 2, 2); + + tensor_guid_t branch7x7x3 = create_conv_block(cgb, input, 192, 1, 1); + branch7x7x3 = create_conv_block(cgb, branch7x7x3, 192, 1, 7, 1, 1, 0, 3); + branch7x7x3 = create_conv_block(cgb, branch7x7x3, 192, 7, 1, 1, 1, 3, 0); + branch7x7x3 = create_conv_block(cgb, branch7x7x3, 192, 3, 3, 2, 2); + + tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 2, 2, 0, 0, PoolOp::MAX); + + return cgb.concat(3, {branch3x3, branch7x7x3, branch_pool}, 3); +} + +tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch1x1 = create_conv_block(cgb, input, 320, 1, 1); + + tensor_guid_t branch3x3 = create_conv_block(cgb, input, 384, 1, 1); + tensor_guid_t branch3x3_1 = + create_conv_block(cgb, branch3x3, 384, 1, 3, 1, 1, 0, 1); + tensor_guid_t branch3x3_2 = + create_conv_block(cgb, branch3x3, 384, 3, 1, 1, 1, 1, 0); + branch3x3 = cgb.concat(2, {branch3x3_1, branch3x3_2}, 3); + + tensor_guid_t branch3x3dbl = create_conv_block(cgb, input, 448, 1, 1); + branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 384, 3, 3, 1, 1, 1, 1); + tensor_guid_t branch3x3dbl_1 = + create_conv_block(cgb, branch3x3dbl, 384, 1, 3, 1, 1, 0, 1); + tensor_guid_t branch3x3dbl_2 = + create_conv_block(cgb, branch3x3dbl, 384, 3, 1, 1, 1, 1, 0); + branch3x3dbl = cgb.concat(2, {branch3x3dbl_1, branch3x3dbl_2}, 3); + + tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); + branch_pool = create_conv_block(cgb, branch_pool, 192, 1, 1); + + return cgb.concat(4, {branch1x1, branch3x3, branch3x3dbl, branch_pool}, 3); +} + +tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t x = create_conv_block(cgb, input, 32, 3, 3, 2, 2); + x = create_conv_block(cgb, x, 32, 3, 3); + x = create_conv_block(cgb, x, 64, 3, 3, 1, 1, 1, 1); + x = cgb.pool2d(x, 3, 3, 2, 2, 0, 0, PoolOp::MAX); + + x = create_conv_block(cgb, x, 80, 1, 1); + x = create_conv_block(cgb, x, 192, 3, 3); + x = cgb.pool2d(x, 3, 3, 2, 2, 0, 0, PoolOp::MAX); + + return x; +} + +tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + size_t num_classes) { + tensor_guid_t x = cgb.pool2d(input, 8, 8, 1, 1, 0, 0, PoolOp::AVG); + x = cgb.dropout(x, 0.5); + x = cgb.dense(x, num_classes); + return x; +} + +tensor_guid_t create_inception_v3(ComputationGraphBuilder &cgb, + InceptionV3Config const &config, + tensor_guid_t const &input) { + tensor_guid_t x = create_initial_layers(cgb, input); + + x = create_inception_module_a(cgb, x, 32); + x = create_inception_module_a(cgb, x, 64); + x = create_inception_module_a(cgb, x, 64); + + x = create_inception_module_b(cgb, x); + + x = create_inception_module_c(cgb, x, 128); + x = create_inception_module_c(cgb, x, 160); + x = create_inception_module_c(cgb, x, 160); + x = create_inception_module_c(cgb, x, 192); + + x = create_inception_module_d(cgb, x); + + x = create_inception_module_e(cgb, x); + x = create_inception_module_e(cgb, x); + + x = create_final_layers(cgb, x, config.num_classes); + + return x; +} + +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config) { + ComputationGraphBuilder cgb; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{config.batch_size, + config.input_height, + config.input_width, + config.input_num_channels}}, + DataType::FLOAT, + }; + + tensor_guid_t input = cgb.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t output = create_inception_v3(cgb, config, input); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/test/src/models/inceptionv3.cc b/lib/models/test/src/models/inceptionv3.cc new file mode 100644 index 0000000000..1f95db9f59 --- /dev/null +++ b/lib/models/test/src/models/inceptionv3.cc @@ -0,0 +1,19 @@ +#include "models/inceptionv3.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_inceptionv3_computation_graph") { + InceptionV3Config config = get_default_inception_v3_config(); + + // ComputationGraph result = get_inception_v3_computation_graph(config); + + SUBCASE("num layers") { + // int result_num_layers = get_layers(result).size(); + int correct_num_layers = -1; + // CHECK(result_num_layers == correct_num_layers); + } + } +} diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index e1917efd89..175f41d5db 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,10 +1,31 @@ #include "op-attrs/ops/pool_2d.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); +TensorShape get_output_shape(Pool2DAttrs const &attrs, + TensorShape const &input_shape) { + size_t num_samples = dim_at_idx(input_shape, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + size_t input_height = dim_at_idx(input_shape, ff_dim_t{2}); + size_t input_width = dim_at_idx(input_shape, ff_dim_t{3}); + + size_t output_height = + (input_height + 2 * attrs.padding_h - attrs.kernel_h) / attrs.stride_h + + 1; + + size_t output_width = + (input_width + 2 * attrs.padding_w - attrs.kernel_w) / attrs.stride_w + 1; + + return TensorShape{TensorDims{FFOrdered{ + num_samples, + num_channels, + output_height, + output_width, + }}, + input_shape.data_type}; } +// TODO(@pietro): add tests for this and concat ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &) { diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index c641aed6a4..b24d4b7d13 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -135,8 +135,8 @@ struct ComputationGraphBuilder { int paddingH, int paddingW, PoolOp type = PoolOp::MAX, - std::optional const &activation = std::nullopt, - std::optional const &name = std::nullopt); + Activation const &activation = Activation::RELU, + std::optional const &maybe_name = std::nullopt); tensor_guid_t layer_norm(tensor_guid_t const &input, std::vector const &axes, diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 3f2feaf619..abb3ba8bf7 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -5,6 +5,7 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/concat.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/dropout.h" #include "op-attrs/ops/element_binary.h" @@ -13,6 +14,7 @@ #include "op-attrs/ops/gather.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "pcg/computation_graph.h" @@ -485,6 +487,33 @@ tensor_guid_t ComputationGraphBuilder::gather( return this->add_layer(layer, {input}, {}, output_shape); } +tensor_guid_t ComputationGraphBuilder::pool2d( + tensor_guid_t const &x, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + PoolOp type, + Activation const &activation, + std::optional const &maybe_name) { + + Pool2DAttrs attrs = Pool2DAttrs{ + kernelH, kernelW, strideH, strideW, paddingH, paddingW, type, activation}; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + + return this->add_layer(layer, {input}, {}, output_shape); +} /* std::vector * ComputationGraphBuilder::get_shapes(std::vector const &ts) @@ -637,6 +666,31 @@ tensor_guid_t ComputationGraphBuilder::dense( return this->add_layer(layer, {input}, weights, output_shape); } +tensor_guid_t ComputationGraphBuilder::concat( + int n, + std::vector const &tensors, + int axis, + std::optional const &maybe_name) { + assert(n == tensors.size()); + ConcatAttrs attrs = ConcatAttrs{ff_dim_t{axis}, n}; + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + std::vector inputs = + transform(tensors, [&](tensor_guid_t const &t) { + return this->as_type(t, DataType::FLOAT, name + "input_pre_cast"); + }); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = + get_output_shape(attrs, transform(inputs, [&](tensor_guid_t const &t) { + return this->get_shape(t); + })); + + return this->add_layer(layer, inputs, {}, output_shape); +} + tensor_guid_t ComputationGraphBuilder::layer_norm( tensor_guid_t const &input, std::vector const &axes, From d724185dce337a6603ce9758228588bf8d96806a Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 10 Sep 2024 19:57:01 -0700 Subject: [PATCH 02/10] Add parallel shape inference for concat and pool2d --- lib/local-execution/src/ops/concat.cc | 4 +- .../inception_v3.h} | 2 +- .../inception_v3_config.struct.toml} | 0 .../models/{ => transformer}/transformer.h | 2 +- .../transformer_config.struct.toml | 0 .../inception_v3.cc} | 16 +- .../models/{ => transformer}/transformer.cc | 2 +- .../inception_v3.cc} | 6 +- .../models/{ => transformer}/transformer.cc | 2 +- lib/op-attrs/include/op-attrs/dim_ordered.h | 2 +- .../dim_ordered/ff_ordered_from_map.h | 29 ++ lib/op-attrs/include/op-attrs/ops/concat.h | 10 +- .../op-attrs/ops/concat_attrs.struct.toml | 4 - lib/op-attrs/include/op-attrs/ops/pool_2d.h | 12 +- .../op-attrs/ops/pool_2d_attrs.struct.toml | 8 +- .../parallel_tensor_dim_degrees.struct.toml | 28 ++ .../include/op-attrs/parallel_tensor_dims.h | 3 + .../include/op-attrs/parallel_tensor_shape.h | 6 + .../dim_ordered/ff_ordered_from_map.cc | 1 + .../src/op-attrs/get_output_shapes.cc | 6 +- lib/op-attrs/src/op-attrs/ops/concat.cc | 118 +++++++-- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 85 +++--- .../src/op-attrs/parallel_tensor_dims.cc | 8 + .../src/op-attrs/parallel_tensor_shape.cc | 10 + .../src/dim_ordered/ff_ordered_from_map.cc | 63 +++++ lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc | 250 ++++++++++++++++++ .../include/pcg/computation_graph_builder.h | 9 +- lib/pcg/src/pcg/computation_graph_builder.cc | 35 +-- .../test/src/substitutions/substitution.cc | 2 +- .../include/utils/containers/are_all_same.h | 23 ++ .../utils/containers/require_all_same1.h | 26 ++ lib/utils/include/utils/containers/sum.h | 18 ++ lib/utils/include/utils/optional.h | 4 +- .../src/utils/containers/are_all_same.cc | 1 + .../src/utils/containers/require_all_same1.cc | 1 + lib/utils/src/utils/containers/sum.cc | 1 + .../test/src/utils/containers/are_all_same.cc | 36 +++ .../src/utils/containers/require_all_same1.cc | 50 ++++ lib/utils/test/src/utils/containers/sum.cc | 27 ++ 39 files changed, 791 insertions(+), 119 deletions(-) rename lib/models/include/models/{inceptionv3.h => inception_v3/inception_v3.h} (87%) rename lib/models/include/models/{inceptionv3_config.struct.toml => inception_v3/inception_v3_config.struct.toml} (100%) rename lib/models/include/models/{ => transformer}/transformer.h (97%) rename lib/models/include/models/{ => transformer}/transformer_config.struct.toml (100%) rename lib/models/src/models/{inceptionv3.cc => inception_v3/inception_v3.cc} (93%) rename lib/models/src/models/{ => transformer}/transformer.cc (99%) rename lib/models/test/src/models/{inceptionv3.cc => inception_v3/inception_v3.cc} (68%) rename lib/models/test/src/models/{ => transformer}/transformer.cc (91%) create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc create mode 100644 lib/op-attrs/test/src/dim_ordered/ff_ordered_from_map.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc create mode 100644 lib/utils/include/utils/containers/are_all_same.h create mode 100644 lib/utils/include/utils/containers/require_all_same1.h create mode 100644 lib/utils/include/utils/containers/sum.h create mode 100644 lib/utils/src/utils/containers/are_all_same.cc create mode 100644 lib/utils/src/utils/containers/require_all_same1.cc create mode 100644 lib/utils/src/utils/containers/sum.cc create mode 100644 lib/utils/test/src/utils/containers/are_all_same.cc create mode 100644 lib/utils/test/src/utils/containers/require_all_same1.cc create mode 100644 lib/utils/test/src/utils/containers/sum.cc diff --git a/lib/local-execution/src/ops/concat.cc b/lib/local-execution/src/ops/concat.cc index 35f663b1cd..4c3462e694 100644 --- a/lib/local-execution/src/ops/concat.cc +++ b/lib/local-execution/src/ops/concat.cc @@ -50,7 +50,7 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto inputs = acc.get_variadic_tensor(INPUTS); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(inputs.size() <= MAX_NUM_INPUTS); return profile(forward_kernel, profiling, @@ -68,7 +68,7 @@ static std::optional auto input_grads = acc.get_variadic_tensor_grad(INPUTS); auto output_grad = acc.get_tensor_grad(OUTPUT); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(input_grads.size() <= MAX_NUM_INPUTS); return profile(backward_kernel, profiling, diff --git a/lib/models/include/models/inceptionv3.h b/lib/models/include/models/inception_v3/inception_v3.h similarity index 87% rename from lib/models/include/models/inceptionv3.h rename to lib/models/include/models/inception_v3/inception_v3.h index e124334171..15b81ae45d 100644 --- a/lib/models/include/models/inceptionv3.h +++ b/lib/models/include/models/inception_v3/inception_v3.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 #define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 -#include "models/inceptionv3_config.dtg.h" +#include "models/inception_v3/inception_v3_config.dtg.h" #include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" diff --git a/lib/models/include/models/inceptionv3_config.struct.toml b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml similarity index 100% rename from lib/models/include/models/inceptionv3_config.struct.toml rename to lib/models/include/models/inception_v3/inception_v3_config.struct.toml diff --git a/lib/models/include/models/transformer.h b/lib/models/include/models/transformer/transformer.h similarity index 97% rename from lib/models/include/models/transformer.h rename to lib/models/include/models/transformer/transformer.h index e50fa37709..037cd84ecb 100644 --- a/lib/models/include/models/transformer.h +++ b/lib/models/include/models/transformer/transformer.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H #define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H -#include "models/transformer_config.dtg.h" +#include "models/transformer/transformer_config.dtg.h" #include "pcg/computation_graph_builder.h" namespace FlexFlow { diff --git a/lib/models/include/models/transformer_config.struct.toml b/lib/models/include/models/transformer/transformer_config.struct.toml similarity index 100% rename from lib/models/include/models/transformer_config.struct.toml rename to lib/models/include/models/transformer/transformer_config.struct.toml diff --git a/lib/models/src/models/inceptionv3.cc b/lib/models/src/models/inception_v3/inception_v3.cc similarity index 93% rename from lib/models/src/models/inceptionv3.cc rename to lib/models/src/models/inception_v3/inception_v3.cc index 25aac3a246..375ef7e11d 100644 --- a/lib/models/src/models/inceptionv3.cc +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -1,4 +1,4 @@ -#include "models/inceptionv3.h" +#include "models/inception_v3/inception_v3.h" #include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" @@ -51,7 +51,7 @@ tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); branch_pool = create_conv_block(cgb, branch_pool, pool_features, 1, 1); - return cgb.concat(4, {branch1x1, branch5x5, branch3x3dbl, branch_pool}, 3); + return cgb.concat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, /*axis=*/3); } tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, @@ -64,7 +64,7 @@ tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 2, 2, 0, 0, PoolOp::MAX); - return cgb.concat(3, {branch3x3, branch3x3dbl, branch_pool}, 3); + return cgb.concat({branch3x3, branch3x3dbl, branch_pool}, 3); } tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, @@ -90,7 +90,7 @@ tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); branch_pool = create_conv_block(cgb, branch_pool, 192, 1, 1); - return cgb.concat(4, {branch1x1, branch7x7, branch7x7dbl, branch_pool}, 3); + return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, 3); } tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, @@ -105,7 +105,7 @@ tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 2, 2, 0, 0, PoolOp::MAX); - return cgb.concat(3, {branch3x3, branch7x7x3, branch_pool}, 3); + return cgb.concat({branch3x3, branch7x7x3, branch_pool}, 3); } tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, @@ -117,7 +117,7 @@ tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, create_conv_block(cgb, branch3x3, 384, 1, 3, 1, 1, 0, 1); tensor_guid_t branch3x3_2 = create_conv_block(cgb, branch3x3, 384, 3, 1, 1, 1, 1, 0); - branch3x3 = cgb.concat(2, {branch3x3_1, branch3x3_2}, 3); + branch3x3 = cgb.concat({branch3x3_1, branch3x3_2}, 3); tensor_guid_t branch3x3dbl = create_conv_block(cgb, input, 448, 1, 1); branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 384, 3, 3, 1, 1, 1, 1); @@ -125,12 +125,12 @@ tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, create_conv_block(cgb, branch3x3dbl, 384, 1, 3, 1, 1, 0, 1); tensor_guid_t branch3x3dbl_2 = create_conv_block(cgb, branch3x3dbl, 384, 3, 1, 1, 1, 1, 0); - branch3x3dbl = cgb.concat(2, {branch3x3dbl_1, branch3x3dbl_2}, 3); + branch3x3dbl = cgb.concat({branch3x3dbl_1, branch3x3dbl_2}, 3); tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); branch_pool = create_conv_block(cgb, branch_pool, 192, 1, 1); - return cgb.concat(4, {branch1x1, branch3x3, branch3x3dbl, branch_pool}, 3); + return cgb.concat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, 3); } tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, diff --git a/lib/models/src/models/transformer.cc b/lib/models/src/models/transformer/transformer.cc similarity index 99% rename from lib/models/src/models/transformer.cc rename to lib/models/src/models/transformer/transformer.cc index 874cd85787..8725f2a8d1 100644 --- a/lib/models/src/models/transformer.cc +++ b/lib/models/src/models/transformer/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" namespace FlexFlow { diff --git a/lib/models/test/src/models/inceptionv3.cc b/lib/models/test/src/models/inception_v3/inception_v3.cc similarity index 68% rename from lib/models/test/src/models/inceptionv3.cc rename to lib/models/test/src/models/inception_v3/inception_v3.cc index 1f95db9f59..6ff891b26a 100644 --- a/lib/models/test/src/models/inceptionv3.cc +++ b/lib/models/test/src/models/inception_v3/inception_v3.cc @@ -1,14 +1,14 @@ -#include "models/inceptionv3.h" +#include "models/inception_v3/inception_v3.h" #include "pcg/computation_graph.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_inceptionv3_computation_graph") { + TEST_CASE("get_inception_v3_computation_graph") { InceptionV3Config config = get_default_inception_v3_config(); - // ComputationGraph result = get_inception_v3_computation_graph(config); + ComputationGraph result = get_inception_v3_computation_graph(config); SUBCASE("num layers") { // int result_num_layers = get_layers(result).size(); diff --git a/lib/models/test/src/models/transformer.cc b/lib/models/test/src/models/transformer/transformer.cc similarity index 91% rename from lib/models/test/src/models/transformer.cc rename to lib/models/test/src/models/transformer/transformer.cc index 2133e9965b..a13d512b92 100644 --- a/lib/models/test/src/models/transformer.cc +++ b/lib/models/test/src/models/transformer/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" #include diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index 6868ba083f..6035e62891 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -10,7 +10,7 @@ namespace FlexFlow { template struct DimOrdered { - DimOrdered() = delete; + DimOrdered() { } DimOrdered(std::initializer_list const &l) : contents(l.begin(), l.end()) {} diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h new file mode 100644 index 0000000000..ba85ec59c8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H + +#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/ff_ordered_of.h" + +namespace FlexFlow { + +template +FFOrdered ff_ordered_from_map(std::map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +template +FFOrdered ff_ordered_from_map(std::unordered_map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index f3ac8494c0..d270bd0c56 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -10,10 +10,12 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ConcatAttrs); -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, + std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml index 4faa870bc4..fab8132993 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -17,7 +17,3 @@ includes = [ [[fields]] name = "axis" type = "::FlexFlow::ff_dim_t" - -[[fields]] -name = "num_inputs" -type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 505fdd9f8c..94282c1806 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -5,14 +5,20 @@ #include "op-attrs/ops/pool_2d_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &); +tl::expected + get_output_shape(Pool2DAttrs const &, TensorShape const &); + +tl::expected + get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(Pool2DAttrs const &, ParallelTensorDimDegrees const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml index 56bf682f50..003469f6f0 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -12,6 +12,12 @@ features = [ includes = [ "op-attrs/pool_op.dtg.h", "op-attrs/activation.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json.h", ] [[fields]] @@ -44,4 +50,4 @@ type = "::FlexFlow::PoolOp" [[fields]] name = "activation" -type = "::FlexFlow::Activation" +type = "std::optional<::FlexFlow::Activation>" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml new file mode 100644 index 0000000000..9a93c64b13 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ParallelTensorDimDegrees" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/dim_ordered.h", +] + +[[fields]] +name = "sum_degree" +type = "::FlexFlow::SumDegree" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "shard_degrees" +type = "::FlexFlow::FFOrdered" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 8e02e3607b..ed49f9a8dd 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -4,6 +4,7 @@ #include "op-attrs/parallel_dim.h" #include "op-attrs/parallel_tensor_dims.dtg.h" #include "op-attrs/tensor_dims.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" namespace FlexFlow { @@ -14,6 +15,8 @@ std::unordered_set replica_dims(ParallelTensorDims const &); /* size_t get_volume(ParallelTensorDims const &); */ size_t num_shard_dims(ParallelTensorDims const &); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &); + int total_replica_degree(ParallelTensorDims const &); int total_shard_degree(ParallelTensorDims const &); int total_parallel_degree(ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 76356b39d4..2b7cdf60a8 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -4,6 +4,7 @@ #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include namespace FlexFlow { @@ -17,12 +18,17 @@ FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &); + ParallelTensorShape lift_to_parallel(TensorShape const &); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &, + ParallelTensorDimDegrees const &); std::unordered_set replica_dims(ParallelTensorShape const &); diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..2de88f38c8 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc index d91d1a1eca..9fa1709b97 100644 --- a/lib/op-attrs/src/op-attrs/get_output_shapes.cc +++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc @@ -14,6 +14,7 @@ #include "op-attrs/ops/input.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/replicate.h" #include "op-attrs/ops/weight.h" #include "utils/overload.h" @@ -38,7 +39,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](ConcatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs)}; + return {throw_if_unexpected(get_output_shape(attrs, inputs))}; }, [&](Conv2DAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; @@ -71,6 +72,9 @@ std::vector [&](LinearAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, + [&](Pool2DAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, [&](ReplicateAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; }, diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 02fee70bea..056e2da0a1 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -1,24 +1,112 @@ #include "op-attrs/ops/concat.h" +#include "op-attrs/dim_ordered/enumerate.h" +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "utils/containers/are_all_same.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/require_all_same1.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/all_of.h" +#include "utils/fmt/map.h" namespace FlexFlow { -/* bool ConcatAttrs::is_valid( */ -/* std::vector const &input) const { */ -/* bool valid = true; */ -/* for (auto p : input) { */ -/* valid &= p.is_valid(); */ -/* } */ -/* return valid; */ -/* } */ - -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + auto get_non_axis_dims = [&](TensorShape const &s) { + std::map dim_sizes = enumerate(ff_ordered(s.dims)); + dim_sizes.erase(attrs.axis); + return dim_sizes; + }; + + if (inputs.size() <= 1) { + return tl::unexpected(fmt::format("get_output_shape for Concat expected 2 or more input, but receieved {}", inputs)); + } + + if (attrs.axis.value < 0) { + return tl::unexpected(fmt::format("ConcatAttrs requires axis >= 0")); + } + + if (!are_all_same(transform(inputs, [](TensorShape const &s) { return num_dims(s); }))) { + return tl::unexpected(fmt::format("get_output_shape for Concat expected all inputs to have the same number of dimensions, but receieved {}", inputs)); + } + + std::map non_axis_dims = ({ + tl::expected, std::string> returned = require_all_same1(transform(inputs, get_non_axis_dims)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + std::vector axis_dim_sizes = transform(inputs, [&](TensorShape const &s) { return dim_at_idx(s, attrs.axis); }); + + size_t output_axis_dim_size = sum(axis_dim_sizes); + + non_axis_dims.insert({attrs.axis, output_axis_dim_size}); + + DataType datatype = ({ + tl::expected returned = require_all_same1(transform(inputs, [](TensorShape const &s) { return s.data_type; })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return TensorShape{ + TensorDims{ + ff_ordered_from_map(non_axis_dims), + }, + datatype, + }; } -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + TensorShape unpar = ({ + tl::expected returned = + get_output_shape(attrs, transform(inputs, get_reduced_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + SumDegree sum_degree = ({ + tl::expected returned = require_all_same1(transform(inputs, get_sum_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + SumDegree{returned.value()}; + }); + + DiscardCopyDegree discard_copy_degree = ({ + tl::expected returned = require_all_same1(transform(inputs, get_discard_copy_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + DiscardCopyDegree{returned.value()}; + }); + + if (!all_of(inputs, [&](ParallelTensorShape const &s) { return shard_dim_at_idx(s, attrs.axis).degree == 1; })) { + return tl::unexpected(fmt::format("get_output_shape for Concat expected input tensors to have parallel degree 1 in the concat axis dimension, but received {}", inputs)); + } + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = require_all_same1(transform(inputs, [](ParallelTensorShape const &s) { return get_parallel_degrees(s); })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index 175f41d5db..32a2485efd 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,10 +1,16 @@ #include "op-attrs/ops/pool_2d.h" +#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" namespace FlexFlow { -TensorShape get_output_shape(Pool2DAttrs const &attrs, +tl::expected + get_output_shape(Pool2DAttrs const &attrs, TensorShape const &input_shape) { + if (num_dims(input_shape) != 4) { + return tl::unexpected(fmt::format("get_output_shape for Pool2DAttrs expected input tensor to have 4 dims, but received shape {}", input_shape)); + } + size_t num_samples = dim_at_idx(input_shape, ff_dim_t{0}); size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); size_t input_height = dim_at_idx(input_shape, ff_dim_t{2}); @@ -25,59 +31,42 @@ TensorShape get_output_shape(Pool2DAttrs const &attrs, }}, input_shape.data_type}; } -// TODO(@pietro): add tests for this and concat - -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow - -/* -#include "op-attrs/ops/pool_2d.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" - -namespace FlexFlow { - -namespace Input { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; - -namespace Output { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; - -bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { - ParallelTensorShape output_shape = this->calculate_output_shape(input); - - return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); -} -static std::vector - construct_mappings(ParallelTensorShape const &input_shape) { - auto const outputMappings = construct_output_parallel_dims({ - {Input::REPLICA, MappingOperation::PARTITION, Output::REPLICA}, - {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, - {Input::CHANNEL, MappingOperation::PARTITION, Output::CHANNEL}, - {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, - {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}, +tl::expected + get_output_shape(Pool2DAttrs const &attrs, ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); }); - return outputMappings; -} + ParallelTensorDimDegrees degrees = ({ + tl::expected result_degrees = + get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + if (!result_degrees.has_value()) { + return tl::unexpected(result_degrees.error()); + } + result_degrees.value(); + }); -static ParallelDimMappingSolution - solve_mappings(ParallelTensorShape const &input) { - return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); + return lift_to_parallel_with_degrees(unpar, degrees); } -ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape -const &input) const { return solve_mappings(input).output_shapes.at(0); +tl::expected + get_output_parallel_dim_degrees(Pool2DAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.sum_degree.value > 1) { + if (attrs.pool_type == PoolOp::MAX) { + return tl::unexpected(fmt::format("get_output_parallel_dim_degrees for Pool2DAttrs with PoolOp::MAX expected input sum degree == 1, but received {}", input_degrees)); + } else if (attrs.activation.has_value()) { + return tl::unexpected(fmt::format("get_output_parallel_dim_degrees for Pool2DAttrs with activation={} expected input sum degree == 1, but received {}", attrs.activation.value(), input_degrees)); + } + } + + return input_degrees; } } // namespace FlexFlow -*/ diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 73c0068826..c5dd501194 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -29,6 +29,14 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { return dims.shard_dims.size(); } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { + return ParallelTensorDimDegrees{ + d.replica_dims.sum_degree, + d.replica_dims.discard_copy_degree, + ff_ordered_shard_degrees(d), + }; +} + int total_replica_degree(ParallelTensorDims const &dims) { return dims.replica_dims.discard_copy_degree.value * dims.replica_dims.sum_degree.value; diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 10bf5027a4..a2cc9a4dc5 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -59,6 +59,10 @@ std::optional } } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &s) { + return get_parallel_degrees(s.dims); +} + ParallelTensorShape lift_to_parallel(TensorShape const &s) { return ParallelTensorShape{lift_to_parallel(s.dims), s.data_type}; } @@ -75,6 +79,12 @@ ParallelTensorShape }; } +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &s, + ParallelTensorDimDegrees const °rees) { + return lift_to_parallel_with_degrees(s, degrees.sum_degree, degrees.discard_copy_degree, degrees.shard_degrees); +} + TensorShape require_not_parallel(ParallelTensorShape const &s) { int total_degree = get_total_parallel_degree(s); if (total_degree != 1) { diff --git a/lib/op-attrs/test/src/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/test/src/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..0ef746dcb5 --- /dev/null +++ b/lib/op-attrs/test/src/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1,63 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("ff_ordered_from_map", T, std::map, std::unordered_map) { + SUBCASE("input is empty") { + T m = {}; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is missing keys") { + SUBCASE("missing key is in middle") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 2}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("missing key is 0 idx") { + T m = { + {ff_dim_t{1}, 2}, + {ff_dim_t{2}, 7}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + } + + SUBCASE("input has negative keys") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{-1}, 2}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("input is valid") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{2}, 2}, + {ff_dim_t{3}, 7}, + }; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {4, 5, 2, 7}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc new file mode 100644 index 0000000000..ddcc801f83 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc @@ -0,0 +1,250 @@ +#include +#include "op-attrs/ops/pool_2d.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(Pool2DAttrs, TensorShape)") { + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("fails on non-4d inputs") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 10, 12, 14, + }}, + DataType::FLOAT, + }; + + std::optional result = optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("4d input") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 11, 13, 12, 6 + }}, + DataType::FLOAT, + }; + + tl::expected result = get_output_shape(attrs, input); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + 11, 13, 6, 4 + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_parallel_dim_degrees(Pool2DAttrs, ParallelTensorDimDegrees)") { + auto make_attrs = [](PoolOp pool_type, std::optional const &activation) { + return Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + }; + + SUBCASE("allows data parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, 1, 1, 1, + }, + }; + + tl::expected result = get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows arbitrary input sharding parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, 2, 5, 6, + }, + }; + + tl::expected result = get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{3}, + FFOrdered{ + 1, 1, 1, 1, + }, + }; + + tl::expected result = get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("sum parallelism") { + SUBCASE("without activation") { + SUBCASE("PoolOp::MAX does not allow sum parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, 1, 1, 1, + }, + }; + + std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("PoolOp::AVG does allow sum parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::AVG, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, 1, 1, 1, + }, + }; + + tl::expected result = get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + } + + SUBCASE("with activation does not allow sum parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::AVG, /*activation=*/Activation::RELU); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, 1, 1, 1, + }, + }; + + std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("get_output_shape(Pool2DAttrs, ParallelTensorShape)") { + // this function is mostly covered by the tests above, so we + // just do a single test to make sure it works/exists + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("valid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{12, 3}, + ShardParallelDim{6, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = get_output_shape(attrs, input); + tl::expected correct = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{6, 3}, + ShardParallelDim{4, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + } + + SUBCASE("invalid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 1}, + ShardParallelDim{16, 1}, + ShardParallelDim{12, 1}, + ShardParallelDim{6, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + std::optional result = optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index b24d4b7d13..4793fc1530 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -135,8 +135,8 @@ struct ComputationGraphBuilder { int paddingH, int paddingW, PoolOp type = PoolOp::MAX, - Activation const &activation = Activation::RELU, - std::optional const &maybe_name = std::nullopt); + std::optional const &activation = std::nullopt, + std::optional const &name = std::nullopt); tensor_guid_t layer_norm(tensor_guid_t const &input, std::vector const &axes, @@ -168,10 +168,9 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); // Add a concat layer tensor_guid_t - concat(int n, - std::vector const &tensors, + concat(std::vector const &tensors, int axis, - std::optional const &maybe_name = std::nullopt); + std::optional const &name = std::nullopt); // Add a mean layer tensor_guid_t mean(tensor_guid_t const &input, std::vector const &dims, diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index abb3ba8bf7..5a6bb5c644 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -496,11 +496,19 @@ tensor_guid_t ComputationGraphBuilder::pool2d( int paddingH, int paddingW, PoolOp type, - Activation const &activation, + std::optional const &activation, std::optional const &maybe_name) { Pool2DAttrs attrs = Pool2DAttrs{ - kernelH, kernelW, strideH, strideW, paddingH, paddingW, type, activation}; + /*kernel_h=*/kernelH, + /*kernel_w=*/kernelW, + /*stride_h=*/strideH, + /*stride_w=*/strideW, + /*padding_h=*/paddingH, + /*padding_w=*/paddingW, + /*pool_type=*/type, + /*activation=*/activation, + }; std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); @@ -510,7 +518,7 @@ tensor_guid_t ComputationGraphBuilder::pool2d( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -667,27 +675,20 @@ tensor_guid_t ComputationGraphBuilder::dense( } tensor_guid_t ComputationGraphBuilder::concat( - int n, - std::vector const &tensors, + std::vector const &inputs, int axis, std::optional const &maybe_name) { - assert(n == tensors.size()); - ConcatAttrs attrs = ConcatAttrs{ff_dim_t{axis}, n}; + + ConcatAttrs attrs = ConcatAttrs{ff_dim_t{axis}}; + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - std::vector inputs = - transform(tensors, [&](tensor_guid_t const &t) { - return this->as_type(t, DataType::FLOAT, name + "input_pre_cast"); - }); - LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = - get_output_shape(attrs, transform(inputs, [&](tensor_guid_t const &t) { - return this->get_shape(t); - })); - + std::vector input_shapes = transform(inputs, [&](tensor_guid_t const &i) { return this->get_shape(i); }); + TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shapes)); + return this->add_layer(layer, inputs, {}, output_shape); } diff --git a/lib/substitutions/test/src/substitutions/substitution.cc b/lib/substitutions/test/src/substitutions/substitution.cc index 87ffc01f0b..1718b03b5c 100644 --- a/lib/substitutions/test/src/substitutions/substitution.cc +++ b/lib/substitutions/test/src/substitutions/substitution.cc @@ -21,7 +21,7 @@ TEST_SUITE(FF_TEST_SUITE) { // } TEST_CASE("evaluate_substitution_output(SubParallelComputationGraph, " - "Substituion, PCGPatternMatch)") { + "Substitution, PCGPatternMatch)") { // Currently Substitution creation is very verbose. // This is being addressed in // https://github.com/flexflow/FlexFlow/issues/1473. diff --git a/lib/utils/include/utils/containers/are_all_same.h b/lib/utils/include/utils/containers/are_all_same.h new file mode 100644 index 0000000000..37b1838146 --- /dev/null +++ b/lib/utils/include/utils/containers/are_all_same.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H + +namespace FlexFlow { + +template +bool are_all_same(C const &c) { + if (c.empty()) { + return true; + } + + auto const &first = *c.cbegin(); + for (auto const &v : c) { + if (v != first) { + return false; + } + } + return true; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_all_same1.h b/lib/utils/include/utils/containers/require_all_same1.h new file mode 100644 index 0000000000..ea167fa0df --- /dev/null +++ b/lib/utils/include/utils/containers/require_all_same1.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H + +#include +#include + +namespace FlexFlow { + +template +tl::expected require_all_same1(C const &c) { + if (c.empty()) { + return tl::unexpected(fmt::format("require_all_same1 expected non-empty container, but received {}", c)); + } + + T const &first = *c.cbegin(); + for (T const &v : c) { + if (v != first) { + return tl::unexpected(fmt::format("require_all_same1 found non-same elements {} and {} in containers {}", first, v, c)); + } + } + return first; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/sum.h b/lib/utils/include/utils/containers/sum.h new file mode 100644 index 0000000000..e34b74f6bb --- /dev/null +++ b/lib/utils/include/utils/containers/sum.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H + +namespace FlexFlow { + +template +T sum(C const &c) { + T result = 0; + for (T const &t : c) { + result += t; + } + return result; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 3448ec4e0e..764cc3d5e3 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H #include "utils/exception.h" #include "utils/fmt/optional.h" diff --git a/lib/utils/src/utils/containers/are_all_same.cc b/lib/utils/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..c515bceee2 --- /dev/null +++ b/lib/utils/src/utils/containers/are_all_same.cc @@ -0,0 +1 @@ +#include "utils/containers/are_all_same.h" diff --git a/lib/utils/src/utils/containers/require_all_same1.cc b/lib/utils/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..295339a91d --- /dev/null +++ b/lib/utils/src/utils/containers/require_all_same1.cc @@ -0,0 +1 @@ +#include "utils/containers/require_all_same1.h" diff --git a/lib/utils/src/utils/containers/sum.cc b/lib/utils/src/utils/containers/sum.cc new file mode 100644 index 0000000000..088b5f1983 --- /dev/null +++ b/lib/utils/src/utils/containers/sum.cc @@ -0,0 +1 @@ +#include "utils/containers/sum.h" diff --git a/lib/utils/test/src/utils/containers/are_all_same.cc b/lib/utils/test/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..fd8b321439 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_all_same.cc @@ -0,0 +1,36 @@ +#include "utils/containers/are_all_same.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_all_same(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("input elements are all same") { + std::vector input = {1, 1, 1}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("input elements are not all same") { + std::vector input = {1, 1, 2, 1}; + + bool result = are_all_same(input); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/require_all_same1.cc b/lib/utils/test/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..09e14ccde1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/require_all_same1.cc @@ -0,0 +1,50 @@ +#include "utils/containers/require_all_same1.h" +#include +#include +#include +#include +#include "utils/expected.h" +#include "utils/fmt/optional.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/vector.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/unordered_multiset.h" +#include "utils/fmt/set.h" +#include "utils/fmt/multiset.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("require_all_same1(T)", T, std::vector, + std::unordered_set, + std::unordered_multiset, + std::set, + std::multiset) { + SUBCASE("input is empty") { + T input = {}; + + std::optional result = optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input elements are all the same") { + T input = {1, 1, 1}; + + tl::expected result = require_all_same1(input); + tl::expected correct = 1; + + CHECK(result == correct); + } + + SUBCASE("input elements are not all the same") { + T input = {1, 1, 2, 1}; + + std::optional result = optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/sum.cc b/lib/utils/test/src/utils/containers/sum.cc new file mode 100644 index 0000000000..32d8cd32a3 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sum.cc @@ -0,0 +1,27 @@ +#include "utils/containers/sum.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("sum(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + + int result = sum(input); + int correct = 0; + + CHECK(result == correct); + } + + SUBCASE("input is not empty") { + std::vector input = {1, 3, 2}; + + int result = sum(input); + int correct = 6; + + CHECK(result == correct); + } + } +} From cc16481bac4ced5f7bc4dd6bdbda536bfdea1e6b Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 10 Sep 2024 20:02:23 -0700 Subject: [PATCH 03/10] Format --- .../src/models/inception_v3/inception_v3.cc | 3 +- lib/op-attrs/include/op-attrs/dim_ordered.h | 2 +- lib/op-attrs/include/op-attrs/ops/concat.h | 3 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 13 +- .../include/op-attrs/parallel_tensor_dims.h | 2 +- .../include/op-attrs/parallel_tensor_shape.h | 4 +- lib/op-attrs/src/op-attrs/ops/concat.cc | 62 ++-- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 34 ++- .../src/op-attrs/parallel_tensor_dims.cc | 6 +- .../src/op-attrs/parallel_tensor_shape.cc | 9 +- .../src/dim_ordered/ff_ordered_from_map.cc | 35 ++- lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc | 283 ++++++++++-------- .../include/pcg/computation_graph_builder.h | 7 +- lib/pcg/src/pcg/computation_graph_builder.cc | 27 +- .../utils/containers/require_all_same1.h | 9 +- lib/utils/include/utils/containers/sum.h | 1 - .../src/utils/containers/require_all_same1.cc | 38 +-- 17 files changed, 309 insertions(+), 229 deletions(-) diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc index 375ef7e11d..a1c7f41c25 100644 --- a/lib/models/src/models/inception_v3/inception_v3.cc +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -51,7 +51,8 @@ tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); branch_pool = create_conv_block(cgb, branch_pool, pool_features, 1, 1); - return cgb.concat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, /*axis=*/3); + return cgb.concat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, + /*axis=*/3); } tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index 6035e62891..96a3c254f7 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -10,7 +10,7 @@ namespace FlexFlow { template struct DimOrdered { - DimOrdered() { } + DimOrdered() {} DimOrdered(std::initializer_list const &l) : contents(l.begin(), l.end()) {} diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index d270bd0c56..f07f06df85 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -11,8 +11,7 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ConcatAttrs); tl::expected - get_output_shape(ConcatAttrs const &, - std::vector const &); + get_output_shape(ConcatAttrs const &, std::vector const &); tl::expected get_output_shape(ConcatAttrs const &, std::vector const &); diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 94282c1806..2c9ef9a1ce 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -3,22 +3,23 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" -#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); -tl::expected - get_output_shape(Pool2DAttrs const &, TensorShape const &); +tl::expected get_output_shape(Pool2DAttrs const &, + TensorShape const &); tl::expected - get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); + get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); -tl::expected - get_output_parallel_dim_degrees(Pool2DAttrs const &, ParallelTensorDimDegrees const &); +tl::expected + get_output_parallel_dim_degrees(Pool2DAttrs const &, + ParallelTensorDimDegrees const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index ed49f9a8dd..7a89b4bd78 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_dims.dtg.h" #include "op-attrs/tensor_dims.dtg.h" -#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 2b7cdf60a8..806a5f0de7 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,10 +1,10 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.h" -#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include namespace FlexFlow { @@ -27,7 +27,7 @@ ParallelTensorShape DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); ParallelTensorShape - lift_to_parallel_with_degrees(TensorShape const &, + lift_to_parallel_with_degrees(TensorShape const &, ParallelTensorDimDegrees const &); std::unordered_set diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 056e2da0a1..0e1f52d9ff 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -3,13 +3,13 @@ #include "op-attrs/dim_ordered/ff_ordered_from_map.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/all_of.h" #include "utils/containers/are_all_same.h" #include "utils/containers/as_vector.h" #include "utils/containers/require_all_same1.h" #include "utils/containers/sum.h" #include "utils/containers/transform.h" -#include "op-attrs/tensor_shape.h" -#include "utils/containers/all_of.h" #include "utils/fmt/map.h" namespace FlexFlow { @@ -24,33 +24,42 @@ tl::expected }; if (inputs.size() <= 1) { - return tl::unexpected(fmt::format("get_output_shape for Concat expected 2 or more input, but receieved {}", inputs)); + return tl::unexpected(fmt::format("get_output_shape for Concat expected 2 " + "or more input, but receieved {}", + inputs)); } if (attrs.axis.value < 0) { return tl::unexpected(fmt::format("ConcatAttrs requires axis >= 0")); } - if (!are_all_same(transform(inputs, [](TensorShape const &s) { return num_dims(s); }))) { - return tl::unexpected(fmt::format("get_output_shape for Concat expected all inputs to have the same number of dimensions, but receieved {}", inputs)); + if (!are_all_same(transform( + inputs, [](TensorShape const &s) { return num_dims(s); }))) { + return tl::unexpected( + fmt::format("get_output_shape for Concat expected all inputs to have " + "the same number of dimensions, but receieved {}", + inputs)); } std::map non_axis_dims = ({ - tl::expected, std::string> returned = require_all_same1(transform(inputs, get_non_axis_dims)); + tl::expected, std::string> returned = + require_all_same1(transform(inputs, get_non_axis_dims)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } returned.value(); }); - std::vector axis_dim_sizes = transform(inputs, [&](TensorShape const &s) { return dim_at_idx(s, attrs.axis); }); - + std::vector axis_dim_sizes = transform( + inputs, [&](TensorShape const &s) { return dim_at_idx(s, attrs.axis); }); + size_t output_axis_dim_size = sum(axis_dim_sizes); non_axis_dims.insert({attrs.axis, output_axis_dim_size}); DataType datatype = ({ - tl::expected returned = require_all_same1(transform(inputs, [](TensorShape const &s) { return s.data_type; })); + tl::expected returned = require_all_same1( + transform(inputs, [](TensorShape const &s) { return s.data_type; })); if (!returned.has_value()) { return tl::unexpected(returned.error()); } @@ -58,10 +67,10 @@ tl::expected }); return TensorShape{ - TensorDims{ - ff_ordered_from_map(non_axis_dims), - }, - datatype, + TensorDims{ + ff_ordered_from_map(non_axis_dims), + }, + datatype, }; } @@ -69,8 +78,8 @@ tl::expected get_output_shape(ConcatAttrs const &attrs, std::vector const &inputs) { TensorShape unpar = ({ - tl::expected returned = - get_output_shape(attrs, transform(inputs, get_reduced_shape)); + tl::expected returned = + get_output_shape(attrs, transform(inputs, get_reduced_shape)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } @@ -78,7 +87,8 @@ tl::expected }); SumDegree sum_degree = ({ - tl::expected returned = require_all_same1(transform(inputs, get_sum_degree)); + tl::expected returned = + require_all_same1(transform(inputs, get_sum_degree)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } @@ -86,19 +96,28 @@ tl::expected }); DiscardCopyDegree discard_copy_degree = ({ - tl::expected returned = require_all_same1(transform(inputs, get_discard_copy_degree)); + tl::expected returned = + require_all_same1(transform(inputs, get_discard_copy_degree)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } DiscardCopyDegree{returned.value()}; }); - - if (!all_of(inputs, [&](ParallelTensorShape const &s) { return shard_dim_at_idx(s, attrs.axis).degree == 1; })) { - return tl::unexpected(fmt::format("get_output_shape for Concat expected input tensors to have parallel degree 1 in the concat axis dimension, but received {}", inputs)); + + if (!all_of(inputs, [&](ParallelTensorShape const &s) { + return shard_dim_at_idx(s, attrs.axis).degree == 1; + })) { + return tl::unexpected(fmt::format( + "get_output_shape for Concat expected input tensors to have parallel " + "degree 1 in the concat axis dimension, but received {}", + inputs)); } ParallelTensorDimDegrees degrees = ({ - tl::expected returned = require_all_same1(transform(inputs, [](ParallelTensorShape const &s) { return get_parallel_degrees(s); })); + tl::expected returned = + require_all_same1(transform(inputs, [](ParallelTensorShape const &s) { + return get_parallel_degrees(s); + })); if (!returned.has_value()) { return tl::unexpected(returned.error()); } @@ -108,5 +127,4 @@ tl::expected return lift_to_parallel_with_degrees(unpar, degrees); } - } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index 32a2485efd..6fe0ace109 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -5,10 +5,12 @@ namespace FlexFlow { tl::expected - get_output_shape(Pool2DAttrs const &attrs, - TensorShape const &input_shape) { + get_output_shape(Pool2DAttrs const &attrs, TensorShape const &input_shape) { if (num_dims(input_shape) != 4) { - return tl::unexpected(fmt::format("get_output_shape for Pool2DAttrs expected input tensor to have 4 dims, but received shape {}", input_shape)); + return tl::unexpected( + fmt::format("get_output_shape for Pool2DAttrs expected input tensor to " + "have 4 dims, but received shape {}", + input_shape)); } size_t num_samples = dim_at_idx(input_shape, ff_dim_t{0}); @@ -33,7 +35,8 @@ tl::expected } tl::expected - get_output_shape(Pool2DAttrs const &attrs, ParallelTensorShape const &input_shape) { + get_output_shape(Pool2DAttrs const &attrs, + ParallelTensorShape const &input_shape) { TensorShape unpar = ({ tl::expected result_unpar = get_output_shape(attrs, get_reduced_shape(input_shape)); @@ -45,7 +48,8 @@ tl::expected ParallelTensorDimDegrees degrees = ({ tl::expected result_degrees = - get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); if (!result_degrees.has_value()) { return tl::unexpected(result_degrees.error()); } @@ -56,16 +60,24 @@ tl::expected } tl::expected - get_output_parallel_dim_degrees(Pool2DAttrs const &attrs, - ParallelTensorDimDegrees const &input_degrees) { + get_output_parallel_dim_degrees( + Pool2DAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { if (input_degrees.sum_degree.value > 1) { if (attrs.pool_type == PoolOp::MAX) { - return tl::unexpected(fmt::format("get_output_parallel_dim_degrees for Pool2DAttrs with PoolOp::MAX expected input sum degree == 1, but received {}", input_degrees)); - } else if (attrs.activation.has_value()) { - return tl::unexpected(fmt::format("get_output_parallel_dim_degrees for Pool2DAttrs with activation={} expected input sum degree == 1, but received {}", attrs.activation.value(), input_degrees)); + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with PoolOp::MAX " + "expected input sum degree == 1, but received {}", + input_degrees)); + } else if (attrs.activation.has_value()) { + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with activation={} " + "expected input sum degree == 1, but received {}", + attrs.activation.value(), + input_degrees)); } } - + return input_degrees; } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index c5dd501194..dfc6775954 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -31,9 +31,9 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { return ParallelTensorDimDegrees{ - d.replica_dims.sum_degree, - d.replica_dims.discard_copy_degree, - ff_ordered_shard_degrees(d), + d.replica_dims.sum_degree, + d.replica_dims.discard_copy_degree, + ff_ordered_shard_degrees(d), }; } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index a2cc9a4dc5..3cd0f47a5d 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -60,7 +60,7 @@ std::optional } ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &s) { - return get_parallel_degrees(s.dims); + return get_parallel_degrees(s.dims); } ParallelTensorShape lift_to_parallel(TensorShape const &s) { @@ -80,9 +80,12 @@ ParallelTensorShape } ParallelTensorShape - lift_to_parallel_with_degrees(TensorShape const &s, + lift_to_parallel_with_degrees(TensorShape const &s, ParallelTensorDimDegrees const °rees) { - return lift_to_parallel_with_degrees(s, degrees.sum_degree, degrees.discard_copy_degree, degrees.shard_degrees); + return lift_to_parallel_with_degrees(s, + degrees.sum_degree, + degrees.discard_copy_degree, + degrees.shard_degrees); } TensorShape require_not_parallel(ParallelTensorShape const &s) { diff --git a/lib/op-attrs/test/src/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/test/src/dim_ordered/ff_ordered_from_map.cc index 0ef746dcb5..7bc1695e5c 100644 --- a/lib/op-attrs/test/src/dim_ordered/ff_ordered_from_map.cc +++ b/lib/op-attrs/test/src/dim_ordered/ff_ordered_from_map.cc @@ -4,7 +4,10 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("ff_ordered_from_map", T, std::map, std::unordered_map) { + TEST_CASE_TEMPLATE("ff_ordered_from_map", + T, + std::map, + std::unordered_map) { SUBCASE("input is empty") { T m = {}; @@ -17,30 +20,30 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is missing keys") { SUBCASE("missing key is in middle") { T m = { - {ff_dim_t{0}, 4}, - {ff_dim_t{1}, 2}, - {ff_dim_t{3}, 5}, + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 2}, + {ff_dim_t{3}, 5}, }; - + CHECK_THROWS(ff_ordered_from_map(m)); } SUBCASE("missing key is 0 idx") { T m = { - {ff_dim_t{1}, 2}, - {ff_dim_t{2}, 7}, - {ff_dim_t{3}, 5}, + {ff_dim_t{1}, 2}, + {ff_dim_t{2}, 7}, + {ff_dim_t{3}, 5}, }; - + CHECK_THROWS(ff_ordered_from_map(m)); } } SUBCASE("input has negative keys") { T m = { - {ff_dim_t{0}, 4}, - {ff_dim_t{1}, 5}, - {ff_dim_t{-1}, 2}, + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{-1}, 2}, }; CHECK_THROWS(ff_ordered_from_map(m)); @@ -48,10 +51,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is valid") { T m = { - {ff_dim_t{0}, 4}, - {ff_dim_t{1}, 5}, - {ff_dim_t{2}, 2}, - {ff_dim_t{3}, 7}, + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{2}, 2}, + {ff_dim_t{3}, 7}, }; FFOrdered result = ff_ordered_from_map(m); diff --git a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc index ddcc801f83..7db95a545c 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc @@ -1,33 +1,36 @@ -#include #include "op-attrs/ops/pool_2d.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include "utils/fmt/optional.h" +#include -using namespace ::FlexFlow; +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(Pool2DAttrs, TensorShape)") { Pool2DAttrs attrs = Pool2DAttrs{ - /*kernel_h=*/3, - /*kernel_w=*/2, - /*stride_h=*/2, - /*stride_w=*/2, - /*padding_h=*/1, - /*padding_w=*/1, - /*pool_type=*/PoolOp::MAX, - /*activation=*/std::nullopt, + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, }; SUBCASE("fails on non-4d inputs") { TensorShape input = TensorShape{ - TensorDims{FFOrdered{ - 10, 12, 14, - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + 10, + 12, + 14, + }}, + DataType::FLOAT, }; - std::optional result = optional_from_expected(get_output_shape(attrs, input)); + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); std::optional correct = std::nullopt; CHECK(result == correct); @@ -35,35 +38,34 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("4d input") { TensorShape input = TensorShape{ - TensorDims{FFOrdered{ - 11, 13, 12, 6 - }}, - DataType::FLOAT, + TensorDims{FFOrdered{11, 13, 12, 6}}, + DataType::FLOAT, }; - tl::expected result = get_output_shape(attrs, input); + tl::expected result = + get_output_shape(attrs, input); tl::expected correct = TensorShape{ - TensorDims{FFOrdered{ - 11, 13, 6, 4 - }}, - DataType::FLOAT, + TensorDims{FFOrdered{11, 13, 6, 4}}, + DataType::FLOAT, }; CHECK(result == correct); } } - TEST_CASE("get_output_parallel_dim_degrees(Pool2DAttrs, ParallelTensorDimDegrees)") { - auto make_attrs = [](PoolOp pool_type, std::optional const &activation) { + TEST_CASE("get_output_parallel_dim_degrees(Pool2DAttrs, " + "ParallelTensorDimDegrees)") { + auto make_attrs = [](PoolOp pool_type, + std::optional const &activation) { return Pool2DAttrs{ - /*kernel_h=*/3, - /*kernel_w=*/2, - /*stride_h=*/2, - /*stride_w=*/2, - /*padding_h=*/1, - /*padding_w=*/1, - /*pool_type=*/pool_type, - /*activation=*/activation, + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/pool_type, + /*activation=*/activation, }; }; @@ -71,14 +73,18 @@ TEST_SUITE(FF_TEST_SUITE) { Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{ - 4, 1, 1, 1, - }, + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 1, + 1, + 1, + }, }; - tl::expected result = get_output_parallel_dim_degrees(attrs, input); + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); tl::expected correct = input; CHECK(result == correct); @@ -88,14 +94,18 @@ TEST_SUITE(FF_TEST_SUITE) { Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{ - 4, 2, 5, 6, - }, + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 2, + 5, + 6, + }, }; - tl::expected result = get_output_parallel_dim_degrees(attrs, input); + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); tl::expected correct = input; CHECK(result == correct); @@ -105,14 +115,18 @@ TEST_SUITE(FF_TEST_SUITE) { Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{3}, - FFOrdered{ - 1, 1, 1, 1, - }, + SumDegree{1}, + DiscardCopyDegree{3}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, }; - tl::expected result = get_output_parallel_dim_degrees(attrs, input); + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); tl::expected correct = input; CHECK(result == correct); @@ -121,52 +135,68 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("sum parallelism") { SUBCASE("without activation") { SUBCASE("PoolOp::MAX does not allow sum parallelism") { - Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + Pool2DAttrs attrs = + make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{2}, - DiscardCopyDegree{1}, - FFOrdered{ - 1, 1, 1, 1, - }, + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, }; - std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional result = + optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); std::optional correct = std::nullopt; CHECK(result == correct); } SUBCASE("PoolOp::AVG does allow sum parallelism") { - Pool2DAttrs attrs = make_attrs(PoolOp::AVG, /*activation=*/std::nullopt); + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/std::nullopt); ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{2}, - DiscardCopyDegree{1}, - FFOrdered{ - 1, 1, 1, 1, - }, + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, }; - tl::expected result = get_output_parallel_dim_degrees(attrs, input); + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); tl::expected correct = input; CHECK(result == correct); } } - + SUBCASE("with activation does not allow sum parallelism") { - Pool2DAttrs attrs = make_attrs(PoolOp::AVG, /*activation=*/Activation::RELU); + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/Activation::RELU); ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{2}, - DiscardCopyDegree{1}, - FFOrdered{ - 1, 1, 1, 1, - }, + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, }; - std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); std::optional correct = std::nullopt; CHECK(result == correct); @@ -175,73 +205,76 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("get_output_shape(Pool2DAttrs, ParallelTensorShape)") { - // this function is mostly covered by the tests above, so we + // this function is mostly covered by the tests above, so we // just do a single test to make sure it works/exists Pool2DAttrs attrs = Pool2DAttrs{ - /*kernel_h=*/3, - /*kernel_w=*/2, - /*stride_h=*/2, - /*stride_w=*/2, - /*padding_h=*/1, - /*padding_w=*/1, - /*pool_type=*/PoolOp::MAX, - /*activation=*/std::nullopt, + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, }; SUBCASE("valid parallelism") { ParallelTensorShape input = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{14, 7}, - ShardParallelDim{16, 8}, - ShardParallelDim{12, 3}, - ShardParallelDim{6, 2}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{12, 3}, + ShardParallelDim{6, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{2}, - }, - }, - DataType::FLOAT, - }; - - tl::expected result = get_output_shape(attrs, input); - tl::expected correct = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{14, 7}, - ShardParallelDim{16, 8}, - ShardParallelDim{6, 3}, - ShardParallelDim{4, 2}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{2}, - }, - }, - DataType::FLOAT, + DataType::FLOAT, }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{6, 3}, + ShardParallelDim{4, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; } SUBCASE("invalid parallelism") { ParallelTensorShape input = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{14, 1}, - ShardParallelDim{16, 1}, - ShardParallelDim{12, 1}, - ShardParallelDim{6, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{2}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 1}, + ShardParallelDim{16, 1}, + ShardParallelDim{12, 1}, + ShardParallelDim{6, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; - std::optional result = optional_from_expected(get_output_shape(attrs, input)); + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); std::optional correct = std::nullopt; CHECK(result == correct); diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 4793fc1530..1d5dc1bec4 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -167,10 +167,9 @@ struct ComputationGraphBuilder { DataType dtype, std::optional const &name = std::nullopt); // Add a concat layer - tensor_guid_t - concat(std::vector const &tensors, - int axis, - std::optional const &name = std::nullopt); + tensor_guid_t concat(std::vector const &tensors, + int axis, + std::optional const &name = std::nullopt); // Add a mean layer tensor_guid_t mean(tensor_guid_t const &input, std::vector const &dims, diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 5a6bb5c644..86d6948c60 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -500,14 +500,14 @@ tensor_guid_t ComputationGraphBuilder::pool2d( std::optional const &maybe_name) { Pool2DAttrs attrs = Pool2DAttrs{ - /*kernel_h=*/kernelH, - /*kernel_w=*/kernelW, - /*stride_h=*/strideH, - /*stride_w=*/strideW, - /*padding_h=*/paddingH, - /*padding_w=*/paddingW, - /*pool_type=*/type, - /*activation=*/activation, + /*kernel_h=*/kernelH, + /*kernel_w=*/kernelW, + /*stride_h=*/strideH, + /*stride_w=*/strideW, + /*padding_h=*/paddingH, + /*padding_w=*/paddingW, + /*pool_type=*/type, + /*activation=*/activation, }; std::string name = @@ -518,7 +518,8 @@ tensor_guid_t ComputationGraphBuilder::pool2d( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -686,9 +687,11 @@ tensor_guid_t ComputationGraphBuilder::concat( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - std::vector input_shapes = transform(inputs, [&](tensor_guid_t const &i) { return this->get_shape(i); }); - TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shapes)); - + std::vector input_shapes = transform( + inputs, [&](tensor_guid_t const &i) { return this->get_shape(i); }); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shapes)); + return this->add_layer(layer, inputs, {}, output_shape); } diff --git a/lib/utils/include/utils/containers/require_all_same1.h b/lib/utils/include/utils/containers/require_all_same1.h index ea167fa0df..2f42243857 100644 --- a/lib/utils/include/utils/containers/require_all_same1.h +++ b/lib/utils/include/utils/containers/require_all_same1.h @@ -9,13 +9,18 @@ namespace FlexFlow { template tl::expected require_all_same1(C const &c) { if (c.empty()) { - return tl::unexpected(fmt::format("require_all_same1 expected non-empty container, but received {}", c)); + return tl::unexpected(fmt::format( + "require_all_same1 expected non-empty container, but received {}", c)); } T const &first = *c.cbegin(); for (T const &v : c) { if (v != first) { - return tl::unexpected(fmt::format("require_all_same1 found non-same elements {} and {} in containers {}", first, v, c)); + return tl::unexpected(fmt::format("require_all_same1 found non-same " + "elements {} and {} in containers {}", + first, + v, + c)); } } return first; diff --git a/lib/utils/include/utils/containers/sum.h b/lib/utils/include/utils/containers/sum.h index e34b74f6bb..5dbd620781 100644 --- a/lib/utils/include/utils/containers/sum.h +++ b/lib/utils/include/utils/containers/sum.h @@ -12,7 +12,6 @@ T sum(C const &c) { return result; } - } // namespace FlexFlow #endif diff --git a/lib/utils/test/src/utils/containers/require_all_same1.cc b/lib/utils/test/src/utils/containers/require_all_same1.cc index 09e14ccde1..a655ac02ef 100644 --- a/lib/utils/test/src/utils/containers/require_all_same1.cc +++ b/lib/utils/test/src/utils/containers/require_all_same1.cc @@ -1,29 +1,32 @@ #include "utils/containers/require_all_same1.h" -#include -#include -#include -#include #include "utils/expected.h" -#include "utils/fmt/optional.h" #include "utils/fmt/expected.h" -#include "utils/fmt/vector.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/set.h" #include "utils/fmt/multiset.h" +#include "utils/fmt/optional.h" +#include "utils/fmt/set.h" +#include "utils/fmt/unordered_multiset.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" +#include +#include +#include +#include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("require_all_same1(T)", T, std::vector, - std::unordered_set, - std::unordered_multiset, - std::set, - std::multiset) { + TEST_CASE_TEMPLATE("require_all_same1(T)", + T, + std::vector, + std::unordered_set, + std::unordered_multiset, + std::set, + std::multiset) { SUBCASE("input is empty") { T input = {}; - std::optional result = optional_from_expected(require_all_same1(input)); + std::optional result = + optional_from_expected(require_all_same1(input)); std::optional correct = std::nullopt; CHECK(result == correct); @@ -34,14 +37,15 @@ TEST_SUITE(FF_TEST_SUITE) { tl::expected result = require_all_same1(input); tl::expected correct = 1; - + CHECK(result == correct); } SUBCASE("input elements are not all the same") { T input = {1, 1, 2, 1}; - std::optional result = optional_from_expected(require_all_same1(input)); + std::optional result = + optional_from_expected(require_all_same1(input)); std::optional correct = std::nullopt; CHECK(result == correct); From 97b19d833513e4e238c631a83afdb352ec2d4fc1 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 16 Sep 2024 15:38:27 -0700 Subject: [PATCH 04/10] Respond to PR comments --- .../models/inception_v3/inception_v3.h | 12 +- .../inception_v3_config.struct.toml | 20 +- .../inception_v3_output.struct.toml | 25 + .../src/models/inception_v3/inception_v3.cc | 748 +++++++++++++++--- .../src/models/inception_v3/inception_v3.cc | 6 +- lib/op-attrs/include/op-attrs/dim_ordered.h | 4 + .../include/op-attrs/dim_ordered/concat.h | 34 + .../include/op-attrs/dim_ordered/slice.h | 8 + lib/op-attrs/include/op-attrs/ops/flat.h | 7 +- .../op-attrs/ops/flat_attrs.struct.toml | 21 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 7 + .../src/op-attrs/dim_ordered/concat.cc | 1 + .../src/op-attrs/get_output_shapes.cc | 2 +- lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 4 +- lib/op-attrs/src/op-attrs/ops/flat.cc | 107 +-- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 72 ++ .../test/src/op-attrs/dim_ordered/concat.cc | 71 ++ lib/op-attrs/test/src/op-attrs/ops/flat.cc | 235 ++++++ lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc | 119 +++ .../include/pcg/computation_graph_builder.h | 12 +- lib/pcg/src/pcg/computation_graph_builder.cc | 84 +- lib/utils/include/utils/containers/subvec.h | 5 + 22 files changed, 1388 insertions(+), 216 deletions(-) create mode 100644 lib/models/include/models/inception_v3/inception_v3_output.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/concat.h create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/concat.cc create mode 100644 lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/flat.cc diff --git a/lib/models/include/models/inception_v3/inception_v3.h b/lib/models/include/models/inception_v3/inception_v3.h index 15b81ae45d..5c4754e441 100644 --- a/lib/models/include/models/inception_v3/inception_v3.h +++ b/lib/models/include/models/inception_v3/inception_v3.h @@ -2,13 +2,19 @@ #define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 #include "models/inception_v3/inception_v3_config.dtg.h" -#include "pcg/computation_graph.h" -#include "pcg/computation_graph_builder.h" +#include "pcg/computation_graph.dtg.h" namespace FlexFlow { -InceptionV3Config get_default_inception_v3_config(); +/** + * @brief Get the default training config from https://arxiv.org/abs/1512.00567. + */ +InceptionV3Config get_default_inception_v3_training_config(); +/** + * @brief Get a computation graph for Inception-v3 as described in + * https://arxiv.org/abs/1512.00567. + */ ComputationGraph get_inception_v3_computation_graph(InceptionV3Config const &config); diff --git a/lib/models/include/models/inception_v3/inception_v3_config.struct.toml b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml index cae54c892e..a2a75c83bb 100644 --- a/lib/models/include/models/inception_v3/inception_v3_config.struct.toml +++ b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml @@ -10,22 +10,14 @@ features = [ "fmt", ] -[[fields]] -name = "input_height" -type = "size_t" - -[[fields]] -name = "input_width" -type = "size_t" - -[[fields]] -name = "input_num_channels" -type = "size_t" - [[fields]] name = "num_classes" -type = "size_t" +type = "int" [[fields]] name = "batch_size" -type = "size_t" +type = "int" + +[[fields]] +name = "aux_logits" +type = "bool" diff --git a/lib/models/include/models/inception_v3/inception_v3_output.struct.toml b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml new file mode 100644 index 0000000000..066e6df02b --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "InceptionV3Output" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "standard_logits" +type = "::FlexFlow::tensor_guid_t" + +[[fields]] +name = "aux_logits" +type = "std::optional<::FlexFlow::tensor_guid_t>" diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc index a1c7f41c25..ef291c9bcb 100644 --- a/lib/models/src/models/inception_v3/inception_v3.cc +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -1,18 +1,70 @@ #include "models/inception_v3/inception_v3.h" +#include "op-attrs/tensor_shape.h" #include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" +#include "utils/integer_conversions.h" +#include "models/inception_v3/inception_v3_output.dtg.h" namespace FlexFlow { -InceptionV3Config get_default_inception_v3_config() { - return InceptionV3Config{/*input_height=*/299, - /*input_width=*/299, - /*input_num_channels=*/3, - /*num_classes=*/1000, - /*batch_size=*/32}; +struct CheckShape { + CheckShape(ComputationGraphBuilder const &cgb, + InceptionV3Config const &config) + : cgb(cgb), + config(config) + { } + + ComputationGraphBuilder const &cgb; + InceptionV3Config const &config; + + void operator()(tensor_guid_t t, int c, int h, int w) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + size_t_from_int(h), + size_t_from_int(w), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format("Expected activation shape {}, but found activation shape {}", expected_shape, current_shape)); + } + } + + void operator()(tensor_guid_t t, int c) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format("Expected activation shape {}, but found activation shape {}", expected_shape, current_shape)); + } + } +}; + + +InceptionV3Config get_default_inception_v3_training_config() { + return InceptionV3Config{ + /*num_classes=*/1000, + + // see section 8 of https://arxiv.org/abs/1512.00567 for the source of the batch size + /*batch_size=*/32, + + // see section 4 of https://arxiv.org/abs/1512.00567 for a discussion of auxiliary logits. + // they are used by default in training + /*aux_logits=*/true, + }; } -tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, +static tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, tensor_guid_t const &input, int filters, int kernel_size_h, @@ -36,151 +88,625 @@ tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, return cgb.batch_norm(conv); } -tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, +static tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, tensor_guid_t const &input, int pool_features) { - tensor_guid_t branch1x1 = create_conv_block(cgb, input, 64, 1, 1); - - tensor_guid_t branch5x5 = create_conv_block(cgb, input, 48, 1, 1); - branch5x5 = create_conv_block(cgb, branch5x5, 64, 5, 5, 1, 1, 2, 2); - - tensor_guid_t branch3x3dbl = create_conv_block(cgb, input, 64, 1, 1); - branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 96, 3, 3, 1, 1, 1, 1); - branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 96, 3, 3, 1, 1, 1, 1); - - tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); - branch_pool = create_conv_block(cgb, branch_pool, pool_features, 1, 1); + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch5x5 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/48, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/2, + /*padding_w=*/2); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/pool_features, + /*kernel_stride_h=*/1, + /*kernel_stride_w=*/1); + return t; + }(); return cgb.concat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, - /*axis=*/3); + /*axis=*/1); } -tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, +static tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, tensor_guid_t const &input) { - tensor_guid_t branch3x3 = create_conv_block(cgb, input, 384, 3, 3, 2, 2); - - tensor_guid_t branch3x3dbl = create_conv_block(cgb, input, 64, 1, 1); - branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 96, 3, 3, 1, 1, 1, 1); - branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 96, 3, 3, 2, 2); - - tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 2, 2, 0, 0, PoolOp::MAX); - - return cgb.concat({branch3x3, branch3x3dbl, branch_pool}, 3); + tensor_guid_t branch3x3 = create_conv_block(cgb, + input, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_stride_h=*/3, + /*kernel_stride_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch3x3dbl, branch_pool}, /*axis=*/1); } -tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, +static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, tensor_guid_t const &input, int channels_7x7) { tensor_guid_t branch1x1 = create_conv_block(cgb, input, 192, 1, 1); - tensor_guid_t branch7x7 = create_conv_block(cgb, input, channels_7x7, 1, 1); - branch7x7 = create_conv_block(cgb, branch7x7, channels_7x7, 1, 7, 1, 1, 0, 3); - branch7x7 = create_conv_block(cgb, branch7x7, 192, 7, 1, 1, 1, 3, 0); - - tensor_guid_t branch7x7dbl = - create_conv_block(cgb, input, channels_7x7, 1, 1); - branch7x7dbl = - create_conv_block(cgb, branch7x7dbl, channels_7x7, 7, 1, 1, 1, 3, 0); - branch7x7dbl = - create_conv_block(cgb, branch7x7dbl, channels_7x7, 1, 7, 1, 1, 0, 3); - branch7x7dbl = - create_conv_block(cgb, branch7x7dbl, channels_7x7, 7, 1, 1, 1, 3, 0); - branch7x7dbl = - create_conv_block(cgb, branch7x7dbl, channels_7x7, 1, 7, 1, 1, 0, 3); - - tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); - branch_pool = create_conv_block(cgb, branch_pool, 192, 1, 1); - - return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, 3); + tensor_guid_t branch7x7 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + return t; + }(); + + tensor_guid_t branch7x7dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, /*axis=*/1); } -tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, +static tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, tensor_guid_t const &input) { - tensor_guid_t branch3x3 = create_conv_block(cgb, input, 192, 1, 1); - branch3x3 = create_conv_block(cgb, branch3x3, 320, 3, 3, 2, 2); - - tensor_guid_t branch7x7x3 = create_conv_block(cgb, input, 192, 1, 1); - branch7x7x3 = create_conv_block(cgb, branch7x7x3, 192, 1, 7, 1, 1, 0, 3); - branch7x7x3 = create_conv_block(cgb, branch7x7x3, 192, 7, 1, 1, 1, 3, 0); - branch7x7x3 = create_conv_block(cgb, branch7x7x3, 192, 3, 3, 2, 2); - - tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 2, 2, 0, 0, PoolOp::MAX); - - return cgb.concat({branch3x3, branch7x7x3, branch_pool}, 3); + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, t, 320, 3, 3, 2, 2); + return t; + }(); + + tensor_guid_t branch7x7x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch7x7x3, branch_pool}, /*axis=*/1); } -tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, +static tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, tensor_guid_t const &input) { - tensor_guid_t branch1x1 = create_conv_block(cgb, input, 320, 1, 1); - - tensor_guid_t branch3x3 = create_conv_block(cgb, input, 384, 1, 1); - tensor_guid_t branch3x3_1 = - create_conv_block(cgb, branch3x3, 384, 1, 3, 1, 1, 0, 1); - tensor_guid_t branch3x3_2 = - create_conv_block(cgb, branch3x3, 384, 3, 1, 1, 1, 1, 0); - branch3x3 = cgb.concat({branch3x3_1, branch3x3_2}, 3); - - tensor_guid_t branch3x3dbl = create_conv_block(cgb, input, 448, 1, 1); - branch3x3dbl = create_conv_block(cgb, branch3x3dbl, 384, 3, 3, 1, 1, 1, 1); - tensor_guid_t branch3x3dbl_1 = - create_conv_block(cgb, branch3x3dbl, 384, 1, 3, 1, 1, 0, 1); - tensor_guid_t branch3x3dbl_2 = - create_conv_block(cgb, branch3x3dbl, 384, 3, 1, 1, 1, 1, 0); - branch3x3dbl = cgb.concat({branch3x3dbl_1, branch3x3dbl_2}, 3); - - tensor_guid_t branch_pool = cgb.pool2d(input, 3, 3, 1, 1, 1, 1, PoolOp::AVG); - branch_pool = create_conv_block(cgb, branch_pool, 192, 1, 1); - - return cgb.concat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, 3); + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/320, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + tensor_guid_t t_1 = + create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = + create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/448, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + tensor_guid_t t_1 = + create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = + create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, /*axis=*/1); } -tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, +static tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, tensor_guid_t const &input) { - tensor_guid_t x = create_conv_block(cgb, input, 32, 3, 3, 2, 2); - x = create_conv_block(cgb, x, 32, 3, 3); - x = create_conv_block(cgb, x, 64, 3, 3, 1, 1, 1, 1); - x = cgb.pool2d(x, 3, 3, 2, 2, 0, 0, PoolOp::MAX); + tensor_guid_t t = input; + + check_shape(t, 3, 299, 299); + + // Conv2d_1a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + check_shape(t, 32, 149, 149); + + // Conv2d_2a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 32, 147, 147); + + // Conv2d_2b_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + check_shape(t, 64, 147, 147); + + // maxpool1 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 64, 73, 73); + + // Conv2d_3b_1x1 + t = create_conv_block(cgb, + t, + /*filters=*/80, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(t, 80, 73, 73); + + // Conv2d_4a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 192, 71, 71); + + // maxpool2 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 192, 35, 35); + + return t; +} - x = create_conv_block(cgb, x, 80, 1, 1); - x = create_conv_block(cgb, x, 192, 3, 3); - x = cgb.pool2d(x, 3, 3, 2, 2, 0, 0, PoolOp::MAX); +static tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + // avgpool + tensor_guid_t x = cgb.pool2d(input, + /*kernelH=*/8, + /*kernelW=*/8, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 2048, 1, 1); + + // dropout + x = cgb.dropout(x, + /*rate=*/0.5); + check_shape(x, 2048, 1, 1); + + x = cgb.flat(x); + check_shape(x, 2048); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); return x; } -tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, - tensor_guid_t const &input, - size_t num_classes) { - tensor_guid_t x = cgb.pool2d(input, 8, 8, 1, 1, 0, 0, PoolOp::AVG); - x = cgb.dropout(x, 0.5); - x = cgb.dense(x, num_classes); +static tensor_guid_t create_inception_aux(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + tensor_guid_t x = input; + check_shape(x, 768, 17, 17); + + x = cgb.pool2d(x, + /*kernelH=*/5, + /*kernelW=*/5, + /*strideH=*/3, + /*strideW=*/3, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 768, 5, 5); + + // conv0 + x = create_conv_block(cgb, + x, + /*filters=*/128, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(x, 128, 5, 5); + + // conv1 + x = create_conv_block(cgb, + x, + /*filters=*/128, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(x, 768, 1, 1); + + x = cgb.adaptive_pool2d(x, + /*output_h=*/1, + /*output_w=*/1); + check_shape(x, 768, 1, 1); + + x = cgb.flat(input); + check_shape(x, 768); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); + return x; } -tensor_guid_t create_inception_v3(ComputationGraphBuilder &cgb, - InceptionV3Config const &config, - tensor_guid_t const &input) { - tensor_guid_t x = create_initial_layers(cgb, input); +static +InceptionV3Output + create_inception_v3(ComputationGraphBuilder &cgb, + InceptionV3Config const &config, + tensor_guid_t const &input) { + // NOTE: the shapes for check_shape (as well as the layer names in comments) are pulled from + // https://github.com/pytorch/vision/blob/6d7851bd5e2bedc294e40e90532f0e375fcfee04/torchvision/models/inception.py#L103-L155 + CheckShape check_shape = CheckShape{ + /*cgb=*/cgb, + /*config=*/config, + }; + + tensor_guid_t x = create_initial_layers(cgb, check_shape, input); + check_shape(x, 192, 35, 35); + // Mixed_5b x = create_inception_module_a(cgb, x, 32); + check_shape(x, 256, 35, 35); + + // Mixed_5c x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + + // Mixed_5d x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + // Mixed_6a x = create_inception_module_b(cgb, x); + check_shape(x, 768, 17, 17); + // Mixed_6b x = create_inception_module_c(cgb, x, 128); + check_shape(x, 768, 17, 17); + + // Mixed_6c x = create_inception_module_c(cgb, x, 160); + check_shape(x, 768, 17, 17); + + // Mixed_6d x = create_inception_module_c(cgb, x, 160); - x = create_inception_module_c(cgb, x, 192); + check_shape(x, 768, 17, 17); + // Mixed_6e + x = create_inception_module_c(cgb, x, 192); + check_shape(x, 768, 17, 17); + + std::optional aux; + if (config.aux_logits) { + aux = create_inception_aux(cgb, + check_shape, + x, + config.num_classes); + check_shape(aux.value(), config.num_classes); + } + + // Mixed_7a x = create_inception_module_d(cgb, x); + check_shape(x, 1280, 8, 8); + // Mixed_7b x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); + + // Mixed_7c x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); - x = create_final_layers(cgb, x, config.num_classes); + x = create_final_layers(cgb, check_shape, x, config.num_classes); + check_shape(x, config.num_classes); - return x; + return InceptionV3Output{ + x, + aux, + }; } ComputationGraph @@ -188,15 +714,17 @@ ComputationGraph ComputationGraphBuilder cgb; TensorShape input_shape = TensorShape{ - TensorDims{FFOrdered{config.batch_size, - config.input_height, - config.input_width, - config.input_num_channels}}, + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + 3, + 299, + 299, + }}, DataType::FLOAT, }; tensor_guid_t input = cgb.create_tensor(input_shape, CreateGrad::YES); - tensor_guid_t output = create_inception_v3(cgb, config, input); + InceptionV3Output output = create_inception_v3(cgb, config, input); return cgb.computation_graph; } diff --git a/lib/models/test/src/models/inception_v3/inception_v3.cc b/lib/models/test/src/models/inception_v3/inception_v3.cc index 6ff891b26a..8ec91d81fc 100644 --- a/lib/models/test/src/models/inception_v3/inception_v3.cc +++ b/lib/models/test/src/models/inception_v3/inception_v3.cc @@ -6,14 +6,14 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_inception_v3_computation_graph") { - InceptionV3Config config = get_default_inception_v3_config(); + InceptionV3Config config = get_default_inception_v3_training_config(); ComputationGraph result = get_inception_v3_computation_graph(config); SUBCASE("num layers") { - // int result_num_layers = get_layers(result).size(); + int result_num_layers = get_layers(result).size(); int correct_num_layers = -1; - // CHECK(result_num_layers == correct_num_layers); + CHECK(result_num_layers == correct_num_layers); } } } diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index 96a3c254f7..7ea5c3206b 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -138,6 +138,10 @@ struct DimOrdered { return this->contents.size(); } + size_t empty() const { + return this->contents.empty(); + } + size_t num_dims() const { return this->size(); } diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/concat.h b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h new file mode 100644 index 0000000000..c279fe9502 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H + +#include "op-attrs/dim_ordered.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template +FFOrdered concat(FFOrdered const &l, FFOrdered const &r) { + std::vector l_vec = std::vector(l.cbegin(), l.cend()); + std::vector r_vec = std::vector(r.cbegin(), r.cend()); + + std::vector raw_result = concat_vectors(l_vec, r_vec); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +template +FFOrdered concat(std::vector> const &inputs) { + std::vector> vec_inputs = transform(inputs, + [](FFOrdered const &input) { + return std::vector(input.cbegin(), input.cend()); + }); + + std::vector raw_result = concat_vectors(vec_inputs); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index f3dfe5d199..a36b8de29c 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -21,6 +21,14 @@ DimOrdered nonoverloaded_slice(DimOrdered const &d, subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; } +template +FFOrdered slice(FFOrdered const &d, + std::optional const &start, + std::optional const &end) { + return nonoverloaded_slice(d, start, end); +} + + template DimOrdered slice(DimOrdered const &d, std::optional const &start, diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 676d21c59b..3f0cdd7fa4 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -3,6 +3,7 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -11,8 +12,10 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(FlatAttrs); TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &); +tl::expected get_output_parallel_dim_degrees(FlatAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected get_output_shape(FlatAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml index e445535e29..aa286a03e7 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -8,4 +8,23 @@ features = [ "rapidcheck", "fmt", ] -fields = [] + +includes = [ + "", + "op-attrs/ff_dim.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/optional.h", + "utils/json.h", + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "start_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "end_dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 2c9ef9a1ce..36bec5f0d1 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -11,6 +11,13 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); +tl::expected make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation); + + tl::expected get_output_shape(Pool2DAttrs const &, TensorShape const &); diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..cb29f708a3 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/concat.h" diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc index 9fa1709b97..44d31d7143 100644 --- a/lib/op-attrs/src/op-attrs/get_output_shapes.cc +++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc @@ -58,7 +58,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](FlatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](GatherAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0), inputs.at(1))}; diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index 03ae18a1d9..5c4e537974 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -41,10 +41,10 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, size_t out_height = (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / - attrs.stride_h; + attrs.stride_h + 1; size_t out_width = (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / - attrs.stride_w; + attrs.stride_w + 1; assert(attrs.out_channels > 0); diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index 5d318207ee..824695ca48 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -1,57 +1,72 @@ #include "op-attrs/ops/flat.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "utils/containers/any_of.h" +#include "utils/containers/product.h" +#include "op-attrs/dim_ordered/slice.h" #include namespace FlexFlow { -TensorShape get_output_shape(FlatAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); +TensorShape get_output_shape(FlatAttrs const &attrs, TensorShape const &input_shape) { + FFOrdered leading_dims = slice(ff_ordered(input_shape.dims), ff_dim_t{0}, attrs.start_dim); + FFOrdered flattened_dims = slice(ff_ordered(input_shape.dims), attrs.start_dim, attrs.end_dim); + FFOrdered trailing_dims = slice(ff_ordered(input_shape.dims), attrs.end_dim, std::nullopt); + + if (flattened_dims.empty()) { + return input_shape; + } + + return TensorShape{ + TensorDims{ + concat(std::vector{ + leading_dims, + {product(flattened_dims)}, + trailing_dims, + }), + }, + input_shape.data_type, + }; } -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected get_output_parallel_dim_degrees(FlatAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + FFOrdered flattened_dim_degrees = slice(input_degrees.shard_degrees, attrs.start_dim, attrs.end_dim); + + if (flattened_dim_degrees.empty()) { + return input_degrees; + } + + if (any_of(flattened_dim_degrees, [](int degree) { return degree != 1; })) { + return tl::unexpected(fmt::format("get_output_parallel_dim_degrees for {} expected all shard degrees of flattened dimensions to be 1, but received {}", attrs, input_degrees)); + } + + return ParallelTensorDimDegrees{ + /*sum_degree=*/input_degrees.sum_degree, + /*discard_copy_degree=*/input_degrees.discard_copy_degree, + /*shard_degrees=*/concat(std::vector{ + slice(input_degrees.shard_degrees, ff_dim_t{0}, attrs.start_dim), + {product(flattened_dim_degrees)}, + slice(input_degrees.shard_degrees, attrs.end_dim, std::nullopt), + }), + }; } -// namespace Input { -// constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, -// REPLICA = 4; -// } -// -// namespace Output { -// constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; -// } -// -/* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ - -/* bool is_valid = true; */ -/* is_valid &= input.is_valid(); */ -/* is_valid &= output_shape.is_valid(); */ -/* is_valid &= (input.at(Input::WIDTH).degree == 1); */ - -/* return is_valid; */ -/* } */ - -/* ParallelTensorShape FlatAttrs::calculate_output_shape(ParallelTensorShape - * const &input) const { */ -/* assert (input.num_dims() == Input::NUMDIM); */ -/* ParallelTensorShape output_dims; */ -/* output_dims.data_type = input.data_type; */ - -/* output_dims.at(Output::REPLICA) = input.at(Input::REPLICA); */ -/* output_dims.at(Output::SAMPLE) = input.at(Input::SAMPLE); */ - -/* output_dims.at(Output::CHANNEL).degree = input.at(Input::CHANNEL).degree; - */ -/* assert (input.at(Input::HEIGHT).degree == 1); */ -/* assert (input.at(Input::WIDTH).degree == 1); */ - -/* output_dims.at(Output::CHANNEL).size = input.at(Input::CHANNEL).size * - * input.at(Input::HEIGHT).size * input.at(Input::WIDTH).size; */ -/* output_dims.at(Output::CHANNEL).parallel_idx = - * input.at(Input::CHANNEL).parallel_idx; */ - -/* return output_dims; */ -/* } */ +tl::expected get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = get_output_shape(attrs, get_reduced_shape(input_shape)); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index 6fe0ace109..f09e274cfa 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,9 +1,81 @@ #include "op-attrs/ops/pool_2d.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "utils/integer_conversions.h" namespace FlexFlow { +tl::expected make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation) { + // AdaptivePool2D semantics pulled from + // https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993 + + if (num_dims(input_dims) != 4) { + return tl::unexpected( + fmt::format("make_adaptive_pool2d_attrs expected input tensor to " + "have 4 dims, but received dims {}", + input_dims)); + } + + size_t num_samples = dim_at_idx(input_dims, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_dims, ff_dim_t{1}); + size_t input_h = dim_at_idx(input_dims, ff_dim_t{2}); + size_t input_w = dim_at_idx(input_dims, ff_dim_t{3}); + + if (input_h % output_h != 0) { + return tl::unexpected(fmt::format("Currently make_adaptive_pool2d_attrs only supports input_h % output_h == 0, but received input_h={} and output_h={} (input_dims={}). If you need input_h % output_h != 0 supported, please create an issue.", input_h, output_h, input_dims)); + } + + if (input_w % output_w != 0) { + return tl::unexpected(fmt::format("Currently make_adaptive_pool2d_attrs only supports input_w % output_w == 0, but received input_w={} and output_w={} (input_dims={}). If you need input_w % output_w != 0 supported, please create an issue.", input_w, output_w, input_dims)); + } + + int kernel_h = input_h / output_h; + int kernel_w = input_w / output_w; + + int stride_h = kernel_h; + int stride_w = kernel_w; + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + + TensorShape expected_ouput_shape = TensorShape{ + TensorDims{FFOrdered{ + num_samples, + num_channels, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; + + TensorShape output_shape = ({ + tl::expected result = get_output_shape(attrs, TensorShape{input_dims, DataType::FLOAT}); + if (!result.has_value()) { + return tl::unexpected(result.error()); + } + result.value(); + }); + + if (output_shape != expected_ouput_shape) { + return tl::unexpected(fmt::format("Result of make_adaptive_pool_2d (i.e., {}) should produce expected output shape {}, but produced {}. This is a bug in FlexFlow, Please create an issue.", attrs, expected_ouput_shape, output_shape)); + } + + return attrs; +} + tl::expected get_output_shape(Pool2DAttrs const &attrs, TensorShape const &input_shape) { if (num_dims(input_shape) != 4) { diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..da95263743 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1,71 @@ +#include "op-attrs/dim_ordered/concat.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("concat(FFOrdered, FFOrdered)") { + SUBCASE("inputs have elements") { + FFOrdered l_input = FFOrdered{ + 1, 3, 1 + }; + FFOrdered r_input = FFOrdered{ + 2, 1 + }; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = { + 1, 3, 1, 2, 1 + }; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + FFOrdered l_input = FFOrdered{}; + FFOrdered r_input = FFOrdered{}; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } + + TEST_CASE("concat(std::vector>)") { + SUBCASE("inputs have elements") { + std::vector> input = { + {1}, + {2, 1}, + {1}, + }; + + FFOrdered result = concat(input); + FFOrdered correct = { + 1, 2, 1, 1, + }; + + CHECK(result == correct); + } + + SUBCASE("no inputs") { + std::vector> input = {}; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + std::vector> input = { + {}, {}, {} + }; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/flat.cc b/lib/op-attrs/test/src/op-attrs/ops/flat.cc new file mode 100644 index 0000000000..5e74139bea --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/flat.cc @@ -0,0 +1,235 @@ +#include "op-attrs/ops/flat.h" +#include "utils/expected.h" +#include +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(FlatAttrs, TensorShape)") { + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + SUBCASE("flatten all dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4 * 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten trailing dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten leading dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten middle dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4 * 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim == end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim < end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{1}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3} + }; + + SUBCASE("allows shard parallelism in non-flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 1, 3}, + }; + + tl::expected result = get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 3}, + }; + + CHECK(result == correct); + } + + SUBCASE("does not allow shard parallelism in flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 2, 1}, + }; + + std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("allows sum parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_shape(FlatAttrs, ParallelTensorShape)") { + // since most of the edge cases are already tested in get_output_shape(FlatAttrs, TensorShape) + // and get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees), here we just do + // a basic check that they compose + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8, 1}, + ShardParallelDim{6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + tl::expected result = get_output_shape(attrs, input_shape); + tl::expected correct = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8*6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc index 7db95a545c..4efbe20cf4 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc @@ -2,11 +2,130 @@ #include "utils/expected.h" #include "utils/fmt/expected.h" #include "utils/fmt/optional.h" +#include "utils/integer_conversions.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("make_adaptive_pool2d") { + size_t input_n = 10; + size_t input_c = 11; + size_t input_h = 15; + size_t input_w = 20; + Activation activation = Activation::RELU; + PoolOp op = PoolOp::AVG; + + TensorDims input_dims = TensorDims{FFOrdered{ + input_n, input_c, input_h, input_w + }}; + + SUBCASE("input_h divisible by output_h && input_w divisible by output_w") { + int output_h = 5; + int output_w = 2; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/10, + /*stride_h=*/3, + /*stride_w=*/10, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = make_adaptive_pool2d_attrs(input_dims, + output_h, + output_w, + op, + activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE("confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = get_output_shape(correct_attrs, input_shape); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + input_n, input_c, size_t_from_int(output_h), size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + SUBCASE("input_h not divisible by output_h") { + int output_h = 6; + int output_w = 2; + + std::optional result = optional_from_expected(make_adaptive_pool2d_attrs(input_dims, + output_h, + output_w, + op, + activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_w not divisible by output_w") { + int output_h = 5; + int output_w = 3; + + std::optional result = optional_from_expected(make_adaptive_pool2d_attrs(input_dims, + output_h, + output_w, + op, + activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_h == output_h and input_w == output_w") { + int output_h = input_h; + int output_w = input_w; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/1, + /*kernel_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = make_adaptive_pool2d_attrs(input_dims, + output_h, + output_w, + op, + activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE("confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = get_output_shape(correct_attrs, input_shape); + tl::expected correct = input_shape; + + CHECK(result == correct); + } + } + } + TEST_CASE("get_output_shape(Pool2DAttrs, TensorShape)") { Pool2DAttrs attrs = Pool2DAttrs{ /*kernel_h=*/3, diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 1d5dc1bec4..cab042e609 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -137,6 +137,13 @@ struct ComputationGraphBuilder { PoolOp type = PoolOp::MAX, std::optional const &activation = std::nullopt, std::optional const &name = std::nullopt); + tensor_guid_t + adaptive_pool2d(tensor_guid_t const &input, + int output_h, + int output_w, + PoolOp type = PoolOp::MAX, + std::optional const &activation = std::nullopt, + std::optional const &name = std::nullopt); tensor_guid_t layer_norm(tensor_guid_t const &input, std::vector const &axes, @@ -183,6 +190,8 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); // Add a flat layer tensor_guid_t flat(tensor_guid_t const &input, + int start_dim = 0, + std::optional const &end_dim = std::nullopt, std::optional const &name = std::nullopt); // Add a softmax layer tensor_guid_t softmax(tensor_guid_t const &input, @@ -238,9 +247,8 @@ struct ComputationGraphBuilder { std::vector const &weights, std::vector const &outputs); -private: TensorShape get_shape(tensor_guid_t const &) const; - +private: tensor_guid_t broadcast(tensor_guid_t const &, TensorShape const &, std::string const &); diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 86d6948c60..bc9bc2d0b8 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -11,6 +11,7 @@ #include "op-attrs/ops/element_binary.h" #include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/flat.h" #include "op-attrs/ops/gather.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" @@ -524,38 +525,35 @@ tensor_guid_t ComputationGraphBuilder::pool2d( return this->add_layer(layer, {input}, {}, output_shape); } -/* std::vector - * ComputationGraphBuilder::get_shapes(std::vector const &ts) - * const { */ -/* return transform(ts, [&](tensor_guid_t const &t) { return - * this->get_shape(t); }); */ -/* } */ - -// tensor_guid_t ComputationGraphBuilder::aggregate( -// tensor_guid_t const &gate_preds, -// tensor_guid_t const &gate_assign, -// tensor_guid_t const &true_gate_assign, -// tensor_guid_t const &full_gate_gradients, -// std::vector const &exp_preds, -// int n, -// float lambda_bal, -// std::optional const &maybe_name) { -// AggregateAttrs attrs = {n, lambda_bal}; -// std::string name = maybe_name.value_or(get_default_name(attrs)); - -// LayerAttrs layer = {attrs, name}; -// TensorShape output_shape = get_output_shape(attrs, -// this->get_shape(gate_preds), -// this->get_shape(gate_assign), -// this->get_shape(true_gate_assign), -// this->get_shape(full_gate_gradients), -// this->get_shape(exp_preds)); - -// std::vector inputs = { -// gate_preds, gate_assign, true_gate_assign, full_gate_gradients}; -// extend(inputs, exp_preds); -// return this->add_layer(layer, inputs, {}, output_shape); -// } +tensor_guid_t + ComputationGraphBuilder::adaptive_pool2d(tensor_guid_t const &uncasted_input, + int output_h, + int output_w, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { + + TensorDims input_dims = this->get_shape(uncasted_input).dims; + + Pool2DAttrs attrs = throw_if_unexpected(make_adaptive_pool2d_attrs(input_dims, + output_h, + output_w, + type, + activation)); + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t casted_input = + this->as_type(uncasted_input, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(casted_input))); + + return this->add_layer(layer, {casted_input}, {}, output_shape); +} tensor_guid_t ComputationGraphBuilder::batch_norm( tensor_guid_t const &input, @@ -695,6 +693,28 @@ tensor_guid_t ComputationGraphBuilder::concat( return this->add_layer(layer, inputs, {}, output_shape); } +tensor_guid_t ComputationGraphBuilder::flat(tensor_guid_t const &input, + int start_dim, + std::optional const &end_dim, + std::optional const &maybe_name) { + int input_num_dims = num_dims(this->get_shape(input)); + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{start_dim}, + /*end_dim=*/ff_dim_t{end_dim.value_or(input_num_dims)}, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = + get_output_shape(attrs, this->get_shape(input)); + + return this->add_layer(layer, {input}, {}, output_shape); +} + tensor_guid_t ComputationGraphBuilder::layer_norm( tensor_guid_t const &input, std::vector const &axes, diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 52368f94ad..e8b9f4e441 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -25,10 +25,15 @@ std::vector subvec(std::vector const &v, if (maybe_start.has_value()) { begin_iter += resolve_loc(maybe_start.value()); } + if (maybe_end.has_value()) { end_iter = v.cbegin() + resolve_loc(maybe_end.value()); } + if (end_iter < begin_iter) { + end_iter = begin_iter; + } + std::vector output(begin_iter, end_iter); return output; } From de67ca4981ed54f906193b1efb133b029b5af8bd Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 16 Sep 2024 21:16:19 -0700 Subject: [PATCH 05/10] Fix model bugs --- .../src/models/inception_v3/inception_v3.cc | 37 ++++++++++++------- .../src/models/inception_v3/inception_v3.cc | 2 +- lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 4 +- lib/op-attrs/test/src/ops/conv_2d.cc | 8 ++-- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc index ef291c9bcb..2770a436a8 100644 --- a/lib/models/src/models/inception_v3/inception_v3.cc +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -215,9 +215,15 @@ static tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, } static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, - tensor_guid_t const &input, - int channels_7x7) { - tensor_guid_t branch1x1 = create_conv_block(cgb, input, 192, 1, 1); + CheckShape const &check_shape, + tensor_guid_t const &input, + int channels_7x7) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(branch1x1, 192, 17, 17); tensor_guid_t branch7x7 = [&] { tensor_guid_t t = input; @@ -246,6 +252,7 @@ static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, /*padding_w=*/0); return t; }(); + check_shape(branch7x7, 192, 17, 17); tensor_guid_t branch7x7dbl = [&] { tensor_guid_t t = input; @@ -283,7 +290,7 @@ static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, /*padding_w=*/0); t = create_conv_block(cgb, t, - /*filters=*/channels_7x7, + /*filters=*/192, /*kernel_size_h=*/1, /*kernel_size_w=*/7, /*stride_h=*/1, @@ -292,6 +299,7 @@ static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, /*padding_w=*/3); return t; }(); + check_shape(branch7x7dbl, 192, 17, 17); tensor_guid_t branch_pool = [&] { tensor_guid_t t = input; @@ -310,6 +318,7 @@ static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, /*kernel_size_w=*/1); return t; }(); + check_shape(branch_pool, 192, 17, 17); return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, /*axis=*/1); } @@ -572,7 +581,8 @@ static tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, /*rate=*/0.5); check_shape(x, 2048, 1, 1); - x = cgb.flat(x); + x = cgb.flat(x, + /*start_dim=*/1); check_shape(x, 2048); // fc @@ -611,9 +621,9 @@ static tensor_guid_t create_inception_aux(ComputationGraphBuilder &cgb, // conv1 x = create_conv_block(cgb, x, - /*filters=*/128, - /*kernel_size_h=*/1, - /*kernel_size_w=*/1); + /*filters=*/768, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5); check_shape(x, 768, 1, 1); x = cgb.adaptive_pool2d(x, @@ -621,7 +631,8 @@ static tensor_guid_t create_inception_aux(ComputationGraphBuilder &cgb, /*output_w=*/1); check_shape(x, 768, 1, 1); - x = cgb.flat(input); + x = cgb.flat(x, + /*start_dim=*/1); check_shape(x, 768); // fc @@ -664,19 +675,19 @@ InceptionV3Output check_shape(x, 768, 17, 17); // Mixed_6b - x = create_inception_module_c(cgb, x, 128); + x = create_inception_module_c(cgb, check_shape, x, 128); check_shape(x, 768, 17, 17); // Mixed_6c - x = create_inception_module_c(cgb, x, 160); + x = create_inception_module_c(cgb, check_shape, x, 160); check_shape(x, 768, 17, 17); // Mixed_6d - x = create_inception_module_c(cgb, x, 160); + x = create_inception_module_c(cgb, check_shape, x, 160); check_shape(x, 768, 17, 17); // Mixed_6e - x = create_inception_module_c(cgb, x, 192); + x = create_inception_module_c(cgb, check_shape, x, 192); check_shape(x, 768, 17, 17); std::optional aux; diff --git a/lib/models/test/src/models/inception_v3/inception_v3.cc b/lib/models/test/src/models/inception_v3/inception_v3.cc index 8ec91d81fc..fedaf881b8 100644 --- a/lib/models/test/src/models/inception_v3/inception_v3.cc +++ b/lib/models/test/src/models/inception_v3/inception_v3.cc @@ -12,7 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("num layers") { int result_num_layers = get_layers(result).size(); - int correct_num_layers = -1; + int correct_num_layers = 329; CHECK(result_num_layers == correct_num_layers); } } diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index 5c4e537974..289a32c313 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -40,10 +40,10 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, Conv2DInputShape input = parse_input_shape(raw_input_shape); size_t out_height = - (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / + (input.height + (2 * attrs.padding_h) - attrs.kernel_h) / attrs.stride_h + 1; size_t out_width = - (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / + (input.width + (2 * attrs.padding_w) - attrs.kernel_w) / attrs.stride_w + 1; assert(attrs.out_channels > 0); diff --git a/lib/op-attrs/test/src/ops/conv_2d.cc b/lib/op-attrs/test/src/ops/conv_2d.cc index c4462eb7ec..3e851f2d4d 100644 --- a/lib/op-attrs/test/src/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/ops/conv_2d.cc @@ -31,8 +31,8 @@ TEST_SUITE(FF_TEST_SUITE) { }; size_t num_samples = 7; - size_t input_channels = 6; - size_t input_height = 10; + size_t input_channels = 4; + size_t input_height = 11; size_t input_width = 15; TensorShape input = TensorShape{ @@ -45,8 +45,8 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - size_t output_height = 3; - size_t output_width = 6; + size_t output_height = 6; + size_t output_width = 8; TensorShape output = TensorShape{ TensorDims{FFOrdered{ From fe31d8e415bc0d1059d59c50b42efe24f8458e93 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 16 Sep 2024 23:02:36 -0700 Subject: [PATCH 06/10] Update batch norm to match pytorch interface for inception v3 --- lib/local-execution/src/ops/batch_norm.cc | 23 +- .../src/models/inception_v3/inception_v3.cc | 27 +- .../include/op-attrs/dim_ordered/concat.h | 2 +- .../dim_ordered/ff_ordered_from_map.h | 2 +- .../include/op-attrs/ops/batch_norm.h | 33 +- .../op-attrs/ops/batch_norm_attrs.struct.toml | 20 +- .../op-attrs/ops/flat_attrs.struct.toml | 4 +- .../op-attrs/ops/pool_2d_attrs.struct.toml | 3 +- .../parallel_tensor_dim_degrees.struct.toml | 2 +- .../src/op-attrs/get_incoming_tensor_roles.cc | 5 +- .../src/op-attrs/get_output_shapes.cc | 2 +- lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 227 ++++++++++- lib/op-attrs/src/op-attrs/ops/concat.cc | 1 - lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 2 +- .../op-attrs/computation_graph_op_attrs.cc | 6 +- .../src/op-attrs/get_incoming_tensor_roles.cc | 2 +- .../test/src/op-attrs/ops/batch_norm.cc | 380 ++++++++++++++++++ .../test/src/op-attrs/ops/batch_norm_attrs.cc | 6 +- .../include/pcg/computation_graph_builder.h | 4 +- .../parallel_computation_graph_builder.h | 4 +- lib/pcg/src/pcg/computation_graph_builder.cc | 50 ++- .../parallel_computation_graph_builder.cc | 33 +- .../operator_attribute_key.enum.toml | 3 +- .../operator_pattern/get_attribute.cc | 12 +- .../src/utils/containers/require_all_same1.cc | 14 +- 25 files changed, 798 insertions(+), 69 deletions(-) create mode 100644 lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc diff --git a/lib/local-execution/src/ops/batch_norm.cc b/lib/local-execution/src/ops/batch_norm.cc index 851566fc02..5decfde631 100644 --- a/lib/local-execution/src/ops/batch_norm.cc +++ b/lib/local-execution/src/ops/batch_norm.cc @@ -82,17 +82,18 @@ static DeviceSpecificDeviceStates float *runningMean; - BatchNormPerDeviceState per_device_state = init_kernel(handle, - allocator, - runningMean, - output_n, - output_c, - output_h, - output_w, - attrs.relu); - - return DeviceSpecificDeviceStates{ - DeviceSpecific::create(per_device_state)}; + NOT_IMPLEMENTED(); // TODO @reyna fix me + // BatchNormPerDeviceState per_device_state = init_kernel(handle, + // allocator, + // runningMean, + // output_n, + // output_c, + // output_h, + // output_w, + // attrs.relu); + + // return DeviceSpecificDeviceStates{ + // DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc index 2770a436a8..af6fccb1a6 100644 --- a/lib/models/src/models/inception_v3/inception_v3.cc +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -75,17 +75,20 @@ static tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, int padding_w = 0, bool use_bias = false) { tensor_guid_t conv = cgb.conv2d(input, - filters, - kernel_size_h, - kernel_size_w, - stride_h, - stride_w, - padding_h, - padding_w, - std::nullopt, - 1, - use_bias); - return cgb.batch_norm(conv); + /*outChannels=*/filters, + /*kernelH=*/kernel_size_h, + /*kernelW=*/kernel_size_w, + /*strideH=*/stride_h, + /*strideW=*/stride_w, + /*paddingH=*/padding_h, + /*paddingW=*/padding_w, + /*activation=*/std::nullopt, + /*groups=*/1, + /*use_bias=*/use_bias); + return cgb.batch_norm(conv, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1); } static tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, @@ -734,7 +737,7 @@ ComputationGraph DataType::FLOAT, }; - tensor_guid_t input = cgb.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES); InceptionV3Output output = create_inception_v3(cgb, config, input); return cgb.computation_graph; diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/concat.h b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h index c279fe9502..dfc9869306 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/concat.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/transform.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h index ba85ec59c8..79d4929797 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "op-attrs/dim_ordered/ff_ordered_of.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 8afcbb06b1..73bfd56803 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -1,15 +1,40 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &); +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &); + +tl::expected + get_output_shape(BatchNormAttrs const &, TensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, TensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, TensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(BatchNormAttrs const &, ParallelTensorDimDegrees const &); +tl::expected + get_gamma_weights_parallel_dim_degrees(BatchNormAttrs const &, ParallelTensorDimDegrees const &); +tl::expected + get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &, ParallelTensorDimDegrees const &); + + +tl::expected + get_output_shape(BatchNormAttrs const &, + ParallelTensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml index bc82f3c743..e20183b41d 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml @@ -10,6 +10,24 @@ features = [ "fmt", ] +includes = [ + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + [[fields]] -name = "relu" +name = "eps" +type = "float" + +[[fields]] +name = "affine" type = "bool" + +[[fields]] +name = "momentum" +type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml index aa286a03e7..7349e2a8c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -16,8 +16,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", - "utils/optional.h", - "utils/json.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", "op-attrs/ff_dim.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml index 003469f6f0..20ca7deabc 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -17,7 +17,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", - "utils/json.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml index 9a93c64b13..974b27d2a7 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml @@ -12,7 +12,7 @@ features = [ includes = [ "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", ] [[fields]] diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc index c7febde1d6..21efc26466 100644 --- a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -1,5 +1,6 @@ #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" @@ -22,8 +23,8 @@ std::vector return std::vector{IncomingTensorRole::INPUT, IncomingTensorRole::INPUT}; }, - [](BatchNormAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; + [](BatchNormAttrs const &attrs) { + return get_batch_norm_incoming_tensor_roles(attrs); }, [](BroadcastAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc index 44d31d7143..0058ee35a2 100644 --- a/lib/op-attrs/src/op-attrs/get_output_shapes.cc +++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc @@ -30,7 +30,7 @@ std::vector get_output_shape(attrs, inputs.at(0), inputs.at(1)))}; }, [&](BatchNormAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](CastAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index b75c3521c6..defc695675 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -1,15 +1,232 @@ #include "op-attrs/ops/batch_norm.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/any_of.h" +#include "utils/containers/extend.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, - TensorShape const &input_shape) { +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &attrs) { + std::vector result = {IncomingTensorRole::INPUT}; + + if (attrs.affine) { + extend(result, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return result; +} + +static std::optional + check_input_shape(BatchNormAttrs const &, TensorShape const &input_shape) { + if (num_dims(input_shape) < 2) { + return fmt::format("BatchNormAttrs expected input dims >= 2, but received input shape {}", input_shape); + } + + if (input_shape.data_type != DataType::FLOAT) { + return fmt::format("BatchNormAttrs currently only supports data_type = FLOAT, but received input data_type {}. " + "If you need this feature, please create an issue.", input_shape.data_type); + } + + return std::nullopt; +} + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + return input_shape; } -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected( + "No gamma weights exist for attrs.affine = false"); + } + + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + + return TensorShape{ + TensorDims{FFOrdered{ + num_channels, + }}, + DataType::FLOAT, + }; +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, TensorShape const &input_shape) { + + if (!attrs.affine) { + return tl::unexpected( + "No beta weights exist for attrs.affine = false"); + } + + return get_gamma_weights_shape(attrs, input_shape); +} + +static std::optional + check_input_degrees(BatchNormAttrs const &, ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.shard_degrees.size() < 2) { + return fmt::format("BatchNormAttrs expected input dims >= 2, but received input degrees {}", input_degrees); + } + + if (input_degrees.sum_degree != SumDegree{1}) { + return fmt::format("Expected sum degree 1, but receieved sum degree {}", + input_degrees.sum_degree); + } + + if (input_degrees.discard_copy_degree != DiscardCopyDegree{1}) { + return fmt::format("Expected discard copy degree 1, but receieved discard copy degree {}", + input_degrees.discard_copy_degree); + } + + FFOrdered non_channel_degrees = concat( + slice(input_degrees.shard_degrees, ff_dim_t{0}, ff_dim_t{1}), + slice(input_degrees.shard_degrees, ff_dim_t{2}, std::nullopt)); + + if (any_of(non_channel_degrees, [](int degree) { + return degree != 1; + })) { + return fmt::format("Expected parallel degree of all non-channel dimensions to be 1, but received input with degrees {}", input_degrees); + } + + return std::nullopt; +} + + +tl::expected + get_output_parallel_dim_degrees(BatchNormAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + return input_degrees; +} + +tl::expected + get_gamma_weights_parallel_dim_degrees(BatchNormAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected( + "No gamma weights exist for attrs.affine = false"); + } + + return input_degrees; +} + +tl::expected + get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected( + "No beta weights exist for attrs.affine = false"); + } + + return input_degrees; +} + + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected returned = get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = get_gamma_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = get_gamma_weights_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = get_beta_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = get_beta_weights_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 0e1f52d9ff..74295f279e 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -6,7 +6,6 @@ #include "op-attrs/tensor_shape.h" #include "utils/containers/all_of.h" #include "utils/containers/are_all_same.h" -#include "utils/containers/as_vector.h" #include "utils/containers/require_all_same1.h" #include "utils/containers/sum.h" #include "utils/containers/transform.h" diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index d3c00efbb9..0dd9ac7a17 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -112,7 +112,7 @@ static std::optional if (get_discard_copy_degree(input_shape) != 1) { return fmt::format( - "Expected discard copy degree 1, but received discartd copy degree {}", + "Expected discard copy degree 1, but received discard copy degree {}", get_discard_copy_degree(input_shape)); } diff --git a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc index 42ea07e6b5..7f244aa507 100644 --- a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc @@ -6,7 +6,11 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ComputationGraphOpAttrs to/from json") { ComputationGraphOpAttrs correct = - ComputationGraphOpAttrs{BatchNormAttrs{true}}; + ComputationGraphOpAttrs{BatchNormAttrs{ + /*eps=*/1e-5, + /*affine=*/true, + /*momentum=*/0.1, + }}; nlohmann::json j = correct; auto result = j.get(); diff --git a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc index 60dedfe70a..33cc00c6a1 100644 --- a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc @@ -9,7 +9,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Concat") { int num_incoming = 4; ComputationGraphOpAttrs attrs = - ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}, num_incoming}}; + ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}}}; std::vector result = get_incoming_tensor_roles(attrs, num_incoming); diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc new file mode 100644 index 0000000000..d1074d8482 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc @@ -0,0 +1,380 @@ +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_batch_norm_incoming_tensor_roles(BatchNormAttrs)") { + auto make_attrs = [](bool affine) { + return BatchNormAttrs{ + /*eps=*/1.0, + /*affine=*/affine, + /*momentum=*/0.1, + }; + }; + + SUBCASE("affine = true") { + BatchNormAttrs attrs = make_attrs(/*affine=*/true); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + BatchNormAttrs attrs = make_attrs(/*affine=*/false); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("shape inference (BatchNorm)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*eps=*/1.0, + /*affine=*/true, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + 18, + }}, + DataType::FLOAT, + }; + + TensorShape output = input; + + TensorShape gamma = TensorShape{ + TensorDims{FFOrdered{ + 14, + }}, + DataType::FLOAT, + }; + + TensorShape beta = gamma; + + SUBCASE("get_output_shape(BatchNormAttrs, TensorShape)") { + tl::expected result = + get_output_shape(attrs_affine_true, input); + tl::expected correct = output; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_shape(attrs_affine_true, input); + tl::expected correct = gamma; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_shape(attrs_affine_true, input); + tl::expected correct = beta; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel dim degree inference (BatchNormAttrs)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*eps=*/1.0, + /*affine=*/true, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + SUBCASE("partition parallelism (in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, degree, 1, 1, + }, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + tl::expected result = + get_output_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE( + "get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + SUBCASE("partition parallelism (not in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, degree, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE( + "get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("sum parallelism") { + SumDegree sum_degree = SumDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + sum_degree, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE( + "get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("discard copy parallelism") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + discard_copy_degree, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE( + "get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel shape inference (BatchNormAttrs)") { + // since most of the edge cases are already tested in the above test cases + // (i.e., shape inference and parallel degree inference) + // here we just do a basic check that they compose + + BatchNormAttrs attrs = BatchNormAttrs{ + /*eps=*/1.0, + /*affine=*/true, + /*momentum=*/0.1, + }; + + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{18, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("get_output_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = get_output_shape(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = get_gamma_weights_shape(attrs, input); + tl::expected correct = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = get_beta_weights_shape(attrs, input); + tl::expected correct = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc index df436da66c..b412b1a3b7 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc @@ -5,7 +5,11 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { - BatchNormAttrs correct = BatchNormAttrs{true}; + BatchNormAttrs correct = BatchNormAttrs{ + /*eps=*/1e-5, + /*affine=*/true, + /*momentum=*/0.1, + }; nlohmann::json j = correct; BatchNormAttrs result = j.get(); diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 6d56a4cada..69d1dc1313 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -152,7 +152,9 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); tensor_guid_t batch_norm(tensor_guid_t const &input, - bool relu = true, + bool affine, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); tensor_guid_t batch_matmul(tensor_guid_t const &A, diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 3a7f67dcf0..00303e30d8 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -87,7 +87,9 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t batch_norm(parallel_tensor_guid_t const &input, - bool relu = true, + bool affine, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); parallel_tensor_guid_t diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 0b16fd648b..162a124b52 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -535,7 +535,9 @@ tensor_guid_t ComputationGraphBuilder::pool2d( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, {}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)}) + ); } tensor_guid_t @@ -565,20 +567,50 @@ tensor_guid_t TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(casted_input))); - return this->add_layer(layer, {casted_input}, {}, output_shape); + return get_only( + this->add_layer(layer, {casted_input}, {}, {make_output_attrs(output_shape)}) + ); } tensor_guid_t ComputationGraphBuilder::batch_norm( tensor_guid_t const &input, - bool relu, + bool affine, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + + BatchNormAttrs attrs = BatchNormAttrs{ + /*eps=*/eps, + /*affine=*/affine, + /*momentum=*/momentum, + }; + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape input_shape = this->get_shape(input); + TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); + + std::vector weights; + + if (affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + TensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + TensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } return get_only( this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); @@ -760,7 +792,9 @@ tensor_guid_t ComputationGraphBuilder::concat( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shapes)); - return this->add_layer(layer, inputs, {}, output_shape); + return get_only( + this->add_layer(layer, inputs, {}, {make_output_attrs(output_shape)}) + ); } tensor_guid_t ComputationGraphBuilder::flat(tensor_guid_t const &input, @@ -782,7 +816,9 @@ tensor_guid_t ComputationGraphBuilder::flat(tensor_guid_t const &input, TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, {}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)}) + ); } tensor_guid_t ComputationGraphBuilder::layer_norm( diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 620dc035fc..c7ea8ea9dc 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -331,18 +331,45 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( parallel_tensor_guid_t const &input, - bool relu, + bool affine, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + BatchNormAttrs attrs = BatchNormAttrs{ + /*eps=*/eps, + /*affine=*/affine, + /*momentum=*/momentum, + }; std::string name = maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + ParallelTensorShape input_shape = this->get_shape(input); + ParallelTensorShape output_shape = - get_output_shape(attrs, this->get_shape(input)); + throw_if_unexpected(get_output_shape(attrs, input_shape)); + + std::vector weights; + + if (attrs.affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + ParallelTensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + ParallelTensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } return this->add_layer(layer, {input}, {}, {output_shape}); } diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml index 59e913750e..eb758ea4fc 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml @@ -55,7 +55,8 @@ values = [ { name = "SHOULD_BROADCAST_LHS" }, { name = "SHOULD_BROADCAST_RHS" }, { name = "DIM" }, - { name = "ELEMENTWISE_AFFINE" }, + { name = "AFFINE" }, + { name = "MOMENTUM" }, { name = "REGULARIZER" }, { name = "SHAPE" }, { name = "SPLITS" }, diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index d5d735ef59..442d3345a1 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -19,8 +19,12 @@ std::optional get_attribute(BatchNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); - case OperatorAttributeKey::RELU: - return p.relu; + case OperatorAttributeKey::EPSILON: + return p.eps; + case OperatorAttributeKey::AFFINE: + return p.affine; + case OperatorAttributeKey::MOMENTUM: + return p.momentum; default: return std::nullopt; } @@ -189,6 +193,10 @@ std::optional get_attribute(LayerNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); + case OperatorAttributeKey::AFFINE: + return p.elementwise_affine; + case OperatorAttributeKey::AXES: + return vector_of(p.axes); default: return std::nullopt; } diff --git a/lib/utils/test/src/utils/containers/require_all_same1.cc b/lib/utils/test/src/utils/containers/require_all_same1.cc index a655ac02ef..45a7fcfa78 100644 --- a/lib/utils/test/src/utils/containers/require_all_same1.cc +++ b/lib/utils/test/src/utils/containers/require_all_same1.cc @@ -1,12 +1,12 @@ #include "utils/containers/require_all_same1.h" #include "utils/expected.h" -#include "utils/fmt/expected.h" -#include "utils/fmt/multiset.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include #include #include From 51fa6a23f5f12acdb5606d110c1ec62c0d54d2a2 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 17 Sep 2024 10:29:09 -0700 Subject: [PATCH 07/10] Finishing touches for inception, re-add relu flag for batchnorm --- .../src/export_model_arch.cc | 6 ++++- ...ion_graph_series_parallel_decomposition.cc | 11 +++++++++ lib/local-execution/src/ops/batch_norm.cc | 23 +++++++++---------- .../src/models/inception_v3/inception_v3.cc | 6 +++++ .../src/models/inception_v3/inception_v3.cc | 2 +- .../op-attrs/ops/batch_norm_attrs.struct.toml | 8 +++++-- lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 12 ++++++++-- .../op-attrs/computation_graph_op_attrs.cc | 3 ++- .../test/src/op-attrs/ops/batch_norm.cc | 12 ++++++---- .../test/src/op-attrs/ops/batch_norm_attrs.cc | 3 ++- .../include/pcg/computation_graph_builder.h | 1 + .../parallel_computation_graph_builder.h | 1 + lib/pcg/src/pcg/computation_graph_builder.cc | 15 ++++++++++-- .../parallel_computation_graph_builder.cc | 9 +++++++- 14 files changed, 85 insertions(+), 27 deletions(-) diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index ccc720ed14..31ec2ebcd7 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,6 +1,7 @@ #include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" #include "op-attrs/computation_graph_op_attrs.h" @@ -59,6 +60,9 @@ tl::expected get_model_computation_graph(std::string const &model_name) { if (model_name == "transformer") { return get_default_transformer_computation_graph(); + } else if (model_name == "inception_v3") { + return get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); } else if (model_name == "split_test") { int batch_size = 8; return get_split_test_computation_graph(batch_size); @@ -132,7 +136,7 @@ int main(int argc, char **argv) { "for preprocessed to help check series-parallel structure"}); std::vector model_options = { - "transformer", "split_test", "single_operator"}; + "transformer", "inception_v3", "split_test", "single_operator"}; CLIArgumentKey key_model_name = cli_add_positional_argument( cli, CLIPositionalArgumentSpec{ diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc index ab537e73de..c4966673c7 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,5 @@ #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" #include "pcg/computation_graph.h" @@ -291,6 +292,16 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(sp_decomposition.has_value()); } + + SUBCASE("inception_v3") { + ComputationGraph cg = + get_inception_v3_computation_graph(get_default_inception_v3_training_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } } } diff --git a/lib/local-execution/src/ops/batch_norm.cc b/lib/local-execution/src/ops/batch_norm.cc index 5decfde631..851566fc02 100644 --- a/lib/local-execution/src/ops/batch_norm.cc +++ b/lib/local-execution/src/ops/batch_norm.cc @@ -82,18 +82,17 @@ static DeviceSpecificDeviceStates float *runningMean; - NOT_IMPLEMENTED(); // TODO @reyna fix me - // BatchNormPerDeviceState per_device_state = init_kernel(handle, - // allocator, - // runningMean, - // output_n, - // output_c, - // output_h, - // output_w, - // attrs.relu); - - // return DeviceSpecificDeviceStates{ - // DeviceSpecific::create(per_device_state)}; + BatchNormPerDeviceState per_device_state = init_kernel(handle, + allocator, + runningMean, + output_n, + output_c, + output_h, + output_w, + attrs.relu); + + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc index af6fccb1a6..3b3e377c3f 100644 --- a/lib/models/src/models/inception_v3/inception_v3.cc +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -87,6 +87,7 @@ static tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, /*use_bias=*/use_bias); return cgb.batch_norm(conv, /*affine=*/true, + /*activation=*/Activation::RELU, /*eps=*/1e-5, /*momentum=*/0.1); } @@ -593,6 +594,11 @@ static tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, /*outDim=*/num_classes); check_shape(x, num_classes); + // softmax (not in pytorch model, but shown in Table 1 on p6 of + // https://arxiv.org/abs/1512.00567) + x = cgb.softmax(x); + check_shape(x, num_classes); + return x; } diff --git a/lib/models/test/src/models/inception_v3/inception_v3.cc b/lib/models/test/src/models/inception_v3/inception_v3.cc index fedaf881b8..2b0fe82fd6 100644 --- a/lib/models/test/src/models/inception_v3/inception_v3.cc +++ b/lib/models/test/src/models/inception_v3/inception_v3.cc @@ -12,7 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("num layers") { int result_num_layers = get_layers(result).size(); - int correct_num_layers = 329; + int correct_num_layers = 522; CHECK(result_num_layers == correct_num_layers); } } diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml index e20183b41d..fdc3bce1fe 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml @@ -21,13 +21,17 @@ src_includes = [ ] [[fields]] -name = "eps" -type = "float" +name = "relu" +type = "bool" [[fields]] name = "affine" type = "bool" +[[fields]] +name = "eps" +type = "float" + [[fields]] name = "momentum" type = "std::optional" diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index defc695675..a35cfed307 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -141,7 +141,15 @@ tl::expected "No gamma weights exist for attrs.affine = false"); } - return input_degrees; + ff_dim_t channel_dim = ff_dim_t{1}; + + return ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + input_degrees.shard_degrees.at(channel_dim) + }, + }; } tl::expected @@ -159,7 +167,7 @@ tl::expected "No beta weights exist for attrs.affine = false"); } - return input_degrees; + return get_gamma_weights_parallel_dim_degrees(attrs, input_degrees); } diff --git a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc index 7f244aa507..37db098196 100644 --- a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc @@ -7,8 +7,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ComputationGraphOpAttrs to/from json") { ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{BatchNormAttrs{ - /*eps=*/1e-5, + /*relu=*/false, /*affine=*/true, + /*eps=*/1e-5, /*momentum=*/0.1, }}; nlohmann::json j = correct; diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc index d1074d8482..af5e61f8d3 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc @@ -11,8 +11,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_batch_norm_incoming_tensor_roles(BatchNormAttrs)") { auto make_attrs = [](bool affine) { return BatchNormAttrs{ - /*eps=*/1.0, + /*relu=*/false, /*affine=*/affine, + /*eps=*/1.0, /*momentum=*/0.1, }; }; @@ -46,8 +47,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("shape inference (BatchNorm)") { BatchNormAttrs attrs_affine_true = BatchNormAttrs{ - /*eps=*/1.0, + /*relu=*/false, /*affine=*/true, + /*eps=*/1.0, /*momentum=*/0.1, }; @@ -125,8 +127,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("parallel dim degree inference (BatchNormAttrs)") { BatchNormAttrs attrs_affine_true = BatchNormAttrs{ - /*eps=*/1.0, + /*relu=*/false, /*affine=*/true, + /*eps=*/1.0, /*momentum=*/0.1, }; @@ -313,8 +316,9 @@ TEST_SUITE(FF_TEST_SUITE) { // here we just do a basic check that they compose BatchNormAttrs attrs = BatchNormAttrs{ - /*eps=*/1.0, + /*relu=*/true, /*affine=*/true, + /*eps=*/1.0, /*momentum=*/0.1, }; diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc index b412b1a3b7..cbe5ff3c42 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc @@ -6,8 +6,9 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { BatchNormAttrs correct = BatchNormAttrs{ - /*eps=*/1e-5, + /*relu=*/false, /*affine=*/true, + /*eps=*/1e-5, /*momentum=*/0.1, }; diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 69d1dc1313..4d4e0dfd74 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -153,6 +153,7 @@ struct ComputationGraphBuilder { tensor_guid_t batch_norm(tensor_guid_t const &input, bool affine, + std::optional const &activation, float eps, std::optional const &momentum, std::optional const &name = std::nullopt); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 00303e30d8..019b120936 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -88,6 +88,7 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t batch_norm(parallel_tensor_guid_t const &input, bool affine, + std::optional const &activation, float eps, std::optional const &momentum, std::optional const &name = std::nullopt); diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 162a124b52..696d558fe8 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -575,13 +575,20 @@ tensor_guid_t tensor_guid_t ComputationGraphBuilder::batch_norm( tensor_guid_t const &input, bool affine, + std::optional const &activation, float eps, std::optional const &momentum, std::optional const &maybe_name) { + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format("batch_norm currently only supports (1) no activation function, or (2) relu activation function, but received {}. " + "If you need support for additional activation functions, please create an issue.", activation)); + } + BatchNormAttrs attrs = BatchNormAttrs{ - /*eps=*/eps, + /*relu=*/activation.has_value(), /*affine=*/affine, + /*eps=*/eps, /*momentum=*/momentum, }; @@ -613,7 +620,11 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( } return get_only( - this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); + this->add_layer(layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::multihead_attention( diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index c7ea8ea9dc..956cf1d4f3 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -332,13 +332,20 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( parallel_tensor_guid_t const &input, bool affine, + std::optional const &activation, float eps, std::optional const &momentum, std::optional const &maybe_name) { + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format("batch_norm currently only supports (1) no activation function, or (2) relu activation function, but received {}. " + "If you need support for additional activation functions, please create an issue.", activation)); + } + BatchNormAttrs attrs = BatchNormAttrs{ - /*eps=*/eps, + /*relu=*/activation.has_value(), /*affine=*/affine, + /*eps=*/eps, /*momentum=*/momentum, }; From e501aea91ecc807e1b1af4007c22c61b4278050e Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 17 Sep 2024 10:40:00 -0700 Subject: [PATCH 08/10] Format --- .../src/export_model_arch.cc | 2 +- ...ion_graph_series_parallel_decomposition.cc | 4 +- .../src/models/inception_v3/inception_v3.cc | 536 +++++++++--------- .../include/op-attrs/dim_ordered/concat.h | 8 +- .../include/op-attrs/dim_ordered/slice.h | 1 - .../include/op-attrs/ops/batch_norm.h | 26 +- lib/op-attrs/include/op-attrs/ops/flat.h | 9 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 12 +- lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 118 ++-- lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 8 +- lib/op-attrs/src/op-attrs/ops/flat.cc | 67 ++- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 74 ++- .../op-attrs/computation_graph_op_attrs.cc | 13 +- .../test/src/op-attrs/dim_ordered/concat.cc | 27 +- .../test/src/op-attrs/ops/batch_norm.cc | 222 ++++---- .../test/src/op-attrs/ops/batch_norm_attrs.cc | 8 +- lib/op-attrs/test/src/op-attrs/ops/flat.cc | 231 ++++---- lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc | 96 ++-- .../include/pcg/computation_graph_builder.h | 15 +- lib/pcg/src/pcg/computation_graph_builder.cc | 88 ++- .../parallel_computation_graph_builder.cc | 16 +- lib/utils/include/utils/containers/subvec.h | 2 +- .../src/utils/containers/require_all_same1.cc | 2 +- 23 files changed, 831 insertions(+), 754 deletions(-) diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 31ec2ebcd7..98b7a003ce 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -62,7 +62,7 @@ tl::expected return get_default_transformer_computation_graph(); } else if (model_name == "inception_v3") { return get_inception_v3_computation_graph( - get_default_inception_v3_training_config()); + get_default_inception_v3_training_config()); } else if (model_name == "split_test") { int batch_size = 8; return get_split_test_computation_graph(batch_size); diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc index c4966673c7..c9d84a8948 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -294,8 +294,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("inception_v3") { - ComputationGraph cg = - get_inception_v3_computation_graph(get_default_inception_v3_training_config()); + ComputationGraph cg = get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); std::optional sp_decomposition = get_computation_graph_series_parallel_decomposition(cg); diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc index 3b3e377c3f..f540eae629 100644 --- a/lib/models/src/models/inception_v3/inception_v3.cc +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -1,18 +1,16 @@ #include "models/inception_v3/inception_v3.h" +#include "models/inception_v3/inception_v3_output.dtg.h" #include "op-attrs/tensor_shape.h" #include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" #include "utils/integer_conversions.h" -#include "models/inception_v3/inception_v3_output.dtg.h" namespace FlexFlow { struct CheckShape { CheckShape(ComputationGraphBuilder const &cgb, InceptionV3Config const &config) - : cgb(cgb), - config(config) - { } + : cgb(cgb), config(config) {} ComputationGraphBuilder const &cgb; InceptionV3Config const &config; @@ -20,60 +18,66 @@ struct CheckShape { void operator()(tensor_guid_t t, int c, int h, int w) const { TensorShape current_shape = cgb.get_shape(t); TensorShape expected_shape = TensorShape{ - TensorDims{FFOrdered{ - size_t_from_int(config.batch_size), - size_t_from_int(c), - size_t_from_int(h), - size_t_from_int(w), - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + size_t_from_int(h), + size_t_from_int(w), + }}, + DataType::FLOAT, }; if (current_shape != expected_shape) { - throw mk_runtime_error(fmt::format("Expected activation shape {}, but found activation shape {}", expected_shape, current_shape)); + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); } } void operator()(tensor_guid_t t, int c) const { TensorShape current_shape = cgb.get_shape(t); TensorShape expected_shape = TensorShape{ - TensorDims{FFOrdered{ - size_t_from_int(config.batch_size), - size_t_from_int(c), - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + }}, + DataType::FLOAT, }; if (current_shape != expected_shape) { - throw mk_runtime_error(fmt::format("Expected activation shape {}, but found activation shape {}", expected_shape, current_shape)); + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); } } }; - InceptionV3Config get_default_inception_v3_training_config() { return InceptionV3Config{ - /*num_classes=*/1000, + /*num_classes=*/1000, - // see section 8 of https://arxiv.org/abs/1512.00567 for the source of the batch size - /*batch_size=*/32, + // see section 8 of https://arxiv.org/abs/1512.00567 for the source of the + // batch size + /*batch_size=*/32, - // see section 4 of https://arxiv.org/abs/1512.00567 for a discussion of auxiliary logits. - // they are used by default in training - /*aux_logits=*/true, + // see section 4 of https://arxiv.org/abs/1512.00567 for a discussion of + // auxiliary logits. they are used by default in training + /*aux_logits=*/true, }; } static tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, - tensor_guid_t const &input, - int filters, - int kernel_size_h, - int kernel_size_w, - int stride_h = 1, - int stride_w = 1, - int padding_h = 0, - int padding_w = 0, - bool use_bias = false) { + tensor_guid_t const &input, + int filters, + int kernel_size_h, + int kernel_size_w, + int stride_h = 1, + int stride_w = 1, + int padding_h = 0, + int padding_w = 0, + bool use_bias = false) { tensor_guid_t conv = cgb.conv2d(input, /*outChannels=*/filters, /*kernelH=*/kernel_size_h, @@ -93,75 +97,75 @@ static tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, } static tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, - tensor_guid_t const &input, - int pool_features) { - tensor_guid_t branch1x1 = create_conv_block(cgb, - input, - /*filters=*/64, - /*kernel_size_h=*/1, + tensor_guid_t const &input, + int pool_features) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/64, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); tensor_guid_t branch5x5 = [&] { tensor_guid_t t = input; - t = create_conv_block(cgb, - t, - /*filters=*/48, - /*kernel_size_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/48, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); - t = create_conv_block(cgb, - t, - /*filters=*/64, - /*kernel_size_h=*/5, - /*kernel_size_w=*/5, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/2, + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/2, /*padding_w=*/2); return t; }(); tensor_guid_t branch3x3dbl = [&] { tensor_guid_t t = input; - t = create_conv_block(cgb, - t, - /*filters=*/64, - /*kernel_size_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); - t = create_conv_block(cgb, - t, - /*filters=*/96, - /*kernel_size_h=*/3, - /*kernel_size_w=*/3, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, /*padding_w=*/1); - t = create_conv_block(cgb, - t, - /*filters=*/96, - /*kernel_size_h=*/3, - /*kernel_size_w=*/3, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, /*padding_w=*/1); return t; }(); tensor_guid_t branch_pool = [&] { tensor_guid_t t = input; - t = cgb.pool2d(t, - /*kernelH=*/3, - /*kernelW=*/3, - /*strideH=*/1, - /*strideW=*/1, - /*paddingH=*/1, - /*paddingW=*/1, + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, /*type=*/PoolOp::AVG); - t = create_conv_block(cgb, - t, - /*filters=*/pool_features, - /*kernel_stride_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/pool_features, + /*kernel_stride_h=*/1, /*kernel_stride_w=*/1); return t; }(); @@ -171,48 +175,48 @@ static tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, } static tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, - tensor_guid_t const &input) { - tensor_guid_t branch3x3 = create_conv_block(cgb, - input, - /*filters=*/384, - /*kernel_size_h=*/3, - /*kernel_size_w=*/3, - /*stride_h=*/2, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = create_conv_block(cgb, + input, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, /*stride_w=*/2); tensor_guid_t branch3x3dbl = [&] { tensor_guid_t t = input; - t = create_conv_block(cgb, - t, - /*filters=*/64, - /*kernel_size_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); - t = create_conv_block(cgb, - t, - /*filters=*/96, - /*kernel_size_h=*/3, - /*kernel_size_w=*/3, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, /*padding_w=*/1); - t = create_conv_block(cgb, - t, - /*filters=*/96, - /*kernel_stride_h=*/3, + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_stride_h=*/3, /*kernel_stride_w=*/3, - /*stride_h=*/2, + /*stride_h=*/2, /*stride_w=*/2); return t; }(); - tensor_guid_t branch_pool = cgb.pool2d(input, - /*kernelH=*/3, - /*kernelW=*/3, - /*strideH=*/2, - /*strideW=*/2, - /*paddingH=*/0, - /*paddingW=*/0, + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, /*type=*/PoolOp::MAX); return cgb.concat({branch3x3, branch3x3dbl, branch_pool}, /*axis=*/1); @@ -222,10 +226,10 @@ static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, CheckShape const &check_shape, tensor_guid_t const &input, int channels_7x7) { - tensor_guid_t branch1x1 = create_conv_block(cgb, - input, - /*filters=*/192, - /*kernel_size_h=*/1, + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/192, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); check_shape(branch1x1, 192, 17, 17); @@ -245,14 +249,14 @@ static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, /*stride_w=*/1, /*padding_h=*/0, /*padding_w=*/3); - t = create_conv_block(cgb, - t, - /*filters=*/192, - /*kernel_size_h=*/7, - /*kernel_size_w=*/1, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/3, + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, /*padding_w=*/0); return t; }(); @@ -260,46 +264,46 @@ static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, tensor_guid_t branch7x7dbl = [&] { tensor_guid_t t = input; - t = create_conv_block(cgb, - t, - /*filters=*/channels_7x7, - /*kernel_size_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); - t = create_conv_block(cgb, - t, - /*filters=*/channels_7x7, - /*kernel_size_h=*/7, - /*kernel_size_w=*/1, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/3, + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, /*padding_w=*/0); - t = create_conv_block(cgb, - t, - /*filters=*/channels_7x7, - /*kernel_size_h=*/1, - /*kernel_size_w=*/7, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/0, + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, /*padding_w=*/3); - t = create_conv_block(cgb, - t, - /*filters=*/channels_7x7, - /*kernel_size_h=*/7, - /*kernel_size_w=*/1, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/3, + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, /*padding_w=*/0); - t = create_conv_block(cgb, - t, - /*filters=*/192, - /*kernel_size_h=*/1, - /*kernel_size_w=*/7, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/0, + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, /*padding_w=*/3); return t; }(); @@ -307,34 +311,35 @@ static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, tensor_guid_t branch_pool = [&] { tensor_guid_t t = input; - t = cgb.pool2d(t, - /*kernelH=*/3, - /*kernelW=*/3, - /*strideH=*/1, - /*strideW=*/1, - /*paddingH=*/1, - /*paddingW=*/1, + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, /*type=*/PoolOp::AVG); - t = create_conv_block(cgb, - t, - /*filters=*/192, - /*kernel_size_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); return t; }(); check_shape(branch_pool, 192, 17, 17); - return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, /*axis=*/1); + return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, + /*axis=*/1); } static tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, - tensor_guid_t const &input) { + tensor_guid_t const &input) { tensor_guid_t branch3x3 = [&] { tensor_guid_t t = input; - t = create_conv_block(cgb, - t, - /*filters=*/192, - /*kernel_size_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); t = create_conv_block(cgb, t, 320, 3, 3, 2, 2); return t; @@ -375,7 +380,7 @@ static tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, return t; }(); - tensor_guid_t branch_pool = cgb.pool2d(input, + tensor_guid_t branch_pool = cgb.pool2d(input, /*kernelH=*/3, /*kernelW=*/3, /*strideH=*/2, @@ -388,50 +393,48 @@ static tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, } static tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, - tensor_guid_t const &input) { - tensor_guid_t branch1x1 = create_conv_block(cgb, - input, - /*filters=*/320, - /*kernel_size_h=*/1, + tensor_guid_t const &input) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/320, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); tensor_guid_t branch3x3 = [&] { tensor_guid_t t = input; - t = create_conv_block(cgb, - t, - /*filters=*/384, - /*kernel_size_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); - tensor_guid_t t_1 = - create_conv_block(cgb, - t, - /*filters=*/384, - /*kernel_size_h=*/1, - /*kernel_size_w=*/3, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/0, - /*padding_w=*/1); - tensor_guid_t t_2 = - create_conv_block(cgb, - t, - /*filters=*/384, - /*kernel_size_h=*/3, - /*kernel_size_w=*/1, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/1, - /*padding_w=*/0); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); t = cgb.concat({t_1, t_2}, /*axis=*/1); return t; }(); tensor_guid_t branch3x3dbl = [&] { tensor_guid_t t = input; - t = create_conv_block(cgb, - t, - /*filters=*/448, - /*kernel_size_h=*/1, + t = create_conv_block(cgb, + t, + /*filters=*/448, + /*kernel_size_h=*/1, /*kernel_size_w=*/1); t = create_conv_block(cgb, t, @@ -442,26 +445,24 @@ static tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, /*stride_w=*/1, /*padding_h=*/1, /*padding_w=*/1); - tensor_guid_t t_1 = - create_conv_block(cgb, - t, - /*filters=*/384, - /*kernel_size_h=*/1, - /*kernel_size_w=*/3, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/0, - /*padding_w=*/1); - tensor_guid_t t_2 = - create_conv_block(cgb, - t, - /*filters=*/384, - /*kernel_size_h=*/3, - /*kernel_size_w=*/1, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/1, - /*padding_w=*/0); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); t = cgb.concat({t_1, t_2}, /*axis=*/1); return t; }(); @@ -484,23 +485,24 @@ static tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, return t; }(); - return cgb.concat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, /*axis=*/1); + return cgb.concat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, + /*axis=*/1); } static tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, - CheckShape const &check_shape, - tensor_guid_t const &input) { + CheckShape const &check_shape, + tensor_guid_t const &input) { tensor_guid_t t = input; check_shape(t, 3, 299, 299); // Conv2d_1a_3x3 - t = create_conv_block(cgb, - t, - /*filters=*/32, - /*kernel_size_h=*/3, - /*kernel_size_w=*/3, - /*stride_h=*/2, + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, /*stride_w=*/2); check_shape(t, 32, 149, 149); @@ -566,9 +568,9 @@ static tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, } static tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, - CheckShape const &check_shape, - tensor_guid_t const &input, - size_t num_classes) { + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { // avgpool tensor_guid_t x = cgb.pool2d(input, /*kernelH=*/8, @@ -584,17 +586,17 @@ static tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, x = cgb.dropout(x, /*rate=*/0.5); check_shape(x, 2048, 1, 1); - + x = cgb.flat(x, /*start_dim=*/1); check_shape(x, 2048); - + // fc x = cgb.dense(x, /*outDim=*/num_classes); check_shape(x, num_classes); - // softmax (not in pytorch model, but shown in Table 1 on p6 of + // softmax (not in pytorch model, but shown in Table 1 on p6 of // https://arxiv.org/abs/1512.00567) x = cgb.softmax(x); check_shape(x, num_classes); @@ -640,28 +642,27 @@ static tensor_guid_t create_inception_aux(ComputationGraphBuilder &cgb, /*output_w=*/1); check_shape(x, 768, 1, 1); - x = cgb.flat(x, + x = cgb.flat(x, /*start_dim=*/1); check_shape(x, 768); // fc - x = cgb.dense(x, + x = cgb.dense(x, /*outDim=*/num_classes); check_shape(x, num_classes); return x; } -static -InceptionV3Output - create_inception_v3(ComputationGraphBuilder &cgb, - InceptionV3Config const &config, - tensor_guid_t const &input) { - // NOTE: the shapes for check_shape (as well as the layer names in comments) are pulled from +static InceptionV3Output create_inception_v3(ComputationGraphBuilder &cgb, + InceptionV3Config const &config, + tensor_guid_t const &input) { + // NOTE: the shapes for check_shape (as well as the layer names in comments) + // are pulled from // https://github.com/pytorch/vision/blob/6d7851bd5e2bedc294e40e90532f0e375fcfee04/torchvision/models/inception.py#L103-L155 CheckShape check_shape = CheckShape{ - /*cgb=*/cgb, - /*config=*/config, + /*cgb=*/cgb, + /*config=*/config, }; tensor_guid_t x = create_initial_layers(cgb, check_shape, input); @@ -701,13 +702,10 @@ InceptionV3Output std::optional aux; if (config.aux_logits) { - aux = create_inception_aux(cgb, - check_shape, - x, - config.num_classes); + aux = create_inception_aux(cgb, check_shape, x, config.num_classes); check_shape(aux.value(), config.num_classes); } - + // Mixed_7a x = create_inception_module_d(cgb, x); check_shape(x, 1280, 8, 8); @@ -724,8 +722,8 @@ InceptionV3Output check_shape(x, config.num_classes); return InceptionV3Output{ - x, - aux, + x, + aux, }; } @@ -735,10 +733,10 @@ ComputationGraph TensorShape input_shape = TensorShape{ TensorDims{FFOrdered{ - size_t_from_int(config.batch_size), - 3, - 299, - 299, + size_t_from_int(config.batch_size), + 3, + 299, + 299, }}, DataType::FLOAT, }; diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/concat.h b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h index dfc9869306..9b9eaf9b93 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/concat.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h @@ -19,10 +19,10 @@ FFOrdered concat(FFOrdered const &l, FFOrdered const &r) { template FFOrdered concat(std::vector> const &inputs) { - std::vector> vec_inputs = transform(inputs, - [](FFOrdered const &input) { - return std::vector(input.cbegin(), input.cend()); - }); + std::vector> vec_inputs = + transform(inputs, [](FFOrdered const &input) { + return std::vector(input.cbegin(), input.cend()); + }); std::vector raw_result = concat_vectors(vec_inputs); diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index 6c986cb1f1..e4c0e8e275 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -28,7 +28,6 @@ FFOrdered slice(FFOrdered const &d, return nonoverloaded_slice(d, start, end); } - template DimOrdered slice(DimOrdered const &d, std::optional const &start, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 73bfd56803..f2e95690d1 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -13,28 +13,30 @@ namespace FlexFlow { std::vector get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &); +tl::expected get_output_shape(BatchNormAttrs const &, + TensorShape const &); tl::expected - get_output_shape(BatchNormAttrs const &, TensorShape const &); + get_gamma_weights_shape(BatchNormAttrs const &, TensorShape const &); tl::expected - get_gamma_weights_shape(BatchNormAttrs const &, TensorShape const &); -tl::expected - get_beta_weights_shape(BatchNormAttrs const &, TensorShape const &); + get_beta_weights_shape(BatchNormAttrs const &, TensorShape const &); tl::expected - get_output_parallel_dim_degrees(BatchNormAttrs const &, ParallelTensorDimDegrees const &); + get_output_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); tl::expected - get_gamma_weights_parallel_dim_degrees(BatchNormAttrs const &, ParallelTensorDimDegrees const &); + get_gamma_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); tl::expected - get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &, ParallelTensorDimDegrees const &); - + get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); tl::expected - get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &); + get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &); tl::expected - get_gamma_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); + get_gamma_weights_shape(BatchNormAttrs const &, + ParallelTensorShape const &); tl::expected - get_beta_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); + get_beta_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 3f0cdd7fa4..710cbdb44b 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -12,10 +12,11 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(FlatAttrs); TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); -tl::expected get_output_parallel_dim_degrees(FlatAttrs const &, - ParallelTensorDimDegrees const &); -tl::expected get_output_shape(FlatAttrs const &, - ParallelTensorShape const &); +tl::expected + get_output_parallel_dim_degrees(FlatAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_output_shape(FlatAttrs const &, ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 36bec5f0d1..1af22ad022 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -11,12 +11,12 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); -tl::expected make_adaptive_pool2d_attrs(TensorDims const &input_dims, - int output_h, - int output_w, - PoolOp pool_type, - std::optional const &activation); - +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation); tl::expected get_output_shape(Pool2DAttrs const &, TensorShape const &); diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index a35cfed307..f394bb8473 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -23,19 +23,24 @@ std::vector static std::optional check_input_shape(BatchNormAttrs const &, TensorShape const &input_shape) { if (num_dims(input_shape) < 2) { - return fmt::format("BatchNormAttrs expected input dims >= 2, but received input shape {}", input_shape); + return fmt::format( + "BatchNormAttrs expected input dims >= 2, but received input shape {}", + input_shape); } if (input_shape.data_type != DataType::FLOAT) { - return fmt::format("BatchNormAttrs currently only supports data_type = FLOAT, but received input data_type {}. " - "If you need this feature, please create an issue.", input_shape.data_type); + return fmt::format("BatchNormAttrs currently only supports data_type = " + "FLOAT, but received input data_type {}. " + "If you need this feature, please create an issue.", + input_shape.data_type); } return std::nullopt; } tl::expected - get_output_shape(BatchNormAttrs const &attrs, TensorShape const &input_shape) { + get_output_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { { std::optional maybe_err_msg = check_input_shape(attrs, input_shape); @@ -48,7 +53,8 @@ tl::expected } tl::expected - get_gamma_weights_shape(BatchNormAttrs const &attrs, TensorShape const &input_shape) { + get_gamma_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { { std::optional maybe_err_msg = check_input_shape(attrs, input_shape); @@ -58,35 +64,37 @@ tl::expected } if (!attrs.affine) { - return tl::unexpected( - "No gamma weights exist for attrs.affine = false"); + return tl::unexpected("No gamma weights exist for attrs.affine = false"); } size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); return TensorShape{ - TensorDims{FFOrdered{ - num_channels, - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + num_channels, + }}, + DataType::FLOAT, }; } tl::expected - get_beta_weights_shape(BatchNormAttrs const &attrs, TensorShape const &input_shape) { + get_beta_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { if (!attrs.affine) { - return tl::unexpected( - "No beta weights exist for attrs.affine = false"); + return tl::unexpected("No beta weights exist for attrs.affine = false"); } return get_gamma_weights_shape(attrs, input_shape); } static std::optional - check_input_degrees(BatchNormAttrs const &, ParallelTensorDimDegrees const &input_degrees) { + check_input_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &input_degrees) { if (input_degrees.shard_degrees.size() < 2) { - return fmt::format("BatchNormAttrs expected input dims >= 2, but received input degrees {}", input_degrees); + return fmt::format("BatchNormAttrs expected input dims >= 2, but received " + "input degrees {}", + input_degrees); } if (input_degrees.sum_degree != SumDegree{1}) { @@ -95,26 +103,28 @@ static std::optional } if (input_degrees.discard_copy_degree != DiscardCopyDegree{1}) { - return fmt::format("Expected discard copy degree 1, but receieved discard copy degree {}", - input_degrees.discard_copy_degree); + return fmt::format( + "Expected discard copy degree 1, but receieved discard copy degree {}", + input_degrees.discard_copy_degree); } - FFOrdered non_channel_degrees = concat( - slice(input_degrees.shard_degrees, ff_dim_t{0}, ff_dim_t{1}), - slice(input_degrees.shard_degrees, ff_dim_t{2}, std::nullopt)); + FFOrdered non_channel_degrees = + concat(slice(input_degrees.shard_degrees, ff_dim_t{0}, ff_dim_t{1}), + slice(input_degrees.shard_degrees, ff_dim_t{2}, std::nullopt)); - if (any_of(non_channel_degrees, [](int degree) { - return degree != 1; - })) { - return fmt::format("Expected parallel degree of all non-channel dimensions to be 1, but received input with degrees {}", input_degrees); + if (any_of(non_channel_degrees, [](int degree) { return degree != 1; })) { + return fmt::format("Expected parallel degree of all non-channel dimensions " + "to be 1, but received input with degrees {}", + input_degrees); } return std::nullopt; } - tl::expected - get_output_parallel_dim_degrees(BatchNormAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + get_output_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { { std::optional maybe_err_msg = check_input_degrees(attrs, input_degrees); @@ -127,7 +137,9 @@ tl::expected } tl::expected - get_gamma_weights_parallel_dim_degrees(BatchNormAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + get_gamma_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { { std::optional maybe_err_msg = check_input_degrees(attrs, input_degrees); @@ -137,23 +149,22 @@ tl::expected } if (!attrs.affine) { - return tl::unexpected( - "No gamma weights exist for attrs.affine = false"); + return tl::unexpected("No gamma weights exist for attrs.affine = false"); } ff_dim_t channel_dim = ff_dim_t{1}; return ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{ - input_degrees.shard_degrees.at(channel_dim) - }, + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{input_degrees.shard_degrees.at(channel_dim)}, }; } tl::expected - get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + get_beta_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { { std::optional maybe_err_msg = check_input_degrees(attrs, input_degrees); @@ -163,19 +174,18 @@ tl::expected } if (!attrs.affine) { - return tl::unexpected( - "No beta weights exist for attrs.affine = false"); + return tl::unexpected("No beta weights exist for attrs.affine = false"); } return get_gamma_weights_parallel_dim_degrees(attrs, input_degrees); } - tl::expected - get_output_shape(BatchNormAttrs const &attrs, - ParallelTensorShape const &input_shape) { + get_output_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { TensorShape unpar = ({ - tl::expected returned = get_output_shape(attrs, get_reduced_shape(input_shape)); + tl::expected returned = + get_output_shape(attrs, get_reduced_shape(input_shape)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } @@ -183,7 +193,9 @@ tl::expected }); ParallelTensorDimDegrees degrees = ({ - tl::expected returned = get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } @@ -194,10 +206,12 @@ tl::expected } tl::expected - get_gamma_weights_shape(BatchNormAttrs const &attrs, ParallelTensorShape const &input_shape) { + get_gamma_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { TensorShape unpar = ({ - tl::expected returned = get_gamma_weights_shape(attrs, get_reduced_shape(input_shape)); + tl::expected returned = + get_gamma_weights_shape(attrs, get_reduced_shape(input_shape)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } @@ -205,7 +219,9 @@ tl::expected }); ParallelTensorDimDegrees degrees = ({ - tl::expected returned = get_gamma_weights_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + tl::expected returned = + get_gamma_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } @@ -216,10 +232,12 @@ tl::expected } tl::expected - get_beta_weights_shape(BatchNormAttrs const &attrs, ParallelTensorShape const &input_shape) { + get_beta_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { TensorShape unpar = ({ - tl::expected returned = get_beta_weights_shape(attrs, get_reduced_shape(input_shape)); + tl::expected returned = + get_beta_weights_shape(attrs, get_reduced_shape(input_shape)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } @@ -227,7 +245,9 @@ tl::expected }); ParallelTensorDimDegrees degrees = ({ - tl::expected returned = get_beta_weights_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + tl::expected returned = + get_beta_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index fb0039936c..eac756cc15 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -54,11 +54,11 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, Conv2DInputShape input = parse_input_shape(raw_input_shape); size_t out_height = - (input.height + (2 * attrs.padding_h) - attrs.kernel_h) / - attrs.stride_h + 1; + (input.height + (2 * attrs.padding_h) - attrs.kernel_h) / attrs.stride_h + + 1; size_t out_width = - (input.width + (2 * attrs.padding_w) - attrs.kernel_w) / - attrs.stride_w + 1; + (input.width + (2 * attrs.padding_w) - attrs.kernel_w) / attrs.stride_w + + 1; assert(attrs.out_channels > 0); diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index 824695ca48..e9833d5e3f 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -5,61 +5,74 @@ #include "op-attrs/tensor_dims.h" #include "utils/containers/any_of.h" #include "utils/containers/product.h" -#include "op-attrs/dim_ordered/slice.h" #include namespace FlexFlow { -TensorShape get_output_shape(FlatAttrs const &attrs, TensorShape const &input_shape) { - FFOrdered leading_dims = slice(ff_ordered(input_shape.dims), ff_dim_t{0}, attrs.start_dim); - FFOrdered flattened_dims = slice(ff_ordered(input_shape.dims), attrs.start_dim, attrs.end_dim); - FFOrdered trailing_dims = slice(ff_ordered(input_shape.dims), attrs.end_dim, std::nullopt); +TensorShape get_output_shape(FlatAttrs const &attrs, + TensorShape const &input_shape) { + FFOrdered leading_dims = + slice(ff_ordered(input_shape.dims), ff_dim_t{0}, attrs.start_dim); + FFOrdered flattened_dims = + slice(ff_ordered(input_shape.dims), attrs.start_dim, attrs.end_dim); + FFOrdered trailing_dims = + slice(ff_ordered(input_shape.dims), attrs.end_dim, std::nullopt); if (flattened_dims.empty()) { return input_shape; } return TensorShape{ - TensorDims{ - concat(std::vector{ - leading_dims, - {product(flattened_dims)}, - trailing_dims, - }), - }, - input_shape.data_type, + TensorDims{ + concat(std::vector{ + leading_dims, + {product(flattened_dims)}, + trailing_dims, + }), + }, + input_shape.data_type, }; } -tl::expected get_output_parallel_dim_degrees(FlatAttrs const &attrs, - ParallelTensorDimDegrees const &input_degrees) { - FFOrdered flattened_dim_degrees = slice(input_degrees.shard_degrees, attrs.start_dim, attrs.end_dim); +tl::expected + get_output_parallel_dim_degrees( + FlatAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + FFOrdered flattened_dim_degrees = + slice(input_degrees.shard_degrees, attrs.start_dim, attrs.end_dim); if (flattened_dim_degrees.empty()) { return input_degrees; } if (any_of(flattened_dim_degrees, [](int degree) { return degree != 1; })) { - return tl::unexpected(fmt::format("get_output_parallel_dim_degrees for {} expected all shard degrees of flattened dimensions to be 1, but received {}", attrs, input_degrees)); + return tl::unexpected( + fmt::format("get_output_parallel_dim_degrees for {} expected all shard " + "degrees of flattened dimensions to be 1, but received {}", + attrs, + input_degrees)); } return ParallelTensorDimDegrees{ - /*sum_degree=*/input_degrees.sum_degree, - /*discard_copy_degree=*/input_degrees.discard_copy_degree, - /*shard_degrees=*/concat(std::vector{ - slice(input_degrees.shard_degrees, ff_dim_t{0}, attrs.start_dim), - {product(flattened_dim_degrees)}, - slice(input_degrees.shard_degrees, attrs.end_dim, std::nullopt), - }), + /*sum_degree=*/input_degrees.sum_degree, + /*discard_copy_degree=*/input_degrees.discard_copy_degree, + /*shard_degrees=*/ + concat(std::vector{ + slice(input_degrees.shard_degrees, ff_dim_t{0}, attrs.start_dim), + {product(flattened_dim_degrees)}, + slice(input_degrees.shard_degrees, attrs.end_dim, std::nullopt), + }), }; } -tl::expected get_output_shape(FlatAttrs const &attrs, - ParallelTensorShape const &input_shape) { +tl::expected + get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input_shape) { TensorShape unpar = get_output_shape(attrs, get_reduced_shape(input_shape)); ParallelTensorDimDegrees degrees = ({ - tl::expected returned = get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); if (!returned.has_value()) { return tl::unexpected(returned.error()); } diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index f09e274cfa..9bc25929ce 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,19 +1,20 @@ #include "op-attrs/ops/pool_2d.h" #include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/tensor_shape.h" #include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" #include "utils/integer_conversions.h" namespace FlexFlow { -tl::expected make_adaptive_pool2d_attrs(TensorDims const &input_dims, - int output_h, - int output_w, - PoolOp pool_type, - std::optional const &activation) { +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation) { // AdaptivePool2D semantics pulled from // https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993 - + if (num_dims(input_dims) != 4) { return tl::unexpected( fmt::format("make_adaptive_pool2d_attrs expected input tensor to " @@ -27,11 +28,23 @@ tl::expected make_adaptive_pool2d_attrs(TensorDims con size_t input_w = dim_at_idx(input_dims, ff_dim_t{3}); if (input_h % output_h != 0) { - return tl::unexpected(fmt::format("Currently make_adaptive_pool2d_attrs only supports input_h % output_h == 0, but received input_h={} and output_h={} (input_dims={}). If you need input_h % output_h != 0 supported, please create an issue.", input_h, output_h, input_dims)); + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_h % output_h " + "== 0, but received input_h={} and output_h={} (input_dims={}). If you " + "need input_h % output_h != 0 supported, please create an issue.", + input_h, + output_h, + input_dims)); } if (input_w % output_w != 0) { - return tl::unexpected(fmt::format("Currently make_adaptive_pool2d_attrs only supports input_w % output_w == 0, but received input_w={} and output_w={} (input_dims={}). If you need input_w % output_w != 0 supported, please create an issue.", input_w, output_w, input_dims)); + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_w % output_w " + "== 0, but received input_w={} and output_w={} (input_dims={}). If you " + "need input_w % output_w != 0 supported, please create an issue.", + input_w, + output_w, + input_dims)); } int kernel_h = input_h / output_h; @@ -41,28 +54,29 @@ tl::expected make_adaptive_pool2d_attrs(TensorDims con int stride_w = kernel_w; Pool2DAttrs attrs = Pool2DAttrs{ - /*kernel_h=*/kernel_h, - /*kernel_w=*/kernel_w, - /*stride_h=*/stride_h, - /*stride_w=*/stride_w, - /*padding_h=*/0, - /*padding_w=*/0, - /*pool_type=*/pool_type, - /*activation=*/activation, + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/pool_type, + /*activation=*/activation, }; TensorShape expected_ouput_shape = TensorShape{ - TensorDims{FFOrdered{ - num_samples, - num_channels, - size_t_from_int(output_h), - size_t_from_int(output_w), - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + num_samples, + num_channels, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, }; TensorShape output_shape = ({ - tl::expected result = get_output_shape(attrs, TensorShape{input_dims, DataType::FLOAT}); + tl::expected result = + get_output_shape(attrs, TensorShape{input_dims, DataType::FLOAT}); if (!result.has_value()) { return tl::unexpected(result.error()); } @@ -70,10 +84,16 @@ tl::expected make_adaptive_pool2d_attrs(TensorDims con }); if (output_shape != expected_ouput_shape) { - return tl::unexpected(fmt::format("Result of make_adaptive_pool_2d (i.e., {}) should produce expected output shape {}, but produced {}. This is a bug in FlexFlow, Please create an issue.", attrs, expected_ouput_shape, output_shape)); + return tl::unexpected( + fmt::format("Result of make_adaptive_pool_2d (i.e., {}) should produce " + "expected output shape {}, but produced {}. This is a bug " + "in FlexFlow, Please create an issue.", + attrs, + expected_ouput_shape, + output_shape)); } - return attrs; + return attrs; } tl::expected diff --git a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc index 37db098196..84f1861f0b 100644 --- a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc @@ -5,13 +5,12 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ComputationGraphOpAttrs to/from json") { - ComputationGraphOpAttrs correct = - ComputationGraphOpAttrs{BatchNormAttrs{ - /*relu=*/false, - /*affine=*/true, - /*eps=*/1e-5, - /*momentum=*/0.1, - }}; + ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, + }}; nlohmann::json j = correct; auto result = j.get(); diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc index da95263743..2ac641cfc2 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc @@ -6,17 +6,11 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat(FFOrdered, FFOrdered)") { SUBCASE("inputs have elements") { - FFOrdered l_input = FFOrdered{ - 1, 3, 1 - }; - FFOrdered r_input = FFOrdered{ - 2, 1 - }; + FFOrdered l_input = FFOrdered{1, 3, 1}; + FFOrdered r_input = FFOrdered{2, 1}; FFOrdered result = concat(l_input, r_input); - FFOrdered correct = { - 1, 3, 1, 2, 1 - }; + FFOrdered correct = {1, 3, 1, 2, 1}; CHECK(result == correct); } @@ -35,14 +29,17 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat(std::vector>)") { SUBCASE("inputs have elements") { std::vector> input = { - {1}, - {2, 1}, - {1}, + {1}, + {2, 1}, + {1}, }; FFOrdered result = concat(input); FFOrdered correct = { - 1, 2, 1, 1, + 1, + 2, + 1, + 1, }; CHECK(result == correct); @@ -58,9 +55,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("inputs are empty") { - std::vector> input = { - {}, {}, {} - }; + std::vector> input = {{}, {}, {}}; FFOrdered result = concat(input); FFOrdered correct = {}; diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc index af5e61f8d3..4196394d00 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc @@ -11,10 +11,10 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_batch_norm_incoming_tensor_roles(BatchNormAttrs)") { auto make_attrs = [](bool affine) { return BatchNormAttrs{ - /*relu=*/false, - /*affine=*/affine, - /*eps=*/1.0, - /*momentum=*/0.1, + /*relu=*/false, + /*affine=*/affine, + /*eps=*/1.0, + /*momentum=*/0.1, }; }; @@ -47,10 +47,10 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("shape inference (BatchNorm)") { BatchNormAttrs attrs_affine_true = BatchNormAttrs{ - /*relu=*/false, - /*affine=*/true, - /*eps=*/1.0, - /*momentum=*/0.1, + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, }; BatchNormAttrs attrs_affine_false = [&] { @@ -127,10 +127,10 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("parallel dim degree inference (BatchNormAttrs)") { BatchNormAttrs attrs_affine_true = BatchNormAttrs{ - /*relu=*/false, - /*affine=*/true, - /*eps=*/1.0, - /*momentum=*/0.1, + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, }; BatchNormAttrs attrs_affine_false = [&] { @@ -143,14 +143,18 @@ TEST_SUITE(FF_TEST_SUITE) { int degree = 2; ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{ - 1, degree, 1, 1, - }, + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + degree, + 1, + 1, + }, }; - SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { tl::expected result = get_output_parallel_dim_degrees(attrs_affine_true, input); tl::expected correct = input; @@ -158,45 +162,50 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE( - "get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { SUBCASE("affine = true") { tl::expected result = get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input); - tl::expected correct = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{degree}, - }; + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; CHECK(result == correct); } SUBCASE("affine = false") { - std::optional result = optional_from_expected( - get_gamma_weights_parallel_dim_degrees(attrs_affine_false, input)); + std::optional result = + optional_from_expected(get_gamma_weights_parallel_dim_degrees( + attrs_affine_false, input)); std::optional correct = std::nullopt; CHECK(result == correct); } } - SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { SUBCASE("affine = true") { tl::expected result = get_beta_weights_parallel_dim_degrees(attrs_affine_true, input); - tl::expected correct = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{degree}, - }; + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; CHECK(result == correct); } SUBCASE("affine = false") { - std::optional result = optional_from_expected( - get_beta_weights_parallel_dim_degrees(attrs_affine_false, input)); + std::optional result = + optional_from_expected(get_beta_weights_parallel_dim_degrees( + attrs_affine_false, input)); std::optional correct = std::nullopt; CHECK(result == correct); @@ -208,12 +217,13 @@ TEST_SUITE(FF_TEST_SUITE) { int degree = 2; ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{1, 1, degree, 1}, + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, degree, 1}, }; - SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { std::optional result = optional_from_expected( get_output_parallel_dim_degrees(attrs_affine_true, input)); std::optional correct = std::nullopt; @@ -221,8 +231,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE( - "get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { std::optional result = optional_from_expected( get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); std::optional correct = std::nullopt; @@ -230,7 +240,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { std::optional result = optional_from_expected( get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); std::optional correct = std::nullopt; @@ -243,12 +254,13 @@ TEST_SUITE(FF_TEST_SUITE) { SumDegree sum_degree = SumDegree{2}; ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - sum_degree, - DiscardCopyDegree{1}, - FFOrdered{1, 1, 1, 1}, + sum_degree, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, }; - SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { std::optional result = optional_from_expected( get_output_parallel_dim_degrees(attrs_affine_true, input)); std::optional correct = std::nullopt; @@ -256,8 +268,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE( - "get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { std::optional result = optional_from_expected( get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); std::optional correct = std::nullopt; @@ -265,7 +277,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { std::optional result = optional_from_expected( get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); std::optional correct = std::nullopt; @@ -278,12 +291,13 @@ TEST_SUITE(FF_TEST_SUITE) { DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{1}, - discard_copy_degree, - FFOrdered{1, 1, 1, 1}, + SumDegree{1}, + discard_copy_degree, + FFOrdered{1, 1, 1, 1}, }; - SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { std::optional result = optional_from_expected( get_output_parallel_dim_degrees(attrs_affine_true, input)); std::optional correct = std::nullopt; @@ -291,8 +305,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE( - "get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { std::optional result = optional_from_expected( get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); std::optional correct = std::nullopt; @@ -300,7 +314,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, ParallelTensorDimDegrees)") { + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { std::optional result = optional_from_expected( get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); std::optional correct = std::nullopt; @@ -311,72 +326,77 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("parallel shape inference (BatchNormAttrs)") { - // since most of the edge cases are already tested in the above test cases + // since most of the edge cases are already tested in the above test cases // (i.e., shape inference and parallel degree inference) // here we just do a basic check that they compose BatchNormAttrs attrs = BatchNormAttrs{ - /*relu=*/true, - /*affine=*/true, - /*eps=*/1.0, - /*momentum=*/0.1, + /*relu=*/true, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, }; ParallelTensorShape input = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{12, 1}, - ShardParallelDim{14, 2}, - ShardParallelDim{16, 1}, - ShardParallelDim{18, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{18, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; - + SUBCASE("get_output_shape(BatchNormAttrs, ParallelTensorShape)") { - tl::expected result = get_output_shape(attrs, input); + tl::expected result = + get_output_shape(attrs, input); tl::expected correct = input; CHECK(result == correct); } SUBCASE("get_gamma_weights_shape(BatchNormAttrs, ParallelTensorShape)") { - tl::expected result = get_gamma_weights_shape(attrs, input); - tl::expected correct = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{14, 2}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT, - }; + tl::expected result = + get_gamma_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; CHECK(result == correct); } SUBCASE("get_beta_weights_shape(BatchNormAttrs, ParallelTensorShape)") { - tl::expected result = get_beta_weights_shape(attrs, input); - tl::expected correct = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{14, 2}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT, - }; + tl::expected result = + get_beta_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc index cbe5ff3c42..3d86576279 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc @@ -6,10 +6,10 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { BatchNormAttrs correct = BatchNormAttrs{ - /*relu=*/false, - /*affine=*/true, - /*eps=*/1e-5, - /*momentum=*/0.1, + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, }; nlohmann::json j = correct; diff --git a/lib/op-attrs/test/src/op-attrs/ops/flat.cc b/lib/op-attrs/test/src/op-attrs/ops/flat.cc index 5e74139bea..d81ab95c35 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/flat.cc @@ -1,35 +1,35 @@ #include "op-attrs/ops/flat.h" #include "utils/expected.h" -#include #include "utils/fmt/expected.h" #include "utils/fmt/optional.h" +#include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(FlatAttrs, TensorShape)") { TensorShape input_shape = TensorShape{ - TensorDims{FFOrdered{ - 2, - 4, - 2, - 3, - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + 2, + 4, + 2, + 3, + }}, + DataType::FLOAT, }; SUBCASE("flatten all dims") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{0}, - /*end_dim=*/ff_dim_t{4}, + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{4}, }; TensorShape result = get_output_shape(attrs, input_shape); TensorShape correct = TensorShape{ - TensorDims{FFOrdered{ - 2 * 4 * 2 * 3, - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + 2 * 4 * 2 * 3, + }}, + DataType::FLOAT, }; CHECK(result == correct); @@ -37,18 +37,18 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten trailing dims") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{2}, - /*end_dim=*/ff_dim_t{4}, + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{4}, }; TensorShape result = get_output_shape(attrs, input_shape); TensorShape correct = TensorShape{ - TensorDims{FFOrdered{ - 2, - 4, - 2 * 3, - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + 2, + 4, + 2 * 3, + }}, + DataType::FLOAT, }; CHECK(result == correct); @@ -56,18 +56,18 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten leading dims") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{0}, - /*end_dim=*/ff_dim_t{2}, + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{2}, }; TensorShape result = get_output_shape(attrs, input_shape); TensorShape correct = TensorShape{ - TensorDims{FFOrdered{ - 2 * 4, - 2, - 3, - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + 2 * 4, + 2, + 3, + }}, + DataType::FLOAT, }; CHECK(result == correct); @@ -75,18 +75,18 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten middle dims") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{1}, - /*end_dim=*/ff_dim_t{3}, + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, }; TensorShape result = get_output_shape(attrs, input_shape); TensorShape correct = TensorShape{ - TensorDims{FFOrdered{ - 2, - 4 * 2, - 3, - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + 2, + 4 * 2, + 3, + }}, + DataType::FLOAT, }; CHECK(result == correct); @@ -94,8 +94,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten no dims (start_dim == end_dim)") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{2}, - /*end_dim=*/ff_dim_t{2}, + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{2}, }; TensorShape result = get_output_shape(attrs, input_shape); @@ -106,8 +106,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten no dims (start_dim < end_dim)") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{2}, - /*end_dim=*/ff_dim_t{1}, + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{1}, }; TensorShape result = get_output_shape(attrs, input_shape); @@ -117,37 +117,39 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees)") { - FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{1}, - /*end_dim=*/ff_dim_t{3} - }; + TEST_CASE( + "get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees)") { + FlatAttrs attrs = FlatAttrs{/*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}}; SUBCASE("allows shard parallelism in non-flattened dims") { ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{2, 1, 1, 3}, + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 1, 3}, }; - tl::expected result = get_output_parallel_dim_degrees(attrs, input); - tl::expected correct = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{2, 1, 3}, - }; + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 3}, + }; CHECK(result == correct); } SUBCASE("does not allow shard parallelism in flattened dims") { ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{1, 1, 2, 1}, + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 2, 1}, }; - std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); std::optional correct = std::nullopt; CHECK(result == correct); @@ -155,80 +157,87 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("allows sum parallelism") { ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{2}, - DiscardCopyDegree{1}, - FFOrdered{1, 1, 1, 1}, + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, }; - std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); - std::optional correct = ParallelTensorDimDegrees{ - SumDegree{2}, - DiscardCopyDegree{1}, - FFOrdered{1, 1, 1}, - }; + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1}, + }; CHECK(result == correct); } SUBCASE("allows discard copy parallelism") { ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{2}, - FFOrdered{1, 1, 1, 1}, + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1, 1}, }; - std::optional result = optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); - std::optional correct = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{2}, - FFOrdered{1, 1, 1}, - }; + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1}, + }; CHECK(result == correct); } } TEST_CASE("get_output_shape(FlatAttrs, ParallelTensorShape)") { - // since most of the edge cases are already tested in get_output_shape(FlatAttrs, TensorShape) - // and get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees), here we just do - // a basic check that they compose - + // since most of the edge cases are already tested in + // get_output_shape(FlatAttrs, TensorShape) and + // get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees), + // here we just do a basic check that they compose + ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{4, 2}, - ShardParallelDim{8, 1}, - ShardParallelDim{6, 1}, - ShardParallelDim{9, 3}, - }, - ReplicaParallelDimSet{ - SumDegree{7}, - DiscardCopyDegree{5}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8, 1}, + ShardParallelDim{6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{1}, - /*end_dim=*/ff_dim_t{3}, + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, }; - tl::expected result = get_output_shape(attrs, input_shape); - tl::expected correct = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{4, 2}, - ShardParallelDim{8*6, 1}, - ShardParallelDim{9, 3}, - }, - ReplicaParallelDimSet{ - SumDegree{7}, - DiscardCopyDegree{5}, - }, - }, - DataType::FLOAT, - }; + tl::expected result = + get_output_shape(attrs, input_shape); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8 * 6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc index 4efbe20cf4..0c14c0fc2a 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc @@ -16,45 +16,47 @@ TEST_SUITE(FF_TEST_SUITE) { Activation activation = Activation::RELU; PoolOp op = PoolOp::AVG; - TensorDims input_dims = TensorDims{FFOrdered{ - input_n, input_c, input_h, input_w - }}; + TensorDims input_dims = + TensorDims{FFOrdered{input_n, input_c, input_h, input_w}}; SUBCASE("input_h divisible by output_h && input_w divisible by output_w") { int output_h = 5; int output_w = 2; Pool2DAttrs correct_attrs = Pool2DAttrs{ - /*kernel_h=*/3, - /*kernel_w=*/10, - /*stride_h=*/3, - /*stride_w=*/10, - /*padding_h=*/0, - /*padding_w=*/0, - /*pool_type=*/op, - /*activation=*/activation, + /*kernel_h=*/3, + /*kernel_w=*/10, + /*stride_h=*/3, + /*stride_w=*/10, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, }; SUBCASE("returns correct attrs") { - tl::expected result = make_adaptive_pool2d_attrs(input_dims, - output_h, - output_w, - op, - activation); + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); tl::expected correct = correct_attrs; CHECK(result == correct); } - SUBCASE("confirm that output shape is as expected for the expected attrs") { - TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; - tl::expected result = get_output_shape(correct_attrs, input_shape); + tl::expected result = + get_output_shape(correct_attrs, input_shape); tl::expected correct = TensorShape{ - TensorDims{FFOrdered{ - input_n, input_c, size_t_from_int(output_h), size_t_from_int(output_w), - }}, - DataType::FLOAT, + TensorDims{FFOrdered{ + input_n, + input_c, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, }; CHECK(result == correct); @@ -65,11 +67,9 @@ TEST_SUITE(FF_TEST_SUITE) { int output_h = 6; int output_w = 2; - std::optional result = optional_from_expected(make_adaptive_pool2d_attrs(input_dims, - output_h, - output_w, - op, - activation)); + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); std::optional correct = std::nullopt; CHECK(result == correct); @@ -79,11 +79,9 @@ TEST_SUITE(FF_TEST_SUITE) { int output_h = 5; int output_w = 3; - std::optional result = optional_from_expected(make_adaptive_pool2d_attrs(input_dims, - output_h, - output_w, - op, - activation)); + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); std::optional correct = std::nullopt; CHECK(result == correct); @@ -94,31 +92,31 @@ TEST_SUITE(FF_TEST_SUITE) { int output_w = input_w; Pool2DAttrs correct_attrs = Pool2DAttrs{ - /*kernel_h=*/1, - /*kernel_w=*/1, - /*stride_h=*/1, - /*stride_w=*/1, - /*padding_h=*/0, - /*padding_w=*/0, - /*pool_type=*/op, - /*activation=*/activation, + /*kernel_h=*/1, + /*kernel_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, }; SUBCASE("returns correct attrs") { - tl::expected result = make_adaptive_pool2d_attrs(input_dims, - output_h, - output_w, - op, - activation); + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); tl::expected correct = correct_attrs; CHECK(result == correct); } - SUBCASE("confirm that output shape is as expected for the expected attrs") { - TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; - tl::expected result = get_output_shape(correct_attrs, input_shape); + tl::expected result = + get_output_shape(correct_attrs, input_shape); tl::expected correct = input_shape; CHECK(result == correct); diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 4d4e0dfd74..45cde0de57 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -137,13 +137,13 @@ struct ComputationGraphBuilder { PoolOp type = PoolOp::MAX, std::optional const &activation = std::nullopt, std::optional const &name = std::nullopt); - tensor_guid_t - adaptive_pool2d(tensor_guid_t const &input, - int output_h, - int output_w, - PoolOp type = PoolOp::MAX, - std::optional const &activation = std::nullopt, - std::optional const &name = std::nullopt); + tensor_guid_t adaptive_pool2d( + tensor_guid_t const &input, + int output_h, + int output_w, + PoolOp type = PoolOp::MAX, + std::optional const &activation = std::nullopt, + std::optional const &name = std::nullopt); tensor_guid_t layer_norm(tensor_guid_t const &input, std::vector const &axes, @@ -263,6 +263,7 @@ struct ComputationGraphBuilder { std::vector const &outputs); TensorShape get_shape(tensor_guid_t const &) const; + private: tensor_guid_t broadcast(tensor_guid_t const &, TensorDims const &, std::string const &); diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 696d558fe8..4a565476bd 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -536,25 +536,21 @@ tensor_guid_t ComputationGraphBuilder::pool2d( throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return get_only( - this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)}) - ); + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } -tensor_guid_t - ComputationGraphBuilder::adaptive_pool2d(tensor_guid_t const &uncasted_input, - int output_h, - int output_w, - PoolOp type, - std::optional const &activation, - std::optional const &maybe_name) { +tensor_guid_t ComputationGraphBuilder::adaptive_pool2d( + tensor_guid_t const &uncasted_input, + int output_h, + int output_w, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { TensorDims input_dims = this->get_shape(uncasted_input).dims; - Pool2DAttrs attrs = throw_if_unexpected(make_adaptive_pool2d_attrs(input_dims, - output_h, - output_w, - type, - activation)); + Pool2DAttrs attrs = throw_if_unexpected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, type, activation)); std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); @@ -564,12 +560,11 @@ tensor_guid_t LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = - throw_if_unexpected(get_output_shape(attrs, this->get_shape(casted_input))); + TensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, this->get_shape(casted_input))); - return get_only( - this->add_layer(layer, {casted_input}, {}, {make_output_attrs(output_shape)}) - ); + return get_only(this->add_layer( + layer, {casted_input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::batch_norm( @@ -581,15 +576,19 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( std::optional const &maybe_name) { if (activation.has_value() && activation.value() != Activation::RELU) { - throw mk_runtime_error(fmt::format("batch_norm currently only supports (1) no activation function, or (2) relu activation function, but received {}. " - "If you need support for additional activation functions, please create an issue.", activation)); + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); } BatchNormAttrs attrs = BatchNormAttrs{ - /*relu=*/activation.has_value(), - /*affine=*/affine, - /*eps=*/eps, - /*momentum=*/momentum, + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, }; std::string name = @@ -598,7 +597,8 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape input_shape = this->get_shape(input); - TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); std::vector weights; @@ -619,12 +619,12 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } - return get_only( - this->add_layer(layer, - {input}, - transform(weights, - [&](TensorAttrs const &a) { return this->create_weight(a); }), - {make_output_attrs(output_shape)})); + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -804,32 +804,30 @@ tensor_guid_t ComputationGraphBuilder::concat( throw_if_unexpected(get_output_shape(attrs, input_shapes)); return get_only( - this->add_layer(layer, inputs, {}, {make_output_attrs(output_shape)}) - ); + this->add_layer(layer, inputs, {}, {make_output_attrs(output_shape)})); } -tensor_guid_t ComputationGraphBuilder::flat(tensor_guid_t const &input, - int start_dim, - std::optional const &end_dim, - std::optional const &maybe_name) { +tensor_guid_t ComputationGraphBuilder::flat( + tensor_guid_t const &input, + int start_dim, + std::optional const &end_dim, + std::optional const &maybe_name) { int input_num_dims = num_dims(this->get_shape(input)); FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{start_dim}, - /*end_dim=*/ff_dim_t{end_dim.value_or(input_num_dims)}, - }; + /*start_dim=*/ff_dim_t{start_dim}, + /*end_dim=*/ff_dim_t{end_dim.value_or(input_num_dims)}, + }; std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = - get_output_shape(attrs, this->get_shape(input)); + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return get_only( - this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)}) - ); + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::layer_norm( diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 956cf1d4f3..ce00ea62f4 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -338,15 +338,19 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( std::optional const &maybe_name) { if (activation.has_value() && activation.value() != Activation::RELU) { - throw mk_runtime_error(fmt::format("batch_norm currently only supports (1) no activation function, or (2) relu activation function, but received {}. " - "If you need support for additional activation functions, please create an issue.", activation)); + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); } BatchNormAttrs attrs = BatchNormAttrs{ - /*relu=*/activation.has_value(), - /*affine=*/affine, - /*eps=*/eps, - /*momentum=*/momentum, + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, }; std::string name = diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index e8b9f4e441..5ae90ec5ba 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -31,7 +31,7 @@ std::vector subvec(std::vector const &v, } if (end_iter < begin_iter) { - end_iter = begin_iter; + end_iter = begin_iter; } std::vector output(begin_iter, end_iter); diff --git a/lib/utils/test/src/utils/containers/require_all_same1.cc b/lib/utils/test/src/utils/containers/require_all_same1.cc index 45a7fcfa78..48c1ab0b99 100644 --- a/lib/utils/test/src/utils/containers/require_all_same1.cc +++ b/lib/utils/test/src/utils/containers/require_all_same1.cc @@ -1,5 +1,4 @@ #include "utils/containers/require_all_same1.h" -#include "utils/expected.h" #include "test/utils/doctest/fmt/expected.h" #include "test/utils/doctest/fmt/multiset.h" #include "test/utils/doctest/fmt/optional.h" @@ -7,6 +6,7 @@ #include "test/utils/doctest/fmt/unordered_multiset.h" #include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/vector.h" +#include "utils/expected.h" #include #include #include From eabd822032368e90b98c95d4efce4f34e9d79e46 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 17 Sep 2024 11:09:08 -0700 Subject: [PATCH 09/10] Document adaptive pool2d formula simplification --- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index 9bc25929ce..1e1f693184 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -47,6 +47,14 @@ tl::expected input_dims)); } + // Note that for some reason the stack overflow post linked above states that + // `kernel_size = ind - (outd-1)*stride`, but some simplification yields + // `kernel_size` = `ind - (outd - 1)*stride` + // = `ind - (outd - 1) * (ind / outd)` + // = `ind - ind + (ind /outd)` + // = `ind / outd` + // = `stride` + int kernel_h = input_h / output_h; int kernel_w = input_w / output_w; From 8a062aade78525301305ec86339ef5d016482edd Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 17 Sep 2024 11:37:16 -0700 Subject: [PATCH 10/10] Format --- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index 1e1f693184..95bcd8b336 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -47,12 +47,12 @@ tl::expected input_dims)); } - // Note that for some reason the stack overflow post linked above states that + // Note that for some reason the stack overflow post linked above states that // `kernel_size = ind - (outd-1)*stride`, but some simplification yields // `kernel_size` = `ind - (outd - 1)*stride` - // = `ind - (outd - 1) * (ind / outd)` - // = `ind - ind + (ind /outd)` - // = `ind / outd` + // = `ind - (outd - 1) * (ind / outd)` + // = `ind - ind + (ind /outd)` + // = `ind / outd` // = `stride` int kernel_h = input_h / output_h;