From 2d3ec713b4a7d2d038091c7480b54f7a16ab744f Mon Sep 17 00:00:00 2001 From: George Stelle Date: Mon, 3 Jul 2023 09:38:40 -0600 Subject: [PATCH 1/5] Added ConcatAttrs::is_valid --- lib/op-attrs/include/op-attrs/ops/concat.h | 4 ++++ .../include/op-attrs/parallel_tensor_shape.h | 2 ++ lib/op-attrs/src/concat.cc | 16 ++++++++-------- lib/op-attrs/src/parallel_tensor_shape.cc | 4 ++++ 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index b9bd14a231..0cae0f86b2 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -10,7 +10,11 @@ namespace FlexFlow { struct ConcatAttrs { ff_dim_t axis; +bool is_valid( + std::vector const &input) const; }; + + FF_VISITABLE_STRUCT(ConcatAttrs, axis); CHECK_VALID_OP_ATTR(ConcatAttrs); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index fd560352bb..f3229e5618 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -32,6 +32,8 @@ struct ParallelTensorShape : public use_visitable_cmp { ParallelDim const &operator[](ff_dim_t const &) const; ParallelDim &operator[](ff_dim_t const &); + bool is_valid(); + public: ParallelTensorDims dims; DataType data_type; diff --git a/lib/op-attrs/src/concat.cc b/lib/op-attrs/src/concat.cc index 065c58f365..4360e431e8 100644 --- a/lib/op-attrs/src/concat.cc +++ b/lib/op-attrs/src/concat.cc @@ -2,13 +2,13 @@ namespace FlexFlow { -/* bool ConcatAttrs::is_valid( */ -/* std::vector const &input) const { */ -/* bool valid = true; */ -/* for (auto p : input) { */ -/* valid &= p.is_valid(); */ -/* } */ -/* return valid; */ -/* } */ +bool ConcatAttrs::is_valid( + std::vector const &input) const { + bool valid = true; + for (auto p : input) { + valid &= p.is_valid(); + } + return valid; + } } // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_tensor_shape.cc b/lib/op-attrs/src/parallel_tensor_shape.cc index 9a36e7d11b..9f5ec21920 100644 --- a/lib/op-attrs/src/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/parallel_tensor_shape.cc @@ -37,4 +37,8 @@ bool is_valid(ParallelTensorShape const &shape) { return is_valid(shape.dims); } +bool ParallelTensorShape::is_valid() { + return FlexFlow::is_valid(*this); +} + } // namespace FlexFlow From 733da61b8f6865909e01bd978a4b146b264926f4 Mon Sep 17 00:00:00 2001 From: George Stelle Date: Mon, 21 Aug 2023 15:15:52 -0600 Subject: [PATCH 2/5] Added gather is_valid --- lib/op-attrs/include/op-attrs/ops/concat.h | 4 +-- lib/op-attrs/include/op-attrs/ops/gather.h | 3 +++ .../include/op-attrs/parallel_tensor_shape.h | 2 +- lib/op-attrs/src/concat.cc | 16 ++++++------ lib/op-attrs/src/gather.cc | 26 +++++++++---------- lib/op-attrs/src/parallel_tensor_shape.cc | 2 +- 6 files changed, 27 insertions(+), 26 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index 0cae0f86b2..ef2b96bae2 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -10,11 +10,9 @@ namespace FlexFlow { struct ConcatAttrs { ff_dim_t axis; -bool is_valid( - std::vector const &input) const; + bool is_valid(std::vector const &input) const; }; - FF_VISITABLE_STRUCT(ConcatAttrs, axis); CHECK_VALID_OP_ATTR(ConcatAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index ca2406ef75..b4343cf183 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -10,7 +10,10 @@ namespace FlexFlow { struct GatherAttrs { ff_dim_t dim; + bool is_valid(ParallelTensorShape const &lhs, + ParallelTensorShape const &rhs) const; }; + FF_VISITABLE_STRUCT(GatherAttrs, dim); CHECK_VALID_OP_ATTR(GatherAttrs); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index f3229e5618..c787403506 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -32,7 +32,7 @@ struct ParallelTensorShape : public use_visitable_cmp { ParallelDim const &operator[](ff_dim_t const &) const; ParallelDim &operator[](ff_dim_t const &); - bool is_valid(); + bool is_valid(); public: ParallelTensorDims dims; diff --git a/lib/op-attrs/src/concat.cc b/lib/op-attrs/src/concat.cc index 4360e431e8..9ffe42d246 100644 --- a/lib/op-attrs/src/concat.cc +++ b/lib/op-attrs/src/concat.cc @@ -2,13 +2,13 @@ namespace FlexFlow { -bool ConcatAttrs::is_valid( - std::vector const &input) const { - bool valid = true; - for (auto p : input) { - valid &= p.is_valid(); - } - return valid; - } +bool ConcatAttrs::is_valid( + std::vector const &input) const { + bool valid = true; + for (auto p : input) { + valid &= p.is_valid(); + } + return valid; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/gather.cc b/lib/op-attrs/src/gather.cc index 4f2c13c794..9f1d95a20e 100644 --- a/lib/op-attrs/src/gather.cc +++ b/lib/op-attrs/src/gather.cc @@ -2,18 +2,18 @@ namespace FlexFlow { -/* bool GatherAttrs::is_valid(ParallelTensorShape const &lhs, - * ParallelTensorShape const &rhs) const { */ -/* if (lhs.num_dims() != rhs.num_dims()) { */ -/* return false; */ -/* } */ -/* for (int i = 0; i < lhs.num_dims(); i++) { */ -/* if (i != this->legion_dim && */ -/* lhs.at(i).size < rhs.at(i).size) { */ -/* return false; */ -/* } */ -/* } */ -/* return true; */ -/* } */ +bool GatherAttrs::is_valid(ParallelTensorShape const &lhs, + ParallelTensorShape const &rhs) const { + if (lhs.dims.num_dims() != rhs.dims.num_dims()) { + return false; + } + for (auto i : lhs.dims) { + if (ff_dim_t(i.size) != this->dim && + lhs.at(ff_dim_t(i.size)).size < rhs.at(ff_dim_t(i.size)).size) { + return false; + } + } + return true; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_tensor_shape.cc b/lib/op-attrs/src/parallel_tensor_shape.cc index 9f5ec21920..dc9c94e796 100644 --- a/lib/op-attrs/src/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/parallel_tensor_shape.cc @@ -37,7 +37,7 @@ bool is_valid(ParallelTensorShape const &shape) { return is_valid(shape.dims); } -bool ParallelTensorShape::is_valid() { +bool ParallelTensorShape::is_valid() { return FlexFlow::is_valid(*this); } From 258e4b3944a6dd9b92579dcefd8b832add1db20e Mon Sep 17 00:00:00 2001 From: George Stelle Date: Wed, 6 Sep 2023 10:45:12 -0600 Subject: [PATCH 3/5] Added simple op-attrs gather valid test case --- lib/op-attrs/CMakeLists.txt | 1 + lib/op-attrs/test/CMakeLists.txt | 10 ++++++++++ lib/op-attrs/test/main.cc | 2 ++ lib/op-attrs/test/test_valid.cc | 17 +++++++++++++++++ 4 files changed, 30 insertions(+) create mode 100644 lib/op-attrs/test/CMakeLists.txt create mode 100644 lib/op-attrs/test/main.cc create mode 100644 lib/op-attrs/test/test_valid.cc diff --git a/lib/op-attrs/CMakeLists.txt b/lib/op-attrs/CMakeLists.txt index 778be53d7c..9a9721ef2d 100644 --- a/lib/op-attrs/CMakeLists.txt +++ b/lib/op-attrs/CMakeLists.txt @@ -12,3 +12,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/op-attrs/test/CMakeLists.txt b/lib/op-attrs/test/CMakeLists.txt new file mode 100644 index 0000000000..4027148387 --- /dev/null +++ b/lib/op-attrs/test/CMakeLists.txt @@ -0,0 +1,10 @@ +ff_add_test_executable( + NAME + op-attrs-test + SRC_PATTERNS + *.cc + DEPS + op-attrs + doctest + utils-test-common +) diff --git a/lib/op-attrs/test/main.cc b/lib/op-attrs/test/main.cc new file mode 100644 index 0000000000..9522fa7fdb --- /dev/null +++ b/lib/op-attrs/test/main.cc @@ -0,0 +1,2 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "doctest/doctest.h" diff --git a/lib/op-attrs/test/test_valid.cc b/lib/op-attrs/test/test_valid.cc new file mode 100644 index 0000000000..6b43c10ffe --- /dev/null +++ b/lib/op-attrs/test/test_valid.cc @@ -0,0 +1,17 @@ +#include "op-attrs/ops/gather.h" +#include "doctest/doctest.h" +#include +#include + +using namespace FlexFlow; + +TEST_CASE("gather_valid") { + + GatherAttrs g{ff_dim_t(2)}; + + TensorDims tds({2,2}); + ParallelTensorShape p(tds, DataType::FLOAT); + RC_ASSERT(g.is_valid(p, p)); + +}; + From 090401721d770bb1d74c9ed4ea04f8b5dc2c5a83 Mon Sep 17 00:00:00 2001 From: George Stelle Date: Fri, 15 Sep 2023 12:39:34 -0600 Subject: [PATCH 4/5] Added shape inference for some easier ops --- lib/op-attrs/src/cast.cc | 6 +++++ lib/op-attrs/src/combine.cc | 7 ++++++ lib/op-attrs/src/concat.cc | 8 +++++++ lib/op-attrs/src/dropout.cc | 11 +++++++++ lib/op-attrs/src/element_binary.cc | 28 ++++++++++++++++++++++- lib/op-attrs/src/element_unary.cc | 10 ++++++++- lib/op-attrs/src/flat.cc | 36 ++++++++++++------------------ lib/op-attrs/src/reduction.cc | 14 ++++++------ 8 files changed, 89 insertions(+), 31 deletions(-) create mode 100644 lib/op-attrs/src/dropout.cc diff --git a/lib/op-attrs/src/cast.cc b/lib/op-attrs/src/cast.cc index e4ab178a7e..bb237a78f9 100644 --- a/lib/op-attrs/src/cast.cc +++ b/lib/op-attrs/src/cast.cc @@ -2,6 +2,12 @@ namespace FlexFlow { +ParallelTensorShape get_output_shape(CastAttrs const &attrs, + ParallelTensorShape const &input) { + ParallelTensorShape output = input; + output.data_type = attrs.dtype; +} + /* bool CastAttrs::is_valid(ParallelTensorShape const &input) const { */ /* bool valid = input.is_valid(); */ /* valid &= (input.at(input.num_dims() - 1).degree == 1); */ diff --git a/lib/op-attrs/src/combine.cc b/lib/op-attrs/src/combine.cc index cdca524538..7530b636af 100644 --- a/lib/op-attrs/src/combine.cc +++ b/lib/op-attrs/src/combine.cc @@ -3,6 +3,13 @@ namespace FlexFlow { +ParallelTensorShape output_shape(CombineAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ParallelTensorShape output = input_shape; + output.at(attrs.combine_dim).degree /= attrs.combine_degree; + return output; +} + /* bool CombineAttrs::is_valid(ParallelTensorShape const &input) const { */ /* return input.at(this->combine_legion_dim).degree % this->combine_degree == * 0; */ diff --git a/lib/op-attrs/src/concat.cc b/lib/op-attrs/src/concat.cc index 9ffe42d246..7611f7bb68 100644 --- a/lib/op-attrs/src/concat.cc +++ b/lib/op-attrs/src/concat.cc @@ -11,4 +11,12 @@ bool ConcatAttrs::is_valid( return valid; } +ParallelTensorShape + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + ParallelTensorShape output = inputs[0]; + for (auto &i : inputs) { + output.at(attrs.axis).size += i.at(attrs.axis).size; + } +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/dropout.cc b/lib/op-attrs/src/dropout.cc new file mode 100644 index 0000000000..dc323fe067 --- /dev/null +++ b/lib/op-attrs/src/dropout.cc @@ -0,0 +1,11 @@ +#include "op-attrs/ops/dropout.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(DropoutAttrs const &attrs, + ParallelTensorShape const &input) { + ParallelTensorShape output = input; + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/element_binary.cc b/lib/op-attrs/src/element_binary.cc index b713c6753f..2f014b2f73 100644 --- a/lib/op-attrs/src/element_binary.cc +++ b/lib/op-attrs/src/element_binary.cc @@ -1,3 +1,29 @@ #include "op-attrs/ops/element_binary.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &attrs, + ParallelTensorShape const &in1, + ParallelTensorShape const &in2) { + ParallelTensorShape output = in1.num_dims() >= in2.num_dims() ? in1 : in2; + for (int i = 0; i < output.num_dims(); i++) { + if (i >= in1.num_dims()) { + output.at(ff_dim_t(i)) = in2.at(ff_dim_t(i)); + } else if (i >= in2.num_dims()) { + output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i)); + } else if (in1.at(ff_dim_t(i)).size == in2.at(ff_dim_t(i)).size) { + output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i)); + } else if (in1.at(ff_dim_t(i)).size == 1) { + output.at(ff_dim_t(i)) = in2.at(ff_dim_t(i)); + } else if (in2.at(ff_dim_t(i)).size == 1) { + output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i)); + } else { + assert(false && "Operands could not be broadcast together"); + exit(0); + } + } + + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/element_unary.cc b/lib/op-attrs/src/element_unary.cc index 481151fafb..aae5e7ecf0 100644 --- a/lib/op-attrs/src/element_unary.cc +++ b/lib/op-attrs/src/element_unary.cc @@ -1,3 +1,11 @@ #include "op-attrs/ops/element_unary.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ElementUnaryAttrs const &attrs, + ParallelTensorShape const &in) { + ParallelTensorShape out = in; + return out; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/flat.cc b/lib/op-attrs/src/flat.cc index 75d31beae4..03cf48843c 100644 --- a/lib/op-attrs/src/flat.cc +++ b/lib/op-attrs/src/flat.cc @@ -14,6 +14,20 @@ namespace Output { constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; } +ParallelTensorShape get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input) { + ParallelTensorShape output_shape(input.dims, input.data_type); + + output_shape.at(ff_dim_t(Output::CHANNEL)).size = + input.at(ff_dim_t(Input::CHANNEL)).size * + input.at(ff_dim_t(Input::HEIGHT)).size * + input.at(ff_dim_t(Input::WIDTH)).size; + output_shape.at(ff_dim_t(Output::CHANNEL)).degree = + input.at(ff_dim_t(Input::CHANNEL)).degree; + + return output_shape; +} + /* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */ /* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ @@ -25,26 +39,4 @@ constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; /* return is_valid; */ /* } */ -/* ParallelTensorShape FlatAttrs::calculate_output_shape(ParallelTensorShape - * const &input) const { */ -/* assert (input.num_dims() == Input::NUMDIM); */ -/* ParallelTensorShape output_dims; */ -/* output_dims.data_type = input.data_type; */ - -/* output_dims.at(Output::REPLICA) = input.at(Input::REPLICA); */ -/* output_dims.at(Output::SAMPLE) = input.at(Input::SAMPLE); */ - -/* output_dims.at(Output::CHANNEL).degree = input.at(Input::CHANNEL).degree; - */ -/* assert (input.at(Input::HEIGHT).degree == 1); */ -/* assert (input.at(Input::WIDTH).degree == 1); */ - -/* output_dims.at(Output::CHANNEL).size = input.at(Input::CHANNEL).size * - * input.at(Input::HEIGHT).size * input.at(Input::WIDTH).size; */ -/* output_dims.at(Output::CHANNEL).parallel_idx = - * input.at(Input::CHANNEL).parallel_idx; */ - -/* return output_dims; */ -/* } */ - } // namespace FlexFlow diff --git a/lib/op-attrs/src/reduction.cc b/lib/op-attrs/src/reduction.cc index 22fc9bab6a..1147045919 100644 --- a/lib/op-attrs/src/reduction.cc +++ b/lib/op-attrs/src/reduction.cc @@ -2,12 +2,12 @@ namespace FlexFlow { -/* ParallelTensorShape ReductionAttrs::output_shape(ParallelTensorShape const - * &input_shape) const { */ -/* ParallelTensorShape output = input_shape; */ -/* output.at(this->reduction_legion_dim).degree /= this->reduction_degree; */ -/* output.at(this->reduction_legion_dim).size /= this->reduction_degree; */ -/* return output; */ -/* } */ +ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ParallelTensorShape output(input_shape.dims, input_shape.data_type); + output.at(attrs.reduction_dim).degree /= attrs.reduction_degree; + output.at(attrs.reduction_dim).size /= attrs.reduction_degree; + return output; +} } // namespace FlexFlow From 31a8f2dbcc85bbe1b017d99419c90768bb6a54bf Mon Sep 17 00:00:00 2001 From: George Stelle Date: Fri, 15 Sep 2023 12:46:17 -0600 Subject: [PATCH 5/5] Addressed some of Colin's comments --- lib/op-attrs/src/parallel_tensor_shape.cc | 4 ---- lib/op-attrs/test/test_valid.cc | 14 ++++++-------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/lib/op-attrs/src/parallel_tensor_shape.cc b/lib/op-attrs/src/parallel_tensor_shape.cc index dc9c94e796..9a36e7d11b 100644 --- a/lib/op-attrs/src/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/parallel_tensor_shape.cc @@ -37,8 +37,4 @@ bool is_valid(ParallelTensorShape const &shape) { return is_valid(shape.dims); } -bool ParallelTensorShape::is_valid() { - return FlexFlow::is_valid(*this); -} - } // namespace FlexFlow diff --git a/lib/op-attrs/test/test_valid.cc b/lib/op-attrs/test/test_valid.cc index 6b43c10ffe..14a3b686cb 100644 --- a/lib/op-attrs/test/test_valid.cc +++ b/lib/op-attrs/test/test_valid.cc @@ -1,17 +1,15 @@ -#include "op-attrs/ops/gather.h" #include "doctest/doctest.h" +#include "op-attrs/ops/gather.h" #include #include using namespace FlexFlow; -TEST_CASE("gather_valid") { +TEST_CASE("GatherAttrs::is_valid") { - GatherAttrs g{ff_dim_t(2)}; - - TensorDims tds({2,2}); - ParallelTensorShape p(tds, DataType::FLOAT); - RC_ASSERT(g.is_valid(p, p)); + GatherAttrs g{ff_dim_t(2)}; + TensorDims tds({2, 2}); + ParallelTensorShape p(tds, DataType::FLOAT); + CHECK(g.is_valid(p, p)); }; -