diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index d8ab31b8b708..959bce4778eb 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -23,7 +23,6 @@ from typing import Dict from typing import Iterable from typing import Sequence -from typing import Union import torch from apache_beam.io.filesystems import FileSystems @@ -31,14 +30,32 @@ from apache_beam.ml.inference.base import PredictionResult -class PytorchModelHandler(ModelHandler[torch.Tensor, - PredictionResult, - torch.nn.Module]): - """ Implementation of the ModelHandler interface for PyTorch. +def _load_model( + model_class: torch.nn.Module, state_dict_path, device, **model_params): + model = model_class(**model_params) + model.to(device) + file = FileSystems.open(state_dict_path, 'rb') + model.load_state_dict(torch.load(file)) + model.eval() + return model - NOTE: This API and its implementation are under development and - do not provide backward compatibility guarantees. + +def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor: """ + Converts samples to a style matching given device. + + Note: A user may pass in device='GPU' but if GPU is not detected in the + environment it must be converted back to CPU. + """ + if examples.device != device: + examples = examples.to(device) + return examples + + +class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, + PredictionResult, + torch.nn.Module]): + """ Implementation of the ModelHandler interface for PyTorch.""" def __init__( self, state_dict_path: str, @@ -46,7 +63,7 @@ def __init__( model_params: Dict[str, Any], device: str = 'CPU'): """ - Initializes a PytorchModelHandler + Initializes a PytorchModelHandlerTensor :param state_dict_path: path to the saved dictionary of the model state. :param model_class: class of the Pytorch model that defines the model structure. @@ -67,67 +84,114 @@ def __init__( def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" - model = self._model_class(**self._model_params) - model.to(self._device) - file = FileSystems.open(self._state_dict_path, 'rb') - model.load_state_dict(torch.load(file)) - model.eval() - return model - - def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor: + return _load_model( + self._model_class, + self._state_dict_path, + self._device, + **self._model_params) + + def run_inference( + self, batch: Sequence[torch.Tensor], model: torch.nn.Module, + **kwargs) -> Iterable[PredictionResult]: + """ + Runs inferences on a batch of Tensors and returns an Iterable of + Tensor Predictions. + + This method stacks the list of Tensors in a vectorized format to optimize + the inference call. + """ + prediction_params = kwargs.get('prediction_params', {}) + batched_tensors = torch.stack(batch) + batched_tensors = _convert_to_device(batched_tensors, self._device) + predictions = model(batched_tensors, **prediction_params) + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: + """Returns the number of bytes of data for a batch of Tensors.""" + return sum((el.element_size() for tensor in batch for el in tensor)) + + def get_metrics_namespace(self) -> str: + """ + Returns a namespace for metrics collected by the RunInference transform. + """ + return 'RunInferencePytorch' + + +class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], + PredictionResult, + torch.nn.Module]): + """ Implementation of the ModelHandler interface for PyTorch. + + NOTE: This API and its implementation are under development and + do not provide backward compatibility guarantees. + """ + def __init__( + self, + state_dict_path: str, + model_class: Callable[..., torch.nn.Module], + model_params: Dict[str, Any], + device: str = 'CPU'): """ - Converts samples to a style matching given device. + Initializes a PytorchModelHandlerKeyedTensor + :param state_dict_path: path to the saved dictionary of the model state. + :param model_class: class of the Pytorch model that defines the model + structure. + :param device: the device on which you wish to run the model. If + ``device = GPU`` then a GPU device will be used if it is available. + Otherwise, it will be CPU. - Note: A user may pass in device='GPU' but if GPU is not detected in the - environment it must be converted back to CPU. + See https://pytorch.org/tutorials/beginner/saving_loading_models.html + for details """ - if examples.device != self._device: - examples = examples.to(self._device) - return examples + self._state_dict_path = state_dict_path + if device == 'GPU' and torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self._model_class = model_class + self._model_params = model_params + + def load_model(self) -> torch.nn.Module: + """Loads and initializes a Pytorch model for processing.""" + return _load_model( + self._model_class, + self._state_dict_path, + self._device, + **self._model_params) def run_inference( self, - batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]], + batch: Sequence[Dict[str, torch.Tensor]], model: torch.nn.Module, **kwargs) -> Iterable[PredictionResult]: """ - Runs inferences on a batch of Tensors and returns an Iterable of + Runs inferences on a batch of Keyed Tensors and returns an Iterable of Tensor Predictions. - This method stacks the list of Tensors in a vectorized format to optimize - the inference call. + For the same key across all examples, this will stack all Tensors values + in a vectorized format to optimize the inference call. """ prediction_params = kwargs.get('prediction_params', {}) # If elements in `batch` are provided as a dictionaries from key to Tensors, # then iterate through the batch list, and group Tensors to the same key - if isinstance(batch[0], dict): - key_to_tensor_list = defaultdict(list) - for example in batch: - for key, tensor in example.items(): - key_to_tensor_list[key].append(tensor) - key_to_batched_tensors = {} - for key in key_to_tensor_list: - batched_tensors = torch.stack(key_to_tensor_list[key]) - batched_tensors = self._convert_to_device(batched_tensors) - key_to_batched_tensors[key] = batched_tensors - predictions = model(**key_to_batched_tensors, **prediction_params) - else: - # If elements in `batch` are provided as Tensors, then do a regular stack - batched_tensors = torch.stack(batch) - batched_tensors = self._convert_to_device(batched_tensors) - predictions = model(batched_tensors, **prediction_params) + key_to_tensor_list = defaultdict(list) + for example in batch: + for key, tensor in example.items(): + key_to_tensor_list[key].append(tensor) + key_to_batched_tensors = {} + for key in key_to_tensor_list: + batched_tensors = torch.stack(key_to_tensor_list[key]) + batched_tensors = _convert_to_device(batched_tensors, self._device) + key_to_batched_tensors[key] = batched_tensors + predictions = model(**key_to_batched_tensors, **prediction_params) return [PredictionResult(x, y) for x, y in zip(batch, predictions)] def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: - """Returns the number of bytes of data for a batch of Tensors.""" + """Returns the number of bytes of data for a batch of Dict of Tensors.""" # If elements in `batch` are provided as a dictionaries from key to Tensors - if isinstance(batch[0], dict): - return sum( - (el.element_size() for tensor in batch for el in tensor.values())) - else: - # If elements in `batch` are provided as Tensors - return sum((el.element_size() for tensor in batch for el in tensor)) + return sum( + (el.element_size() for tensor in batch for el in tensor.values())) def get_metrics_namespace(self) -> str: """ diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index ad51a4e77f7b..d852dd72bb74 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -37,7 +37,8 @@ import torch from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference - from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler + from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor + from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor except ImportError: raise unittest.SkipTest('PyTorch dependencies are not installed') @@ -90,7 +91,13 @@ ] -class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandler): +class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandlerTensor): + def __init__(self, device): + self._device = device + + +class TestPytorchModelHandlerKeyedTensorForInferenceOnly( + PytorchModelHandlerKeyedTensor): def __init__(self, device): self._device = device @@ -209,7 +216,7 @@ def forward(self, k1, k2): ('linear.bias', torch.Tensor([0.5]))])) model.eval() - inference_runner = TestPytorchModelHandlerForInferenceOnly( + inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly( torch.device('cpu')) predictions = inference_runner.run_inference(KWARGS_TORCH_EXAMPLES, model) for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS): @@ -234,7 +241,7 @@ def test_run_inference_kwargs_prediction_params(self): ('linear.bias', torch.Tensor([0.5]))])) model.eval() - inference_runner = TestPytorchModelHandlerForInferenceOnly( + inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly( torch.device('cpu')) predictions = inference_runner.run_inference( batch=KWARGS_TORCH_EXAMPLES, @@ -274,7 +281,7 @@ def test_pipeline_local_model_simple(self): path = os.path.join(self.tmpdir, 'my_state_dict_path') torch.save(state_dict, path) - model_handler = PytorchModelHandler( + model_handler = PytorchModelHandlerTensor( state_dict_path=path, model_class=PytorchLinearRegression, model_params={ @@ -301,7 +308,7 @@ def test_pipeline_local_model_kwargs_prediction_params(self): path = os.path.join(self.tmpdir, 'my_state_dict_path') torch.save(state_dict, path) - model_handler = PytorchModelHandler( + model_handler = PytorchModelHandlerKeyedTensor( state_dict_path=path, model_class=PytorchLinearRegressionKwargsPredictionParams, model_params={ @@ -334,7 +341,7 @@ def test_pipeline_gcs_model(self): gs_pth = 'gs://apache-beam-ml/models/' \ 'pytorch_lin_reg_model_2x+0.5_state_dict.pth' - model_handler = PytorchModelHandler( + model_handler = PytorchModelHandlerTensor( state_dict_path=gs_pth, model_class=PytorchLinearRegression, model_params={ @@ -357,7 +364,7 @@ def test_invalid_input_type(self): path = os.path.join(self.tmpdir, 'my_state_dict_path') torch.save(state_dict, path) - model_handler = PytorchModelHandler( + model_handler = PytorchModelHandlerTensor( state_dict_path=path, model_class=PytorchLinearRegression, model_params={ diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index 2c63de25f992..ecd81d204d6d 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -225,9 +225,9 @@ def test_bad_input_type_raises(self): with self.assertRaisesRegex(AssertionError, 'Unsupported serialization type'): with tempfile.NamedTemporaryFile() as file: - model_loader = SklearnModelHandlerNumpy( + model_handler = SklearnModelHandlerNumpy( model_uri=file.name, model_file_type=None) - model_loader.load_model() + model_handler.load_model() @unittest.skipIf(platform.system() == 'Windows', 'BEAM-14359') def test_pipeline_pandas(self):