Skip to content

Commit e9cc491

Browse files
committed
Add simple PyTorch example
1 parent e1cfdf2 commit e9cc491

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

examples/PyTorch/main.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Taken from https://github.com/pytorch/examples/blob/main/mnist/main.py
2+
from __future__ import print_function
3+
import argparse
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
import torch.optim as optim
8+
from torchvision import datasets, transforms
9+
from torch.optim.lr_scheduler import StepLR
10+
from simvue import Run
11+
12+
13+
class Net(nn.Module):
14+
def __init__(self):
15+
super(Net, self).__init__()
16+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
17+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
18+
self.dropout1 = nn.Dropout(0.25)
19+
self.dropout2 = nn.Dropout(0.5)
20+
self.fc1 = nn.Linear(9216, 128)
21+
self.fc2 = nn.Linear(128, 10)
22+
23+
def forward(self, x):
24+
x = self.conv1(x)
25+
x = F.relu(x)
26+
x = self.conv2(x)
27+
x = F.relu(x)
28+
x = F.max_pool2d(x, 2)
29+
x = self.dropout1(x)
30+
x = torch.flatten(x, 1)
31+
x = self.fc1(x)
32+
x = F.relu(x)
33+
x = self.dropout2(x)
34+
x = self.fc2(x)
35+
output = F.log_softmax(x, dim=1)
36+
return output
37+
38+
39+
def train(args, model, device, train_loader, optimizer, epoch, run):
40+
model.train()
41+
for batch_idx, (data, target) in enumerate(train_loader):
42+
data, target = data.to(device), target.to(device)
43+
optimizer.zero_grad()
44+
output = model(data)
45+
loss = F.nll_loss(output, target)
46+
loss.backward()
47+
optimizer.step()
48+
if batch_idx % args.log_interval == 0:
49+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
50+
epoch, batch_idx * len(data), len(train_loader.dataset),
51+
100. * batch_idx / len(train_loader), loss.item()))
52+
run.log_metrics({"train.loss.%d" % epoch: float(loss.item())}, step=batch_idx)
53+
if args.dry_run:
54+
break
55+
56+
57+
def test(model, device, test_loader, epoch, run):
58+
model.eval()
59+
test_loss = 0
60+
correct = 0
61+
with torch.no_grad():
62+
for data, target in test_loader:
63+
data, target = data.to(device), target.to(device)
64+
output = model(data)
65+
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
66+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
67+
correct += pred.eq(target.view_as(pred)).sum().item()
68+
69+
test_loss /= len(test_loader.dataset)
70+
test_accuracy = 100. * correct / len(test_loader.dataset)
71+
72+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
73+
test_loss, correct, len(test_loader.dataset),
74+
test_accuracy))
75+
run.log_metrics({'test.loss': test_loss,
76+
'test.accuracy': test_accuracy}, step=epoch)
77+
78+
79+
def main():
80+
# Training settings
81+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
82+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
83+
help='input batch size for training (default: 64)')
84+
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
85+
help='input batch size for testing (default: 1000)')
86+
parser.add_argument('--epochs', type=int, default=14, metavar='N',
87+
help='number of epochs to train (default: 14)')
88+
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
89+
help='learning rate (default: 1.0)')
90+
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
91+
help='Learning rate step gamma (default: 0.7)')
92+
parser.add_argument('--no-cuda', action='store_true', default=False,
93+
help='disables CUDA training')
94+
parser.add_argument('--no-mps', action='store_true', default=False,
95+
help='disables macOS GPU training')
96+
parser.add_argument('--dry-run', action='store_true', default=False,
97+
help='quickly check a single pass')
98+
parser.add_argument('--seed', type=int, default=1, metavar='S',
99+
help='random seed (default: 1)')
100+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
101+
help='how many batches to wait before logging training status')
102+
parser.add_argument('--save-model', action='store_true', default=False,
103+
help='For Saving the current Model')
104+
args = parser.parse_args()
105+
use_cuda = not args.no_cuda and torch.cuda.is_available()
106+
use_mps = not args.no_mps and torch.backends.mps.is_available()
107+
108+
torch.manual_seed(args.seed)
109+
110+
if use_cuda:
111+
device = torch.device("cuda")
112+
elif use_mps:
113+
device = torch.device("mps")
114+
else:
115+
device = torch.device("cpu")
116+
117+
train_kwargs = {'batch_size': args.batch_size}
118+
test_kwargs = {'batch_size': args.test_batch_size}
119+
if use_cuda:
120+
cuda_kwargs = {'num_workers': 1,
121+
'pin_memory': True,
122+
'shuffle': True}
123+
train_kwargs.update(cuda_kwargs)
124+
test_kwargs.update(cuda_kwargs)
125+
126+
transform=transforms.Compose([
127+
transforms.ToTensor(),
128+
transforms.Normalize((0.1307,), (0.3081,))
129+
])
130+
dataset1 = datasets.MNIST('../data', train=True, download=True,
131+
transform=transform)
132+
dataset2 = datasets.MNIST('../data', train=False,
133+
transform=transform)
134+
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
135+
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
136+
137+
model = Net().to(device)
138+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
139+
140+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
141+
142+
run = Run()
143+
run.init(tags=['PyTorch'])
144+
145+
for epoch in range(1, args.epochs + 1):
146+
train(args, model, device, train_loader, optimizer, epoch, run)
147+
test(model, device, test_loader, epoch, run)
148+
scheduler.step()
149+
150+
if args.save_model:
151+
run.save(model.state_dict(), "output", name="mnist_cnn.pt")
152+
153+
run.close()
154+
155+
156+
if __name__ == '__main__':
157+
main()
158+

examples/PyTorch/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch
2+
torchvision
3+
simvue

0 commit comments

Comments
 (0)