Skip to content
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/ml/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/ml/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mypy: ignore-errors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeVar should have a bounded source. Right now, we want to keep it generic as possible. So, only this file will be ignored for mypy for now and will circle back to this later.

Jira related to this: https://issues.apache.org/jira/browse/BEAM-14217

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add relevant comments to BEAM-14217 and consider adding a # TODO (BEAM-14217): ...
here if applicable at a later change.


from dataclasses import dataclass
from typing import Tuple
Expand Down Expand Up @@ -51,7 +52,7 @@ class RunInference(beam.PTransform):

TODO(BEAM-14046): Add and link to help documentation
"""
def __init__(self, model_loader: base.ModelLoader) -> beam.pvalue.PCollection:
def __init__(self, model_loader: base.ModelLoader):
self._model_loader = model_loader

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
Expand Down
11 changes: 6 additions & 5 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
from apache_beam.utils import shared

try:
# pylint: disable=g-import-not-at-top
# pylint: disable=wrong-import-order, wrong-import-position
import resource
except ImportError:
resource = None
resource = None # type: ignore[assignment]

_MICROSECOND_TO_MILLISECOND = 1000
_NANOSECOND_TO_MICROSECOND = 1000
Expand All @@ -59,15 +59,16 @@
class InferenceRunner():
"""Implements running inferences for a framework."""
def run_inference(self, batch: List[Any], model: Any) -> Iterable[Any]:
"""Runs inferences on a batch of examples and returns an Iterable of Predictions."""
"""Runs inferences on a batch of examples and
returns an Iterable of Predictions."""
raise NotImplementedError(type(self))

def get_num_bytes(self, batch: Any) -> int:
"""Returns the number of bytes of data for a batch."""
return len(pickle.dumps(batch))

def get_metrics_namespace(self) -> str:
"""Returns a namespace for metrics collected by the RunInference transform."""
"""Returns a namespace for metrics collected by RunInference transform."""
return 'RunInference'


Expand Down Expand Up @@ -249,7 +250,7 @@ def get_current_time_in_microseconds(self) -> int:
class _FineGrainedClock(_Clock):
def get_current_time_in_microseconds(self) -> int:
return int(
time.clock_gettime_ns(time.CLOCK_REALTIME) / # pytype: disable=module-attr
time.clock_gettime_ns(time.CLOCK_REALTIME) / # type: ignore[attr-defined]
_NANOSECOND_TO_MICROSECOND)


Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from typing import Iterable

import apache_beam as beam
import apache_beam.ml.inference.base as base
from apache_beam.metrics.metric import MetricsFilter
from apache_beam.ml.inference import base
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
Expand Down
19 changes: 11 additions & 8 deletions sdks/python/apache_beam/ml/inference/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def run_inference(self, batch: List[torch.Tensor],
the inference call.
"""

batch = torch.stack(batch)
if batch.device != self._device:
batch = batch.to(self._device)
predictions = model(batch)
torch_batch = torch.stack(batch)
if torch_batch.device != self._device:
torch_batch = torch_batch.to(self._device)
predictions = model(torch_batch)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
Expand All @@ -75,10 +75,13 @@ def __init__(
model_params: Dict[str, Any],
device: str = 'CPU'):
"""
state_dict_path: path to the saved dictionary of the model state.
model_class: class of the Pytorch model that defines the model structure.
device: the device on which you wish to run the model. If ``device = GPU``
then device will be cuda if it is available. Otherwise, it will be cpu.
Initializes a PytorchModelLoader
:param state_dict_path: path to the saved dictionary of the model state.
:param model_class: class of the Pytorch model that defines the model
structure.
:param device: the device on which you wish to run the model. If
``device = GPU`` then a GPU device will be used if it is available.
Otherwise, it will be CPU.

See https://pytorch.org/tutorials/beginner/saving_loading_models.html
for details
Expand Down
11 changes: 6 additions & 5 deletions sdks/python/apache_beam/ml/inference/sklearn_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ class ModelFileType(enum.Enum):


class SklearnInferenceRunner(InferenceRunner):
def run_inference(self, batch: List[numpy.array],
model: Any) -> Iterable[numpy.array]:
def run_inference(self, batch: List[numpy.ndarray],
model: Any) -> Iterable[PredictionResult]:
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
predictions = model.predict(vectorized_batch)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def get_num_bytes(self, batch: List[numpy.array]) -> int:
def get_num_bytes(self, batch: List[numpy.ndarray]) -> int:
"""Returns the number of bytes of data for a batch."""
return sum(sys.getsizeof(element) for element in batch)

Expand All @@ -71,8 +71,9 @@ def load_model(self):
elif self._model_file_type == ModelFileType.JOBLIB:
if not joblib:
raise ImportError(
'Could not import joblib in this execution'
' environment. For help with managing dependencies on Python workers see https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/'
'Could not import joblib in this execution environment. '
'For help with managing dependencies on Python workers.'
'see https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/' # pylint: disable=line-too-long
)
return joblib.load(file)
raise AssertionError('Unsupported serialization type.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from sklearn import svm

import apache_beam as beam
import apache_beam.ml.inference.api as api
import apache_beam.ml.inference.base as base
from apache_beam.ml.inference import api
from apache_beam.ml.inference import base
from apache_beam.ml.inference.sklearn_inference import ModelFileType
from apache_beam.ml.inference.sklearn_inference import SklearnInferenceRunner
from apache_beam.ml.inference.sklearn_inference import SklearnModelLoader
Expand All @@ -49,7 +49,7 @@ class FakeModel:
def __init__(self):
self.total_predict_calls = 0

def predict(self, input_vector: numpy.array):
def predict(self, input_vector: numpy.ndarray):
self.total_predict_calls += 1
return numpy.sum(input_vector, axis=1)

Expand Down
11 changes: 0 additions & 11 deletions sdks/python/scripts/generate_pydoc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,6 @@ nitpick_ignore = []
nitpick_ignore += [('py:class', iden) for iden in ignore_identifiers]
nitpick_ignore += [('py:obj', iden) for iden in ignore_identifiers]
nitpick_ignore += [('py:exc', iden) for iden in ignore_references]

# Monkey patch functools.wraps to retain original function argument signature
# for documentation.
# https://github.com/sphinx-doc/sphinx/issues/1711
import functools
def fake_wraps(wrapped):
def wrapper(decorator):
return wrapped
return wrapper

functools.wraps = fake_wraps
EOF

#=== index.rst ===#
Expand Down
1 change: 1 addition & 0 deletions sdks/python/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ deps =
sphinx_rtd_theme==0.4.3
docutils<0.18
Jinja2==3.0.3 # TODO(BEAM-14172): Sphinx version is too old.
torch
commands =
time {toxinidir}/scripts/generate_pydoc.sh

Expand Down