diff --git a/sdks/python/apache_beam/examples/inference/README.md b/sdks/python/apache_beam/examples/inference/README.md index f769e82bf22f..2dc1247934d0 100644 --- a/sdks/python/apache_beam/examples/inference/README.md +++ b/sdks/python/apache_beam/examples/inference/README.md @@ -334,6 +334,7 @@ To use this transform, you need a dataset and model for language modeling. ... ``` 2. Create a file named `MODEL_PATH` that contains the pickled file of a scikit-learn model trained on MNIST data. Please refer to this scikit-learn [model persistence documentation](https://scikit-learn.org/stable/model_persistence.html) on how to serialize models. +3. Update sklearn_examples_requirements.txt to match the version of sklearn used to train the model. Sklearn doesn't guarantee model compatability between versions. ### Running `sklearn_mnist_classification.py` diff --git a/sdks/python/apache_beam/examples/inference/sklearn_examples_requirements.txt b/sdks/python/apache_beam/examples/inference/sklearn_examples_requirements.txt new file mode 100644 index 000000000000..7c41e37e01b7 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/sklearn_examples_requirements.txt @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This should match the saved version of your trained model. +# Beam's tests use sklearn 1.0.2 for their saved models. +scikit-learn==1.0.2 \ No newline at end of file diff --git a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py index 0d0adb0425a8..3aa2f362fa64 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py @@ -30,6 +30,7 @@ """ import argparse +import os from typing import Iterable import pandas @@ -137,6 +138,12 @@ def run( known_args, pipeline_args = parse_known_args(argv) pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + requirements_dir = os.path.dirname(os.path.realpath(__file__)) + # Pin to the version that we trained the model on. + # Sklearn doesn't guarantee compatability between versions. + pipeline_options.view_as( + SetupOptions + ).requirements_file = f'{requirements_dir}/sklearn_examples_requirements.txt' pipeline = test_pipeline if not test_pipeline: diff --git a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py index e748166e6fda..6f8ea929bbb6 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py @@ -26,6 +26,7 @@ import argparse import logging +import os from typing import Iterable from typing import List from typing import Tuple @@ -90,6 +91,12 @@ def run( known_args, pipeline_args = parse_known_args(argv) pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + requirements_dir = os.path.dirname(os.path.realpath(__file__)) + # Pin to the version that we trained the model on. + # Sklearn doesn't guarantee compatability between versions. + pipeline_options.view_as( + SetupOptions + ).requirements_file = f'{requirements_dir}/sklearn_examples_requirements.txt' # In this example we pass keyed inputs to RunInference transform. # Therefore, we use KeyedModelHandler wrapper over SklearnModelHandlerNumpy.