Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions cellmlmanip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,24 @@ def add_number(self, value, units):
Creates and returns a :class:`NumberDummy` to represent a number with units in sympy expressions.

:param number: A number (anything convertible to float).
:param units: A string unit representation.
:param units: A `pint` units representation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this also accepts strings, should we keep that option in the docs too? Or hide it just for use by our tests?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, and should it be double backticks? My reST is rusty!


:return: A :class:`NumberDummy` object.
"""
return NumberDummy(value, self.units.get_quantity(units))

# Check units
if not isinstance(units, self.units.ureg.Unit):
units = self.units.get_quantity(units)

return NumberDummy(value, units)

def add_variable(self, name, units, initial_value=None,
public_interface=None, private_interface=None, cmeta_id=None):
"""
Adds a variable to the model and returns a :class:`VariableDummy` to represent it in sympy expressions.

:param name: A string name.
:param units: A string units representation.
:param units: A `pint` units representation.
:param initial_value: An optional initial value.
:param public_interface: An optional public interface specifier (only required when parsing CellML).
:param private_interface: An optional private interface specifier (only required when parsing CellML).
Expand All @@ -116,10 +121,14 @@ def add_variable(self, name, units, initial_value=None,
if name in self._name_to_symbol:
raise ValueError('Variable %s already exists.' % name)

# Check units
if not isinstance(units, self.units.ureg.Unit):
units = self.units.get_quantity(units)

# Add variable
self._name_to_symbol[name] = var = VariableDummy(
name=name,
units=self.units.get_quantity(units),
units=units,
initial_value=initial_value,
public_interface=public_interface,
private_interface=private_interface,
Expand Down Expand Up @@ -172,7 +181,7 @@ def connect_variables(self, source_name: str, target_name: str):
factor = self.units.convert_to(1 * source.units, target.units).magnitude

# Dummy to represent this factor in equations, having units for conversion
factor_dummy = self.add_number(factor, str(target.units / source.units))
factor_dummy = self.add_number(factor, target.units / source.units)

# Add an equations making the connection with the required conversion
self.equations.append(sympy.Eq(target, source.assigned_to * factor_dummy))
Expand Down Expand Up @@ -395,6 +404,12 @@ def get_ontology_terms_by_symbol(self, symbol, namespace_uri=None):
ontology_terms.append(uri_parts[-1])
return ontology_terms

def get_units(self, name):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use this method above?

"""
Looks up and returns a pint `Unit` object with the given name.
"""
return self.units.get_quantity(name)

@property
def graph(self):
""" A ``networkx.DiGraph`` containing the model equations. """
Expand Down Expand Up @@ -456,8 +471,7 @@ def graph(self):
else:
# this variable is a parameter - add to graph and connect to lhs
rhs.type = 'parameter'
unit = rhs.units
dummy = self.add_number(rhs.initial_value, str(unit))
dummy = self.add_number(rhs.initial_value, rhs.units)
graph.add_node(rhs, equation=sympy.Eq(rhs, dummy), variable_type='parameter')
graph.add_edge(rhs, lhs)

Expand Down
5 changes: 4 additions & 1 deletion cellmlmanip/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def _add_variables(self, component_element):
attributes['name'] = Parser._get_variable_name(component_element.get('name'),
attributes['name'])

# look up units
attributes['units'] = self.model.get_units(attributes['units'])

# model.add_variable() returns sympy dummy created for this variable - keep it
variable_lookup_symbol[attributes['name']] = self.model.add_variable(**attributes)

Expand Down Expand Up @@ -244,7 +247,7 @@ def symbol_generator(identifer):
# reuse transpiler so dummy symbols are kept across <math> elements
transpiler = Transpiler(
symbol_generator=symbol_generator,
number_generator=lambda x, y: self.model.add_number(x, y),
number_generator=lambda x, y: self.model.add_number(x, self.model.get_units(y)),
)

# for each math element
Expand Down
6 changes: 3 additions & 3 deletions cellmlmanip/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,10 @@ def _define_pint_unit(self, units_name, definition_string_or_instance):
self.custom_defined.add(units_name)

def get_quantity(self, unit_name):
"""Returns a pint.Unit with the given name from the UnitRegistry.
"""Returns a pint `Unit` with the given name from the UnitRegistry.
:param unit_name: string name of the unit
:return: pint.Unit
throws pint.UndefinedUnitError if te unit is not present in registry
:return: `Unit`
throws pint.UndefinedUnitError if the unit is not present in registry
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use the Sphinx :raises pint.UndefinedUnitError: if... syntax here, actually?

"""
try:
return self.ureg.parse_expression(unit_name).units
Expand Down
60 changes: 49 additions & 11 deletions tests/test_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import sympy as sp

from cellmlmanip import parser
from cellmlmanip import parser, units
from cellmlmanip.model import Model, VariableDummy

from . import shared
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_get_equations_for(self):
"""

m = Model('simplification')
u = 'dimensionless'
u = m.get_units('dimensionless')
t = m.add_variable('t', u)
y1 = m.add_variable('y1', u, initial_value=10)
y2 = m.add_variable('y2', u, initial_value=20)
Expand Down Expand Up @@ -371,22 +371,15 @@ def test_remove_equation(self, local_hh_model):
t = model.get_symbol_by_ontology_term(shared.OXMETA, 'time')
equation = model.graph.nodes[sp.Derivative(v, t)]['equation']
model.remove_equation(equation)
equation = sp.Eq(v, model.add_number(-80, str(v.units)))
equation = sp.Eq(v, model.add_number(-80, v.units))
model.add_equation(equation)

# Check that V is no longer a state
v = model.get_symbol_by_ontology_term(shared.OXMETA, 'membrane_voltage')
assert v.type != 'state'

# TODO: Get dvdt_unit in a more sensible way
# See: https://github.com/ModellingWebLab/cellmlmanip/issues/133

# Now make V a state again
dvdt_units = 'unlikely_unit_name'
model.add_unit(dvdt_units, [
{'units': str(v.units)},
{'units': str(t.units), 'exponent': -1},
])
dvdt_units = v.units / t.units
model.remove_equation(equation)
equation = sp.Eq(sp.Derivative(v, t), model.add_number(0, dvdt_units))
model.add_equation(equation)
Expand Down Expand Up @@ -442,6 +435,51 @@ def test_add_variable(self, local_model):
with pytest.raises(ValueError, match='already exists'):
model.add_variable(name='varvar1', units=unit)

###################################################################
# Unit related functionality

def test_get_units(self):
""" Tests Model.get_units(). """

# Get predefined unit
m = Model('test')
m.get_units('volt')

# Non-existent unit
with pytest.raises(KeyError, match='Cannot find unit'):
m.get_units('towel')

def test_units(self, simple_units_model):
""" Tests units read and calculated from a model. """
symbol_a = simple_units_model.get_symbol_by_cmeta_id("a")
equation = simple_units_model.get_equations_for([symbol_a], strip_units=False)
assert simple_units_model.units.summarise_units(equation[0].lhs) == 'ms'
assert simple_units_model.units.summarise_units(equation[0].rhs) == 'ms'

symbol_b = simple_units_model.get_symbol_by_cmeta_id("b")
equation = simple_units_model.get_equations_for([symbol_b])
assert simple_units_model.units.summarise_units(equation[1].lhs) == 'per_ms'
assert simple_units_model.units.summarise_units(equation[1].rhs) == '1 / ms'
assert simple_units_model.units.is_unit_equal(simple_units_model.units.summarise_units(equation[1].lhs),
simple_units_model.units.summarise_units(equation[1].rhs))

def test_bad_units(self, bad_units_model):
""" Tests units read and calculated from an inconsistent model. """
symbol_a = bad_units_model.get_symbol_by_cmeta_id("a")
symbol_b = bad_units_model.get_symbol_by_cmeta_id("b")
equation = bad_units_model.get_equations_for([symbol_b], strip_units=False)
assert len(equation) == 2
assert equation[0].lhs == symbol_a
assert bad_units_model.units.summarise_units(equation[0].lhs) == 'ms'
with pytest.raises(units.UnitError):
# cellml file states a (ms) = 1 (ms) + 1 (second)
bad_units_model.units.summarise_units(equation[0].rhs)

assert equation[1].lhs == symbol_b
with pytest.raises(units.UnitError):
# cellml file states b (per_ms) = power(a (ms), 1 (second))
bad_units_model.units.summarise_units(equation[1].rhs)

###################################################################
# this section is for other functions

Expand Down
44 changes: 0 additions & 44 deletions tests/test_model_units.py

This file was deleted.