Skip to content

Conversation

@mofeing
Copy link
Collaborator

@mofeing mofeing commented May 27, 2024

Fixes #82

@wsmoses I don't know much about LLVM, so I don't know how to implement them. But I wrote the prototypes and some annotations on how to continue the implementation.

Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should copy the gradpad test which computes derivatives

@wsmoses
Copy link
Member

wsmoses commented May 28, 2024

@mofeing look here for example of how to deal with the einsum config.

def InversePermutation : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{

I can help as well, but my question for you is offhand do you know the corresponding derivative rule for the config

@mofeing
Copy link
Collaborator Author

mofeing commented May 28, 2024

I can help as well, but my question for you is offhand do you know the corresponding derivative rule for the config

@wsmoses The reverse derivative for einsum is very easy. Basically it's just another einsum per input where you replace the computing . Imagine the following einsum:

$$ C = \mathtt{einsum}(''$ia,$ib->$ic'', A, B) $$

where $A,B,C$ are tensors and $ia,ib,ic$ denote the list of indices of each tensor. The string is the einsum config.

Then the reverse derivatives are the following:

$$ dA = \mathtt{einsum}(''$ib,$ic->$ia'', conj(B), dC) $$

$$ dB = \mathtt{einsum}(''$ia,$ic->$ib'', conj(A), dC) $$

I see that we need to do the following:

  1. Modify the einsum config like the patterns above.
  2. Apply the conjugate only if A or B respectively, are complex. The main reason is that complex.conj only works on complex types, not on f32/f64.

@wsmoses
Copy link
Member

wsmoses commented May 28, 2024

To start with let's not add complex stuff yet and do that in a subsequent PR. Other things will also need to be updated here as well so might as well do it generically for everything.

@wsmoses wsmoses closed this May 28, 2024
@wsmoses wsmoses reopened this May 28, 2024
@wsmoses
Copy link
Member

wsmoses commented May 28, 2024

[whoops meant to comment, not close...my b]

@wsmoses
Copy link
Member

wsmoses commented May 28, 2024

In any case, want to try to add the config update?

@mofeing
Copy link
Collaborator Author

mofeing commented May 28, 2024

Yeah, I see that GlobalExpr lets you write C++ code, right? I just don't know what I have access to and the methods of op inside that code.

For example, where is .getPermutation() defined in the code you pointed?

@wsmoses
Copy link
Member

wsmoses commented May 28, 2024 via email

@mofeing
Copy link
Collaborator Author

mofeing commented May 29, 2024

@wsmoses just uploaded the stablehlo.einsum test with complex data. it might still error due to the result not being exactly the same, but it should be easily fixable.

To do

  •  add activity support to stablehlo.einsum and stablehlo.unary_einsum diff rules
  • fix diff rule for stablehlo.einsum on complex data

@wsmoses
Copy link
Member

wsmoses commented May 31, 2024

@mofeing tests seem to fail, worth investigating?

@mofeing mofeing merged commit 9f26b9c into EnzymeAD:main Jun 1, 2024
@mofeing mofeing deleted the fix/diff-einsum branch June 1, 2024 16:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Missing differentiation rules for einsum, unary_einsum

2 participants