diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index d7e0b7395deb..19c6f5b5046f 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -18,10 +18,8 @@ import enum import pickle import sys -from typing import Any from typing import Iterable from typing import Sequence -from typing import Union import numpy import pandas @@ -43,13 +41,26 @@ class ModelFileType(enum.Enum): JOBLIB = 2 -class SklearnModelHandler(ModelHandler[Union[numpy.ndarray, pandas.DataFrame], - PredictionResult, - BaseEstimator]): - """ Implementation of the ModelHandler interface for scikit-learn. +def _load_model(model_uri, file_type): + file = FileSystems.open(model_uri, 'rb') + if file_type == ModelFileType.PICKLE: + return pickle.load(file) + elif file_type == ModelFileType.JOBLIB: + if not joblib: + raise ImportError( + 'Could not import joblib in this execution environment. ' + 'For help with managing dependencies on Python workers.' + 'see https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/' # pylint: disable=line-too-long + ) + return joblib.load(file) + raise AssertionError('Unsupported serialization type.') - NOTE: This API and its implementation are under development and - do not provide backward compatibility guarantees. + +class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, + PredictionResult, + BaseEstimator]): + """ Implementation of the ModelHandler interface for scikit-learn + using numpy arrays as input. """ def __init__( self, @@ -60,42 +71,44 @@ def __init__( def load_model(self) -> BaseEstimator: """Loads and initializes a model for processing.""" - file = FileSystems.open(self._model_uri, 'rb') - if self._model_file_type == ModelFileType.PICKLE: - return pickle.load(file) - elif self._model_file_type == ModelFileType.JOBLIB: - if not joblib: - raise ImportError( - 'Could not import joblib in this execution environment. ' - 'For help with managing dependencies on Python workers.' - 'see https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/' # pylint: disable=line-too-long - ) - return joblib.load(file) - raise AssertionError('Unsupported serialization type.') + return _load_model(self._model_uri, self._model_file_type) def run_inference( - self, - batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]], - model: BaseEstimator, + self, batch: Sequence[numpy.ndarray], model: BaseEstimator, **kwargs) -> Iterable[PredictionResult]: - # TODO(github.com/apache/beam/issues/21769): Use supplied input type hint. - if isinstance(batch[0], numpy.ndarray): - return SklearnModelHandler._predict_np_array(batch, model) - elif isinstance(batch[0], pandas.DataFrame): - return SklearnModelHandler._predict_pandas_dataframe(batch, model) - raise ValueError('Unsupported data type.') - - @staticmethod - def _predict_np_array(batch: Sequence[numpy.ndarray], - model: Any) -> Iterable[PredictionResult]: # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) predictions = model.predict(vectorized_batch) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] - @staticmethod - def _predict_pandas_dataframe(batch: Sequence[pandas.DataFrame], - model: Any) -> Iterable[PredictionResult]: + def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int: + """Returns the number of bytes of data for a batch.""" + return sum(sys.getsizeof(element) for element in batch) + + +class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame, + PredictionResult, + BaseEstimator]): + """ Implementation of the ModelHandler interface for scikit-learn that + supports pandas dataframes. + + NOTE: This API and its implementation are under development and + do not provide backward compatibility guarantees. + """ + def __init__( + self, + model_uri: str, + model_file_type: ModelFileType = ModelFileType.PICKLE): + self._model_uri = model_uri + self._model_file_type = model_file_type + + def load_model(self) -> BaseEstimator: + """Loads and initializes a model for processing.""" + return _load_model(self._model_uri, self._model_file_type) + + def run_inference( + self, batch: Sequence[pandas.DataFrame], model: BaseEstimator, + **kwargs) -> Iterable[PredictionResult]: # sklearn_inference currently only supports single rowed dataframes. for dataframe in batch: if dataframe.shape[0] != 1: @@ -112,12 +125,6 @@ def _predict_pandas_dataframe(batch: Sequence[pandas.DataFrame], inference in zip(splits, predictions) ] - def get_num_bytes( - self, batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]]) -> int: + def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int: """Returns the number of bytes of data for a batch.""" - if isinstance(batch[0], numpy.ndarray): - return sum(sys.getsizeof(element) for element in batch) - elif isinstance(batch[0], pandas.DataFrame): - data_frames: Sequence[pandas.DataFrame] = batch - return sum(df.memory_usage(deep=True).sum() for df in data_frames) - raise ValueError('Unsupported data type.') + return sum(df.memory_usage(deep=True).sum() for df in batch) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 1e788525dc04..2c63de25f992 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -41,7 +41,8 @@ from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.sklearn_inference import ModelFileType -from apache_beam.ml.inference.sklearn_inference import SklearnModelHandler +from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy +from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerPandas from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -130,7 +131,7 @@ def tearDown(self): def test_predict_output(self): fake_model = FakeModel() - inference_runner = SklearnModelHandler(model_uri='unused') + inference_runner = SklearnModelHandlerNumpy(model_uri='unused') batched_examples = [ numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9]) ] @@ -145,7 +146,7 @@ def test_predict_output(self): def test_data_vectorized(self): fake_model = FakeModel() - inference_runner = SklearnModelHandler(model_uri='unused') + inference_runner = SklearnModelHandlerNumpy(model_uri='unused') batched_examples = [ numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9]) ] @@ -154,8 +155,8 @@ def test_data_vectorized(self): inference_runner.run_inference(batched_examples, fake_model) self.assertEqual(1, fake_model.total_predict_calls) - def test_num_bytes(self): - inference_runner = SklearnModelHandler(model_uri='unused') + def test_num_bytes_numpy(self): + inference_runner = SklearnModelHandlerNumpy(model_uri='unused') batched_examples_int = [ numpy.array([1, 2, 3]), numpy.array([4, 5, 6]), numpy.array([7, 8, 9]) ] @@ -181,9 +182,8 @@ def test_pipeline_pickled(self): examples = [numpy.array([0, 0]), numpy.array([1, 1])] pcoll = pipeline | 'start' >> beam.Create(examples) - #TODO(BEAM-14305) Test against the public API. actual = pcoll | RunInference( - SklearnModelHandler(model_uri=temp_file_name)) + SklearnModelHandlerNumpy(model_uri=temp_file_name)) expected = [ PredictionResult(numpy.array([0, 0]), 0), PredictionResult(numpy.array([1, 1]), 1) @@ -200,10 +200,9 @@ def test_pipeline_joblib(self): examples = [numpy.array([0, 0]), numpy.array([1, 1])] pcoll = pipeline | 'start' >> beam.Create(examples) - #TODO(BEAM-14305) Test against the public API. actual = pcoll | RunInference( - SklearnModelHandler( + SklearnModelHandlerNumpy( model_uri=temp_file_name, model_file_type=ModelFileType.JOBLIB)) expected = [ PredictionResult(numpy.array([0, 0]), 0), @@ -217,9 +216,8 @@ def test_bad_file_raises(self): with TestPipeline() as pipeline: examples = [numpy.array([0, 0])] pcoll = pipeline | 'start' >> beam.Create(examples) - # TODO(BEAM-14305) Test against the public API. _ = pcoll | RunInference( - SklearnModelHandler(model_uri='/var/bad_file_name')) + SklearnModelHandlerNumpy(model_uri='/var/bad_file_name')) pipeline.run() @unittest.skipIf(platform.system() == 'Windows', 'BEAM-14359') @@ -227,7 +225,7 @@ def test_bad_input_type_raises(self): with self.assertRaisesRegex(AssertionError, 'Unsupported serialization type'): with tempfile.NamedTemporaryFile() as file: - model_loader = SklearnModelHandler( + model_loader = SklearnModelHandlerNumpy( model_uri=file.name, model_file_type=None) model_loader.load_model() @@ -241,7 +239,30 @@ def test_pipeline_pandas(self): splits = [dataframe.loc[[i]] for i in dataframe.index] pcoll = pipeline | 'start' >> beam.Create(splits) actual = pcoll | RunInference( - SklearnModelHandler(model_uri=temp_file_name)) + SklearnModelHandlerPandas(model_uri=temp_file_name)) + + expected = [ + PredictionResult(splits[0], 5), + PredictionResult(splits[1], 8), + PredictionResult(splits[2], 1), + PredictionResult(splits[3], 1), + PredictionResult(splits[4], 2), + ] + assert_that( + actual, equal_to(expected, equals_fn=_compare_dataframe_predictions)) + + @unittest.skipIf(platform.system() == 'Windows', 'BEAM-14359') + def test_pipeline_pandas_joblib(self): + temp_file_name = self.tmpdir + os.sep + 'pickled_file' + with open(temp_file_name, 'wb') as file: + joblib.dump(build_pandas_pipeline(), file) + with TestPipeline() as pipeline: + dataframe = pandas_dataframe() + splits = [dataframe.loc[[i]] for i in dataframe.index] + pcoll = pipeline | 'start' >> beam.Create(splits) + actual = pcoll | RunInference( + SklearnModelHandlerPandas( + model_uri=temp_file_name, model_file_type=ModelFileType.JOBLIB)) expected = [ PredictionResult(splits[0], 5), @@ -266,7 +287,8 @@ def test_pipeline_pandas_with_keys(self): pcoll = pipeline | 'start' >> beam.Create(keyed_rows) actual = pcoll | RunInference( - KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name))) + KeyedModelHandler( + SklearnModelHandlerPandas(model_uri=temp_file_name))) expected = [ ('0', PredictionResult(splits[0], 5)), ('1', PredictionResult(splits[1], 8)), @@ -277,18 +299,11 @@ def test_pipeline_pandas_with_keys(self): assert_that( actual, equal_to(expected, equals_fn=_compare_dataframe_predictions)) - def test_infer_invalid_data_type(self): - with self.assertRaises(ValueError): - unexpected_input_type = [[1, 2, 3, 4], [5, 6, 7, 8]] - inference_runner = SklearnModelHandler(model_uri='unused') - fake_model = FakeModel() - inference_runner.run_inference(unexpected_input_type, fake_model) - def test_infer_too_many_rows_in_dataframe(self): with self.assertRaises(ValueError): data_frame_too_many_rows = pandas_dataframe() - inference_runner = SklearnModelHandler(model_uri='unused') fake_model = FakeModel() + inference_runner = SklearnModelHandlerPandas(model_uri='unused') inference_runner.run_inference([data_frame_too_many_rows], fake_model)