diff --git a/src/CSharp/CSharpExamples/AdversarialExampleGeneration.cs b/src/CSharp/CSharpExamples/AdversarialExampleGeneration.cs
index d34bbea..91bc8ac 100644
--- a/src/CSharp/CSharpExamples/AdversarialExampleGeneration.cs
+++ b/src/CSharp/CSharpExamples/AdversarialExampleGeneration.cs
@@ -148,24 +148,26 @@ private static double Test(
foreach (var (data, target) in dataLoader) {
- data.requires_grad = true;
+ using (var d = torch.NewDisposeScope())
+ {
+ data.requires_grad = true;
- using (var output = model.forward(data))
- using (var loss = criterion(output, target)) {
+ using (var output = model.forward(data))
+ using (var loss = criterion(output, target))
+ {
- model.zero_grad();
- loss.backward();
+ model.zero_grad();
+ loss.backward();
- var perturbed = Attack(data, ε, data.grad());
+ var perturbed = Attack(data, ε, data.grad());
- using (var final = model.forward(perturbed)) {
+ using (var final = model.forward(perturbed))
+ {
- correct += final.argmax(1).eq(target).sum().ToInt32();
+ correct += final.argmax(1).eq(target).sum().ToInt32();
+ }
}
}
-
-
- GC.Collect();
}
return (double)correct / size;
diff --git a/src/CSharp/CSharpExamples/CIFAR10.cs b/src/CSharp/CSharpExamples/CIFAR10.cs
index 48de4f6..3162791 100644
--- a/src/CSharp/CSharpExamples/CIFAR10.cs
+++ b/src/CSharp/CSharpExamples/CIFAR10.cs
@@ -49,7 +49,8 @@ internal static void Run(int epochs, int timeout, string modelName)
torch.cuda.is_available() ? torch.CUDA :
torch.CPU;
- if (device.type == DeviceType.CUDA) {
+ if (device.type == DeviceType.CUDA)
+ {
_trainBatchSize *= 8;
_testBatchSize *= 8;
}
@@ -61,7 +62,8 @@ internal static void Run(int epochs, int timeout, string modelName)
var sourceDir = _dataLocation;
var targetDir = Path.Combine(_dataLocation, "test_data");
- if (!Directory.Exists(targetDir)) {
+ if (!Directory.Exists(targetDir))
+ {
Directory.CreateDirectory(targetDir);
Decompress.ExtractTGZ(Path.Combine(sourceDir, "cifar-10-binary.tar.gz"), targetDir);
}
@@ -70,40 +72,41 @@ internal static void Run(int epochs, int timeout, string modelName)
Module model = null;
- switch (modelName.ToLower()) {
- case "alexnet":
- model = new AlexNet(modelName, _numClasses, device);
- break;
- case "mobilenet":
- model = new MobileNet(modelName, _numClasses, device);
- break;
- case "vgg11":
- case "vgg13":
- case "vgg16":
- case "vgg19":
- model = new VGG(modelName, _numClasses, device);
- break;
- case "resnet18":
- model = ResNet.ResNet18(_numClasses, device);
- break;
- case "resnet34":
- _testBatchSize /= 4;
- model = ResNet.ResNet34(_numClasses, device);
- break;
- case "resnet50":
- _trainBatchSize /= 6;
- _testBatchSize /= 8;
- model = ResNet.ResNet50(_numClasses, device);
- break;
- case "resnet101":
- _trainBatchSize /= 6;
- _testBatchSize /= 8;
- model = ResNet.ResNet101(_numClasses, device);
- break;
- case "resnet152":
- _testBatchSize /= 4;
- model = ResNet.ResNet152(_numClasses, device);
- break;
+ switch (modelName.ToLower())
+ {
+ case "alexnet":
+ model = new AlexNet(modelName, _numClasses, device);
+ break;
+ case "mobilenet":
+ model = new MobileNet(modelName, _numClasses, device);
+ break;
+ case "vgg11":
+ case "vgg13":
+ case "vgg16":
+ case "vgg19":
+ model = new VGG(modelName, _numClasses, device);
+ break;
+ case "resnet18":
+ model = ResNet.ResNet18(_numClasses, device);
+ break;
+ case "resnet34":
+ _testBatchSize /= 4;
+ model = ResNet.ResNet34(_numClasses, device);
+ break;
+ case "resnet50":
+ _trainBatchSize /= 6;
+ _testBatchSize /= 8;
+ model = ResNet.ResNet50(_numClasses, device);
+ break;
+ case "resnet101":
+ _trainBatchSize /= 6;
+ _testBatchSize /= 8;
+ model = ResNet.ResNet101(_numClasses, device);
+ break;
+ case "resnet152":
+ _testBatchSize /= 4;
+ model = ResNet.ResNet152(_numClasses, device);
+ break;
}
var hflip = transforms.HorizontalFlip();
@@ -116,19 +119,20 @@ internal static void Run(int epochs, int timeout, string modelName)
using (var train = new CIFARReader(targetDir, false, _trainBatchSize, shuffle: true, device: device, transforms: new ITransform[] { }))
using (var test = new CIFARReader(targetDir, true, _testBatchSize, device: device))
- using (var optimizer = torch.optim.Adam(model.parameters(), 0.001)) {
+ using (var optimizer = torch.optim.Adam(model.parameters(), 0.001))
+ {
Stopwatch totalSW = new Stopwatch();
totalSW.Start();
- for (var epoch = 1; epoch <= epochs; epoch++) {
+ for (var epoch = 1; epoch <= epochs; epoch++)
+ {
Stopwatch epchSW = new Stopwatch();
epchSW.Start();
Train(model, optimizer, nll_loss(), train.Data(), epoch, _trainBatchSize, train.Size);
Test(model, nll_loss(), test.Data(), test.Size);
- GC.Collect();
epchSW.Stop();
Console.WriteLine($"Elapsed time for this epoch: {epchSW.Elapsed.TotalSeconds} s.");
@@ -160,13 +164,16 @@ private static void Train(
Console.WriteLine($"Epoch: {epoch}...");
- foreach (var (data, target) in dataLoader) {
+ foreach (var (data, target) in dataLoader)
+ {
- optimizer.zero_grad();
+ using (var d = torch.NewDisposeScope())
+ {
+ optimizer.zero_grad();
- using var prediction = model.forward(data);
- using var lsm = log_softmax(prediction, 1);
- using (var output = loss(lsm, target)) {
+ var prediction = model.forward(data);
+ var lsm = log_softmax(prediction, 1);
+ var output = loss(lsm, target);
output.backward();
@@ -174,21 +181,16 @@ private static void Train(
total += target.shape[0];
- using (var predicted = prediction.argmax(1))
- using (var eq = predicted.eq(target))
- using (var sum = eq.sum()) {
- correct += sum.ToInt64();
- }
+ correct += prediction.argmax(1).eq(target).ToInt64();
- if (batchId % _logInterval == 0) {
+ if (batchId % _logInterval == 0)
+ {
var count = Math.Min(batchId * batchSize, size);
Console.WriteLine($"\rTrain: epoch {epoch} [{count} / {size}] Loss: {output.ToSingle().ToString("0.000000")} | Accuracy: { ((float)correct / total).ToString("0.000000") }");
}
batchId++;
}
-
- GC.Collect();
}
}
@@ -204,23 +206,20 @@ private static void Test(
long correct = 0;
int batchCount = 0;
- foreach (var (data, target) in dataLoader) {
+ foreach (var (data, target) in dataLoader)
+ {
- using var prediction = model.forward(data);
- using var lsm = log_softmax(prediction, 1);
- using (var output = loss(lsm, target)) {
+ using (var d = torch.NewDisposeScope())
+ {
+ var prediction = model.forward(data);
+ var lsm = log_softmax(prediction, 1);
+ var output = loss(lsm, target);
testLoss += output.ToSingle();
batchCount += 1;
- using (var predicted = prediction.argmax(1))
- using (var eq = predicted.eq(target))
- using (var sum = eq.sum()) {
- correct += sum.ToInt64();
- }
+ correct += prediction.argmax(1).eq(target).ToInt64();
}
-
- GC.Collect();
}
Console.WriteLine($"\rTest set: Average loss {(testLoss / batchCount).ToString("0.0000")} | Accuracy {((float)correct / size).ToString("0.0000")}");
diff --git a/src/CSharp/CSharpExamples/CSharpExamples.csproj b/src/CSharp/CSharpExamples/CSharpExamples.csproj
index 5eb0500..31c26d9 100644
--- a/src/CSharp/CSharpExamples/CSharpExamples.csproj
+++ b/src/CSharp/CSharpExamples/CSharpExamples.csproj
@@ -17,7 +17,7 @@
-
+
diff --git a/src/CSharp/CSharpExamples/MNIST.cs b/src/CSharp/CSharpExamples/MNIST.cs
index 14e7fe3..cf4ddd4 100644
--- a/src/CSharp/CSharpExamples/MNIST.cs
+++ b/src/CSharp/CSharpExamples/MNIST.cs
@@ -47,7 +47,8 @@ internal static void Run(int epochs, int timeout, string dataset)
{
_epochs = epochs;
- if (string.IsNullOrEmpty(dataset)) {
+ if (string.IsNullOrEmpty(dataset))
+ {
dataset = "mnist";
}
@@ -67,7 +68,8 @@ internal static void Run(int epochs, int timeout, string dataset)
var sourceDir = datasetPath;
var targetDir = Path.Combine(datasetPath, "test_data");
- if (!Directory.Exists(targetDir)) {
+ if (!Directory.Exists(targetDir))
+ {
Directory.CreateDirectory(targetDir);
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "train-images-idx3-ubyte.gz"), targetDir);
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "train-labels-idx1-ubyte.gz"), targetDir);
@@ -75,7 +77,8 @@ internal static void Run(int epochs, int timeout, string dataset)
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "t10k-labels-idx1-ubyte.gz"), targetDir);
}
- if (device.type == DeviceType.CUDA) {
+ if (device.type == DeviceType.CUDA)
+ {
_trainBatchSize *= 4;
_testBatchSize *= 4;
}
@@ -90,7 +93,8 @@ internal static void Run(int epochs, int timeout, string dataset)
Console.WriteLine();
using (MNISTReader train = new MNISTReader(targetDir, "train", _trainBatchSize, device: device, shuffle: true, transform: normImage),
- test = new MNISTReader(targetDir, "t10k", _testBatchSize, device: device, transform: normImage)) {
+ test = new MNISTReader(targetDir, "t10k", _testBatchSize, device: device, transform: normImage))
+ {
TrainingLoop(dataset, timeout, device, model, train, test);
}
@@ -105,7 +109,8 @@ internal static void TrainingLoop(string dataset, int timeout, Device device, Mo
Stopwatch totalTime = new Stopwatch();
totalTime.Start();
- for (var epoch = 1; epoch <= _epochs; epoch++) {
+ for (var epoch = 1; epoch <= _epochs; epoch++)
+ {
Train(model, optimizer, nll_loss(reduction: Reduction.Mean), device, train, epoch, train.BatchSize, train.Size);
Test(model, nll_loss(reduction: nn.Reduction.Sum), device, test, test.Size);
@@ -137,23 +142,28 @@ private static void Train(
int batchId = 1;
Console.WriteLine($"Epoch: {epoch}...");
- foreach (var (data, target) in dataLoader) {
- optimizer.zero_grad();
- var prediction = model.forward(data);
- var output = loss(prediction, target);
+ foreach (var (data, target) in dataLoader)
+ {
+ using (var d = torch.NewDisposeScope())
+ {
+ optimizer.zero_grad();
- output.backward();
+ var prediction = model.forward(data);
+ var output = loss(prediction, target);
- optimizer.step();
+ output.backward();
- if (batchId % _logInterval == 0) {
- Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.ToSingle():F4}");
- }
+ optimizer.step();
+
+ if (batchId % _logInterval == 0)
+ {
+ Console.WriteLine($"\rTrain: epoch {epoch} [{batchId * batchSize} / {size}] Loss: {output.ToSingle():F4}");
+ }
- batchId++;
+ batchId++;
- GC.Collect();
+ }
}
}
@@ -169,17 +179,16 @@ private static void Test(
double testLoss = 0;
int correct = 0;
- foreach (var (data, target) in dataLoader) {
- var prediction = model.forward(data);
- var output = loss(prediction, target);
- testLoss += output.ToSingle();
+ foreach (var (data, target) in dataLoader)
+ {
+ using (var d = torch.NewDisposeScope())
+ {
+ var prediction = model.forward(data);
+ var output = loss(prediction, target);
+ testLoss += output.ToSingle();
- var pred = prediction.argmax(1);
- correct += pred.eq(target).sum().ToInt32();
-
- pred.Dispose();
-
- GC.Collect();
+ correct += prediction.argmax(1).eq(target).sum().ToInt32();
+ }
}
Console.WriteLine($"Size: {size}, Total: {size}");
diff --git a/src/CSharp/CSharpExamples/SequenceToSequence.cs b/src/CSharp/CSharpExamples/SequenceToSequence.cs
index 66e38c9..039e2d6 100644
--- a/src/CSharp/CSharpExamples/SequenceToSequence.cs
+++ b/src/CSharp/CSharpExamples/SequenceToSequence.cs
@@ -63,7 +63,8 @@ internal static void Run(int epochs, int timeout)
var tokenizer = TorchText.Data.Utils.get_tokenizer("basic_english");
var counter = new TorchText.Vocab.Counter();
- foreach (var item in vocab_iter) {
+ foreach (var item in vocab_iter)
+ {
counter.update(tokenizer(item));
}
@@ -91,7 +92,8 @@ internal static void Run(int epochs, int timeout)
var totalTime = new Stopwatch();
totalTime.Start();
- foreach (var epoch in Enumerable.Range(1, epochs)) {
+ foreach (var epoch in Enumerable.Range(1, epochs))
+ {
var sw = new Stopwatch();
sw.Start();
@@ -101,7 +103,7 @@ internal static void Run(int epochs, int timeout)
var val_loss = evaluate(valid_data, model, loss, bptt, ntokens, optimizer);
sw.Stop();
- Console.WriteLine($"\nEnd of epoch: {epoch} | lr: {scheduler.LearningRate:0.00} | time: {sw.Elapsed.TotalSeconds:0.0}s | loss: {val_loss:0.00}\n");
+ Console.WriteLine($"\nEnd of epoch: {epoch} | lr: {optimizer.LearningRate:0.00} | time: {sw.Elapsed.TotalSeconds:0.0}s | loss: {val_loss:0.00}\n");
scheduler.step();
if (totalTime.Elapsed.TotalSeconds > timeout) break;
@@ -119,81 +121,94 @@ private static void train(int epoch, Tensor train_data, TransformerModel model,
var total_loss = 0.0f;
- var src_mask = model.GenerateSquareSubsequentMask(bptt);
+ using (var d = torch.NewDisposeScope())
+ {
+ var batch = 0;
+ var log_interval = 200;
- var batch = 0;
- var log_interval = 200;
+ var src_mask = model.GenerateSquareSubsequentMask(bptt);
- var tdlen = train_data.shape[0];
+ var tdlen = train_data.shape[0];
- for (int i = 0; i < tdlen - 1; batch++, i += bptt) {
- var (data, targets) = GetBatch(train_data, i, bptt);
- optimizer.zero_grad();
+ for (int i = 0; i < tdlen - 1; batch++, i += bptt)
+ {
- if (data.shape[0] != bptt) {
- src_mask.Dispose();
- src_mask = model.GenerateSquareSubsequentMask(data.shape[0]);
- }
+ var (data, targets) = GetBatch(train_data, i, bptt);
+ optimizer.zero_grad();
- using (var output = model.forward(data, src_mask)) {
- var loss = criterion(output.view(-1, ntokens), targets);
- loss.backward();
- torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5);
- optimizer.step();
+ if (data.shape[0] != bptt)
+ {
+ src_mask = model.GenerateSquareSubsequentMask(data.shape[0]);
+ }
- total_loss += loss.to(torch.CPU).item();
- }
+ using (var output = model.forward(data, src_mask))
+ {
+ var loss = criterion(output.view(-1, ntokens), targets);
+ loss.backward();
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5);
+ optimizer.step();
+
+ total_loss += loss.to(torch.CPU).item();
+ }
- GC.Collect();
+ if (batch % log_interval == 0 && batch > 0)
+ {
+ var cur_loss = total_loss / log_interval;
+ Console.WriteLine($"epoch: {epoch} | batch: {batch} / {tdlen / bptt} | loss: {cur_loss:0.00}");
+ total_loss = 0;
+ }
- if (batch % log_interval == 0 && batch > 0) {
- var cur_loss = total_loss / log_interval;
- Console.WriteLine($"epoch: {epoch} | batch: {batch} / {tdlen / bptt} | loss: {cur_loss:0.00}");
- total_loss = 0;
+ d.DisposeEverythingBut(src_mask);
}
}
-
- src_mask.Dispose();
}
private static double evaluate(Tensor eval_data, TransformerModel model, Loss criterion, int bptt, int ntokens, torch.optim.Optimizer optimizer)
{
model.Eval();
- var total_loss = 0.0f;
- var src_mask = model.GenerateSquareSubsequentMask(bptt);
- var batch = 0;
+ using (var d = torch.NewDisposeScope())
+ {
- for (int i = 0; i < eval_data.shape[0] - 1; batch++, i += bptt) {
+ var src_mask = model.GenerateSquareSubsequentMask(bptt);
- var (data, targets) = GetBatch(eval_data, i, bptt);
- if (data.shape[0] != bptt) {
- src_mask.Dispose();
- src_mask = model.GenerateSquareSubsequentMask(data.shape[0]);
- }
- using (var output = model.forward(data, src_mask)) {
- var loss = criterion(output.view(-1, ntokens), targets);
- total_loss += data.shape[0] * loss.to(torch.CPU).item();
- }
+ var total_loss = 0.0f;
+ var batch = 0;
- data.Dispose();
- targets.Dispose();
- GC.Collect();
- }
+ for (int i = 0; i < eval_data.shape[0] - 1; batch++, i += bptt)
+ {
+
+ var (data, targets) = GetBatch(eval_data, i, bptt);
+ if (data.shape[0] != bptt)
+ {
+ src_mask = model.GenerateSquareSubsequentMask(data.shape[0]);
+ }
+ using (var output = model.forward(data, src_mask))
+ {
+ var loss = criterion(output.view(-1, ntokens), targets);
+ total_loss += data.shape[0] * loss.to(torch.CPU).item();
+ }
- src_mask.Dispose();
+ data.Dispose();
+ targets.Dispose();
- return total_loss / eval_data.shape[0];
+ d.DisposeEverythingBut(src_mask);
+ }
+
+ return total_loss / eval_data.shape[0];
+ }
}
static Tensor ProcessInput(IEnumerable iter, Func> tokenizer, TorchText.Vocab.Vocab vocab)
{
List data = new List();
- foreach (var item in iter) {
+ foreach (var item in iter)
+ {
List itemData = new List();
- foreach (var token in tokenizer(item)) {
+ foreach (var token in tokenizer(item))
+ {
itemData.Add(vocab[token]);
}
data.Add(torch.tensor(itemData.ToArray(), torch.int64));
diff --git a/src/CSharp/CSharpExamples/TextClassification.cs b/src/CSharp/CSharpExamples/TextClassification.cs
index c2fbbcb..37c78dd 100644
--- a/src/CSharp/CSharpExamples/TextClassification.cs
+++ b/src/CSharp/CSharpExamples/TextClassification.cs
@@ -55,14 +55,16 @@ internal static void Run(int epochs, int timeout)
Console.WriteLine($"\tPreparing training and test data...");
- using (var reader = TorchText.Data.AG_NEWSReader.AG_NEWS("train", (Device)device, _dataLocation)) {
+ using (var reader = TorchText.Data.AG_NEWSReader.AG_NEWS("train", (Device)device, _dataLocation))
+ {
var dataloader = reader.Enumerate();
var tokenizer = TorchText.Data.Utils.get_tokenizer("basic_english");
var counter = new TorchText.Vocab.Counter();
- foreach (var (label, text) in dataloader) {
+ foreach (var (label, text) in dataloader)
+ {
counter.update(tokenizer(text));
}
@@ -82,7 +84,8 @@ internal static void Run(int epochs, int timeout)
var totalTime = new Stopwatch();
totalTime.Start();
- foreach (var epoch in Enumerable.Range(1, epochs)) {
+ foreach (var epoch in Enumerable.Range(1, epochs))
+ {
var sw = new Stopwatch();
sw.Start();
@@ -91,7 +94,7 @@ internal static void Run(int epochs, int timeout)
sw.Stop();
- Console.WriteLine($"\nEnd of epoch: {epoch} | lr: {scheduler.LearningRate:0.0000} | time: {sw.Elapsed.TotalSeconds:0.0}s\n");
+ Console.WriteLine($"\nEnd of epoch: {epoch} | lr: {optimizer.LearningRate:0.0000} | time: {sw.Elapsed.TotalSeconds:0.0}s\n");
scheduler.step();
if (totalTime.Elapsed.TotalSeconds > timeout) break;
@@ -99,7 +102,8 @@ internal static void Run(int epochs, int timeout)
totalTime.Stop();
- using (var test_reader = TorchText.Data.AG_NEWSReader.AG_NEWS("test", (Device)device, _dataLocation)) {
+ using (var test_reader = TorchText.Data.AG_NEWSReader.AG_NEWS("test", (Device)device, _dataLocation))
+ {
var sw = new Stopwatch();
sw.Start();
@@ -127,32 +131,33 @@ static void train(int epoch, IEnumerable<(Tensor, Tensor, Tensor)> train_data, T
var batch_count = train_data.Count();
- foreach (var (labels, texts, offsets) in train_data) {
+ using (var d = torch.NewDisposeScope())
+ {
+ foreach (var (labels, texts, offsets) in train_data)
+ {
- optimizer.zero_grad();
+ optimizer.zero_grad();
- using (var predicted_labels = model.forward(texts, offsets)) {
+ using (var predicted_labels = model.forward(texts, offsets))
+ {
- var loss = criterion(predicted_labels, labels);
- loss.backward();
- torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5);
- optimizer.step();
+ var loss = criterion(predicted_labels, labels);
+ loss.backward();
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5);
+ optimizer.step();
- total_acc += (predicted_labels.argmax(1) == labels).sum().to(torch.CPU).item();
- total_count += labels.size(0);
- }
+ total_acc += (predicted_labels.argmax(1) == labels).sum().to(torch.CPU).item();
+ total_count += labels.size(0);
+ }
- if (batch % log_interval == 0 && batch > 0) {
- var accuracy = total_acc / total_count;
- Console.WriteLine($"epoch: {epoch} | batch: {batch} / {batch_count} | accuracy: {accuracy:0.00}");
+ if (batch % log_interval == 0 && batch > 0)
+ {
+ var accuracy = total_acc / total_count;
+ Console.WriteLine($"epoch: {epoch} | batch: {batch} / {batch_count} | accuracy: {accuracy:0.00}");
+ }
+ batch += 1;
}
- batch += 1;
}
-
- // This data set is small enough that we can get away with
- // collecting memory only once per epoch.
-
- GC.Collect();
}
static double evaluate(IEnumerable<(Tensor, Tensor, Tensor)> test_data, TextClassificationModel model, Loss criterion)
@@ -162,17 +167,22 @@ static double evaluate(IEnumerable<(Tensor, Tensor, Tensor)> test_data, TextClas
double total_acc = 0.0;
long total_count = 0;
- foreach (var (labels, texts, offsets) in test_data) {
+ using (var d = torch.NewDisposeScope())
+ {
+ foreach (var (labels, texts, offsets) in test_data)
+ {
- using (var predicted_labels = model.forward(texts, offsets)) {
- var loss = criterion(predicted_labels, labels);
+ using (var predicted_labels = model.forward(texts, offsets))
+ {
+ var loss = criterion(predicted_labels, labels);
- total_acc += (predicted_labels.argmax(1) == labels).sum().to(torch.CPU).item();
- total_count += labels.size(0);
+ total_acc += (predicted_labels.argmax(1) == labels).sum().to(torch.CPU).item();
+ total_count += labels.size(0);
+ }
}
- }
- return total_acc / total_count;
+ return total_acc / total_count;
+ }
}
}
}
diff --git a/src/CSharp/Models/AlexNet.cs b/src/CSharp/Models/AlexNet.cs
index d3fd5c5..5736a46 100644
--- a/src/CSharp/Models/AlexNet.cs
+++ b/src/CSharp/Models/AlexNet.cs
@@ -61,11 +61,12 @@ public AlexNet(string name, int numClasses, Device device = null) : base(name)
public override Tensor forward(Tensor input)
{
- using (var f = features.forward(input))
- using (var avg = avgPool.forward(f))
+ var f = features.forward(input);
+ var avg = avgPool.forward(f);
- using (var x = avg.view(new long[] { avg.shape[0], 256 * 2 * 2 }))
- return classifier.forward(x);
+ var x = avg.view(new long[] { avg.shape[0], 256 * 2 * 2 });
+
+ return classifier.forward(x);
}
}
diff --git a/src/CSharp/Models/Models.csproj b/src/CSharp/Models/Models.csproj
index 1f8dd15..a288854 100644
--- a/src/CSharp/Models/Models.csproj
+++ b/src/CSharp/Models/Models.csproj
@@ -5,7 +5,7 @@
-
+
diff --git a/src/CSharp/Models/ResNet.cs b/src/CSharp/Models/ResNet.cs
index 6e26906..f7fd309 100644
--- a/src/CSharp/Models/ResNet.cs
+++ b/src/CSharp/Models/ResNet.cs
@@ -145,7 +145,7 @@ public BasicBlock (string name, int in_planes, int planes, int stride) : base(na
public override Tensor forward(Tensor t)
{
var x = layers.forward(t);
- using var y = shortcut.forward(t);
+ var y = shortcut.forward(t);
return x.add_(y).relu_();
}
@@ -186,7 +186,7 @@ public Bottleneck(string name, int in_planes, int planes, int stride) : base(nam
public override Tensor forward(Tensor t)
{
var x = layers.forward(t);
- using var y = shortcut.forward(t);
+ var y = shortcut.forward(t);
return x.add_(y).relu_();
}
diff --git a/src/CSharp/Models/SequenceToSequence.cs b/src/CSharp/Models/SequenceToSequence.cs
index 5ed11bb..b6cc0f6 100644
--- a/src/CSharp/Models/SequenceToSequence.cs
+++ b/src/CSharp/Models/SequenceToSequence.cs
@@ -44,7 +44,7 @@ public TransformerModel(long ntokens, long ninputs, long nheads, long nhidden, l
public Tensor GenerateSquareSubsequentMask(long size)
{
- using var mask = (torch.ones(new long[] { size, size }) == 1).triu().transpose(0, 1);
+ var mask = (torch.ones(new long[] { size, size }) == 1).triu().transpose(0, 1);
return mask.to_type(ScalarType.Float32)
.masked_fill(mask == 0, float.NegativeInfinity)
.masked_fill(mask == 1, 0.0f).to(device);
@@ -54,9 +54,9 @@ private void InitWeights()
{
var initrange = 0.1;
- init.uniform_(encoder.Weight, -initrange, initrange);
- init.zeros_(decoder.Bias);
- init.uniform_(decoder.Weight, -initrange, initrange);
+ init.uniform_(encoder.weight, -initrange, initrange);
+ init.zeros_(decoder.bias);
+ init.uniform_(decoder.weight, -initrange, initrange);
}
public override Tensor forward(Tensor t)
@@ -99,7 +99,7 @@ public PositionalEncoding(long dmodel, double dropout, int maxLen = 5000) : base
public override Tensor forward(Tensor t)
{
- using var x = t + pe[TensorIndex.Slice(null, t.shape[0]), TensorIndex.Slice()];
+ var x = t + pe[TensorIndex.Slice(null, t.shape[0]), TensorIndex.Slice()];
return dropout.forward(x);
}
}
diff --git a/src/CSharp/Models/TextClassification.cs b/src/CSharp/Models/TextClassification.cs
index c5ed8f9..3288e74 100644
--- a/src/CSharp/Models/TextClassification.cs
+++ b/src/CSharp/Models/TextClassification.cs
@@ -35,9 +35,9 @@ private void InitWeights()
{
var initrange = 0.5;
- init.uniform_(embedding.Weight, -initrange, initrange);
- init.uniform_(fc.Weight, -initrange, initrange);
- init.zeros_(fc.Bias);
+ init.uniform_(embedding.weight, -initrange, initrange);
+ init.uniform_(fc.weight, -initrange, initrange);
+ init.zeros_(fc.bias);
}
public override Tensor forward(Tensor t)
@@ -47,7 +47,7 @@ public override Tensor forward(Tensor t)
public override Tensor forward(Tensor input, Tensor offsets)
{
- using var t = embedding.forward(input, offsets);
+ var t = embedding.forward(input, offsets);
return fc.forward(t);
}
diff --git a/src/FSharp/FSharpExamples/AdversarialExampleGeneration.fs b/src/FSharp/FSharpExamples/AdversarialExampleGeneration.fs
index 7847dd7..4ac251e 100644
--- a/src/FSharp/FSharpExamples/AdversarialExampleGeneration.fs
+++ b/src/FSharp/FSharpExamples/AdversarialExampleGeneration.fs
@@ -66,6 +66,8 @@ let test (model:MNIST.Model) (eps:float) (dataLoader:MNISTReader) size =
for (input,labels) in dataLoader do
+ use d = torch.NewDisposeScope()
+
input.requires_grad <- true
begin // This is introduced in order to let a few tensors go out of scope before GC
@@ -80,8 +82,6 @@ let test (model:MNIST.Model) (eps:float) (dataLoader:MNISTReader) size =
correct <- correct + final.argmax(1L).eq(labels).sum().ToInt32()
end
- GC.Collect()
-
float correct / size
let run epochs =
diff --git a/src/FSharp/FSharpExamples/AlexNet.fs b/src/FSharp/FSharpExamples/AlexNet.fs
index f60b6f0..4a40fc2 100644
--- a/src/FSharp/FSharpExamples/AlexNet.fs
+++ b/src/FSharp/FSharpExamples/AlexNet.fs
@@ -99,6 +99,9 @@ let train (model:Model) (optimizer:Optimizer) (dataLoader: CIFARReader) epoch =
printfn $"Epoch: {epoch}..."
for (input,labels) in dataLoader.Data() do
+
+ use d = torch.NewDisposeScope()
+
optimizer.zero_grad()
begin
@@ -122,8 +125,6 @@ let train (model:Model) (optimizer:Optimizer) (dataLoader: CIFARReader) epoch =
batchID <- batchID + 1
end
- GC.Collect()
-
let test (model:Model) (dataLoader:CIFARReader) =
model.Eval()
@@ -135,6 +136,8 @@ let test (model:Model) (dataLoader:CIFARReader) =
for (input,labels) in dataLoader.Data() do
+ use d = torch.NewDisposeScope()
+
use estimate = input --> model
use output = loss estimate labels
testLoss <- testLoss + output.ToSingle()
@@ -160,7 +163,6 @@ let trainingLoop (model:Model) epochs trainData testData =
for epoch = 1 to epochs do
train model optimizer trainData epoch
test model testData
- GC.Collect()
sw.Stop()
diff --git a/src/FSharp/FSharpExamples/FSharpExamples.fsproj b/src/FSharp/FSharpExamples/FSharpExamples.fsproj
index c5333c3..a098ff5 100644
--- a/src/FSharp/FSharpExamples/FSharpExamples.fsproj
+++ b/src/FSharp/FSharpExamples/FSharpExamples.fsproj
@@ -19,7 +19,7 @@
-
+
diff --git a/src/FSharp/FSharpExamples/MNIST.fs b/src/FSharp/FSharpExamples/MNIST.fs
index 7035413..1d78341 100644
--- a/src/FSharp/FSharpExamples/MNIST.fs
+++ b/src/FSharp/FSharpExamples/MNIST.fs
@@ -97,6 +97,9 @@ let train (model:Model) (optimizer:Optimizer) (dataLoader: MNISTReader) epoch =
printfn $"Epoch: {epoch}..."
for (input,labels) in dataLoader do
+
+ use d = torch.NewDisposeScope()
+
optimizer.zero_grad()
begin // This is introduced in order to let a few tensors go out of scope before GC
@@ -112,8 +115,6 @@ let train (model:Model) (optimizer:Optimizer) (dataLoader: MNISTReader) epoch =
batchID <- batchID + 1
end
- GC.Collect()
-
let test (model:Model) (dataLoader:MNISTReader) =
model.Eval()
@@ -124,6 +125,8 @@ let test (model:Model) (dataLoader:MNISTReader) =
for (input,labels) in dataLoader do
+ use d = torch.NewDisposeScope()
+
begin // This is introduced in order to let a few tensors go out of scope before GC
use estimate = input --> model
use output = loss estimate labels
@@ -133,8 +136,6 @@ let test (model:Model) (dataLoader:MNISTReader) =
correct <- correct + pred.eq(labels).sum().ToInt32()
end
- GC.Collect()
-
printfn $"Size: {sz}, Total: {sz}"
printfn $"\rTest set: Average loss {(testLoss / sz):F4} | Accuracy {(float32 correct / sz):P2}"
diff --git a/src/FSharp/FSharpExamples/Program.fs b/src/FSharp/FSharpExamples/Program.fs
index 1ccdc6d..b2a93c0 100644
--- a/src/FSharp/FSharpExamples/Program.fs
+++ b/src/FSharp/FSharpExamples/Program.fs
@@ -29,11 +29,11 @@ let main args =
| true,t -> t
| false,_ -> 3600
- for idx = 0 to argumentParser.Count do
+ for idx = 0 to argumentParser.Count-1 do
let modelName = argumentParser.[idx]
- match modelName with
+ match modelName.ToLowerInvariant() with
| "mnist" -> FSharpExamples.MNIST.run epochs
| "fgsm" -> FSharpExamples.AdversarialExampleGeneration.run epochs
| "alexnet" -> FSharpExamples.AlexNet.run epochs
diff --git a/src/FSharp/FSharpExamples/Properties/launchSettings.json b/src/FSharp/FSharpExamples/Properties/launchSettings.json
new file mode 100644
index 0000000..61b3584
--- /dev/null
+++ b/src/FSharp/FSharpExamples/Properties/launchSettings.json
@@ -0,0 +1,8 @@
+{
+ "profiles": {
+ "FSharpExamples": {
+ "commandName": "Project",
+ "commandLineArgs": "-e 5 MNIST"
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/FSharp/FSharpExamples/SequenceToSequence.fs b/src/FSharp/FSharpExamples/SequenceToSequence.fs
index 9453995..76dc4f8 100644
--- a/src/FSharp/FSharpExamples/SequenceToSequence.fs
+++ b/src/FSharp/FSharpExamples/SequenceToSequence.fs
@@ -90,9 +90,9 @@ type TransformerModel(ntokens, device:torch.Device) as this =
do
let initrange = 0.1
- init.uniform_(encoder.Weight, -initrange, initrange) |> ignore
- init.zeros_(decoder.Bias) |> ignore
- init.uniform_(decoder.Weight, -initrange, initrange) |> ignore
+ init.uniform_(encoder.weight, -initrange, initrange) |> ignore
+ init.zeros_(decoder.bias) |> ignore
+ init.uniform_(decoder.weight, -initrange, initrange) |> ignore
this.RegisterComponents()
@@ -151,6 +151,8 @@ let train epoch (model:TransformerModel) (optimizer:Optimizer) (trainData:torch.
while i < tdlen - 2L do
+ use d = torch.NewDisposeScope()
+
begin
let data,targets = get_batch trainData i
use data = data
@@ -169,9 +171,7 @@ let train epoch (model:TransformerModel) (optimizer:Optimizer) (trainData:torch.
optimizer.step() |> ignore
total_loss <- total_loss + loss.cpu().item()
- end
-
- GC.Collect()
+ end
if (batch % logInterval = 0) && (batch > 0) then
let cur_loss = (total_loss / (float32 logInterval)).ToString("0.00")
@@ -197,6 +197,8 @@ let evaluate (model:TransformerModel) (evalData:torch.Tensor) ntokens =
while i < tdlen - 2L do
+ use d = torch.NewDisposeScope()
+
begin
let data,targets = get_batch evalData i
use data = data
@@ -211,8 +213,6 @@ let evaluate (model:TransformerModel) (evalData:torch.Tensor) ntokens =
total_loss <- total_loss + (float32 data.shape.[0]) * loss.cpu().item()
end
- GC.Collect()
-
batch <- batch + 1L
i <- i + bptt
@@ -259,7 +259,7 @@ let run epochs =
let val_loss = evaluate model valid_data ntokens
sw.Stop()
- let lrStr = scheduler.LearningRate.ToString("0.00")
+ let lrStr = optimizer.LearningRate.ToString("0.00")
let elapsed = sw.Elapsed.TotalSeconds.ToString("0.0")
let lossStr = val_loss.ToString("0.00")
diff --git a/src/FSharp/FSharpExamples/TextClassification.fs b/src/FSharp/FSharpExamples/TextClassification.fs
index 77e650c..fb3ffa8 100644
--- a/src/FSharp/FSharpExamples/TextClassification.fs
+++ b/src/FSharp/FSharpExamples/TextClassification.fs
@@ -55,9 +55,9 @@ type TextClassificationModel(vocabSize, embedDim, nClasses, device:torch.Device)
do
let initrange = 0.5
- init.uniform_(embedding.Weight, -initrange, initrange) |> ignore
- init.uniform_(fc.Weight, -initrange, initrange) |> ignore
- init.zeros_(fc.Bias) |> ignore
+ init.uniform_(embedding.weight, -initrange, initrange) |> ignore
+ init.uniform_(fc.weight, -initrange, initrange) |> ignore
+ init.zeros_(fc.bias) |> ignore
this.RegisterComponents()
@@ -80,6 +80,8 @@ let train epoch (trainData:IEnumerable)
let batch_count = trainData.Count()
for labels,texts,offsets in trainData do
+
+ use d = torch.NewDisposeScope()
optimizer.zero_grad()
@@ -99,8 +101,6 @@ let train epoch (trainData:IEnumerable)
batch <- batch + 1
- GC.Collect()
-
let evaluate (testData:IEnumerable) (model:TextClassificationModel) =
model.Eval()
@@ -149,7 +149,7 @@ let run epochs =
sw.Stop()
- let lrStr = scheduler.LearningRate.ToString("0.0000")
+ let lrStr = optimizer.LearningRate.ToString("0.0000")
let tsStr = sw.Elapsed.TotalSeconds.ToString("0.0")
printfn $"\nEnd of epoch: {epoch} | lr: {lrStr} | time: {tsStr}s\n"
scheduler.step() |> ignore
diff --git a/src/Utils/Examples.Utils.csproj b/src/Utils/Examples.Utils.csproj
index 357077f..aba3914 100644
--- a/src/Utils/Examples.Utils.csproj
+++ b/src/Utils/Examples.Utils.csproj
@@ -11,7 +11,7 @@
-
+
diff --git a/tutorials/CSharp/tutorial2.ipynb b/tutorials/CSharp/tutorial2.ipynb
index 2bb1732..aeb233e 100644
--- a/tutorials/CSharp/tutorial2.ipynb
+++ b/tutorials/CSharp/tutorial2.ipynb
@@ -623,7 +623,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "If you need to have some idea of how many tensors you have allocated in you process, there's a static property `TotalCount` that tells you just how many .NET tensors are active. That's not all tensors, just those that have a .NET representation. Temporaries used by the native library and not surfaces don't count.\n",
+ "If you need to have some idea of how many tensors you have allocated in you process, there's a static property `TotalCount` that tells you just how many .NET tensors are active. That's not all tensors, just those that have a .NET representation. Temporaries used by the native library and not surfaced to managed code don't count.\n",
"\n",
"The property is useful if you are diagnosing memory issues, for example in a training loop. If the number of tensors keeps growing, somewhere there's a missing Dispose() call."
]
diff --git a/tutorials/CSharp/tutorial5.ipynb b/tutorials/CSharp/tutorial5.ipynb
index 57764a8..a51e741 100644
--- a/tutorials/CSharp/tutorial5.ipynb
+++ b/tutorials/CSharp/tutorial5.ipynb
@@ -191,7 +191,110 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "It's worth noting that use of in-place operators cuts down significantly the number of temporaries that have to be disposed."
+ "### DisposeScope\n",
+ "\n",
+ "In version 0.95.4, TorchSharp introduced the notion of a `DisposeScope`, which deals with the dispose pattern systematically. It introduces the notion of a dynamic (runtime) scope, which controls the liftime of all tensors created while the scope is in effect. The lexical scope, i.e. the source code location where variables are held, etc., had no impact on the dynamic scope management.\n",
+ "\n",
+ "When any .NET tensor is created, it will be registered with the current dynamic scope, if there is one. Once registered, the tensor will be disposed automatically when the scope is disposed. It doesn't matter if the tensor is held in a variable declared outside or inside the scope.\n",
+ "\n",
+ "Try running the next cell with and without the 'using' line (comment it out), and notice how the tensor count grows if you don't have it, but stays constant when you do."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "dotnet_interactive": {
+ "language": "csharp"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "Console.WriteLine(torch.Tensor.TotalCount);\n",
+ "using (var d = torch.NewDisposeScope()) \n",
+ "{\n",
+ " var t3 = (a + b) * (a + c.cuda());\n",
+ " t3.print();\n",
+ "}\n",
+ "Console.WriteLine(torch.Tensor.TotalCount);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Sometimes, you really need a tensor instance to survive the end of the dynamic scope. It can be detached from the scope, or moved to a surrounding scope (they nest). If there is no surrounding scope, it's the same thing.\n",
+ "\n",
+ "For example, let's try this example -- you should get an exception that complains about an invalid tensor. It's because it was disposed before calling `print()`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "dotnet_interactive": {
+ "language": "csharp"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "torch.Tensor t3;\n",
+ "using (var d = torch.NewDisposeScope()) \n",
+ "{\n",
+ " t3 = (a + b) * (a + c.cuda());\n",
+ "}\n",
+ "t3.print();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The solution is to detach it before the end of the scope. Once it's been moved outside, the tensor needs to either land in an outer scope that will automatically dispose it, or it needs to be disposed explicitly in order to free native memory (unless waiting until GC kicks in is acceptable)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "dotnet_interactive": {
+ "language": "csharp"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "torch.Tensor t3;\n",
+ "using (var d = torch.NewDisposeScope()) \n",
+ "{\n",
+ " t3 = d.MoveToOuter((a + b) * (a + c.cuda()));\n",
+ "}\n",
+ "t3.print();\n",
+ "Console.WriteLine(t3.IsInvalid);\n",
+ "t3.Dispose();\n",
+ "Console.WriteLine(t3.IsInvalid);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "dotnet_interactive": {
+ "language": "csharp"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "torch.Tensor t3;\n",
+ "using (var d0 = torch.NewDisposeScope()) \n",
+ "{\n",
+ " using (var d1 = torch.NewDisposeScope()) \n",
+ " {\n",
+ " t3 = d1.MoveToOuter((a + b) * (a + c.cuda()));\n",
+ " }\n",
+ " t3.print();\n",
+ " Console.WriteLine(t3.IsInvalid);\n",
+ "}\n",
+ "Console.WriteLine(t3.IsInvalid);"
]
},
{
diff --git a/tutorials/CSharp/tutorial7.ipynb b/tutorials/CSharp/tutorial7.ipynb
index 45e6cb2..c36efec 100644
--- a/tutorials/CSharp/tutorial7.ipynb
+++ b/tutorials/CSharp/tutorial7.ipynb
@@ -43,9 +43,9 @@
"source": [
"To further complicate matters, it turns out that the learning rate shouldn't necessarily be constant. Training can go much better if the learning rate starts out relatively large and gets smaller as you get closer to the end.\n",
"\n",
- "There's a solution for this, called a Learning Rate Scheduler. An LRS instance has access to the internal state of the optimizer, and can modify the LR as it goes along. \n",
+ "There's a solution for this, called a Learning Rate Scheduler. An LRS instance has access to the internal state of the optimizer, and can modify the LR as it goes along. Some schedulers modify other optimizer state, too, such as the momentum (for optimizers that use momentum).\n",
"\n",
- "There are several algorithms for scheduling, but TorchSharp only implements the two most conceptually simple: StepLR and ExponentialLR. In this tutorial, we will only cover StepLR."
+ "There are several algorithms for scheduling, and TorchSharp implements a number of them."
]
},
{
diff --git a/tutorials/FSharp/tutorial5.ipynb b/tutorials/FSharp/tutorial5.ipynb
index df27637..cfd1962 100644
--- a/tutorials/FSharp/tutorial5.ipynb
+++ b/tutorials/FSharp/tutorial5.ipynb
@@ -204,6 +204,120 @@
"It's worth noting that use of in-place operators cuts down significantly the number of temporaries that have to be disposed."
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### DisposeScope\n",
+ "\n",
+ "In version 0.95.4, TorchSharp introduced the notion of a `DisposeScope`, which deals with the dispose pattern systematically. It introduces the notion of a dynamic (runtime) scope, which controls the liftime of all tensors created while the scope is in effect. The lexical scope, i.e. the source code location where variables are held, etc., had no impact on the dynamic scope management.\n",
+ "\n",
+ "When any .NET tensor is created, it will be registered with the current dynamic scope, if there is one. Once registered, the tensor will be disposed automatically when the scope is disposed. It doesn't matter if the tensor is held in a variable declared outside or inside the scope.\n",
+ "\n",
+ "Try running the next cell with and without the 'use d = ...' line (comment it out), and notice how the tensor count grows if you don't have it, but stays constant when you do."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "dotnet_interactive": {
+ "language": "fsharp"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "printf \"%d\\n\" torch.Tensor.TotalCount\n",
+ "do \n",
+ " use d = torch.NewDisposeScope()\n",
+ " let t3 = (a + b) * (a + c.cuda())\n",
+ " t3.print() |> ignore\n",
+ "\n",
+ "printf \"%d\\n\" torch.Tensor.TotalCount"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Sometimes, you really need a tensor instance to survive the end of the dynamic scope. It can be detached from the scope, or moved to a surrounding scope (they nest). If there is no surrounding scope, it's the same thing.\n",
+ "\n",
+ "For example, let's try this example -- you should get an exception that complains about an invalid tensor. It's because it was disposed before calling `print()`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "dotnet_interactive": {
+ "language": "fsharp"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "let mutable t3:torch.Tensor = null\n",
+ "\n",
+ "do \n",
+ " use d = torch.NewDisposeScope()\n",
+ " t3 <- (a + b) * (a + c.cuda())\n",
+ "\n",
+ "t3.print()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The solution is to detach it before the end of the scope. Once it's been moved outside, the tensor needs to either land in an outer scope that will automatically dispose it, or it needs to be disposed explicitly in order to free native memory (unless waiting until GC kicks in is acceptable)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "dotnet_interactive": {
+ "language": "fsharp"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "let mutable t3:torch.Tensor = null\n",
+ "\n",
+ "do \n",
+ " use d = torch.NewDisposeScope()\n",
+ " t3 <- d.MoveToOuter((a + b) * (a + c.cuda()))\n",
+ "\n",
+ "t3.print()\n",
+ "printf \"%b\\n\" t3.IsInvalid\n",
+ "t3.Dispose()\n",
+ "printf \"%b\\n\" t3.IsInvalid"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "dotnet_interactive": {
+ "language": "fsharp"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "let mutable t3:torch.Tensor = null\n",
+ "\n",
+ "do\n",
+ " use d0 = torch.NewDisposeScope()\n",
+ "\n",
+ " do \n",
+ " use d1 = torch.NewDisposeScope()\n",
+ " t3 <- d1.MoveToOuter((a + b) * (a + c.cuda()))\n",
+ "\n",
+ " t3.print()\n",
+ " printf \"%b\\n\" t3.IsInvalid\n",
+ "\n",
+ "printf \"%b\\n\" t3.IsInvalid"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
diff --git a/tutorials/FSharp/tutorial6.ipynb b/tutorials/FSharp/tutorial6.ipynb
index 2bc3aa8..ac23e59 100644
--- a/tutorials/FSharp/tutorial6.ipynb
+++ b/tutorials/FSharp/tutorial6.ipynb
@@ -690,9 +690,7 @@
"\n",
"The prior sections have described the most general way of constructing a model, that is, by creating a class that abstracts the logic of the model and explicitly calls each layer's `forward` methos. While it's not too complicated to do so, it's a lot of \"ceremony\" to accomplish something very regular.\n",
"\n",
- "Fortunately, for models, or components of models, that simply pass one tensor from layer to layer, there's a class to handle it. It's called `Sequential` and is created by passing a sequence of tuples. The first element of the tuple is the name of the layer (required), and the second is the component.\n",
- "\n",
- "The following model is equivalent to the Trivial model we've seen before. In fact, weights from Trivial can be loaded into it."
+ "Fortunately, for models, or components of models, that simply pass one tensor from layer to layer, there's a class to handle it. It's called `Sequential` and is created by passing a sequence of modules."
]
},
{
@@ -705,7 +703,7 @@
},
"outputs": [],
"source": [
- "let seq = nn.Sequential(struct (\"lin1\", upcast nn.Linear(1000L, 100L)), (\"relu1\", upcast nn.ReLU()), (\"lin2\", upcast nn.Linear(100L, 10L)))\n",
+ "let seq = nn.Sequential((\"lin1\", nn.Linear(1000L, 100L) :> nn.Module), (\"relu\", nn.ReLU() :> nn.Module), (\"lin2\", nn.Linear(100L, 10L) :> nn.Module))\n",
"seq.load(\"tutorial6.model.bin\")\n",
"predMax = seq.forward(dataBatch).argmax(1L)\n",
"refMax.eq(predMax).sum() / predMax.numel().ToScalar()"
diff --git a/tutorials/FSharp/tutorial7.ipynb b/tutorials/FSharp/tutorial7.ipynb
index 4aa2509..19d011c 100644
--- a/tutorials/FSharp/tutorial7.ipynb
+++ b/tutorials/FSharp/tutorial7.ipynb
@@ -39,7 +39,7 @@
"source": [
"To further complicate matters, it turns out that the learning rate shouldn't necessarily be constant. Training can go much better if the learning rate starts out relatively large and gets smaller as you get closer to the end.\n",
"\n",
- "There's a solution for this, called a Learning Rate Scheduler. An LRS instance has access to the internal state of the optimizer, and can modify the LR as it goes along. There are several algorithms for scheduling, of which TorchSharp currently implements a subset."
+ "There's a solution for this, called a Learning Rate Scheduler. An LRS instance has access to the internal state of the optimizer, and can modify the LR as it goes along. There are several algorithms for scheduling, of which TorchSharp currently implements a significant subset."
]
},
{
@@ -156,7 +156,7 @@
"let loss x y = nn.functional.mse_loss().Invoke(x,y)\n",
"\n",
"let optimizer = torch.optim.SGD(model.parameters(), learning_rate)\n",
- "let scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 25, 0.95)\n",
+ "let scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 25, 0.95, verbose=true)\n",
"\n",
"for epoch = 1 to 300 do\n",
"\n",
@@ -189,9 +189,7 @@
"source": [
"Well, that was underwhelming. The loss (in my case) went up a bit, so that's nothing to get excited about. For this trivial model, using a scheduler isn't going to make a huge difference, and it may not make much of a difference even for complex models. It's very hard to know until you try it, but now you know how to try it out. If you try this trivial example over and over, you will see that the results vary quite a bit. It's simply too simple.\n",
"\n",
- "Regardless, you can see from the verbose output that the learning rate is adjusted as the epochs proceed. \n",
- "\n",
- "Note: If you're using 0.93.9 and you see odd dips in the learning rate, that's a bug in the verbose printout logic, not the learning rate scheduler itself."
+ "Regardless, you can see from the verbose output that the learning rate is adjusted as the epochs proceed."
]
},
{