From 40427a7ab9360cc5dcfef1836eb0cc7c71fc108e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 31 Dec 2021 01:32:27 -0500 Subject: [PATCH] add unittest for `deepmd.common.cast_precision` --- source/tests/test_common.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/source/tests/test_common.py b/source/tests/test_common.py index 0f052663b2..adbde971e9 100644 --- a/source/tests/test_common.py +++ b/source/tests/test_common.py @@ -3,7 +3,8 @@ import unittest from pathlib import Path -from deepmd.common import expand_sys_str +from deepmd.common import expand_sys_str, cast_precision, GLOBAL_TF_FLOAT_PRECISION +from deepmd.env import tf # compute relative path # https://stackoverflow.com/questions/38083555/using-pathlibs-relative-to-for-directories-on-the-same-level @@ -46,3 +47,30 @@ def test_expand(self): ret = expand_sys_str('test_sys') ret.sort() self.assertEqual(ret, self.expected_out) + + +class TestCastPrecision(unittest.TestCase): + """This class tests `deepmd.common.cast_precision`.""" + @property + def precision(self): + if GLOBAL_TF_FLOAT_PRECISION == tf.float32: + return tf.float64 + return tf.float32 + + def test_cast_precision(self): + x = tf.zeros(1, dtype=GLOBAL_TF_FLOAT_PRECISION) + y = tf.zeros(1, dtype=tf.int64) + self.assertEqual(x.dtype, GLOBAL_TF_FLOAT_PRECISION) + self.assertEqual(y.dtype, tf.int64) + x, y, z = self._inner_method(x, y) + self.assertEqual(x.dtype, GLOBAL_TF_FLOAT_PRECISION) + self.assertEqual(y.dtype, tf.int64) + self.assertIsInstance(z, bool) + + @cast_precision + def _inner_method(self, x: tf.Tensor, y: tf.Tensor, z: bool = False) -> tf.Tensor: + # y and z should not be cast here + self.assertEqual(x.dtype, self.precision) + self.assertEqual(y.dtype, tf.int64) + self.assertIsInstance(z, bool) + return x, y, z