You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jan 16, 2025. It is now read-only.
Here is how I installed JAX v0.3.25 on Windows + Anaconda. It is a completely self-contained method that does not rely on any external Windows installers from nVIDIA.
BTW, I could create a pull request with these extra docs if it would help others?
# Install Anaconda or Miniconda
conda create -n py310jax python=3.10 -y
conda activate py310jax
conda install -c conda-forge cudatoolkit=11.1 cudnn -y
# Tensorflow 2.10 was the last version to support CUDA+GPU on Windows.
pip install "tensorflow<2.11"
# Install jaxlib
# - Download file "jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl" from "https://whls.blob.core.windows.net/unstable/index.html"
pip install jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl
# Install matching version of jax
pip install jax==0.3.25
# Now we can run JAX-based Python code on Windows.