From 27dbd4f3eae5a9e008b607cd49b4deef1e697d4d Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Fri, 10 Jun 2022 17:10:59 -0400 Subject: [PATCH 1/2] Start to split of Pytorch handlers --- .../ml/inference/pytorch_inference.py | 183 ++++++++++++++++++ .../ml/inference/pytorch_inference_test.py | 12 +- .../ml/inference/sklearn_inference_test.py | 4 +- 3 files changed, 192 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 3a4fb2926f81..4ba1f71ad496 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -29,6 +29,7 @@ from apache_beam.io.filesystems import FileSystems from apache_beam.ml.inference.api import PredictionResult from apache_beam.ml.inference.base import ModelHandler +from apache_beam.utils.annotations import experimental class PytorchModelHandler(ModelHandler[torch.Tensor, @@ -134,3 +135,185 @@ def get_metrics_namespace(self) -> str: Returns a namespace for metrics collected by the RunInference transform. """ return 'RunInferencePytorch' + + +@experimental() +class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, + PredictionResult, + torch.nn.Module]): + """ Implementation of the ModelHandler interface for PyTorch. + + NOTE: This API and its implementation are under development and + do not provide backward compatibility guarantees. + """ + def __init__( + self, + state_dict_path: str, + model_class: Callable[..., torch.nn.Module], + model_params: Dict[str, Any], + device: str = 'CPU'): + """ + Initializes a PytorchModelHandler + :param state_dict_path: path to the saved dictionary of the model state. + :param model_class: class of the Pytorch model that defines the model + structure. + :param device: the device on which you wish to run the model. If + ``device = GPU`` then a GPU device will be used if it is available. + Otherwise, it will be CPU. + + See https://pytorch.org/tutorials/beginner/saving_loading_models.html + for details + """ + self._state_dict_path = state_dict_path + if device == 'GPU' and torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self._model_class = model_class + self._model_params = model_params + + def load_model(self) -> torch.nn.Module: + """Loads and initializes a Pytorch model for processing.""" + model = self._model_class(**self._model_params) + model.to(self._device) + file = FileSystems.open(self._state_dict_path, 'rb') + model.load_state_dict(torch.load(file)) + model.eval() + return model + + def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor: + """ + Converts samples to a style matching given device. + + Note: A user may pass in device='GPU' but if GPU is not detected in the + environment it must be converted back to CPU. + """ + if examples.device != self._device: + examples = examples.to(self._device) + return examples + + def run_inference( + self, batch: List[torch.Tensor], model: torch.nn.Module, + **kwargs) -> Iterable[PredictionResult]: + """ + Runs inferences on a batch of Tensors and returns an Iterable of + Tensor Predictions. + + This method stacks the list of Tensors in a vectorized format to optimize + the inference call. + """ + prediction_params = kwargs.get('prediction_params', {}) + batched_tensors = torch.stack(batch) + batched_tensors = self._convert_to_device(batched_tensors) + predictions = model(batched_tensors, **prediction_params) + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def get_num_bytes(self, batch: List[torch.Tensor]) -> int: + """Returns the number of bytes of data for a batch of Tensors.""" + return sum((el.element_size() for tensor in batch for el in tensor)) + + def get_metrics_namespace(self) -> str: + """ + Returns a namespace for metrics collected by the RunInference transform. + """ + return 'RunInferencePytorch' + + +@experimental() +class PytorchModelHandlerKeyedTensor(ModelHandler[torch.Tensor, + PredictionResult, + torch.nn.Module]): + """ Implementation of the ModelHandler interface for PyTorch. + + NOTE: This API and its implementation are under development and + do not provide backward compatibility guarantees. + """ + def __init__( + self, + state_dict_path: str, + model_class: Callable[..., torch.nn.Module], + model_params: Dict[str, Any], + device: str = 'CPU'): + """ + Initializes a PytorchModelHandler + :param state_dict_path: path to the saved dictionary of the model state. + :param model_class: class of the Pytorch model that defines the model + structure. + :param device: the device on which you wish to run the model. If + ``device = GPU`` then a GPU device will be used if it is available. + Otherwise, it will be CPU. + + See https://pytorch.org/tutorials/beginner/saving_loading_models.html + for details + """ + self._state_dict_path = state_dict_path + if device == 'GPU' and torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self._model_class = model_class + self._model_params = model_params + + def load_model(self) -> torch.nn.Module: + """Loads and initializes a Pytorch model for processing.""" + model = self._model_class(**self._model_params) + model.to(self._device) + file = FileSystems.open(self._state_dict_path, 'rb') + model.load_state_dict(torch.load(file)) + model.eval() + return model + + def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor: + """ + Converts samples to a style matching given device. + + Note: A user may pass in device='GPU' but if GPU is not detected in the + environment it must be converted back to CPU. + """ + if examples.device != self._device: + examples = examples.to(self._device) + return examples + + def run_inference( + self, + batch: List[Dict[str, torch.Tensor]], + model: torch.nn.Module, + **kwargs) -> Iterable[PredictionResult]: + """ + Runs inferences on a batch of Tensors and returns an Iterable of + Tensor Predictions. + + This method stacks the list of Tensors in a vectorized format to optimize + the inference call. + """ + prediction_params = kwargs.get('prediction_params', {}) + + # If elements in `batch` are provided as a dictionaries from key to Tensors, + # then iterate through the batch list, and group Tensors to the same key + key_to_tensor_list = defaultdict(list) + for example in batch: + for key, tensor in example.items(): + key_to_tensor_list[key].append(tensor) + key_to_batched_tensors = {} + for key in key_to_tensor_list: + batched_tensors = torch.stack(key_to_tensor_list[key]) + batched_tensors = self._convert_to_device(batched_tensors) + key_to_batched_tensors[key] = batched_tensors + predictions = model(**key_to_batched_tensors, **prediction_params) + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def get_num_bytes(self, batch: List[torch.Tensor]) -> int: + """Returns the number of bytes of data for a batch of Tensors.""" + # If elements in `batch` are provided as a dictionaries from key to Tensors + if isinstance(batch[0], dict): + return sum( + (el.element_size() for tensor in batch for el in tensor.values())) + else: + # If elements in `batch` are provided as Tensors + return sum((el.element_size() for tensor in batch for el in tensor)) + + def get_metrics_namespace(self) -> str: + """ + Returns a namespace for metrics collected by the RunInference transform. + """ + return 'RunInferencePytorch' diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 7f563d7cf4c4..c4401c5a9490 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -38,6 +38,8 @@ from apache_beam.ml.inference.api import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler + from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor + from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor except ImportError: raise unittest.SkipTest('PyTorch dependencies are not installed') @@ -274,7 +276,7 @@ def test_pipeline_local_model_simple(self): path = os.path.join(self.tmpdir, 'my_state_dict_path') torch.save(state_dict, path) - model_loader = PytorchModelHandler( + model_loader = PytorchModelHandlerTensor( state_dict_path=path, model_class=PytorchLinearRegression, model_params={ @@ -301,7 +303,7 @@ def test_pipeline_local_model_kwargs_prediction_params(self): path = os.path.join(self.tmpdir, 'my_state_dict_path') torch.save(state_dict, path) - model_loader = PytorchModelHandler( + model_loader = PytorchModelHandlerKeyedTensor( state_dict_path=path, model_class=PytorchLinearRegressionKwargsPredictionParams, model_params={ @@ -312,7 +314,7 @@ def test_pipeline_local_model_kwargs_prediction_params(self): prediction_params_side_input = ( pipeline | 'create side' >> beam.Create(prediction_params)) predictions = pcoll | RunInference( - model_loader=model_loader, + model_handler=model_loader, prediction_params=beam.pvalue.AsDict(prediction_params_side_input)) assert_that( predictions, @@ -334,7 +336,7 @@ def test_pipeline_gcs_model(self): gs_pth = 'gs://apache-beam-ml/models/' \ 'pytorch_lin_reg_model_2x+0.5_state_dict.pth' - model_loader = PytorchModelHandler( + model_loader = PytorchModelHandlerTensor( state_dict_path=gs_pth, model_class=PytorchLinearRegression, model_params={ @@ -357,7 +359,7 @@ def test_invalid_input_type(self): path = os.path.join(self.tmpdir, 'my_state_dict_path') torch.save(state_dict, path) - model_loader = PytorchModelHandler( + model_loader = PytorchModelHandlerTensor( state_dict_path=path, model_class=PytorchLinearRegression, model_params={ diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 91eb86e2de4b..828cfb69a53f 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -226,9 +226,9 @@ def test_bad_input_type_raises(self): with self.assertRaisesRegex(AssertionError, 'Unsupported serialization type'): with tempfile.NamedTemporaryFile() as file: - model_loader = SklearnModelHandler( + model_handler = SklearnModelHandler( model_uri=file.name, model_file_type=None) - model_loader.load_model() + model_handler.load_model() @unittest.skipIf(platform.system() == 'Windows', 'BEAM-14359') def test_pipeline_pandas(self): From 718669e9f75f620bded24db35edf7138620b0b81 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Mon, 13 Jun 2022 07:41:44 -0400 Subject: [PATCH 2/2] Remove old PytorchModelHandler --- .../ml/inference/pytorch_inference.py | 197 ++++-------------- .../ml/inference/pytorch_inference_test.py | 13 +- 2 files changed, 48 insertions(+), 162 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 4ba1f71ad496..c889d093be90 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -23,129 +23,39 @@ from typing import Dict from typing import Iterable from typing import List -from typing import Union import torch from apache_beam.io.filesystems import FileSystems from apache_beam.ml.inference.api import PredictionResult from apache_beam.ml.inference.base import ModelHandler -from apache_beam.utils.annotations import experimental -class PytorchModelHandler(ModelHandler[torch.Tensor, - PredictionResult, - torch.nn.Module]): - """ Implementation of the ModelHandler interface for PyTorch. - - NOTE: This API and its implementation are under development and - do not provide backward compatibility guarantees. - """ - def __init__( - self, - state_dict_path: str, - model_class: Callable[..., torch.nn.Module], - model_params: Dict[str, Any], - device: str = 'CPU'): - """ - Initializes a PytorchModelHandler - :param state_dict_path: path to the saved dictionary of the model state. - :param model_class: class of the Pytorch model that defines the model - structure. - :param device: the device on which you wish to run the model. If - ``device = GPU`` then a GPU device will be used if it is available. - Otherwise, it will be CPU. - - See https://pytorch.org/tutorials/beginner/saving_loading_models.html - for details - """ - self._state_dict_path = state_dict_path - if device == 'GPU' and torch.cuda.is_available(): - self._device = torch.device('cuda') - else: - self._device = torch.device('cpu') - self._model_class = model_class - self._model_params = model_params - - def load_model(self) -> torch.nn.Module: - """Loads and initializes a Pytorch model for processing.""" - model = self._model_class(**self._model_params) - model.to(self._device) - file = FileSystems.open(self._state_dict_path, 'rb') - model.load_state_dict(torch.load(file)) - model.eval() - return model - - def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor: - """ - Converts samples to a style matching given device. - - Note: A user may pass in device='GPU' but if GPU is not detected in the - environment it must be converted back to CPU. - """ - if examples.device != self._device: - examples = examples.to(self._device) - return examples - - def run_inference( - self, - batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], - model: torch.nn.Module, - **kwargs) -> Iterable[PredictionResult]: - """ - Runs inferences on a batch of Tensors and returns an Iterable of - Tensor Predictions. +def _load_model( + model_class: torch.nn.Module, state_dict_path, device, **model_params): + model = model_class(**model_params) + model.to(device) + file = FileSystems.open(state_dict_path, 'rb') + model.load_state_dict(torch.load(file)) + model.eval() + return model - This method stacks the list of Tensors in a vectorized format to optimize - the inference call. - """ - prediction_params = kwargs.get('prediction_params', {}) - # If elements in `batch` are provided as a dictionaries from key to Tensors, - # then iterate through the batch list, and group Tensors to the same key - if isinstance(batch[0], dict): - key_to_tensor_list = defaultdict(list) - for example in batch: - for key, tensor in example.items(): - key_to_tensor_list[key].append(tensor) - key_to_batched_tensors = {} - for key in key_to_tensor_list: - batched_tensors = torch.stack(key_to_tensor_list[key]) - batched_tensors = self._convert_to_device(batched_tensors) - key_to_batched_tensors[key] = batched_tensors - predictions = model(**key_to_batched_tensors, **prediction_params) - else: - # If elements in `batch` are provided as Tensors, then do a regular stack - batched_tensors = torch.stack(batch) - batched_tensors = self._convert_to_device(batched_tensors) - predictions = model(batched_tensors, **prediction_params) - return [PredictionResult(x, y) for x, y in zip(batch, predictions)] - - def get_num_bytes(self, batch: List[torch.Tensor]) -> int: - """Returns the number of bytes of data for a batch of Tensors.""" - # If elements in `batch` are provided as a dictionaries from key to Tensors - if isinstance(batch[0], dict): - return sum( - (el.element_size() for tensor in batch for el in tensor.values())) - else: - # If elements in `batch` are provided as Tensors - return sum((el.element_size() for tensor in batch for el in tensor)) +def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor: + """ + Converts samples to a style matching given device. - def get_metrics_namespace(self) -> str: - """ - Returns a namespace for metrics collected by the RunInference transform. - """ - return 'RunInferencePytorch' + Note: A user may pass in device='GPU' but if GPU is not detected in the + environment it must be converted back to CPU. + """ + if examples.device != device: + examples = examples.to(device) + return examples -@experimental() class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, PredictionResult, torch.nn.Module]): - """ Implementation of the ModelHandler interface for PyTorch. - - NOTE: This API and its implementation are under development and - do not provide backward compatibility guarantees. - """ + """ Implementation of the ModelHandler interface for PyTorch.""" def __init__( self, state_dict_path: str, @@ -153,7 +63,7 @@ def __init__( model_params: Dict[str, Any], device: str = 'CPU'): """ - Initializes a PytorchModelHandler + Initializes a PytorchModelHandlerTensor :param state_dict_path: path to the saved dictionary of the model state. :param model_class: class of the Pytorch model that defines the model structure. @@ -174,23 +84,11 @@ def __init__( def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" - model = self._model_class(**self._model_params) - model.to(self._device) - file = FileSystems.open(self._state_dict_path, 'rb') - model.load_state_dict(torch.load(file)) - model.eval() - return model - - def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor: - """ - Converts samples to a style matching given device. - - Note: A user may pass in device='GPU' but if GPU is not detected in the - environment it must be converted back to CPU. - """ - if examples.device != self._device: - examples = examples.to(self._device) - return examples + return _load_model( + self._model_class, + self._state_dict_path, + self._device, + **self._model_params) def run_inference( self, batch: List[torch.Tensor], model: torch.nn.Module, @@ -204,7 +102,7 @@ def run_inference( """ prediction_params = kwargs.get('prediction_params', {}) batched_tensors = torch.stack(batch) - batched_tensors = self._convert_to_device(batched_tensors) + batched_tensors = _convert_to_device(batched_tensors, self._device) predictions = model(batched_tensors, **prediction_params) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] @@ -219,8 +117,7 @@ def get_metrics_namespace(self) -> str: return 'RunInferencePytorch' -@experimental() -class PytorchModelHandlerKeyedTensor(ModelHandler[torch.Tensor, +class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], PredictionResult, torch.nn.Module]): """ Implementation of the ModelHandler interface for PyTorch. @@ -235,7 +132,7 @@ def __init__( model_params: Dict[str, Any], device: str = 'CPU'): """ - Initializes a PytorchModelHandler + Initializes a PytorchModelHandlerKeyedTensor :param state_dict_path: path to the saved dictionary of the model state. :param model_class: class of the Pytorch model that defines the model structure. @@ -256,23 +153,11 @@ def __init__( def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" - model = self._model_class(**self._model_params) - model.to(self._device) - file = FileSystems.open(self._state_dict_path, 'rb') - model.load_state_dict(torch.load(file)) - model.eval() - return model - - def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor: - """ - Converts samples to a style matching given device. - - Note: A user may pass in device='GPU' but if GPU is not detected in the - environment it must be converted back to CPU. - """ - if examples.device != self._device: - examples = examples.to(self._device) - return examples + return _load_model( + self._model_class, + self._state_dict_path, + self._device, + **self._model_params) def run_inference( self, @@ -280,11 +165,11 @@ def run_inference( model: torch.nn.Module, **kwargs) -> Iterable[PredictionResult]: """ - Runs inferences on a batch of Tensors and returns an Iterable of + Runs inferences on a batch of Keyed Tensors and returns an Iterable of Tensor Predictions. - This method stacks the list of Tensors in a vectorized format to optimize - the inference call. + For the same key across all examples, this will stack all Tensors values + in a vectorized format to optimize the inference call. """ prediction_params = kwargs.get('prediction_params', {}) @@ -297,20 +182,16 @@ def run_inference( key_to_batched_tensors = {} for key in key_to_tensor_list: batched_tensors = torch.stack(key_to_tensor_list[key]) - batched_tensors = self._convert_to_device(batched_tensors) + batched_tensors = _convert_to_device(batched_tensors, self._device) key_to_batched_tensors[key] = batched_tensors predictions = model(**key_to_batched_tensors, **prediction_params) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def get_num_bytes(self, batch: List[torch.Tensor]) -> int: - """Returns the number of bytes of data for a batch of Tensors.""" + """Returns the number of bytes of data for a batch of Dict of Tensors.""" # If elements in `batch` are provided as a dictionaries from key to Tensors - if isinstance(batch[0], dict): - return sum( - (el.element_size() for tensor in batch for el in tensor.values())) - else: - # If elements in `batch` are provided as Tensors - return sum((el.element_size() for tensor in batch for el in tensor)) + return sum( + (el.element_size() for tensor in batch for el in tensor.values())) def get_metrics_namespace(self) -> str: """ diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index c4401c5a9490..10cb4e9c8b80 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -37,7 +37,6 @@ import torch from apache_beam.ml.inference.api import PredictionResult from apache_beam.ml.inference.base import RunInference - from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor except ImportError: @@ -92,7 +91,13 @@ ] -class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandler): +class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandlerTensor): + def __init__(self, device): + self._device = device + + +class TestPytorchModelHandlerKeyedTensorForInferenceOnly( + PytorchModelHandlerKeyedTensor): def __init__(self, device): self._device = device @@ -211,7 +216,7 @@ def forward(self, k1, k2): ('linear.bias', torch.Tensor([0.5]))])) model.eval() - inference_runner = TestPytorchModelHandlerForInferenceOnly( + inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly( torch.device('cpu')) predictions = inference_runner.run_inference(KWARGS_TORCH_EXAMPLES, model) for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS): @@ -236,7 +241,7 @@ def test_run_inference_kwargs_prediction_params(self): ('linear.bias', torch.Tensor([0.5]))])) model.eval() - inference_runner = TestPytorchModelHandlerForInferenceOnly( + inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly( torch.device('cpu')) predictions = inference_runner.run_inference( batch=KWARGS_TORCH_EXAMPLES,