@@ -49,7 +49,8 @@ internal static void Run(int epochs, int timeout, string modelName)
4949 torch . cuda . is_available ( ) ? torch . CUDA :
5050 torch . CPU ;
5151
52- if ( device . type == DeviceType . CUDA ) {
52+ if ( device . type == DeviceType . CUDA )
53+ {
5354 _trainBatchSize *= 8 ;
5455 _testBatchSize *= 8 ;
5556 }
@@ -61,7 +62,8 @@ internal static void Run(int epochs, int timeout, string modelName)
6162 var sourceDir = _dataLocation ;
6263 var targetDir = Path . Combine ( _dataLocation , "test_data" ) ;
6364
64- if ( ! Directory . Exists ( targetDir ) ) {
65+ if ( ! Directory . Exists ( targetDir ) )
66+ {
6567 Directory . CreateDirectory ( targetDir ) ;
6668 Decompress . ExtractTGZ ( Path . Combine ( sourceDir , "cifar-10-binary.tar.gz" ) , targetDir ) ;
6769 }
@@ -70,40 +72,41 @@ internal static void Run(int epochs, int timeout, string modelName)
7072
7173 Module model = null ;
7274
73- switch ( modelName . ToLower ( ) ) {
74- case "alexnet" :
75- model = new AlexNet ( modelName , _numClasses , device ) ;
76- break ;
77- case "mobilenet" :
78- model = new MobileNet ( modelName , _numClasses , device ) ;
79- break ;
80- case "vgg11" :
81- case "vgg13" :
82- case "vgg16" :
83- case "vgg19" :
84- model = new VGG ( modelName , _numClasses , device ) ;
85- break ;
86- case "resnet18" :
87- model = ResNet . ResNet18 ( _numClasses , device ) ;
88- break ;
89- case "resnet34" :
90- _testBatchSize /= 4 ;
91- model = ResNet . ResNet34 ( _numClasses , device ) ;
92- break ;
93- case "resnet50" :
94- _trainBatchSize /= 6 ;
95- _testBatchSize /= 8 ;
96- model = ResNet . ResNet50 ( _numClasses , device ) ;
97- break ;
98- case "resnet101" :
99- _trainBatchSize /= 6 ;
100- _testBatchSize /= 8 ;
101- model = ResNet . ResNet101 ( _numClasses , device ) ;
102- break ;
103- case "resnet152" :
104- _testBatchSize /= 4 ;
105- model = ResNet . ResNet152 ( _numClasses , device ) ;
106- break ;
75+ switch ( modelName . ToLower ( ) )
76+ {
77+ case "alexnet" :
78+ model = new AlexNet ( modelName , _numClasses , device ) ;
79+ break ;
80+ case "mobilenet" :
81+ model = new MobileNet ( modelName , _numClasses , device ) ;
82+ break ;
83+ case "vgg11" :
84+ case "vgg13" :
85+ case "vgg16" :
86+ case "vgg19" :
87+ model = new VGG ( modelName , _numClasses , device ) ;
88+ break ;
89+ case "resnet18" :
90+ model = ResNet . ResNet18 ( _numClasses , device ) ;
91+ break ;
92+ case "resnet34" :
93+ _testBatchSize /= 4 ;
94+ model = ResNet . ResNet34 ( _numClasses , device ) ;
95+ break ;
96+ case "resnet50" :
97+ _trainBatchSize /= 6 ;
98+ _testBatchSize /= 8 ;
99+ model = ResNet . ResNet50 ( _numClasses , device ) ;
100+ break ;
101+ case "resnet101" :
102+ _trainBatchSize /= 6 ;
103+ _testBatchSize /= 8 ;
104+ model = ResNet . ResNet101 ( _numClasses , device ) ;
105+ break ;
106+ case "resnet152" :
107+ _testBatchSize /= 4 ;
108+ model = ResNet . ResNet152 ( _numClasses , device ) ;
109+ break ;
107110 }
108111
109112 var hflip = transforms . HorizontalFlip ( ) ;
@@ -116,19 +119,20 @@ internal static void Run(int epochs, int timeout, string modelName)
116119
117120 using ( var train = new CIFARReader ( targetDir , false , _trainBatchSize , shuffle : true , device : device , transforms : new ITransform [ ] { } ) )
118121 using ( var test = new CIFARReader ( targetDir , true , _testBatchSize , device : device ) )
119- using ( var optimizer = torch . optim . Adam ( model . parameters ( ) , 0.001 ) ) {
122+ using ( var optimizer = torch . optim . Adam ( model . parameters ( ) , 0.001 ) )
123+ {
120124
121125 Stopwatch totalSW = new Stopwatch ( ) ;
122126 totalSW . Start ( ) ;
123127
124- for ( var epoch = 1 ; epoch <= epochs ; epoch ++ ) {
128+ for ( var epoch = 1 ; epoch <= epochs ; epoch ++ )
129+ {
125130
126131 Stopwatch epchSW = new Stopwatch ( ) ;
127132 epchSW . Start ( ) ;
128133
129134 Train ( model , optimizer , nll_loss ( ) , train . Data ( ) , epoch , _trainBatchSize , train . Size ) ;
130135 Test ( model , nll_loss ( ) , test . Data ( ) , test . Size ) ;
131- GC . Collect ( ) ;
132136
133137 epchSW . Stop ( ) ;
134138 Console . WriteLine ( $ "Elapsed time for this epoch: { epchSW . Elapsed . TotalSeconds } s.") ;
@@ -160,35 +164,33 @@ private static void Train(
160164
161165 Console . WriteLine ( $ "Epoch: { epoch } ...") ;
162166
163- foreach ( var ( data , target ) in dataLoader ) {
167+ foreach ( var ( data , target ) in dataLoader )
168+ {
164169
165- optimizer . zero_grad ( ) ;
170+ using ( var d = torch . NewDisposeScope ( ) )
171+ {
172+ optimizer . zero_grad ( ) ;
166173
167- using var prediction = model . forward ( data ) ;
168- using var lsm = log_softmax ( prediction , 1 ) ;
169- using ( var output = loss ( lsm , target ) ) {
174+ var prediction = model . forward ( data ) ;
175+ var lsm = log_softmax ( prediction , 1 ) ;
176+ var output = loss ( lsm , target ) ;
170177
171178 output . backward ( ) ;
172179
173180 optimizer . step ( ) ;
174181
175182 total += target . shape [ 0 ] ;
176183
177- using ( var predicted = prediction . argmax ( 1 ) )
178- using ( var eq = predicted . eq ( target ) )
179- using ( var sum = eq . sum ( ) ) {
180- correct += sum . ToInt64 ( ) ;
181- }
184+ correct += prediction . argmax ( 1 ) . eq ( target ) . ToInt64 ( ) ;
182185
183- if ( batchId % _logInterval == 0 ) {
186+ if ( batchId % _logInterval == 0 )
187+ {
184188 var count = Math . Min ( batchId * batchSize , size ) ;
185189 Console . WriteLine ( $ "\r Train: epoch { epoch } [{ count } / { size } ] Loss: { output . ToSingle ( ) . ToString ( "0.000000" ) } | Accuracy: { ( ( float ) correct / total ) . ToString ( "0.000000" ) } ") ;
186190 }
187191
188192 batchId ++ ;
189193 }
190-
191- GC . Collect ( ) ;
192194 }
193195 }
194196
@@ -204,23 +206,20 @@ private static void Test(
204206 long correct = 0 ;
205207 int batchCount = 0 ;
206208
207- foreach ( var ( data , target ) in dataLoader ) {
209+ foreach ( var ( data , target ) in dataLoader )
210+ {
208211
209- using var prediction = model . forward ( data ) ;
210- using var lsm = log_softmax ( prediction , 1 ) ;
211- using ( var output = loss ( lsm , target ) ) {
212+ using ( var d = torch . NewDisposeScope ( ) )
213+ {
214+ var prediction = model . forward ( data ) ;
215+ var lsm = log_softmax ( prediction , 1 ) ;
216+ var output = loss ( lsm , target ) ;
212217
213218 testLoss += output . ToSingle ( ) ;
214219 batchCount += 1 ;
215220
216- using ( var predicted = prediction . argmax ( 1 ) )
217- using ( var eq = predicted . eq ( target ) )
218- using ( var sum = eq . sum ( ) ) {
219- correct += sum . ToInt64 ( ) ;
220- }
221+ correct += prediction . argmax ( 1 ) . eq ( target ) . ToInt64 ( ) ;
221222 }
222-
223- GC . Collect ( ) ;
224223 }
225224
226225 Console . WriteLine ( $ "\r Test set: Average loss { ( testLoss / batchCount ) . ToString ( "0.0000" ) } | Accuracy { ( ( float ) correct / size ) . ToString ( "0.0000" ) } ") ;
0 commit comments