From 653f3e67660f0a56d21bbe0b1242bba1653682bb Mon Sep 17 00:00:00 2001 From: vlado Date: Fri, 26 Oct 2018 10:24:58 -0600 Subject: [PATCH 1/2] Sample python bilinear initializer at integral points in y-direction --- python/mxnet/initializer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 357e75b3bdf5..9a06782daaf4 100755 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -216,7 +216,7 @@ def _init_bilinear(self, _, arr): c = (2 * f - 1 - f % 2) / (2. * f) for i in range(np.prod(shape)): x = i % shape[3] - y = (i / shape[3]) % shape[2] + y = (i // shape[3]) % shape[2] weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) arr[:] = weight.reshape(shape) @@ -656,7 +656,7 @@ def _init_weight(self, _, arr): c = (2 * f - 1 - f % 2) / (2. * f) for i in range(np.prod(shape)): x = i % shape[3] - y = (i / shape[3]) % shape[2] + y = (i // shape[3]) % shape[2] weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) arr[:] = weight.reshape(shape) From b43347aa1e7db0ef8141fbb77bd3c6f1d541d391 Mon Sep 17 00:00:00 2001 From: vlado Date: Wed, 21 Nov 2018 14:02:27 -0700 Subject: [PATCH 2/2] Add unit test for bilinear initializer --- tests/python/unittest/test_init.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/python/unittest/test_init.py b/tests/python/unittest/test_init.py index efd6ef36744f..c8bf01f48ca3 100644 --- a/tests/python/unittest/test_init.py +++ b/tests/python/unittest/test_init.py @@ -60,8 +60,17 @@ def check_rsp_const_init(init, val): check_rsp_const_init(mx.initializer.Zero(), 0.) check_rsp_const_init(mx.initializer.One(), 1.) +def test_bilinear_init(): + bili = mx.init.Bilinear() + bili_weight = mx.ndarray.empty((1,1,4,4)) + bili._init_weight(None, bili_weight) + bili_1d = np.array([[1/float(4), 3/float(4), 3/float(4), 1/float(4)]]) + bili_2d = bili_1d * np.transpose(bili_1d) + assert (bili_2d == bili_weight.asnumpy()).all() + if __name__ == '__main__': test_variable_init() test_default_init() test_aux_init() test_rsp_const_init() + test_bilinear_init()