op-attrs is_valid and shape inference functionality#1014
op-attrs is_valid and shape inference functionality#1014stelleg wants to merge 5 commits intoflexflow:repo-refactorfrom
Conversation
5ee94f6 to
290e4c9
Compare
290e4c9 to
733da61
Compare
e0ad037 to
258e4b3
Compare
lockshaw
left a comment
There was a problem hiding this comment.
Tests for ConcatAttrs::is_valid?
Reviewed 6 of 6 files at r1, 5 of 5 files at r2, all commit messages.
Reviewable status: all files reviewed, 5 unresolved discussions (waiting on @stelleg)
lib/op-attrs/include/op-attrs/ops/concat.h line 17 at r1 (raw file):
FF_VISITABLE_STRUCT(ConcatAttrs, axis); CHECK_VALID_OP_ATTR(ConcatAttrs);
Especially with visitable types it's nice to use functions instead to keep the actual data definition easier to read
Suggestion:
struct ConcatAttrs {
ff_dim_t axis;
};
FF_VISITABLE_STRUCT(ConcatAttrs, axis);
CHECK_VALID_OP_ATTR(ConcatAttrs);
bool is_valid(ConcatAttrs const &, std::vector<ParallelTensorShape> const &inputs);lib/op-attrs/src/concat.cc line 7 at r1 (raw file):
bool ConcatAttrs::is_valid( std::vector<ParallelTensorShape> const &input) const { bool valid = true;
Tensor shape validity should be checked somewhere else (as it's common for all of the attrs--an attr is only valid if all of its input shapes are valid) so this can be removed. Useful checks might be things like "is the number of arguments greater than 0" and "are the shapes actually concatable along the given dimension". That said, implementing shape inference (i.e. get_output_shape) is much more important currently than implementing is_valid
lib/op-attrs/src/parallel_tensor_shape.cc line 40 at r1 (raw file):
} bool ParallelTensorShape::is_valid() {
Why was this added? (also why not const?)
lib/op-attrs/test/test_valid.cc line 8 at r2 (raw file):
using namespace FlexFlow; TEST_CASE("gather_valid") {
Suggestion:
TEST_CASE("GatherAttrs::is_valid") {lib/op-attrs/test/test_valid.cc line 14 at r2 (raw file):
TensorDims tds({2,2}); ParallelTensorShape p(tds, DataType::FLOAT); RC_ASSERT(g.is_valid(p, p));
Suggestion:
CHECK(g.is_valid(p, p));|
@lockshaw Added shape inference for the following ops: Let me know if there are others if you need pushed through. Cheers, |
lockshaw
left a comment
There was a problem hiding this comment.
Reviewed 10 of 10 files at r3, all commit messages.
Reviewable status: all files reviewed, 11 unresolved discussions (waiting on @stelleg)
lib/op-attrs/src/cast.cc line 8 at r3 (raw file):
ParallelTensorShape const &input) { ParallelTensorShape output = input; output.data_type = attrs.dtype;
return?
lib/op-attrs/src/concat.cc line 18 at r3 (raw file):
std::vector<ParallelTensorShape> const &inputs) { ParallelTensorShape output = inputs[0]; for (auto &i : inputs) {
Suggestion:
for (ParallelTensorShape const &i : inputs) {lib/op-attrs/src/concat.cc line 19 at r3 (raw file):
ParallelTensorShape output = inputs[0]; for (auto &i : inputs) { output.at(attrs.axis).size += i.at(attrs.axis).size;
It might be clearer to use sum from containers.h here instead
lib/op-attrs/src/dropout.cc line 8 at r3 (raw file):
ParallelTensorShape const &input) { ParallelTensorShape output = input; return output;
Suggestion:
return input;lib/op-attrs/src/element_binary.cc line 22 at r3 (raw file):
} else { assert(false && "Operands could not be broadcast together"); exit(0);
Suggestion:
throw mk_runtime_error("Operands could not be broadcast together: in1={} in2={}", in1, in2);lib/op-attrs/src/element_binary.cc line 23 at r3 (raw file):
assert(false && "Operands could not be broadcast together"); exit(0); }
Is there really not a clearer way to rewrite this?
Code quote:
ParallelTensorShape output = in1.num_dims() >= in2.num_dims() ? in1 : in2;
for (int i = 0; i < output.num_dims(); i++) {
if (i >= in1.num_dims()) {
output.at(ff_dim_t(i)) = in2.at(ff_dim_t(i));
} else if (i >= in2.num_dims()) {
output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i));
} else if (in1.at(ff_dim_t(i)).size == in2.at(ff_dim_t(i)).size) {
output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i));
} else if (in1.at(ff_dim_t(i)).size == 1) {
output.at(ff_dim_t(i)) = in2.at(ff_dim_t(i));
} else if (in2.at(ff_dim_t(i)).size == 1) {
output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i));
} else {
assert(false && "Operands could not be broadcast together");
exit(0);
}lib/op-attrs/src/element_unary.cc line 8 at r3 (raw file):
ParallelTensorShape const &in) { ParallelTensorShape out = in; return out;
Suggestion:
return in;lib/op-attrs/src/flat.cc line 26 at r3 (raw file):
input.at(ff_dim_t(Input::WIDTH)).size; output_shape.at(ff_dim_t(Output::CHANNEL)).degree = input.at(ff_dim_t(Input::CHANNEL)).degree;
Generalize this to arbitrary dimensions, not just 4
Code quote:
output_shape.at(ff_dim_t(Output::CHANNEL)).size =
input.at(ff_dim_t(Input::CHANNEL)).size *
input.at(ff_dim_t(Input::HEIGHT)).size *
input.at(ff_dim_t(Input::WIDTH)).size;
output_shape.at(ff_dim_t(Output::CHANNEL)).degree =
input.at(ff_dim_t(Input::CHANNEL)).degree;lib/op-attrs/src/reduction.cc line 7 at r3 (raw file):
ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, ParallelTensorShape const &input_shape) { ParallelTensorShape output(input_shape.dims, input_shape.data_type);
Suggestion:
ParallelTensorShape output = input;|
Move to #1183 |
Description of changes:
Fixing up is_valid and shape inference implementations and tests for repo-refactor.
Related Issues:
Linked Issues:
Issues closed by this PR:
This change is