diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 46938ad619d8..5428e8bf4cac 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -38,6 +38,18 @@ 'PytorchModelHandlerKeyedTensor', ] +TensorInferenceFn = Callable[ + [Sequence[torch.Tensor], torch.nn.Module, str, Optional[Dict[str, Any]]], + Iterable[PredictionResult]] + +KeyedTensorInferenceFn = Callable[[ + Sequence[Dict[str, torch.Tensor]], + torch.nn.Module, + str, + Optional[Dict[str, Any]] +], + Iterable[PredictionResult]] + def _load_model( model_class: torch.nn.Module, state_dict_path, device, **model_params): @@ -100,6 +112,46 @@ def _convert_to_result( return [PredictionResult(x, y) for x, y in zip(batch, predictions)] +def default_tensor_inference_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: + # torch.no_grad() mitigates GPU memory issues + # https://github.com/apache/beam/issues/22811 + with torch.no_grad(): + batched_tensors = torch.stack(batch) + batched_tensors = _convert_to_device(batched_tensors, device) + predictions = model(batched_tensors, **inference_args) + return _convert_to_result(batch, predictions) + + +def make_tensor_model_fn(model_fn: str) -> TensorInferenceFn: + """ + Produces a TensorInferenceFn that uses a method of the model other that + the forward() method. + + Args: + model_fn: A string name of the method to be used. This is accessed through + getattr(model, model_fn) + """ + def attr_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + with torch.no_grad(): + batched_tensors = torch.stack(batch) + batched_tensors = _convert_to_device(batched_tensors, device) + pred_fn = getattr(model, model_fn) + predictions = pred_fn(batched_tensors, **inference_args) + return _convert_to_result(batch, predictions) + + return attr_fn + + class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, PredictionResult, torch.nn.Module]): @@ -108,7 +160,9 @@ def __init__( state_dict_path: str, model_class: Callable[..., torch.nn.Module], model_params: Dict[str, Any], - device: str = 'CPU'): + device: str = 'CPU', + *, + inference_fn: TensorInferenceFn = default_tensor_inference_fn): """Implementation of the ModelHandler interface for PyTorch. Example Usage:: @@ -127,6 +181,8 @@ def __init__( device: the device on which you wish to run the model. If ``device = GPU`` then a GPU device will be used if it is available. Otherwise, it will be CPU. + inference_fn: the inference function to use during RunInference. + default=_default_tensor_inference_fn **Supported Versions:** RunInference APIs in Apache Beam have been tested with PyTorch 1.9 and 1.10. @@ -140,6 +196,7 @@ def __init__( self._device = torch.device('cpu') self._model_class = model_class self._model_params = model_params + self._inference_fn = inference_fn def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -179,13 +236,7 @@ def run_inference( """ inference_args = {} if not inference_args else inference_args - # torch.no_grad() mitigates GPU memory issues - # https://github.com/apache/beam/issues/22811 - with torch.no_grad(): - batched_tensors = torch.stack(batch) - batched_tensors = _convert_to_device(batched_tensors, self._device) - predictions = model(batched_tensors, **inference_args) - return _convert_to_result(batch, predictions) + return self._inference_fn(batch, model, self._device, inference_args) def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """ @@ -205,6 +256,69 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): pass +def default_keyed_tensor_inference_fn( + batch: Sequence[Dict[str, torch.Tensor]], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, + Any]] = None) -> Iterable[PredictionResult]: + # If elements in `batch` are provided as a dictionaries from key to Tensors, + # then iterate through the batch list, and group Tensors to the same key + key_to_tensor_list = defaultdict(list) + + # torch.no_grad() mitigates GPU memory issues + # https://github.com/apache/beam/issues/22811 + with torch.no_grad(): + for example in batch: + for key, tensor in example.items(): + key_to_tensor_list[key].append(tensor) + key_to_batched_tensors = {} + for key in key_to_tensor_list: + batched_tensors = torch.stack(key_to_tensor_list[key]) + batched_tensors = _convert_to_device(batched_tensors, device) + key_to_batched_tensors[key] = batched_tensors + predictions = model(**key_to_batched_tensors, **inference_args) + + return _convert_to_result(batch, predictions) + + +def make_keyed_tensor_model_fn(model_fn: str) -> KeyedTensorInferenceFn: + """ + Produces a KeyedTensorInferenceFn that uses a method of the model other that + the forward() method. + + Args: + model_fn: A string name of the method to be used. This is accessed through + getattr(model, model_fn) + """ + def attr_fn( + batch: Sequence[torch.Tensor], + model: torch.nn.Module, + device: str, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + # If elements in `batch` are provided as a dictionaries from key to Tensors, + # then iterate through the batch list, and group Tensors to the same key + key_to_tensor_list = defaultdict(list) + + # torch.no_grad() mitigates GPU memory issues + # https://github.com/apache/beam/issues/22811 + with torch.no_grad(): + for example in batch: + for key, tensor in example.items(): + key_to_tensor_list[key].append(tensor) + key_to_batched_tensors = {} + for key in key_to_tensor_list: + batched_tensors = torch.stack(key_to_tensor_list[key]) + batched_tensors = _convert_to_device(batched_tensors, device) + key_to_batched_tensors[key] = batched_tensors + pred_fn = getattr(model, model_fn) + predictions = pred_fn(**key_to_batched_tensors, **inference_args) + return _convert_to_result(batch, predictions) + + return attr_fn + + @experimental(extra_message="No backwards-compatibility guarantees.") class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], PredictionResult, @@ -214,7 +328,9 @@ def __init__( state_dict_path: str, model_class: Callable[..., torch.nn.Module], model_params: Dict[str, Any], - device: str = 'CPU'): + device: str = 'CPU', + *, + inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn): """Implementation of the ModelHandler interface for PyTorch. Example Usage:: @@ -237,6 +353,8 @@ def __init__( device: the device on which you wish to run the model. If ``device = GPU`` then a GPU device will be used if it is available. Otherwise, it will be CPU. + inference_fn: the function to invoke on run_inference. + default = default_keyed_tensor_inference_fn **Supported Versions:** RunInference APIs in Apache Beam have been tested with PyTorch 1.9 and 1.10. @@ -250,6 +368,7 @@ def __init__( self._device = torch.device('cpu') self._model_class = model_class self._model_params = model_params + self._inference_fn = inference_fn def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -289,24 +408,7 @@ def run_inference( """ inference_args = {} if not inference_args else inference_args - # If elements in `batch` are provided as a dictionaries from key to Tensors, - # then iterate through the batch list, and group Tensors to the same key - key_to_tensor_list = defaultdict(list) - - # torch.no_grad() mitigates GPU memory issues - # https://github.com/apache/beam/issues/22811 - with torch.no_grad(): - for example in batch: - for key, tensor in example.items(): - key_to_tensor_list[key].append(tensor) - key_to_batched_tensors = {} - for key in key_to_tensor_list: - batched_tensors = torch.stack(key_to_tensor_list[key]) - batched_tensors = _convert_to_device(batched_tensors, self._device) - key_to_batched_tensors[key] = batched_tensors - predictions = model(**key_to_batched_tensors, **inference_args) - - return _convert_to_result(batch, predictions) + return self._inference_fn(batch, model, self._device, inference_args) def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: """ diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py index 32036f43de86..d6d3a2934555 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py @@ -37,6 +37,10 @@ import torch from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference + from apache_beam.ml.inference.pytorch_inference import default_keyed_tensor_inference_fn + from apache_beam.ml.inference.pytorch_inference import default_tensor_inference_fn + from apache_beam.ml.inference.pytorch_inference import make_keyed_tensor_model_fn + from apache_beam.ml.inference.pytorch_inference import make_tensor_model_fn from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor except ImportError: @@ -97,6 +101,15 @@ for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) ] +KEYED_TORCH_HELPER_PREDICTIONS = [ + PredictionResult(ex, pred) for ex, + pred in zip( + KEYED_TORCH_EXAMPLES, + torch.Tensor([(example['k1'] * 2.0 + 0.5) + + (example['k2'] * 2.0 + 0.5) + 0.5 + for example in KEYED_TORCH_EXAMPLES]).reshape(-1, 1)) +] + KEYED_TORCH_DICT_OUT_PREDICTIONS = [ PredictionResult( p.example, { @@ -106,14 +119,16 @@ class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandlerTensor): - def __init__(self, device): + def __init__(self, device, *, inference_fn=default_tensor_inference_fn): self._device = device + self._inference_fn = inference_fn class TestPytorchModelHandlerKeyedTensorForInferenceOnly( PytorchModelHandlerKeyedTensor): - def __init__(self, device): + def __init__(self, device, *, inference_fn=default_keyed_tensor_inference_fn): self._device = device + self._inference_fn = inference_fn def _compare_prediction_result(x, y): @@ -134,6 +149,16 @@ def _compare_prediction_result(x, y): return torch.equal(x.inference, y.inference) +def custom_tensor_inference_fn(batch, model, device, inference_args): + predictions = [ + PredictionResult(ex, pred) for ex, + pred in zip( + batch, + torch.Tensor([item * 2.0 + 1.5 for item in batch]).reshape(-1, 1)) + ] + return predictions + + class PytorchLinearRegression(torch.nn.Module): def __init__(self, input_dim, output_dim): super().__init__() @@ -143,6 +168,10 @@ def forward(self, x): out = self.linear(x) return out + def generate(self, x): + out = self.linear(x) + 0.5 + return out + class PytorchLinearRegressionDict(torch.nn.Module): def __init__(self, input_dim, output_dim): @@ -231,6 +260,33 @@ def test_run_inference_multiple_tensor_features_dict_output(self): for actual, expected in zip(predictions, TWO_FEATURES_DICT_OUT_PREDICTIONS): self.assertEqual(actual, expected) + def test_run_inference_custom(self): + examples = [ + torch.from_numpy(np.array([1], dtype="float32")), + torch.from_numpy(np.array([5], dtype="float32")), + torch.from_numpy(np.array([-3], dtype="float32")), + torch.from_numpy(np.array([10.0], dtype="float32")), + ] + expected_predictions = [ + PredictionResult(ex, pred) for ex, + pred in zip( + examples, + torch.Tensor([example * 2.0 + 1.5 + for example in examples]).reshape(-1, 1)) + ] + + model = PytorchLinearRegression(input_dim=1, output_dim=1) + model.load_state_dict( + OrderedDict([('linear.weight', torch.Tensor([[2.0]])), + ('linear.bias', torch.Tensor([0.5]))])) + model.eval() + + inference_runner = TestPytorchModelHandlerForInferenceOnly( + torch.device('cpu'), inference_fn=custom_tensor_inference_fn) + predictions = inference_runner.run_inference(examples, model) + for actual, expected in zip(predictions, expected_predictions): + self.assertEqual(actual, expected) + def test_run_inference_keyed(self): """ This tests for inputs that are passed as a dictionary from key to tensor @@ -315,6 +371,77 @@ def test_inference_runner_inference_args(self): for actual, expected in zip(predictions, KEYED_TORCH_PREDICTIONS): self.assertEqual(actual, expected) + def test_run_inference_helper(self): + examples = [ + torch.from_numpy(np.array([1], dtype="float32")), + torch.from_numpy(np.array([5], dtype="float32")), + torch.from_numpy(np.array([-3], dtype="float32")), + torch.from_numpy(np.array([10.0], dtype="float32")), + ] + expected_predictions = [ + PredictionResult(ex, pred) for ex, + pred in zip( + examples, + torch.Tensor([example * 2.0 + 1.0 + for example in examples]).reshape(-1, 1)) + ] + + gen_fn = make_tensor_model_fn('generate') + + model = PytorchLinearRegression(input_dim=1, output_dim=1) + model.load_state_dict( + OrderedDict([('linear.weight', torch.Tensor([[2.0]])), + ('linear.bias', torch.Tensor([0.5]))])) + model.eval() + + inference_runner = TestPytorchModelHandlerForInferenceOnly( + torch.device('cpu'), inference_fn=gen_fn) + predictions = inference_runner.run_inference(examples, model) + for actual, expected in zip(predictions, expected_predictions): + self.assertEqual(actual, expected) + + def test_run_inference_keyed_helper(self): + """ + This tests for inputs that are passed as a dictionary from key to tensor + instead of a standard non-keyed tensor example. + + Example: + Typical input format is + input = torch.tensor([1, 2, 3]) + + But Pytorch syntax allows inputs to have the form + input = { + 'k1' : torch.tensor([1, 2, 3]), + 'k2' : torch.tensor([4, 5, 6]) + } + """ + class PytorchLinearRegressionMultipleArgs(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.linear = torch.nn.Linear(input_dim, output_dim) + + def forward(self, k1, k2): + out = self.linear(k1) + self.linear(k2) + return out + + def generate(self, k1, k2): + out = self.linear(k1) + self.linear(k2) + 0.5 + return out + + model = PytorchLinearRegressionMultipleArgs(input_dim=1, output_dim=1) + model.load_state_dict( + OrderedDict([('linear.weight', torch.Tensor([[2.0]])), + ('linear.bias', torch.Tensor([0.5]))])) + model.eval() + + gen_fn = make_keyed_tensor_model_fn('generate') + + inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly( + torch.device('cpu'), inference_fn=gen_fn) + predictions = inference_runner.run_inference(KEYED_TORCH_EXAMPLES, model) + for actual, expected in zip(predictions, KEYED_TORCH_HELPER_PREDICTIONS): + self.assertTrue(_compare_prediction_result(actual, expected)) + def test_num_bytes(self): inference_runner = TestPytorchModelHandlerForInferenceOnly( torch.device('cpu'))