Skip to content

Commit b94ff3a

Browse files
Merge pull request #13 from NiklasGustafsson/main
Update examples and tutorials to 0.95.4
2 parents 2166036 + 0b3f52b commit b94ff3a

26 files changed

+498
-238
lines changed

src/CSharp/CSharpExamples/AdversarialExampleGeneration.cs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,24 +148,26 @@ private static double Test(
148148

149149
foreach (var (data, target) in dataLoader) {
150150

151-
data.requires_grad = true;
151+
using (var d = torch.NewDisposeScope())
152+
{
153+
data.requires_grad = true;
152154

153-
using (var output = model.forward(data))
154-
using (var loss = criterion(output, target)) {
155+
using (var output = model.forward(data))
156+
using (var loss = criterion(output, target))
157+
{
155158

156-
model.zero_grad();
157-
loss.backward();
159+
model.zero_grad();
160+
loss.backward();
158161

159-
var perturbed = Attack(data, ε, data.grad());
162+
var perturbed = Attack(data, ε, data.grad());
160163

161-
using (var final = model.forward(perturbed)) {
164+
using (var final = model.forward(perturbed))
165+
{
162166

163-
correct += final.argmax(1).eq(target).sum().ToInt32();
167+
correct += final.argmax(1).eq(target).sum().ToInt32();
168+
}
164169
}
165170
}
166-
167-
168-
GC.Collect();
169171
}
170172

171173
return (double)correct / size;

src/CSharp/CSharpExamples/CIFAR10.cs

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ internal static void Run(int epochs, int timeout, string modelName)
4949
torch.cuda.is_available() ? torch.CUDA :
5050
torch.CPU;
5151

52-
if (device.type == DeviceType.CUDA) {
52+
if (device.type == DeviceType.CUDA)
53+
{
5354
_trainBatchSize *= 8;
5455
_testBatchSize *= 8;
5556
}
@@ -61,7 +62,8 @@ internal static void Run(int epochs, int timeout, string modelName)
6162
var sourceDir = _dataLocation;
6263
var targetDir = Path.Combine(_dataLocation, "test_data");
6364

64-
if (!Directory.Exists(targetDir)) {
65+
if (!Directory.Exists(targetDir))
66+
{
6567
Directory.CreateDirectory(targetDir);
6668
Decompress.ExtractTGZ(Path.Combine(sourceDir, "cifar-10-binary.tar.gz"), targetDir);
6769
}
@@ -70,40 +72,41 @@ internal static void Run(int epochs, int timeout, string modelName)
7072

7173
Module model = null;
7274

73-
switch (modelName.ToLower()) {
74-
case "alexnet":
75-
model = new AlexNet(modelName, _numClasses, device);
76-
break;
77-
case "mobilenet":
78-
model = new MobileNet(modelName, _numClasses, device);
79-
break;
80-
case "vgg11":
81-
case "vgg13":
82-
case "vgg16":
83-
case "vgg19":
84-
model = new VGG(modelName, _numClasses, device);
85-
break;
86-
case "resnet18":
87-
model = ResNet.ResNet18(_numClasses, device);
88-
break;
89-
case "resnet34":
90-
_testBatchSize /= 4;
91-
model = ResNet.ResNet34(_numClasses, device);
92-
break;
93-
case "resnet50":
94-
_trainBatchSize /= 6;
95-
_testBatchSize /= 8;
96-
model = ResNet.ResNet50(_numClasses, device);
97-
break;
98-
case "resnet101":
99-
_trainBatchSize /= 6;
100-
_testBatchSize /= 8;
101-
model = ResNet.ResNet101(_numClasses, device);
102-
break;
103-
case "resnet152":
104-
_testBatchSize /= 4;
105-
model = ResNet.ResNet152(_numClasses, device);
106-
break;
75+
switch (modelName.ToLower())
76+
{
77+
case "alexnet":
78+
model = new AlexNet(modelName, _numClasses, device);
79+
break;
80+
case "mobilenet":
81+
model = new MobileNet(modelName, _numClasses, device);
82+
break;
83+
case "vgg11":
84+
case "vgg13":
85+
case "vgg16":
86+
case "vgg19":
87+
model = new VGG(modelName, _numClasses, device);
88+
break;
89+
case "resnet18":
90+
model = ResNet.ResNet18(_numClasses, device);
91+
break;
92+
case "resnet34":
93+
_testBatchSize /= 4;
94+
model = ResNet.ResNet34(_numClasses, device);
95+
break;
96+
case "resnet50":
97+
_trainBatchSize /= 6;
98+
_testBatchSize /= 8;
99+
model = ResNet.ResNet50(_numClasses, device);
100+
break;
101+
case "resnet101":
102+
_trainBatchSize /= 6;
103+
_testBatchSize /= 8;
104+
model = ResNet.ResNet101(_numClasses, device);
105+
break;
106+
case "resnet152":
107+
_testBatchSize /= 4;
108+
model = ResNet.ResNet152(_numClasses, device);
109+
break;
107110
}
108111

109112
var hflip = transforms.HorizontalFlip();
@@ -116,19 +119,20 @@ internal static void Run(int epochs, int timeout, string modelName)
116119

117120
using (var train = new CIFARReader(targetDir, false, _trainBatchSize, shuffle: true, device: device, transforms: new ITransform[] { }))
118121
using (var test = new CIFARReader(targetDir, true, _testBatchSize, device: device))
119-
using (var optimizer = torch.optim.Adam(model.parameters(), 0.001)) {
122+
using (var optimizer = torch.optim.Adam(model.parameters(), 0.001))
123+
{
120124

121125
Stopwatch totalSW = new Stopwatch();
122126
totalSW.Start();
123127

124-
for (var epoch = 1; epoch <= epochs; epoch++) {
128+
for (var epoch = 1; epoch <= epochs; epoch++)
129+
{
125130

126131
Stopwatch epchSW = new Stopwatch();
127132
epchSW.Start();
128133

129134
Train(model, optimizer, nll_loss(), train.Data(), epoch, _trainBatchSize, train.Size);
130135
Test(model, nll_loss(), test.Data(), test.Size);
131-
GC.Collect();
132136

133137
epchSW.Stop();
134138
Console.WriteLine($"Elapsed time for this epoch: {epchSW.Elapsed.TotalSeconds} s.");
@@ -160,35 +164,33 @@ private static void Train(
160164

161165
Console.WriteLine($"Epoch: {epoch}...");
162166

163-
foreach (var (data, target) in dataLoader) {
167+
foreach (var (data, target) in dataLoader)
168+
{
164169

165-
optimizer.zero_grad();
170+
using (var d = torch.NewDisposeScope())
171+
{
172+
optimizer.zero_grad();
166173

167-
using var prediction = model.forward(data);
168-
using var lsm = log_softmax(prediction, 1);
169-
using (var output = loss(lsm, target)) {
174+
var prediction = model.forward(data);
175+
var lsm = log_softmax(prediction, 1);
176+
var output = loss(lsm, target);
170177

171178
output.backward();
172179

173180
optimizer.step();
174181

175182
total += target.shape[0];
176183

177-
using (var predicted = prediction.argmax(1))
178-
using (var eq = predicted.eq(target))
179-
using (var sum = eq.sum()) {
180-
correct += sum.ToInt64();
181-
}
184+
correct += prediction.argmax(1).eq(target).ToInt64();
182185

183-
if (batchId % _logInterval == 0) {
186+
if (batchId % _logInterval == 0)
187+
{
184188
var count = Math.Min(batchId * batchSize, size);
185189
Console.WriteLine($"\rTrain: epoch {epoch} [{count} / {size}] Loss: {output.ToSingle().ToString("0.000000")} | Accuracy: { ((float)correct / total).ToString("0.000000") }");
186190
}
187191

188192
batchId++;
189193
}
190-
191-
GC.Collect();
192194
}
193195
}
194196

@@ -204,23 +206,20 @@ private static void Test(
204206
long correct = 0;
205207
int batchCount = 0;
206208

207-
foreach (var (data, target) in dataLoader) {
209+
foreach (var (data, target) in dataLoader)
210+
{
208211

209-
using var prediction = model.forward(data);
210-
using var lsm = log_softmax(prediction, 1);
211-
using (var output = loss(lsm, target)) {
212+
using (var d = torch.NewDisposeScope())
213+
{
214+
var prediction = model.forward(data);
215+
var lsm = log_softmax(prediction, 1);
216+
var output = loss(lsm, target);
212217

213218
testLoss += output.ToSingle();
214219
batchCount += 1;
215220

216-
using (var predicted = prediction.argmax(1))
217-
using (var eq = predicted.eq(target))
218-
using (var sum = eq.sum()) {
219-
correct += sum.ToInt64();
220-
}
221+
correct += prediction.argmax(1).eq(target).ToInt64();
221222
}
222-
223-
GC.Collect();
224223
}
225224

226225
Console.WriteLine($"\rTest set: Average loss {(testLoss / batchCount).ToString("0.0000")} | Accuracy {((float)correct / size).ToString("0.0000")}");

src/CSharp/CSharpExamples/CSharpExamples.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
</ItemGroup>
1818

1919
<ItemGroup>
20-
<PackageReference Include="TorchSharp-cpu" Version="0.95.3" />
20+
<PackageReference Include="TorchSharp-cpu" Version="0.95.4" />
2121
</ItemGroup>
2222

2323
<ItemGroup>

src/CSharp/CSharpExamples/MNIST.cs

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ internal static void Run(int epochs, int timeout, string dataset)
4747
{
4848
_epochs = epochs;
4949

50-
if (string.IsNullOrEmpty(dataset)) {
50+
if (string.IsNullOrEmpty(dataset))
51+
{
5152
dataset = "mnist";
5253
}
5354

@@ -67,15 +68,17 @@ internal static void Run(int epochs, int timeout, string dataset)
6768
var sourceDir = datasetPath;
6869
var targetDir = Path.Combine(datasetPath, "test_data");
6970

70-
if (!Directory.Exists(targetDir)) {
71+
if (!Directory.Exists(targetDir))
72+
{
7173
Directory.CreateDirectory(targetDir);
7274
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "train-images-idx3-ubyte.gz"), targetDir);
7375
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "train-labels-idx1-ubyte.gz"), targetDir);
7476
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "t10k-images-idx3-ubyte.gz"), targetDir);
7577
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "t10k-labels-idx1-ubyte.gz"), targetDir);
7678
}
7779

78-
if (device.type == DeviceType.CUDA) {
80+
if (device.type == DeviceType.CUDA)
81+
{
7982
_trainBatchSize *= 4;
8083
_testBatchSize *= 4;
8184
}
@@ -90,7 +93,8 @@ internal static void Run(int epochs, int timeout, string dataset)
9093
Console.WriteLine();
9194

9295
using (MNISTReader train = new MNISTReader(targetDir, "train", _trainBatchSize, device: device, shuffle: true, transform: normImage),
93-
test = new MNISTReader(targetDir, "t10k", _testBatchSize, device: device, transform: normImage)) {
96+
test = new MNISTReader(targetDir, "t10k", _testBatchSize, device: device, transform: normImage))
97+
{
9498

9599
TrainingLoop(dataset, timeout, device, model, train, test);
96100
}
@@ -105,7 +109,8 @@ internal static void TrainingLoop(string dataset, int timeout, Device device, Mo
105109
Stopwatch totalTime = new Stopwatch();
106110
totalTime.Start();
107111

108-
for (var epoch = 1; epoch <= _epochs; epoch++) {
112+
for (var epoch = 1; epoch <= _epochs; epoch++)
113+
{
109114

110115
Train(model, optimizer, nll_loss(reduction: Reduction.Mean), device, train, epoch, train.BatchSize, train.Size);
111116
Test(model, nll_loss(reduction: nn.Reduction.Sum), device, test, test.Size);
@@ -137,23 +142,28 @@ private static void Train(
137142
int batchId = 1;
138143

139144
Console.WriteLine($"Epoch: {epoch}...");
140-
foreach (var (data, target) in dataLoader) {
141-
optimizer.zero_grad();
142145

143-
var prediction = model.forward(data);
144-
var output = loss(prediction, target);
146+
foreach (var (data, target) in dataLoader)
147+
{
148+
using (var d = torch.NewDisposeScope())
149+
{
150+
optimizer.zero_grad();
145151

146-
output.backward();
152+
var prediction = model.forward(data);
153+
var output = loss(prediction, target);
147154

148-
optimizer.step();
155+
output.backward();
149156

150-
if (batchId % _logInterval == 0) {
151-
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.ToSingle():F4}");
152-
}
157+
optimizer.step();
158+
159+
if (batchId % _logInterval == 0)
160+
{
161+
Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.ToSingle():F4}");
162+
}
153163

154-
batchId++;
164+
batchId++;
155165

156-
GC.Collect();
166+
}
157167
}
158168
}
159169

@@ -169,17 +179,16 @@ private static void Test(
169179
double testLoss = 0;
170180
int correct = 0;
171181

172-
foreach (var (data, target) in dataLoader) {
173-
var prediction = model.forward(data);
174-
var output = loss(prediction, target);
175-
testLoss += output.ToSingle();
182+
foreach (var (data, target) in dataLoader)
183+
{
184+
using (var d = torch.NewDisposeScope())
185+
{
186+
var prediction = model.forward(data);
187+
var output = loss(prediction, target);
188+
testLoss += output.ToSingle();
176189

177-
var pred = prediction.argmax(1);
178-
correct += pred.eq(target).sum().ToInt32();
179-
180-
pred.Dispose();
181-
182-
GC.Collect();
190+
correct += prediction.argmax(1).eq(target).sum().ToInt32();
191+
}
183192
}
184193

185194
Console.WriteLine($"Size: {size}, Total: {size}");

0 commit comments

Comments
 (0)