diff --git a/cellmlmanip/units.py b/cellmlmanip/units.py index c13cc00f..5ad4fb72 100644 --- a/cellmlmanip/units.py +++ b/cellmlmanip/units.py @@ -341,24 +341,51 @@ def summarise_units(self, expr): logger.debug('summarise_units(%s) ⟶ %s', expr, found.units) return found.units - def get_conversion_factor(self, quantity, to_unit): - """Returns the magnitude multiplier required to convert from_unit to to_unit + def get_conversion_factor(self, + to_unit, + from_unit=None, + quantity=None, + expression=None): + """Returns the magnitude multiplier required to convert a unit to the specified unit. + + Note this will work on either a unit, a quantity or an expression, but requires only + one of these arguments. + + :param to_unit: Unit object into which the units should be converted + :param from_unit: the Unit to be converted :param quantity: the Unit to be converted, multiplied by '1' to form a Quantity object - :param to_unit: Unit object into which the first units should be converted - :return the magnitude of the resulting conversion factor + :param expression: an expression from which the Unit is evaluated before conversion + + :return: the magnitude of the resulting conversion factor + + :throws: AssertionError if no target unit is specified or no source unit is specified """ - return self.convert_to(quantity, to_unit).magnitude + assert to_unit is not None, 'No unit given as target of conversion; to_unit argument is required' + assert quantity is not None or from_unit is not None or expression is not None, \ + 'No unit given as source of conversion; please use one of from_unit, quantity or expression' + assert [from_unit, quantity, expression].count(None) == 2, \ + 'Multiple target specified; please use only one of from_unit, quantity or expression' + + if from_unit is not None: + assert isinstance(from_unit, self.ureg.Unit), 'from_unit must be of type pint:Unit' + return self.convert_to(1 * from_unit, to_unit).magnitude + elif quantity is not None: + assert isinstance(quantity, self.ureg.Quantity), 'quantity must be of type pint:Quantity' + return self.convert_to(quantity, to_unit).magnitude + else: + assert isinstance(expression, sympy.Expr), 'expression must be of type Sympy expression' + return self.convert_to(1 * self.summarise_units(expression), to_unit).magnitude def dimensionally_equivalent(self, symbol1, symbol2): """Returns whether two expressions, symbol1 and symbol2, - are dimensionally_equivalent (same units ignorging a calling factor). + are dimensionally_equivalent (same units ignoring a calling factor). :param symbol1: the first expression to compare - :param unit2: the second expression to compare + :param symbol2: the second expression to compare :return True if units are equal (regardless of quantity), False otherwise """ try: - self.get_conversion_factor(1 * self.summarise_units(symbol1), - self.summarise_units(symbol2)) + self.get_conversion_factor(from_unit=self.summarise_units(symbol1), + to_unit=self.summarise_units(symbol2)) return True except pint.errors.DimensionalityError: return False diff --git a/tests/cellml_files/simple_model_units.cellml b/tests/cellml_files/simple_model_units.cellml index a89f8906..0f986900 100644 --- a/tests/cellml_files/simple_model_units.cellml +++ b/tests/cellml_files/simple_model_units.cellml @@ -10,6 +10,9 @@ + + + @@ -31,4 +34,15 @@ + + + + + + + b_1 + 5 + + + \ No newline at end of file diff --git a/tests/test_unit_conversion.py b/tests/test_unit_conversion.py index 0596a20d..e0d08c54 100644 --- a/tests/test_unit_conversion.py +++ b/tests/test_unit_conversion.py @@ -13,6 +13,11 @@ def model(): return cellmlmanip.load_model(os.path.join(os.path.dirname(__file__), 'cellml_files', "test_simple_odes.cellml")) +@pytest.fixture +def simple_model(): + return cellmlmanip.load_model(os.path.join(os.path.dirname(__file__), 'cellml_files', "simple_model_units.cellml")) + + def test_add_preferred_custom_unit_name(model): time_var = model.get_symbol_by_ontology_term(OXMETA, "time") assert str(model.units.summarise_units(time_var)) == "ms" @@ -23,3 +28,65 @@ def test_add_preferred_custom_unit_name(model): # again model.units.add_preferred_custom_unit_name('millisecond', [{'prefix': 'milli', 'units': 'second'}]) assert str(model.units.summarise_units(time_var)) == "millisecond" + + +def test_conversion_factor_original(simple_model): + simple_model.get_equation_graph(True) # set up the graph - it is not automatic + symbol_b1 = simple_model.get_symbol_by_cmeta_id("b_1") + equation = simple_model.get_equations_for([symbol_b1]) + factor = simple_model.units.get_conversion_factor(quantity=1 * simple_model.units.summarise_units(equation[0].lhs), + to_unit=simple_model.units.ureg('us').units) + assert factor == 1000 + + +def test_conversion_factor_bad_types(simple_model): + simple_model.get_equation_graph(True) # set up the graph - it is not automatic + symbol_b1 = simple_model.get_symbol_by_cmeta_id("b_1") + equation = simple_model.get_equations_for([symbol_b1]) + expression = equation[0].lhs + to_unit = simple_model.units.ureg('us').units + from_unit = simple_model.units.summarise_units(expression) + quantity = 1 * from_unit + # no source unit + with pytest.raises(AssertionError, match='^No unit given as source.*'): + simple_model.units.get_conversion_factor(to_unit=to_unit) + with pytest.raises(AssertionError, match='^No unit given as source.*'): + simple_model.units.get_conversion_factor(to_unit) + + # no target unit + with pytest.raises(TypeError): + simple_model.units.get_conversion_factor(from_unit=from_unit) + # multiple sources + with pytest.raises(AssertionError, match='^Multiple target.*'): + simple_model.units.get_conversion_factor(to_unit, from_unit=from_unit, quantity=quantity) + # incorrect types + with pytest.raises(AssertionError, match='^from_unit must be of type pint:Unit$'): + simple_model.units.get_conversion_factor(to_unit, from_unit=quantity) + with pytest.raises(AssertionError, match='^quantity must be of type pint:Quantity$'): + simple_model.units.get_conversion_factor(to_unit, quantity=from_unit) + with pytest.raises(AssertionError, match='^expression must be of type Sympy expression$'): + simple_model.units.get_conversion_factor(to_unit, expression=quantity) + + # unit to unit + assert simple_model.units.get_conversion_factor(to_unit=to_unit, from_unit=from_unit) == 1000 + # quantity to unit + assert simple_model.units.get_conversion_factor(to_unit=to_unit, quantity=quantity) == 1000 + # expression to unit + assert simple_model.units.get_conversion_factor(to_unit=to_unit, expression=expression) == 1000 + + +def test_conversion_factor_same_units(simple_model): + simple_model.get_equation_graph(True) # set up the graph - it is not automatic + symbol_b = simple_model.get_symbol_by_cmeta_id("b") + equation = simple_model.get_equations_for([symbol_b]) + expression = equation[1].rhs + to_unit = simple_model.units.ureg('per_ms').units + from_unit = simple_model.units.summarise_units(expression) + quantity = 1 * from_unit + # quantity to unit + assert simple_model.units.get_conversion_factor(to_unit=to_unit, quantity=quantity) == 1 + # unit to unit + assert simple_model.units.get_conversion_factor(to_unit=to_unit, from_unit=from_unit) == 1 + # expression to unit + assert simple_model.units.get_conversion_factor(to_unit=to_unit, expression=expression) == 1 + diff --git a/tests/test_units.py b/tests/test_units.py index 5cbacc58..fa0d3e06 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -103,12 +103,12 @@ def test_quantity_translation(self, quantity_store): def test_conversion_factor(self, quantity_store): ureg = quantity_store.ureg - assert quantity_store.get_conversion_factor(1 * ureg.ms, ureg.second) == 0.001 - assert quantity_store.get_conversion_factor(1 * ureg.volt, ureg.mV) == 1000.0 + assert quantity_store.get_conversion_factor(quantity=1 * ureg.ms, to_unit=ureg.second) == 0.001 + assert quantity_store.get_conversion_factor(quantity=1 * ureg.volt, to_unit=ureg.mV) == 1000.0 assert quantity_store.get_conversion_factor( - 1 * quantity_store.get_quantity('milli_mole'), - quantity_store.get_quantity('mole') + quantity=1 * quantity_store.get_quantity('milli_mole'), + to_unit=quantity_store.get_quantity('mole') ) == 0.001 def test_add_custom_unit(self):