Skip to content

Optimize the jax and pytorch compability #2417

@scarlehoff

Description

@scarlehoff

I don't know if you are familiar with the (infamous) graveyard of projects of Google: https://killedbygoogle.com

A lot of Google's efforts are moving from tensorflow to jax jax. Going to jax-by-default would be an option, but I believe google will eventually kill jax as well when some middle manager comes up with whatever shiny new thing they have.

As much as I hate Meta and all they represent, they have a much better track record of keeping projects alive, so my vote would be for pytorch.

(just to clarify, we can use jax or pytorch already changing the backend, but a number of operations were optimized and benchmarked using tensorflow and there is room for improvement when using any of the other backends)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions