diff --git a/scimath/units/assertion_utils.py b/scimath/units/assertion_utils.py new file mode 100644 index 0000000..89f7e67 --- /dev/null +++ b/scimath/units/assertion_utils.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +""" Utilities providing assertions to support unit tests involving UnitScalars +and UnitArrays. +""" +from nose.tools import assert_false, assert_true + +from scimath.units.compare_units import unit_arrays_almost_equal, \ + unit_scalars_almost_equal + + +def assert_unit_scalar_almost_equal(val1, val2, rtol=1.e-9, msg=None): + if msg is None: + msg = "{} and {} are not almost equal with precision {}" + msg = msg.format(val1, val2, rtol) + + assert_true(unit_scalars_almost_equal(val1, val2, rtol=rtol), msg=msg) + + +def assert_unit_scalar_not_almost_equal(val1, val2, rtol=1.e-9, msg=None): + if msg is None: + msg = "{} and {} unexpectedly almost equal with precision {}" + msg = msg.format(val1, val2, rtol) + + assert_false(unit_scalars_almost_equal(val1, val2, rtol=rtol), msg=msg) + + +def assert_unit_array_almost_equal(uarr1, uarr2, rtol=1e-9, msg=None): + if msg is None: + msg = "{} and {} are not almost equal with precision {}" + msg = msg.format(uarr1, uarr2, rtol) + + assert_true(unit_arrays_almost_equal(uarr1, uarr2, rtol=rtol), msg=msg) + + +def assert_unit_array_not_almost_equal(uarr1, uarr2, rtol=1e-9, msg=None): + if msg is None: + msg = "{} and {} are almost equal with precision {}" + msg = msg.format(uarr1, uarr2, rtol) + + assert_false(unit_arrays_almost_equal(uarr1, uarr2, rtol=rtol), msg=msg) diff --git a/scimath/units/compare_units.py b/scimath/units/compare_units.py new file mode 100644 index 0000000..094bc09 --- /dev/null +++ b/scimath/units/compare_units.py @@ -0,0 +1,81 @@ +""" Utilities around unit comparisons. +""" +import numpy as np + +from scimath.units.api import convert, UnitArray, UnitScalar +from scimath.units.unit import InvalidConversion + + +def unit_scalars_almost_equal(x1, x2, rtol=1e-9): + """ Returns whether 2 UnitScalars are almost equal. + + More precisely, what is tested is whether abs(a1-a2) < rtol*abs(a2), where + a1=float(x1) and a2=float(x2) after conversion to x1's units. + + Parameters + ---------- + x1 : UnitScalar + First unit scalar to compare. + + x2 : UnitScalar + Second unit scalar to compare. + + rtol : float + Relative precision of the comparison. + """ + if not isinstance(x1, UnitScalar): + msg = "x1 is supposed to be a UnitScalar but a {} was passed." + msg = msg.format(type(x1)) + raise ValueError(msg) + + if not isinstance(x2, UnitScalar): + msg = "x2 is supposed to be a UnitScalar but a {} was passed." + msg = msg.format(type(x2)) + raise ValueError(msg) + + a1 = float(x1) + try: + a2 = convert(float(x2), from_unit=x2.units, to_unit=x1.units) + except InvalidConversion: + return False + return np.abs(a1 - a2) < np.abs(rtol * a2) + + +def unit_arrays_almost_equal(uarr1, uarr2, rtol=1e-9): + """ Returns whether 2 UnitArrays are almost equal (must be the same shape). + + More precisely, what is tested is whether abs(a1-a2) < rtol*abs(a2) for all + values in the arrays, once uarr2 has been converted to uarr1's units. + + Parameters + ---------- + uarr1 : UnitArray + First unit array to compare. + + uarr2 : UnitArray + Second unit array to compare. + + rtol : float + Relative precision of the comparison. + """ + if not isinstance(uarr1, UnitArray): + msg = "uarr1 is supposed to be a UnitArray but a {} was passed." + msg = msg.format(type(uarr1)) + raise ValueError(msg) + + if not isinstance(uarr2, UnitArray): + msg = "uarr2 is supposed to be a UnitArray but a {} was passed." + msg = msg.format(type(uarr2)) + raise ValueError(msg) + + if uarr1.shape != uarr2.shape: + return False + + a1 = np.array(uarr1) + try: + a2 = convert(np.array(uarr2), from_unit=uarr2.units, + to_unit=uarr1.units) + except InvalidConversion: + return False + + return np.all(np.abs(a1 - a2) < np.abs(rtol * a2)) diff --git a/scimath/units/testing/__init__.py b/scimath/units/testing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scimath/units/tests/test_assertion_utils.py b/scimath/units/tests/test_assertion_utils.py new file mode 100644 index 0000000..47401a8 --- /dev/null +++ b/scimath/units/tests/test_assertion_utils.py @@ -0,0 +1,56 @@ +from unittest import TestCase + +from scimath.units.api import UnitArray, UnitScalar +from scimath.units.assertion_utils import assert_unit_array_almost_equal, \ + assert_unit_scalar_almost_equal + + +class TestAssertUnitScalarEqual(TestCase): + def test_same_unit_scalar(self): + assert_unit_scalar_almost_equal(UnitScalar(1, units="s"), + UnitScalar(1, units="s")) + + def test_equivalent_unit_scalar(self): + assert_unit_scalar_almost_equal(UnitScalar(1, units="m"), + UnitScalar(100, units="cm")) + + def test_not_close(self): + with self.assertRaises(AssertionError): + assert_unit_scalar_almost_equal(UnitScalar(1, units="m"), + UnitScalar(1.1, units="m")) + + def test_not_close_custom_msg(self): + a1 = UnitScalar(1, units="m") + a2 = UnitScalar(1.1, units="m") + with self.assertRaises(AssertionError): + assert_unit_scalar_almost_equal(a1, a2, rtol=1e-2, msg="BLAH") + + def test_unit_scalar_non_default_rtol(self): + assert_unit_scalar_almost_equal(UnitScalar(1, units="m"), + UnitScalar(1.01, units="m"), rtol=1e-1) + + +class TestAssertUnitArrayEqual(TestCase): + def test_same_unit_array(self): + assert_unit_array_almost_equal(UnitArray([1, 2], units="s"), + UnitArray([1, 2], units="s")) + + def test_equivalent_unit_array(self): + assert_unit_array_almost_equal(UnitArray([1, 2], units="m"), + UnitArray([100, 200], units="cm")) + + def test_not_close(self): + a1 = UnitArray([1.01, 2], units="s") + a2 = UnitArray([1, 2], units="s") + with self.assertRaises(AssertionError): + assert_unit_array_almost_equal(a1, a2) + + def test_not_close_custom_msg(self): + a1 = UnitArray([1.01, 2], units="s") + a2 = UnitArray([1, 2], units="s") + with self.assertRaises(AssertionError): + assert_unit_array_almost_equal(a1, a2, msg="BLAH") + + def test_unit_scalar_non_default_rtol(self): + assert_unit_array_almost_equal(UnitScalar(1, units="m"), + UnitScalar(1.01, units="m"), rtol=1e-1) diff --git a/scimath/units/tests/test_compare_units.py b/scimath/units/tests/test_compare_units.py new file mode 100644 index 0000000..3784315 --- /dev/null +++ b/scimath/units/tests/test_compare_units.py @@ -0,0 +1,115 @@ +from unittest import TestCase + +from scimath.units.api import dimensionless, UnitArray, UnitScalar +from scimath.units.compare_units import unit_arrays_almost_equal, \ + unit_scalars_almost_equal + + +class TestUnitScalarAlmostEqual(TestCase): + def test_values_identical(self): + val1 = UnitScalar(1., units="m") + self.assertTrue(unit_scalars_almost_equal(val1, val1)) + + def test_wrong_arg_type1(self): + val1 = 1 + val2 = UnitScalar(1., units="m") + with self.assertRaises(ValueError): + unit_scalars_almost_equal(val1, val2) + + def test_wrong_arg_type2(self): + val1 = UnitScalar(1., units="m") + val2 = 1 + with self.assertRaises(ValueError): + unit_scalars_almost_equal(val1, val2) + + def test_values_not_close(self): + val1 = UnitScalar(1., units="m") + val2 = UnitScalar(1.1, units="m") + self.assertFalse(unit_scalars_almost_equal(val1, val2)) + + val2 = UnitScalar(1.00001, units="m") + self.assertFalse(unit_scalars_almost_equal(val1, val2)) + + def test_values_identical_in_diff_units(self): + val1 = UnitScalar(1., units="m") + val2 = UnitScalar(100., units="cm") + self.assertTrue(unit_scalars_almost_equal(val1, val2)) + + def test_dimensionless(self): + val1 = UnitScalar(1., units=dimensionless) + val2 = UnitScalar(1., units="cm") + self.assertFalse(unit_scalars_almost_equal(val1, val2)) + + def test_2_dimensionless(self): + val1 = UnitScalar(1., units=dimensionless) + val2 = UnitScalar(1., units="BLAH") + val3 = UnitScalar(100., units="BLAH") + self.assertTrue(unit_scalars_almost_equal(val1, val1)) + self.assertTrue(unit_scalars_almost_equal(val1, val2)) + self.assertFalse(unit_scalars_almost_equal(val1, val3)) + + def test_values_close_enough(self): + val1 = UnitScalar(1., units="m") + val2 = val1 + UnitScalar(1.e-5, units="m") + self.assertFalse(unit_scalars_almost_equal(val1, val2)) + self.assertTrue(unit_scalars_almost_equal(val1, val2, rtol=1e-4)) + + +class TestUnitArraysAlmostEqual(TestCase): + def test_wrong_argument_type1(self): + val1 = 1 + val2 = UnitArray([1.], units="m") + with self.assertRaises(ValueError): + unit_arrays_almost_equal(val1, val2) + + def test_wrong_argument_type2(self): + val1 = UnitArray([1.], units="m") + val2 = 1 + with self.assertRaises(ValueError): + unit_arrays_almost_equal(val1, val2) + + def test_different_shape(self): + val1 = UnitArray([1.], units="m") + val2 = UnitArray([1., 2.], units="m") + self.assertFalse(unit_arrays_almost_equal(val1, val2)) + + def test_not_close_default_rtol(self): + val1 = UnitArray([1., 2.], units="m") + val2 = UnitArray([1., 2.1], units="m") + self.assertFalse(unit_arrays_almost_equal(val1, val2)) + + val2 = UnitArray([1., 2.000001], units="m") + self.assertFalse(unit_arrays_almost_equal(val1, val2)) + + def test_values_identical(self): + val1 = UnitArray([1., 2.], units="m") + self.assertTrue(unit_arrays_almost_equal(val1, val1)) + + def test_values_identical_in_diff_units(self): + val1 = UnitArray([1., 2.], units="m") + val2 = UnitArray([100., 200.], units="cm") + self.assertTrue(unit_arrays_almost_equal(val1, val2)) + + def test_dimensionless(self): + val1 = UnitArray([1.], units=dimensionless) + val2 = UnitArray([1.], units="cm") + self.assertFalse(unit_arrays_almost_equal(val1, val2)) + + def test_2_dimensionless(self): + val1 = UnitArray([1.], units=dimensionless) + val2 = UnitArray([1.], units="BLAH") + val3 = UnitArray([100.], units="BLAH") + self.assertTrue(unit_arrays_almost_equal(val1, val1)) + self.assertTrue(unit_arrays_almost_equal(val1, val2)) + self.assertFalse(unit_arrays_almost_equal(val1, val3)) + + def test_values_close_enough(self): + val1 = UnitArray([1., 2.], units="m") + val2 = val1 + UnitArray([1.e-5, 1.e-6], units="m") + self.assertFalse(unit_arrays_almost_equal(val1, val2)) + self.assertTrue(unit_arrays_almost_equal(val1, val2, rtol=1e-4)) + + def test_values_not_close_enough(self): + val1 = UnitArray([1., 2.], units="m") + val3 = val1 + UnitArray([1.e-2, 1.e-6], units="m") + self.assertFalse(unit_arrays_almost_equal(val1, val3, rtol=1e-4))