From d73a01bfc9c9693acf1f2e8a2467f182048a9f21 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 24 Feb 2020 11:28:14 -0800 Subject: [PATCH] Fix --- src/nnvm/plan_memory.cc | 25 ++++++++++++----------- tests/python/unittest/test_numpy_gluon.py | 21 +++++++++++++++++++ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc index 6c6e02d88757..3815f239f88c 100644 --- a/src/nnvm/plan_memory.cc +++ b/src/nnvm/plan_memory.cc @@ -38,21 +38,22 @@ namespace { // Return bytes of data flag. static int MXGetDTypeSize(int type_flag) { switch (type_flag) { - case kUint8: - case kInt8: + case mshadow::kUint8: + case mshadow::kInt8: + case mshadow::kBool: return 1; - case kFloat16: - case kBfloat16: - case kInt16: - case kUint16: + case mshadow::kFloat16: + case mshadow::kBfloat16: + case mshadow::kInt16: + case mshadow::kUint16: return 2; - case kFloat32: - case kInt32: - case kUint32: + case mshadow::kFloat32: + case mshadow::kInt32: + case mshadow::kUint32: return 4; - case kFloat64: - case kInt64: - case kUint64: + case mshadow::kFloat64: + case mshadow::kInt64: + case mshadow::kUint64: return 8; default: LOG(FATAL) << "unknown type_flag=" << type_flag; diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index 6ce9e18e2624..0d1e5fed59b3 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -400,6 +400,27 @@ def hybrid_forward(self, F, x, y): mx.np.random.normal(0, 1, (10, 5, 8))]) +@with_seed() +@use_np +def test_hybridize_boolean_dtype(): + class Foo(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(Foo, self).__init__(prefix=prefix, params=params) + + def hybrid_forward(self, F, valid_length): + mask = ((F.np.ones((10,)) / 2) < valid_length) + return mask + + valid_length = mx.np.random.uniform(size=(10,)) + foo = Foo() + out1 = foo(valid_length) + + foo = Foo() + foo.hybridize() + out2 = foo(valid_length) + + assert mx.test_utils.same(out1.asnumpy(), out2.asnumpy()) + if __name__ == '__main__': import nose