From 1a49cebb67ce0f172fae70b77fd44de172401520 Mon Sep 17 00:00:00 2001 From: Liam Sturge Date: Thu, 27 Jul 2023 11:13:07 +0100 Subject: [PATCH] Bump Flax and Jax versions to fix install error The Flax dependency Orbax (v0.1.8) has deprecated being able to install Orbax as a standalone package. Flax v0.6.8 attempts to install Orbax as a standalone package and raises an error about doing so. Going forward, the package orbax-checkpoint should be installed instead. Flax v0.6.8 does not recognize this and attempts to install Orbax instead of orbax-checkpoint and the installation fails. In order to resolve Jax installation issues, bumping the version of Flax to be at least 0.6.9, which resolves the problem. Flax >= 0.6.9 does not pin the version of orbax-checkpoint that it installs and the latest version requires Jax >= 0.4.9 to be installed so the two must be updated together. --- docker/install/ubuntu_install_jax.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docker/install/ubuntu_install_jax.sh b/docker/install/ubuntu_install_jax.sh index 87cb6f7dbe47..19149909161e 100644 --- a/docker/install/ubuntu_install_jax.sh +++ b/docker/install/ubuntu_install_jax.sh @@ -23,13 +23,13 @@ set -o pipefail # Install jax and jaxlib if [ "$1" == "cuda" ]; then pip3 install --upgrade \ - jaxlib==0.4.7 \ - "jax[cuda11_pip]==0.4.7" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + jaxlib~=0.4.9 \ + "jax[cuda11_pip]~=0.4.9" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html else pip3 install --upgrade \ - jaxlib==0.4.7 \ - "jax[cpu]==0.4.7" + jaxlib~=0.4.9 \ + "jax[cpu]~=0.4.9" fi # Install flax -pip3 install flax==0.6.8 +pip3 install flax~=0.6.9