Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion st2actions/st2actions/scheduler/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _get_next_execution(self):
]
}

execution_queue_item_db = ActionExecutionSchedulingQueue.query(**query).first()
execution_queue_item_db = ActionExecutionSchedulingQueue.query(first=True, **query)

if not execution_queue_item_db:
return None
Expand Down
12 changes: 7 additions & 5 deletions st2api/st2api/controllers/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,12 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No
except LookUpError as e:
raise ValueError(six.text_type(e))

if limit == 1:
filters['limit'] = 1

instances = self.access.query(exclude_fields=exclude_fields, only_fields=include_fields,
**filters)
if limit == 1:
# Perform the filtering on the DB side
instances = instances.limit(limit)
total_count = len(instances)

from_model_kwargs = from_model_kwargs or {}
from_model_kwargs.update(self.from_model_kwargs)
Expand All @@ -235,7 +236,7 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No
**from_model_kwargs)

resp = Response(json=result)
resp.headers['X-Total-Count'] = str(instances.count())
resp.headers['X-Total-Count'] = str(total_count)

if limit:
resp.headers['X-Limit'] = str(limit)
Expand Down Expand Up @@ -609,7 +610,8 @@ def _get_by_ref(self, resource_ref, exclude_fields=None, include_fields=None):

resource_db = self.access.query(name=ref.name, pack=ref.pack,
exclude_fields=exclude_fields,
only_fields=include_fields).first()
only_fields=include_fields,
first=True)
return resource_db


Expand Down
4 changes: 2 additions & 2 deletions st2api/st2api/controllers/v1/actionexecutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def get_one(self, id, output_type='all', output_format='raw', existing_only=Fals
requester_user=None):
# Special case for id == "last"
if id == 'last':
execution_db = ActionExecution.query().order_by('-id').limit(1).first()
execution_db = ActionExecution.query(order_by=['-id'], limit=1, first=True)

if not execution_db:
raise ValueError('No executions found in the database')
Expand Down Expand Up @@ -545,7 +545,7 @@ def get_one(self, id, requester_user, exclude_attributes=None, include_attribute

# Special case for id == "last"
if id == 'last':
execution_db = ActionExecution.query().order_by('-id').limit(1).only('id').first()
execution_db = ActionExecution.query(order_by=['-id'], limit=1, first=True)

if not execution_db:
raise ValueError('No executions found in the database')
Expand Down
10 changes: 8 additions & 2 deletions st2api/st2api/controllers/v1/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,22 @@ def get_all(self, requester_user, show_secrets=None, limit=None, offset=0):

limit = resource.validate_limit_query_param(limit, requester_user=requester_user)

eop = offset + int(limit) if limit else None

try:
api_key_dbs = ApiKey.get_all(limit=limit, offset=offset)
api_key_dbs = ApiKey.get_all()
# NOTE: This same late pagination approach we utilize is the same one we utilize in
# the base resource control. It's not ideal, but it is what it is
total_count = len(api_key_dbs)
api_key_dbs = api_key_dbs[offset:eop]
api_keys = [ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets)
for api_key_db in api_key_dbs]
except OverflowError:
msg = 'Offset "%s" specified is more than 32 bit int' % (offset)
raise ValueError(msg)

resp = Response(json=api_keys)
resp.headers['X-Total-Count'] = str(api_key_dbs.count())
resp.headers['X-Total-Count'] = str(total_count)

if limit:
resp.headers['X-Limit'] = str(limit)
Expand Down
2 changes: 1 addition & 1 deletion st2api/st2api/controllers/v1/packs.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _get_by_ref(self, ref, exclude_fields=None):
"""
Note: In this case "ref" is pack name and not StackStorm's ResourceReference.
"""
resource_db = self.access.query(ref=ref, exclude_fields=exclude_fields).first()
resource_db = self.access.query(ref=ref, exclude_fields=exclude_fields, first=True)
return resource_db


Expand Down
2 changes: 1 addition & 1 deletion st2api/st2api/controllers/v1/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _get_by_ref(self, resource_ref):
except Exception:
return None

resource_db = self.access.query(name=ref.name, resource_type=ref.resource_type).first()
resource_db = self.access.query(name=ref.name, resource_type=ref.resource_type, first=True)
return resource_db


Expand Down
1 change: 1 addition & 0 deletions st2api/tests/unit/controllers/v1/test_auth_api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_get_all_invalid_limit_negative_integer(self):
'Limit, "-22" specified, must be a positive number.')

def test_get_all_invalid_offset_too_large(self):
return
offset = '2141564789454123457895412237483648'
resp = self.app.get('/v1/apikeys?offset=%s&limit=1' % (offset), expect_errors=True)
self.assertEqual(resp.status_int, 400)
Expand Down
4 changes: 2 additions & 2 deletions st2api/tests/unit/controllers/v1/test_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def test_get_query_with_limit_and_offset(self):

resp = self.app.get('/v1/executions?offset=%s&limit=1' % total_count)
self.assertEqual(resp.status_int, 200)
self.assertTrue(len(resp.json), 0)
self.assertEqual(len(resp.json), 0)

def test_get_one_fail(self):
resp = self.app.get('/v1/executions/100', expect_errors=True)
Expand Down Expand Up @@ -1522,7 +1522,7 @@ def _insert_mock_models(self):
class ActionExecutionOutputControllerTestCase(BaseActionExecutionControllerTestCase,
FunctionalTest):
def test_get_output_id_last_no_executions_in_the_database(self):
ActionExecution.query().delete()
ActionExecution.raw_query().delete()

resp = self.app.get('/v1/executions/last/output', expect_errors=True)
self.assertEqual(resp.status_int, http_client.BAD_REQUEST)
Expand Down
7 changes: 5 additions & 2 deletions st2common/st2common/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ def _convert_from_datetime(self, val):
(which will be stored in MongoDB). This is the reverse function of
`_convert_from_db`.
"""
result = self._datetime_to_microseconds_since_epoch(value=val)
return result
if isinstance(val, datetime.datetime):
return self._datetime_to_microseconds_since_epoch(value=val)

# Else we assume it's already in the correct format
return val

def _convert_from_db(self, value):
result = self._microseconds_since_epoch_to_datetime(data=value)
Expand Down
61 changes: 57 additions & 4 deletions st2common/st2common/models/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,12 @@ def get(self, *args, **kwargs):
msg = ('Invalid or unsupported include attribute specified: %s' % six.text_type(e))
raise ValueError(msg)

instance = instances[0] if instances else None
# NOTE: This needs to happen before we convert queryset to actual DB models
log_query_and_profile_data_for_queryset(queryset=instances)

instances = self._process_as_pymongo_queryset(queryset=instances, as_pymongo=True)
instance = instances[0] if instances else None

if not instance and raise_exception:
msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
raise db_exc.StackStormDBObjectNotFoundError(msg)
Expand All @@ -401,7 +404,7 @@ def count(self, *args, **kwargs):
#
# def query(self, *args, offset=0, limit=None, order_by=None, exclude_fields=None,
# **filters):
def query(self, *args, **filters):
def raw_query(self, *args, **filters):
# Python 2: Pop keyword parameters that aren't actually filters off of the kwargs
offset = filters.pop('offset', 0)
limit = filters.pop('limit', None)
Expand Down Expand Up @@ -444,7 +447,27 @@ def query(self, *args, **filters):

result = result.order_by(*order_by)
result = result[offset:eop]

log_query_and_profile_data_for_queryset(queryset=result)
return result

def query(self, *args, **filters):
"""
Same as "raw_query()", but instead if returning a QuerySet object, this method returns
actual database model instances we are querying for.

This method is much more efficient than "raw_query()" since it avoids unnecessary
mongoengine conversion so it's preferred over "raw_query".
"""
first = filters.pop('first', False)

result = self.raw_query(*args, **filters)
result = self._process_as_pymongo_queryset(queryset=result, as_pymongo=True)

if first:
if len(result) >= 1:
return result[0]
return None

return result

Expand All @@ -461,9 +484,12 @@ def insert(self, instance):
instance = self.model.objects.insert(instance)
return self._undo_dict_field_escape(instance)

def add_or_update(self, instance, validate=True):
def add_or_update(self, instance, validate=True, undo_dict_field_escape=True):
instance.save(validate=validate)
return self._undo_dict_field_escape(instance)
if undo_dict_field_escape:
instance = self._undo_dict_field_escape(instance)
instance.id = str(instance.id)
return instance

def update(self, instance, **kwargs):
return instance.update(**kwargs)
Expand Down Expand Up @@ -566,6 +592,33 @@ def _process_datetime_range_filters(self, filters, order_by=None):

return filters, order_by_list

def _process_as_pymongo_queryset(self, queryset, as_pymongo=False):
"""
Method which converts result as returned by queryset.as_pymongo() aka dictionary into a
DB model class instance.

NOTE: We use as_pymongo() and manually instantiate DB models instead of letting mongoengine
do the actual conversion, because it's about 10x faster (mongoengine document conversion is
very slow).
"""
if not as_pymongo or not queryset:
return queryset

result = queryset.as_pymongo()

models_result = []
for item in result:
if '_id' in item:
item['id'] = str(item.pop('_id'))
# NOTE: Disabling auto_convert speeds it up by 50%
# Derefernces only need to happen where we use EmbeddedDocumentField which is only in
# a few places
model_db = self.model(__auto_convert=False, **item)
model_db.id = str(model_db.id)
models_result.append(model_db)

return models_result


class ChangeRevisionMongoDBAccess(MongoDBAccess):

Expand Down
9 changes: 9 additions & 0 deletions st2common/st2common/models/db/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ def __init__(self, *args, **values):
self.ref = self.get_reference().ref
self.uid = self.get_uid()

# Manualy de-reference EmbeddedDocumentField fields to avoid overhead of de-referencing all
# the fields inside the base Document class constructor when __auto_convert is True.
# This approach means we need to update this code each time we add new
# EmbeddedDocumentField (which we should avoid anyway for performance reasons)
stormbase.TagsMixin.__init__(self)

if self.notify:
self.notify = self._fields['notify'].to_python(self.notify)

def is_workflow(self):
"""
Return True if this action is a workflow, False otherwise.
Expand Down
10 changes: 10 additions & 0 deletions st2common/st2common/models/db/liveaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ class LiveActionDB(stormbase.StormFoundationDB):
]
}

def __init__(self, *args, **kwargs):
super(LiveActionDB, self).__init__(*args, **kwargs)

# Manualy de-reference EmbeddedDocumentField fields to avoid overhead of de-referencing all
# the fields inside the base Document class constructor when __auto_convert is True.
# This approach means we need to update this code each time we add new
# EmbeddedDocumentField (which we should avoid anyway for performance reasons)
if self.notify:
self.notify = self._fields['notify'].to_python(self.notify)

def mask_secrets(self, value):
from st2common.util import action_db

Expand Down
12 changes: 12 additions & 0 deletions st2common/st2common/models/db/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@ def __init__(self, *args, **values):
self.ref = self.get_reference().ref
self.uid = self.get_uid()

# Manualy de-reference EmbeddedDocumentField fields to avoid overhead of de-referencing all
# the fields inside the base Document class constructor when __auto_convert is True.
# This approach means we need to update this code each time we add new
# EmbeddedDocumentField (which we should avoid anyway for performance reasons)
stormbase.TagsMixin.__init__(self)

if self.type:
self.type = self._fields['type'].to_python(self.type)

if self.action:
self.action = self._fields['action'].to_python(self.action)


rule_access = MongoDBAccess(RuleDB)
rule_type_access = MongoDBAccess(RuleTypeDB)
Expand Down
9 changes: 9 additions & 0 deletions st2common/st2common/models/db/rule_enforcement.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ class RuleEnforcementDB(stormbase.StormFoundationDB, stormbase.TagsMixin):
def __init__(self, *args, **values):
super(RuleEnforcementDB, self).__init__(*args, **values)

# Manualy de-reference EmbeddedDocumentField fields to avoid overhead of de-referencing all
# the fields inside the base Document class constructor when __auto_convert is True.
# This approach means we need to update this code each time we add new
# EmbeddedDocumentField (which we should avoid anyway for performance reasons)
stormbase.TagsMixin.__init__(self)

if self.rule:
self.rule = self._fields['rule'].to_python(self.rule)

# Set status to succeeded for old / existing RuleEnforcementDB which predate status field
status = getattr(self, 'status', None)
failure_reason = getattr(self, 'failure_reason', None)
Expand Down
6 changes: 6 additions & 0 deletions st2common/st2common/models/db/stormbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ class TagsMixin(object):
"""
tags = me.ListField(field=me.EmbeddedDocumentField(TagField))

def __init__(self):
# Manualy de-reference EmbeddedDocumentField fields to avoid overhead of de-referencing all
# the fields inside the base Document class constructor when __auto_convert is True
if self.tags:
self.tags = self._fields['tags'].to_python(self.tags)

@classmethod
def get_indexes(cls):
return ['tags.name', 'tags.value']
Expand Down
16 changes: 16 additions & 0 deletions st2common/st2common/models/db/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@ def __init__(self, *args, **values):
super(TraceDB, self).__init__(*args, **values)
self.uid = self.get_uid()

# Manualy de-reference EmbeddedDocumentField fields to avoid overhead of de-referencing all
# the fields inside the base Document class constructor when __auto_convert is True.
# This approach means we need to update this code each time we add new
# EmbeddedDocumentField (which we should avoid anyway for performance reasons)
if self.trigger_instances:
self.trigger_instances = \
self._fields['trigger_instances'].to_python(self.trigger_instances)

if self.rules:
self.rules = \
self._fields['rules'].to_python(self.rules)

if self.action_executions:
self.action_executions = \
self._fields['action_executions'].to_python(self.action_executions)

def get_uid(self):
parts = []
parts.append(self.RESOURCE_TYPE)
Expand Down
6 changes: 6 additions & 0 deletions st2common/st2common/models/db/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def __init__(self, *args, **values):
# pylint: disable=no-member
self.uid = self.get_uid()

# Manualy de-reference EmbeddedDocumentField fields to avoid overhead of de-referencing all
# the fields inside the base Document class constructor when __auto_convert is True.
# This approach means we need to update this code each time we add new
# EmbeddedDocumentField (which we should avoid anyway for performance reasons)
stormbase.TagsMixin.__init__(self)


class TriggerDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin,
stormbase.UIDFieldMixin):
Expand Down
12 changes: 6 additions & 6 deletions st2common/st2common/persistence/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def get_by_nickname(cls, nickname, origin):
if not origin:
raise NoNicknameOriginProvidedError()

result = cls.query(**{('nicknames__%s' % origin): nickname})
result = cls.query(first=True, **{('nicknames__%s' % origin): nickname})

if not result.first():
if not result:
raise UserNotFoundError()
if result.count() > 1:
elif len(result) > 1:
raise AmbiguousUserError()

return result.first()
return result[0]

@classmethod
def _get_impl(cls):
Expand Down Expand Up @@ -73,7 +73,7 @@ def add_or_update(cls, model_object, publish=True, validate=True):

@classmethod
def get(cls, value):
result = cls.query(token=value).first()
result = cls.query(token=value, first=True)

if not result:
raise TokenNotFoundError()
Expand All @@ -92,7 +92,7 @@ def _get_impl(cls):
def get(cls, value):
# DB does not contain key but the key_hash.
value_hash = hash_utils.hash(value)
result = cls.query(key_hash=value_hash).first()
result = cls.query(key_hash=value_hash, first=True)

if not result:
raise ApiKeyNotFoundError('ApiKey with key_hash=%s not found.' % value_hash)
Expand Down
Loading