Skip to content

Conversation

@oliverholworthy
Copy link
Contributor

Update SoftmaxSampling to support dynamic dtypes.

Currently the output schema of the SoftmaxSampling operator returns static dtypes which does not match the data returned when the operator is called with different dtypes from these.

This update the compute_output_schema method to match the implementation of transform. It returns ids and scores with the corresponding dtypes of the input ids and scores.

Alternatively, we could update the transform method to coerce the outputs to match the static types. However, it feels to me that matching the types passed in could feel more natural, as it would match the behaviour or array and tensor processing libraries, which typically return the same type as came in unless specified explicitly with a cast or extra parameter.

The output types match the input dtypes passed. And the output schema
is computed from the input types passed.
@oliverholworthy oliverholworthy added the enhancement New feature or request label Mar 24, 2023
@oliverholworthy oliverholworthy added this to the Merlin 23.03 milestone Mar 24, 2023
@oliverholworthy oliverholworthy self-assigned this Mar 24, 2023
@rnyak rnyak requested a review from karlhigley March 24, 2023 15:29
@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/systems/review/pr-304

@karlhigley karlhigley merged commit d1d270e into main Mar 24, 2023
@rnyak rnyak deleted the softmax-dynamic-dtype branch March 24, 2023 19:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants