diff --git a/tests/python/train/test_dtype.py b/tests/python/train/test_dtype.py index b0a524815c6c..96912c09dbe0 100644 --- a/tests/python/train/test_dtype.py +++ b/tests/python/train/test_dtype.py @@ -99,7 +99,7 @@ def run_cifar10(train, val, use_module): devs = [mx.cpu(0)] net = get_net() mod = mx.mod.Module(net, context=devs) - optim_args = {'learning_rate': 0.05, 'wd': 0.00001, 'momentum': 0.9} + optim_args = {'learning_rate': 0.001, 'wd': 0.00001, 'momentum': 0.9} eval_metrics = ['accuracy'] if use_module: executor = mx.mod.Module(net, context=devs)