From c6f927555409fcc0f46931b51c47c108fde2c514 Mon Sep 17 00:00:00 2001 From: Michael Clerx Date: Mon, 2 Dec 2019 18:12:07 +0000 Subject: [PATCH 1/2] Replaced Model.set_equation(lhs, rhs) with Model.remove_equation(equation). Fixes #118. --- cellmlmanip/model.py | 47 +++++++++++++------------------ tests/test_model_functions.py | 53 +++++++++-------------------------- 2 files changed, 34 insertions(+), 66 deletions(-) diff --git a/cellmlmanip/model.py b/cellmlmanip/model.py index 11670566..cd89e335 100644 --- a/cellmlmanip/model.py +++ b/cellmlmanip/model.py @@ -112,9 +112,11 @@ def add_variable(self, name, units, initial_value=None, :return: A :class:`VariableDummy` object. """ + # Check for clashes if name in self._name_to_symbol: raise ValueError('Variable %s already exists.' % name) + # Add variable self._name_to_symbol[name] = var = VariableDummy( name=name, units=self.units.get_quantity(units), @@ -125,6 +127,9 @@ def add_variable(self, name, units, initial_value=None, cmeta_id=cmeta_id, ) + # Invalidate cached graphs + self._invalidate_cache() + return var def connect_variables(self, source_name: str, target_name: str): @@ -173,12 +178,19 @@ def connect_variables(self, source_name: str, target_name: str): logger.debug('Updated target: %s', target) + # Invalidate cached graphs + self._invalidate_cache() + return True # The source variable has not been assigned a symbol, so we can't make this connection logger.info('The source variable has not been assigned to a symbol ' '(i.e. expecting a connection): %s ⟶ %s', target.name, source.name) + + # Invalidate cached graphs + self._invalidate_cache() + return False def add_rdf(self, rdf: str): @@ -559,36 +571,17 @@ def find_symbols_and_derivatives(self, expression): symbols |= self.find_symbols_and_derivatives(expr.args) return symbols - def set_equation(self, lhs, rhs): + def remove_equation(self, equation): """ - Adds an equation defining the variable named in ``lhs``, or replaces an existing one. - - As with :meth:`add_equation()` the LHS must be either a variable symbol or a derivative, and all numbers and - variable symbols used in ``lhs`` and ``rhs`` must have been obtained from this model, e.g. via - :meth:`add_number()`, :meth:`add_variable()`, or :meth:`get_symbol_by_ontology_term()`. + Removes an equation from the model. - :param lhs: An LHS expression (either a symbol or a derivative). - :param rhs: The new RHS expression for this variable. + :param equation: The equation to remove. """ - # Get variable symbol named in the lhs - lhs_symbol = lhs - if lhs_symbol.is_Derivative: - lhs_symbol = lhs_symbol.free_symbols.pop() - assert isinstance(lhs_symbol, VariableDummy) - - # Check if the variable named in the lhs already has an equation - i_existing = None - for i, eq in enumerate(self.equations): - symbol = eq.lhs.free_symbols.pop() if eq.lhs.is_Derivative else eq.lhs - if symbol == lhs_symbol: - i_existing = i - break - - # Add or replace equation - if i_existing is None: - self.equations.append(sympy.Eq(lhs, rhs)) - else: - self.equations[i_existing] = sympy.Eq(lhs, rhs) + try: + i = self.equations.index(equation) + except ValueError: + raise KeyError('Equation not found in model ' + str(equation)) + del(self.equations[i]) # Invalidate cached equation graphs self._invalidate_cache() diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index 58dd7a74..b192f68e 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -329,35 +329,8 @@ def test_add_equation(self, local_model): assert eqn[2].lhs == symbol2 assert eqn[2].rhs == sp.Add(symbol, symbol1) - def test_set_equation(self, local_model): - """ Tests the Model.set_equation method. - """ - model = local_model - assert len(model.equations) == 1 - # so we are adding - # newvar2 = newvar + newvar1 - # but need to also add newvar1 = 2; newvar = 2 in order or the graph to resolve correctly - model.add_variable(name='newvar', units='mV') - symbol = model.get_symbol_by_name('newvar') - model.add_variable(name='newvar1', units='mV') - symbol1 = model.get_symbol_by_name('newvar1') - model.add_variable(name='newvar2', units='mV') - symbol2 = model.get_symbol_by_name('newvar2') - model.set_equation(symbol, 2.0) - model.set_equation(symbol1, 2.0) - model.set_equation(symbol2, sp.Add(symbol, symbol1)) - assert len(model.equations) == 4 - eqn = model.get_equations_for([symbol2]) - assert len(eqn) == 3 - assert eqn[0].lhs == symbol - assert eqn[0].rhs == 2.0 - assert eqn[1].lhs == symbol1 - assert eqn[1].rhs == 2.0 - assert eqn[2].lhs == symbol2 - assert eqn[2].rhs == sp.Add(symbol, symbol1) - - def test_set_equation2(self, local_hh_model): - """ Tests replacing an equation in a model. """ + def test_remove_equation(self, local_hh_model): + """ Tests the Model.remove_equation method. """ model = local_hh_model # Get model, assert that V is a state variable @@ -365,8 +338,11 @@ def test_set_equation2(self, local_hh_model): assert v.type == 'state' # Now clamp it to -80mV - rhs = model.add_number(-80, str(v.units)) - model.set_equation(v, rhs) + t = model.get_symbol_by_ontology_term(shared.OXMETA, 'time') + equation = model.graph.nodes[sp.Derivative(v, t)]['equation'] + model.remove_equation(equation) + equation = sp.Eq(v, model.add_number(-80, str(v.units))) + model.add_equation(equation) # Check that V is no longer a state v = model.get_symbol_by_ontology_term(shared.OXMETA, 'membrane_voltage') @@ -376,24 +352,23 @@ def test_set_equation2(self, local_hh_model): # See: https://github.com/ModellingWebLab/cellmlmanip/issues/133 # Now make V a state again - t = model.get_symbol_by_ontology_term(shared.OXMETA, 'time') - lhs = sp.Derivative(v, t) dvdt_units = 'unlikely_unit_name' model.add_unit(dvdt_units, [ {'units': str(v.units)}, {'units': str(t.units), 'exponent': -1}, ]) - rhs = model.add_number(0, dvdt_units) - model.set_equation(lhs, rhs) + model.remove_equation(equation) + equation = sp.Eq(sp.Derivative(v, t), model.add_number(0, dvdt_units)) + model.add_equation(equation) # Check that V is a state again v = model.get_symbol_by_ontology_term(shared.OXMETA, 'membrane_voltage') assert v.type == 'state' - # Set equation for a newly created variable - lhs = model.add_variable(name='an_incredibly_unlikely_variable_name', units=str(v.units)) - rhs = model.add_number(12, str(v.units)) - model.set_equation(lhs, rhs) + # Test removing non-existing equation + equation = sp.Eq(sp.Derivative(v, t), model.add_number(5, dvdt_units)) + with pytest.raises(KeyError, match='Equation not found'): + model.remove_equation(equation) def test_add_number(self, local_model): """ Tests the Model.add_number method. """ From df0d4e0add616474446adf3fe2bda2968ef09890 Mon Sep 17 00:00:00 2001 From: Michael Clerx Date: Fri, 6 Dec 2019 10:46:46 +0100 Subject: [PATCH 2/2] Small tweaks to Model class. --- cellmlmanip/model.py | 73 +++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 39 deletions(-) diff --git a/cellmlmanip/model.py b/cellmlmanip/model.py index 521b2efe..cf611dd1 100644 --- a/cellmlmanip/model.py +++ b/cellmlmanip/model.py @@ -146,52 +146,48 @@ def connect_variables(self, source_name: str, target_name: str): source = self._name_to_symbol[source_name] target = self._name_to_symbol[target_name] - # If the source variable has already been assigned a final symbol - if source.assigned_to: - - if target.assigned_to: - raise ValueError('Target already assigned to %s before assignment to %s' % - (target.assigned_to, source.assigned_to)) - - # If source/target variable is in the same unit - if source.units == target.units: - # Direct substitution is possible - target.assigned_to = source.assigned_to - # everywhere the target variable is used, replace with source variable - for index, equation in enumerate(self.equations): - self.equations[index] = equation.xreplace({target: source.assigned_to}) - # Otherwise, this connection requires a conversion - else: - # Get the scaling factor required to convert source units to target units - factor = self.units.convert_to(1 * source.units, target.units).magnitude - - # Dummy to represent this factor in equations, having units for conversion - factor_dummy = self.add_number(factor, str(target.units / source.units)) - - # Add an equations making the connection with the required conversion - self.equations.append(sympy.Eq(target, source.assigned_to * factor_dummy)) - - logger.info('Connection req. unit conversion: %s', self.equations[-1]) + # If the source variable has not been assigned a symbol, we can't make this connection + if not source.assigned_to: + logger.info('The source variable has not been assigned to a symbol ' + '(i.e. expecting a connection): %s ⟶ %s', + target.name, source.name) + return False + + # If target is already assigned this is an error + if target.assigned_to: + raise ValueError('Target already assigned to %s before assignment to %s' % + (target.assigned_to, source.assigned_to)) + + # If source/target variable is in the same unit + if source.units == target.units: + # Direct substitution is possible + target.assigned_to = source.assigned_to + # everywhere the target variable is used, replace with source variable + for index, equation in enumerate(self.equations): + self.equations[index] = equation.xreplace({target: source.assigned_to}) + + # Otherwise, this connection requires a conversion + else: + # Get the scaling factor required to convert source units to target units + factor = self.units.convert_to(1 * source.units, target.units).magnitude - # The assigned symbol for this variable is itself - target.assigned_to = target + # Dummy to represent this factor in equations, having units for conversion + factor_dummy = self.add_number(factor, str(target.units / source.units)) - logger.debug('Updated target: %s', target) + # Add an equations making the connection with the required conversion + self.equations.append(sympy.Eq(target, source.assigned_to * factor_dummy)) - # Invalidate cached graphs - self._invalidate_cache() + logger.info('Connection req. unit conversion: %s', self.equations[-1]) - return True + # The assigned symbol for this variable is itself + target.assigned_to = target - # The source variable has not been assigned a symbol, so we can't make this connection - logger.info('The source variable has not been assigned to a symbol ' - '(i.e. expecting a connection): %s ⟶ %s', - target.name, source.name) + logger.debug('Updated target: %s', target) # Invalidate cached graphs self._invalidate_cache() - return False + return True def add_rdf(self, rdf: str): """ Takes an RDF string and stores it in the model's RDF graph. """ @@ -588,10 +584,9 @@ def remove_equation(self, equation): :param equation: The equation to remove. """ try: - i = self.equations.index(equation) + self.equations.remove(equation) except ValueError: raise KeyError('Equation not found in model ' + str(equation)) - del(self.equations[i]) # Invalidate cached equation graphs self._invalidate_cache()