Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 113 additions & 49 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,47 @@
from typing import Dict
from typing import Iterable
from typing import Sequence
from typing import Union

import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult


class PytorchModelHandler(ModelHandler[torch.Tensor,
PredictionResult,
torch.nn.Module]):
""" Implementation of the ModelHandler interface for PyTorch.
def _load_model(
model_class: torch.nn.Module, state_dict_path, device, **model_params):
model = model_class(**model_params)
model.to(device)
file = FileSystems.open(state_dict_path, 'rb')
model.load_state_dict(torch.load(file))
model.eval()
return model

NOTE: This API and its implementation are under development and
do not provide backward compatibility guarantees.

def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor:
"""
Converts samples to a style matching given device.

Note: A user may pass in device='GPU' but if GPU is not detected in the
environment it must be converted back to CPU.
"""
if examples.device != device:
examples = examples.to(device)
return examples


class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
PredictionResult,
torch.nn.Module]):
""" Implementation of the ModelHandler interface for PyTorch."""
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
device: str = 'CPU'):
"""
Initializes a PytorchModelHandler
Initializes a PytorchModelHandlerTensor
:param state_dict_path: path to the saved dictionary of the model state.
:param model_class: class of the Pytorch model that defines the model
structure.
Expand All @@ -67,67 +84,114 @@ def __init__(

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model = self._model_class(**self._model_params)
model.to(self._device)
file = FileSystems.open(self._state_dict_path, 'rb')
model.load_state_dict(torch.load(file))
model.eval()
return model

def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor:
return _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about naming model parameters as a dictionary?

The advantage is that users can specify exactly what their parameters should be.

They would specify the parameters like this:

model_parameters = {
'key_1': 'parameter_1'
}

Then in the future if optional parameters are added they won't collide.

Feel free to do that change in another PR if you think it's a good idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you referring to something like this? #21806


def run_inference(
self, batch: Sequence[torch.Tensor], model: torch.nn.Module,
**kwargs) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Tensors and returns an Iterable of
Tensor Predictions.

This method stacks the list of Tensors in a vectorized format to optimize
the inference call.
"""
prediction_params = kwargs.get('prediction_params', {})
batched_tensors = torch.stack(batch)
batched_tensors = _convert_to_device(batched_tensors, self._device)
predictions = model(batched_tensors, **prediction_params)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""Returns the number of bytes of data for a batch of Tensors."""
return sum((el.element_size() for tensor in batch for el in tensor))

def get_metrics_namespace(self) -> str:
"""
Returns a namespace for metrics collected by the RunInference transform.
"""
return 'RunInferencePytorch'


class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
PredictionResult,
torch.nn.Module]):
""" Implementation of the ModelHandler interface for PyTorch.

NOTE: This API and its implementation are under development and
do not provide backward compatibility guarantees.
"""
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
device: str = 'CPU'):
"""
Converts samples to a style matching given device.
Initializes a PytorchModelHandlerKeyedTensor
:param state_dict_path: path to the saved dictionary of the model state.
:param model_class: class of the Pytorch model that defines the model
structure.
:param device: the device on which you wish to run the model. If
``device = GPU`` then a GPU device will be used if it is available.
Otherwise, it will be CPU.

Note: A user may pass in device='GPU' but if GPU is not detected in the
environment it must be converted back to CPU.
See https://pytorch.org/tutorials/beginner/saving_loading_models.html
for details
"""
if examples.device != self._device:
examples = examples.to(self._device)
return examples
self._state_dict_path = state_dict_path
if device == 'GPU' and torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self._model_class = model_class
self._model_params = model_params

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
return _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)

def run_inference(
self,
batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]],
batch: Sequence[Dict[str, torch.Tensor]],
model: torch.nn.Module,
**kwargs) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Tensors and returns an Iterable of
Runs inferences on a batch of Keyed Tensors and returns an Iterable of
Tensor Predictions.

This method stacks the list of Tensors in a vectorized format to optimize
the inference call.
For the same key across all examples, this will stack all Tensors values
in a vectorized format to optimize the inference call.
"""
prediction_params = kwargs.get('prediction_params', {})

# If elements in `batch` are provided as a dictionaries from key to Tensors,
# then iterate through the batch list, and group Tensors to the same key
if isinstance(batch[0], dict):
key_to_tensor_list = defaultdict(list)
for example in batch:
for key, tensor in example.items():
key_to_tensor_list[key].append(tensor)
key_to_batched_tensors = {}
for key in key_to_tensor_list:
batched_tensors = torch.stack(key_to_tensor_list[key])
batched_tensors = self._convert_to_device(batched_tensors)
key_to_batched_tensors[key] = batched_tensors
predictions = model(**key_to_batched_tensors, **prediction_params)
else:
# If elements in `batch` are provided as Tensors, then do a regular stack
batched_tensors = torch.stack(batch)
batched_tensors = self._convert_to_device(batched_tensors)
predictions = model(batched_tensors, **prediction_params)
key_to_tensor_list = defaultdict(list)
for example in batch:
for key, tensor in example.items():
key_to_tensor_list[key].append(tensor)
key_to_batched_tensors = {}
for key in key_to_tensor_list:
batched_tensors = torch.stack(key_to_tensor_list[key])
batched_tensors = _convert_to_device(batched_tensors, self._device)
key_to_batched_tensors[key] = batched_tensors
predictions = model(**key_to_batched_tensors, **prediction_params)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""Returns the number of bytes of data for a batch of Tensors."""
"""Returns the number of bytes of data for a batch of Dict of Tensors."""
# If elements in `batch` are provided as a dictionaries from key to Tensors
if isinstance(batch[0], dict):
return sum(
(el.element_size() for tensor in batch for el in tensor.values()))
else:
# If elements in `batch` are provided as Tensors
return sum((el.element_size() for tensor in batch for el in tensor))
return sum(
(el.element_size() for tensor in batch for el in tensor.values()))

def get_metrics_namespace(self) -> str:
"""
Expand Down
23 changes: 15 additions & 8 deletions sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
import torch
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
except ImportError:
raise unittest.SkipTest('PyTorch dependencies are not installed')

Expand Down Expand Up @@ -90,7 +91,13 @@
]


class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandler):
class TestPytorchModelHandlerForInferenceOnly(PytorchModelHandlerTensor):
def __init__(self, device):
self._device = device


class TestPytorchModelHandlerKeyedTensorForInferenceOnly(
PytorchModelHandlerKeyedTensor):
def __init__(self, device):
self._device = device

Expand Down Expand Up @@ -209,7 +216,7 @@ def forward(self, k1, k2):
('linear.bias', torch.Tensor([0.5]))]))
model.eval()

inference_runner = TestPytorchModelHandlerForInferenceOnly(
inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
torch.device('cpu'))
predictions = inference_runner.run_inference(KWARGS_TORCH_EXAMPLES, model)
for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS):
Expand All @@ -234,7 +241,7 @@ def test_run_inference_kwargs_prediction_params(self):
('linear.bias', torch.Tensor([0.5]))]))
model.eval()

inference_runner = TestPytorchModelHandlerForInferenceOnly(
inference_runner = TestPytorchModelHandlerKeyedTensorForInferenceOnly(
torch.device('cpu'))
predictions = inference_runner.run_inference(
batch=KWARGS_TORCH_EXAMPLES,
Expand Down Expand Up @@ -274,7 +281,7 @@ def test_pipeline_local_model_simple(self):
path = os.path.join(self.tmpdir, 'my_state_dict_path')
torch.save(state_dict, path)

model_handler = PytorchModelHandler(
model_handler = PytorchModelHandlerTensor(
state_dict_path=path,
model_class=PytorchLinearRegression,
model_params={
Expand All @@ -301,7 +308,7 @@ def test_pipeline_local_model_kwargs_prediction_params(self):
path = os.path.join(self.tmpdir, 'my_state_dict_path')
torch.save(state_dict, path)

model_handler = PytorchModelHandler(
model_handler = PytorchModelHandlerKeyedTensor(
state_dict_path=path,
model_class=PytorchLinearRegressionKwargsPredictionParams,
model_params={
Expand Down Expand Up @@ -334,7 +341,7 @@ def test_pipeline_gcs_model(self):

gs_pth = 'gs://apache-beam-ml/models/' \
'pytorch_lin_reg_model_2x+0.5_state_dict.pth'
model_handler = PytorchModelHandler(
model_handler = PytorchModelHandlerTensor(
state_dict_path=gs_pth,
model_class=PytorchLinearRegression,
model_params={
Expand All @@ -357,7 +364,7 @@ def test_invalid_input_type(self):
path = os.path.join(self.tmpdir, 'my_state_dict_path')
torch.save(state_dict, path)

model_handler = PytorchModelHandler(
model_handler = PytorchModelHandlerTensor(
state_dict_path=path,
model_class=PytorchLinearRegression,
model_params={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def test_bad_input_type_raises(self):
with self.assertRaisesRegex(AssertionError,
'Unsupported serialization type'):
with tempfile.NamedTemporaryFile() as file:
model_loader = SklearnModelHandlerNumpy(
model_handler = SklearnModelHandlerNumpy(
model_uri=file.name, model_file_type=None)
model_loader.load_model()
model_handler.load_model()

@unittest.skipIf(platform.system() == 'Windows', 'BEAM-14359')
def test_pipeline_pandas(self):
Expand Down