Skip to content

Commit 804848a

Browse files
authored
fix: do not install tf-keras for cu11 (#3444)
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 2ee8a3b commit 804848a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

backend/find_tensorflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def find_tensorflow() -> Tuple[Optional[str], List[str]]:
8383
# TypeError if submodule_search_locations are None
8484
# IndexError if submodule_search_locations is an empty list
8585
except (AttributeError, TypeError, IndexError):
86+
tf_version = ""
8687
if os.environ.get("CIBUILDWHEEL", "0") == "1":
8788
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
8889
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
@@ -99,9 +100,10 @@ def find_tensorflow() -> Tuple[Optional[str], List[str]]:
99100
"tensorflow-cpu>=2.5.0rc0,<2.15; platform_machine=='x86_64' and platform_system == 'Linux'",
100101
]
101102
)
103+
tf_version = "2.14.1"
102104
else:
103105
raise RuntimeError("Unsupported CUDA version")
104-
requires.extend(get_tf_requirement()["cpu"])
106+
requires.extend(get_tf_requirement(tf_version)["cpu"])
105107
# setuptools will re-find tensorflow after installing setup_requires
106108
tf_install_dir = None
107109
return tf_install_dir, requires

0 commit comments

Comments
 (0)