diff --git a/docker/install/ubuntu_install_jax.sh b/docker/install/ubuntu_install_jax.sh index 87cb6f7dbe47..19149909161e 100644 --- a/docker/install/ubuntu_install_jax.sh +++ b/docker/install/ubuntu_install_jax.sh @@ -23,13 +23,13 @@ set -o pipefail # Install jax and jaxlib if [ "$1" == "cuda" ]; then pip3 install --upgrade \ - jaxlib==0.4.7 \ - "jax[cuda11_pip]==0.4.7" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + jaxlib~=0.4.9 \ + "jax[cuda11_pip]~=0.4.9" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html else pip3 install --upgrade \ - jaxlib==0.4.7 \ - "jax[cpu]==0.4.7" + jaxlib~=0.4.9 \ + "jax[cpu]~=0.4.9" fi # Install flax -pip3 install flax==0.6.8 +pip3 install flax~=0.6.9