diff --git a/cellmlmanip/units.py b/cellmlmanip/units.py
index c13cc00f..5ad4fb72 100644
--- a/cellmlmanip/units.py
+++ b/cellmlmanip/units.py
@@ -341,24 +341,51 @@ def summarise_units(self, expr):
logger.debug('summarise_units(%s) ⟶ %s', expr, found.units)
return found.units
- def get_conversion_factor(self, quantity, to_unit):
- """Returns the magnitude multiplier required to convert from_unit to to_unit
+ def get_conversion_factor(self,
+ to_unit,
+ from_unit=None,
+ quantity=None,
+ expression=None):
+ """Returns the magnitude multiplier required to convert a unit to the specified unit.
+
+ Note this will work on either a unit, a quantity or an expression, but requires only
+ one of these arguments.
+
+ :param to_unit: Unit object into which the units should be converted
+ :param from_unit: the Unit to be converted
:param quantity: the Unit to be converted, multiplied by '1' to form a Quantity object
- :param to_unit: Unit object into which the first units should be converted
- :return the magnitude of the resulting conversion factor
+ :param expression: an expression from which the Unit is evaluated before conversion
+
+ :return: the magnitude of the resulting conversion factor
+
+ :throws: AssertionError if no target unit is specified or no source unit is specified
"""
- return self.convert_to(quantity, to_unit).magnitude
+ assert to_unit is not None, 'No unit given as target of conversion; to_unit argument is required'
+ assert quantity is not None or from_unit is not None or expression is not None, \
+ 'No unit given as source of conversion; please use one of from_unit, quantity or expression'
+ assert [from_unit, quantity, expression].count(None) == 2, \
+ 'Multiple target specified; please use only one of from_unit, quantity or expression'
+
+ if from_unit is not None:
+ assert isinstance(from_unit, self.ureg.Unit), 'from_unit must be of type pint:Unit'
+ return self.convert_to(1 * from_unit, to_unit).magnitude
+ elif quantity is not None:
+ assert isinstance(quantity, self.ureg.Quantity), 'quantity must be of type pint:Quantity'
+ return self.convert_to(quantity, to_unit).magnitude
+ else:
+ assert isinstance(expression, sympy.Expr), 'expression must be of type Sympy expression'
+ return self.convert_to(1 * self.summarise_units(expression), to_unit).magnitude
def dimensionally_equivalent(self, symbol1, symbol2):
"""Returns whether two expressions, symbol1 and symbol2,
- are dimensionally_equivalent (same units ignorging a calling factor).
+ are dimensionally_equivalent (same units ignoring a calling factor).
:param symbol1: the first expression to compare
- :param unit2: the second expression to compare
+ :param symbol2: the second expression to compare
:return True if units are equal (regardless of quantity), False otherwise
"""
try:
- self.get_conversion_factor(1 * self.summarise_units(symbol1),
- self.summarise_units(symbol2))
+ self.get_conversion_factor(from_unit=self.summarise_units(symbol1),
+ to_unit=self.summarise_units(symbol2))
return True
except pint.errors.DimensionalityError:
return False
diff --git a/tests/cellml_files/simple_model_units.cellml b/tests/cellml_files/simple_model_units.cellml
index a89f8906..0f986900 100644
--- a/tests/cellml_files/simple_model_units.cellml
+++ b/tests/cellml_files/simple_model_units.cellml
@@ -10,6 +10,9 @@
+
+
+
@@ -31,4 +34,15 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/test_unit_conversion.py b/tests/test_unit_conversion.py
index 0596a20d..e0d08c54 100644
--- a/tests/test_unit_conversion.py
+++ b/tests/test_unit_conversion.py
@@ -13,6 +13,11 @@ def model():
return cellmlmanip.load_model(os.path.join(os.path.dirname(__file__), 'cellml_files', "test_simple_odes.cellml"))
+@pytest.fixture
+def simple_model():
+ return cellmlmanip.load_model(os.path.join(os.path.dirname(__file__), 'cellml_files', "simple_model_units.cellml"))
+
+
def test_add_preferred_custom_unit_name(model):
time_var = model.get_symbol_by_ontology_term(OXMETA, "time")
assert str(model.units.summarise_units(time_var)) == "ms"
@@ -23,3 +28,65 @@ def test_add_preferred_custom_unit_name(model):
# again
model.units.add_preferred_custom_unit_name('millisecond', [{'prefix': 'milli', 'units': 'second'}])
assert str(model.units.summarise_units(time_var)) == "millisecond"
+
+
+def test_conversion_factor_original(simple_model):
+ simple_model.get_equation_graph(True) # set up the graph - it is not automatic
+ symbol_b1 = simple_model.get_symbol_by_cmeta_id("b_1")
+ equation = simple_model.get_equations_for([symbol_b1])
+ factor = simple_model.units.get_conversion_factor(quantity=1 * simple_model.units.summarise_units(equation[0].lhs),
+ to_unit=simple_model.units.ureg('us').units)
+ assert factor == 1000
+
+
+def test_conversion_factor_bad_types(simple_model):
+ simple_model.get_equation_graph(True) # set up the graph - it is not automatic
+ symbol_b1 = simple_model.get_symbol_by_cmeta_id("b_1")
+ equation = simple_model.get_equations_for([symbol_b1])
+ expression = equation[0].lhs
+ to_unit = simple_model.units.ureg('us').units
+ from_unit = simple_model.units.summarise_units(expression)
+ quantity = 1 * from_unit
+ # no source unit
+ with pytest.raises(AssertionError, match='^No unit given as source.*'):
+ simple_model.units.get_conversion_factor(to_unit=to_unit)
+ with pytest.raises(AssertionError, match='^No unit given as source.*'):
+ simple_model.units.get_conversion_factor(to_unit)
+
+ # no target unit
+ with pytest.raises(TypeError):
+ simple_model.units.get_conversion_factor(from_unit=from_unit)
+ # multiple sources
+ with pytest.raises(AssertionError, match='^Multiple target.*'):
+ simple_model.units.get_conversion_factor(to_unit, from_unit=from_unit, quantity=quantity)
+ # incorrect types
+ with pytest.raises(AssertionError, match='^from_unit must be of type pint:Unit$'):
+ simple_model.units.get_conversion_factor(to_unit, from_unit=quantity)
+ with pytest.raises(AssertionError, match='^quantity must be of type pint:Quantity$'):
+ simple_model.units.get_conversion_factor(to_unit, quantity=from_unit)
+ with pytest.raises(AssertionError, match='^expression must be of type Sympy expression$'):
+ simple_model.units.get_conversion_factor(to_unit, expression=quantity)
+
+ # unit to unit
+ assert simple_model.units.get_conversion_factor(to_unit=to_unit, from_unit=from_unit) == 1000
+ # quantity to unit
+ assert simple_model.units.get_conversion_factor(to_unit=to_unit, quantity=quantity) == 1000
+ # expression to unit
+ assert simple_model.units.get_conversion_factor(to_unit=to_unit, expression=expression) == 1000
+
+
+def test_conversion_factor_same_units(simple_model):
+ simple_model.get_equation_graph(True) # set up the graph - it is not automatic
+ symbol_b = simple_model.get_symbol_by_cmeta_id("b")
+ equation = simple_model.get_equations_for([symbol_b])
+ expression = equation[1].rhs
+ to_unit = simple_model.units.ureg('per_ms').units
+ from_unit = simple_model.units.summarise_units(expression)
+ quantity = 1 * from_unit
+ # quantity to unit
+ assert simple_model.units.get_conversion_factor(to_unit=to_unit, quantity=quantity) == 1
+ # unit to unit
+ assert simple_model.units.get_conversion_factor(to_unit=to_unit, from_unit=from_unit) == 1
+ # expression to unit
+ assert simple_model.units.get_conversion_factor(to_unit=to_unit, expression=expression) == 1
+
diff --git a/tests/test_units.py b/tests/test_units.py
index 5cbacc58..fa0d3e06 100644
--- a/tests/test_units.py
+++ b/tests/test_units.py
@@ -103,12 +103,12 @@ def test_quantity_translation(self, quantity_store):
def test_conversion_factor(self, quantity_store):
ureg = quantity_store.ureg
- assert quantity_store.get_conversion_factor(1 * ureg.ms, ureg.second) == 0.001
- assert quantity_store.get_conversion_factor(1 * ureg.volt, ureg.mV) == 1000.0
+ assert quantity_store.get_conversion_factor(quantity=1 * ureg.ms, to_unit=ureg.second) == 0.001
+ assert quantity_store.get_conversion_factor(quantity=1 * ureg.volt, to_unit=ureg.mV) == 1000.0
assert quantity_store.get_conversion_factor(
- 1 * quantity_store.get_quantity('milli_mole'),
- quantity_store.get_quantity('mole')
+ quantity=1 * quantity_store.get_quantity('milli_mole'),
+ to_unit=quantity_store.get_quantity('mole')
) == 0.001
def test_add_custom_unit(self):