Skip to content

op-attrs is_valid and shape inference functionality#1014

Closed
stelleg wants to merge 5 commits intoflexflow:repo-refactorfrom
stelleg:repo-refactor
Closed

op-attrs is_valid and shape inference functionality#1014
stelleg wants to merge 5 commits intoflexflow:repo-refactorfrom
stelleg:repo-refactor

Conversation

@stelleg
Copy link
Contributor

@stelleg stelleg commented Aug 21, 2023

Description of changes:
Fixing up is_valid and shape inference implementations and tests for repo-refactor.

Related Issues:

Linked Issues:

  • Issue #

Issues closed by this PR:

  • Closes #

This change is Reviewable

@stelleg stelleg marked this pull request as draft August 21, 2023 22:39
@stelleg stelleg force-pushed the repo-refactor branch 2 times, most recently from 5ee94f6 to 290e4c9 Compare August 21, 2023 22:48
@stelleg stelleg requested a review from lockshaw August 21, 2023 22:50
Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for ConcatAttrs::is_valid?

Reviewed 6 of 6 files at r1, 5 of 5 files at r2, all commit messages.
Reviewable status: all files reviewed, 5 unresolved discussions (waiting on @stelleg)


lib/op-attrs/include/op-attrs/ops/concat.h line 17 at r1 (raw file):

FF_VISITABLE_STRUCT(ConcatAttrs, axis);
CHECK_VALID_OP_ATTR(ConcatAttrs);

Especially with visitable types it's nice to use functions instead to keep the actual data definition easier to read

Suggestion:

struct ConcatAttrs {
  ff_dim_t axis;
};

FF_VISITABLE_STRUCT(ConcatAttrs, axis);
CHECK_VALID_OP_ATTR(ConcatAttrs);

bool is_valid(ConcatAttrs const &, std::vector<ParallelTensorShape> const &inputs);

lib/op-attrs/src/concat.cc line 7 at r1 (raw file):

bool ConcatAttrs::is_valid(
    std::vector<ParallelTensorShape> const &input) const {
  bool valid = true;

Tensor shape validity should be checked somewhere else (as it's common for all of the attrs--an attr is only valid if all of its input shapes are valid) so this can be removed. Useful checks might be things like "is the number of arguments greater than 0" and "are the shapes actually concatable along the given dimension". That said, implementing shape inference (i.e. get_output_shape) is much more important currently than implementing is_valid


lib/op-attrs/src/parallel_tensor_shape.cc line 40 at r1 (raw file):

}

bool ParallelTensorShape::is_valid() {

Why was this added? (also why not const?)


lib/op-attrs/test/test_valid.cc line 8 at r2 (raw file):

using namespace FlexFlow;

TEST_CASE("gather_valid") {

Suggestion:

TEST_CASE("GatherAttrs::is_valid") {

lib/op-attrs/test/test_valid.cc line 14 at r2 (raw file):

  TensorDims tds({2,2}); 
  ParallelTensorShape p(tds, DataType::FLOAT); 
  RC_ASSERT(g.is_valid(p, p));  

Suggestion:

  CHECK(g.is_valid(p, p));

@stelleg
Copy link
Contributor Author

stelleg commented Sep 15, 2023

@lockshaw Added shape inference for the following ops:

lib/op-attrs/src/cast.cc
lib/op-attrs/src/combine.cc
lib/op-attrs/src/concat.cc
lib/op-attrs/src/dropout.cc
lib/op-attrs/src/element_binary.cc
lib/op-attrs/src/element_unary.cc
lib/op-attrs/src/flat.cc
lib/op-attrs/src/reduction.cc

Let me know if there are others if you need pushed through.

Cheers,
George

@lockshaw lockshaw marked this pull request as ready for review September 18, 2023 22:05
@stelleg stelleg changed the title op-attrs is_valid functionality op-attrs is_valid and shape inference functionality Sep 25, 2023
Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 10 of 10 files at r3, all commit messages.
Reviewable status: all files reviewed, 11 unresolved discussions (waiting on @stelleg)


lib/op-attrs/src/cast.cc line 8 at r3 (raw file):

                                     ParallelTensorShape const &input) {
  ParallelTensorShape output = input;
  output.data_type = attrs.dtype;

return?


lib/op-attrs/src/concat.cc line 18 at r3 (raw file):

                     std::vector<ParallelTensorShape> const &inputs) {
  ParallelTensorShape output = inputs[0];
  for (auto &i : inputs) {

Suggestion:

  for (ParallelTensorShape const &i : inputs) {

lib/op-attrs/src/concat.cc line 19 at r3 (raw file):

  ParallelTensorShape output = inputs[0];
  for (auto &i : inputs) {
    output.at(attrs.axis).size += i.at(attrs.axis).size;

It might be clearer to use sum from containers.h here instead


lib/op-attrs/src/dropout.cc line 8 at r3 (raw file):

                                     ParallelTensorShape const &input) {
  ParallelTensorShape output = input;
  return output;

Suggestion:

  return input;

lib/op-attrs/src/element_binary.cc line 22 at r3 (raw file):

    } else {
      assert(false && "Operands could not be broadcast together");
      exit(0);

Suggestion:

      throw mk_runtime_error("Operands could not be broadcast together: in1={} in2={}", in1, in2);

lib/op-attrs/src/element_binary.cc line 23 at r3 (raw file):

      assert(false && "Operands could not be broadcast together");
      exit(0);
    }

Is there really not a clearer way to rewrite this?

Code quote:

  ParallelTensorShape output = in1.num_dims() >= in2.num_dims() ? in1 : in2;
  for (int i = 0; i < output.num_dims(); i++) {
    if (i >= in1.num_dims()) {
      output.at(ff_dim_t(i)) = in2.at(ff_dim_t(i));
    } else if (i >= in2.num_dims()) {
      output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i));
    } else if (in1.at(ff_dim_t(i)).size == in2.at(ff_dim_t(i)).size) {
      output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i));
    } else if (in1.at(ff_dim_t(i)).size == 1) {
      output.at(ff_dim_t(i)) = in2.at(ff_dim_t(i));
    } else if (in2.at(ff_dim_t(i)).size == 1) {
      output.at(ff_dim_t(i)) = in1.at(ff_dim_t(i));
    } else {
      assert(false && "Operands could not be broadcast together");
      exit(0);
    }

lib/op-attrs/src/element_unary.cc line 8 at r3 (raw file):

                                     ParallelTensorShape const &in) {
  ParallelTensorShape out = in;
  return out;

Suggestion:

  return in;

lib/op-attrs/src/flat.cc line 26 at r3 (raw file):

      input.at(ff_dim_t(Input::WIDTH)).size;
  output_shape.at(ff_dim_t(Output::CHANNEL)).degree =
      input.at(ff_dim_t(Input::CHANNEL)).degree;

Generalize this to arbitrary dimensions, not just 4

Code quote:

  output_shape.at(ff_dim_t(Output::CHANNEL)).size =
      input.at(ff_dim_t(Input::CHANNEL)).size *
      input.at(ff_dim_t(Input::HEIGHT)).size *
      input.at(ff_dim_t(Input::WIDTH)).size;
  output_shape.at(ff_dim_t(Output::CHANNEL)).degree =
      input.at(ff_dim_t(Input::CHANNEL)).degree;

lib/op-attrs/src/reduction.cc line 7 at r3 (raw file):

ParallelTensorShape get_output_shape(ReductionAttrs const &attrs,
                                     ParallelTensorShape const &input_shape) {
  ParallelTensorShape output(input_shape.dims, input_shape.data_type);

Suggestion:

  ParallelTensorShape output = input;

@reyna-abhyankar
Copy link
Collaborator

Move to #1183

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants