Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 47 additions & 59 deletions cellmlmanip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ def add_variable(self, name, units, initial_value=None,

:return: A :class:`VariableDummy` object.
"""
# Check for clashes
if name in self._name_to_symbol:
raise ValueError('Variable %s already exists.' % name)

# Add variable
self._name_to_symbol[name] = var = VariableDummy(
name=name,
units=self.units.get_quantity(units),
Expand All @@ -125,6 +127,9 @@ def add_variable(self, name, units, initial_value=None,
cmeta_id=cmeta_id,
)

# Invalidate cached graphs
self._invalidate_cache()

return var

def connect_variables(self, source_name: str, target_name: str):
Expand All @@ -141,45 +146,48 @@ def connect_variables(self, source_name: str, target_name: str):
source = self._name_to_symbol[source_name]
target = self._name_to_symbol[target_name]

# If the source variable has already been assigned a final symbol
if source.assigned_to:

if target.assigned_to:
raise ValueError('Target already assigned to %s before assignment to %s' %
(target.assigned_to, source.assigned_to))

# If source/target variable is in the same unit
if source.units == target.units:
# Direct substitution is possible
target.assigned_to = source.assigned_to
# everywhere the target variable is used, replace with source variable
for index, equation in enumerate(self.equations):
self.equations[index] = equation.xreplace({target: source.assigned_to})
# Otherwise, this connection requires a conversion
else:
# Get the scaling factor required to convert source units to target units
factor = self.units.convert_to(1 * source.units, target.units).magnitude
# If the source variable has not been assigned a symbol, we can't make this connection
if not source.assigned_to:
logger.info('The source variable has not been assigned to a symbol '
'(i.e. expecting a connection): %s ⟶ %s',
target.name, source.name)
return False

# If target is already assigned this is an error
if target.assigned_to:
raise ValueError('Target already assigned to %s before assignment to %s' %
(target.assigned_to, source.assigned_to))

# If source/target variable is in the same unit
if source.units == target.units:
# Direct substitution is possible
target.assigned_to = source.assigned_to
# everywhere the target variable is used, replace with source variable
for index, equation in enumerate(self.equations):
self.equations[index] = equation.xreplace({target: source.assigned_to})

# Otherwise, this connection requires a conversion
else:
# Get the scaling factor required to convert source units to target units
factor = self.units.convert_to(1 * source.units, target.units).magnitude

# Dummy to represent this factor in equations, having units for conversion
factor_dummy = self.add_number(factor, str(target.units / source.units))
# Dummy to represent this factor in equations, having units for conversion
factor_dummy = self.add_number(factor, str(target.units / source.units))

# Add an equations making the connection with the required conversion
self.equations.append(sympy.Eq(target, source.assigned_to * factor_dummy))
# Add an equations making the connection with the required conversion
self.equations.append(sympy.Eq(target, source.assigned_to * factor_dummy))

logger.info('Connection req. unit conversion: %s', self.equations[-1])
logger.info('Connection req. unit conversion: %s', self.equations[-1])

# The assigned symbol for this variable is itself
target.assigned_to = target
# The assigned symbol for this variable is itself
target.assigned_to = target

logger.debug('Updated target: %s', target)
logger.debug('Updated target: %s', target)

return True
# Invalidate cached graphs
self._invalidate_cache()
Comment thread
MichaelClerx marked this conversation as resolved.

# The source variable has not been assigned a symbol, so we can't make this connection
logger.info('The source variable has not been assigned to a symbol '
'(i.e. expecting a connection): %s ⟶ %s',
target.name, source.name)
return False
return True

def add_rdf(self, rdf: str):
""" Takes an RDF string and stores it in the model's RDF graph. """
Expand Down Expand Up @@ -569,36 +577,16 @@ def find_symbols_and_derivatives(self, expression):
symbols |= self.find_symbols_and_derivatives(expr.args)
return symbols

def set_equation(self, lhs, rhs):
def remove_equation(self, equation):
"""
Adds an equation defining the variable named in ``lhs``, or replaces an existing one.

As with :meth:`add_equation()` the LHS must be either a variable symbol or a derivative, and all numbers and
variable symbols used in ``lhs`` and ``rhs`` must have been obtained from this model, e.g. via
:meth:`add_number()`, :meth:`add_variable()`, or :meth:`get_symbol_by_ontology_term()`.
Removes an equation from the model.

:param lhs: An LHS expression (either a symbol or a derivative).
:param rhs: The new RHS expression for this variable.
:param equation: The equation to remove.
"""
# Get variable symbol named in the lhs
lhs_symbol = lhs
if lhs_symbol.is_Derivative:
lhs_symbol = lhs_symbol.free_symbols.pop()
assert isinstance(lhs_symbol, VariableDummy)

# Check if the variable named in the lhs already has an equation
i_existing = None
for i, eq in enumerate(self.equations):
symbol = eq.lhs.free_symbols.pop() if eq.lhs.is_Derivative else eq.lhs
if symbol == lhs_symbol:
i_existing = i
break

# Add or replace equation
if i_existing is None:
self.equations.append(sympy.Eq(lhs, rhs))
else:
self.equations[i_existing] = sympy.Eq(lhs, rhs)
try:
Comment thread
MichaelClerx marked this conversation as resolved.
self.equations.remove(equation)
except ValueError:
Comment thread
MichaelClerx marked this conversation as resolved.
raise KeyError('Equation not found in model ' + str(equation))

# Invalidate cached equation graphs
self._invalidate_cache()
Expand Down
53 changes: 14 additions & 39 deletions tests/test_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,44 +359,20 @@ def test_add_equation(self, local_model):
assert eqn[2].lhs == symbol2
assert eqn[2].rhs == sp.Add(symbol, symbol1)

def test_set_equation(self, local_model):
""" Tests the Model.set_equation method.
"""
model = local_model
assert len(model.equations) == 1
# so we are adding
# newvar2 = newvar + newvar1
# but need to also add newvar1 = 2; newvar = 2 in order or the graph to resolve correctly
model.add_variable(name='newvar', units='mV')
symbol = model.get_symbol_by_name('newvar')
model.add_variable(name='newvar1', units='mV')
symbol1 = model.get_symbol_by_name('newvar1')
model.add_variable(name='newvar2', units='mV')
symbol2 = model.get_symbol_by_name('newvar2')
model.set_equation(symbol, 2.0)
model.set_equation(symbol1, 2.0)
model.set_equation(symbol2, sp.Add(symbol, symbol1))
assert len(model.equations) == 4
eqn = model.get_equations_for([symbol2])
assert len(eqn) == 3
assert eqn[0].lhs == symbol
assert eqn[0].rhs == 2.0
assert eqn[1].lhs == symbol1
assert eqn[1].rhs == 2.0
assert eqn[2].lhs == symbol2
assert eqn[2].rhs == sp.Add(symbol, symbol1)

def test_set_equation2(self, local_hh_model):
""" Tests replacing an equation in a model. """
def test_remove_equation(self, local_hh_model):
""" Tests the Model.remove_equation method. """

model = local_hh_model
# Get model, assert that V is a state variable
v = model.get_symbol_by_ontology_term(shared.OXMETA, 'membrane_voltage')
assert v.type == 'state'

# Now clamp it to -80mV
rhs = model.add_number(-80, str(v.units))
model.set_equation(v, rhs)
t = model.get_symbol_by_ontology_term(shared.OXMETA, 'time')
equation = model.graph.nodes[sp.Derivative(v, t)]['equation']
model.remove_equation(equation)
equation = sp.Eq(v, model.add_number(-80, str(v.units)))
model.add_equation(equation)

# Check that V is no longer a state
v = model.get_symbol_by_ontology_term(shared.OXMETA, 'membrane_voltage')
Expand All @@ -406,24 +382,23 @@ def test_set_equation2(self, local_hh_model):
# See: https://github.com/ModellingWebLab/cellmlmanip/issues/133

# Now make V a state again
t = model.get_symbol_by_ontology_term(shared.OXMETA, 'time')
lhs = sp.Derivative(v, t)
dvdt_units = 'unlikely_unit_name'
model.add_unit(dvdt_units, [
{'units': str(v.units)},
{'units': str(t.units), 'exponent': -1},
])
rhs = model.add_number(0, dvdt_units)
model.set_equation(lhs, rhs)
model.remove_equation(equation)
equation = sp.Eq(sp.Derivative(v, t), model.add_number(0, dvdt_units))
model.add_equation(equation)

# Check that V is a state again
v = model.get_symbol_by_ontology_term(shared.OXMETA, 'membrane_voltage')
assert v.type == 'state'

# Set equation for a newly created variable
lhs = model.add_variable(name='an_incredibly_unlikely_variable_name', units=str(v.units))
rhs = model.add_number(12, str(v.units))
model.set_equation(lhs, rhs)
# Test removing non-existing equation
equation = sp.Eq(sp.Derivative(v, t), model.add_number(5, dvdt_units))
with pytest.raises(KeyError, match='Equation not found'):
model.remove_equation(equation)

def test_add_number(self, local_model):
""" Tests the Model.add_number method. """
Expand Down