Skip to content

Element Unary Op#1257

Merged
lockshaw merged 21 commits intoflexflow:repo-refactorfrom
reyna-abhyankar:element-unary-op
Mar 7, 2024
Merged

Element Unary Op#1257
lockshaw merged 21 commits intoflexflow:repo-refactorfrom
reyna-abhyankar:element-unary-op

Conversation

@reyna-abhyankar
Copy link
Collaborator

@reyna-abhyankar reyna-abhyankar commented Jan 1, 2024

Description of changes:

Element Unary operator

Related Issues:

Linked Issues:

Issues closed by this PR:


This change is Reviewable

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 6 of 6 files at r1, all commit messages.
Reviewable status: all files reviewed, 5 unresolved discussions (waiting on @reyna-abhyankar)


lib/kernels/include/kernels/element_unary_kernels.h line 13 at r1 (raw file):

struct ElementUnaryPerDeviceState {
  PerDeviceFFHandle handle;

I don't remember--should this actually be stored as part of the per device state or accessed using the task spec interface as some kind or argument?


lib/kernels/include/kernels/element_unary_kernels.h line 19 at r1 (raw file):

  OperatorType op_type;
  DataType data_type;
  float scalar;

These are part of the op attrs, not the per device state, right?


lib/kernels/src/cuda/element_unary_kernels.cu line 66 at r1 (raw file):

        break;
      default:
        assert(false);

It would be good to at least start to add an interface for doing more sophisticated error reporting (i.e., having messages) than just asserting out, even if the mechanism for actually handling and showing the error messages doesn't exist yet. This probable shouldn't be done as part of this PR, but it would be good to create an issue to begin tracking this.


lib/kernels/src/cuda/element_unary_kernels.cu line 72 at r1 (raw file):

    checkCUDNN(
        cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape));
    // input_shape == output_shape

What is the purpose of this? I'm not sure what this comment is referring to


lib/op-attrs/include/op-attrs/ops/element_unary.h line 24 at r1 (raw file):

FF_VISITABLE_STRUCT(ElementUnaryAttrs, op);
CHECK_VALID_OP_ATTR(ElementUnaryAttrs);

I don't think this should be removed?

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 5 unresolved discussions (waiting on @lockshaw)


lib/kernels/include/kernels/element_unary_kernels.h line 13 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

I don't remember--should this actually be stored as part of the per device state or accessed using the task spec interface as some kind or argument?

Currently, it's kind of ... both? We get handle from argument accessor, and then initialize the device state with it. I think this makes sense, since the handle is needed by the kernel.


lib/kernels/include/kernels/element_unary_kernels.h line 19 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

These are part of the op attrs, not the per device state, right?

I think it's ok to subsume it in per device state since it kind of streamlines the kernel code. But either is ok, if its attrs then it needs to be passed into the kernel.


lib/kernels/src/cuda/element_unary_kernels.cu line 66 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

It would be good to at least start to add an interface for doing more sophisticated error reporting (i.e., having messages) than just asserting out, even if the mechanism for actually handling and showing the error messages doesn't exist yet. This probable shouldn't be done as part of this PR, but it would be good to create an issue to begin tracking this.

Agree.


lib/kernels/src/cuda/element_unary_kernels.cu line 72 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

What is the purpose of this? I'm not sure what this comment is referring to

I'm actually not sure why the comment is there either, I think it's pretty obvious that for element unary the input shape == output shape. I think I just wanted to preserve the comment from before about input domain == output domain. But honestly don't think we need it.


lib/op-attrs/include/op-attrs/ops/element_unary.h line 24 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

I don't think this should be removed?

Done.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 2 of 2 files at r2, all commit messages.
Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @reyna-abhyankar)


lib/kernels/include/kernels/element_unary_kernels.h line 13 at r1 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Currently, it's kind of ... both? We get handle from argument accessor, and then initialize the device state with it. I think this makes sense, since the handle is needed by the kernel.

In that case it seems best to remove it from the PerDeviceState and then only pass it in as necessary. Like, if we already have one mechanism to access the handle, then having a second one likely opens the possibility to bugs/divergence between the two methods and makes it harder for beginners to reason about which to do. Thoughts?


lib/kernels/include/kernels/element_unary_kernels.h line 19 at r1 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

I think it's ok to subsume it in per device state since it kind of streamlines the kernel code. But either is ok, if its attrs then it needs to be passed into the kernel.

I'd argue that anything that can be passed as an argument should be, and that the device specific feature should only be used where necessary because it's behavior is a bit more complex than the nice value semantics of arguments. Also, it's nicer to just have a single mechanism even if it leads to slightly longer code--see the comment above


lib/kernels/src/cuda/element_unary_kernels.cu line 66 at r1 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Agree.

Can you create an issue, make sure it's added to the project, and then add a link to it here?

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 5 of 5 files at r3, 8 of 8 files at r4, all commit messages.
Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @reyna-abhyankar)

a discussion (no related file):
Reminder of the counterargument for removing ElementScalarUnaryAttrs: removing it allows many more invalid states (operators that don't use scalars to have scalars, operators that need scalars to not have them) that create more edge cases around hashing, equality, etc. that need to be detected. In addition many of these issues can only be detected at runtime, whereas having two different Attr types allows all of these to be prevented at compile-time.



lib/kernels/src/hip/element_unary_kernels.cpp line 184 at r4 (raw file):

                                      m,
                                      attrs,
                                      hanlde,

Suggestion:

handle

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 8 of 14 files reviewed, 3 unresolved discussions (waiting on @lockshaw)

a discussion (no related file):

Previously, lockshaw (Colin Unger) wrote…

Reminder of the counterargument for removing ElementScalarUnaryAttrs: removing it allows many more invalid states (operators that don't use scalars to have scalars, operators that need scalars to not have them) that create more edge cases around hashing, equality, etc. that need to be detected. In addition many of these issues can only be detected at runtime, whereas having two different Attr types allows all of these to be prevented at compile-time.

I've added support for this via inheritance. I think this works and we can keep the code in the rest of the element unary files non-redundant


Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 5 of 6 files at r5, 1 of 1 files at r6, all commit messages.
Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @lockshaw and @reyna-abhyankar)


lib/op-attrs/include/op-attrs/ops/element_unary.h line 20 at r6 (raw file):

struct ElementScalarUnaryAttrs : ElementUnaryAttrs {
  req<Op> op_type;
  req<float> scalar;

I don't think this does what you want it to. Inheritance in C++ is done by extension, so ElementScalarUnaryAttrs is effectively defined here as

struct ElementScalarUnaryAttrs {
  req<Op> op_type;
  req<optional<float>> scalar;
  req<Op> op_type;
  req<float> scalar;
};

which doesn't seem like what you want. I think what you want is more along the lines of

struct ElementUnaryAttrs {
  req<Op> op_type;
};

struct ElementScalarUnaryAttrs : public ElementUnaryAttrs {
  req<float> scalar;
};

Either of these approaches has the downside that I no longer have a type that denotes "a unary op attrs without a scalar", as now ElementUnaryAttrs represents "a unary op attrs with or without a scalar" and ElementScalarUnaryAttrs represents "a unary op attrs with a scalar". If you instead have two different and unrelated structs, then you're at least able to represent all three with ElementScalarUnaryAttrs, ElementUnaryAttrs, and variant<ElementUnaryAttrs, ElementScalarUnaryAttrs>. Of course dealing with the variant is a bit annoying at times, and so if you really think the reduction in code is worthwhile, I think the following would probably work (though I'm not sure the conceptual complexity is worth it--is subtyping really worth it here?)

struct GenericElementUnaryAttrs {
  req<Op> op_type;
};

struct NonscalarElementUnaryAttrs : public GenericElementUnaryAttrs { };
struct ScalarElementUnaryAttrs : public GenericElementUnaryAttrs {
  req<float> scalar;
};

and you'd have to check whether or not visitable works as expected with this inheritance structure (I think it will, but I'm not 100% certain). I'd also want a concrete example of the case you're worried about with the two-structs-and-a-variant solution to be sure that this more complex solution would actually solve it.

That said, I 100% get the overall impulse. Unfortunately, subtyping in C++ just kinda sucks. There are places (like the graph library) where we're making increased use of the combination of inheritance and visitable, but even then only in cases where the benefits of the additional complexity are clear and concrete--it's also easier to start with the two-structs-and-a-variant solution and move up to this solution if the additional expressiveness is really needed, but I think it's harder to remove the expressiveness of the subtyping if we realize it's more trouble than it's worth as existing code would likely rely on the behavior.

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 2 of 17 files reviewed, 2 unresolved discussions (waiting on @lockshaw)


lib/kernels/src/cuda/element_unary_kernels.cu line 66 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Can you create an issue, make sure it's added to the project, and then add a link to it here?

Done.


lib/op-attrs/include/op-attrs/ops/element_unary.h line 20 at r6 (raw file):

Previously, lockshaw (Colin Unger) wrote…

I don't think this does what you want it to. Inheritance in C++ is done by extension, so ElementScalarUnaryAttrs is effectively defined here as

struct ElementScalarUnaryAttrs {
  req<Op> op_type;
  req<optional<float>> scalar;
  req<Op> op_type;
  req<float> scalar;
};

which doesn't seem like what you want. I think what you want is more along the lines of

struct ElementUnaryAttrs {
  req<Op> op_type;
};

struct ElementScalarUnaryAttrs : public ElementUnaryAttrs {
  req<float> scalar;
};

Either of these approaches has the downside that I no longer have a type that denotes "a unary op attrs without a scalar", as now ElementUnaryAttrs represents "a unary op attrs with or without a scalar" and ElementScalarUnaryAttrs represents "a unary op attrs with a scalar". If you instead have two different and unrelated structs, then you're at least able to represent all three with ElementScalarUnaryAttrs, ElementUnaryAttrs, and variant<ElementUnaryAttrs, ElementScalarUnaryAttrs>. Of course dealing with the variant is a bit annoying at times, and so if you really think the reduction in code is worthwhile, I think the following would probably work (though I'm not sure the conceptual complexity is worth it--is subtyping really worth it here?)

struct GenericElementUnaryAttrs {
  req<Op> op_type;
};

struct NonscalarElementUnaryAttrs : public GenericElementUnaryAttrs { };
struct ScalarElementUnaryAttrs : public GenericElementUnaryAttrs {
  req<float> scalar;
};

and you'd have to check whether or not visitable works as expected with this inheritance structure (I think it will, but I'm not 100% certain). I'd also want a concrete example of the case you're worried about with the two-structs-and-a-variant solution to be sure that this more complex solution would actually solve it.

That said, I 100% get the overall impulse. Unfortunately, subtyping in C++ just kinda sucks. There are places (like the graph library) where we're making increased use of the combination of inheritance and visitable, but even then only in cases where the benefits of the additional complexity are clear and concrete--it's also easier to start with the two-structs-and-a-variant solution and move up to this solution if the additional expressiveness is really needed, but I think it's harder to remove the expressiveness of the subtyping if we realize it's more trouble than it's worth as existing code would likely rely on the behavior.

I changed it to the following and I use variant<x,y> as well when needed

Code snippet:

struct ElementUnaryAttrs {
  req<Op> op_type;
};

struct ElementScalarUnaryAttrs {
  req<Op> op_type;
  req<float> scalar;
};

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 15 of 15 files at r7, all commit messages.
Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @reyna-abhyankar)


lib/kernels/src/cuda/element_unary_kernels.cu line 66 at r1 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Done.

Not seeing link? What PR number is it?


lib/op-attrs/include/op-attrs/ops/element_unary.h line 18 at r7 (raw file):

struct ElementScalarUnaryAttrs {
  req<Op> op_type;

req only needed on the last field


lib/runtime/src/ops/element_unary.cc line 27 at r7 (raw file):

/* ElementUnary */
OpTaskInvocation init(ElementUnaryUnifiedAttrs const &attrs) {

This seems good. My only question is if the right overloadinit, forward, etc. will be picked up when called on one of the individual types of attrs. If not, it might be good to add trivial wrappers that forward to the variant version. If not, LGTM.


lib/substitutions/include/substitutions/get_attribute.h line 28 at r7 (raw file):

optional<OperatorAttributeValue> get_attribute(DropoutAttrs const &p,
                                               OperatorAttributeKey);
optional<OperatorAttributeValue> get_attribute(ElementBinaryAttrs const &p,

Why was this removed?

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 4 unresolved discussions (waiting on @lockshaw)


lib/kernels/src/cuda/element_unary_kernels.cu line 66 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Not seeing link? What PR number is it?

I accidentally hit Done when we were showing Abishek how to use reviewable LOL. Here it is #1316, i'll link in the PR


lib/op-attrs/include/op-attrs/ops/element_unary.h line 18 at r7 (raw file):

Previously, lockshaw (Colin Unger) wrote…

req only needed on the last field

You mean just req<float> scalar;? How come this is the case?


lib/runtime/src/ops/element_unary.cc line 27 at r7 (raw file):

Previously, lockshaw (Colin Unger) wrote…

This seems good. My only question is if the right overloadinit, forward, etc. will be picked up when called on one of the individual types of attrs. If not, it might be good to add trivial wrappers that forward to the variant version. If not, LGTM.

I think since attrs is a black box for the task invocations, it won't matter (i.e. the correct one will always be passed to the kernel call and then the kernel function will disambiguate it.


lib/substitutions/include/substitutions/get_attribute.h line 28 at r7 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why was this removed?

ElementBinary and ElementUnary are already present on lines 22 and 24, for some reason they were duplicated here.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 2 unresolved discussions (waiting on @reyna-abhyankar)


lib/op-attrs/include/op-attrs/ops/element_unary.h line 18 at r7 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

You mean just req<float> scalar;? How come this is the case?

Since C++ does not support named arguments, if the last member of the struct has to be passed to the constructor then every member of the struct has to be passed to the constructor, which is what we are looking for. There's nothing wrong with putting req on other fields, but it's unnecessary complication, and for consistency I just decided on the "req only on last field rule"


lib/runtime/src/ops/element_unary.cc line 27 at r7 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

I think since attrs is a black box for the task invocations, it won't matter (i.e. the correct one will always be passed to the kernel call and then the kernel function will disambiguate it.

I mean in local backing during the iteration over all of the nodes in the PCG calling forward--one of those nodes could be an ElementUnaryAttrs and I want to make sure that gets properly resolved to forward(ElementUnaryUnifiedAttrs const &)

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 16 of 17 files reviewed, 3 unresolved discussions (waiting on @lockshaw)


lib/op-attrs/include/op-attrs/ops/element_unary.h line 18 at r7 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Since C++ does not support named arguments, if the last member of the struct has to be passed to the constructor then every member of the struct has to be passed to the constructor, which is what we are looking for. There's nothing wrong with putting req on other fields, but it's unnecessary complication, and for consistency I just decided on the "req only on last field rule"

Done.


lib/runtime/src/ops/element_unary.cc line 27 at r7 (raw file):

Previously, lockshaw (Colin Unger) wrote…

I mean in local backing during the iteration over all of the nodes in the PCG calling forward--one of those nodes could be an ElementUnaryAttrs and I want to make sure that gets properly resolved to forward(ElementUnaryUnifiedAttrs const &)

Just tested this in a simple case offline if a caller invokes the function with either of the variant types, and it works.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 1 files at r10, all commit messages.
Reviewable status: :shipit: complete! all files reviewed, all discussions resolved (waiting on @reyna-abhyankar)

@lockshaw lockshaw merged commit 3237169 into flexflow:repo-refactor Mar 7, 2024
@lockshaw lockshaw deleted the element-unary-op branch March 7, 2024 02:08
@lockshaw lockshaw linked an issue Mar 28, 2024 that may be closed by this pull request
@victorli2002 victorli2002 mentioned this pull request Mar 12, 2025
@victorli2002 victorli2002 mentioned this pull request Apr 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Update ElementUnary operator

2 participants