Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/op-attrs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ ff_add_library(
)

add_subdirectory(ffi)
add_subdirectory(test)
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ namespace FlexFlow {

struct ConcatAttrs {
ff_dim_t axis;
bool is_valid(std::vector<ParallelTensorShape> const &input) const;
};

FF_VISITABLE_STRUCT(ConcatAttrs, axis);
CHECK_VALID_OP_ATTR(ConcatAttrs);

Expand Down
3 changes: 3 additions & 0 deletions lib/op-attrs/include/op-attrs/ops/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ namespace FlexFlow {

struct GatherAttrs {
ff_dim_t dim;
bool is_valid(ParallelTensorShape const &lhs,
ParallelTensorShape const &rhs) const;
};

FF_VISITABLE_STRUCT(GatherAttrs, dim);
CHECK_VALID_OP_ATTR(GatherAttrs);

Expand Down
2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/parallel_tensor_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct ParallelTensorShape : public use_visitable_cmp<ParallelTensorShape> {
ParallelDim const &operator[](ff_dim_t const &) const;
ParallelDim &operator[](ff_dim_t const &);

bool is_valid();

public:
ParallelTensorDims dims;
DataType data_type;
Expand Down
6 changes: 6 additions & 0 deletions lib/op-attrs/src/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

namespace FlexFlow {

ParallelTensorShape get_output_shape(CastAttrs const &attrs,
ParallelTensorShape const &input) {
ParallelTensorShape output = input;
output.data_type = attrs.dtype;
}

/* bool CastAttrs::is_valid(ParallelTensorShape const &input) const { */
/* bool valid = input.is_valid(); */
/* valid &= (input.at(input.num_dims() - 1).degree == 1); */
Expand Down
7 changes: 7 additions & 0 deletions lib/op-attrs/src/combine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

namespace FlexFlow {

ParallelTensorShape output_shape(CombineAttrs const &attrs,
ParallelTensorShape const &input_shape) {
ParallelTensorShape output = input_shape;
output.at(attrs.combine_dim).degree /= attrs.combine_degree;
return output;
}

/* bool CombineAttrs::is_valid(ParallelTensorShape const &input) const { */
/* return input.at(this->combine_legion_dim).degree % this->combine_degree ==
* 0; */
Expand Down
24 changes: 16 additions & 8 deletions lib/op-attrs/src/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@

namespace FlexFlow {

/* bool ConcatAttrs::is_valid( */
/* std::vector<ParallelTensorShape> const &input) const { */
/* bool valid = true; */
/* for (auto p : input) { */
/* valid &= p.is_valid(); */
/* } */
/* return valid; */
/* } */
bool ConcatAttrs::is_valid(
std::vector<ParallelTensorShape> const &input) const {
bool valid = true;
for (auto p : input) {
valid &= p.is_valid();
}
return valid;
}

ParallelTensorShape
get_output_shape(ConcatAttrs const &attrs,
std::vector<ParallelTensorShape> const &inputs) {
ParallelTensorShape output = inputs[0];
for (auto &i : inputs) {
output.at(attrs.axis).size += i.at(attrs.axis).size;
}
}
} // namespace FlexFlow
11 changes: 11 additions & 0 deletions lib/op-attrs/src/dropout.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "op-attrs/ops/dropout.h"

namespace FlexFlow {

ParallelTensorShape get_output_shape(DropoutAttrs const &attrs,
ParallelTensorShape const &input) {
ParallelTensorShape output = input;
return output;
}

} // namespace FlexFlow
28 changes: 27 additions & 1 deletion lib/op-attrs/src/element_binary.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
#include "op-attrs/ops/element_binary.h"

namespace FlexFlow {} // namespace FlexFlow
namespace FlexFlow {

ParallelTensorShape get_output_shape(ElementBinaryAttrs const &attrs,
ParallelTensorShape const &in1,
ParallelTensorShape const &in2) {
ParallelTensorShape output = in1.num_dims() >= in2.num_dims() ? in1 : in2;
for (int i = 0; i < output.num_dims(); i++) {
if (i >= in1.num_dims()) {
output.at(ff_dim_t(i)) = in2.at(ff_dim_t(i));
} else if (i >= in2.num_dims()) {
output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i));
} else if (in1.at(ff_dim_t(i)).size == in2.at(ff_dim_t(i)).size) {
output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i));
} else if (in1.at(ff_dim_t(i)).size == 1) {
output.at(ff_dim_t(i)) = in2.at(ff_dim_t(i));
} else if (in2.at(ff_dim_t(i)).size == 1) {
output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i));
} else {
assert(false && "Operands could not be broadcast together");
exit(0);
}
}

return output;
}

} // namespace FlexFlow
10 changes: 9 additions & 1 deletion lib/op-attrs/src/element_unary.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
#include "op-attrs/ops/element_unary.h"

namespace FlexFlow {} // namespace FlexFlow
namespace FlexFlow {

ParallelTensorShape get_output_shape(ElementUnaryAttrs const &attrs,
ParallelTensorShape const &in) {
ParallelTensorShape out = in;
return out;
}

} // namespace FlexFlow
36 changes: 14 additions & 22 deletions lib/op-attrs/src/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ namespace Output {
constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2;
}

ParallelTensorShape get_output_shape(FlatAttrs const &attrs,
ParallelTensorShape const &input) {
ParallelTensorShape output_shape(input.dims, input.data_type);

output_shape.at(ff_dim_t(Output::CHANNEL)).size =
input.at(ff_dim_t(Input::CHANNEL)).size *
input.at(ff_dim_t(Input::HEIGHT)).size *
input.at(ff_dim_t(Input::WIDTH)).size;
output_shape.at(ff_dim_t(Output::CHANNEL)).degree =
input.at(ff_dim_t(Input::CHANNEL)).degree;

return output_shape;
}

/* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */
/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */

Expand All @@ -25,26 +39,4 @@ constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2;
/* return is_valid; */
/* } */

/* ParallelTensorShape FlatAttrs::calculate_output_shape(ParallelTensorShape
* const &input) const { */
/* assert (input.num_dims() == Input::NUMDIM); */
/* ParallelTensorShape output_dims; */
/* output_dims.data_type = input.data_type; */

/* output_dims.at(Output::REPLICA) = input.at(Input::REPLICA); */
/* output_dims.at(Output::SAMPLE) = input.at(Input::SAMPLE); */

/* output_dims.at(Output::CHANNEL).degree = input.at(Input::CHANNEL).degree;
*/
/* assert (input.at(Input::HEIGHT).degree == 1); */
/* assert (input.at(Input::WIDTH).degree == 1); */

/* output_dims.at(Output::CHANNEL).size = input.at(Input::CHANNEL).size *
* input.at(Input::HEIGHT).size * input.at(Input::WIDTH).size; */
/* output_dims.at(Output::CHANNEL).parallel_idx =
* input.at(Input::CHANNEL).parallel_idx; */

/* return output_dims; */
/* } */

} // namespace FlexFlow
26 changes: 13 additions & 13 deletions lib/op-attrs/src/gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

namespace FlexFlow {

/* bool GatherAttrs::is_valid(ParallelTensorShape const &lhs,
* ParallelTensorShape const &rhs) const { */
/* if (lhs.num_dims() != rhs.num_dims()) { */
/* return false; */
/* } */
/* for (int i = 0; i < lhs.num_dims(); i++) { */
/* if (i != this->legion_dim && */
/* lhs.at(i).size < rhs.at(i).size) { */
/* return false; */
/* } */
/* } */
/* return true; */
/* } */
bool GatherAttrs::is_valid(ParallelTensorShape const &lhs,
ParallelTensorShape const &rhs) const {
if (lhs.dims.num_dims() != rhs.dims.num_dims()) {
return false;
}
for (auto i : lhs.dims) {
if (ff_dim_t(i.size) != this->dim &&
lhs.at(ff_dim_t(i.size)).size < rhs.at(ff_dim_t(i.size)).size) {
return false;
}
}
return true;
}

} // namespace FlexFlow
14 changes: 7 additions & 7 deletions lib/op-attrs/src/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

namespace FlexFlow {

/* ParallelTensorShape ReductionAttrs::output_shape(ParallelTensorShape const
* &input_shape) const { */
/* ParallelTensorShape output = input_shape; */
/* output.at(this->reduction_legion_dim).degree /= this->reduction_degree; */
/* output.at(this->reduction_legion_dim).size /= this->reduction_degree; */
/* return output; */
/* } */
ParallelTensorShape get_output_shape(ReductionAttrs const &attrs,
ParallelTensorShape const &input_shape) {
ParallelTensorShape output(input_shape.dims, input_shape.data_type);
output.at(attrs.reduction_dim).degree /= attrs.reduction_degree;
output.at(attrs.reduction_dim).size /= attrs.reduction_degree;
return output;
}

} // namespace FlexFlow
10 changes: 10 additions & 0 deletions lib/op-attrs/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
ff_add_test_executable(
NAME
op-attrs-test
SRC_PATTERNS
*.cc
DEPS
op-attrs
doctest
utils-test-common
)
2 changes: 2 additions & 0 deletions lib/op-attrs/test/main.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
#include "doctest/doctest.h"
15 changes: 15 additions & 0 deletions lib/op-attrs/test/test_valid.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "doctest/doctest.h"
#include "op-attrs/ops/gather.h"
#include <rapidcheck.h>
#include <vector>

using namespace FlexFlow;

TEST_CASE("GatherAttrs::is_valid") {

GatherAttrs g{ff_dim_t(2)};

TensorDims tds({2, 2});
ParallelTensorShape p(tds, DataType::FLOAT);
CHECK(g.is_valid(p, p));
};