Skip to content
This repository was archived by the owner on Dec 17, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions handwritten_digit_recog/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
MAKE := $(MAKE) -e
DOCKER_BUILD_CMD = docker build

all:
$(MAKE) train serve website

train:
$(DOCKER_BUILD_CMD) --build-arg stage=train -f docker/Dockerfile.$(TYPE) -t $(REGISTRY)/dlrs-train-$(TYPE) .

serve:
$(DOCKER_BUILD_CMD) --build-arg stage=serve -f docker/Dockerfile.$(TYPE) -t $(REGISTRY)/dlrs-serve-$(TYPE) .

website:
$(DOCKER_BUILD_CMD) -f docker/Dockerfile.website -t $(REGISTRY)/dlrs-website .

.PHONY: all train serve website
50 changes: 50 additions & 0 deletions handwritten_digit_recog/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Handwritten Digit Recognition with DLRS

The content in this directory is for running a handwritten digit recognition example using the Deep Learning Reference Stack, Pytorch and MNIST.

#### Pre-requisites

* Docker

## Running on containers

Please follow these instructions to train the model and classify random handwritten digits on DLRS based Docker containers.

### Set up

Set TYPE and REGISTRY env variables
TYPE options: mkl or oss
REGISTRY options: registry name

```bash
export TYPE=<oss or mkl>
export REGISTRY=<your registry>
make
```

### Train

```bash
mkdir models
docker run --rm -ti -v ${PWD}/models:/workdir/models $REGISTRY/dlrs-train-$TYPE:latest "-s train"
```

### Serving the model for live classification

```bash
docker run -p 5059:5059 -it -v ${PWD}/models:/workdir/models $REGISTRY/dlrs-serve-$TYPE:latest "-s serve"
curl -i -X POST -d 'Classify' http://localhost:5059/digit_recognition/classify
```

### Website

We have created a simple website template for you to interact with.

```bash
docker run --rm -p 8080:5000 -it $REGISTRY/dlrs-website:latest --website_endpoint 0.0.0.0
Open localhost:8080 on a web browser
```

## Running on Kubeflow pipelines

We have created a Kubeflow Pipeline to run this example. Please go to [Kubeflow Pipelines](https://github.com/intel/stacks-usecase/kubeflow/pipelines) for more details.
15 changes: 15 additions & 0 deletions handwritten_digit_recog/docker/Dockerfile.mkl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
FROM clearlinux/stacks-pytorch-mkl:v0.5.0

ARG stage

ENV PATH=$PATH:/opt/conda/bin/ \
LD_LIBRARY_PATH=/usr/lib64:/opt/conda/lib \
STAGE=$stage

WORKDIR /workdir
COPY python/ python/
COPY scripts/entrypoint.sh .

RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]
14 changes: 14 additions & 0 deletions handwritten_digit_recog/docker/Dockerfile.oss
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM clearlinux/stacks-pytorch-oss:v0.5.0

ARG stage
ENV STAGE=$stage

WORKDIR /workdir
COPY python/ python/
COPY scripts/entrypoint.sh .

RUN chmod +x entrypoint.sh

EXPOSE 5059

ENTRYPOINT ["./entrypoint.sh"]
9 changes: 9 additions & 0 deletions handwritten_digit_recog/docker/Dockerfile.website
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
FROM clearlinux/stacks-pytorch-oss:v0.5.0

WORKDIR /workdir/website
COPY website/ /workdir/website/

EXPOSE 5000

SHELL ["/bin/bash", "-c"]
ENTRYPOINT ["python", "app.py"]
63 changes: 63 additions & 0 deletions handwritten_digit_recog/python/classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

from train import Net

# Load pre-trained model
model_path = "/workdir/models/mnist_cnn.pt"
device = torch.device("cpu")
model = Net().to(device)
model.load_state_dict(torch.load(model_path))

# Use a transform to normalize data (same as in training)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# Load training data
kwargs = {"num_workers": 1, "pin_memory": True}
test_loader = datasets.MNIST("data", download=True, train=False, transform=transform)
data_loader = torch.utils.data.DataLoader(test_loader, batch_size=64, shuffle=True)


def img_show(img, ps, probab):
ps = ps.data.numpy().squeeze()
fig, (ax1, ax2) = plt.subplots(figsize=(5, 3), ncols=2)
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze(), cmap="gray_r")
ax1.axis("off")
ax1.set_title("Random Test Image")
ax1.text(
5,
30,
"Predicted value: %s" % probab.index(max(probab)),
fontsize=10,
bbox={"facecolor": "cornsilk", "boxstyle": "round", "alpha": 0.5},
)
ax2.barh(np.arange(10), ps, color="gold")
ax2.set_aspect(0.1)
ax2.set_yticks(np.arange(10))
ax2.set_yticklabels(np.arange(10))
ax2.set_title("Probability Chart")
ax2.set_xlim(0, 1.1)
plt.grid(True)
plt.tight_layout()
plt.show()


# Function for classifying random handwritten numbers from the MNIST dataset
def classify(imgshow=False):
images, labels = next(iter(data_loader))
img = images[0].view(1, 1, 28, 28)
with torch.no_grad():
logps = model(img)
ps = torch.exp(logps)
probab = list(ps.numpy()[0])
if imgshow:
img_show(img, ps, probab)
return img, probab


if __name__ == "__main__":
classify(imgshow=True)
52 changes: 52 additions & 0 deletions handwritten_digit_recog/python/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!usr/bin/env python
#
# Copyright (c) 2019 Intel Corporation
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A REST API for Pytorch MNIST Handwritten Digit Recognition"""

import flask

from classify import classify

app = flask.Flask("Handwritten Digit Recognition")

banner = {
"what": "Handwritten Digit Recognition",
"usage": {
"Client": "curl -i -X POST -d 'Classify' http://localhost:5059/digit_recognition/classify",
"Server": "docker run -d -p 5059:5059 stacks_handwritten_digit_recog",
},
}


@app.route("/digit_recognition/", methods=["GET"])
def index():
return flask.jsonify(banner), 201


@app.route("/digit_recognition/classify", methods=["POST"])
def digit_recog():
img, probab = classify(imgshow=False)
return flask.jsonify({"Prediction": probab.index(max(probab))}), 201


@app.errorhandler(404)
def not_found(error):
return flask.make_response(flask.jsonify({"error": "Not found"}), 404)


if __name__ == "__main__":
app.run(host="0.0.0.0", port=5059)
Loading