elementwise_util: don't cast the result of compute_fun back to the common type#9385
elementwise_util: don't cast the result of compute_fun back to the common type#9385
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9385
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled JobAs of commit f1c5429 with merge base 644b7dd ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
hmm, don't understand why unittest / macos is failing on CI but not locally. maybe PR is out of sync? rebasing |
…mmon type The compute function might return an entirely different type. For example, if we were applying a trigonometric function like acos to an input of type bool expecting an output of type float, we would get bad results because acos(0) = 1.57, but casting through bool would truncate that to 1. Note that we don't need the pair of ET_CHECK_MSG I removed because we already check tensor dtypes on entry to the elementwise util functions; the checks were inconvenient because we now call get_store_common_to_tensor_fn without the actual common type. ghstack-source-id: cfcbe8b ghstack-comment-id: 2735017325 Pull Request resolved: #9385
There was an ASAN failure, which is now fixed. |
|
need to add a regression test for acos case as well |
|
|
||
| template <typename CTYPE_COMMON, typename Op, typename... Args> | ||
| using op_call_result = | ||
| std::invoke_result_t<Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>; |
There was a problem hiding this comment.
why do you need ignore_first_yield_second? why not use CTYPE_COMMON directly in here?
There was a problem hiding this comment.
because I need the ... to produce sizeof...(Args) instances of CTYPE_COMMON. If you have a suggestion for a better way to do that I would love to hear it; this is the best I could do.
|
this breaks mul(int8, int8, out=long). I think we need to add a notion of "float ops" to the elementwise_util functions corresponding to TensorIterator's notion of "float ops", which essentially just means "set |
|
this diff has been removed from the stack it thinks it's in |
|
turns out this is unnecessary and the real problem is that "float ops" were not setting the compute dtype to a floating-point type |
…mmon type The compute function might return an entirely different type. For example, if we were applying a trigonometric function like acos to an input of type bool expecting an output of type float, we would get bad results because acos(0) = 1.57, but casting through bool would truncate that to 1. Note that we don't need the pair of ET_CHECK_MSG I removed because we already check tensor dtypes on entry to the elementwise util functions; the checks were inconvenient because we now call get_store_common_to_tensor_fn without the actual common type. ghstack-source-id: 72643ec ghstack-comment-id: 2735017325 Pull Request resolved: pytorch/executorch#9385
The compute function might return an entirely different type. For
example, if we were applying a trigonometric function like acos to an
input of type bool expecting an output of type float, we would get bad
results because acos(0) = 1.57, but casting through bool would
truncate that to 1.
Note that we don't need the pair of ET_CHECK_MSG I removed because we
already check tensor dtypes on entry to the elementwise util
functions; the checks were inconvenient because we now call
get_store_common_to_tensor_fn without the actual common type.