From 7b199e1a35fffd00ed0d0858f6aab3581ad42b9e Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Tue, 20 Jun 2023 13:55:21 -0400 Subject: [PATCH] exception handling for loading models --- .../apache_beam/ml/inference/tensorflow_inference.py | 11 ++++++++++- .../ml/inference/tensorflow_inference_test.py | 6 ++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index d1f236dc53b1..991ae971d9e6 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -56,7 +56,15 @@ class ModelType(enum.Enum): def _load_model(model_uri, custom_weights, load_model_args): - model = tf.keras.models.load_model(hub.resolve(model_uri), **load_model_args) + try: + model = tf.keras.models.load_model( + hub.resolve(model_uri), **load_model_args) + except Exception as e: + raise ValueError( + "Unable to load the TensorFlow model: {exception}. Make sure you've \ + saved the model with TF2 format. Check out the list of TF2 Models on \ + TensorFlow Hub - https://tfhub.dev/s?subtype=module,placeholder&tf-version=tf2." # pylint: disable=line-too-long + .format(exception=e)) if custom_weights: model.load_weights(custom_weights) return model @@ -156,6 +164,7 @@ def load_model(self) -> tf.Module: "Callable create_model_fn must be passed" "with ModelType.SAVED_WEIGHTS") return _load_model_from_weights(self._create_model_fn, self._model_uri) + return _load_model( self._model_uri, self._custom_weights, self._load_model_args) diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py index 31dde5940102..52c525cc0eaf 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py @@ -363,6 +363,12 @@ def test_predict_keyed_tensor(self): for actual, expected in zip(inferences, expected_predictions): self.assertTrue(_compare_tensor_prediction_result(actual[1], expected[1])) + def test_load_model_exception(self): + with self.assertRaises(ValueError): + tensorflow_inference._load_model( + "https://tfhub.dev/google/imagenet/mobilenet_v1_075_192/quantops/classification/3", # pylint: disable=line-too-long + None, {}) + @pytest.mark.uses_tf class TFRunInferenceTestWithMocks(unittest.TestCase):