From 2d5f40b24038b55b39ac49bda378bea9685cddb9 Mon Sep 17 00:00:00 2001 From: Michael Clerx Date: Mon, 2 Dec 2019 20:01:23 +0000 Subject: [PATCH 1/2] Model.add_number and add_variable now allow unit objects instead of strings as arguments. --- cellmlmanip/model.py | 29 +++++++++++++---- cellmlmanip/parser.py | 5 ++- cellmlmanip/units.py | 6 ++-- tests/test_model_functions.py | 60 ++++++++++++++++++++++++++++------- tests/test_model_units.py | 44 ------------------------- 5 files changed, 78 insertions(+), 66 deletions(-) delete mode 100644 tests/test_model_units.py diff --git a/cellmlmanip/model.py b/cellmlmanip/model.py index 5e014f6c..34b57463 100644 --- a/cellmlmanip/model.py +++ b/cellmlmanip/model.py @@ -92,11 +92,16 @@ def add_number(self, value, units): Creates and returns a :class:`NumberDummy` to represent a number with units in sympy expressions. :param number: A number (anything convertible to float). - :param units: A string unit representation. + :param units: A `pint` units representation :return: A :class:`NumberDummy` object. """ - return NumberDummy(value, self.units.get_quantity(units)) + + # Check units + if not isinstance(units, self.units.ureg.Unit): + units = self.units.get_quantity(units) + + return NumberDummy(value, units) def add_variable(self, name, units, initial_value=None, public_interface=None, private_interface=None, cmeta_id=None): @@ -104,7 +109,7 @@ def add_variable(self, name, units, initial_value=None, Adds a variable to the model and returns a :class:`VariableDummy` to represent it in sympy expressions. :param name: A string name. - :param units: A string units representation. + :param units: A `pint` units representation. :param initial_value: An optional initial value. :param public_interface: An optional public interface specifier (only required when parsing CellML). :param private_interface: An optional private interface specifier (only required when parsing CellML). @@ -115,9 +120,14 @@ def add_variable(self, name, units, initial_value=None, if name in self._name_to_symbol: raise ValueError('Variable %s already exists.' % name) + # Check units + if not isinstance(units, self.units.ureg.Unit): + units = self.units.get_quantity(units) + + # Add variable self._name_to_symbol[name] = var = VariableDummy( name=name, - units=self.units.get_quantity(units), + units=units, initial_value=initial_value, public_interface=public_interface, private_interface=private_interface, @@ -161,7 +171,7 @@ def connect_variables(self, source_name: str, target_name: str): factor = self.units.convert_to(1 * source.units, target.units).magnitude # Dummy to represent this factor in equations, having units for conversion - factor_dummy = self.add_number(factor, str(target.units / source.units)) + factor_dummy = self.add_number(factor, target.units / source.units) # Add an equations making the connection with the required conversion self.equations.append(sympy.Eq(target, source.assigned_to * factor_dummy)) @@ -387,6 +397,12 @@ def get_ontology_terms_by_symbol(self, symbol, namespace_uri=None): ontology_terms.append(uri_parts[-1]) return ontology_terms + def get_units(self, name): + """ + Looks up and returns a pint `Unit` object with the given name. + """ + return self.units.get_quantity(name) + @property def graph(self): """ A ``networkx.DiGraph`` containing the model equations. """ @@ -448,8 +464,7 @@ def graph(self): else: # this variable is a parameter - add to graph and connect to lhs rhs.type = 'parameter' - unit = rhs.units - dummy = self.add_number(rhs.initial_value, str(unit)) + dummy = self.add_number(rhs.initial_value, rhs.units) graph.add_node(rhs, equation=sympy.Eq(rhs, dummy), variable_type='parameter') graph.add_edge(rhs, lhs) diff --git a/cellmlmanip/parser.py b/cellmlmanip/parser.py index 196336ed..e9e114b1 100644 --- a/cellmlmanip/parser.py +++ b/cellmlmanip/parser.py @@ -214,6 +214,9 @@ def _add_variables(self, component_element): attributes['name'] = Parser._get_variable_name(component_element.get('name'), attributes['name']) + # look up units + attributes['units'] = self.model.get_units(attributes['units']) + # model.add_variable() returns sympy dummy created for this variable - keep it variable_lookup_symbol[attributes['name']] = self.model.add_variable(**attributes) @@ -244,7 +247,7 @@ def symbol_generator(identifer): # reuse transpiler so dummy symbols are kept across elements transpiler = Transpiler( symbol_generator=symbol_generator, - number_generator=lambda x, y: self.model.add_number(x, y), + number_generator=lambda x, y: self.model.add_number(x, self.model.get_units(y)), ) # for each math element diff --git a/cellmlmanip/units.py b/cellmlmanip/units.py index cf0517cd..1543a46f 100644 --- a/cellmlmanip/units.py +++ b/cellmlmanip/units.py @@ -250,10 +250,10 @@ def _define_pint_unit(self, units_name, definition_string_or_instance): self.custom_defined.add(units_name) def get_quantity(self, unit_name): - """Returns a pint.Unit with the given name from the UnitRegistry. + """Returns a pint `Unit` with the given name from the UnitRegistry. :param unit_name: string name of the unit - :return: pint.Unit - throws pint.UndefinedUnitError if te unit is not present in registry + :return: `Unit` + throws pint.UndefinedUnitError if the unit is not present in registry """ try: return self.ureg.parse_expression(unit_name).units diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index 0cff687e..fdc2a4cd 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -3,7 +3,7 @@ import pytest import sympy as sp -from cellmlmanip import parser +from cellmlmanip import parser, units from cellmlmanip.model import Model, VariableDummy from . import shared @@ -160,7 +160,7 @@ def test_get_equations_for(self): """ m = Model('simplification') - u = 'dimensionless' + u = m.get_units('dimensionless') t = m.add_variable('t', u) y1 = m.add_variable('y1', u, initial_value=10) y2 = m.add_variable('y2', u, initial_value=20) @@ -395,24 +395,17 @@ def test_set_equation2(self, local_hh_model): assert v.type == 'state' # Now clamp it to -80mV - rhs = model.add_number(-80, str(v.units)) + rhs = model.add_number(-80, v.units) model.set_equation(v, rhs) # Check that V is no longer a state v = model.get_symbol_by_ontology_term(shared.OXMETA, 'membrane_voltage') assert v.type != 'state' - # TODO: Get dvdt_unit in a more sensible way - # See: https://github.com/ModellingWebLab/cellmlmanip/issues/133 - # Now make V a state again t = model.get_symbol_by_ontology_term(shared.OXMETA, 'time') lhs = sp.Derivative(v, t) - dvdt_units = 'unlikely_unit_name' - model.add_unit(dvdt_units, [ - {'units': str(v.units)}, - {'units': str(t.units), 'exponent': -1}, - ]) + dvdt_units = v.units / t.units rhs = model.add_number(0, dvdt_units) model.set_equation(lhs, rhs) @@ -467,6 +460,51 @@ def test_add_variable(self, local_model): with pytest.raises(ValueError, match='already exists'): model.add_variable(name='varvar1', units=unit) + ################################################################### + # Unit related functionality + + def test_get_units(self): + """ Tests Model.get_units(). """ + + # Get predefined unit + m = Model('test') + m.get_units('volt') + + # Non-existent unit + with pytest.raises(KeyError, match='Cannot find unit'): + m.get_units('towel') + + def test_units(self, simple_units_model): + """ Tests units read and calculated from a model. """ + symbol_a = simple_units_model.get_symbol_by_cmeta_id("a") + equation = simple_units_model.get_equations_for([symbol_a], strip_units=False) + assert simple_units_model.units.summarise_units(equation[0].lhs) == 'ms' + assert simple_units_model.units.summarise_units(equation[0].rhs) == 'ms' + + symbol_b = simple_units_model.get_symbol_by_cmeta_id("b") + equation = simple_units_model.get_equations_for([symbol_b]) + assert simple_units_model.units.summarise_units(equation[1].lhs) == 'per_ms' + assert simple_units_model.units.summarise_units(equation[1].rhs) == '1 / ms' + assert simple_units_model.units.is_unit_equal(simple_units_model.units.summarise_units(equation[1].lhs), + simple_units_model.units.summarise_units(equation[1].rhs)) + + def test_bad_units(self, bad_units_model): + """ Tests units read and calculated from an inconsistent model. """ + symbol_a = bad_units_model.get_symbol_by_cmeta_id("a") + symbol_b = bad_units_model.get_symbol_by_cmeta_id("b") + equation = bad_units_model.get_equations_for([symbol_b], strip_units=False) + assert len(equation) == 2 + assert equation[0].lhs == symbol_a + assert bad_units_model.units.summarise_units(equation[0].lhs) == 'ms' + with pytest.raises(units.UnitError): + # cellml file states a (ms) = 1 (ms) + 1 (second) + bad_units_model.units.summarise_units(equation[0].rhs) + + assert equation[1].lhs == symbol_b + with pytest.raises(units.UnitError): + # cellml file states b (per_ms) = power(a (ms), 1 (second)) + bad_units_model.units.summarise_units(equation[1].rhs) + ################################################################### # this section is for other functions diff --git a/tests/test_model_units.py b/tests/test_model_units.py deleted file mode 100644 index 79f184c4..00000000 --- a/tests/test_model_units.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest - -from cellmlmanip import units - - -# TODO some tests here are repeats and may not be necessary -class TestModelUnits: - def test_symbols(self, simple_units_model): - """ Tests the Model.get_symbol_by_cmeta_id function.""" - symbol = simple_units_model.get_symbol_by_cmeta_id("a") - assert symbol.is_Symbol - symbol = simple_units_model.get_symbol_by_cmeta_id("b") - assert symbol.is_Symbol - - def test_units(self, simple_units_model): - """ Tests units read and calculated from a model. """ - symbol_a = simple_units_model.get_symbol_by_cmeta_id("a") - equation = simple_units_model.get_equations_for([symbol_a], strip_units=False) - assert simple_units_model.units.summarise_units(equation[0].lhs) == 'ms' - assert simple_units_model.units.summarise_units(equation[0].rhs) == 'ms' - - symbol_b = simple_units_model.get_symbol_by_cmeta_id("b") - equation = simple_units_model.get_equations_for([symbol_b]) - assert simple_units_model.units.summarise_units(equation[1].lhs) == 'per_ms' - assert simple_units_model.units.summarise_units(equation[1].rhs) == '1 / ms' - assert simple_units_model.units.is_unit_equal(simple_units_model.units.summarise_units(equation[1].lhs), - simple_units_model.units.summarise_units(equation[1].rhs)) - - def test_bad_units(self, bad_units_model): - """ Tests units read and calculated from an inconsistent model. """ - symbol_a = bad_units_model.get_symbol_by_cmeta_id("a") - symbol_b = bad_units_model.get_symbol_by_cmeta_id("b") - equation = bad_units_model.get_equations_for([symbol_b], strip_units=False) - assert len(equation) == 2 - assert equation[0].lhs == symbol_a - assert bad_units_model.units.summarise_units(equation[0].lhs) == 'ms' - with pytest.raises(units.UnitError): - # cellml file states a (ms) = 1 (ms) + 1 (second) - bad_units_model.units.summarise_units(equation[0].rhs) - - assert equation[1].lhs == symbol_b - with pytest.raises(units.UnitError): - # cellml file states b (per_ms) = power(a (ms), 1 (second)) - bad_units_model.units.summarise_units(equation[1].rhs) From 2f08502ea396aec2759c05abe7a7c300d4b7b945 Mon Sep 17 00:00:00 2001 From: Michael Clerx Date: Fri, 6 Dec 2019 11:32:54 +0100 Subject: [PATCH 2/2] Fixed syntax error introduced in merging --- tests/test_model_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index 2cebf24c..a47e278f 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -371,7 +371,7 @@ def test_remove_equation(self, local_hh_model): t = model.get_symbol_by_ontology_term(shared.OXMETA, 'time') equation = model.graph.nodes[sp.Derivative(v, t)]['equation'] model.remove_equation(equation) - equation = sp.Eq(v, model.add_number(-80, v.units) + equation = sp.Eq(v, model.add_number(-80, v.units)) model.add_equation(equation) # Check that V is no longer a state