[Relax] Clean up scatter_elements unknown dtype handling#18577
[Relax] Clean up scatter_elements unknown dtype handling#18577tlopex merged 1 commit intoapache:mainfrom
Conversation
Summary of ChangesHello @guan404ming, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refines the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request cleans up the handling of unknown dtypes in scatter_elements by removing some TODO comments and adds test coverage for these scenarios. The changes are good, but I've found a few areas for improvement.
In src/relax/op/tensor/manipulate.cc, I've identified a potential bug where the wrong variable is used to check the dtype of updates, and a minor typo in a log message.
In tests/python/relax/test_op_manipulate.py, while the added tests are valuable, I've suggested adding another test case for when indices has an unknown dtype to ensure full coverage of the warning paths.
| if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { | ||
| auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) { | ||
| if (sinfo->IsUnknownDtype()) { | ||
| // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? |
There was a problem hiding this comment.
While you're cleaning up this section, I noticed a potential bug in the usage of the diag_dtype lambda. On line 2406 (in the full file), it seems data_sinfo is used for both "data" and "updates", but it should probably be updates_sinfo for "updates". This looks like a copy-paste error. It should be diag_dtype(updates_sinfo, "updates"); to correctly check the updates tensor.
| } | ||
|
|
||
| if (indices_sinfo->IsUnknownDtype()) { | ||
| // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? |
There was a problem hiding this comment.
| # Test with unknown dtype for data | ||
| d_unknown = relax.Var("data", R.Tensor((4, 4))) | ||
| _check_inference( | ||
| bb, | ||
| relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"), | ||
| relax.TensorStructInfo((4, 4), dtype=""), | ||
| ) | ||
| # Test with unknown dtype for updates | ||
| u_unknown = relax.Var("updates", R.Tensor((2, 2))) | ||
| _check_inference( | ||
| bb, | ||
| relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"), | ||
| relax.TensorStructInfo((4, 4), dtype="float32"), | ||
| ) |
There was a problem hiding this comment.
These are good additions for testing unknown dtypes for data and updates. To improve test coverage further, could you also add a test case for when indices has an unknown dtype? This would cover the warning log path for indices_sinfo->IsUnknownDtype() in InferStructInfoScatterElements.
# Test with unknown dtype for data
d_unknown = relax.Var("data", R.Tensor((4, 4)))
_check_inference(
bb,
relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"),
relax.TensorStructInfo((4, 4), dtype=""),
)
# Test with unknown dtype for updates
u_unknown = relax.Var("updates", R.Tensor((2, 2)))
_check_inference(
bb,
relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"),
relax.TensorStructInfo((4, 4), dtype="float32"),
)
# Test with unknown dtype for indices
i_unknown = relax.Var("indices", R.Tensor((2, 2)))
_check_inference(
bb,
relax.op.scatter_elements(d0, i_unknown, u0, 0, "updates"),
relax.TensorStructInfo((4, 4), dtype="float32"),
)7918d90 to
b4dec7c
Compare
Why