Skip to content
6 changes: 4 additions & 2 deletions doc/frameworks/pytorch/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ directories ('train' and 'test').
pytorch_estimator = PyTorch('pytorch-train.py',
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='1.0.0',
framework_version='1.5.0',
py_version='py3',
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
'test': 's3://my-data-bucket/path/to/my/test/data'})
Expand Down Expand Up @@ -247,7 +248,8 @@ operation.
pytorch_estimator = PyTorch(entry_point='train_and_deploy.py',
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='1.0.0')
framework_version='1.5.0',
py_version='py3')
pytorch_estimator.fit('s3://my_bucket/my_training_data/')

# Deploy my estimator to a SageMaker Endpoint and get a Predictor
Expand Down
57 changes: 31 additions & 26 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
empty_framework_version_warning,
python_deprecation_warning,
is_version_equal_or_higher,
python_deprecation_warning,
validate_version_or_image_args,
)
from sagemaker.pytorch import defaults
from sagemaker.pytorch.model import PyTorchModel
Expand All @@ -40,10 +40,10 @@ class PyTorch(Framework):
def __init__(
self,
entry_point,
framework_version=None,
py_version=None,
source_dir=None,
hyperparameters=None,
py_version=defaults.PYTHON_VERSION,
framework_version=None,
image_name=None,
**kwargs
):
Expand All @@ -69,6 +69,13 @@ def __init__(
file which should be executed as the entry point to training.
If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
framework_version (str): PyTorch version you want to use for
executing your model training code. Defaults to ``None``. Required unless
``image_name`` is provided. List of supported versions:
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
py_version (str): Python version you want to use for executing your
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
unless ``image_name`` is provided.
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
with any other training source code dependencies aside from the entry
point file (default: None). If ``source_dir`` is an S3 URI, it must
Expand All @@ -80,12 +87,6 @@ def __init__(
SageMaker. For convenience, this accepts other types for keys
and values, but ``str()`` will be called to convert them before
training.
py_version (str): Python version you want to use for executing your
model training code (default: 'py3'). One of 'py2' or 'py3'.
framework_version (str): PyTorch version you want to use for
executing your model training code. List of supported versions
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
If not specified, this will default to 0.4.
image_name (str): If specified, the estimator will use this image
for training and hosting, instead of selecting the appropriate
SageMaker official image based on framework_version and
Expand All @@ -95,6 +96,9 @@ def __init__(
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
will be raised.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
constructor.

Expand All @@ -104,28 +108,25 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
validate_version_or_image_args(framework_version, py_version, image_name)
if py_version == "py2":
logger.warning(
empty_framework_version_warning(defaults.PYTORCH_VERSION, self.LATEST_VERSION)
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version or defaults.PYTORCH_VERSION
self.framework_version = framework_version
self.py_version = py_version

if "enable_sagemaker_metrics" not in kwargs:
# enable sagemaker metrics for PT v1.3 or greater:
if is_version_equal_or_higher([1, 3], self.framework_version):
if self.framework_version and is_version_equal_or_higher(
[1, 3], self.framework_version
):
kwargs["enable_sagemaker_metrics"] = True

super(PyTorch, self).__init__(
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
)

if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)

self.py_version = py_version

def create_model(
self,
model_server_workers=None,
Expand Down Expand Up @@ -177,12 +178,12 @@ def create_model(
self.model_data,
role or self.role,
entry_point or self.entry_point,
framework_version=self.framework_version,
py_version=self.py_version,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand Down Expand Up @@ -210,15 +211,19 @@ class constructor
image_name = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_name)

if tag is None:
framework_version = None
else:
framework_version = framework_version_from_tag(tag)
init_params["framework_version"] = framework_version
init_params["py_version"] = py_version

if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params["image_name"] = image_name
return init_params

init_params["py_version"] = py_version
init_params["framework_version"] = framework_version_from_tag(tag)

training_job_name = init_params["base_job_name"]

if framework != cls.__framework_name__:
Expand Down
38 changes: 19 additions & 19 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
validate_version_or_image_args,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch import defaults
Expand Down Expand Up @@ -66,9 +66,9 @@ def __init__(
model_data,
role,
entry_point,
image=None,
py_version=defaults.PYTHON_VERSION,
framework_version=None,
py_version=None,
image=None,
predictor_cls=PyTorchPredictor,
model_server_workers=None,
**kwargs
Expand All @@ -87,12 +87,16 @@ def __init__(
file which should be executed as the entry point to model
hosting. If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
image (str): A Docker image URI (default: None). If not specified, a
default image for PyTorch will be used.
py_version (str): Python version you want to use for executing your
model training code (default: 'py3').
framework_version (str): PyTorch version you want to use for
executing your model training code.
executing your model training code. Defaults to None. Required
unless ``image`` is provided.
py_version (str): Python version you want to use for executing your
model training code. Defaults to ``None``. Required unless
``image`` is provided.
image (str): A Docker image URI (default: None). If not specified, a
default image for PyTorch will be used. If ``framework_version``
or ``py_version`` are ``None``, then ``image`` is required. If
also ``None``, then a ``ValueError`` will be raised.
predictor_cls (callable[str, sagemaker.session.Session]): A function
to call to create a predictor with an endpoint name and
SageMaker ``Session``. If specified, ``deploy()`` returns the
Expand All @@ -109,22 +113,18 @@ def __init__(
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(PyTorchModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

if py_version == "py2":
validate_version_or_image_args(framework_version, py_version, image)
if py_version and py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version

if framework_version is None:
logger.warning(
empty_framework_version_warning(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION)
)
super(PyTorchModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

self.py_version = py_version
self.framework_version = framework_version or defaults.PYTORCH_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def pytorch_version(request):
return request.param


@pytest.fixture(scope="module", params=["py2", "py3"])
def pytorch_py_version(request):
return request.param


@pytest.fixture(scope="module", params=["0.20.0"])
def sklearn_version(request):
return request.param
Expand Down
6 changes: 4 additions & 2 deletions tests/integ/test_airflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,13 +608,14 @@ def test_xgboost_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu

@pytest.mark.canary_quick
def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provided(
sagemaker_session, cpu_instance_type
sagemaker_session, cpu_instance_type, pytorch_full_version
):
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
estimator = PyTorch(
entry_point=PYTORCH_MNIST_SCRIPT,
role=ROLE,
framework_version="1.1.0",
framework_version=pytorch_full_version,
py_version="py3",
train_instance_count=2,
train_instance_type=cpu_instance_type,
hyperparameters={"epochs": 6, "backend": "gloo"},
Expand All @@ -639,6 +640,7 @@ def test_pytorch_12_airflow_config_uploads_data_source_to_s3_when_inputs_not_pro
entry_point=PYTORCH_MNIST_SCRIPT,
role=ROLE,
framework_version="1.2.0",
py_version="py3",
train_instance_count=2,
train_instance_type=cpu_instance_type,
hyperparameters={"epochs": 6, "backend": "gloo"},
Expand Down
6 changes: 4 additions & 2 deletions tests/integ/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from tests.integ import lock as lock
from sagemaker.mxnet.estimator import MXNet
from sagemaker.pytorch.defaults import PYTORCH_VERSION
from sagemaker.pytorch.estimator import PyTorch
from sagemaker.sklearn.estimator import SKLearn
from sagemaker.sklearn.model import SKLearnModel
Expand Down Expand Up @@ -56,11 +55,14 @@ def test_github(sagemaker_local_session):
script_path = "mnist.py"
data_path = os.path.join(DATA_DIR, "pytorch_mnist")
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}

# TODO: fails for newer pytorch versions when using MNIST from torchvision due to missing dataset
# "algo-1-v767u_1 | RuntimeError: Dataset not found. You can use download=True to download it"
pytorch = PyTorch(
entry_point=script_path,
role="SageMakerRole",
source_dir="pytorch",
framework_version=PYTORCH_VERSION,
framework_version="0.4", # hard-code to last known good pytorch for now (see TODO above)
py_version=PYTHON_VERSION,
train_instance_count=1,
train_instance_type="local",
Expand Down
20 changes: 9 additions & 11 deletions tests/integ/test_pytorch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,9 @@ def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
predictor.delete_endpoint()


@pytest.mark.skipif(
PYTHON_VERSION == "py2",
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
)
def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type):
def test_deploy_model(
pytorch_training_job, sagemaker_session, cpu_instance_type, pytorch_full_version
):
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
Expand All @@ -114,6 +112,8 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
model_data,
"SageMakerRole",
entry_point=MNIST_SCRIPT,
framework_version=pytorch_full_version,
py_version="py3",
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
Expand All @@ -125,10 +125,6 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
assert output.shape == (batch_size, 10)


@pytest.mark.skipif(
PYTHON_VERSION == "py2",
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
)
def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instance_type):
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())

Expand All @@ -139,6 +135,7 @@ def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instan
"SageMakerRole",
entry_point="mnist.py",
framework_version="1.4.0",
py_version="py3",
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
Expand All @@ -160,8 +157,9 @@ def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
pytorch = PyTorchModel(
model_data,
"SageMakerRole",
framework_version="1.3.1",
entry_point=EIA_SCRIPT,
framework_version="1.3.1",
py_version="py3",
sagemaker_session=sagemaker_session,
)
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
Expand Down Expand Up @@ -193,7 +191,7 @@ def _get_pytorch_estimator(
entry_point=entry_point,
role="SageMakerRole",
framework_version=pytorch_full_version,
py_version=PYTHON_VERSION,
py_version="py3",
train_instance_count=1,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
Expand Down
3 changes: 3 additions & 0 deletions tests/integ/test_source_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
with open(lib, "w") as f:
f.write("def question(to_anything): return 42")

# TODO: fails on newer versions of pytorch in call to np.load(BytesIO(stream.read()))
# "ValueError: Cannot load file containing pickled data when allow_pickle=False"
estimator = PyTorch(
entry_point="train.py",
role="SageMakerRole",
source_dir=source_dir,
dependencies=[lib],
framework_version="0.4", # hard-code to last known good pytorch for now (see TODO above)
py_version=PYTHON_VERSION,
train_instance_count=1,
train_instance_type="local",
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_transform_pytorch_vpc_custom_model_bucket(
entry_point=os.path.join(data_dir, "mnist.py"),
role="SageMakerRole",
framework_version=pytorch_full_version,
py_version=PYTHON_VERSION,
py_version="py3",
sagemaker_session=sagemaker_session,
vpc_config={"Subnets": subnet_ids, "SecurityGroupIds": [security_group_id]},
code_location="s3://{}".format(custom_bucket_name),
Expand Down
5 changes: 3 additions & 2 deletions tests/integ/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,15 +819,16 @@ def test_tuning_chainer(sagemaker_session, cpu_instance_type):
reason="This test has always failed, but the failure was masked by a bug. "
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
)
def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type):
def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type, pytorch_full_version):
mnist_dir = os.path.join(DATA_DIR, "pytorch_mnist")
mnist_script = os.path.join(mnist_dir, "mnist.py")

estimator = PyTorch(
entry_point=mnist_script,
role="SageMakerRole",
train_instance_count=1,
py_version=PYTHON_VERSION,
framework_version=pytorch_full_version,
py_version="py3",
train_instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
)
Expand Down
Loading