Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdks/python/apache_beam/examples/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ To use this transform, you need a dataset and model for language modeling.
...
```
2. Create a file named `MODEL_PATH` that contains the pickled file of a scikit-learn model trained on MNIST data. Please refer to this scikit-learn [model persistence documentation](https://scikit-learn.org/stable/model_persistence.html) on how to serialize models.
3. Update sklearn_examples_requirements.txt to match the version of sklearn used to train the model. Sklearn doesn't guarantee model compatability between versions.


### Running `sklearn_mnist_classification.py`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# This should match the saved version of your trained model.
# Beam's tests use sklearn 1.0.2 for their saved models.
scikit-learn==1.0.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will dependabot update this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure - I added a comment so that it will be obvious we should reject PRs that try to bump this.

Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"""

import argparse
import os
from typing import Iterable

import pandas
Expand Down Expand Up @@ -137,6 +138,12 @@ def run(
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
requirements_dir = os.path.dirname(os.path.realpath(__file__))
# Pin to the version that we trained the model on.
# Sklearn doesn't guarantee compatability between versions.
pipeline_options.view_as(
SetupOptions
).requirements_file = f'{requirements_dir}/sklearn_examples_requirements.txt'

pipeline = test_pipeline
if not test_pipeline:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import argparse
import logging
import os
from typing import Iterable
from typing import List
from typing import Tuple
Expand Down Expand Up @@ -90,6 +91,12 @@ def run(
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
requirements_dir = os.path.dirname(os.path.realpath(__file__))
# Pin to the version that we trained the model on.
# Sklearn doesn't guarantee compatability between versions.
pipeline_options.view_as(
SetupOptions
).requirements_file = f'{requirements_dir}/sklearn_examples_requirements.txt'

# In this example we pass keyed inputs to RunInference transform.
# Therefore, we use KeyedModelHandler wrapper over SklearnModelHandlerNumpy.
Expand Down