diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py index 0e0cd521f28d..9412b6f9371b 100755 --- a/example/image-classification/common/fit.py +++ b/example/image-classification/common/fit.py @@ -237,6 +237,9 @@ def fit(args, network, data_loader, **kwargs): if args.network == 'alexnet': # AlexNet will not converge using Xavier initializer = mx.init.Normal() + # VGG will not trend to converge using Xavier-Gaussian + elif 'vgg' in args.network: + initializer = mx.init.Xavier() else: initializer = mx.init.Xavier( rnd_type='gaussian', factor_type="in", magnitude=2)