-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Add custom inference function support to the PyTorch model handler #24062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
2532890
Initial type def and function signature
jrmccluskey 4ed9ca5
[Draft] Add custom inference fn support to Pytorch Model Handler
jrmccluskey cbe4297
Formatting
jrmccluskey a8d5150
Split out default
jrmccluskey a495835
Remove Keyed version for testing
jrmccluskey 4e35ebd
Move device optimization
jrmccluskey 0c4e292
Make default available for import, add to test classes
jrmccluskey 5143a23
Remove incorrect default from keyed test
jrmccluskey 05b1fcc
Keyed impl
jrmccluskey f63820b
Fix device arg
jrmccluskey a926725
custom inference test
jrmccluskey 253d887
formatting
jrmccluskey b977863
Add helpers to define custom inference functions using model methods
jrmccluskey d992873
Trailing whitespace
jrmccluskey f95042d
Unit tests
jrmccluskey fb9fb9a
Fix incorrect getattr syntax
jrmccluskey 047e873
Type typo
jrmccluskey 2d178ed
Fix docstring
jrmccluskey be50381
Fix keyed helper, add basic generate route
jrmccluskey 84249cc
Modify generate() to be different than forward()
jrmccluskey f3ce521
formatting
jrmccluskey dae57c5
Remove extra generate() def
jrmccluskey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,18 @@ | |
| 'PytorchModelHandlerKeyedTensor', | ||
| ] | ||
|
|
||
| TensorInferenceFn = Callable[ | ||
| [Sequence[torch.Tensor], torch.nn.Module, str, Optional[Dict[str, Any]]], | ||
| Iterable[PredictionResult]] | ||
|
|
||
| KeyedTensorInferenceFn = Callable[[ | ||
| Sequence[Dict[str, torch.Tensor]], | ||
| torch.nn.Module, | ||
| str, | ||
| Optional[Dict[str, Any]] | ||
| ], | ||
| Iterable[PredictionResult]] | ||
|
|
||
|
|
||
| def _load_model( | ||
| model_class: torch.nn.Module, state_dict_path, device, **model_params): | ||
|
|
@@ -100,6 +112,46 @@ def _convert_to_result( | |
| return [PredictionResult(x, y) for x, y in zip(batch, predictions)] | ||
|
|
||
|
|
||
| def default_tensor_inference_fn( | ||
| batch: Sequence[torch.Tensor], | ||
| model: torch.nn.Module, | ||
| device: str, | ||
| inference_args: Optional[Dict[str, | ||
| Any]] = None) -> Iterable[PredictionResult]: | ||
| # torch.no_grad() mitigates GPU memory issues | ||
| # https://github.com/apache/beam/issues/22811 | ||
| with torch.no_grad(): | ||
| batched_tensors = torch.stack(batch) | ||
| batched_tensors = _convert_to_device(batched_tensors, device) | ||
| predictions = model(batched_tensors, **inference_args) | ||
| return _convert_to_result(batch, predictions) | ||
|
|
||
|
|
||
| def make_tensor_model_fn(model_fn: str) -> TensorInferenceFn: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add tests for these |
||
| """ | ||
| Produces a TensorInferenceFn that uses a method of the model other that | ||
| the forward() method. | ||
|
|
||
| Args: | ||
| model_fn: A string name of the method to be used. This is accessed through | ||
| getattr(model, model_fn) | ||
| """ | ||
| def attr_fn( | ||
| batch: Sequence[torch.Tensor], | ||
| model: torch.nn.Module, | ||
| device: str, | ||
| inference_args: Optional[Dict[str, Any]] = None | ||
| ) -> Iterable[PredictionResult]: | ||
| with torch.no_grad(): | ||
| batched_tensors = torch.stack(batch) | ||
| batched_tensors = _convert_to_device(batched_tensors, device) | ||
| pred_fn = getattr(model, model_fn) | ||
| predictions = pred_fn(batched_tensors, **inference_args) | ||
| return _convert_to_result(batch, predictions) | ||
|
|
||
| return attr_fn | ||
|
|
||
|
|
||
| class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, | ||
| PredictionResult, | ||
| torch.nn.Module]): | ||
|
|
@@ -108,7 +160,9 @@ def __init__( | |
| state_dict_path: str, | ||
| model_class: Callable[..., torch.nn.Module], | ||
| model_params: Dict[str, Any], | ||
| device: str = 'CPU'): | ||
| device: str = 'CPU', | ||
| *, | ||
| inference_fn: TensorInferenceFn = default_tensor_inference_fn): | ||
| """Implementation of the ModelHandler interface for PyTorch. | ||
|
|
||
| Example Usage:: | ||
|
|
@@ -127,6 +181,8 @@ def __init__( | |
| device: the device on which you wish to run the model. If | ||
| ``device = GPU`` then a GPU device will be used if it is available. | ||
| Otherwise, it will be CPU. | ||
| inference_fn: the inference function to use during RunInference. | ||
| default=_default_tensor_inference_fn | ||
|
|
||
| **Supported Versions:** RunInference APIs in Apache Beam have been tested | ||
| with PyTorch 1.9 and 1.10. | ||
|
|
@@ -140,6 +196,7 @@ def __init__( | |
| self._device = torch.device('cpu') | ||
| self._model_class = model_class | ||
| self._model_params = model_params | ||
| self._inference_fn = inference_fn | ||
|
|
||
| def load_model(self) -> torch.nn.Module: | ||
| """Loads and initializes a Pytorch model for processing.""" | ||
|
|
@@ -179,13 +236,7 @@ def run_inference( | |
| """ | ||
| inference_args = {} if not inference_args else inference_args | ||
|
|
||
| # torch.no_grad() mitigates GPU memory issues | ||
| # https://github.com/apache/beam/issues/22811 | ||
| with torch.no_grad(): | ||
| batched_tensors = torch.stack(batch) | ||
| batched_tensors = _convert_to_device(batched_tensors, self._device) | ||
| predictions = model(batched_tensors, **inference_args) | ||
| return _convert_to_result(batch, predictions) | ||
| return self._inference_fn(batch, model, self._device, inference_args) | ||
|
|
||
| def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: | ||
| """ | ||
|
|
@@ -205,6 +256,69 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): | |
| pass | ||
|
|
||
|
|
||
| def default_keyed_tensor_inference_fn( | ||
| batch: Sequence[Dict[str, torch.Tensor]], | ||
| model: torch.nn.Module, | ||
| device: str, | ||
| inference_args: Optional[Dict[str, | ||
| Any]] = None) -> Iterable[PredictionResult]: | ||
| # If elements in `batch` are provided as a dictionaries from key to Tensors, | ||
| # then iterate through the batch list, and group Tensors to the same key | ||
| key_to_tensor_list = defaultdict(list) | ||
|
|
||
| # torch.no_grad() mitigates GPU memory issues | ||
| # https://github.com/apache/beam/issues/22811 | ||
| with torch.no_grad(): | ||
| for example in batch: | ||
| for key, tensor in example.items(): | ||
| key_to_tensor_list[key].append(tensor) | ||
| key_to_batched_tensors = {} | ||
| for key in key_to_tensor_list: | ||
| batched_tensors = torch.stack(key_to_tensor_list[key]) | ||
| batched_tensors = _convert_to_device(batched_tensors, device) | ||
| key_to_batched_tensors[key] = batched_tensors | ||
| predictions = model(**key_to_batched_tensors, **inference_args) | ||
|
|
||
| return _convert_to_result(batch, predictions) | ||
|
|
||
|
|
||
| def make_keyed_tensor_model_fn(model_fn: str) -> KeyedTensorInferenceFn: | ||
| """ | ||
| Produces a KeyedTensorInferenceFn that uses a method of the model other that | ||
| the forward() method. | ||
|
|
||
| Args: | ||
| model_fn: A string name of the method to be used. This is accessed through | ||
| getattr(model, model_fn) | ||
| """ | ||
| def attr_fn( | ||
| batch: Sequence[torch.Tensor], | ||
| model: torch.nn.Module, | ||
| device: str, | ||
| inference_args: Optional[Dict[str, Any]] = None | ||
| ) -> Iterable[PredictionResult]: | ||
| # If elements in `batch` are provided as a dictionaries from key to Tensors, | ||
| # then iterate through the batch list, and group Tensors to the same key | ||
| key_to_tensor_list = defaultdict(list) | ||
|
|
||
| # torch.no_grad() mitigates GPU memory issues | ||
| # https://github.com/apache/beam/issues/22811 | ||
| with torch.no_grad(): | ||
| for example in batch: | ||
| for key, tensor in example.items(): | ||
| key_to_tensor_list[key].append(tensor) | ||
| key_to_batched_tensors = {} | ||
| for key in key_to_tensor_list: | ||
| batched_tensors = torch.stack(key_to_tensor_list[key]) | ||
| batched_tensors = _convert_to_device(batched_tensors, device) | ||
| key_to_batched_tensors[key] = batched_tensors | ||
| pred_fn = getattr(model, model_fn) | ||
| predictions = pred_fn(**key_to_batched_tensors, **inference_args) | ||
| return _convert_to_result(batch, predictions) | ||
|
|
||
| return attr_fn | ||
|
|
||
|
|
||
| @experimental(extra_message="No backwards-compatibility guarantees.") | ||
| class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], | ||
| PredictionResult, | ||
|
|
@@ -214,7 +328,9 @@ def __init__( | |
| state_dict_path: str, | ||
| model_class: Callable[..., torch.nn.Module], | ||
| model_params: Dict[str, Any], | ||
| device: str = 'CPU'): | ||
| device: str = 'CPU', | ||
| *, | ||
| inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn): | ||
| """Implementation of the ModelHandler interface for PyTorch. | ||
|
|
||
| Example Usage:: | ||
|
|
@@ -237,6 +353,8 @@ def __init__( | |
| device: the device on which you wish to run the model. If | ||
| ``device = GPU`` then a GPU device will be used if it is available. | ||
| Otherwise, it will be CPU. | ||
| inference_fn: the function to invoke on run_inference. | ||
| default = default_keyed_tensor_inference_fn | ||
|
|
||
| **Supported Versions:** RunInference APIs in Apache Beam have been tested | ||
| with PyTorch 1.9 and 1.10. | ||
|
|
@@ -250,6 +368,7 @@ def __init__( | |
| self._device = torch.device('cpu') | ||
| self._model_class = model_class | ||
| self._model_params = model_params | ||
| self._inference_fn = inference_fn | ||
|
|
||
| def load_model(self) -> torch.nn.Module: | ||
| """Loads and initializes a Pytorch model for processing.""" | ||
|
|
@@ -289,24 +408,7 @@ def run_inference( | |
| """ | ||
| inference_args = {} if not inference_args else inference_args | ||
|
|
||
| # If elements in `batch` are provided as a dictionaries from key to Tensors, | ||
| # then iterate through the batch list, and group Tensors to the same key | ||
| key_to_tensor_list = defaultdict(list) | ||
|
|
||
| # torch.no_grad() mitigates GPU memory issues | ||
| # https://github.com/apache/beam/issues/22811 | ||
| with torch.no_grad(): | ||
| for example in batch: | ||
| for key, tensor in example.items(): | ||
| key_to_tensor_list[key].append(tensor) | ||
| key_to_batched_tensors = {} | ||
| for key in key_to_tensor_list: | ||
| batched_tensors = torch.stack(key_to_tensor_list[key]) | ||
| batched_tensors = _convert_to_device(batched_tensors, self._device) | ||
| key_to_batched_tensors[key] = batched_tensors | ||
| predictions = model(**key_to_batched_tensors, **inference_args) | ||
|
|
||
| return _convert_to_result(batch, predictions) | ||
| return self._inference_fn(batch, model, self._device, inference_args) | ||
|
|
||
| def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int: | ||
| """ | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this, my only concern is that doing the simple thing (calling
model.generate(...)instead ofmodel(...)or something similar) is now harder to do performantly.I wonder if there's a way to still make that easy - one option would be to add something like a convenience function
make_tensor_inference_override_fn(model_function: string)that generates a lambda. So the user code would then look like:my_tensor_inference_fn = make_tensor_inference_overide_fn('generate'), which would setmy_tensor_inference_fnto an anonymous function that does:Does that make sense/sound reasonable?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The convenience function kind of works, although you get a funky named function defined within the scope of the convenience function. No multi-line lambdas will do that. The function winds up looking something like this:
Although the more I think about it, the more I think we should just provide a
generatefunction users can pass since we know using that routing instead is a major motivating example for this feature. There's also something to be said for making_convert_to_deviceand_convert_to_resultavailable to users as building blocks for their custom functions.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's what I was imagining here. Note that while this is funky for us, its a smooth user experience because they don't have to be aware of the get_attr call.
I agree, though I think
generateis just an example and we want this to be as easy as possible for similar functions as well. I'd probably argue that a genericmake_tensor_inference_override_fn(maybe with better naming 😅) does that better than us trying to hit all the functions a user could provide. We don't necessarily have to do that, but if you disagree let's get more peoples' voices/opinions involved.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make sure whatever we do, we shouldn't instantiate the model before the pipeline startup. If the model is large, instantiating model's object takes a lot of time and sometimes pipeline won't start because we have hit a limit (https://cloud.google.com/dataflow/quotas#:~:text=The%20default%20disk%20size%20is,for%20Dataflow%20Shuffle%20batch%20pipelines._
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The helper route isn't that bad, a little verbose on our end but it should make user routing easier. First run at it was just pushed.