This repository contains a basic image classification model implemented in PyTorch.
This model is built with a resnet 101 architecture and trained on the ImageNet dataset.
For more information on the implementation of this model, visit PyTorch’s Torchvision Models page.
To run the main.py script:
-
Install Python 3.6 or greater.
-
Install the version of PyTorch that fits your machine’s resources (i.e., install CPU version if you do not use GPUs).
-
Install Pillow Imaging Library.
Clone the repository:
Once your environment is set up and the requirements are installed,
Run the download_weights.py script to download the weights fromt the PyTorch Torchvision models library:
python download_weights.py
Run the main.py script and add the image_name and labels arguments:
python main.py dog.jpg imagenet_classes.txt
The imagenet_classes.txt holds all classes for ImageNet, and the ./data directory contains a sample image.