From 691da7008c348c9b25c4472601e0e9acfc7e9df2 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Fri, 1 Mar 2024 15:13:31 +0200 Subject: [PATCH] jax: update config import --- arraycontext/pytest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 4fce5885..b1bbec95 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -189,7 +189,7 @@ def is_available(cls) -> bool: return False def __call__(self): - from jax.config import config + from jax import config from arraycontext import EagerJAXArrayContext config.update("jax_enable_x64", True) @@ -214,7 +214,7 @@ def is_available(cls) -> bool: return False def __call__(self): - from jax.config import config + from jax import config from arraycontext import PytatoJAXArrayContext config.update("jax_enable_x64", True)