Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 129 additions & 27 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@
'PytorchModelHandlerKeyedTensor',
]

TensorInferenceFn = Callable[
[Sequence[torch.Tensor], torch.nn.Module, str, Optional[Dict[str, Any]]],
Iterable[PredictionResult]]

KeyedTensorInferenceFn = Callable[[
Sequence[Dict[str, torch.Tensor]],
torch.nn.Module,
str,
Optional[Dict[str, Any]]
],
Iterable[PredictionResult]]


def _load_model(
model_class: torch.nn.Module, state_dict_path, device, **model_params):
Expand Down Expand Up @@ -100,6 +112,46 @@ def _convert_to_result(
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]


def default_tensor_inference_fn(
batch: Sequence[torch.Tensor],
model: torch.nn.Module,
device: str,
inference_args: Optional[Dict[str,
Any]] = None) -> Iterable[PredictionResult]:
# torch.no_grad() mitigates GPU memory issues
# https://github.com/apache/beam/issues/22811
with torch.no_grad():
batched_tensors = torch.stack(batch)
batched_tensors = _convert_to_device(batched_tensors, device)
predictions = model(batched_tensors, **inference_args)
return _convert_to_result(batch, predictions)
Comment on lines +123 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this, my only concern is that doing the simple thing (calling model.generate(...) instead of model(...) or something similar) is now harder to do performantly.

I wonder if there's a way to still make that easy - one option would be to add something like a convenience function make_tensor_inference_override_fn(model_function: string) that generates a lambda. So the user code would then look like: my_tensor_inference_fn = make_tensor_inference_overide_fn('generate'), which would set my_tensor_inference_fn to an anonymous function that does:

  with torch.no_grad():
    batched_tensors = torch.stack(batch)
    batched_tensors = _convert_to_device(batched_tensors, device)
    predictions = model.generate(batched_tensors, **inference_args)
    return _convert_to_result(batch, predictions)

Does that make sense/sound reasonable?

Copy link
Contributor Author

@jrmccluskey jrmccluskey Nov 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convenience function kind of works, although you get a funky named function defined within the scope of the convenience function. No multi-line lambdas will do that. The function winds up looking something like this:

with torch.no_grad():
      batched_tensors = torch.stack(batch)
      batched_tensors = _convert_to_device(batched_tensors, device)
      pred_fn = model.get_attr(model_fn)
      predictions = pred_fn(batched_tensors, **inference_args)
      return _convert_to_result(batch, predictions)

Although the more I think about it, the more I think we should just provide a generate function users can pass since we know using that routing instead is a major motivating example for this feature. There's also something to be said for making _convert_to_device and _convert_to_result available to users as building blocks for their custom functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although you get a funky named function defined within the scope of the convenience function. No multi-line lambdas will do that. The function winds up looking something like this:

Yeah, that's what I was imagining here. Note that while this is funky for us, its a smooth user experience because they don't have to be aware of the get_attr call.

Although the more I think about it, the more I think we should just provide a generate function users can pass since we know using that routing instead is a major motivating example for this feature.

I agree, though I think generate is just an example and we want this to be as easy as possible for similar functions as well. I'd probably argue that a generic make_tensor_inference_override_fn (maybe with better naming 😅) does that better than us trying to hit all the functions a user could provide. We don't necessarily have to do that, but if you disagree let's get more peoples' voices/opinions involved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure whatever we do, we shouldn't instantiate the model before the pipeline startup. If the model is large, instantiating model's object takes a lot of time and sometimes pipeline won't start because we have hit a limit (https://cloud.google.com/dataflow/quotas#:~:text=The%20default%20disk%20size%20is,for%20Dataflow%20Shuffle%20batch%20pipelines._

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The helper route isn't that bad, a little verbose on our end but it should make user routing easier. First run at it was just pushed.



def make_tensor_model_fn(model_fn: str) -> TensorInferenceFn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add tests for these make_XYZ_fn functions? Otherwise, LGTM

"""
Produces a TensorInferenceFn that uses a method of the model other that
the forward() method.

Args:
model_fn: A string name of the method to be used. This is accessed through
getattr(model, model_fn)
"""
def attr_fn(
batch: Sequence[torch.Tensor],
model: torch.nn.Module,
device: str,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
with torch.no_grad():
batched_tensors = torch.stack(batch)
batched_tensors = _convert_to_device(batched_tensors, device)
pred_fn = getattr(model, model_fn)
predictions = pred_fn(batched_tensors, **inference_args)
return _convert_to_result(batch, predictions)

return attr_fn


class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
PredictionResult,
torch.nn.Module]):
Expand All @@ -108,7 +160,9 @@ def __init__(
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
device: str = 'CPU'):
device: str = 'CPU',
*,
inference_fn: TensorInferenceFn = default_tensor_inference_fn):
"""Implementation of the ModelHandler interface for PyTorch.

Example Usage::
Expand All @@ -127,6 +181,8 @@ def __init__(
device: the device on which you wish to run the model. If
``device = GPU`` then a GPU device will be used if it is available.
Otherwise, it will be CPU.
inference_fn: the inference function to use during RunInference.
default=_default_tensor_inference_fn

**Supported Versions:** RunInference APIs in Apache Beam have been tested
with PyTorch 1.9 and 1.10.
Expand All @@ -140,6 +196,7 @@ def __init__(
self._device = torch.device('cpu')
self._model_class = model_class
self._model_params = model_params
self._inference_fn = inference_fn

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
Expand Down Expand Up @@ -179,13 +236,7 @@ def run_inference(
"""
inference_args = {} if not inference_args else inference_args

# torch.no_grad() mitigates GPU memory issues
# https://github.com/apache/beam/issues/22811
with torch.no_grad():
batched_tensors = torch.stack(batch)
batched_tensors = _convert_to_device(batched_tensors, self._device)
predictions = model(batched_tensors, **inference_args)
return _convert_to_result(batch, predictions)
return self._inference_fn(batch, model, self._device, inference_args)

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
Expand All @@ -205,6 +256,69 @@ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
pass


def default_keyed_tensor_inference_fn(
batch: Sequence[Dict[str, torch.Tensor]],
model: torch.nn.Module,
device: str,
inference_args: Optional[Dict[str,
Any]] = None) -> Iterable[PredictionResult]:
# If elements in `batch` are provided as a dictionaries from key to Tensors,
# then iterate through the batch list, and group Tensors to the same key
key_to_tensor_list = defaultdict(list)

# torch.no_grad() mitigates GPU memory issues
# https://github.com/apache/beam/issues/22811
with torch.no_grad():
for example in batch:
for key, tensor in example.items():
key_to_tensor_list[key].append(tensor)
key_to_batched_tensors = {}
for key in key_to_tensor_list:
batched_tensors = torch.stack(key_to_tensor_list[key])
batched_tensors = _convert_to_device(batched_tensors, device)
key_to_batched_tensors[key] = batched_tensors
predictions = model(**key_to_batched_tensors, **inference_args)

return _convert_to_result(batch, predictions)


def make_keyed_tensor_model_fn(model_fn: str) -> KeyedTensorInferenceFn:
"""
Produces a KeyedTensorInferenceFn that uses a method of the model other that
the forward() method.

Args:
model_fn: A string name of the method to be used. This is accessed through
getattr(model, model_fn)
"""
def attr_fn(
batch: Sequence[torch.Tensor],
model: torch.nn.Module,
device: str,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
# If elements in `batch` are provided as a dictionaries from key to Tensors,
# then iterate through the batch list, and group Tensors to the same key
key_to_tensor_list = defaultdict(list)

# torch.no_grad() mitigates GPU memory issues
# https://github.com/apache/beam/issues/22811
with torch.no_grad():
for example in batch:
for key, tensor in example.items():
key_to_tensor_list[key].append(tensor)
key_to_batched_tensors = {}
for key in key_to_tensor_list:
batched_tensors = torch.stack(key_to_tensor_list[key])
batched_tensors = _convert_to_device(batched_tensors, device)
key_to_batched_tensors[key] = batched_tensors
pred_fn = getattr(model, model_fn)
predictions = pred_fn(**key_to_batched_tensors, **inference_args)
return _convert_to_result(batch, predictions)

return attr_fn


@experimental(extra_message="No backwards-compatibility guarantees.")
class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
PredictionResult,
Expand All @@ -214,7 +328,9 @@ def __init__(
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
device: str = 'CPU'):
device: str = 'CPU',
*,
inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn):
"""Implementation of the ModelHandler interface for PyTorch.

Example Usage::
Expand All @@ -237,6 +353,8 @@ def __init__(
device: the device on which you wish to run the model. If
``device = GPU`` then a GPU device will be used if it is available.
Otherwise, it will be CPU.
inference_fn: the function to invoke on run_inference.
default = default_keyed_tensor_inference_fn

**Supported Versions:** RunInference APIs in Apache Beam have been tested
with PyTorch 1.9 and 1.10.
Expand All @@ -250,6 +368,7 @@ def __init__(
self._device = torch.device('cpu')
self._model_class = model_class
self._model_params = model_params
self._inference_fn = inference_fn

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
Expand Down Expand Up @@ -289,24 +408,7 @@ def run_inference(
"""
inference_args = {} if not inference_args else inference_args

# If elements in `batch` are provided as a dictionaries from key to Tensors,
# then iterate through the batch list, and group Tensors to the same key
key_to_tensor_list = defaultdict(list)

# torch.no_grad() mitigates GPU memory issues
# https://github.com/apache/beam/issues/22811
with torch.no_grad():
for example in batch:
for key, tensor in example.items():
key_to_tensor_list[key].append(tensor)
key_to_batched_tensors = {}
for key in key_to_tensor_list:
batched_tensors = torch.stack(key_to_tensor_list[key])
batched_tensors = _convert_to_device(batched_tensors, self._device)
key_to_batched_tensors[key] = batched_tensors
predictions = model(**key_to_batched_tensors, **inference_args)

return _convert_to_result(batch, predictions)
return self._inference_fn(batch, model, self._device, inference_args)

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
Expand Down
Loading