Skip to content

Conversation

@jwfromm
Copy link
Contributor

@jwfromm jwfromm commented Aug 3, 2022

This PR introduces the multinomial random operator. It's a neat adaptation of random.uniform that allows weighted selection of indices from a probability tensor. This op is used in new Dalle-like architectures to generate random images. The PR provides a topi implemenation and tests, relay integration, and an initial pytorch integration. I did not implement sampling without replacement at this time as it seems complicated to do as a tensor operation.

@jwfromm
Copy link
Contributor Author

jwfromm commented Aug 3, 2022

@sfvaroglu can you take a look at this PR?

@jwfromm jwfromm requested a review from tkonolige August 3, 2022 05:56
@sfvaroglu
Copy link
Contributor

LGTM, thanks @jwfromm! Would be nice to have this in the onnx importer, too :)

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @jwfromm. I would just like a little testing on the output to make sure it is actually a multinomial distribution. Let me know if you think that is too complicated.

assert not (
replacement is False and num_samples > 1
), "Multinomial without replacement is not yet supported."
seed = np.random.randint(1e6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally there would be one seed that we pass through the entire graph that is set or initialized at runtime. But I don't think we have the infrastructure for that yet. This is fine for now but maybe you could add a comment about how to improve this in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thats a good point, we'd have to use a global dictionary or something for that. I'll add a note. For now, this approach matches how we handle other rng importer functions.


@tvm.testing.parametrize_targets
def test_multinomial(target, dev):
def _verify_multinomial(size, num_samples):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you do some rough checking of expected value and variance of the distribution. It's always hard to tell if these random things are implemented correctly, but I think this would help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a good way to do this without potentially introducing flakiness? I guess we could use a fixed seed. Would that be satisfactory?

Copy link
Contributor

@octoJon octoJon Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could generate a "large" sample of at least 10,000 values and then use a chi-squared test (scipy.stats.chisquare). You'd look at the p-value from that chi-squared test and compare it to an acceptably low threshold for flakiness -- for example, have this unit test fail if the p-value is smaller than 1e-6, which should only happen by chance in one run per million.

Copy link
Contributor Author

@jwfromm jwfromm Aug 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this tip. I added a chisquared test which confirms that the behavior of this function is expected.

@jwfromm
Copy link
Contributor Author

jwfromm commented Aug 6, 2022

@tkonolige can you give this another look. I think its all set to merge.

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jwfromm

@tkonolige tkonolige merged commit b79f950 into apache:main Aug 8, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* Add multinomial operator.
* Implemented Pytorch integration with multinomial.
* Fixed test paramatrization and added onnx integration.
* Add statistical testing.
* Make get_type more flexible.
@jwfromm jwfromm deleted the torch_multinomial branch April 12, 2023 15:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants