diff --git a/st2common/st2common/models/api/reactor.py b/st2common/st2common/models/api/reactor.py index 30c4e3f674..91c2fe1b2b 100644 --- a/st2common/st2common/models/api/reactor.py +++ b/st2common/st2common/models/api/reactor.py @@ -15,19 +15,23 @@ def _get_impl(cls): @classmethod def get_by_name(cls, value): - cls._get_impl().get_by_name(value) + return cls._get_impl().get_by_name(value) @classmethod def get_by_id(cls, value): - cls._get_impl().get_by_id(value) + return cls._get_impl().get_by_id(value) @classmethod def get_all(cls): - cls._get_impl().get_all() + return cls._get_impl().get_all() @classmethod def add_or_update(cls, model_object): - cls.add_or_update(model_object) + return cls._get_impl().add_or_update(model_object) + + @classmethod + def delete(cls, model_object): + return cls._get_impl().delete(model_object) class TriggerSource(Access): diff --git a/st2common/st2common/models/db/reactor.py b/st2common/st2common/models/db/reactor.py index 1b0964e8d7..c5fc972273 100644 --- a/st2common/st2common/models/db/reactor.py +++ b/st2common/st2common/models/db/reactor.py @@ -1,5 +1,6 @@ import mongoengine as me from st2common.models.db.stormbase import BaseDB +from st2common.models.db.stactioncontroller import StactionDB class TriggerSourceDB(BaseDB): @@ -25,7 +26,7 @@ class TriggerDB(BaseDB): trigger_source: Source that owns this trigger type. payload_info: Meta information of the expected payload. """ - trigger_source = me.ReferenceField() + trigger_source = me.ReferenceField(TriggerSourceDB.__name__) payload_info = me.ListField() @@ -36,7 +37,7 @@ class TriggerInstanceDB(BaseDB): payload (dict): payload specific to the occurrence. occurrence_time (datetime): time of occurrence of the trigger. """ - trigger = me.ReferenceField() + trigger = me.ReferenceField(TriggerDB.__name__) payload = me.DictField() occurrence_time = me.DateTimeField() @@ -53,8 +54,8 @@ class RuleDB(BaseDB): status: enabled or disabled. If disabled occurence of the trigger does not lead to execution of a staction and vice-versa. """ - trigger = me.ReferenceField() - staction = me.ReferenceField() + trigger = me.ReferenceField(TriggerDB.__name__) + staction = me.ReferenceField(StactionDB.__name__) data_mapping = me.DictField() status = me.StringField() @@ -68,9 +69,9 @@ class RuleEnforcementDB(BaseDB): staction_execution (Reference): The StactionExecution that was created to record execution of a staction as part of this enforcement. """ - rule = me.ReferenceField() - trigger_instance = me.ReferenceField() - staction_execution = me.ReferenceField() + rule = me.ReferenceField(RuleDB.__name__) + trigger_instance = me.ReferenceField(TriggerInstanceDB.__name__) + staction_execution = me.ReferenceField(StactionDB.__name__) class MongoDBAccess(object): @@ -98,7 +99,11 @@ def get_all(self): @staticmethod def add_or_update(model_object): - model_object.save() + return model_object.save() + + @staticmethod + def delete(model_object): + model_object.delete() # specialized access objects triggersource_access = MongoDBAccess(TriggerSourceDB) diff --git a/st2common/st2common/models/db/stormbase.py b/st2common/st2common/models/db/stormbase.py index d872f4f8bc..5e45290cf4 100644 --- a/st2common/st2common/models/db/stormbase.py +++ b/st2common/st2common/models/db/stormbase.py @@ -7,8 +7,13 @@ class BaseDB(me.Document): name : name of the entity. description : description of the entity. id : unique identifier for the entity. If none is provided it - will be auto generate by the system. + will be auto generated by the system. """ name = me.StringField(required=True) description = me.StringField() - id = me.ObjectIdField(primary_key=True, unique=True, required=True) + # ObjectIdField should be not have any constraints like required, + # unique etc for it to be auto-generated. + id = me.ObjectIdField() + # see http://docs.mongoengine.org/guide/defining-documents + # .html#document-inheritance + meta = {'allow_inheritance': True} diff --git a/st2common/tests/test_db.py b/st2common/tests/test_db.py index 06c61b44eb..dda3c4ccba 100644 --- a/st2common/tests/test_db.py +++ b/st2common/tests/test_db.py @@ -5,6 +5,9 @@ from st2common.models.db import setup, teardown +SKIP_DELETE = False + + class DbConnectionTest(unittest2.TestCase): def setUp(self): tests.parse_args() @@ -14,8 +17,151 @@ def tearDown(self): teardown() def test_check_connect(self): + """ + Tests connectivity to the db server. Requires the db server to be + running. + """ client = mongoengine.connection.get_connection() self.assertEqual(client.host, cfg.CONF.database.host, 'Not connected to desired host.') self.assertEqual(client.port, cfg.CONF.database.port, 'Not connected to desired port.') + +from st2common.models.db.reactor import TriggerDB, TriggerInstanceDB, \ + TriggerSourceDB, RuleEnforcementDB, RuleDB +from st2common.models.api.reactor import Trigger, TriggerInstance, \ + TriggerSource, RuleEnforcement, Rule + + +class ReactorModelTest(unittest2.TestCase): + def setUp(self): + tests.parse_args() + setup() + + def tearDown(self): + teardown() + + def test_triggersource_crud(self): + saved = ReactorModelTest._create_save_triggersource() + retrieved = TriggerSource.get_by_id(saved.id) + self.assertEqual(saved.name, retrieved.name, + 'Same TriggerSource was not returned.') + ReactorModelTest._delete([retrieved]) + try: + retrieved = TriggerSource.get_by_id(saved.id) + except ValueError: + retrieved = None + self.assertIsNone(retrieved, 'managed to retrieve after failure.') + + def test_trigger_crud(self): + triggersource = ReactorModelTest._create_save_triggersource() + saved = ReactorModelTest._create_save_trigger(triggersource) + retrieved = Trigger.get_by_id(saved.id) + self.assertEqual(saved.name, retrieved.name, + 'Same trigger was not returned.') + ReactorModelTest._delete([retrieved, triggersource]) + try: + retrieved = Trigger.get_by_id(saved.id) + except ValueError: + retrieved = None + self.assertIsNone(retrieved, 'managed to retrieve after failure.') + + def test_triggerinstance_crud(self): + triggersource = ReactorModelTest._create_save_triggersource() + trigger = ReactorModelTest._create_save_trigger(triggersource) + saved = ReactorModelTest._create_save_triggerinstance(trigger) + retrieved = TriggerInstance.get_by_id(saved.id) + self.assertEqual(saved.name, retrieved.name, + 'Same triggerinstance was not returned.') + ReactorModelTest._delete([retrieved, trigger, triggersource]) + try: + retrieved = TriggerInstance.get_by_id(saved.id) + except ValueError: + retrieved = None + self.assertIsNone(retrieved, 'managed to retrieve after failure.') + + def test_rule_crud(self): + triggersource = ReactorModelTest._create_save_triggersource() + trigger = ReactorModelTest._create_save_trigger(triggersource) + saved = ReactorModelTest._create_save_rule(trigger) + retrieved = Rule.get_by_id(saved.id) + self.assertEqual(saved.name, retrieved.name, + 'Same rule was not returned.') + ReactorModelTest._delete([retrieved, trigger, triggersource]) + try: + retrieved = Rule.get_by_id(saved.id) + except ValueError: + retrieved = None + self.assertIsNone(retrieved, 'managed to retrieve after failure.') + + def test_ruleenforcement_crud(self): + triggersource = ReactorModelTest._create_save_triggersource() + trigger = ReactorModelTest._create_save_trigger(triggersource) + triggerinstance = ReactorModelTest._create_save_triggerinstance(trigger) + rule = ReactorModelTest._create_save_rule(trigger) + saved = ReactorModelTest._create_save_ruleenforcement(triggerinstance, + rule) + retrieved = RuleEnforcement.get_by_id(saved.id) + self.assertEqual(saved.name, retrieved.name, + 'Same rule was not returned.') + ReactorModelTest._delete([retrieved,rule, triggerinstance, trigger, + triggersource]) + try: + retrieved = Rule.get_by_id(saved.id) + except ValueError: + retrieved = None + self.assertIsNone(retrieved, 'managed to retrieve after failure.') + + @staticmethod + def _create_save_triggersource(): + created = TriggerSourceDB() + created.name = 'triggersource-1' + created.description = '' + return TriggerSource.add_or_update(created) + + @staticmethod + def _create_save_trigger(triggersource): + created = TriggerDB() + created.name = 'trigger-1' + created.description = '' + created.payload_info = [] + created.trigger_source = triggersource + return TriggerSource.add_or_update(created) + + @staticmethod + def _create_save_triggerinstance(trigger): + created = TriggerInstanceDB() + created.name = 'triggerinstance-1' + created.description = '' + created.trigger = trigger + created.payload = {} + return TriggerInstance.add_or_update(created) + + @staticmethod + def _create_save_rule(trigger, staction=None): + created = RuleDB() + created.name = 'rule-1' + created.description = '' + created.trigger = trigger + created.staction = staction + created.data_mapping = {} + return Rule.add_or_update(created) + + @staticmethod + def _create_save_ruleenforcement(triggerinstance, rule, + stactionexecution=None): + created = RuleEnforcementDB() + created.name = 'ruleenforcement-1' + created.description = '' + created.rule = rule + created.trigger_instance = triggerinstance + created.staction_execution = stactionexecution + return RuleEnforcement.add_or_update(created) + + @staticmethod + def _delete(model_objects): + global SKIP_DELETE + if SKIP_DELETE: + return + for model_object in model_objects: + model_object.delete()