Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions st2common/st2common/models/api/reactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@ def _get_impl(cls):

@classmethod
def get_by_name(cls, value):
cls._get_impl().get_by_name(value)
return cls._get_impl().get_by_name(value)

@classmethod
def get_by_id(cls, value):
cls._get_impl().get_by_id(value)
return cls._get_impl().get_by_id(value)

@classmethod
def get_all(cls):
cls._get_impl().get_all()
return cls._get_impl().get_all()

@classmethod
def add_or_update(cls, model_object):
cls.add_or_update(model_object)
return cls._get_impl().add_or_update(model_object)

@classmethod
def delete(cls, model_object):
return cls._get_impl().delete(model_object)


class TriggerSource(Access):
Expand Down
21 changes: 13 additions & 8 deletions st2common/st2common/models/db/reactor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import mongoengine as me
from st2common.models.db.stormbase import BaseDB
from st2common.models.db.stactioncontroller import StactionDB


class TriggerSourceDB(BaseDB):
Expand All @@ -25,7 +26,7 @@ class TriggerDB(BaseDB):
trigger_source: Source that owns this trigger type.
payload_info: Meta information of the expected payload.
"""
trigger_source = me.ReferenceField()
trigger_source = me.ReferenceField(TriggerSourceDB.__name__)
payload_info = me.ListField()


Expand All @@ -36,7 +37,7 @@ class TriggerInstanceDB(BaseDB):
payload (dict): payload specific to the occurrence.
occurrence_time (datetime): time of occurrence of the trigger.
"""
trigger = me.ReferenceField()
trigger = me.ReferenceField(TriggerDB.__name__)
payload = me.DictField()
occurrence_time = me.DateTimeField()

Expand All @@ -53,8 +54,8 @@ class RuleDB(BaseDB):
status: enabled or disabled. If disabled occurence of the trigger
does not lead to execution of a staction and vice-versa.
"""
trigger = me.ReferenceField()
staction = me.ReferenceField()
trigger = me.ReferenceField(TriggerDB.__name__)
staction = me.ReferenceField(StactionDB.__name__)
data_mapping = me.DictField()
status = me.StringField()

Expand All @@ -68,9 +69,9 @@ class RuleEnforcementDB(BaseDB):
staction_execution (Reference): The StactionExecution that was
created to record execution of a staction as part of this enforcement.
"""
rule = me.ReferenceField()
trigger_instance = me.ReferenceField()
staction_execution = me.ReferenceField()
rule = me.ReferenceField(RuleDB.__name__)
trigger_instance = me.ReferenceField(TriggerInstanceDB.__name__)
staction_execution = me.ReferenceField(StactionDB.__name__)


class MongoDBAccess(object):
Expand Down Expand Up @@ -98,7 +99,11 @@ def get_all(self):

@staticmethod
def add_or_update(model_object):
model_object.save()
return model_object.save()

@staticmethod
def delete(model_object):
model_object.delete()

# specialized access objects
triggersource_access = MongoDBAccess(TriggerSourceDB)
Expand Down
9 changes: 7 additions & 2 deletions st2common/st2common/models/db/stormbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ class BaseDB(me.Document):
name : name of the entity.
description : description of the entity.
id : unique identifier for the entity. If none is provided it
will be auto generate by the system.
will be auto generated by the system.
"""
name = me.StringField(required=True)
description = me.StringField()
id = me.ObjectIdField(primary_key=True, unique=True, required=True)
# ObjectIdField should be not have any constraints like required,
# unique etc for it to be auto-generated.
id = me.ObjectIdField()
# see http://docs.mongoengine.org/guide/defining-documents
# .html#document-inheritance
meta = {'allow_inheritance': True}
146 changes: 146 additions & 0 deletions st2common/tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from st2common.models.db import setup, teardown


SKIP_DELETE = False


class DbConnectionTest(unittest2.TestCase):
def setUp(self):
tests.parse_args()
Expand All @@ -14,8 +17,151 @@ def tearDown(self):
teardown()

def test_check_connect(self):
"""
Tests connectivity to the db server. Requires the db server to be
running.
"""
client = mongoengine.connection.get_connection()
self.assertEqual(client.host, cfg.CONF.database.host,
'Not connected to desired host.')
self.assertEqual(client.port, cfg.CONF.database.port,
'Not connected to desired port.')

from st2common.models.db.reactor import TriggerDB, TriggerInstanceDB, \
TriggerSourceDB, RuleEnforcementDB, RuleDB
from st2common.models.api.reactor import Trigger, TriggerInstance, \
TriggerSource, RuleEnforcement, Rule


class ReactorModelTest(unittest2.TestCase):
def setUp(self):
tests.parse_args()
setup()

def tearDown(self):
teardown()

def test_triggersource_crud(self):
saved = ReactorModelTest._create_save_triggersource()
retrieved = TriggerSource.get_by_id(saved.id)
self.assertEqual(saved.name, retrieved.name,
'Same TriggerSource was not returned.')
ReactorModelTest._delete([retrieved])
try:
retrieved = TriggerSource.get_by_id(saved.id)
except ValueError:
retrieved = None
self.assertIsNone(retrieved, 'managed to retrieve after failure.')

def test_trigger_crud(self):
triggersource = ReactorModelTest._create_save_triggersource()
saved = ReactorModelTest._create_save_trigger(triggersource)
retrieved = Trigger.get_by_id(saved.id)
self.assertEqual(saved.name, retrieved.name,
'Same trigger was not returned.')
ReactorModelTest._delete([retrieved, triggersource])
try:
retrieved = Trigger.get_by_id(saved.id)
except ValueError:
retrieved = None
self.assertIsNone(retrieved, 'managed to retrieve after failure.')

def test_triggerinstance_crud(self):
triggersource = ReactorModelTest._create_save_triggersource()
trigger = ReactorModelTest._create_save_trigger(triggersource)
saved = ReactorModelTest._create_save_triggerinstance(trigger)
retrieved = TriggerInstance.get_by_id(saved.id)
self.assertEqual(saved.name, retrieved.name,
'Same triggerinstance was not returned.')
ReactorModelTest._delete([retrieved, trigger, triggersource])
try:
retrieved = TriggerInstance.get_by_id(saved.id)
except ValueError:
retrieved = None
self.assertIsNone(retrieved, 'managed to retrieve after failure.')

def test_rule_crud(self):
triggersource = ReactorModelTest._create_save_triggersource()
trigger = ReactorModelTest._create_save_trigger(triggersource)
saved = ReactorModelTest._create_save_rule(trigger)
retrieved = Rule.get_by_id(saved.id)
self.assertEqual(saved.name, retrieved.name,
'Same rule was not returned.')
ReactorModelTest._delete([retrieved, trigger, triggersource])
try:
retrieved = Rule.get_by_id(saved.id)
except ValueError:
retrieved = None
self.assertIsNone(retrieved, 'managed to retrieve after failure.')

def test_ruleenforcement_crud(self):
triggersource = ReactorModelTest._create_save_triggersource()
trigger = ReactorModelTest._create_save_trigger(triggersource)
triggerinstance = ReactorModelTest._create_save_triggerinstance(trigger)
rule = ReactorModelTest._create_save_rule(trigger)
saved = ReactorModelTest._create_save_ruleenforcement(triggerinstance,
rule)
retrieved = RuleEnforcement.get_by_id(saved.id)
self.assertEqual(saved.name, retrieved.name,
'Same rule was not returned.')
ReactorModelTest._delete([retrieved,rule, triggerinstance, trigger,
triggersource])
try:
retrieved = Rule.get_by_id(saved.id)
except ValueError:
retrieved = None
self.assertIsNone(retrieved, 'managed to retrieve after failure.')

@staticmethod
def _create_save_triggersource():
created = TriggerSourceDB()
created.name = 'triggersource-1'
created.description = ''
return TriggerSource.add_or_update(created)

@staticmethod
def _create_save_trigger(triggersource):
created = TriggerDB()
created.name = 'trigger-1'
created.description = ''
created.payload_info = []
created.trigger_source = triggersource
return TriggerSource.add_or_update(created)

@staticmethod
def _create_save_triggerinstance(trigger):
created = TriggerInstanceDB()
created.name = 'triggerinstance-1'
created.description = ''
created.trigger = trigger
created.payload = {}
return TriggerInstance.add_or_update(created)

@staticmethod
def _create_save_rule(trigger, staction=None):
created = RuleDB()
created.name = 'rule-1'
created.description = ''
created.trigger = trigger
created.staction = staction
created.data_mapping = {}
return Rule.add_or_update(created)

@staticmethod
def _create_save_ruleenforcement(triggerinstance, rule,
stactionexecution=None):
created = RuleEnforcementDB()
created.name = 'ruleenforcement-1'
created.description = ''
created.rule = rule
created.trigger_instance = triggerinstance
created.staction_execution = stactionexecution
return RuleEnforcement.add_or_update(created)

@staticmethod
def _delete(model_objects):
global SKIP_DELETE
if SKIP_DELETE:
return
for model_object in model_objects:
model_object.delete()