-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay][PRNG] Support generating data of any shape in threefry_generate #8085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tkonolige
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I've left a couple of comments. Also, could you add a test that makes sure the last 3 random number of a size 4*x+3 shape change between calls with different keys?
python/tvm/topi/random/kernel.py
Outdated
| with irb.if_scope(out_len % 4 != 0): | ||
| out_gen[7] = tmp[7] + tir.Cast(gen.dtype, out_len % 4) | ||
| with irb.else_scope(): | ||
| out_gen[7] = tmp[7] + tir.Cast(gen.dtype, out_len) # increment counter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect. It should just be out_gen[7] = tmp[7] + tir.Cast(gen.dtype, out_len) because we used out_len values from the counter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tkonolige I changed the increment to out_len % 4 here because I already update the tmp[7] above with:
tmp[7] = tmp[7] + tir.Cast(gen.dtype, out_len // 4 * 4)The reason for the separate update is that I think we need to update the key before the second threefry.
However, I update the second update from tir.Cast(gen.dtype, out_len % 4) to tir.Cast(gen.dtype, 4), because 4 random numbers are actually generated (though partially discarded).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, you are correct. Could you leave a comment that it was updated earlier.
python/tvm/topi/random/kernel.py
Outdated
|
|
||
| # Compute random values | ||
| _threefry(irb, tmp, 0, tmp, 4, out_array, 0, out_len // 4) | ||
| with irb.if_scope(out_len % 4 != 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be if out_len.value % 4 != 0:
|
@tkonolige Thank you for your review~ I've updated the code based on them. Could you take a second look? As for the incorrect part of the update, I have some different opinion, see the comment above. |
|
@tkonolige I've added the comment and rebased the code to main branch to fix the shape limit in |
tkonolige
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thank you!
|
@tqchen Could you give a help to this PR? Thank you~ 😄 |
This PR adds support for generating data of any shape in
threefry_generate.Thank you for your time on reviewing this pr.
cc @tkonolige