From 99cd87e8898348255e864540e43bab17ce0576d6 Mon Sep 17 00:00:00 2001 From: Mike Gray Date: Sat, 8 Jun 2024 20:40:52 -0500 Subject: [PATCH 1/2] feat: allow full tensorflow usage if available closes #159 --- openwakeword/model.py | 29 +++++++++++++++++++---------- openwakeword/utils.py | 8 ++++++-- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/openwakeword/model.py b/openwakeword/model.py index 6029963..8e9149d 100755 --- a/openwakeword/model.py +++ b/openwakeword/model.py @@ -119,16 +119,25 @@ def tflite_predict(tflite_interpreter, input_index, output_index, x): return tflite_interpreter.get_tensor(output_index)[None, ] except ImportError: - logging.warning("Tried to import the tflite runtime, but it was not found. " - "Trying to switching to onnxruntime instead, if appropriate models are available.") - if wakeword_models != [] and all(['.onnx' in i for i in wakeword_models]): - inference_framework = "onnx" - elif wakeword_models != [] and all([os.path.exists(i.replace('.tflite', '.onnx')) for i in wakeword_models]): - inference_framework = "onnx" - wakeword_models = [i.replace('.tflite', '.onnx') for i in wakeword_models] - else: - raise ValueError("Tried to import the tflite runtime for provided tflite models, but it was not found. " - "Please install it using `pip install tflite-runtime`") + try: + from tensorflow.lite.python import interpreter as tflite + + def tflite_predict(tflite_interpreter, input_index, output_index, x): + tflite_interpreter.set_tensor(input_index, x) + tflite_interpreter.invoke() + return tflite_interpreter.get_tensor(output_index)[None, ] + + except ImportError: + logging.warning("Tried to import the tflite runtime, but it was not found. " + "Trying to switching to onnxruntime instead, if appropriate models are available.") + if wakeword_models != [] and all(['.onnx' in i for i in wakeword_models]): + inference_framework = "onnx" + elif wakeword_models != [] and all([os.path.exists(i.replace('.tflite', '.onnx')) for i in wakeword_models]): + inference_framework = "onnx" + wakeword_models = [i.replace('.tflite', '.onnx') for i in wakeword_models] + else: + raise ValueError("Tried to import the tflite runtime for provided tflite models, but it was not found. " + "Please install it using `pip install tflite-runtime`") if inference_framework == "onnx": try: diff --git a/openwakeword/utils.py b/openwakeword/utils.py index 4964706..5845d98 100644 --- a/openwakeword/utils.py +++ b/openwakeword/utils.py @@ -96,8 +96,12 @@ def __init__(self, try: import tflite_runtime.interpreter as tflite except ImportError: - raise ValueError("Tried to import the TFLite runtime, but it was not found." - "Please install it using `pip install tflite-runtime`") + try: + from tensorflow.lite.python import interpreter as tflite + except ImportError: + raise ValueError("Tried to import the TFLite runtime, but it was not found." + "Neither was the TensorFlow interpreter." + "Please install TFLite runtime using `pip install tflite-runtime`") if melspec_model_path == "": melspec_model_path = os.path.join(pathlib.Path(__file__).parent.resolve(), From 172e25bfa2e3b06ec9f1bc62ffc8206ca0c8ee9b Mon Sep 17 00:00:00 2001 From: Mike Date: Sun, 17 Nov 2024 17:14:29 -0600 Subject: [PATCH 2/2] more explicit tensorflow checks --- openwakeword/model.py | 7 ++++--- openwakeword/utils.py | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/openwakeword/model.py b/openwakeword/model.py index 8e9149d..621987a 100755 --- a/openwakeword/model.py +++ b/openwakeword/model.py @@ -119,15 +119,16 @@ def tflite_predict(tflite_interpreter, input_index, output_index, x): return tflite_interpreter.get_tensor(output_index)[None, ] except ImportError: - try: + from importlib.util import find_spec + if find_spec("tensorflow") is not None and find_spec("tflite_runtime") is None: + logging.warning("Tried to import the tflite runtime, but it was not found. Using tensorflow instead.") from tensorflow.lite.python import interpreter as tflite def tflite_predict(tflite_interpreter, input_index, output_index, x): tflite_interpreter.set_tensor(input_index, x) tflite_interpreter.invoke() return tflite_interpreter.get_tensor(output_index)[None, ] - - except ImportError: + else: logging.warning("Tried to import the tflite runtime, but it was not found. " "Trying to switching to onnxruntime instead, if appropriate models are available.") if wakeword_models != [] and all(['.onnx' in i for i in wakeword_models]): diff --git a/openwakeword/utils.py b/openwakeword/utils.py index 5845d98..e45c86c 100644 --- a/openwakeword/utils.py +++ b/openwakeword/utils.py @@ -96,9 +96,11 @@ def __init__(self, try: import tflite_runtime.interpreter as tflite except ImportError: - try: + from importlib.util import find_spec + if find_spec("tensorflow") is not None and find_spec("tflite_runtime") is None: + logging.warning("Tried to import the tflite runtime, but it was not found. Using tensorflow instead.") from tensorflow.lite.python import interpreter as tflite - except ImportError: + else: raise ValueError("Tried to import the TFLite runtime, but it was not found." "Neither was the TensorFlow interpreter." "Please install TFLite runtime using `pip install tflite-runtime`")