From 21ce4cbc59ec38c982fd85786bdd71ef9384cbc5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 29 Oct 2023 02:35:28 -0400 Subject: [PATCH] fix SpecifierSet behavior with prereleases By default, SpecifierSet doesn't allow prereleases, which is not our expected behavior. Signed-off-by: Jinzhe Zeng --- backend/find_tensorflow.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/find_tensorflow.py b/backend/find_tensorflow.py index aa75d5ecb4..6d7ce5087d 100644 --- a/backend/find_tensorflow.py +++ b/backend/find_tensorflow.py @@ -114,9 +114,9 @@ def get_tf_requirement(tf_version: str = "") -> dict: extra_requires = [] extra_select = {} - if not (tf_version == "" or tf_version in SpecifierSet(">=2.12")): + if not (tf_version == "" or tf_version in SpecifierSet(">=2.12", prereleases=True)): extra_requires.append("protobuf<3.20") - if tf_version == "" or tf_version in SpecifierSet(">=1.15"): + if tf_version == "" or tf_version in SpecifierSet(">=1.15", prereleases=True): extra_select["mpi"] = [ "horovod", "mpi4py", @@ -138,9 +138,9 @@ def get_tf_requirement(tf_version: str = "") -> dict: ], **extra_select, } - elif tf_version in SpecifierSet("<1.15") or tf_version in SpecifierSet( - ">=2.0,<2.1" - ): + elif tf_version in SpecifierSet( + "<1.15", prereleases=True + ) or tf_version in SpecifierSet(">=2.0,<2.1", prereleases=True): return { "cpu": [ f"tensorflow=={tf_version}",